1package server
2
3import (
4 "bytes"
5 "crypto/sha256"
6 "encoding/base64"
7 "fmt"
8 "slices"
9 "time"
10
11 "github.com/Azure/go-autorest/autorest/to"
12 "github.com/golang-jwt/jwt/v4"
13 "github.com/haileyok/cocoon/internal/helpers"
14 "github.com/haileyok/cocoon/oauth"
15 "github.com/haileyok/cocoon/oauth/constants"
16 "github.com/haileyok/cocoon/oauth/provider"
17 "github.com/labstack/echo/v4"
18)
19
20type OauthTokenRequest struct {
21 provider.AuthenticateClientRequestBase
22 GrantType string `form:"grant_type" json:"grant_type"`
23 Code *string `form:"code" json:"code,omitempty"`
24 CodeVerifier *string `form:"code_verifier" json:"code_verifier,omitempty"`
25 RedirectURI *string `form:"redirect_uri" json:"redirect_uri,omitempty"`
26 RefreshToken *string `form:"refresh_token" json:"refresh_token,omitempty"`
27}
28
29type OauthTokenResponse struct {
30 AccessToken string `json:"access_token"`
31 TokenType string `json:"token_type"`
32 RefreshToken string `json:"refresh_token"`
33 Scope string `json:"scope"`
34 ExpiresIn int64 `json:"expires_in"`
35 Sub string `json:"sub"`
36}
37
38func (s *Server) handleOauthToken(e echo.Context) error {
39 var req OauthTokenRequest
40 if err := e.Bind(&req); err != nil {
41 s.logger.Error("error binding token request", "error", err)
42 return helpers.ServerError(e, nil)
43 }
44
45 proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, e.Request().URL.String(), e.Request().Header, nil)
46 if err != nil {
47 s.logger.Error("error getting dpop proof", "error", err)
48 return helpers.InputError(e, to.StringPtr(err.Error()))
49 }
50
51 client, clientAuth, err := s.oauthProvider.AuthenticateClient(e.Request().Context(), req.AuthenticateClientRequestBase, proof, &provider.AuthenticateClientOptions{
52 AllowMissingDpopProof: true,
53 })
54 if err != nil {
55 s.logger.Error("error authenticating client", "error", err)
56 return helpers.InputError(e, to.StringPtr(err.Error()))
57 }
58
59 // TODO: this should come from an oauth provier config
60 if !slices.Contains([]string{"authorization_code", "refresh_token"}, req.GrantType) {
61 return helpers.InputError(e, to.StringPtr(fmt.Sprintf(`"%s" grant type is not supported by the server`, req.GrantType)))
62 }
63
64 if !slices.Contains(client.Metadata.GrantTypes, req.GrantType) {
65 return helpers.InputError(e, to.StringPtr(fmt.Sprintf(`"%s" grant type is not supported by the client`, req.GrantType)))
66 }
67
68 if req.GrantType == "authorization_code" {
69 if req.Code == nil {
70 return helpers.InputError(e, to.StringPtr(`"code" is required"`))
71 }
72
73 var authReq provider.OauthAuthorizationRequest
74 // get the lil guy and delete him
75 if err := s.db.Raw("DELETE FROM oauth_authorization_requests WHERE code = ? RETURNING *", nil, *req.Code).Scan(&authReq).Error; err != nil {
76 s.logger.Error("error finding authorization request", "error", err)
77 return helpers.ServerError(e, nil)
78 }
79
80 if req.RedirectURI == nil || *req.RedirectURI != authReq.Parameters.RedirectURI {
81 return helpers.InputError(e, to.StringPtr(`"redirect_uri" mismatch`))
82 }
83
84 if authReq.Parameters.CodeChallenge != nil {
85 if req.CodeVerifier == nil {
86 return helpers.InputError(e, to.StringPtr(`"code_verifier" is required`))
87 }
88
89 if len(*req.CodeVerifier) < 43 {
90 return helpers.InputError(e, to.StringPtr(`"code_verifier" is too short`))
91 }
92
93 switch *&authReq.Parameters.CodeChallengeMethod {
94 case "", "plain":
95 if authReq.Parameters.CodeChallenge != req.CodeVerifier {
96 return helpers.InputError(e, to.StringPtr("invalid code_verifier"))
97 }
98 case "S256":
99 inputChal, err := base64.RawURLEncoding.DecodeString(*authReq.Parameters.CodeChallenge)
100 if err != nil {
101 s.logger.Error("error decoding code challenge", "error", err)
102 return helpers.ServerError(e, nil)
103 }
104
105 h := sha256.New()
106 h.Write([]byte(*req.CodeVerifier))
107 compdChal := h.Sum(nil)
108
109 if !bytes.Equal(inputChal, compdChal) {
110 return helpers.InputError(e, to.StringPtr("invalid code_verifier"))
111 }
112 default:
113 return helpers.InputError(e, to.StringPtr("unsupported code_challenge_method "+*&authReq.Parameters.CodeChallengeMethod))
114 }
115 } else if req.CodeVerifier != nil {
116 return helpers.InputError(e, to.StringPtr("code_challenge parameter wasn't provided"))
117 }
118
119 repo, err := s.getRepoActorByDid(*authReq.Sub)
120 if err != nil {
121 helpers.InputError(e, to.StringPtr("unable to find actor"))
122 }
123
124 now := time.Now()
125 eat := now.Add(constants.TokenMaxAge)
126 id := oauth.GenerateTokenId()
127
128 refreshToken := oauth.GenerateRefreshToken()
129
130 accessClaims := jwt.MapClaims{
131 "scope": authReq.Parameters.Scope,
132 "aud": s.config.Did,
133 "sub": repo.Repo.Did,
134 "iat": now.Unix(),
135 "exp": eat.Unix(),
136 "jti": id,
137 "client_id": authReq.ClientId,
138 }
139
140 if authReq.Parameters.DpopJkt != nil {
141 accessClaims["cnf"] = *authReq.Parameters.DpopJkt
142 }
143
144 accessToken := jwt.NewWithClaims(jwt.SigningMethodES256, accessClaims)
145 accessString, err := accessToken.SignedString(s.privateKey)
146 if err != nil {
147 return err
148 }
149
150 if err := s.db.Create(&provider.OauthToken{
151 ClientId: authReq.ClientId,
152 ClientAuth: *clientAuth,
153 Parameters: authReq.Parameters,
154 ExpiresAt: eat,
155 DeviceId: "",
156 Sub: repo.Repo.Did,
157 Code: *authReq.Code,
158 Token: accessString,
159 RefreshToken: refreshToken,
160 Ip: authReq.Ip,
161 }, nil).Error; err != nil {
162 s.logger.Error("error creating token in db", "error", err)
163 return helpers.ServerError(e, nil)
164 }
165
166 // prob not needed
167 tokenType := "Bearer"
168 if authReq.Parameters.DpopJkt != nil {
169 tokenType = "DPoP"
170 }
171
172 e.Response().Header().Set("content-type", "application/json")
173
174 return e.JSON(200, OauthTokenResponse{
175 AccessToken: accessString,
176 RefreshToken: refreshToken,
177 TokenType: tokenType,
178 Scope: authReq.Parameters.Scope,
179 ExpiresIn: int64(eat.Sub(time.Now()).Seconds()),
180 Sub: repo.Repo.Did,
181 })
182 }
183
184 if req.GrantType == "refresh_token" {
185 if req.RefreshToken == nil {
186 return helpers.InputError(e, to.StringPtr(`"refresh_token" is required`))
187 }
188
189 var oauthToken provider.OauthToken
190 if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE refresh_token = ?", nil, req.RefreshToken).Scan(&oauthToken).Error; err != nil {
191 s.logger.Error("error finding oauth token by refresh token", "error", err, "refresh_token", req.RefreshToken)
192 return helpers.ServerError(e, nil)
193 }
194
195 if client.Metadata.ClientID != oauthToken.ClientId {
196 return helpers.InputError(e, to.StringPtr(`"client_id" mismatch`))
197 }
198
199 if clientAuth.Method != oauthToken.ClientAuth.Method {
200 return helpers.InputError(e, to.StringPtr(`"client authentication method mismatch`))
201 }
202
203 if *oauthToken.Parameters.DpopJkt != proof.JKT {
204 return helpers.InputError(e, to.StringPtr("dpop proof does not match expected jkt"))
205 }
206
207 ageRes := oauth.GetSessionAgeFromToken(oauthToken)
208
209 if ageRes.SessionExpired {
210 return helpers.InputError(e, to.StringPtr("Session expired"))
211 }
212
213 if ageRes.RefreshExpired {
214 return helpers.InputError(e, to.StringPtr("Refresh token expired"))
215 }
216
217 if client.Metadata.DpopBoundAccessTokens && oauthToken.Parameters.DpopJkt == nil {
218 // why? ref impl
219 return helpers.InputError(e, to.StringPtr("dpop jkt is required for dpop bound access tokens"))
220 }
221
222 nextTokenId := oauth.GenerateTokenId()
223 nextRefreshToken := oauth.GenerateRefreshToken()
224
225 now := time.Now()
226 eat := now.Add(constants.TokenMaxAge)
227
228 accessClaims := jwt.MapClaims{
229 "scope": oauthToken.Parameters.Scope,
230 "aud": s.config.Did,
231 "sub": oauthToken.Sub,
232 "iat": now.Unix(),
233 "exp": eat.Unix(),
234 "jti": nextTokenId,
235 "client_id": oauthToken.ClientId,
236 }
237
238 if oauthToken.Parameters.DpopJkt != nil {
239 accessClaims["cnf"] = *&oauthToken.Parameters.DpopJkt
240 }
241
242 accessToken := jwt.NewWithClaims(jwt.SigningMethodES256, accessClaims)
243 accessString, err := accessToken.SignedString(s.privateKey)
244 if err != nil {
245 return err
246 }
247
248 if err := s.db.Exec("UPDATE oauth_tokens SET token = ?, refresh_token = ?, expires_at = ?, updated_at = ? WHERE refresh_token = ?", nil, accessString, nextRefreshToken, eat, now, *req.RefreshToken).Error; err != nil {
249 s.logger.Error("error updating token", "error", err)
250 return helpers.ServerError(e, nil)
251 }
252
253 // prob not needed
254 tokenType := "Bearer"
255 if oauthToken.Parameters.DpopJkt != nil {
256 tokenType = "DPoP"
257 }
258
259 return e.JSON(200, OauthTokenResponse{
260 AccessToken: accessString,
261 RefreshToken: nextRefreshToken,
262 TokenType: tokenType,
263 Scope: oauthToken.Parameters.Scope,
264 ExpiresIn: int64(eat.Sub(time.Now()).Seconds()),
265 Sub: oauthToken.Sub,
266 })
267 }
268
269 return helpers.InputError(e, to.StringPtr(fmt.Sprintf(`grant type "%s" is not supported`, req.GrantType)))
270}