1package dpop_manager
2
3import (
4 "crypto"
5 "crypto/sha256"
6 "encoding/base64"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "log/slog"
11 "net/http"
12 "net/url"
13 "strings"
14 "time"
15
16 "github.com/golang-jwt/jwt/v4"
17 "github.com/haileyok/cocoon/internal/helpers"
18 "github.com/haileyok/cocoon/oauth/constants"
19 "github.com/haileyok/cocoon/oauth/dpop"
20 "github.com/haileyok/cocoon/oauth/dpop/nonce"
21 "github.com/lestrrat-go/jwx/v2/jwa"
22 "github.com/lestrrat-go/jwx/v2/jwk"
23)
24
25type DpopManager struct {
26 nonce *nonce.Nonce
27 jtiCache *jtiCache
28 logger *slog.Logger
29 hostname string
30}
31
32type Args struct {
33 NonceSecret []byte
34 NonceRotationInterval time.Duration
35 OnNonceSecretCreated func([]byte)
36 JTICacheSize int
37 Logger *slog.Logger
38 Hostname string
39}
40
41func New(args Args) *DpopManager {
42 if args.Logger == nil {
43 args.Logger = slog.Default()
44 }
45
46 if args.JTICacheSize == 0 {
47 args.JTICacheSize = 100_000
48 }
49
50 if args.NonceSecret == nil {
51 args.Logger.Warn("nonce secret passed to dpop manager was nil. existing sessions may break. consider saving and restoring your nonce.")
52 }
53
54 return &DpopManager{
55 nonce: nonce.NewNonce(nonce.Args{
56 RotationInterval: args.NonceRotationInterval,
57 Secret: args.NonceSecret,
58 OnSecretCreated: args.OnNonceSecretCreated,
59 }),
60 jtiCache: newJTICache(args.JTICacheSize),
61 logger: args.Logger,
62 hostname: args.Hostname,
63 }
64}
65
66func (dm *DpopManager) CheckProof(reqMethod, reqUrl string, headers http.Header, accessToken *string) (*dpop.Proof, error) {
67 if reqMethod == "" {
68 return nil, errors.New("HTTP method is required")
69 }
70
71 if !strings.HasPrefix(reqUrl, "https://") {
72 reqUrl = "https://" + dm.hostname + reqUrl
73 }
74
75 proof := extractProof(headers)
76
77 if proof == "" {
78 return nil, nil
79 }
80
81 parser := jwt.NewParser(jwt.WithoutClaimsValidation())
82 var token *jwt.Token
83
84 token, _, err := parser.ParseUnverified(proof, jwt.MapClaims{})
85 if err != nil {
86 return nil, fmt.Errorf("could not parse dpop proof jwt: %w", err)
87 }
88
89 typ, _ := token.Header["typ"].(string)
90 if typ != "dpop+jwt" {
91 return nil, errors.New(`invalid dpop proof jwt: "typ" must be 'dpop+jwt'`)
92 }
93
94 dpopJwk, jwkOk := token.Header["jwk"].(map[string]any)
95 if !jwkOk {
96 return nil, errors.New(`invalid dpop proof jwt: "jwk" is missing in header`)
97 }
98
99 jwkb, err := json.Marshal(dpopJwk)
100 if err != nil {
101 return nil, fmt.Errorf("failed to marshal jwk: %w", err)
102 }
103
104 key, err := jwk.ParseKey(jwkb)
105 if err != nil {
106 return nil, fmt.Errorf("failed to parse jwk: %w", err)
107 }
108
109 var pubKey any
110 if err := key.Raw(&pubKey); err != nil {
111 return nil, fmt.Errorf("failed to get raw public key: %w", err)
112 }
113
114 token, err = jwt.Parse(proof, func(t *jwt.Token) (any, error) {
115 alg := t.Header["alg"].(string)
116
117 switch key.KeyType() {
118 case jwa.EC:
119 if !strings.HasPrefix(alg, "ES") {
120 return nil, fmt.Errorf("algorithm %s doesn't match EC key type", alg)
121 }
122 case jwa.RSA:
123 if !strings.HasPrefix(alg, "RS") && !strings.HasPrefix(alg, "PS") {
124 return nil, fmt.Errorf("algorithm %s doesn't match RSA key type", alg)
125 }
126 case jwa.OKP:
127 if alg != "EdDSA" {
128 return nil, fmt.Errorf("algorithm %s doesn't match OKP key type", alg)
129 }
130 }
131
132 return pubKey, nil
133 }, jwt.WithValidMethods([]string{"ES256", "ES384", "ES512", "RS256", "RS384", "RS512", "PS256", "PS384", "PS512", "EdDSA"}))
134 if err != nil {
135 return nil, fmt.Errorf("could not verify dpop proof jwt: %w", err)
136 }
137
138 if !token.Valid {
139 return nil, errors.New("dpop proof jwt is invalid")
140 }
141
142 claims, ok := token.Claims.(jwt.MapClaims)
143 if !ok {
144 return nil, errors.New("no claims in dpop proof jwt")
145 }
146
147 iat, iatOk := claims["iat"].(float64)
148 if !iatOk {
149 return nil, errors.New(`invalid dpop proof jwt: "iat" is missing`)
150 }
151
152 iatTime := time.Unix(int64(iat), 0)
153 now := time.Now()
154
155 if now.Sub(iatTime) > constants.DpopNonceMaxAge+constants.DpopCheckTolerance {
156 return nil, errors.New("dpop proof too old")
157 }
158
159 if iatTime.Sub(now) > constants.DpopCheckTolerance {
160 return nil, errors.New("dpop proof iat is in the future")
161 }
162
163 jti, _ := claims["jti"].(string)
164 if jti == "" {
165 return nil, errors.New(`invalid dpop proof jwt: "jti" is missing`)
166 }
167
168 if dm.jtiCache.add(jti) {
169 return nil, errors.New("dpop proof replay detected")
170 }
171
172 htm, _ := claims["htm"].(string)
173 if htm == "" {
174 return nil, errors.New(`invalid dpop proof jwt: "htm" is missing`)
175 }
176
177 if htm != reqMethod {
178 return nil, errors.New(`invalid dpop proof jwt: "htm" mismatch`)
179 }
180
181 htu, _ := claims["htu"].(string)
182 if htu == "" {
183 return nil, errors.New(`invalid dpop proof jwt: "htu" is missing`)
184 }
185
186 parsedHtu, err := helpers.OauthParseHtu(htu)
187 if err != nil {
188 return nil, errors.New(`invalid dpop proof jwt: "htu" could not be parsed`)
189 }
190
191 u, _ := url.Parse(reqUrl)
192 if parsedHtu != helpers.OauthNormalizeHtu(u) {
193 return nil, fmt.Errorf(`invalid dpop proof jwt: "htu" mismatch. reqUrl: %s, parsed: %s, normalized: %s`, reqUrl, parsedHtu, helpers.OauthNormalizeHtu(u))
194 }
195
196 nonce, _ := claims["nonce"].(string)
197 if nonce == "" {
198 // WARN: this _must_ be `use_dpop_nonce` for clients know they should make another request
199 return nil, errors.New("use_dpop_nonce")
200 }
201
202 if nonce != "" && !dm.nonce.Check(nonce) {
203 // WARN: this _must_ be `use_dpop_nonce` so that clients will fetch a new nonce
204 return nil, errors.New("use_dpop_nonce")
205 }
206
207 ath, _ := claims["ath"].(string)
208
209 if accessToken != nil && *accessToken != "" {
210 if ath == "" {
211 return nil, errors.New(`invalid dpop proof jwt: "ath" is required with access token`)
212 }
213
214 hash := sha256.Sum256([]byte(*accessToken))
215 if ath != base64.RawURLEncoding.EncodeToString(hash[:]) {
216 return nil, errors.New(`invalid dpop proof jwt: "ath" mismatch`)
217 }
218 } else if ath != "" {
219 return nil, errors.New(`invalid dpop proof jwt: "ath" claim not allowed`)
220 }
221
222 thumbBytes, err := key.Thumbprint(crypto.SHA256)
223 if err != nil {
224 return nil, fmt.Errorf("failed to calculate thumbprint: %w", err)
225 }
226
227 thumb := base64.RawURLEncoding.EncodeToString(thumbBytes)
228
229 return &dpop.Proof{
230 JTI: jti,
231 JKT: thumb,
232 HTM: htm,
233 HTU: htu,
234 }, nil
235}
236
237func extractProof(headers http.Header) string {
238 dpopHeaders := headers["Dpop"]
239 switch len(dpopHeaders) {
240 case 0:
241 return ""
242 case 1:
243 return dpopHeaders[0]
244 default:
245 return ""
246 }
247}
248
249func (dm *DpopManager) NextNonce() string {
250 return dm.nonce.NextNonce()
251}