An atproto PDS written in Go
1package server 2 3import ( 4 "bytes" 5 "crypto/sha256" 6 "encoding/base64" 7 "fmt" 8 "slices" 9 "time" 10 11 "github.com/Azure/go-autorest/autorest/to" 12 "github.com/golang-jwt/jwt/v4" 13 "github.com/haileyok/cocoon/internal/helpers" 14 "github.com/haileyok/cocoon/oauth" 15 "github.com/haileyok/cocoon/oauth/constants" 16 "github.com/haileyok/cocoon/oauth/provider" 17 "github.com/labstack/echo/v4" 18) 19 20type OauthTokenRequest struct { 21 provider.AuthenticateClientRequestBase 22 GrantType string `form:"grant_type" json:"grant_type"` 23 Code *string `form:"code" json:"code,omitempty"` 24 CodeVerifier *string `form:"code_verifier" json:"code_verifier,omitempty"` 25 RedirectURI *string `form:"redirect_uri" json:"redirect_uri,omitempty"` 26 RefreshToken *string `form:"refresh_token" json:"refresh_token,omitempty"` 27} 28 29type OauthTokenResponse struct { 30 AccessToken string `json:"access_token"` 31 TokenType string `json:"token_type"` 32 RefreshToken string `json:"refresh_token"` 33 Scope string `json:"scope"` 34 ExpiresIn int64 `json:"expires_in"` 35 Sub string `json:"sub"` 36} 37 38func (s *Server) handleOauthToken(e echo.Context) error { 39 var req OauthTokenRequest 40 if err := e.Bind(&req); err != nil { 41 s.logger.Error("error binding token request", "error", err) 42 return helpers.ServerError(e, nil) 43 } 44 45 proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, e.Request().URL.String(), e.Request().Header, nil) 46 if err != nil { 47 s.logger.Error("error getting dpop proof", "error", err) 48 return helpers.InputError(e, to.StringPtr(err.Error())) 49 } 50 51 client, clientAuth, err := s.oauthProvider.AuthenticateClient(e.Request().Context(), req.AuthenticateClientRequestBase, proof, &provider.AuthenticateClientOptions{ 52 AllowMissingDpopProof: true, 53 }) 54 if err != nil { 55 s.logger.Error("error authenticating client", "error", err) 56 return helpers.InputError(e, to.StringPtr(err.Error())) 57 } 58 59 // TODO: this should come from an oauth provier config 60 if !slices.Contains([]string{"authorization_code", "refresh_token"}, req.GrantType) { 61 return helpers.InputError(e, to.StringPtr(fmt.Sprintf(`"%s" grant type is not supported by the server`, req.GrantType))) 62 } 63 64 if !slices.Contains(client.Metadata.GrantTypes, req.GrantType) { 65 return helpers.InputError(e, to.StringPtr(fmt.Sprintf(`"%s" grant type is not supported by the client`, req.GrantType))) 66 } 67 68 if req.GrantType == "authorization_code" { 69 if req.Code == nil { 70 return helpers.InputError(e, to.StringPtr(`"code" is required"`)) 71 } 72 73 var authReq provider.OauthAuthorizationRequest 74 // get the lil guy and delete him 75 if err := s.db.Raw("DELETE FROM oauth_authorization_requests WHERE code = ? RETURNING *", nil, *req.Code).Scan(&authReq).Error; err != nil { 76 s.logger.Error("error finding authorization request", "error", err) 77 return helpers.ServerError(e, nil) 78 } 79 80 if req.RedirectURI == nil || *req.RedirectURI != authReq.Parameters.RedirectURI { 81 return helpers.InputError(e, to.StringPtr(`"redirect_uri" mismatch`)) 82 } 83 84 if authReq.Parameters.CodeChallenge != nil { 85 if req.CodeVerifier == nil { 86 return helpers.InputError(e, to.StringPtr(`"code_verifier" is required`)) 87 } 88 89 if len(*req.CodeVerifier) < 43 { 90 return helpers.InputError(e, to.StringPtr(`"code_verifier" is too short`)) 91 } 92 93 switch *&authReq.Parameters.CodeChallengeMethod { 94 case "", "plain": 95 if authReq.Parameters.CodeChallenge != req.CodeVerifier { 96 return helpers.InputError(e, to.StringPtr("invalid code_verifier")) 97 } 98 case "S256": 99 inputChal, err := base64.RawURLEncoding.DecodeString(*authReq.Parameters.CodeChallenge) 100 if err != nil { 101 s.logger.Error("error decoding code challenge", "error", err) 102 return helpers.ServerError(e, nil) 103 } 104 105 h := sha256.New() 106 h.Write([]byte(*req.CodeVerifier)) 107 compdChal := h.Sum(nil) 108 109 if !bytes.Equal(inputChal, compdChal) { 110 return helpers.InputError(e, to.StringPtr("invalid code_verifier")) 111 } 112 default: 113 return helpers.InputError(e, to.StringPtr("unsupported code_challenge_method "+*&authReq.Parameters.CodeChallengeMethod)) 114 } 115 } else if req.CodeVerifier != nil { 116 return helpers.InputError(e, to.StringPtr("code_challenge parameter wasn't provided")) 117 } 118 119 repo, err := s.getRepoActorByDid(*authReq.Sub) 120 if err != nil { 121 helpers.InputError(e, to.StringPtr("unable to find actor")) 122 } 123 124 now := time.Now() 125 eat := now.Add(constants.TokenMaxAge) 126 id := oauth.GenerateTokenId() 127 128 refreshToken := oauth.GenerateRefreshToken() 129 130 accessClaims := jwt.MapClaims{ 131 "scope": authReq.Parameters.Scope, 132 "aud": s.config.Did, 133 "sub": repo.Repo.Did, 134 "iat": now.Unix(), 135 "exp": eat.Unix(), 136 "jti": id, 137 "client_id": authReq.ClientId, 138 } 139 140 if authReq.Parameters.DpopJkt != nil { 141 accessClaims["cnf"] = *authReq.Parameters.DpopJkt 142 } 143 144 accessToken := jwt.NewWithClaims(jwt.SigningMethodES256, accessClaims) 145 accessString, err := accessToken.SignedString(s.privateKey) 146 if err != nil { 147 return err 148 } 149 150 if err := s.db.Create(&provider.OauthToken{ 151 ClientId: authReq.ClientId, 152 ClientAuth: *clientAuth, 153 Parameters: authReq.Parameters, 154 ExpiresAt: eat, 155 DeviceId: "", 156 Sub: repo.Repo.Did, 157 Code: *authReq.Code, 158 Token: accessString, 159 RefreshToken: refreshToken, 160 }, nil).Error; err != nil { 161 s.logger.Error("error creating token in db", "error", err) 162 return helpers.ServerError(e, nil) 163 } 164 165 // prob not needed 166 tokenType := "Bearer" 167 if authReq.Parameters.DpopJkt != nil { 168 tokenType = "DPoP" 169 } 170 171 e.Response().Header().Set("content-type", "application/json") 172 173 return e.JSON(200, OauthTokenResponse{ 174 AccessToken: accessString, 175 RefreshToken: refreshToken, 176 TokenType: tokenType, 177 Scope: authReq.Parameters.Scope, 178 ExpiresIn: int64(eat.Sub(time.Now()).Seconds()), 179 Sub: repo.Repo.Did, 180 }) 181 } 182 183 if req.GrantType == "refresh_token" { 184 if req.RefreshToken == nil { 185 return helpers.InputError(e, to.StringPtr(`"refresh_token" is required`)) 186 } 187 188 var oauthToken provider.OauthToken 189 if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE refresh_token = ?", nil, req.RefreshToken).Scan(&oauthToken).Error; err != nil { 190 s.logger.Error("error finding oauth token by refresh token", "error", err, "refresh_token", req.RefreshToken) 191 return helpers.ServerError(e, nil) 192 } 193 194 if client.Metadata.ClientID != oauthToken.ClientId { 195 return helpers.InputError(e, to.StringPtr(`"client_id" mismatch`)) 196 } 197 198 if clientAuth.Method != oauthToken.ClientAuth.Method { 199 return helpers.InputError(e, to.StringPtr(`"client authentication method mismatch`)) 200 } 201 202 if *oauthToken.Parameters.DpopJkt != proof.JKT { 203 return helpers.InputError(e, to.StringPtr("dpop proof does not match expected jkt")) 204 } 205 206 sessionLifetime := constants.PublicClientSessionLifetime 207 refreshLifetime := constants.PublicClientRefreshLifetime 208 if clientAuth.Method != "none" { 209 sessionLifetime = constants.ConfidentialClientSessionLifetime 210 refreshLifetime = constants.ConfidentialClientRefreshLifetime 211 } 212 213 sessionAge := time.Since(oauthToken.CreatedAt) 214 if sessionAge > sessionLifetime { 215 return helpers.InputError(e, to.StringPtr("Session expired")) 216 } 217 218 refreshAge := time.Since(oauthToken.UpdatedAt) 219 if refreshAge > refreshLifetime { 220 return helpers.InputError(e, to.StringPtr("Refresh token expired")) 221 } 222 223 if client.Metadata.DpopBoundAccessTokens && oauthToken.Parameters.DpopJkt == nil { 224 // why? ref impl 225 return helpers.InputError(e, to.StringPtr("dpop jkt is required for dpop bound access tokens")) 226 } 227 228 nextTokenId := oauth.GenerateTokenId() 229 nextRefreshToken := oauth.GenerateRefreshToken() 230 231 now := time.Now() 232 eat := now.Add(constants.TokenMaxAge) 233 234 accessClaims := jwt.MapClaims{ 235 "scope": oauthToken.Parameters.Scope, 236 "aud": s.config.Did, 237 "sub": oauthToken.Sub, 238 "iat": now.Unix(), 239 "exp": eat.Unix(), 240 "jti": nextTokenId, 241 "client_id": oauthToken.ClientId, 242 } 243 244 if oauthToken.Parameters.DpopJkt != nil { 245 accessClaims["cnf"] = *&oauthToken.Parameters.DpopJkt 246 } 247 248 accessToken := jwt.NewWithClaims(jwt.SigningMethodES256, accessClaims) 249 accessString, err := accessToken.SignedString(s.privateKey) 250 if err != nil { 251 return err 252 } 253 254 if err := s.db.Exec("UPDATE oauth_tokens SET token = ?, refresh_token = ?, expires_at = ?, updated_at = ? WHERE refresh_token = ?", nil, accessString, nextRefreshToken, eat, now, *req.RefreshToken).Error; err != nil { 255 s.logger.Error("error updating token", "error", err) 256 return helpers.ServerError(e, nil) 257 } 258 259 // prob not needed 260 tokenType := "Bearer" 261 if oauthToken.Parameters.DpopJkt != nil { 262 tokenType = "DPoP" 263 } 264 265 return e.JSON(200, OauthTokenResponse{ 266 AccessToken: accessString, 267 RefreshToken: nextRefreshToken, 268 TokenType: tokenType, 269 Scope: oauthToken.Parameters.Scope, 270 ExpiresIn: int64(eat.Sub(time.Now()).Seconds()), 271 Sub: oauthToken.Sub, 272 }) 273 } 274 275 return helpers.InputError(e, to.StringPtr(fmt.Sprintf(`grant type "%s" is not supported`, req.GrantType))) 276}