1package server
2
3import (
4 "bytes"
5 "crypto/sha256"
6 "encoding/base64"
7 "errors"
8 "fmt"
9 "slices"
10 "time"
11
12 "github.com/Azure/go-autorest/autorest/to"
13 "github.com/golang-jwt/jwt/v4"
14 "github.com/haileyok/cocoon/internal/helpers"
15 "github.com/haileyok/cocoon/oauth"
16 "github.com/haileyok/cocoon/oauth/constants"
17 "github.com/haileyok/cocoon/oauth/dpop"
18 "github.com/haileyok/cocoon/oauth/provider"
19 "github.com/labstack/echo/v4"
20)
21
22type OauthTokenRequest struct {
23 provider.AuthenticateClientRequestBase
24 GrantType string `form:"grant_type" json:"grant_type"`
25 Code *string `form:"code" json:"code,omitempty"`
26 CodeVerifier *string `form:"code_verifier" json:"code_verifier,omitempty"`
27 RedirectURI *string `form:"redirect_uri" json:"redirect_uri,omitempty"`
28 RefreshToken *string `form:"refresh_token" json:"refresh_token,omitempty"`
29}
30
31type OauthTokenResponse struct {
32 AccessToken string `json:"access_token"`
33 TokenType string `json:"token_type"`
34 RefreshToken string `json:"refresh_token"`
35 Scope string `json:"scope"`
36 ExpiresIn int64 `json:"expires_in"`
37 Sub string `json:"sub"`
38}
39
40func (s *Server) handleOauthToken(e echo.Context) error {
41 var req OauthTokenRequest
42 if err := e.Bind(&req); err != nil {
43 s.logger.Error("error binding token request", "error", err)
44 return helpers.ServerError(e, nil)
45 }
46
47 proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, e.Request().URL.String(), e.Request().Header, nil)
48 if err != nil {
49 if errors.Is(err, dpop.ErrUseDpopNonce) {
50 nonce := s.oauthProvider.NextNonce()
51 if nonce != "" {
52 e.Response().Header().Set("DPoP-Nonce", nonce)
53 e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce")
54 }
55 return e.JSON(400, map[string]string{
56 "error": "use_dpop_nonce",
57 })
58 }
59 s.logger.Error("error getting dpop proof", "error", err)
60 return helpers.InputError(e, nil)
61 }
62
63 client, clientAuth, err := s.oauthProvider.AuthenticateClient(e.Request().Context(), req.AuthenticateClientRequestBase, proof, &provider.AuthenticateClientOptions{
64 AllowMissingDpopProof: true,
65 })
66 if err != nil {
67 s.logger.Error("error authenticating client", "client_id", req.ClientID, "error", err)
68 return helpers.InputError(e, to.StringPtr(err.Error()))
69 }
70
71 // TODO: this should come from an oauth provier config
72 if !slices.Contains([]string{"authorization_code", "refresh_token"}, req.GrantType) {
73 return helpers.InputError(e, to.StringPtr(fmt.Sprintf(`"%s" grant type is not supported by the server`, req.GrantType)))
74 }
75
76 if !slices.Contains(client.Metadata.GrantTypes, req.GrantType) {
77 return helpers.InputError(e, to.StringPtr(fmt.Sprintf(`"%s" grant type is not supported by the client`, req.GrantType)))
78 }
79
80 if req.GrantType == "authorization_code" {
81 if req.Code == nil {
82 return helpers.InputError(e, to.StringPtr(`"code" is required"`))
83 }
84
85 var authReq provider.OauthAuthorizationRequest
86 // get the lil guy and delete him
87 if err := s.db.Raw("DELETE FROM oauth_authorization_requests WHERE code = ? RETURNING *", nil, *req.Code).Scan(&authReq).Error; err != nil {
88 s.logger.Error("error finding authorization request", "error", err)
89 return helpers.ServerError(e, nil)
90 }
91
92 if req.RedirectURI == nil || *req.RedirectURI != authReq.Parameters.RedirectURI {
93 return helpers.InputError(e, to.StringPtr(`"redirect_uri" mismatch`))
94 }
95
96 if authReq.Parameters.CodeChallenge != nil {
97 if req.CodeVerifier == nil {
98 return helpers.InputError(e, to.StringPtr(`"code_verifier" is required`))
99 }
100
101 if len(*req.CodeVerifier) < 43 {
102 return helpers.InputError(e, to.StringPtr(`"code_verifier" is too short`))
103 }
104
105 switch *&authReq.Parameters.CodeChallengeMethod {
106 case "", "plain":
107 if authReq.Parameters.CodeChallenge != req.CodeVerifier {
108 return helpers.InputError(e, to.StringPtr("invalid code_verifier"))
109 }
110 case "S256":
111 inputChal, err := base64.RawURLEncoding.DecodeString(*authReq.Parameters.CodeChallenge)
112 if err != nil {
113 s.logger.Error("error decoding code challenge", "error", err)
114 return helpers.ServerError(e, nil)
115 }
116
117 h := sha256.New()
118 h.Write([]byte(*req.CodeVerifier))
119 compdChal := h.Sum(nil)
120
121 if !bytes.Equal(inputChal, compdChal) {
122 return helpers.InputError(e, to.StringPtr("invalid code_verifier"))
123 }
124 default:
125 return helpers.InputError(e, to.StringPtr("unsupported code_challenge_method "+*&authReq.Parameters.CodeChallengeMethod))
126 }
127 } else if req.CodeVerifier != nil {
128 return helpers.InputError(e, to.StringPtr("code_challenge parameter wasn't provided"))
129 }
130
131 repo, err := s.getRepoActorByDid(*authReq.Sub)
132 if err != nil {
133 helpers.InputError(e, to.StringPtr("unable to find actor"))
134 }
135
136 now := time.Now()
137 eat := now.Add(constants.TokenMaxAge)
138 id := oauth.GenerateTokenId()
139
140 refreshToken := oauth.GenerateRefreshToken()
141
142 accessClaims := jwt.MapClaims{
143 "scope": authReq.Parameters.Scope,
144 "aud": s.config.Did,
145 "sub": repo.Repo.Did,
146 "iat": now.Unix(),
147 "exp": eat.Unix(),
148 "jti": id,
149 "client_id": authReq.ClientId,
150 }
151
152 if authReq.Parameters.DpopJkt != nil {
153 accessClaims["cnf"] = *authReq.Parameters.DpopJkt
154 }
155
156 accessToken := jwt.NewWithClaims(jwt.SigningMethodES256, accessClaims)
157 accessString, err := accessToken.SignedString(s.privateKey)
158 if err != nil {
159 return err
160 }
161
162 if err := s.db.Create(&provider.OauthToken{
163 ClientId: authReq.ClientId,
164 ClientAuth: *clientAuth,
165 Parameters: authReq.Parameters,
166 ExpiresAt: eat,
167 DeviceId: "",
168 Sub: repo.Repo.Did,
169 Code: *authReq.Code,
170 Token: accessString,
171 RefreshToken: refreshToken,
172 Ip: authReq.Ip,
173 }, nil).Error; err != nil {
174 s.logger.Error("error creating token in db", "error", err)
175 return helpers.ServerError(e, nil)
176 }
177
178 // prob not needed
179 tokenType := "Bearer"
180 if authReq.Parameters.DpopJkt != nil {
181 tokenType = "DPoP"
182 }
183
184 e.Response().Header().Set("content-type", "application/json")
185
186 return e.JSON(200, OauthTokenResponse{
187 AccessToken: accessString,
188 RefreshToken: refreshToken,
189 TokenType: tokenType,
190 Scope: authReq.Parameters.Scope,
191 ExpiresIn: int64(eat.Sub(time.Now()).Seconds()),
192 Sub: repo.Repo.Did,
193 })
194 }
195
196 if req.GrantType == "refresh_token" {
197 if req.RefreshToken == nil {
198 return helpers.InputError(e, to.StringPtr(`"refresh_token" is required`))
199 }
200
201 var oauthToken provider.OauthToken
202 if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE refresh_token = ?", nil, req.RefreshToken).Scan(&oauthToken).Error; err != nil {
203 s.logger.Error("error finding oauth token by refresh token", "error", err, "refresh_token", req.RefreshToken)
204 return helpers.ServerError(e, nil)
205 }
206
207 if client.Metadata.ClientID != oauthToken.ClientId {
208 return helpers.InputError(e, to.StringPtr(`"client_id" mismatch`))
209 }
210
211 if clientAuth.Method != oauthToken.ClientAuth.Method {
212 return helpers.InputError(e, to.StringPtr(`"client authentication method mismatch`))
213 }
214
215 if *oauthToken.Parameters.DpopJkt != proof.JKT {
216 return helpers.InputError(e, to.StringPtr("dpop proof does not match expected jkt"))
217 }
218
219 ageRes := oauth.GetSessionAgeFromToken(oauthToken)
220
221 if ageRes.SessionExpired {
222 return helpers.InputError(e, to.StringPtr("Session expired"))
223 }
224
225 if ageRes.RefreshExpired {
226 return helpers.InputError(e, to.StringPtr("Refresh token expired"))
227 }
228
229 if client.Metadata.DpopBoundAccessTokens && oauthToken.Parameters.DpopJkt == nil {
230 // why? ref impl
231 return helpers.InputError(e, to.StringPtr("dpop jkt is required for dpop bound access tokens"))
232 }
233
234 nextTokenId := oauth.GenerateTokenId()
235 nextRefreshToken := oauth.GenerateRefreshToken()
236
237 now := time.Now()
238 eat := now.Add(constants.TokenMaxAge)
239
240 accessClaims := jwt.MapClaims{
241 "scope": oauthToken.Parameters.Scope,
242 "aud": s.config.Did,
243 "sub": oauthToken.Sub,
244 "iat": now.Unix(),
245 "exp": eat.Unix(),
246 "jti": nextTokenId,
247 "client_id": oauthToken.ClientId,
248 }
249
250 if oauthToken.Parameters.DpopJkt != nil {
251 accessClaims["cnf"] = *&oauthToken.Parameters.DpopJkt
252 }
253
254 accessToken := jwt.NewWithClaims(jwt.SigningMethodES256, accessClaims)
255 accessString, err := accessToken.SignedString(s.privateKey)
256 if err != nil {
257 return err
258 }
259
260 if err := s.db.Exec("UPDATE oauth_tokens SET token = ?, refresh_token = ?, expires_at = ?, updated_at = ? WHERE refresh_token = ?", nil, accessString, nextRefreshToken, eat, now, *req.RefreshToken).Error; err != nil {
261 s.logger.Error("error updating token", "error", err)
262 return helpers.ServerError(e, nil)
263 }
264
265 // prob not needed
266 tokenType := "Bearer"
267 if oauthToken.Parameters.DpopJkt != nil {
268 tokenType = "DPoP"
269 }
270
271 return e.JSON(200, OauthTokenResponse{
272 AccessToken: accessString,
273 RefreshToken: nextRefreshToken,
274 TokenType: tokenType,
275 Scope: oauthToken.Parameters.Scope,
276 ExpiresIn: int64(eat.Sub(time.Now()).Seconds()),
277 Sub: oauthToken.Sub,
278 })
279 }
280
281 return helpers.InputError(e, to.StringPtr(fmt.Sprintf(`grant type "%s" is not supported`, req.GrantType)))
282}