An atproto PDS written in Go
at main 4.0 kB view raw
1package server 2 3import ( 4 "net/url" 5 "strings" 6 "time" 7 8 "github.com/Azure/go-autorest/autorest/to" 9 "github.com/haileyok/cocoon/internal/helpers" 10 "github.com/haileyok/cocoon/oauth" 11 "github.com/haileyok/cocoon/oauth/provider" 12 "github.com/labstack/echo/v4" 13) 14 15func (s *Server) handleOauthAuthorizeGet(e echo.Context) error { 16 reqUri := e.QueryParam("request_uri") 17 if reqUri == "" { 18 // render page for logged out dev 19 if s.config.Version == "dev" { 20 return e.Render(200, "authorize.html", map[string]any{ 21 "Scopes": []string{"atproto", "transition:generic"}, 22 "AppName": "DEV MODE AUTHORIZATION PAGE", 23 "Handle": "paula.cocoon.social", 24 "RequestUri": "", 25 }) 26 } 27 return helpers.InputError(e, to.StringPtr("no request uri")) 28 } 29 30 repo, _, err := s.getSessionRepoOrErr(e) 31 if err != nil { 32 return e.Redirect(303, "/account/signin?"+e.QueryParams().Encode()) 33 } 34 35 reqId, err := oauth.DecodeRequestUri(reqUri) 36 if err != nil { 37 return helpers.InputError(e, to.StringPtr(err.Error())) 38 } 39 40 var req provider.OauthAuthorizationRequest 41 if err := s.db.Raw("SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&req).Error; err != nil { 42 return helpers.ServerError(e, to.StringPtr(err.Error())) 43 } 44 45 clientId := e.QueryParam("client_id") 46 if clientId != req.ClientId { 47 return helpers.InputError(e, to.StringPtr("client id does not match the client id for the supplied request")) 48 } 49 50 client, err := s.oauthProvider.ClientManager.GetClient(e.Request().Context(), req.ClientId) 51 if err != nil { 52 return helpers.ServerError(e, to.StringPtr(err.Error())) 53 } 54 55 scopes := strings.Split(req.Parameters.Scope, " ") 56 appName := client.Metadata.ClientName 57 58 data := map[string]any{ 59 "Scopes": scopes, 60 "AppName": appName, 61 "RequestUri": reqUri, 62 "QueryParams": e.QueryParams().Encode(), 63 "Handle": repo.Actor.Handle, 64 } 65 66 return e.Render(200, "authorize.html", data) 67} 68 69type OauthAuthorizePostRequest struct { 70 RequestUri string `form:"request_uri"` 71 AcceptOrRejct string `form:"accept_or_reject"` 72} 73 74func (s *Server) handleOauthAuthorizePost(e echo.Context) error { 75 repo, _, err := s.getSessionRepoOrErr(e) 76 if err != nil { 77 return e.Redirect(303, "/account/signin") 78 } 79 80 var req OauthAuthorizePostRequest 81 if err := e.Bind(&req); err != nil { 82 s.logger.Error("error binding authorize post request", "error", err) 83 return helpers.InputError(e, nil) 84 } 85 86 reqId, err := oauth.DecodeRequestUri(req.RequestUri) 87 if err != nil { 88 return helpers.InputError(e, to.StringPtr(err.Error())) 89 } 90 91 var authReq provider.OauthAuthorizationRequest 92 if err := s.db.Raw("SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&authReq).Error; err != nil { 93 return helpers.ServerError(e, to.StringPtr(err.Error())) 94 } 95 96 client, err := s.oauthProvider.ClientManager.GetClient(e.Request().Context(), authReq.ClientId) 97 if err != nil { 98 return helpers.ServerError(e, to.StringPtr(err.Error())) 99 } 100 101 // TODO: figure out how im supposed to actually redirect 102 if req.AcceptOrRejct == "reject" { 103 return e.Redirect(303, client.Metadata.ClientURI) 104 } 105 106 if time.Now().After(authReq.ExpiresAt) { 107 return helpers.InputError(e, to.StringPtr("the request has expired")) 108 } 109 110 if authReq.Sub != nil || authReq.Code != nil { 111 return helpers.InputError(e, to.StringPtr("this request was already authorized")) 112 } 113 114 code := oauth.GenerateCode() 115 116 if err := s.db.Exec("UPDATE oauth_authorization_requests SET sub = ?, code = ?, accepted = ?, ip = ? WHERE request_id = ?", nil, repo.Repo.Did, code, true, e.RealIP(), reqId).Error; err != nil { 117 s.logger.Error("error updating authorization request", "error", err) 118 return helpers.ServerError(e, nil) 119 } 120 121 q := url.Values{} 122 q.Set("state", authReq.Parameters.State) 123 q.Set("iss", "https://"+s.config.Hostname) 124 q.Set("code", code) 125 126 hashOrQuestion := "?" 127 if authReq.ClientAuth.Method != "private_key_jwt" { 128 hashOrQuestion = "#" 129 } 130 131 return e.Redirect(303, authReq.Parameters.RedirectURI+hashOrQuestion+q.Encode()) 132}