1package server
2
3import (
4 "bytes"
5 "crypto/sha256"
6 "encoding/base64"
7 "errors"
8 "fmt"
9 "slices"
10 "time"
11
12 "github.com/Azure/go-autorest/autorest/to"
13 "github.com/golang-jwt/jwt/v4"
14 "github.com/haileyok/cocoon/internal/helpers"
15 "github.com/haileyok/cocoon/oauth"
16 "github.com/haileyok/cocoon/oauth/constants"
17 "github.com/haileyok/cocoon/oauth/dpop"
18 "github.com/haileyok/cocoon/oauth/provider"
19 "github.com/labstack/echo/v4"
20)
21
22type OauthTokenRequest struct {
23 provider.AuthenticateClientRequestBase
24 GrantType string `form:"grant_type" json:"grant_type"`
25 Code *string `form:"code" json:"code,omitempty"`
26 CodeVerifier *string `form:"code_verifier" json:"code_verifier,omitempty"`
27 RedirectURI *string `form:"redirect_uri" json:"redirect_uri,omitempty"`
28 RefreshToken *string `form:"refresh_token" json:"refresh_token,omitempty"`
29}
30
31type OauthTokenResponse struct {
32 AccessToken string `json:"access_token"`
33 TokenType string `json:"token_type"`
34 RefreshToken string `json:"refresh_token"`
35 Scope string `json:"scope"`
36 ExpiresIn int64 `json:"expires_in"`
37 Sub string `json:"sub"`
38}
39
40func (s *Server) handleOauthToken(e echo.Context) error {
41 var req OauthTokenRequest
42 if err := e.Bind(&req); err != nil {
43 s.logger.Error("error binding token request", "error", err)
44 return helpers.ServerError(e, nil)
45 }
46
47 proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, e.Request().URL.String(), e.Request().Header, nil)
48 if err != nil {
49 if errors.Is(err, dpop.ErrUseDpopNonce) {
50 return e.JSON(400, map[string]string{
51 "error": "use_dpop_nonce",
52 })
53 }
54 s.logger.Error("error getting dpop proof", "error", err)
55 return helpers.InputError(e, nil)
56 }
57
58 client, clientAuth, err := s.oauthProvider.AuthenticateClient(e.Request().Context(), req.AuthenticateClientRequestBase, proof, &provider.AuthenticateClientOptions{
59 AllowMissingDpopProof: true,
60 })
61 if err != nil {
62 s.logger.Error("error authenticating client", "client_id", req.ClientID, "error", err)
63 return helpers.InputError(e, to.StringPtr(err.Error()))
64 }
65
66 // TODO: this should come from an oauth provier config
67 if !slices.Contains([]string{"authorization_code", "refresh_token"}, req.GrantType) {
68 return helpers.InputError(e, to.StringPtr(fmt.Sprintf(`"%s" grant type is not supported by the server`, req.GrantType)))
69 }
70
71 if !slices.Contains(client.Metadata.GrantTypes, req.GrantType) {
72 return helpers.InputError(e, to.StringPtr(fmt.Sprintf(`"%s" grant type is not supported by the client`, req.GrantType)))
73 }
74
75 if req.GrantType == "authorization_code" {
76 if req.Code == nil {
77 return helpers.InputError(e, to.StringPtr(`"code" is required"`))
78 }
79
80 var authReq provider.OauthAuthorizationRequest
81 // get the lil guy and delete him
82 if err := s.db.Raw("DELETE FROM oauth_authorization_requests WHERE code = ? RETURNING *", nil, *req.Code).Scan(&authReq).Error; err != nil {
83 s.logger.Error("error finding authorization request", "error", err)
84 return helpers.ServerError(e, nil)
85 }
86
87 if req.RedirectURI == nil || *req.RedirectURI != authReq.Parameters.RedirectURI {
88 return helpers.InputError(e, to.StringPtr(`"redirect_uri" mismatch`))
89 }
90
91 if authReq.Parameters.CodeChallenge != nil {
92 if req.CodeVerifier == nil {
93 return helpers.InputError(e, to.StringPtr(`"code_verifier" is required`))
94 }
95
96 if len(*req.CodeVerifier) < 43 {
97 return helpers.InputError(e, to.StringPtr(`"code_verifier" is too short`))
98 }
99
100 switch *&authReq.Parameters.CodeChallengeMethod {
101 case "", "plain":
102 if authReq.Parameters.CodeChallenge != req.CodeVerifier {
103 return helpers.InputError(e, to.StringPtr("invalid code_verifier"))
104 }
105 case "S256":
106 inputChal, err := base64.RawURLEncoding.DecodeString(*authReq.Parameters.CodeChallenge)
107 if err != nil {
108 s.logger.Error("error decoding code challenge", "error", err)
109 return helpers.ServerError(e, nil)
110 }
111
112 h := sha256.New()
113 h.Write([]byte(*req.CodeVerifier))
114 compdChal := h.Sum(nil)
115
116 if !bytes.Equal(inputChal, compdChal) {
117 return helpers.InputError(e, to.StringPtr("invalid code_verifier"))
118 }
119 default:
120 return helpers.InputError(e, to.StringPtr("unsupported code_challenge_method "+*&authReq.Parameters.CodeChallengeMethod))
121 }
122 } else if req.CodeVerifier != nil {
123 return helpers.InputError(e, to.StringPtr("code_challenge parameter wasn't provided"))
124 }
125
126 repo, err := s.getRepoActorByDid(*authReq.Sub)
127 if err != nil {
128 helpers.InputError(e, to.StringPtr("unable to find actor"))
129 }
130
131 now := time.Now()
132 eat := now.Add(constants.TokenMaxAge)
133 id := oauth.GenerateTokenId()
134
135 refreshToken := oauth.GenerateRefreshToken()
136
137 accessClaims := jwt.MapClaims{
138 "scope": authReq.Parameters.Scope,
139 "aud": s.config.Did,
140 "sub": repo.Repo.Did,
141 "iat": now.Unix(),
142 "exp": eat.Unix(),
143 "jti": id,
144 "client_id": authReq.ClientId,
145 }
146
147 if authReq.Parameters.DpopJkt != nil {
148 accessClaims["cnf"] = *authReq.Parameters.DpopJkt
149 }
150
151 accessToken := jwt.NewWithClaims(jwt.SigningMethodES256, accessClaims)
152 accessString, err := accessToken.SignedString(s.privateKey)
153 if err != nil {
154 return err
155 }
156
157 if err := s.db.Create(&provider.OauthToken{
158 ClientId: authReq.ClientId,
159 ClientAuth: *clientAuth,
160 Parameters: authReq.Parameters,
161 ExpiresAt: eat,
162 DeviceId: "",
163 Sub: repo.Repo.Did,
164 Code: *authReq.Code,
165 Token: accessString,
166 RefreshToken: refreshToken,
167 Ip: authReq.Ip,
168 }, nil).Error; err != nil {
169 s.logger.Error("error creating token in db", "error", err)
170 return helpers.ServerError(e, nil)
171 }
172
173 // prob not needed
174 tokenType := "Bearer"
175 if authReq.Parameters.DpopJkt != nil {
176 tokenType = "DPoP"
177 }
178
179 e.Response().Header().Set("content-type", "application/json")
180
181 return e.JSON(200, OauthTokenResponse{
182 AccessToken: accessString,
183 RefreshToken: refreshToken,
184 TokenType: tokenType,
185 Scope: authReq.Parameters.Scope,
186 ExpiresIn: int64(eat.Sub(time.Now()).Seconds()),
187 Sub: repo.Repo.Did,
188 })
189 }
190
191 if req.GrantType == "refresh_token" {
192 if req.RefreshToken == nil {
193 return helpers.InputError(e, to.StringPtr(`"refresh_token" is required`))
194 }
195
196 var oauthToken provider.OauthToken
197 if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE refresh_token = ?", nil, req.RefreshToken).Scan(&oauthToken).Error; err != nil {
198 s.logger.Error("error finding oauth token by refresh token", "error", err, "refresh_token", req.RefreshToken)
199 return helpers.ServerError(e, nil)
200 }
201
202 if client.Metadata.ClientID != oauthToken.ClientId {
203 return helpers.InputError(e, to.StringPtr(`"client_id" mismatch`))
204 }
205
206 if clientAuth.Method != oauthToken.ClientAuth.Method {
207 return helpers.InputError(e, to.StringPtr(`"client authentication method mismatch`))
208 }
209
210 if *oauthToken.Parameters.DpopJkt != proof.JKT {
211 return helpers.InputError(e, to.StringPtr("dpop proof does not match expected jkt"))
212 }
213
214 ageRes := oauth.GetSessionAgeFromToken(oauthToken)
215
216 if ageRes.SessionExpired {
217 return helpers.InputError(e, to.StringPtr("Session expired"))
218 }
219
220 if ageRes.RefreshExpired {
221 return helpers.InputError(e, to.StringPtr("Refresh token expired"))
222 }
223
224 if client.Metadata.DpopBoundAccessTokens && oauthToken.Parameters.DpopJkt == nil {
225 // why? ref impl
226 return helpers.InputError(e, to.StringPtr("dpop jkt is required for dpop bound access tokens"))
227 }
228
229 nextTokenId := oauth.GenerateTokenId()
230 nextRefreshToken := oauth.GenerateRefreshToken()
231
232 now := time.Now()
233 eat := now.Add(constants.TokenMaxAge)
234
235 accessClaims := jwt.MapClaims{
236 "scope": oauthToken.Parameters.Scope,
237 "aud": s.config.Did,
238 "sub": oauthToken.Sub,
239 "iat": now.Unix(),
240 "exp": eat.Unix(),
241 "jti": nextTokenId,
242 "client_id": oauthToken.ClientId,
243 }
244
245 if oauthToken.Parameters.DpopJkt != nil {
246 accessClaims["cnf"] = *&oauthToken.Parameters.DpopJkt
247 }
248
249 accessToken := jwt.NewWithClaims(jwt.SigningMethodES256, accessClaims)
250 accessString, err := accessToken.SignedString(s.privateKey)
251 if err != nil {
252 return err
253 }
254
255 if err := s.db.Exec("UPDATE oauth_tokens SET token = ?, refresh_token = ?, expires_at = ?, updated_at = ? WHERE refresh_token = ?", nil, accessString, nextRefreshToken, eat, now, *req.RefreshToken).Error; err != nil {
256 s.logger.Error("error updating token", "error", err)
257 return helpers.ServerError(e, nil)
258 }
259
260 // prob not needed
261 tokenType := "Bearer"
262 if oauthToken.Parameters.DpopJkt != nil {
263 tokenType = "DPoP"
264 }
265
266 return e.JSON(200, OauthTokenResponse{
267 AccessToken: accessString,
268 RefreshToken: nextRefreshToken,
269 TokenType: tokenType,
270 Scope: oauthToken.Parameters.Scope,
271 ExpiresIn: int64(eat.Sub(time.Now()).Seconds()),
272 Sub: oauthToken.Sub,
273 })
274 }
275
276 return helpers.InputError(e, to.StringPtr(fmt.Sprintf(`grant type "%s" is not supported`, req.GrantType)))
277}