1package server
2
3import (
4 "net/url"
5 "strings"
6 "time"
7
8 "github.com/Azure/go-autorest/autorest/to"
9 "github.com/haileyok/cocoon/internal/helpers"
10 "github.com/haileyok/cocoon/oauth"
11 "github.com/haileyok/cocoon/oauth/provider"
12 "github.com/labstack/echo/v4"
13)
14
15func (s *Server) handleOauthAuthorizeGet(e echo.Context) error {
16 reqUri := e.QueryParam("request_uri")
17 if reqUri == "" {
18 // render page for logged out dev
19 if s.config.Version == "dev" {
20 return e.Render(200, "authorize.html", map[string]any{
21 "Scopes": []string{"atproto", "transition:generic"},
22 "AppName": "DEV MODE AUTHORIZATION PAGE",
23 "Handle": "paula.cocoon.social",
24 "RequestUri": "",
25 })
26 }
27 return helpers.InputError(e, to.StringPtr("no request uri"))
28 }
29
30 repo, _, err := s.getSessionRepoOrErr(e)
31 if err != nil {
32 return e.Redirect(303, "/account/signin?"+e.QueryParams().Encode())
33 }
34
35 reqId, err := oauth.DecodeRequestUri(reqUri)
36 if err != nil {
37 return helpers.InputError(e, to.StringPtr(err.Error()))
38 }
39
40 var req provider.OauthAuthorizationRequest
41 if err := s.db.Raw("SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&req).Error; err != nil {
42 return helpers.ServerError(e, to.StringPtr(err.Error()))
43 }
44
45 clientId := e.QueryParam("client_id")
46 if clientId != req.ClientId {
47 return helpers.InputError(e, to.StringPtr("client id does not match the client id for the supplied request"))
48 }
49
50 client, err := s.oauthProvider.ClientManager.GetClient(e.Request().Context(), req.ClientId)
51 if err != nil {
52 return helpers.ServerError(e, to.StringPtr(err.Error()))
53 }
54
55 scopes := strings.Split(req.Parameters.Scope, " ")
56 appName := client.Metadata.ClientName
57
58 data := map[string]any{
59 "Scopes": scopes,
60 "AppName": appName,
61 "RequestUri": reqUri,
62 "QueryParams": e.QueryParams().Encode(),
63 "Handle": repo.Actor.Handle,
64 }
65
66 return e.Render(200, "authorize.html", data)
67}
68
69type OauthAuthorizePostRequest struct {
70 RequestUri string `form:"request_uri"`
71 AcceptOrRejct string `form:"accept_or_reject"`
72}
73
74func (s *Server) handleOauthAuthorizePost(e echo.Context) error {
75 repo, _, err := s.getSessionRepoOrErr(e)
76 if err != nil {
77 return e.Redirect(303, "/account/signin")
78 }
79
80 var req OauthAuthorizePostRequest
81 if err := e.Bind(&req); err != nil {
82 s.logger.Error("error binding authorize post request", "error", err)
83 return helpers.InputError(e, nil)
84 }
85
86 reqId, err := oauth.DecodeRequestUri(req.RequestUri)
87 if err != nil {
88 return helpers.InputError(e, to.StringPtr(err.Error()))
89 }
90
91 var authReq provider.OauthAuthorizationRequest
92 if err := s.db.Raw("SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&authReq).Error; err != nil {
93 return helpers.ServerError(e, to.StringPtr(err.Error()))
94 }
95
96 client, err := s.oauthProvider.ClientManager.GetClient(e.Request().Context(), authReq.ClientId)
97 if err != nil {
98 return helpers.ServerError(e, to.StringPtr(err.Error()))
99 }
100
101 // TODO: figure out how im supposed to actually redirect
102 if req.AcceptOrRejct == "reject" {
103 return e.Redirect(303, client.Metadata.ClientURI)
104 }
105
106 if time.Now().After(authReq.ExpiresAt) {
107 return helpers.InputError(e, to.StringPtr("the request has expired"))
108 }
109
110 if authReq.Sub != nil || authReq.Code != nil {
111 return helpers.InputError(e, to.StringPtr("this request was already authorized"))
112 }
113
114 code := oauth.GenerateCode()
115
116 if err := s.db.Exec("UPDATE oauth_authorization_requests SET sub = ?, code = ?, accepted = ?, ip = ? WHERE request_id = ?", nil, repo.Repo.Did, code, true, e.RealIP(), reqId).Error; err != nil {
117 s.logger.Error("error updating authorization request", "error", err)
118 return helpers.ServerError(e, nil)
119 }
120
121 q := url.Values{}
122 q.Set("state", authReq.Parameters.State)
123 q.Set("iss", "https://"+s.config.Hostname)
124 q.Set("code", code)
125
126 hashOrQuestion := "?"
127 if authReq.ClientAuth.Method != "private_key_jwt" {
128 hashOrQuestion = "#"
129 }
130
131 return e.Redirect(303, authReq.Parameters.RedirectURI+hashOrQuestion+q.Encode())
132}