diff --git a/.ocamlformat b/.ocamlformat index c9d45a2..150000d 100644 --- a/.ocamlformat +++ b/.ocamlformat @@ -1,2 +1 @@ profile = ocamlformat -version = 0.27.0 diff --git a/bin/main.ml b/bin/main.ml index 071577e..0d3645a 100644 --- a/bin/main.ml +++ b/bin/main.ml @@ -7,6 +7,23 @@ let handlers = ; (get, "/robots.txt", Api.Robots.handler) ; (get, "/xrpc/_health", Api.Health.handler) ; (get, "/.well-known/did.json", Api.Well_known.did_json) + ; ( get + , "/.well-known/oauth-protected-resource" + , Api.Well_known.oauth_protected_resource ) + ; ( get + , "/.well-known/oauth-authorization-server" + , Api.Well_known.oauth_authorization_server ) + ; (* oauth *) + (options, "/oauth/par", Api.Oauth_.Par.options_handler) + ; (post, "/oauth/par", Api.Oauth_.Par.post_handler) + ; (get, "/oauth/authorize", Api.Oauth_.Authorize.get_handler) + ; (post, "/oauth/authorize", Api.Oauth_.Authorize.post_handler) + ; (options, "/oauth/token", Api.Oauth_.Token.options_handler) + ; (post, "/oauth/token", Api.Oauth_.Token.post_handler) + ; (* account *) + (get, "/account/login", Api.Account_.Login.get_handler) + ; (post, "/account/login", Api.Account_.Login.post_handler) + ; (get, "/account/logout", Api.Account_.Logout.handler) ; (* unauthed *) ( get , "/xrpc/com.atproto.server.describeServer" @@ -15,7 +32,7 @@ let handlers = ; ( get , "/xrpc/com.atproto.identity.resolveHandle" , Api.Identity.ResolveHandle.handler ) - ; (* account *) + ; (* account management *) ( post , "/xrpc/com.atproto.server.createInviteCode" , Api.Server.CreateInviteCode.handler ) @@ -65,16 +82,22 @@ let handlers = , "/xrpc/com.atproto.actor.putPreferences" , Api.Actor.PutPreferences.handler ) ] +let static_routes = + [Dream.get "/public/**" (Dream.static "_build/default/public")] + let main = let%lwt db = Data_store.connect ~create:true () in let%lwt () = Data_store.init db in Dream.serve ~interface:"0.0.0.0" ~port:8008 @@ Dream.logger + @@ Dream.set_secret (Env.jwt_key |> Kleidos.privkey_to_multikey) + @@ Dream.cookie_sessions @@ Xrpc.service_proxy_middleware db - @@ Dream.router + @@ Xrpc.dpop_middleware @@ Xrpc.cors_middleware @@ Dream.router @@ List.map (fun (fn, path, handler) -> fn path (fun req -> handler ({req; db} : Xrpc.init)) ) handlers + @ static_routes let () = Lwt_main.run main diff --git a/dune b/dune new file mode 100644 index 0000000..a40e8d6 --- /dev/null +++ b/dune @@ -0,0 +1,21 @@ +(subdir + public/ + (rule + (target index.css) + (deps + %{workspace_root}/tools/tailwindcss/tailwindcss + (:input %{workspace_root}/public/main.css) + (source_tree %{workspace_root}/public) + (source_tree %{workspace_root}/pegasus/lib/templates)) + (action + (chdir + %{workspace_root} + (run + %{workspace_root}/tools/tailwindcss/tailwindcss + -m + -i + %{input} + -o + %{target}))))) + +(copy_files public/*) diff --git a/dune-project b/dune-project index fc0cc0a..96f72d3 100644 --- a/dune-project +++ b/dune-project @@ -30,6 +30,7 @@ (url "git+https://github.com/roddyyaga/ppx_rapper.git") (package (name ppx_rapper_lwt))) + (package (name pegasus) (synopsis "An atproto Personal Data Server implementation") @@ -46,16 +47,21 @@ (cohttp-lwt-unix (>= 6.1.1)) (dns-client (>= 10.2.0)) dream + html_of_jsx + mlx (re (>= 1.13.2)) (safepass (>= 3.1)) (timedesc (>= 3.1.0)) + (uri (>= 4.4.0)) (uuidm (>= 0.9.10)) (yojson (>= 3.0.0)) (lwt_ppx (>= 5.9.1)) (ppx_deriving_yojson (>= 3.9.1)) ppx_rapper ppx_rapper_lwt - (alcotest :with-test))) + (alcotest :with-test) + (ocamlformat-mlx :with-dev-setup) + (ocamlmerlin-mlx :with-dev-setup))) (package (name mist) @@ -97,3 +103,14 @@ (hacl-star (>= 0.7.2)) (mirage-crypto-ec (>= 2.0.1)) (multibase (>= 0.1.0)))) + +(package + (name tailwindcss) (allow_empty)) + +(dialect + (name mlx) + (implementation + (extension mlx) + (merlin_reader mlx) + (preprocess + (run mlx-pp %{input-file})))) diff --git a/ipld/lib/dag_cbor.ml b/ipld/lib/dag_cbor.ml index 31d4a32..0a20c0e 100644 --- a/ipld/lib/dag_cbor.ml +++ b/ipld/lib/dag_cbor.ml @@ -197,8 +197,8 @@ module Encoder = struct write_type_and_argument t 5 (Int64.of_int len) ; ordered_map_keys m |> List.iter (fun k -> - write_string t k ; - write_value t (String_map.find k m) ) + write_string t k ; + write_value t (String_map.find k m) ) | `Link cid -> write_cid t cid diff --git a/ipld/test/test_dag_cbor.ml b/ipld/test/test_dag_cbor.ml index d3588d2..77cc592 100644 --- a/ipld/test/test_dag_cbor.ml +++ b/ipld/test/test_dag_cbor.ml @@ -3,7 +3,7 @@ module String_map = Dag_cbor.String_map let rec stringify_map m = String_map.bindings m |> List.map (fun (k, v) -> - Format.sprintf "\"%s\": %s" k (stringify_ipld_value v) ) + Format.sprintf "\"%s\": %s" k (stringify_ipld_value v) ) |> String.concat ", " |> Format.sprintf "{%s}" and stringify_ipld_value (value : Dag_cbor.value) = @@ -109,9 +109,9 @@ let test_encode_primitives () = Hashtbl.add cases (to_base_16 (Dag_cbor.encode `Null)) (Bytes.of_string "f6") ; cases |> Hashtbl.iter (fun key value -> - Alcotest.(check bytes) - ("encoded bytes for " ^ key) - value (Bytes.of_string key) ) + Alcotest.(check bytes) + ("encoded bytes for " ^ key) + value (Bytes.of_string key) ) let test_round_trip () = let test_cid = diff --git a/kleidos/kleidos.ml b/kleidos/kleidos.ml index f31a808..6998a09 100644 --- a/kleidos/kleidos.ml +++ b/kleidos/kleidos.ml @@ -220,3 +220,7 @@ let verify ~pubkey ~msg ~signature : bool = let pubkey_to_did_key pubkey : string = let pubkey, (module Curve : CURVE) = pubkey in Curve.pubkey_to_did_key pubkey + +let privkey_to_multikey privkey : string = + let privkey, (module Curve : CURVE) = privkey in + Curve.privkey_to_multikey privkey diff --git a/mist/lib/mst.ml b/mist/lib/mst.ml index 180393f..510806c 100644 --- a/mist/lib/mst.ml +++ b/mist/lib/mst.ml @@ -239,12 +239,12 @@ struct | None, [] -> Lwt.return 0 | Some left, [] -> ( - match%lwt retrieve_node_raw t left with - | Some node -> - let%lwt height = get_node_height t node in - Lwt.return (height + 1) - | None -> - failwith ("couldn't find node " ^ Cid.to_string left) ) + match%lwt retrieve_node_raw t left with + | Some node -> + let%lwt height = get_node_height t node in + Lwt.return (height + 1) + | None -> + failwith ("couldn't find node " ^ Cid.to_string left) ) | _, leaf :: _ -> ( match leaf.p with | 0 -> @@ -497,12 +497,14 @@ struct let%lwt blocks = match Util.at_index index seq with | Some (Leaf (k, v, _)) when k = key -> ( - (* include the found leaf block to prove existence *) - match%lwt Store.get_bytes t.blockstore v with - | Some leaf_bytes -> - Lwt.return (Block_map.set v leaf_bytes Block_map.empty) - | None -> - Lwt.return Block_map.empty ) + (* include the found leaf block to prove existence *) + match%lwt + Store.get_bytes t.blockstore v + with + | Some leaf_bytes -> + Lwt.return (Block_map.set v leaf_bytes Block_map.empty) + | None -> + Lwt.return Block_map.empty ) | _ -> ( let prev = if index - 1 >= 0 then Util.at_index (index - 1) seq else None @@ -529,23 +531,22 @@ struct let%lwt bm = match left_leaf with | Some cid_left -> ( - match%lwt Store.get_bytes t.blockstore cid_left with - | Some b -> - Lwt.return - (Block_map.set cid_left b Block_map.empty) - | None -> - Lwt.return Block_map.empty ) + match%lwt Store.get_bytes t.blockstore cid_left with + | Some b -> + Lwt.return (Block_map.set cid_left b Block_map.empty) + | None -> + Lwt.return Block_map.empty ) | None -> Lwt.return Block_map.empty in let%lwt bm = match right_leaf with | Some cid_right -> ( - match%lwt Store.get_bytes t.blockstore cid_right with - | Some b -> - Lwt.return (Block_map.set cid_right b bm) - | None -> - Lwt.return bm ) + match%lwt Store.get_bytes t.blockstore cid_right with + | Some b -> + Lwt.return (Block_map.set cid_right b bm) + | None -> + Lwt.return bm ) | None -> Lwt.return bm in @@ -571,11 +572,11 @@ struct | Some (Tree c) -> proof_for_left_sibling t c key | Some (Leaf (_, v_left, _)) -> ( - match%lwt Store.get_bytes t.blockstore v_left with - | Some b -> - Lwt.return (Block_map.set v_left b Block_map.empty) - | None -> - Lwt.return Block_map.empty ) + match%lwt Store.get_bytes t.blockstore v_left with + | Some b -> + Lwt.return (Block_map.set v_left b Block_map.empty) + | None -> + Lwt.return Block_map.empty ) | _ -> Lwt.return Block_map.empty in @@ -612,11 +613,11 @@ struct | Some (Tree c) -> proof_for_right_sibling t c key | Some (Leaf (_, v_right, _)) -> ( - match%lwt Store.get_bytes t.blockstore v_right with - | Some b -> - Lwt.return (Block_map.set v_right b Block_map.empty) - | None -> - Lwt.return Block_map.empty ) + match%lwt Store.get_bytes t.blockstore v_right with + | Some b -> + Lwt.return (Block_map.set v_right b Block_map.empty) + | None -> + Lwt.return Block_map.empty ) | _ -> Lwt.return Block_map.empty ) | None -> diff --git a/mist/test/test_util.ml b/mist/test/test_util.ml index 021f256..629b827 100644 --- a/mist/test/test_util.ml +++ b/mist/test/test_util.ml @@ -8,10 +8,10 @@ let test_leading_zeros () = Hashtbl.add cases "app.bsky.feed.post/9adeb165882c" 8 ; cases |> Hashtbl.iter (fun key value -> - Alcotest.(check int) - ("leading zeros on hash " ^ key) - value - (leading_zeros_on_hash key) ) + Alcotest.(check int) + ("leading zeros on hash " ^ key) + value + (leading_zeros_on_hash key) ) let test_shared_prefix_length () = let cases = Hashtbl.create 5 in @@ -22,9 +22,9 @@ let test_shared_prefix_length () = Hashtbl.add cases ("2653ae71", "0653ae71") 0 ; cases |> Hashtbl.iter (fun (a, b) value -> - Alcotest.(check int) - ("prefix length between " ^ a ^ " and " ^ b) - value (shared_prefix_length a b) ) + Alcotest.(check int) + ("prefix length between " ^ a ^ " and " ^ b) + value (shared_prefix_length a b) ) let () = Alcotest.run "util" diff --git a/pegasus.opam b/pegasus.opam index fb505ea..39b379d 100644 --- a/pegasus.opam +++ b/pegasus.opam @@ -18,9 +18,12 @@ depends: [ "cohttp-lwt-unix" {>= "6.1.1"} "dns-client" {>= "10.2.0"} "dream" + "html_of_jsx" + "mlx" "re" {>= "1.13.2"} "safepass" {>= "3.1"} "timedesc" {>= "3.1.0"} + "uri" {>= "4.4.0"} "uuidm" {>= "0.9.10"} "yojson" {>= "3.0.0"} "lwt_ppx" {>= "5.9.1"} @@ -28,6 +31,8 @@ depends: [ "ppx_rapper" "ppx_rapper_lwt" "alcotest" {with-test} + "ocamlformat-mlx" {with-dev-setup} + "ocamlmerlin-mlx" {with-dev-setup} "odoc" {with-doc} ] build: [ diff --git a/pegasus/lib/api/account_/login.ml b/pegasus/lib/api/account_/login.ml new file mode 100644 index 0000000..95be648 --- /dev/null +++ b/pegasus/lib/api/account_/login.ml @@ -0,0 +1,49 @@ +let get_handler = + Xrpc.handler (fun ctx -> + let redirect_url = + if List.length @@ Dream.all_queries ctx.req > 0 then + Uri.make ~path:"/oauth/authorize" ~query:(Util.copy_query ctx.req) () + |> Uri.to_string + else "/account" + in + let csrf_token = Dream.csrf_token ctx.req in + let html = + JSX.render (Templates.Login.make ~redirect_url ~csrf_token ()) + in + Dream.html html ) + +let post_handler = + Xrpc.handler (fun ctx -> + let csrf_token = Dream.csrf_token ctx.req in + match%lwt Dream.form ctx.req with + | `Ok fields -> ( + let identifier = List.assoc "identifier" fields in + let password = List.assoc "password" fields in + let redirect_url = + List.assoc_opt "redirect_url" fields + |> Option.value ~default:"/account" + in + let%lwt actor = + Data_store.try_login ~id:identifier ~password ctx.db + in + match actor with + | None -> + let html = + JSX.render + (Templates.Login.make ~redirect_url + ~error:"Invalid username or password. Please try again." + ~csrf_token () ) + in + Dream.html ~status:`Unauthorized html + | Some {did; _} -> + let%lwt () = Dream.invalidate_session ctx.req in + let%lwt () = Dream.set_session_field ctx.req "did" did in + Dream.redirect ctx.req redirect_url ) + | _ -> + let html = + JSX.render + (Templates.Login.make ~redirect_url:"/account" + ~error:"Invalid credentials provided. Please try again." + ~csrf_token () ) + in + Dream.html ~status:`Unauthorized html ) diff --git a/pegasus/lib/api/account_/logout.ml b/pegasus/lib/api/account_/logout.ml new file mode 100644 index 0000000..4839ff9 --- /dev/null +++ b/pegasus/lib/api/account_/logout.ml @@ -0,0 +1,4 @@ +let handler = + Xrpc.handler (fun ctx -> + let%lwt () = Dream.invalidate_session ctx.req in + Dream.redirect ctx.req "/account/login" ) diff --git a/pegasus/lib/api/actor/putPreferences.ml b/pegasus/lib/api/actor/putPreferences.ml index ddaa25b..b84cbc2 100644 --- a/pegasus/lib/api/actor/putPreferences.ml +++ b/pegasus/lib/api/actor/putPreferences.ml @@ -1,5 +1,5 @@ let handler = - Xrpc.handler ~auth:Authorization (fun {req; db; auth} -> + Xrpc.handler ~auth:Authorization (fun {req; auth; db; _} -> let did = Auth.get_authed_did_exn auth in let%lwt body = Dream.body req in let prefs = diff --git a/pegasus/lib/api/identity/resolveHandle.ml b/pegasus/lib/api/identity/resolveHandle.ml index ba74c06..f7af272 100644 --- a/pegasus/lib/api/identity/resolveHandle.ml +++ b/pegasus/lib/api/identity/resolveHandle.ml @@ -14,9 +14,9 @@ let handler = Dream.json @@ Yojson.Safe.to_string @@ response_to_yojson {did= actor.did} | None -> ( - match%lwt Id_resolver.Handle.resolve handle with - | Ok did -> - Dream.json @@ Yojson.Safe.to_string @@ response_to_yojson {did} - | Error e -> - Errors.log_exn (Failure e) ; - Errors.internal_error ~msg:"could not resolve handle" () ) ) + match%lwt Id_resolver.Handle.resolve handle with + | Ok did -> + Dream.json @@ Yojson.Safe.to_string @@ response_to_yojson {did} + | Error e -> + Errors.log_exn (Failure e) ; + Errors.internal_error ~msg:"could not resolve handle" () ) ) diff --git a/pegasus/lib/api/identity/updateHandle.ml b/pegasus/lib/api/identity/updateHandle.ml index d104e5e..75ccc81 100644 --- a/pegasus/lib/api/identity/updateHandle.ml +++ b/pegasus/lib/api/identity/updateHandle.ml @@ -1,7 +1,7 @@ type request = {handle: string} [@@deriving yojson] let handler = - Xrpc.handler ~auth:Authorization (fun {req; auth; db} -> + Xrpc.handler ~auth:Authorization (fun {req; auth; db; _} -> let did = Auth.get_authed_did_exn auth in let%lwt body = Dream.body req in let handle = @@ -15,60 +15,57 @@ let handler = | Error e -> raise e | Ok () -> ( - match%lwt Data_store.get_actor_by_identifier handle db with - | Some _ -> - Errors.invalid_request ~name:"InvalidHandle" - "handle already in use" - | None -> - let%lwt () = Data_store.update_actor_handle ~did ~handle db in - let%lwt _ = - if String.starts_with ~prefix:"did:plc:" did then - match%lwt Plc.get_audit_log did with - | Error e -> - Dream.error (fun log -> log ~request:req "%s" e) ; - Errors.internal_error ~msg:"failed to fetch did doc" () - | Ok log -> ( - let latest = List.rev log |> List.hd in - let aka = - match - List.mem ("at://" ^ handle) - latest.operation.also_known_as - with - | true -> - latest.operation.also_known_as - | false -> - ("at://" ^ handle) :: latest.operation.also_known_as - in - let%lwt signing_key = - match%lwt Data_store.get_actor_by_identifier did db with - | Some {signing_key; _} -> - Lwt.return @@ Kleidos.parse_multikey_str signing_key - | _ -> - Errors.internal_error () - in - let signed = - Plc.sign_operation signing_key - (Operation - { type'= "plc_operation" - ; prev= Some latest.cid - ; also_known_as= aka - ; rotation_keys= latest.operation.rotation_keys - ; verification_methods= - latest.operation.verification_methods - ; services= latest.operation.services } ) - in - match%lwt Plc.submit_operation did signed with - | Ok _ -> - Lwt.return_unit - | Error (status, msg) -> - Dream.error (fun log -> - log ~request:req "%d %s" status msg ) ; - Errors.internal_error - ~msg:"failed to submit plc operation" () ) - else Lwt.return_unit - in - let () = - Ttl_cache.String_cache.remove Id_resolver.Did.cache did - in - let%lwt _ = Sequencer.sequence_identity db ~did ~handle () in - Dream.empty `OK ) ) + match%lwt Data_store.get_actor_by_identifier handle db with + | Some _ -> + Errors.invalid_request ~name:"InvalidHandle" "handle already in use" + | None -> + let%lwt () = Data_store.update_actor_handle ~did ~handle db in + let%lwt _ = + if String.starts_with ~prefix:"did:plc:" did then + match%lwt Plc.get_audit_log did with + | Error e -> + Dream.error (fun log -> log ~request:req "%s" e) ; + Errors.internal_error ~msg:"failed to fetch did doc" () + | Ok log -> ( + let latest = List.rev log |> List.hd in + let aka = + match + List.mem ("at://" ^ handle) + latest.operation.also_known_as + with + | true -> + latest.operation.also_known_as + | false -> + ("at://" ^ handle) :: latest.operation.also_known_as + in + let%lwt signing_key = + match%lwt Data_store.get_actor_by_identifier did db with + | Some {signing_key; _} -> + Lwt.return @@ Kleidos.parse_multikey_str signing_key + | _ -> + Errors.internal_error () + in + let signed = + Plc.sign_operation signing_key + (Operation + { type'= "plc_operation" + ; prev= Some latest.cid + ; also_known_as= aka + ; rotation_keys= latest.operation.rotation_keys + ; verification_methods= + latest.operation.verification_methods + ; services= latest.operation.services } ) + in + match%lwt Plc.submit_operation did signed with + | Ok _ -> + Lwt.return_unit + | Error (status, msg) -> + Dream.error (fun log -> + log ~request:req "%d %s" status msg ) ; + Errors.internal_error + ~msg:"failed to submit plc operation" () ) + else Lwt.return_unit + in + let () = Ttl_cache.String_cache.remove Id_resolver.Did.cache did in + let%lwt _ = Sequencer.sequence_identity db ~did ~handle () in + Dream.empty `OK ) ) diff --git a/pegasus/lib/api/oauth_/authorize.ml b/pegasus/lib/api/oauth_/authorize.ml new file mode 100644 index 0000000..1a43f31 --- /dev/null +++ b/pegasus/lib/api/oauth_/authorize.ml @@ -0,0 +1,184 @@ +open Oauth +open Oauth.Types + +let get_session_user (ctx : Xrpc.context) = + match Dream.session_field ctx.req "did" with + | Some did -> + Lwt.return_some did + | None -> + Lwt.return_none + +let get_handler = + Xrpc.handler (fun ctx -> + let login_redirect = + Uri.make ~path:"/account/login" ~query:(Util.copy_query ctx.req) () + |> Uri.to_string |> Dream.redirect ctx.req + in + let client_id = Dream.query ctx.req "client_id" in + let request_uri = Dream.query ctx.req "request_uri" in + match (client_id, request_uri) with + | None, _ | _, None -> + login_redirect + | Some client_id, Some request_uri -> ( + let prefix = Constants.request_uri_prefix in + if not (String.starts_with ~prefix request_uri) then login_redirect + else + let request_id = + String.sub request_uri (String.length prefix) + (String.length request_uri - String.length prefix) + in + match%lwt Queries.get_par_request ctx.db request_id with + | None -> + login_redirect + | Some req_record -> ( + if req_record.client_id <> client_id then login_redirect + else + let req = + Yojson.Safe.from_string req_record.request_data + |> par_request_of_yojson + |> Result.map_error (fun _ -> + Errors.internal_error ~msg:"failed to parse par request" + () ) + |> Result.get_ok + in + let%lwt metadata = + try%lwt Client.fetch_client_metadata client_id + with _ -> + Errors.internal_error + ~msg:"failed to fetch client metadata" () + in + let code = + "cod-" + ^ Uuidm.to_string + (Uuidm.v4_gen (Random.State.make_self_init ()) ()) + in + let expires_at = Util.now_ms () + Constants.code_expiry_ms in + let%lwt () = + Queries.insert_auth_code ctx.db + { code + ; request_id + ; authorized_by= None + ; authorized_at= None + ; expires_at + ; used= false } + in + match%lwt get_session_user ctx with + | None -> + login_redirect + | Some did -> ( + match req.login_hint with + | Some hint when hint <> did -> + login_redirect + | _ -> + let%lwt handle = + match%lwt + Data_store.get_actor_by_identifier did ctx.db + with + | Some {handle; _} -> + Lwt.return handle + | None -> + Errors.internal_error + ~msg:"failed to resolve user" () + in + let scopes = String.split_on_char ' ' req.scope in + let csrf_token = Dream.csrf_token ctx.req in + let html = + JSX.render + (Templates.Oauth_authorize.make ~metadata ~handle + ~scopes ~code ~request_uri ~csrf_token () ) + in + Dream.html html ) ) ) ) + +let post_handler = + Xrpc.handler (fun ctx -> + match%lwt get_session_user ctx with + | None -> + Errors.auth_required "missing authentication" + | Some user_did -> ( + match%lwt Dream.form ctx.req with + | `Ok fields -> ( + let action = List.assoc_opt "action" fields in + let code = List.assoc_opt "code" fields in + let request_uri = List.assoc_opt "request_uri" fields in + match (action, code, request_uri) with + | Some "deny", _, Some request_uri -> ( + let prefix = Constants.request_uri_prefix in + let request_id = + String.sub request_uri (String.length prefix) + (String.length request_uri - String.length prefix) + in + let%lwt req_record = + Queries.get_par_request ctx.db request_id + in + match req_record with + | Some rec_ -> + let req = + Yojson.Safe.from_string rec_.request_data + |> par_request_of_yojson |> Result.get_ok + in + let params = + [ ("error", "access_denied") + ; ("error_description", "Unable to authorize user.") + ; ("state", req.state) + ; ("iss", "https://" ^ Env.hostname) ] + in + let query = + String.concat "&" + (List.map + (fun (k, v) -> k ^ "=" ^ Uri.pct_encode v) + params ) + in + Dream.redirect ctx.req (req.redirect_uri ^ "?" ^ query) + | None -> + Errors.invalid_request "request expired" ) + | Some "allow", Some code, Some _request_uri -> ( + let%lwt code_record = Queries.get_auth_code ctx.db code in + match code_record with + | None -> + Errors.invalid_request "invalid code" + | Some code_rec -> ( + if code_rec.authorized_by <> None then + Errors.invalid_request "code already authorized" + else if code_rec.used then + Errors.invalid_request "code already used" + else if Util.now_ms () > code_rec.expires_at then + Errors.invalid_request "code expired" + else + let%lwt () = + Queries.activate_auth_code ctx.db code user_did + in + let%lwt req_record = + Queries.get_par_request ctx.db code_rec.request_id + in + match req_record with + | None -> + Errors.internal_error ~msg:"request not found" () + | Some rec_ -> + let req = + Yojson.Safe.from_string rec_.request_data + |> par_request_of_yojson |> Result.get_ok + in + let params = + [ ("code", code) + ; ("state", req.state) + ; ("iss", "https://" ^ Env.hostname) ] + in + let query = + String.concat "&" + (List.map + (fun (k, v) -> k ^ "=" ^ Uri.pct_encode v) + params ) + in + let separator = + match req.response_mode with + | Some "fragment" -> + "#" + | _ -> + "?" + in + Dream.redirect ctx.req + (req.redirect_uri ^ separator ^ query) ) ) + | _ -> + Errors.invalid_request "invalid request" ) + | _ -> + Errors.invalid_request "invalid request" ) ) diff --git a/pegasus/lib/api/oauth_/par.ml b/pegasus/lib/api/oauth_/par.ml new file mode 100644 index 0000000..e321105 --- /dev/null +++ b/pegasus/lib/api/oauth_/par.ml @@ -0,0 +1,41 @@ +open Oauth +open Oauth.Types + +let options_handler = Xrpc.handler (fun _ -> Dream.empty `No_Content) + +let post_handler = + Xrpc.handler ~auth:DPoP (fun ctx -> + let proof = Auth.get_dpop_proof_exn ctx.auth in + let%lwt req = Xrpc.parse_body ctx.req par_request_of_yojson in + let%lwt client = + try%lwt Client.fetch_client_metadata req.client_id + with e -> + Errors.log_exn ~req:ctx.req e ; + Errors.invalid_request "failed to fetch client metadata" + in + if req.response_type <> "code" then + Errors.invalid_request "only response_type=code supported" + else if req.code_challenge_method <> "S256" then + Errors.invalid_request "only code_challenge_method=S256 supported" + else if not (List.mem req.redirect_uri client.redirect_uris) then + Errors.invalid_request "invalid redirect_uri" + else + let request_id = + "req-" + ^ Uuidm.to_string (Uuidm.v4_gen (Random.State.make_self_init ()) ()) + in + let request_uri = Constants.request_uri_prefix ^ request_id in + let expires_at = Util.now_ms () + Constants.par_request_ttl_ms in + let request : oauth_request = + { request_id + ; client_id= req.client_id + ; request_data= Yojson.Safe.to_string (par_request_to_yojson req) + ; dpop_jkt= Some proof.jkt + ; expires_at + ; created_at= Util.now_ms () } + in + let%lwt () = Queries.insert_par_request ctx.db request in + Dream.json ~status:`Created + @@ Yojson.Safe.to_string + @@ `Assoc + [("request_uri", `String request_uri); ("expires_in", `Int 300)] ) diff --git a/pegasus/lib/api/oauth_/token.ml b/pegasus/lib/api/oauth_/token.ml new file mode 100644 index 0000000..99bd63a --- /dev/null +++ b/pegasus/lib/api/oauth_/token.ml @@ -0,0 +1,179 @@ +open Oauth + +let options_handler = Xrpc.handler (fun _ -> Dream.empty `No_Content) + +let post_handler = + Xrpc.handler ~auth:DPoP (fun ctx -> + let%lwt req = Xrpc.parse_body ctx.req Types.token_request_of_yojson in + let proof = Auth.get_dpop_proof_exn ctx.auth in + match req.grant_type with + | "authorization_code" -> ( + match req.code with + | None -> + Errors.invalid_request "code required" + | Some code -> ( + let%lwt code_record = Queries.consume_auth_code ctx.db code in + match code_record with + | None -> + Errors.invalid_request "invalid code" + | Some code_rec -> ( + if Util.now_ms () > code_rec.expires_at then + Errors.invalid_request "code expired" + else + match code_rec.authorized_by with + | None -> + Errors.invalid_request "code not authorized" + | Some did -> ( + let%lwt par_req = + Queries.get_par_request ctx.db code_rec.request_id + in + match par_req with + | None -> + Errors.internal_error ~msg:"request not found" () + | Some par_record -> + let orig_req = + Yojson.Safe.from_string par_record.request_data + |> Types.par_request_of_yojson |> Result.get_ok + in + ( match req.redirect_uri with + | None -> + Errors.invalid_request "redirect_uri required" + | Some uri when uri <> orig_req.redirect_uri -> + Errors.invalid_request "redirect_uri mismatch" + | _ -> + () ) ; + ( match req.code_verifier with + | None -> + Errors.invalid_request "code_verifier required" + | Some verifier -> + let computed = + Digestif.SHA256.digest_string verifier + |> Digestif.SHA256.to_raw_string + |> Base64.( + encode_exn ~pad:false + ~alphabet:uri_safe_alphabet ) + in + if orig_req.code_challenge <> computed then + Errors.invalid_request "invalid code_verifier" + ) ; + ( match par_record.dpop_jkt with + | Some stored when stored <> proof.jkt -> + Errors.invalid_request "DPoP key mismatch" + | _ -> + () ) ; + let token_id = + "tok-" + ^ Uuidm.to_string + (Uuidm.v4_gen + (Random.State.make_self_init ()) + () ) + in + let refresh_token = + "ref-" + ^ Uuidm.to_string + (Uuidm.v4_gen + (Random.State.make_self_init ()) + () ) + in + let now_sec = int_of_float (Unix.gettimeofday ()) in + let expires_in = + Constants.access_token_expiry_ms / 1000 + in + let exp_sec = now_sec + expires_in in + let expires_at = exp_sec * 1000 in + let claims = + `Assoc + [ ("jti", `String token_id) + ; ("sub", `String did) + ; ("iat", `Int now_sec) + ; ("exp", `Int exp_sec) + ; ("scope", `String orig_req.scope) + ; ("aud", `String ("https://" ^ Env.hostname)) + ; ("cnf", `Assoc [("jkt", `String proof.jkt)]) ] + in + let access_token = + Jwt.sign_jwt claims ~typ:"at+jwt" Env.jwt_key + in + let%lwt () = + Queries.insert_oauth_token ctx.db + { refresh_token + ; client_id= req.client_id + ; did + ; dpop_jkt= proof.jkt + ; scope= orig_req.scope + ; expires_at } + in + let nonce = Dpop.next_nonce () in + Dream.json + ~headers: + [ ("DPoP-Nonce", nonce) + ; ("Access-Control-Expose-Headers", "DPoP-Nonce") + ; ("Cache-Control", "no-store") ] + @@ Yojson.Safe.to_string + @@ `Assoc + [ ("access_token", `String access_token) + ; ("token_type", `String "DPoP") + ; ("refresh_token", `String refresh_token) + ; ("expires_in", `Int expires_in) + ; ("scope", `String orig_req.scope) + ; ("sub", `String did) ] ) ) ) ) + | "refresh_token" -> ( + match req.refresh_token with + | None -> + Errors.invalid_request "refresh_token required" + | Some refresh_token -> ( + let%lwt token_record = + Queries.get_oauth_token_by_refresh ctx.db refresh_token + in + match token_record with + | None -> + Errors.invalid_request "invalid refresh token" + | Some session -> + if session.client_id <> req.client_id then + Errors.invalid_request "client_id mismatch" + else if session.dpop_jkt <> proof.jkt then + Errors.invalid_request "DPoP key mismatch" + else + let new_token_id = + "tok-" + ^ Uuidm.to_string + (Uuidm.v4_gen (Random.State.make_self_init ()) ()) + in + let new_refresh = + "ref-" + ^ Uuidm.to_string + (Uuidm.v4_gen (Random.State.make_self_init ()) ()) + in + let now_sec = int_of_float (Unix.gettimeofday ()) in + let expires_in = Constants.access_token_expiry_ms / 1000 in + let exp_sec = now_sec + expires_in in + let new_expires_at = exp_sec * 1000 in + let claims = + `Assoc + [ ("jti", `String new_token_id) + ; ("sub", `String session.did) + ; ("iat", `Int now_sec) + ; ("exp", `Int exp_sec) + ; ("scope", `String session.scope) + ; ("aud", `String ("https://" ^ Env.hostname)) + ; ("cnf", `Assoc [("jkt", `String proof.jkt)]) ] + in + let new_access_token = + Jwt.sign_jwt claims ~typ:"at+jwt" Env.jwt_key + in + let%lwt () = + Queries.update_oauth_token ctx.db + ~old_refresh_token:refresh_token + ~new_refresh_token:new_refresh ~expires_at:new_expires_at + in + Dream.json ~headers:[("Cache-Control", "no-store")] + @@ Yojson.Safe.to_string + @@ `Assoc + [ ("access_token", `String new_access_token) + ; ("token_type", `String "DPoP") + ; ("refresh_token", `String new_refresh) + ; ("expires_in", `Int expires_in) + ; ("scope", `String session.scope) + ; ("sub", `String session.did) ] ) ) + | _ -> + Errors.invalid_request ("unsupported grant_type: " ^ req.grant_type) ) diff --git a/pegasus/lib/api/repo/createAccount.ml b/pegasus/lib/api/repo/createAccount.ml index 1029089..c082948 100644 --- a/pegasus/lib/api/repo/createAccount.ml +++ b/pegasus/lib/api/repo/createAccount.ml @@ -57,11 +57,11 @@ let handler = let%lwt did = match input.did with | Some did -> ( - match%lwt Data_store.get_actor_by_identifier did ctx.db with - | Some _ -> - Errors.invalid_request "an account with that did already exists" - | None -> - Lwt.return did ) + match%lwt Data_store.get_actor_by_identifier did ctx.db with + | Some _ -> + Errors.invalid_request "an account with that did already exists" + | None -> + Lwt.return did ) | None -> ( let sk_did = Kleidos.K256.pubkey_to_did_key signing_pubkey in let rotation_did_keys = @@ -79,11 +79,11 @@ let handler = let%lwt _ = match input.invite_code with | Some code -> ( - match%lwt Data_store.use_invite ~code ctx.db with - | Some _ -> - Lwt.return () - | None -> - failwith "failed to use invite code" ) + match%lwt Data_store.use_invite ~code ctx.db with + | Some _ -> + Lwt.return () + | None -> + failwith "failed to use invite code" ) | None -> Lwt.return () in diff --git a/pegasus/lib/api/server/createSession.ml b/pegasus/lib/api/server/createSession.ml index 283d1ab..2141941 100644 --- a/pegasus/lib/api/server/createSession.ml +++ b/pegasus/lib/api/server/createSession.ml @@ -17,7 +17,7 @@ type response = [@@deriving yojson {strict= false}] let handler = - Xrpc.handler (fun {req; db; auth} -> + Xrpc.handler (fun {req; auth; db; _} -> let%lwt {identifier; password; _} = Xrpc.parse_body req request_of_yojson in diff --git a/pegasus/lib/api/server/getServiceAuth.ml b/pegasus/lib/api/server/getServiceAuth.ml index 2590e31..cb8433e 100644 --- a/pegasus/lib/api/server/getServiceAuth.ml +++ b/pegasus/lib/api/server/getServiceAuth.ml @@ -1,7 +1,7 @@ type response = {token: string} [@@deriving yojson {strict= false}] let handler = - Xrpc.handler ~auth:Authorization (fun {req; auth; db} -> + Xrpc.handler ~auth:Authorization (fun {req; auth; db; _} -> let did = Auth.get_authed_did_exn auth in let aud, lxm = match (Dream.query req "aud", Dream.query req "lxm") with diff --git a/pegasus/lib/api/well_known.ml b/pegasus/lib/api/well_known.ml index 07a3a1d..982ddc3 100644 --- a/pegasus/lib/api/well_known.ml +++ b/pegasus/lib/api/well_known.ml @@ -1,3 +1,10 @@ +open struct + let make_url pth = + Uri.(make ~scheme:"https" ~host:Env.hostname ~path:pth () |> to_string) + + let pds_url = `String (make_url "") +end + let did_json = Xrpc.handler (fun _ -> Dream.json @@ Yojson.Safe.to_string @@ -8,5 +15,53 @@ let did_json = , `Assoc [ ("id", `String "#atproto_pds") ; ("type", `String "AtprotoPersonalDataServer") - ; ("serviceEndpoint", `String ("https://" ^ Env.hostname)) ] ) - ] ) + ; ("serviceEndpoint", pds_url) ] ) ] ) + +let oauth_protected_resource = + Xrpc.handler (fun _ -> + Dream.json @@ Yojson.Safe.to_string + @@ `Assoc + [ ("authorization_servers", `List [pds_url]) + ; ("bearer_methods_supported", `List [`String "header"]) + ; ("resource", pds_url) + ; ("resource_documentation", `String "https://atproto.com") + ; ("scopes_supported", `List []) ] ) + +let oauth_authorization_server = + Xrpc.handler (fun _ -> + Dream.json @@ Yojson.Safe.to_string + @@ `Assoc + [ ("issuer", pds_url) + ; ("authorization_endpoint", `String (make_url "/oauth/authorize")) + ; ("token_endpoint", `String (make_url "/oauth/token")) + ; ( "pushed_authorization_request_endpoint" + , `String (make_url "/oauth/par") ) + ; ("require_pushed_authorization_requests", `Bool true) + ; ( "scopes_supported" + , `List + [ `String "atproto" + ; `String "transition:email" + ; `String "transition:generic" + ; `String "transition:chat.bsky" ] ) + ; ("subject_types_supported", `List [`String "public"]) + ; ("response_types_supported", `List [`String "code"]) + ; ( "response_modes_supported" + , `List [`String "query"; `String "fragment"] ) + ; ( "grant_types_supported" + , `List [`String "authorization_code"; `String "refresh_token"] ) + ; ("code_challenge_methods_supported", `List [`String "S256"]) + ; ("ui_locales_supported", `List [`String "en-US"]) + ; ( "display_values_supported" + , `List [`String "page"; `String "popup"; `String "touch"] ) + ; ("authorization_response_iss_parameter_supported", `Bool true) + ; ( "request_object_signing_alg_values_supported" + , `List [`String "ES256"; `String "ES256K"] ) + ; ("request_object_encryption_alg_values_supported", `List []) + ; ("request_object_encryption_enc_values_supported", `List []) + ; ( "token_endpoint_auth_methods_supported" + , `List [`String "none"; `String "private_key_jwt"] ) + ; ( "token_endpoint_auth_signing_alg_values_supported" + , `List [`String "ES256"; `String "ES256K"] ) + ; ( "dpop_signing_alg_values_supported" + , `List [`String "ES256"; `String "ES256K"] ) + ; ("client_id_metadata_document_supported", `Bool true) ] ) diff --git a/pegasus/lib/auth.ml b/pegasus/lib/auth.ml index 5738224..a4119f7 100644 --- a/pegasus/lib/auth.ml +++ b/pegasus/lib/auth.ml @@ -15,6 +15,8 @@ type credentials = | Admin | Access of {did: string} | Refresh of {did: string; jti: string} + | OAuth of {did: string; proof: Oauth.Dpop.proof} + | DPoP of {proof: Oauth.Dpop.proof} let verify_bearer_jwt t token expected_scope = match Jwt.verify_jwt token Env.jwt_key with @@ -42,7 +44,7 @@ let verify_auth ?(refresh = false) credentials did = match credentials with | Admin -> true - | Access {did= creds} when creds = did -> + | (Access {did= creds} | OAuth {did= creds; _}) when creds = did -> true | Refresh {did= creds; _} when creds = did && refresh -> true @@ -50,12 +52,18 @@ let verify_auth ?(refresh = false) credentials did = false let get_authed_did_exn = function - | Access {did} -> + | Access {did} | OAuth {did; _} -> did | Refresh {did; _} -> did | _ -> - Errors.auth_required "Invalid authorization header" + Errors.auth_required "invalid authorization header" + +let get_dpop_proof_exn = function + | OAuth {proof; _} | DPoP {proof} -> + proof + | _ -> + Errors.invalid_request "invalid DPoP header" let get_session_info identifier db = let%lwt actor = @@ -84,7 +92,7 @@ let get_session_info identifier db = module Verifiers = struct open struct let parse_header req expected_type = - match Dream.header req "authorization" with + match Dream.header req "Authorization" with | Some header -> ( match String.split_on_char ' ' header with | [typ; token] @@ -95,24 +103,26 @@ module Verifiers = struct Error "invalid authorization header" ) | None -> Error "missing authorization header" + end - let parse_basic req = - match parse_header req "Basic" with - | Ok token -> ( - match Base64.decode token with - | Ok decoded -> ( - match Str.bounded_split (Str.regexp_string ":") decoded 2 with - | [username; password] -> - Ok (username, password) - | _ -> - Error "invalid basic authorization header" ) - | Error _ -> + let parse_basic req = + match parse_header req "Basic" with + | Ok token -> ( + match Base64.decode token with + | Ok decoded -> ( + match Str.bounded_split (Str.regexp_string ":") decoded 2 with + | [username; password] -> + Ok (username, password) + | _ -> Error "invalid basic authorization header" ) | Error _ -> - Error "invalid basic authorization header" + Error "invalid basic authorization header" ) + | Error _ -> + Error "invalid basic authorization header" - let parse_bearer req = parse_header req "Bearer" - end + let parse_bearer req = parse_header req "Bearer" + + let parse_dpop req = parse_header req "DPoP" type ctx = {req: Dream.request; db: Data_store.t} @@ -122,7 +132,7 @@ module Verifiers = struct fun {req; _} -> match Dream.header req "authorization" with | Some _ -> - Lwt.return_error @@ Errors.auth_required "Invalid authorization header" + Lwt.return_error @@ Errors.auth_required "invalid authorization header" | None -> Lwt.return_ok Unauthenticated @@ -134,49 +144,115 @@ module Verifiers = struct | "admin", p when p = Env.admin_password -> Lwt.return_ok Admin | _ -> - Lwt.return_error @@ Errors.auth_required "Invalid credentials" ) + Lwt.return_error @@ Errors.auth_required "invalid credentials" ) | Error _ -> - Lwt.return_error @@ Errors.auth_required "Invalid authorization header" + Lwt.return_error @@ Errors.auth_required "invalid authorization header" - let access : verifier = + let bearer : verifier = fun {req; db} -> match parse_bearer req with | Ok jwt -> ( - match%lwt verify_bearer_jwt db jwt "com.atproto.access" with - | Ok {sub= did; _} -> ( - match%lwt Data_store.get_actor_by_identifier did db with - | Some {deactivated_at= None; _} -> - Lwt.return_ok (Access {did}) - | Some {deactivated_at= Some _; _} -> - Lwt.return_error - @@ Errors.auth_required ~name:"AccountDeactivated" - "Account is deactivated" - | None -> - Lwt.return_error @@ Errors.auth_required "Invalid credentials" ) - | Error _ -> - Lwt.return_error @@ Errors.auth_required "Invalid credentials" ) + match%lwt verify_bearer_jwt db jwt "com.atproto.access" with + | Ok {sub= did; _} -> ( + match%lwt Data_store.get_actor_by_identifier did db with + | Some {deactivated_at= None; _} -> + Lwt.return_ok (Access {did}) + | Some {deactivated_at= Some _; _} -> + Lwt.return_error + @@ Errors.auth_required ~name:"AccountDeactivated" + "account is deactivated" + | None -> + Lwt.return_error @@ Errors.auth_required "invalid credentials" ) + | Error _ -> + Lwt.return_error @@ Errors.auth_required "invalid credentials" ) | Error _ -> - Lwt.return_error @@ Errors.auth_required "Invalid authorization header" + Lwt.return_error @@ Errors.auth_required "invalid authorization header" + + let dpop : verifier = + fun {req; _} -> + let dpop_header = Dream.header req "DPoP" in + match + Oauth.Dpop.verify_dpop_proof + ~mthd:(Dream.method_to_string @@ Dream.method_ req) + ~url:(Dream.target req) ~dpop_header () + with + | Error "use_dpop_nonce" -> + Lwt.return_error @@ Errors.use_dpop_nonce () + | Error e -> + Lwt.return_error @@ Errors.invalid_request ("dpop error: " ^ e) + | Ok proof -> + Lwt.return_ok (DPoP {proof}) + + let oauth : verifier = + fun {req; db} -> + match parse_dpop req with + | Error e -> + Lwt.return_error @@ Errors.invalid_request ("dpop error: " ^ e) + | Ok token -> ( + match%lwt dpop {req; db} with + | Error e -> + Lwt.return_error e + | Ok (DPoP {proof}) -> ( + match Jwt.verify_jwt token Env.jwt_key with + | Error e -> + Lwt.return_error @@ Errors.auth_required e + | Ok (_header, claims) -> ( + let open Yojson.Safe.Util in + try + let did = claims |> member "sub" |> to_string in + let exp = claims |> member "exp" |> to_int in + let jkt_claim = + claims |> member "cnf" |> member "jkt" |> to_string + in + let now = int_of_float (Unix.gettimeofday ()) in + if jkt_claim <> proof.jkt then + Lwt.return_error @@ Errors.auth_required "dpop key mismatch" + else if exp < now then + Lwt.return_error @@ Errors.auth_required "token expired" + else + let%lwt session = + try%lwt + let%lwt sess = get_session_info did db in + Lwt.return_ok sess + with _ -> + Lwt.return_error + @@ Errors.auth_required "invalid credentials" + in + match session with + | Ok {active= Some true; _} -> + Lwt.return_ok (OAuth {did; proof}) + | Ok _ -> + Lwt.return_error + @@ Errors.auth_required ~name:"AccountDeactivated" + "account is deactivated" + | Error _ -> + Lwt.return_error + @@ Errors.auth_required "invalid credentials" + with _ -> + Lwt.return_error @@ Errors.auth_required "malformed JWT claims" ) + ) + | Ok _ -> + Lwt.return_error @@ Errors.auth_required "invalid credentials" ) let refresh : verifier = fun {req; db} -> match parse_bearer req with | Ok jwt -> ( - match%lwt verify_bearer_jwt db jwt "com.atproto.refresh" with - | Ok {sub= did; jti; _} -> ( - match%lwt Data_store.get_actor_by_identifier did db with - | Some {deactivated_at= None; _} -> - Lwt.return_ok (Refresh {did; jti}) - | Some {deactivated_at= Some _; _} -> - Lwt.return_error - @@ Errors.auth_required ~name:"AccountDeactivated" - "Account is deactivated" - | None -> - Lwt.return_error @@ Errors.auth_required "Invalid credentials" ) - | Error "" | Error _ -> - Lwt.return_error @@ Errors.auth_required "Invalid credentials" ) + match%lwt verify_bearer_jwt db jwt "com.atproto.refresh" with + | Ok {sub= did; jti; _} -> ( + match%lwt Data_store.get_actor_by_identifier did db with + | Some {deactivated_at= None; _} -> + Lwt.return_ok (Refresh {did; jti}) + | Some {deactivated_at= Some _; _} -> + Lwt.return_error + @@ Errors.auth_required ~name:"AccountDeactivated" + "account is deactivated" + | None -> + Lwt.return_error @@ Errors.auth_required "invalid credentials" ) + | Error "" | Error _ -> + Lwt.return_error @@ Errors.auth_required "invalid credentials" ) | Error _ -> - Lwt.return_error @@ Errors.auth_required "Invalid authorization header" + Lwt.return_error @@ Errors.auth_required "invalid authorization header" let authorization : verifier = fun ctx -> @@ -187,24 +263,38 @@ module Verifiers = struct | Some ("Basic" :: _) -> admin ctx | Some ("Bearer" :: _) -> - access ctx + bearer ctx + | Some ("DPoP" :: _) -> + oauth ctx | _ -> Lwt.return_error @@ Errors.auth_required ~name:"InvalidToken" - "Unexpected authorization type" + "unexpected authorization type" let any : verifier = fun ctx -> try authorization ctx with _ -> unauthenticated ctx - type t = Unauthenticated | Admin | Access | Refresh | Authorization | Any + type t = + | Unauthenticated + | Admin + | Bearer + | DPoP + | OAuth + | Refresh + | Authorization + | Any let of_t = function | Unauthenticated -> unauthenticated | Admin -> admin - | Access -> - access + | Bearer -> + bearer + | DPoP -> + dpop + | OAuth -> + oauth | Refresh -> refresh | Authorization -> diff --git a/pegasus/lib/data_store.ml b/pegasus/lib/data_store.ml index 7b27372..3628d7a 100644 --- a/pegasus/lib/data_store.ml +++ b/pegasus/lib/data_store.ml @@ -36,7 +36,7 @@ module Queries = struct created_at INTEGER NOT NULL, deactivated_at INTEGER ) - |sql}] + |sql}] () conn in let$! () = @@ -52,36 +52,126 @@ module Queries = struct [%rapper execute {sql| CREATE TABLE IF NOT EXISTS invite_codes ( - code TEXT PRIMARY KEY, - did TEXT NOT NULL, - remaining INTEGER NOT NULL - ) - |sql}] + code TEXT PRIMARY KEY, + did TEXT NOT NULL, + remaining INTEGER NOT NULL + ) + |sql}] () conn in let$! () = [%rapper execute {sql| CREATE TABLE IF NOT EXISTS firehose ( - seq INTEGER PRIMARY KEY, - time INTEGER NOT NULL, - t TEXT NOT NULL, - data BLOB NOT NULL - ) - |sql}] + seq INTEGER PRIMARY KEY, + time INTEGER NOT NULL, + t TEXT NOT NULL, + data BLOB NOT NULL + ) + |sql}] () conn in - [%rapper - execute - (* no need to store issued tokens, just revoked ones; stolen from millipds https://github.com/DavidBuchanan314/millipds/blob/8f89a01e7d367a2a46f379960e9ca50347dcce71/src/millipds/database.py#L253 *) - {sql| CREATE TABLE IF NOT EXISTS revoked_tokens ( - did TEXT NOT NULL, - jti TEXT NOT NULL, - revoked_at INTEGER NOT NULL, - PRIMARY KEY (did, jti) - ) - |sql}] - () conn + let$! () = + [%rapper + execute + (* no need to store issued tokens, just revoked ones; stolen from millipds https://github.com/DavidBuchanan314/millipds/blob/8f89a01e7d367a2a46f379960e9ca50347dcce71/src/millipds/database.py#L253 *) + {sql| CREATE TABLE IF NOT EXISTS revoked_tokens ( + did TEXT NOT NULL, + jti TEXT NOT NULL, + revoked_at INTEGER NOT NULL, + PRIMARY KEY (did, jti) + ) + |sql}] + () conn + in + let$! () = + [%rapper + execute + {sql| CREATE TABLE IF NOT EXISTS oauth_requests ( + request_id TEXT PRIMARY KEY, + client_id TEXT NOT NULL, + request_data TEXT NOT NULL, + dpop_jkt TEXT, + expires_at INTEGER NOT NULL, + created_at INTEGER NOT NULL + ) + |sql}] + () conn + in + let$! () = + [%rapper + execute + {sql| CREATE TABLE IF NOT EXISTS oauth_codes ( + code TEXT PRIMARY KEY, + request_id TEXT NOT NULL REFERENCES oauth_requests(request_id) ON DELETE CASCADE, + authorized_by TEXT, + authorized_at INTEGER, + expires_at INTEGER NOT NULL, + used BOOLEAN DEFAULT FALSE + ) + |sql}] + () conn + in + let$! () = + [%rapper + execute + {sql| CREATE TABLE IF NOT EXISTS oauth_tokens ( + refresh_token TEXT UNIQUE NOT NULL, + client_id TEXT NOT NULL, + did TEXT NOT NULL, + dpop_jkt TEXT, + scope TEXT NOT NULL, + expires_at INTEGER NOT NULL + ) + |sql}] + () conn + in + let$! () = + [%rapper + execute + {sql| CREATE INDEX IF NOT EXISTS oauth_requests_expires_idx ON oauth_requests(expires_at); + CREATE INDEX IF NOT EXISTS oauth_codes_expires_idx ON oauth_codes(expires_at); + CREATE INDEX IF NOT EXISTS oauth_tokens_refresh_idx ON oauth_tokens(refresh_token); + |sql}] + () conn + in + let$! () = + [%rapper + execute + {sql| CREATE TRIGGER IF NOT EXISTS cleanup_expired_oauth_requests + AFTER INSERT ON oauth_requests + BEGIN + DELETE FROM oauth_requests WHERE expires_at < unixepoch() * 1000; + END + |sql} + syntax_off] + () conn + in + let$! () = + [%rapper + execute + {sql| CREATE TRIGGER IF NOT EXISTS cleanup_expired_oauth_codes + AFTER INSERT ON oauth_codes + BEGIN + DELETE FROM oauth_codes WHERE expires_at < unixepoch() * 1000 OR used = 1; + END + |sql} + syntax_off] + () conn + in + let$! () = + [%rapper + execute + {sql| CREATE TRIGGER IF NOT EXISTS cleanup_expired_oauth_tokens + AFTER INSERT ON oauth_tokens + BEGIN + DELETE FROM oauth_tokens WHERE expires_at < unixepoch() * 1000; + END + |sql} + syntax_off] + () conn + in + Lwt.return_ok () let create_actor = [%rapper @@ -221,6 +311,8 @@ end type t = Util.caqti_pool let connect ?create ?write () : t Lwt.t = + if create = Some true then + Util.mkfile_p Util.Constants.pegasus_db_filepath ~perm:0o644 ; Util.connect_sqlite ?create ?write Util.Constants.pegasus_db_location let init conn : unit Lwt.t = Util.use_pool conn Queries.create_tables diff --git a/pegasus/lib/dune b/pegasus/lib/dune index f8fdf84..03d9eab 100644 --- a/pegasus/lib/dune +++ b/pegasus/lib/dune @@ -9,6 +9,7 @@ cohttp-lwt-unix dns-client.unix dream + html_of_jsx ipld kleidos lwt @@ -18,12 +19,13 @@ safepass str timedesc + uri uuidm yojson lwt_ppx ppx_deriving_yojson.runtime ppx_rapper_lwt) (preprocess - (pps lwt_ppx ppx_deriving_yojson ppx_rapper))) + (pps html_of_jsx.ppx lwt_ppx ppx_deriving_yojson ppx_rapper))) (include_subdirs qualified) diff --git a/pegasus/lib/env.ml b/pegasus/lib/env.ml index f06f5eb..b2c65be 100644 --- a/pegasus/lib/env.ml +++ b/pegasus/lib/env.ml @@ -1,15 +1,35 @@ +let getenv name = + try Sys.getenv name + with Not_found -> failwith ("Missing environment variable " ^ name) + let data_dir = Option.value ~default:"./data" @@ Sys.getenv_opt "DATA_DIR" -let hostname = Sys.getenv "PDS_HOSTNAME" +let hostname = getenv "PDS_HOSTNAME" let did = Option.value ~default:("did:web:" ^ hostname) @@ Sys.getenv_opt "PDS_DID" -let invite_required = Sys.getenv "INVITE_CODE_REQUIRED" = "true" +let invite_required = getenv "INVITE_CODE_REQUIRED" = "true" + +let rotation_key = getenv "ROTATION_KEY_MULTIBASE" |> Kleidos.parse_multikey_str -let rotation_key = - Sys.getenv "ROTATION_KEY_MULTIBASE" |> Kleidos.parse_multikey_str +let jwt_key = getenv "JWK_MULTIBASE" |> Kleidos.parse_multikey_str -let jwt_key = Sys.getenv "JWK_MULTIBASE" |> Kleidos.parse_multikey_str +let admin_password = getenv "ADMIN_PASSWORD" -let admin_password = Sys.getenv "ADMIN_PASSWORD" +let dpop_nonce_secret = + match Sys.getenv_opt "DPOP_NONCE_SECRET" with + | Some sec -> + let secret = + Base64.(decode_exn ~alphabet:uri_safe_alphabet ~pad:false) sec + |> Bytes.of_string + in + if Bytes.length secret = 32 then secret + else failwith "DPOP_NONCE_SECRET must be 32 bytes in base64uri" + | None -> + let secret = Mirage_crypto_rng_unix.getrandom 32 in + Dream.warning (fun log -> + log "DPOP_NONCE_SECRET not set; using DPOP_NONCE_SECRET=%s" + ( Base64.(encode ~alphabet:uri_safe_alphabet ~pad:false) secret + |> Result.get_ok ) ) ; + Bytes.of_string secret diff --git a/pegasus/lib/errors.ml b/pegasus/lib/errors.ml index 9b4248b..1c3fa23 100644 --- a/pegasus/lib/errors.ml +++ b/pegasus/lib/errors.ml @@ -4,6 +4,8 @@ exception InternalServerError of (string * string) exception AuthError of (string * string) +exception UseDpopNonceError + let is_xrpc_error = function | InvalidRequestError _ | InternalServerError _ | AuthError _ -> true @@ -19,6 +21,8 @@ let internal_error ?(name = "InternalServerError") let auth_required ?(name = "AuthRequired") msg = raise (AuthError (name, msg)) +let use_dpop_nonce () = raise UseDpopNonceError + let exn_to_response exn = let format_response error msg status = Dream.json ~status @@ Yojson.Safe.to_string @@ -31,6 +35,8 @@ let exn_to_response exn = format_response error message `Internal_Server_Error | AuthError (error, message) -> format_response error message `Unauthorized + | UseDpopNonceError -> + Dream.json ~status:`Bad_Request {|{ "error": "use_dpop_nonce" }|} | _ -> format_response "InternalServerError" "Internal server error" `Internal_Server_Error diff --git a/pegasus/lib/id_resolver.ml b/pegasus/lib/id_resolver.ml index 32d6593..43ef302 100644 --- a/pegasus/lib/id_resolver.ml +++ b/pegasus/lib/id_resolver.ml @@ -1,5 +1,4 @@ open Cohttp_lwt -open Cohttp_lwt_unix let did_regex = Str.regexp {|^did:([a-z]+):([a-zA-Z0-9._:%\-]*[a-zA-Z0-9._\-])$|} @@ -12,7 +11,7 @@ module Handle = struct let uri = Uri.of_string ("https://" ^ handle ^ "/.well-known/atproto-did") in - let%lwt {status; _}, body = Client.get uri in + let%lwt {status; _}, body = Util.http_get uri in match status with | `OK -> let%lwt did = Body.to_string body in @@ -164,7 +163,7 @@ module Did = struct ~path:(Uri.pct_encode did) () in let%lwt {status; _}, body = - Client.get uri + Util.http_get uri ~headers:(Cohttp.Header.of_list [("Accept", "application/json")]) in match status with @@ -186,7 +185,7 @@ module Did = struct ~path:"/.well-known/did.json" () in let%lwt {status; _}, body = - Client.get uri + Util.http_get uri ~headers:(Cohttp.Header.of_list [("Accept", "application/json")]) in match status with diff --git a/pegasus/lib/jwt.ml b/pegasus/lib/jwt.ml index 26f4160..3e32ff3 100644 --- a/pegasus/lib/jwt.ml +++ b/pegasus/lib/jwt.ml @@ -19,9 +19,9 @@ let b64_encode str = let b64_decode str = match Base64.decode ~pad:false ~alphabet:Base64.uri_safe_alphabet str with | Ok s -> - Ok s + s | Error (`Msg e) -> - Error e + failwith e let extract_signature_components signature = if Bytes.length signature <> 64 then failwith "expected 64 byte jwt signature" @@ -30,7 +30,7 @@ let extract_signature_components signature = let s = Bytes.sub signature 32 32 in (r, s) -let sign_jwt payload signing_key = +let sign_jwt payload ?(typ = "JWT") signing_key = let _, (module Curve : Kleidos.CURVE) = signing_key in let alg = match Curve.name with @@ -51,7 +51,7 @@ let sign_jwt payload signing_key = failwith "invalid curve" in let header_json = - `Assoc [("alg", `String alg); ("crv", `String crv); ("typ", `String "JWT")] + `Assoc [("alg", `String alg); ("crv", `String crv); ("typ", `String typ)] in let encoded_header = header_json |> Yojson.Safe.to_string |> b64_encode in let encoded_payload = payload |> Yojson.Safe.to_string |> b64_encode in @@ -65,32 +65,24 @@ let sign_jwt payload signing_key = let decode_jwt jwt = match String.split_on_char '.' jwt with | [header_b64; payload_b64; _] -> ( - match (b64_decode header_b64, b64_decode payload_b64) with - | Ok header_str, Ok payload_str -> ( - try - let header = Yojson.Safe.from_string header_str in - let payload = Yojson.Safe.from_string payload_str in - Ok (header, payload) - with _ -> Error "invalid json in jwt" ) - | Error e, _ | _, Error e -> - Error e ) + try + let header = Yojson.Safe.from_string (b64_decode header_b64) in + let payload = Yojson.Safe.from_string (b64_decode payload_b64) in + Ok (header, payload) + with _ -> Error "invalid jwt" ) | _ -> Error "invalid jwt format" let verify_jwt jwt pubkey = match String.split_on_char '.' jwt with - | [header_b64; payload_b64; signature_b64] -> ( - match b64_decode signature_b64 with - | Error e -> - Error e - | Ok signature_str -> - let signature = Bytes.of_string signature_str in - let signing_input = header_b64 ^ "." ^ payload_b64 in - let verified = - Kleidos.verify ~pubkey ~msg:(Bytes.of_string signing_input) ~signature - in - if verified then decode_jwt jwt - else Error "jwt signature verification failed" ) + | [header_b64; payload_b64; signature_b64] -> + let signature = Bytes.of_string (b64_decode signature_b64) in + let signing_input = header_b64 ^ "." ^ payload_b64 in + let verified = + Kleidos.verify ~pubkey ~msg:(Bytes.of_string signing_input) ~signature + in + if verified then decode_jwt jwt + else Error "jwt signature verification failed" | _ -> Error "invalid jwt format" @@ -98,7 +90,9 @@ let generate_jwt did = let now_s = int_of_float (Unix.gettimeofday ()) in let access_exp = now_s + Defaults.access_token_exp in let refresh_exp = now_s + Defaults.refresh_token_exp in - let jti = Uuidm.v4_gen (Random.get_state ()) () |> Uuidm.to_string in + let jti = + Uuidm.v4_gen (Random.State.make_self_init ()) () |> Uuidm.to_string + in let access_payload = symmetric_jwt_to_yojson { scope= "com.atproto.access" diff --git a/pegasus/lib/oauth/client.ml b/pegasus/lib/oauth/client.ml new file mode 100644 index 0000000..d460457 --- /dev/null +++ b/pegasus/lib/oauth/client.ml @@ -0,0 +1,45 @@ +open Types + +let fetch_client_metadata client_id : client_metadata Lwt.t = + let%lwt {status; _}, res = Util.http_get (Uri.of_string client_id) in + if status <> `OK then + let%lwt () = Cohttp_lwt.Body.drain_body res in + failwith + (Printf.sprintf "client metadata not found; http %d" + (Cohttp.Code.code_of_status status) ) + else + let%lwt body = Cohttp_lwt.Body.to_string res in + let json = Yojson.Safe.from_string body in + let metadata = + match client_metadata_of_yojson json with + | Ok metadata -> + metadata + | Error err -> + failwith err + in + if metadata.client_id <> client_id then failwith "client_id mismatch" + else + let scopes = String.split_on_char ' ' metadata.scope in + if not (List.mem "atproto" scopes) then + failwith "scope must include 'atproto'" + else + List.iter + (function + | "authorization_code" | "refresh_token" -> + () + | grant -> + failwith ("invalid grant type: " ^ grant) ) + metadata.grant_types ; + List.iter + (fun uri -> + let u = Uri.of_string uri in + let host = Uri.host u in + match Uri.scheme u with + | Some "https" when host <> Some "localhost" -> + () + | Some "http" when host = Some "127.0.0.1" || host = Some "[::1]" -> + () + | _ -> + failwith ("invalid redirect_uri: " ^ uri) ) + metadata.redirect_uris ; + Lwt.return metadata diff --git a/pegasus/lib/oauth/constants.ml b/pegasus/lib/oauth/constants.ml new file mode 100644 index 0000000..cedbe25 --- /dev/null +++ b/pegasus/lib/oauth/constants.ml @@ -0,0 +1,15 @@ +let max_dpop_age_s = 60 + +let dpop_rotation_interval_ms = 60_000L + +let jti_ttl_s = 3600 + +let jti_cache_size = 10_000 + +let par_request_ttl_ms = 300_000 + +let code_expiry_ms = 300_000 + +let access_token_expiry_ms = 60 * 60 * 1000 + +let request_uri_prefix = "urn:ietf:params:oauth:request_uri:" diff --git a/pegasus/lib/oauth/dpop.ml b/pegasus/lib/oauth/dpop.ml new file mode 100644 index 0000000..a083146 --- /dev/null +++ b/pegasus/lib/oauth/dpop.ml @@ -0,0 +1,204 @@ +type nonce_state = + { secret: bytes + ; mutable counter: int64 + ; mutable prev: string + ; mutable curr: string + ; mutable next: string + ; rotation_interval_ms: int64 } + +type ec_jwk = {crv: string; kty: string; x: string; y: string} +[@@deriving yojson] + +type proof = {jti: string; jkt: string; htm: string; htu: string} +[@@deriving yojson] + +let jti_cache : (string, int) Hashtbl.t = + Hashtbl.create Constants.jti_cache_size + +let cleanup_jti_cache () = + let now = int_of_float (Unix.gettimeofday ()) in + Hashtbl.filter_map_inplace + (fun _ expires_at -> if expires_at > now then Some expires_at else None) + jti_cache + +let compute_nonce secret counter = + let data = Bytes.create 8 in + Bytes.set_int64_be data 0 counter ; + Digestif.SHA256.( + hmac_bytes ~key:(Bytes.to_string secret) data + |> to_raw_string |> Jwt.b64_encode ) + +let create_nonce_state secret = + let counter = + Int64.div + (Int64.of_float (Unix.gettimeofday () *. 1000.)) + Constants.dpop_rotation_interval_ms + in + { secret + ; counter + ; prev= compute_nonce secret (Int64.pred counter) + ; curr= compute_nonce secret counter + ; next= compute_nonce secret (Int64.succ counter) + ; rotation_interval_ms= Constants.dpop_rotation_interval_ms } + +let nonce_state = ref (create_nonce_state Env.dpop_nonce_secret) + +let next_nonce () = + let now_counter = + Int64.div + (Int64.of_float (Unix.gettimeofday () *. 1000.)) + !nonce_state.rotation_interval_ms + in + if now_counter <> !nonce_state.counter then ( + !nonce_state.prev <- !nonce_state.curr ; + !nonce_state.curr <- !nonce_state.next ; + !nonce_state.next <- + compute_nonce !nonce_state.secret (Int64.succ now_counter) ; + !nonce_state.counter <- now_counter ) ; + !nonce_state.next + +let verify_nonce nonce = + let valid = + nonce = !nonce_state.prev || nonce = !nonce_state.curr + || nonce = !nonce_state.next + in + ignore next_nonce ; valid + +let add_jti jti = + let expires_at = int_of_float (Unix.gettimeofday ()) + Constants.jti_ttl_s in + if Hashtbl.mem jti_cache jti then false (* replay *) + else ( + Hashtbl.add jti_cache jti expires_at ; + (* clean up every once in a while *) + if Hashtbl.length jti_cache mod 100 = 0 then cleanup_jti_cache () ; + true ) + +let normalize_url url = + let uri = Uri.of_string url in + Uri.make ~scheme:"https" + ~host:(Uri.host uri |> Option.value ~default:Env.hostname) + ~path:(Uri.path uri) () + |> Uri.to_string + +let compute_jwk_thumbprint jwk = + let {crv; kty; x; y} = jwk in + let tp = + (* keys must be in lexicographic order *) + Printf.sprintf {|{"crv":"%s","kty":"%s","x":"%s","y":"%s"}|} crv kty x y + in + Digestif.SHA256.(digest_string tp |> to_raw_string |> Jwt.b64_encode) + +let verify_signature jwt jwk = + let parts = String.split_on_char '.' jwt in + match parts with + | [header_b64; payload_b64; sig_b64] -> + let signing_input = header_b64 ^ "." ^ payload_b64 in + let msg = Bytes.of_string signing_input in + let {x; y; crv; _} = jwk in + let x = x |> Jwt.b64_decode |> Bytes.of_string in + let y = y |> Jwt.b64_decode |> Bytes.of_string in + let pubkey = Bytes.cat (Bytes.of_string "\x04") (Bytes.cat x y) in + let pubkey = + ( pubkey + , match crv with + | "secp256k1" -> + (module Kleidos.K256 : Kleidos.CURVE) + | "P-256" -> + (module Kleidos.P256 : Kleidos.CURVE) + | _ -> + failwith "unsupported algorithm" ) + in + let sig_bytes = Jwt.b64_decode sig_b64 |> Bytes.of_string in + let r = Bytes.sub sig_bytes 0 32 in + let s = Bytes.sub sig_bytes 32 32 in + let signature = Bytes.cat r s in + Kleidos.verify ~pubkey ~msg ~signature + | _ -> + false + +let verify_dpop_proof ~mthd ~url ~dpop_header ?access_token () = + match dpop_header with + | None -> + Error "missing dpop header" + | Some jwt -> ( + let open Yojson.Safe.Util in + match String.split_on_char '.' jwt with + | [header_b64; payload_b64; _] -> ( + let header = Yojson.Safe.from_string (Jwt.b64_decode header_b64) in + let payload = Yojson.Safe.from_string (Jwt.b64_decode payload_b64) in + let typ = header |> member "typ" |> to_string in + if typ <> "dpop+jwt" then Error "invalid typ in dpop proof" + else + let alg = header |> member "alg" |> to_string in + if alg <> "ES256" && alg <> "ES256K" then + Error "only es256 and es256k supported for dpop" + else + let jwk = + header |> member "jwk" |> ec_jwk_of_yojson |> Result.get_ok + in + if + not + ( match (alg, jwk.crv) with + | "ES256", "P-256" -> + true + | "ES256K", "secp256k1" -> + true + | _ -> + false ) + then + Error + (Printf.sprintf "algorithm %s doesn't match curve %s" alg + jwk.crv ) + else + let jti = payload |> member "jti" |> to_string in + let htm = payload |> member "htm" |> to_string in + let htu = payload |> member "htu" |> to_string in + let iat = payload |> member "iat" |> to_int in + let nonce_claim = + payload |> member "nonce" |> to_string_option + in + match nonce_claim with + (* error must be this string; see https://datatracker.ietf.org/doc/html/rfc9449#section-8 *) + | None -> + Error "use_dpop_nonce" + | Some n when not (verify_nonce n) -> + Error "use_dpop_nonce" + | Some _ -> ( + if htm <> mthd then Error "htm mismatch" + else if + not (String.equal (normalize_url htu) (normalize_url url)) + then Error "htu mismatch" + else + let now = int_of_float (Unix.gettimeofday ()) in + if now - iat > Constants.max_dpop_age_s then + Error "dpop proof too old" + else if iat - now > 5 then Error "dpop proof in future" + else if not (add_jti jti) then + Error "dpop proof replay detected" + else if not (verify_signature jwt jwk) then + Error "invalid dpop signature" + else + let jkt = compute_jwk_thumbprint jwk in + (* verify ath if access token is provided *) + match access_token with + | Some token -> + let ath_claim = + payload |> member "ath" |> to_string_option + in + let expected_ath = + Digestif.SHA256.( + digest_string token |> to_raw_string + |> Jwt.b64_encode ) + in + if Some expected_ath <> ath_claim then + Error "ath mismatch" + else Ok {jti; jkt; htm; htu} + | None -> + let ath_claim = + payload |> member "ath" |> to_string_option + in + if ath_claim <> None then + Error "ath claim not allowed without access token" + else Ok {jti; jkt; htm; htu} ) ) + | _ -> + Error "invalid dpop jwt" ) diff --git a/pegasus/lib/oauth/queries.ml b/pegasus/lib/oauth/queries.ml new file mode 100644 index 0000000..6fa9d89 --- /dev/null +++ b/pegasus/lib/oauth/queries.ml @@ -0,0 +1,138 @@ +[@@@warning "-missing-record-field-pattern"] + +open Types + +let insert_par_request conn req = + Util.use_pool conn + @@ [%rapper + execute + {sql| + INSERT INTO oauth_requests (request_id, client_id, request_data, dpop_jkt, expires_at, created_at) + VALUES (%string{request_id}, %string{client_id}, %string{request_data}, %string?{dpop_jkt}, %int{expires_at}, %int{created_at}) + |sql} + record_in] + req + +let get_par_request conn request_id = + Util.use_pool conn + @@ [%rapper + get_opt + {sql| + SELECT @string{request_id}, @string{client_id}, @string{request_data}, + @string?{dpop_jkt}, @int{expires_at}, @int{created_at} + FROM oauth_requests + WHERE request_id = %string{request_id} + AND expires_at > %int{now} + |sql} + record_out] + ~request_id ~now:(Util.now_ms ()) + +let insert_auth_code conn code = + Util.use_pool conn + @@ [%rapper + execute + {sql| + INSERT INTO oauth_codes (code, request_id, authorized_by, authorized_at, expires_at, used) + VALUES (%string{code}, %string{request_id}, %string?{authorized_by}, %int?{authorized_at}, %int{expires_at}, 0) + |sql} + record_in] + code + +let get_auth_code conn code = + Util.use_pool conn + @@ [%rapper + get_opt + {sql| + SELECT @string{code}, @string{request_id}, @string?{authorized_by}, + @int?{authorized_at}, @int{expires_at}, @bool{used} + FROM oauth_codes + WHERE code = %string{code} + |sql} + record_out] + ~code + +let activate_auth_code conn code did = + let authorized_at = Util.now_ms () in + Util.use_pool conn + @@ [%rapper + execute + {sql| + UPDATE oauth_codes + SET authorized_by = %string{did}, + authorized_at = %int{authorized_at} + WHERE code = %string{code} + |sql}] + ~did ~authorized_at ~code + +let consume_auth_code conn code = + Util.use_pool conn + @@ [%rapper + get_opt + {sql| + UPDATE oauth_codes + SET used = 1 + WHERE code = %string{code} AND used = 0 + RETURNING @string{code}, @string{request_id}, @string?{authorized_by}, + @int?{authorized_at}, @int{expires_at}, @bool{used} + |sql} + record_out] + ~code + +let insert_oauth_token conn token = + Util.use_pool conn + @@ [%rapper + execute + {sql| + INSERT INTO oauth_tokens (refresh_token, client_id, did, dpop_jkt, scope, expires_at) + VALUES (%string{refresh_token}, %string{client_id}, %string{did}, %string{dpop_jkt}, %string{scope}, %int{expires_at}) + |sql} + record_in] + token + +let get_oauth_token_by_refresh conn refresh_token = + Util.use_pool conn + @@ [%rapper + get_opt + {sql| + SELECT @string{refresh_token}, @string{client_id}, @string{did}, + @string{dpop_jkt}, @string{scope}, @int{expires_at} + FROM oauth_tokens + WHERE refresh_token = %string{refresh_token} + |sql} + record_out] + ~refresh_token + +let update_oauth_token conn ~old_refresh_token ~new_refresh_token ~expires_at = + Util.use_pool conn + @@ [%rapper + execute + {sql| + UPDATE oauth_tokens + SET refresh_token = %string{new_refresh_token}, + expires_at = %int{expires_at} + WHERE refresh_token = %string{old_refresh_token} + |sql}] + ~new_refresh_token ~expires_at ~old_refresh_token + +let delete_oauth_token_by_refresh conn refresh_token = + Util.use_pool conn + @@ [%rapper + execute + {sql| + DELETE FROM oauth_tokens WHERE refresh_token = %string{refresh_token} + |sql}] + ~refresh_token + +let get_oauth_tokens_by_did conn did = + Util.use_pool conn + @@ [%rapper + get_many + {sql| + SELECT @string{refresh_token}, @string{client_id}, @string{did}, + @string{dpop_jkt}, @string{scope}, @int{expires_at} + FROM oauth_tokens + WHERE did = %string{did} + ORDER BY expires_at ASC + |sql} + record_out] + ~did diff --git a/pegasus/lib/oauth/types.ml b/pegasus/lib/oauth/types.ml new file mode 100644 index 0000000..0fd1333 --- /dev/null +++ b/pegasus/lib/oauth/types.ml @@ -0,0 +1,71 @@ +type par_request = + { client_id: string + ; response_type: string + ; response_mode: string option [@default None] + ; redirect_uri: string + ; scope: string + ; state: string + ; code_challenge: string + ; code_challenge_method: string + ; login_hint: string option [@default None] + ; dpop_jkt: string option [@default None] + ; client_assertion_type: string option [@default None] + ; client_assertion: string option [@default None] } +[@@deriving yojson {strict= false}] + +type token_request = + { grant_type: string + ; code: string option [@default None] + ; redirect_uri: string option [@default None] + ; code_verifier: string option [@default None] + ; refresh_token: string option [@default None] + ; client_id: string + ; client_assertion_type: string option [@default None] + ; client_assertion: string option [@default None] } +[@@deriving yojson {strict= false}] + +type client_metadata = + { client_id: string + ; client_name: string option [@default None] + ; client_uri: string + ; redirect_uris: string list + ; grant_types: string list + ; response_types: string list + ; scope: string + ; token_endpoint_auth_method: string + ; token_endpoint_auth_signing_alg: string option [@default None] + ; application_type: string + ; dpop_bound_access_tokens: bool + ; jwks_uri: string option [@default None] + ; jwks: Yojson.Safe.t option [@default None] } +[@@deriving yojson {strict= false}] + +type dpop_proof = {jti: string; jkt: string; htm: string; htu: string} +[@@deriving yojson {strict= false}] + +type oauth_request = + { request_id: string + ; client_id: string + ; request_data: string + ; dpop_jkt: string option [@default None] + ; expires_at: int + ; created_at: int } +[@@deriving yojson {strict= false}] + +type oauth_code = + { code: string + ; request_id: string + ; authorized_by: string option [@default None] + ; authorized_at: int option [@default None] + ; expires_at: int + ; used: bool } +[@@deriving yojson {strict= false}] + +type oauth_token = + { refresh_token: string + ; client_id: string + ; did: string + ; dpop_jkt: string + ; scope: string + ; expires_at: int } +[@@deriving yojson {strict= false}] diff --git a/pegasus/lib/plc.ml b/pegasus/lib/plc.ml index e680832..02740a0 100644 --- a/pegasus/lib/plc.ml +++ b/pegasus/lib/plc.ml @@ -302,7 +302,7 @@ let get_audit_log ?endpoint did : (audit_log, string) Lwt_result.t = did in let headers = Http.Header.init_with "Accept" "application/json" in - let%lwt res, body = Client.get ~headers uri in + let%lwt res, body = Util.http_get ~headers uri in match res.status with | `OK -> let%lwt body = Body.to_string body in diff --git a/pegasus/lib/repository.ml b/pegasus/lib/repository.ml index b72f12b..1dfbf3e 100644 --- a/pegasus/lib/repository.ml +++ b/pegasus/lib/repository.ml @@ -180,7 +180,7 @@ let list_all_records t collection : (string * Cid.t * record) list Lwt.t = let%lwt map = get_map t in String_map.bindings map |> List.filter (fun (path, _) -> - String.starts_with ~prefix:(path ^ "/") collection ) + String.starts_with ~prefix:(path ^ "/") collection ) |> Lwt_list.fold_left_s (fun acc (path, cid) -> match%lwt User_store.get_record t.db path with @@ -320,16 +320,16 @@ let apply_writes (t : t) (writes : repo_write list) (swap_commit : Cid.t option) let%lwt () = match old_cid with | Some _ -> ( - match%lwt User_store.get_record t.db path with - | Some record -> - let refs = - Util.find_blob_refs record.value - |> List.map (fun (r : Mist.Blob_ref.t) -> r.ref) - in - let%lwt () = User_store.clear_blob_refs t.db path refs in - Lwt.return_unit - | None -> - Lwt.return_unit ) + match%lwt User_store.get_record t.db path with + | Some record -> + let refs = + Util.find_blob_refs record.value + |> List.map (fun (r : Mist.Blob_ref.t) -> r.ref) + in + let%lwt () = User_store.clear_blob_refs t.db path refs in + Lwt.return_unit + | None -> + Lwt.return_unit ) | None -> Lwt.return_unit in diff --git a/pegasus/lib/sequencer.ml b/pegasus/lib/sequencer.ml index 5173a25..20cf58f 100644 --- a/pegasus/lib/sequencer.ml +++ b/pegasus/lib/sequencer.ml @@ -330,7 +330,7 @@ module Parse = struct let blobs = j |> member "blobs" |> to_list |> List.filter_map (fun x -> - match Cid.of_yojson x with Ok c -> Some c | _ -> None ) + match Cid.of_yojson x with Ok c -> Some c | _ -> None ) in let prev_data = match j |> member "prevData" with @@ -342,33 +342,33 @@ module Parse = struct let ops = j |> member "ops" |> to_list |> List.map (fun opj -> - let action = - match opj |> member "action" |> to_string with - | "create" -> - `Create - | "update" -> - `Update - | "delete" -> - `Delete - | _ -> - `Create - in - let path = opj |> member "path" |> to_string in - let cid = - match opj |> member "cid" with - | `Null -> - None - | v -> ( - match Cid.of_yojson v with Ok c -> Some c | _ -> None ) - in - let prev = - match opj |> member "prev" with - | `Null -> - None - | v -> ( - match Cid.of_yojson v with Ok c -> Some c | _ -> None ) - in - {action; path; cid; prev} ) + let action = + match opj |> member "action" |> to_string with + | "create" -> + `Create + | "update" -> + `Update + | "delete" -> + `Delete + | _ -> + `Create + in + let path = opj |> member "path" |> to_string in + let cid = + match opj |> member "cid" with + | `Null -> + None + | v -> ( + match Cid.of_yojson v with Ok c -> Some c | _ -> None ) + in + let prev = + match opj |> member "prev" with + | `Null -> + None + | v -> ( + match Cid.of_yojson v with Ok c -> Some c | _ -> None ) + in + {action; path; cid; prev} ) in Ok { rebase diff --git a/pegasus/lib/templates/components/button.mlx b/pegasus/lib/templates/components/button.mlx new file mode 100644 index 0000000..fd8bb0b --- /dev/null +++ b/pegasus/lib/templates/components/button.mlx @@ -0,0 +1,33 @@ +let base_classes = + "py-1 px-4 text-lg rounded-lg w-full flex items-center justify-center \ + transition delay-50 duration-300 focus-visible:outline-none disabled:text-mist-80" + +type kind = Primary | Secondary | Tertiary | Danger + +let classes = function + | Primary -> + base_classes + ^ " bg-white font-serif text-mana-200 shadow-whisper \ + hover:shadow-shimmer hover:bg-mist-20 focus-visible:shadow-shimmer \ + focus-visible:bg-mist-20 active:shadow-glow disabled:bg-mana-40" + | Secondary -> + base_classes + ^ " bg-feather font-serif underline text-mana-100 hover:no-underline \ + focus-visible:shadow-whisper active:shadow-whisper disabled:no-underline \ + disabled:bg-mana-40" + | Tertiary -> + base_classes + ^ " font-sans underline text-mana-100 hover:no-underline \ + focus-visible:text-mana-200 active:text-mana-200" + | Danger -> + base_classes + ^ " bg-white font-serif text-phoenix-100 shadow-bleed hover:bg-mist-20 \ + hover:text-phoenix-40 focus:bg-mist-20 focus:text-phoenix-40 \ + focus-visible:outline-none active:bg-phoenix-40 active:text-mist-20 \ + disabled:bg-mana-40" + +let make ?id ?name ?(kind = Primary) ?(type_ = "button") ?onclick ?value + ?(class_ = "") ~children () = + diff --git a/pegasus/lib/templates/components/input.mlx b/pegasus/lib/templates/components/input.mlx new file mode 100644 index 0000000..093216e --- /dev/null +++ b/pegasus/lib/templates/components/input.mlx @@ -0,0 +1,56 @@ +open JSX + +(* putting this inline messes with ocamlformat-mlx *) +let req_marker = " *" + +let make ?id ~name ?(class_ = "") ?(type_ = "text") ?label ?(sr_only = false) + ?value ?placeholder ?(required = false) ?(disabled = false) ?trailing () = + let id = Option.value id ~default:name in + let placeholder = if label <> None && sr_only then label else placeholder in + let input = + + in +
+ "You’re signing into " + rendered_name + " as " + rendered_handle + " and granting it the following permissions:" +
+