An atproto PDS written in Go
at main 6.3 kB view raw
1package dpop 2 3import ( 4 "crypto" 5 "crypto/sha256" 6 "encoding/base64" 7 "encoding/json" 8 "errors" 9 "fmt" 10 "log/slog" 11 "net/http" 12 "net/url" 13 "strings" 14 "time" 15 16 "github.com/golang-jwt/jwt/v4" 17 "github.com/haileyok/cocoon/internal/helpers" 18 "github.com/haileyok/cocoon/oauth/constants" 19 "github.com/lestrrat-go/jwx/v2/jwa" 20 "github.com/lestrrat-go/jwx/v2/jwk" 21) 22 23type Manager struct { 24 nonce *Nonce 25 jtiCache *jtiCache 26 logger *slog.Logger 27 hostname string 28} 29 30type ManagerArgs struct { 31 NonceSecret []byte 32 NonceRotationInterval time.Duration 33 OnNonceSecretCreated func([]byte) 34 JTICacheSize int 35 Logger *slog.Logger 36 Hostname string 37} 38 39var ( 40 ErrUseDpopNonce = errors.New("use_dpop_nonce") 41) 42 43func NewManager(args ManagerArgs) *Manager { 44 if args.Logger == nil { 45 args.Logger = slog.Default() 46 } 47 48 if args.JTICacheSize == 0 { 49 args.JTICacheSize = 100_000 50 } 51 52 if args.NonceSecret == nil { 53 args.Logger.Warn("nonce secret passed to dpop manager was nil. existing sessions may break. consider saving and restoring your nonce.") 54 } 55 56 return &Manager{ 57 nonce: NewNonce(NonceArgs{ 58 RotationInterval: args.NonceRotationInterval, 59 Secret: args.NonceSecret, 60 OnSecretCreated: args.OnNonceSecretCreated, 61 }), 62 jtiCache: newJTICache(args.JTICacheSize), 63 logger: args.Logger, 64 hostname: args.Hostname, 65 } 66} 67 68func (dm *Manager) CheckProof(reqMethod, reqUrl string, headers http.Header, accessToken *string) (*Proof, error) { 69 if reqMethod == "" { 70 return nil, errors.New("HTTP method is required") 71 } 72 73 if !strings.HasPrefix(reqUrl, "https://") { 74 reqUrl = "https://" + dm.hostname + reqUrl 75 } 76 77 proof := extractProof(headers) 78 79 if proof == "" { 80 return nil, nil 81 } 82 83 parser := jwt.NewParser(jwt.WithoutClaimsValidation()) 84 var token *jwt.Token 85 86 token, _, err := parser.ParseUnverified(proof, jwt.MapClaims{}) 87 if err != nil { 88 return nil, fmt.Errorf("could not parse dpop proof jwt: %w", err) 89 } 90 91 typ, _ := token.Header["typ"].(string) 92 if typ != "dpop+jwt" { 93 return nil, errors.New(`invalid dpop proof jwt: "typ" must be 'dpop+jwt'`) 94 } 95 96 dpopJwk, jwkOk := token.Header["jwk"].(map[string]any) 97 if !jwkOk { 98 return nil, errors.New(`invalid dpop proof jwt: "jwk" is missing in header`) 99 } 100 101 jwkb, err := json.Marshal(dpopJwk) 102 if err != nil { 103 return nil, fmt.Errorf("failed to marshal jwk: %w", err) 104 } 105 106 key, err := jwk.ParseKey(jwkb) 107 if err != nil { 108 return nil, fmt.Errorf("failed to parse jwk: %w", err) 109 } 110 111 var pubKey any 112 if err := key.Raw(&pubKey); err != nil { 113 return nil, fmt.Errorf("failed to get raw public key: %w", err) 114 } 115 116 token, err = jwt.Parse(proof, func(t *jwt.Token) (any, error) { 117 alg := t.Header["alg"].(string) 118 119 switch key.KeyType() { 120 case jwa.EC: 121 if !strings.HasPrefix(alg, "ES") { 122 return nil, fmt.Errorf("algorithm %s doesn't match EC key type", alg) 123 } 124 case jwa.RSA: 125 if !strings.HasPrefix(alg, "RS") && !strings.HasPrefix(alg, "PS") { 126 return nil, fmt.Errorf("algorithm %s doesn't match RSA key type", alg) 127 } 128 case jwa.OKP: 129 if alg != "EdDSA" { 130 return nil, fmt.Errorf("algorithm %s doesn't match OKP key type", alg) 131 } 132 } 133 134 return pubKey, nil 135 }, jwt.WithValidMethods([]string{"ES256", "ES384", "ES512", "RS256", "RS384", "RS512", "PS256", "PS384", "PS512", "EdDSA"})) 136 if err != nil { 137 return nil, fmt.Errorf("could not verify dpop proof jwt: %w", err) 138 } 139 140 if !token.Valid { 141 return nil, errors.New("dpop proof jwt is invalid") 142 } 143 144 claims, ok := token.Claims.(jwt.MapClaims) 145 if !ok { 146 return nil, errors.New("no claims in dpop proof jwt") 147 } 148 149 iat, iatOk := claims["iat"].(float64) 150 if !iatOk { 151 return nil, errors.New(`invalid dpop proof jwt: "iat" is missing`) 152 } 153 154 iatTime := time.Unix(int64(iat), 0) 155 now := time.Now() 156 157 if now.Sub(iatTime) > constants.DpopNonceMaxAge+constants.DpopCheckTolerance { 158 return nil, errors.New("dpop proof too old") 159 } 160 161 if iatTime.Sub(now) > constants.DpopCheckTolerance { 162 return nil, errors.New("dpop proof iat is in the future") 163 } 164 165 jti, _ := claims["jti"].(string) 166 if jti == "" { 167 return nil, errors.New(`invalid dpop proof jwt: "jti" is missing`) 168 } 169 170 if dm.jtiCache.add(jti) { 171 return nil, errors.New("dpop proof replay detected") 172 } 173 174 htm, _ := claims["htm"].(string) 175 if htm == "" { 176 return nil, errors.New(`invalid dpop proof jwt: "htm" is missing`) 177 } 178 179 if htm != reqMethod { 180 return nil, errors.New(`invalid dpop proof jwt: "htm" mismatch`) 181 } 182 183 htu, _ := claims["htu"].(string) 184 if htu == "" { 185 return nil, errors.New(`invalid dpop proof jwt: "htu" is missing`) 186 } 187 188 parsedHtu, err := helpers.OauthParseHtu(htu) 189 if err != nil { 190 return nil, errors.New(`invalid dpop proof jwt: "htu" could not be parsed`) 191 } 192 193 u, _ := url.Parse(reqUrl) 194 if parsedHtu != helpers.OauthNormalizeHtu(u) { 195 return nil, fmt.Errorf(`invalid dpop proof jwt: "htu" mismatch. reqUrl: %s, parsed: %s, normalized: %s`, reqUrl, parsedHtu, helpers.OauthNormalizeHtu(u)) 196 } 197 198 nonce, _ := claims["nonce"].(string) 199 if nonce == "" { 200 // WARN: this _must_ be `use_dpop_nonce` for clients know they should make another request 201 return nil, ErrUseDpopNonce 202 } 203 204 if nonce != "" && !dm.nonce.Check(nonce) { 205 // WARN: this _must_ be `use_dpop_nonce` so that clients will fetch a new nonce 206 return nil, ErrUseDpopNonce 207 } 208 209 ath, _ := claims["ath"].(string) 210 211 if accessToken != nil && *accessToken != "" { 212 if ath == "" { 213 return nil, errors.New(`invalid dpop proof jwt: "ath" is required with access token`) 214 } 215 216 hash := sha256.Sum256([]byte(*accessToken)) 217 if ath != base64.RawURLEncoding.EncodeToString(hash[:]) { 218 return nil, errors.New(`invalid dpop proof jwt: "ath" mismatch`) 219 } 220 } else if ath != "" { 221 return nil, errors.New(`invalid dpop proof jwt: "ath" claim not allowed`) 222 } 223 224 thumbBytes, err := key.Thumbprint(crypto.SHA256) 225 if err != nil { 226 return nil, fmt.Errorf("failed to calculate thumbprint: %w", err) 227 } 228 229 thumb := base64.RawURLEncoding.EncodeToString(thumbBytes) 230 231 return &Proof{ 232 JTI: jti, 233 JKT: thumb, 234 HTM: htm, 235 HTU: htu, 236 }, nil 237} 238 239func extractProof(headers http.Header) string { 240 dpopHeaders := headers["Dpop"] 241 switch len(dpopHeaders) { 242 case 0: 243 return "" 244 case 1: 245 return dpopHeaders[0] 246 default: 247 return "" 248 } 249} 250 251func (dm *Manager) NextNonce() string { 252 return dm.nonce.NextNonce() 253}