An atproto PDS written in Go
1package dpop_manager 2 3import ( 4 "crypto" 5 "crypto/sha256" 6 "encoding/base64" 7 "encoding/json" 8 "errors" 9 "fmt" 10 "log/slog" 11 "net/http" 12 "net/url" 13 "strings" 14 "time" 15 16 "github.com/golang-jwt/jwt/v4" 17 "github.com/haileyok/cocoon/internal/helpers" 18 "github.com/haileyok/cocoon/oauth/constants" 19 "github.com/haileyok/cocoon/oauth/dpop" 20 "github.com/haileyok/cocoon/oauth/dpop/nonce" 21 "github.com/lestrrat-go/jwx/v2/jwa" 22 "github.com/lestrrat-go/jwx/v2/jwk" 23) 24 25type DpopManager struct { 26 nonce *nonce.Nonce 27 jtiCache *jtiCache 28 logger *slog.Logger 29 hostname string 30} 31 32type Args struct { 33 NonceSecret []byte 34 NonceRotationInterval time.Duration 35 OnNonceSecretCreated func([]byte) 36 JTICacheSize int 37 Logger *slog.Logger 38 Hostname string 39} 40 41func New(args Args) *DpopManager { 42 if args.Logger == nil { 43 args.Logger = slog.Default() 44 } 45 46 if args.JTICacheSize == 0 { 47 args.JTICacheSize = 100_000 48 } 49 50 if args.NonceSecret == nil { 51 args.Logger.Warn("nonce secret passed to dpop manager was nil. existing sessions may break. consider saving and restoring your nonce.") 52 } 53 54 return &DpopManager{ 55 nonce: nonce.NewNonce(nonce.Args{ 56 RotationInterval: args.NonceRotationInterval, 57 Secret: args.NonceSecret, 58 OnSecretCreated: args.OnNonceSecretCreated, 59 }), 60 jtiCache: newJTICache(args.JTICacheSize), 61 logger: args.Logger, 62 hostname: args.Hostname, 63 } 64} 65 66func (dm *DpopManager) CheckProof(reqMethod, reqUrl string, headers http.Header, accessToken *string) (*dpop.Proof, error) { 67 if reqMethod == "" { 68 return nil, errors.New("HTTP method is required") 69 } 70 71 if !strings.HasPrefix(reqUrl, "https://") { 72 reqUrl = "https://" + dm.hostname + reqUrl 73 } 74 75 proof := extractProof(headers) 76 77 if proof == "" { 78 return nil, nil 79 } 80 81 parser := jwt.NewParser(jwt.WithoutClaimsValidation()) 82 var token *jwt.Token 83 84 token, _, err := parser.ParseUnverified(proof, jwt.MapClaims{}) 85 if err != nil { 86 return nil, fmt.Errorf("could not parse dpop proof jwt: %w", err) 87 } 88 89 typ, _ := token.Header["typ"].(string) 90 if typ != "dpop+jwt" { 91 return nil, errors.New(`invalid dpop proof jwt: "typ" must be 'dpop+jwt'`) 92 } 93 94 dpopJwk, jwkOk := token.Header["jwk"].(map[string]any) 95 if !jwkOk { 96 return nil, errors.New(`invalid dpop proof jwt: "jwk" is missing in header`) 97 } 98 99 jwkb, err := json.Marshal(dpopJwk) 100 if err != nil { 101 return nil, fmt.Errorf("failed to marshal jwk: %w", err) 102 } 103 104 key, err := jwk.ParseKey(jwkb) 105 if err != nil { 106 return nil, fmt.Errorf("failed to parse jwk: %w", err) 107 } 108 109 var pubKey any 110 if err := key.Raw(&pubKey); err != nil { 111 return nil, fmt.Errorf("failed to get raw public key: %w", err) 112 } 113 114 token, err = jwt.Parse(proof, func(t *jwt.Token) (any, error) { 115 alg := t.Header["alg"].(string) 116 117 switch key.KeyType() { 118 case jwa.EC: 119 if !strings.HasPrefix(alg, "ES") { 120 return nil, fmt.Errorf("algorithm %s doesn't match EC key type", alg) 121 } 122 case jwa.RSA: 123 if !strings.HasPrefix(alg, "RS") && !strings.HasPrefix(alg, "PS") { 124 return nil, fmt.Errorf("algorithm %s doesn't match RSA key type", alg) 125 } 126 case jwa.OKP: 127 if alg != "EdDSA" { 128 return nil, fmt.Errorf("algorithm %s doesn't match OKP key type", alg) 129 } 130 } 131 132 return pubKey, nil 133 }, jwt.WithValidMethods([]string{"ES256", "ES384", "ES512", "RS256", "RS384", "RS512", "PS256", "PS384", "PS512", "EdDSA"})) 134 if err != nil { 135 return nil, fmt.Errorf("could not verify dpop proof jwt: %w", err) 136 } 137 138 if !token.Valid { 139 return nil, errors.New("dpop proof jwt is invalid") 140 } 141 142 claims, ok := token.Claims.(jwt.MapClaims) 143 if !ok { 144 return nil, errors.New("no claims in dpop proof jwt") 145 } 146 147 iat, iatOk := claims["iat"].(float64) 148 if !iatOk { 149 return nil, errors.New(`invalid dpop proof jwt: "iat" is missing`) 150 } 151 152 iatTime := time.Unix(int64(iat), 0) 153 now := time.Now() 154 155 if now.Sub(iatTime) > constants.DpopNonceMaxAge+constants.DpopCheckTolerance { 156 return nil, errors.New("dpop proof too old") 157 } 158 159 if iatTime.Sub(now) > constants.DpopCheckTolerance { 160 return nil, errors.New("dpop proof iat is in the future") 161 } 162 163 jti, _ := claims["jti"].(string) 164 if jti == "" { 165 return nil, errors.New(`invalid dpop proof jwt: "jti" is missing`) 166 } 167 168 if dm.jtiCache.add(jti) { 169 return nil, errors.New("dpop proof replay detected") 170 } 171 172 htm, _ := claims["htm"].(string) 173 if htm == "" { 174 return nil, errors.New(`invalid dpop proof jwt: "htm" is missing`) 175 } 176 177 if htm != reqMethod { 178 return nil, errors.New(`invalid dpop proof jwt: "htm" mismatch`) 179 } 180 181 htu, _ := claims["htu"].(string) 182 if htu == "" { 183 return nil, errors.New(`invalid dpop proof jwt: "htu" is missing`) 184 } 185 186 parsedHtu, err := helpers.OauthParseHtu(htu) 187 if err != nil { 188 return nil, errors.New(`invalid dpop proof jwt: "htu" could not be parsed`) 189 } 190 191 u, _ := url.Parse(reqUrl) 192 if parsedHtu != helpers.OauthNormalizeHtu(u) { 193 return nil, fmt.Errorf(`invalid dpop proof jwt: "htu" mismatch. reqUrl: %s, parsed: %s, normalized: %s`, reqUrl, parsedHtu, helpers.OauthNormalizeHtu(u)) 194 } 195 196 nonce, _ := claims["nonce"].(string) 197 if nonce == "" { 198 // WARN: this _must_ be `use_dpop_nonce` for clients know they should make another request 199 return nil, errors.New("use_dpop_nonce") 200 } 201 202 if nonce != "" && !dm.nonce.Check(nonce) { 203 // WARN: this _must_ be `use_dpop_nonce` so that clients will fetch a new nonce 204 return nil, errors.New("use_dpop_nonce") 205 } 206 207 ath, _ := claims["ath"].(string) 208 209 if accessToken != nil && *accessToken != "" { 210 if ath == "" { 211 return nil, errors.New(`invalid dpop proof jwt: "ath" is required with access token`) 212 } 213 214 hash := sha256.Sum256([]byte(*accessToken)) 215 if ath != base64.RawURLEncoding.EncodeToString(hash[:]) { 216 return nil, errors.New(`invalid dpop proof jwt: "ath" mismatch`) 217 } 218 } else if ath != "" { 219 return nil, errors.New(`invalid dpop proof jwt: "ath" claim not allowed`) 220 } 221 222 thumbBytes, err := key.Thumbprint(crypto.SHA256) 223 if err != nil { 224 return nil, fmt.Errorf("failed to calculate thumbprint: %w", err) 225 } 226 227 thumb := base64.RawURLEncoding.EncodeToString(thumbBytes) 228 229 return &dpop.Proof{ 230 JTI: jti, 231 JKT: thumb, 232 HTM: htm, 233 HTU: htu, 234 }, nil 235} 236 237func extractProof(headers http.Header) string { 238 dpopHeaders := headers["Dpop"] 239 switch len(dpopHeaders) { 240 case 0: 241 return "" 242 case 1: 243 return dpopHeaders[0] 244 default: 245 return "" 246 } 247} 248 249func (dm *DpopManager) NextNonce() string { 250 return dm.nonce.NextNonce() 251}