An atproto PDS written in Go
at main 9.2 kB view raw
1package server 2 3import ( 4 "bytes" 5 "crypto/sha256" 6 "encoding/base64" 7 "errors" 8 "fmt" 9 "slices" 10 "time" 11 12 "github.com/Azure/go-autorest/autorest/to" 13 "github.com/golang-jwt/jwt/v4" 14 "github.com/haileyok/cocoon/internal/helpers" 15 "github.com/haileyok/cocoon/oauth" 16 "github.com/haileyok/cocoon/oauth/constants" 17 "github.com/haileyok/cocoon/oauth/dpop" 18 "github.com/haileyok/cocoon/oauth/provider" 19 "github.com/labstack/echo/v4" 20) 21 22type OauthTokenRequest struct { 23 provider.AuthenticateClientRequestBase 24 GrantType string `form:"grant_type" json:"grant_type"` 25 Code *string `form:"code" json:"code,omitempty"` 26 CodeVerifier *string `form:"code_verifier" json:"code_verifier,omitempty"` 27 RedirectURI *string `form:"redirect_uri" json:"redirect_uri,omitempty"` 28 RefreshToken *string `form:"refresh_token" json:"refresh_token,omitempty"` 29} 30 31type OauthTokenResponse struct { 32 AccessToken string `json:"access_token"` 33 TokenType string `json:"token_type"` 34 RefreshToken string `json:"refresh_token"` 35 Scope string `json:"scope"` 36 ExpiresIn int64 `json:"expires_in"` 37 Sub string `json:"sub"` 38} 39 40func (s *Server) handleOauthToken(e echo.Context) error { 41 var req OauthTokenRequest 42 if err := e.Bind(&req); err != nil { 43 s.logger.Error("error binding token request", "error", err) 44 return helpers.ServerError(e, nil) 45 } 46 47 proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, e.Request().URL.String(), e.Request().Header, nil) 48 if err != nil { 49 if errors.Is(err, dpop.ErrUseDpopNonce) { 50 nonce := s.oauthProvider.NextNonce() 51 if nonce != "" { 52 e.Response().Header().Set("DPoP-Nonce", nonce) 53 e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce") 54 } 55 return e.JSON(400, map[string]string{ 56 "error": "use_dpop_nonce", 57 }) 58 } 59 s.logger.Error("error getting dpop proof", "error", err) 60 return helpers.InputError(e, nil) 61 } 62 63 client, clientAuth, err := s.oauthProvider.AuthenticateClient(e.Request().Context(), req.AuthenticateClientRequestBase, proof, &provider.AuthenticateClientOptions{ 64 AllowMissingDpopProof: true, 65 }) 66 if err != nil { 67 s.logger.Error("error authenticating client", "client_id", req.ClientID, "error", err) 68 return helpers.InputError(e, to.StringPtr(err.Error())) 69 } 70 71 // TODO: this should come from an oauth provier config 72 if !slices.Contains([]string{"authorization_code", "refresh_token"}, req.GrantType) { 73 return helpers.InputError(e, to.StringPtr(fmt.Sprintf(`"%s" grant type is not supported by the server`, req.GrantType))) 74 } 75 76 if !slices.Contains(client.Metadata.GrantTypes, req.GrantType) { 77 return helpers.InputError(e, to.StringPtr(fmt.Sprintf(`"%s" grant type is not supported by the client`, req.GrantType))) 78 } 79 80 if req.GrantType == "authorization_code" { 81 if req.Code == nil { 82 return helpers.InputError(e, to.StringPtr(`"code" is required"`)) 83 } 84 85 var authReq provider.OauthAuthorizationRequest 86 // get the lil guy and delete him 87 if err := s.db.Raw("DELETE FROM oauth_authorization_requests WHERE code = ? RETURNING *", nil, *req.Code).Scan(&authReq).Error; err != nil { 88 s.logger.Error("error finding authorization request", "error", err) 89 return helpers.ServerError(e, nil) 90 } 91 92 if req.RedirectURI == nil || *req.RedirectURI != authReq.Parameters.RedirectURI { 93 return helpers.InputError(e, to.StringPtr(`"redirect_uri" mismatch`)) 94 } 95 96 if authReq.Parameters.CodeChallenge != nil { 97 if req.CodeVerifier == nil { 98 return helpers.InputError(e, to.StringPtr(`"code_verifier" is required`)) 99 } 100 101 if len(*req.CodeVerifier) < 43 { 102 return helpers.InputError(e, to.StringPtr(`"code_verifier" is too short`)) 103 } 104 105 switch *&authReq.Parameters.CodeChallengeMethod { 106 case "", "plain": 107 if authReq.Parameters.CodeChallenge != req.CodeVerifier { 108 return helpers.InputError(e, to.StringPtr("invalid code_verifier")) 109 } 110 case "S256": 111 inputChal, err := base64.RawURLEncoding.DecodeString(*authReq.Parameters.CodeChallenge) 112 if err != nil { 113 s.logger.Error("error decoding code challenge", "error", err) 114 return helpers.ServerError(e, nil) 115 } 116 117 h := sha256.New() 118 h.Write([]byte(*req.CodeVerifier)) 119 compdChal := h.Sum(nil) 120 121 if !bytes.Equal(inputChal, compdChal) { 122 return helpers.InputError(e, to.StringPtr("invalid code_verifier")) 123 } 124 default: 125 return helpers.InputError(e, to.StringPtr("unsupported code_challenge_method "+*&authReq.Parameters.CodeChallengeMethod)) 126 } 127 } else if req.CodeVerifier != nil { 128 return helpers.InputError(e, to.StringPtr("code_challenge parameter wasn't provided")) 129 } 130 131 repo, err := s.getRepoActorByDid(*authReq.Sub) 132 if err != nil { 133 helpers.InputError(e, to.StringPtr("unable to find actor")) 134 } 135 136 now := time.Now() 137 eat := now.Add(constants.TokenMaxAge) 138 id := oauth.GenerateTokenId() 139 140 refreshToken := oauth.GenerateRefreshToken() 141 142 accessClaims := jwt.MapClaims{ 143 "scope": authReq.Parameters.Scope, 144 "aud": s.config.Did, 145 "sub": repo.Repo.Did, 146 "iat": now.Unix(), 147 "exp": eat.Unix(), 148 "jti": id, 149 "client_id": authReq.ClientId, 150 } 151 152 if authReq.Parameters.DpopJkt != nil { 153 accessClaims["cnf"] = *authReq.Parameters.DpopJkt 154 } 155 156 accessToken := jwt.NewWithClaims(jwt.SigningMethodES256, accessClaims) 157 accessString, err := accessToken.SignedString(s.privateKey) 158 if err != nil { 159 return err 160 } 161 162 if err := s.db.Create(&provider.OauthToken{ 163 ClientId: authReq.ClientId, 164 ClientAuth: *clientAuth, 165 Parameters: authReq.Parameters, 166 ExpiresAt: eat, 167 DeviceId: "", 168 Sub: repo.Repo.Did, 169 Code: *authReq.Code, 170 Token: accessString, 171 RefreshToken: refreshToken, 172 Ip: authReq.Ip, 173 }, nil).Error; err != nil { 174 s.logger.Error("error creating token in db", "error", err) 175 return helpers.ServerError(e, nil) 176 } 177 178 // prob not needed 179 tokenType := "Bearer" 180 if authReq.Parameters.DpopJkt != nil { 181 tokenType = "DPoP" 182 } 183 184 e.Response().Header().Set("content-type", "application/json") 185 186 return e.JSON(200, OauthTokenResponse{ 187 AccessToken: accessString, 188 RefreshToken: refreshToken, 189 TokenType: tokenType, 190 Scope: authReq.Parameters.Scope, 191 ExpiresIn: int64(eat.Sub(time.Now()).Seconds()), 192 Sub: repo.Repo.Did, 193 }) 194 } 195 196 if req.GrantType == "refresh_token" { 197 if req.RefreshToken == nil { 198 return helpers.InputError(e, to.StringPtr(`"refresh_token" is required`)) 199 } 200 201 var oauthToken provider.OauthToken 202 if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE refresh_token = ?", nil, req.RefreshToken).Scan(&oauthToken).Error; err != nil { 203 s.logger.Error("error finding oauth token by refresh token", "error", err, "refresh_token", req.RefreshToken) 204 return helpers.ServerError(e, nil) 205 } 206 207 if client.Metadata.ClientID != oauthToken.ClientId { 208 return helpers.InputError(e, to.StringPtr(`"client_id" mismatch`)) 209 } 210 211 if clientAuth.Method != oauthToken.ClientAuth.Method { 212 return helpers.InputError(e, to.StringPtr(`"client authentication method mismatch`)) 213 } 214 215 if *oauthToken.Parameters.DpopJkt != proof.JKT { 216 return helpers.InputError(e, to.StringPtr("dpop proof does not match expected jkt")) 217 } 218 219 ageRes := oauth.GetSessionAgeFromToken(oauthToken) 220 221 if ageRes.SessionExpired { 222 return helpers.InputError(e, to.StringPtr("Session expired")) 223 } 224 225 if ageRes.RefreshExpired { 226 return helpers.InputError(e, to.StringPtr("Refresh token expired")) 227 } 228 229 if client.Metadata.DpopBoundAccessTokens && oauthToken.Parameters.DpopJkt == nil { 230 // why? ref impl 231 return helpers.InputError(e, to.StringPtr("dpop jkt is required for dpop bound access tokens")) 232 } 233 234 nextTokenId := oauth.GenerateTokenId() 235 nextRefreshToken := oauth.GenerateRefreshToken() 236 237 now := time.Now() 238 eat := now.Add(constants.TokenMaxAge) 239 240 accessClaims := jwt.MapClaims{ 241 "scope": oauthToken.Parameters.Scope, 242 "aud": s.config.Did, 243 "sub": oauthToken.Sub, 244 "iat": now.Unix(), 245 "exp": eat.Unix(), 246 "jti": nextTokenId, 247 "client_id": oauthToken.ClientId, 248 } 249 250 if oauthToken.Parameters.DpopJkt != nil { 251 accessClaims["cnf"] = *&oauthToken.Parameters.DpopJkt 252 } 253 254 accessToken := jwt.NewWithClaims(jwt.SigningMethodES256, accessClaims) 255 accessString, err := accessToken.SignedString(s.privateKey) 256 if err != nil { 257 return err 258 } 259 260 if err := s.db.Exec("UPDATE oauth_tokens SET token = ?, refresh_token = ?, expires_at = ?, updated_at = ? WHERE refresh_token = ?", nil, accessString, nextRefreshToken, eat, now, *req.RefreshToken).Error; err != nil { 261 s.logger.Error("error updating token", "error", err) 262 return helpers.ServerError(e, nil) 263 } 264 265 // prob not needed 266 tokenType := "Bearer" 267 if oauthToken.Parameters.DpopJkt != nil { 268 tokenType = "DPoP" 269 } 270 271 return e.JSON(200, OauthTokenResponse{ 272 AccessToken: accessString, 273 RefreshToken: nextRefreshToken, 274 TokenType: tokenType, 275 Scope: oauthToken.Parameters.Scope, 276 ExpiresIn: int64(eat.Sub(time.Now()).Seconds()), 277 Sub: oauthToken.Sub, 278 }) 279 } 280 281 return helpers.InputError(e, to.StringPtr(fmt.Sprintf(`grant type "%s" is not supported`, req.GrantType))) 282}