A community based topic aggregation platform built on atproto
1package oauth
2
3import (
4 "crypto/ecdsa"
5 "crypto/elliptic"
6 "crypto/rand"
7 "crypto/sha256"
8 "encoding/base64"
9 "encoding/json"
10 "fmt"
11 "time"
12
13 "github.com/lestrrat-go/jwx/v2/jwa"
14 "github.com/lestrrat-go/jwx/v2/jwk"
15 "github.com/lestrrat-go/jwx/v2/jws"
16 "github.com/lestrrat-go/jwx/v2/jwt"
17)
18
19// DPoP (Demonstrating Proof of Possession) - RFC 9449
20// Binds access tokens to specific clients using cryptographic proofs
21
22// GenerateDPoPKey generates a new ES256 (NIST P-256) keypair for DPoP
23// Each OAuth session should have its own unique DPoP key
24func GenerateDPoPKey() (jwk.Key, error) {
25 // Generate ES256 private key
26 privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
27 if err != nil {
28 return nil, fmt.Errorf("failed to generate ECDSA key: %w", err)
29 }
30
31 // Convert to JWK
32 jwkKey, err := jwk.FromRaw(privateKey)
33 if err != nil {
34 return nil, fmt.Errorf("failed to create JWK from private key: %w", err)
35 }
36
37 // Set JWK parameters
38 if err := jwkKey.Set(jwk.AlgorithmKey, jwa.ES256); err != nil {
39 return nil, fmt.Errorf("failed to set algorithm: %w", err)
40 }
41 if err := jwkKey.Set(jwk.KeyUsageKey, "sig"); err != nil {
42 return nil, fmt.Errorf("failed to set key usage: %w", err)
43 }
44
45 return jwkKey, nil
46}
47
48// CreateDPoPProof creates a DPoP proof JWT for HTTP requests
49// Parameters:
50// - privateKey: The DPoP private key (ES256) as JWK
51// - method: HTTP method (e.g., "POST", "GET")
52// - uri: Full HTTP URI (e.g., "https://pds.example.com/xrpc/com.atproto.server.getSession")
53// - nonce: Optional server-provided nonce (empty on first request, use nonce from 401 response on retry)
54// - accessToken: Optional access token hash (required when using access token)
55func CreateDPoPProof(privateKey jwk.Key, method, uri, nonce, accessToken string) (string, error) {
56 // Get public key for JWK thumbprint
57 pubKey, err := privateKey.PublicKey()
58 if err != nil {
59 return "", fmt.Errorf("failed to get public key: %w", err)
60 }
61
62 // Create JWT builder
63 builder := jwt.NewBuilder().
64 Claim("htm", method). // HTTP method
65 Claim("htu", uri). // HTTP URI
66 Claim("iat", time.Now().Unix()). // Issued at
67 Claim("jti", generateJTI()) // Unique JWT ID
68
69 // Add nonce if provided (required after first DPoP request)
70 if nonce != "" {
71 builder = builder.Claim("nonce", nonce)
72 }
73
74 // Add access token hash if provided (required when using access token)
75 if accessToken != "" {
76 ath := hashAccessToken(accessToken)
77 builder = builder.Claim("ath", ath)
78 }
79
80 // Build the token
81 token, err := builder.Build()
82 if err != nil {
83 return "", fmt.Errorf("failed to build JWT: %w", err)
84 }
85
86 // Serialize the token payload to JSON
87 payloadBytes, err := json.Marshal(token)
88 if err != nil {
89 return "", fmt.Errorf("failed to marshal token: %w", err)
90 }
91
92 // Create headers with DPoP-specific fields
93 // RFC 9449 requires the "jwk" header to contain the public key as a JSON object
94 headers := jws.NewHeaders()
95 if err := headers.Set(jws.AlgorithmKey, jwa.ES256); err != nil {
96 return "", fmt.Errorf("failed to set algorithm: %w", err)
97 }
98 if err := headers.Set(jws.TypeKey, "dpop+jwt"); err != nil {
99 return "", fmt.Errorf("failed to set type: %w", err)
100 }
101 // Set the public JWK directly - jwx library will handle serialization
102 if err := headers.Set(jws.JWKKey, pubKey); err != nil {
103 return "", fmt.Errorf("failed to set JWK: %w", err)
104 }
105
106 // Sign using jws.Sign to preserve custom headers
107 // (jwt.Sign() overrides headers, so we use jws.Sign() directly)
108 signed, err := jws.Sign(payloadBytes, jws.WithKey(jwa.ES256, privateKey, jws.WithProtectedHeaders(headers)))
109 if err != nil {
110 return "", fmt.Errorf("failed to sign JWT: %w", err)
111 }
112
113 return string(signed), nil
114}
115
116// generateJTI generates a unique JWT ID for DPoP proofs
117func generateJTI() string {
118 // Generate 16 random bytes
119 b := make([]byte, 16)
120 if _, err := rand.Read(b); err != nil {
121 // Fallback to timestamp-based ID
122 return fmt.Sprintf("%d", time.Now().UnixNano())
123 }
124 return base64.RawURLEncoding.EncodeToString(b)
125}
126
127// hashAccessToken creates the 'ath' (access token hash) claim
128// ath = base64url(SHA-256(access_token))
129func hashAccessToken(accessToken string) string {
130 hash := sha256.Sum256([]byte(accessToken))
131 return base64.RawURLEncoding.EncodeToString(hash[:])
132}
133
134// ParseJWKFromJSON parses a JWK from JSON bytes
135func ParseJWKFromJSON(data []byte) (jwk.Key, error) {
136 key, err := jwk.ParseKey(data)
137 if err != nil {
138 return nil, fmt.Errorf("failed to parse JWK: %w", err)
139 }
140 return key, nil
141}
142
143// JWKToJSON converts a JWK to JSON bytes
144func JWKToJSON(key jwk.Key) ([]byte, error) {
145 data, err := json.Marshal(key)
146 if err != nil {
147 return nil, fmt.Errorf("failed to marshal JWK: %w", err)
148 }
149 return data, nil
150}
151
152// GetPublicJWKS creates a JWKS (JSON Web Key Set) response for the public key
153// This is served at /oauth/jwks.json
154func GetPublicJWKS(privateKey jwk.Key) (jwk.Set, error) {
155 pubKey, err := privateKey.PublicKey()
156 if err != nil {
157 return nil, fmt.Errorf("failed to get public key: %w", err)
158 }
159
160 // Create JWK Set
161 set := jwk.NewSet()
162 if err := set.AddKey(pubKey); err != nil {
163 return nil, fmt.Errorf("failed to add key to set: %w", err)
164 }
165
166 return set, nil
167}