An atproto PDS written in Go
1package provider 2 3import ( 4 "context" 5 "crypto" 6 "database/sql/driver" 7 "encoding/base64" 8 "encoding/json" 9 "errors" 10 "fmt" 11 "time" 12 13 "github.com/golang-jwt/jwt/v4" 14 "github.com/haileyok/cocoon/oauth" 15 "github.com/haileyok/cocoon/oauth/constants" 16 "github.com/haileyok/cocoon/oauth/dpop" 17) 18 19type ClientAuth struct { 20 Method string 21 Alg string 22 Kid string 23 Jkt string 24 Jti string 25 Exp *float64 26} 27 28func (ca *ClientAuth) Scan(value any) error { 29 b, ok := value.([]byte) 30 if !ok { 31 return fmt.Errorf("failed to unmarshal OauthParRequest value") 32 } 33 return json.Unmarshal(b, ca) 34} 35 36func (ca ClientAuth) Value() (driver.Value, error) { 37 return json.Marshal(ca) 38} 39 40type AuthenticateClientOptions struct { 41 AllowMissingDpopProof bool 42} 43 44type AuthenticateClientRequestBase struct { 45 ClientID string `form:"client_id" json:"client_id" validate:"required"` 46 ClientAssertionType *string `form:"client_assertion_type" json:"client_assertion_type,omitempty"` 47 ClientAssertion *string `form:"client_assertion" json:"client_assertion,omitempty"` 48} 49 50func (p *Provider) AuthenticateClient(ctx context.Context, req AuthenticateClientRequestBase, proof *dpop.Proof, opts *AuthenticateClientOptions) (*oauth.Client, *ClientAuth, error) { 51 client, err := p.ClientManager.GetClient(ctx, req.ClientID) 52 if err != nil { 53 return nil, nil, fmt.Errorf("failed to get client: %w", err) 54 } 55 56 if client.Metadata.DpopBoundAccessTokens && proof == nil && (opts == nil || !opts.AllowMissingDpopProof) { 57 return nil, nil, errors.New("dpop proof required") 58 } 59 60 if proof != nil && !client.Metadata.DpopBoundAccessTokens { 61 return nil, nil, errors.New("dpop proof not allowed for this client") 62 } 63 64 clientAuth, err := p.Authenticate(ctx, req, client) 65 if err != nil { 66 return nil, nil, err 67 } 68 69 return client, clientAuth, nil 70} 71 72func (p *Provider) Authenticate(_ context.Context, req AuthenticateClientRequestBase, client *oauth.Client) (*ClientAuth, error) { 73 metadata := client.Metadata 74 75 if metadata.TokenEndpointAuthMethod == "none" { 76 return &ClientAuth{ 77 Method: "none", 78 }, nil 79 } 80 81 if metadata.TokenEndpointAuthMethod == "private_key_jwt" { 82 if req.ClientAssertion == nil { 83 return nil, errors.New(`client authentication method "private_key_jwt" requires a "client_assertion`) 84 } 85 86 if req.ClientAssertionType == nil || *req.ClientAssertionType != constants.ClientAssertionTypeJwtBearer { 87 return nil, fmt.Errorf("unsupported client_assertion_type %s", *req.ClientAssertionType) 88 } 89 90 token, _, err := jwt.NewParser().ParseUnverified(*req.ClientAssertion, jwt.MapClaims{}) 91 if err != nil { 92 return nil, fmt.Errorf("error parsing client assertion: %w", err) 93 } 94 95 kid, ok := token.Header["kid"].(string) 96 if !ok || kid == "" { 97 return nil, errors.New(`"kid" required in client_assertion`) 98 } 99 100 var rawKey any 101 if err := client.JWKS.Raw(&rawKey); err != nil { 102 return nil, fmt.Errorf("failed to extract raw key: %w", err) 103 } 104 105 token, err = jwt.Parse(*req.ClientAssertion, func(token *jwt.Token) (any, error) { 106 if token.Method.Alg() != jwt.SigningMethodES256.Alg() { 107 return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) 108 } 109 110 return rawKey, nil 111 }) 112 if err != nil { 113 return nil, fmt.Errorf(`unable to verify "client_assertion" jwt: %w`, err) 114 } 115 116 if !token.Valid { 117 return nil, errors.New("client_assertion jwt is invalid") 118 } 119 120 claims, ok := token.Claims.(jwt.MapClaims) 121 if !ok { 122 return nil, errors.New("no claims in client_assertion jwt") 123 } 124 125 sub, _ := claims["sub"].(string) 126 if sub != metadata.ClientID { 127 return nil, errors.New("subject must be client_id") 128 } 129 130 aud, _ := claims["aud"].(string) 131 if aud != "" && aud != "https://"+p.hostname { 132 return nil, fmt.Errorf("audience must be %s, got %s", "https://"+p.hostname, aud) 133 } 134 135 iat, iatOk := claims["iat"].(float64) 136 if !iatOk { 137 return nil, errors.New(`invalid client_assertion jwt: "iat" is missing`) 138 } 139 140 iatTime := time.Unix(int64(iat), 0) 141 if time.Since(iatTime) > constants.ClientAssertionMaxAge { 142 return nil, errors.New("client_assertion jwt too old") 143 } 144 145 jti, _ := claims["jti"].(string) 146 if jti == "" { 147 return nil, errors.New(`invalid client_assertion jwt: "jti" is missing`) 148 } 149 150 var exp *float64 151 if maybeExp, ok := claims["exp"].(float64); ok { 152 exp = &maybeExp 153 } 154 155 alg := token.Header["alg"].(string) 156 157 thumbBytes, err := client.JWKS.Thumbprint(crypto.SHA256) 158 if err != nil { 159 return nil, fmt.Errorf("failed to calculate thumbprint: %w", err) 160 } 161 162 thumb := base64.RawURLEncoding.EncodeToString(thumbBytes) 163 164 return &ClientAuth{ 165 Method: "private_key_jwt", 166 Jti: jti, 167 Exp: exp, 168 Jkt: thumb, 169 Alg: alg, 170 Kid: kid, 171 }, nil 172 } 173 174 return nil, fmt.Errorf("auth method %s is not implemented in this pds", metadata.TokenEndpointAuthMethod) 175}