1package dpop
2
3import (
4 "crypto"
5 "crypto/sha256"
6 "encoding/base64"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "log/slog"
11 "net/http"
12 "net/url"
13 "strings"
14 "time"
15
16 "github.com/golang-jwt/jwt/v4"
17 "github.com/haileyok/cocoon/internal/helpers"
18 "github.com/haileyok/cocoon/oauth/constants"
19 "github.com/lestrrat-go/jwx/v2/jwa"
20 "github.com/lestrrat-go/jwx/v2/jwk"
21)
22
23type Manager struct {
24 nonce *Nonce
25 jtiCache *jtiCache
26 logger *slog.Logger
27 hostname string
28}
29
30type ManagerArgs struct {
31 NonceSecret []byte
32 NonceRotationInterval time.Duration
33 OnNonceSecretCreated func([]byte)
34 JTICacheSize int
35 Logger *slog.Logger
36 Hostname string
37}
38
39var (
40 ErrUseDpopNonce = errors.New("use_dpop_nonce")
41)
42
43func NewManager(args ManagerArgs) *Manager {
44 if args.Logger == nil {
45 args.Logger = slog.Default()
46 }
47
48 if args.JTICacheSize == 0 {
49 args.JTICacheSize = 100_000
50 }
51
52 if args.NonceSecret == nil {
53 args.Logger.Warn("nonce secret passed to dpop manager was nil. existing sessions may break. consider saving and restoring your nonce.")
54 }
55
56 return &Manager{
57 nonce: NewNonce(NonceArgs{
58 RotationInterval: args.NonceRotationInterval,
59 Secret: args.NonceSecret,
60 OnSecretCreated: args.OnNonceSecretCreated,
61 }),
62 jtiCache: newJTICache(args.JTICacheSize),
63 logger: args.Logger,
64 hostname: args.Hostname,
65 }
66}
67
68func (dm *Manager) CheckProof(reqMethod, reqUrl string, headers http.Header, accessToken *string) (*Proof, error) {
69 if reqMethod == "" {
70 return nil, errors.New("HTTP method is required")
71 }
72
73 if !strings.HasPrefix(reqUrl, "https://") {
74 reqUrl = "https://" + dm.hostname + reqUrl
75 }
76
77 proof := extractProof(headers)
78
79 if proof == "" {
80 return nil, nil
81 }
82
83 parser := jwt.NewParser(jwt.WithoutClaimsValidation())
84 var token *jwt.Token
85
86 token, _, err := parser.ParseUnverified(proof, jwt.MapClaims{})
87 if err != nil {
88 return nil, fmt.Errorf("could not parse dpop proof jwt: %w", err)
89 }
90
91 typ, _ := token.Header["typ"].(string)
92 if typ != "dpop+jwt" {
93 return nil, errors.New(`invalid dpop proof jwt: "typ" must be 'dpop+jwt'`)
94 }
95
96 dpopJwk, jwkOk := token.Header["jwk"].(map[string]any)
97 if !jwkOk {
98 return nil, errors.New(`invalid dpop proof jwt: "jwk" is missing in header`)
99 }
100
101 jwkb, err := json.Marshal(dpopJwk)
102 if err != nil {
103 return nil, fmt.Errorf("failed to marshal jwk: %w", err)
104 }
105
106 key, err := jwk.ParseKey(jwkb)
107 if err != nil {
108 return nil, fmt.Errorf("failed to parse jwk: %w", err)
109 }
110
111 var pubKey any
112 if err := key.Raw(&pubKey); err != nil {
113 return nil, fmt.Errorf("failed to get raw public key: %w", err)
114 }
115
116 token, err = jwt.Parse(proof, func(t *jwt.Token) (any, error) {
117 alg := t.Header["alg"].(string)
118
119 switch key.KeyType() {
120 case jwa.EC:
121 if !strings.HasPrefix(alg, "ES") {
122 return nil, fmt.Errorf("algorithm %s doesn't match EC key type", alg)
123 }
124 case jwa.RSA:
125 if !strings.HasPrefix(alg, "RS") && !strings.HasPrefix(alg, "PS") {
126 return nil, fmt.Errorf("algorithm %s doesn't match RSA key type", alg)
127 }
128 case jwa.OKP:
129 if alg != "EdDSA" {
130 return nil, fmt.Errorf("algorithm %s doesn't match OKP key type", alg)
131 }
132 }
133
134 return pubKey, nil
135 }, jwt.WithValidMethods([]string{"ES256", "ES384", "ES512", "RS256", "RS384", "RS512", "PS256", "PS384", "PS512", "EdDSA"}))
136 if err != nil {
137 return nil, fmt.Errorf("could not verify dpop proof jwt: %w", err)
138 }
139
140 if !token.Valid {
141 return nil, errors.New("dpop proof jwt is invalid")
142 }
143
144 claims, ok := token.Claims.(jwt.MapClaims)
145 if !ok {
146 return nil, errors.New("no claims in dpop proof jwt")
147 }
148
149 iat, iatOk := claims["iat"].(float64)
150 if !iatOk {
151 return nil, errors.New(`invalid dpop proof jwt: "iat" is missing`)
152 }
153
154 iatTime := time.Unix(int64(iat), 0)
155 now := time.Now()
156
157 if now.Sub(iatTime) > constants.DpopNonceMaxAge+constants.DpopCheckTolerance {
158 return nil, errors.New("dpop proof too old")
159 }
160
161 if iatTime.Sub(now) > constants.DpopCheckTolerance {
162 return nil, errors.New("dpop proof iat is in the future")
163 }
164
165 jti, _ := claims["jti"].(string)
166 if jti == "" {
167 return nil, errors.New(`invalid dpop proof jwt: "jti" is missing`)
168 }
169
170 if dm.jtiCache.add(jti) {
171 return nil, errors.New("dpop proof replay detected")
172 }
173
174 htm, _ := claims["htm"].(string)
175 if htm == "" {
176 return nil, errors.New(`invalid dpop proof jwt: "htm" is missing`)
177 }
178
179 if htm != reqMethod {
180 return nil, errors.New(`invalid dpop proof jwt: "htm" mismatch`)
181 }
182
183 htu, _ := claims["htu"].(string)
184 if htu == "" {
185 return nil, errors.New(`invalid dpop proof jwt: "htu" is missing`)
186 }
187
188 parsedHtu, err := helpers.OauthParseHtu(htu)
189 if err != nil {
190 return nil, errors.New(`invalid dpop proof jwt: "htu" could not be parsed`)
191 }
192
193 u, _ := url.Parse(reqUrl)
194 if parsedHtu != helpers.OauthNormalizeHtu(u) {
195 return nil, fmt.Errorf(`invalid dpop proof jwt: "htu" mismatch. reqUrl: %s, parsed: %s, normalized: %s`, reqUrl, parsedHtu, helpers.OauthNormalizeHtu(u))
196 }
197
198 nonce, _ := claims["nonce"].(string)
199 if nonce == "" {
200 // WARN: this _must_ be `use_dpop_nonce` for clients know they should make another request
201 return nil, ErrUseDpopNonce
202 }
203
204 if nonce != "" && !dm.nonce.Check(nonce) {
205 // WARN: this _must_ be `use_dpop_nonce` so that clients will fetch a new nonce
206 return nil, ErrUseDpopNonce
207 }
208
209 ath, _ := claims["ath"].(string)
210
211 if accessToken != nil && *accessToken != "" {
212 if ath == "" {
213 return nil, errors.New(`invalid dpop proof jwt: "ath" is required with access token`)
214 }
215
216 hash := sha256.Sum256([]byte(*accessToken))
217 if ath != base64.RawURLEncoding.EncodeToString(hash[:]) {
218 return nil, errors.New(`invalid dpop proof jwt: "ath" mismatch`)
219 }
220 } else if ath != "" {
221 return nil, errors.New(`invalid dpop proof jwt: "ath" claim not allowed`)
222 }
223
224 thumbBytes, err := key.Thumbprint(crypto.SHA256)
225 if err != nil {
226 return nil, fmt.Errorf("failed to calculate thumbprint: %w", err)
227 }
228
229 thumb := base64.RawURLEncoding.EncodeToString(thumbBytes)
230
231 return &Proof{
232 JTI: jti,
233 JKT: thumb,
234 HTM: htm,
235 HTU: htu,
236 }, nil
237}
238
239func extractProof(headers http.Header) string {
240 dpopHeaders := headers["Dpop"]
241 switch len(dpopHeaders) {
242 case 0:
243 return ""
244 case 1:
245 return dpopHeaders[0]
246 default:
247 return ""
248 }
249}
250
251func (dm *Manager) NextNonce() string {
252 return dm.nonce.NextNonce()
253}