A community based topic aggregation platform built on atproto
1package middleware
2
3import (
4 "Coves/internal/atproto/auth"
5 "context"
6 "fmt"
7 "log"
8 "net/http"
9 "strings"
10)
11
12// Context keys for storing user information
13type contextKey string
14
15const (
16 UserDIDKey contextKey = "user_did"
17 JWTClaimsKey contextKey = "jwt_claims"
18 UserAccessToken contextKey = "user_access_token"
19 DPoPProofKey contextKey = "dpop_proof"
20)
21
22// AtProtoAuthMiddleware enforces atProto OAuth authentication for protected routes
23// Validates JWT Bearer tokens from the Authorization header
24// Supports DPoP (RFC 9449) for token binding verification
25type AtProtoAuthMiddleware struct {
26 jwksFetcher auth.JWKSFetcher
27 dpopVerifier *auth.DPoPVerifier
28 skipVerify bool // For Phase 1 testing only
29}
30
31// NewAtProtoAuthMiddleware creates a new atProto auth middleware
32// skipVerify: if true, only parses JWT without signature verification (Phase 1)
33//
34// if false, performs full signature verification (Phase 2)
35//
36// IMPORTANT: Call Stop() when shutting down to clean up background goroutines.
37func NewAtProtoAuthMiddleware(jwksFetcher auth.JWKSFetcher, skipVerify bool) *AtProtoAuthMiddleware {
38 return &AtProtoAuthMiddleware{
39 jwksFetcher: jwksFetcher,
40 dpopVerifier: auth.NewDPoPVerifier(),
41 skipVerify: skipVerify,
42 }
43}
44
45// Stop stops background goroutines. Call this when shutting down the server.
46// This prevents goroutine leaks from the DPoP verifier's replay protection cache.
47func (m *AtProtoAuthMiddleware) Stop() {
48 if m.dpopVerifier != nil {
49 m.dpopVerifier.Stop()
50 }
51}
52
53// RequireAuth middleware ensures the user is authenticated with a valid JWT
54// If not authenticated, returns 401
55// If authenticated, injects user DID and JWT claims into context
56//
57// Only accepts DPoP authorization scheme per RFC 9449:
58// - Authorization: DPoP <token> (DPoP-bound tokens)
59func (m *AtProtoAuthMiddleware) RequireAuth(next http.Handler) http.Handler {
60 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
61 // Extract Authorization header
62 authHeader := r.Header.Get("Authorization")
63 if authHeader == "" {
64 writeAuthError(w, "Missing Authorization header")
65 return
66 }
67
68 // Only accept DPoP scheme per RFC 9449
69 // HTTP auth schemes are case-insensitive per RFC 7235
70 token, ok := extractDPoPToken(authHeader)
71 if !ok {
72 writeAuthError(w, "Invalid Authorization header format. Expected: DPoP <token>")
73 return
74 }
75
76 var claims *auth.Claims
77 var err error
78
79 if m.skipVerify {
80 // Phase 1: Parse only (no signature verification)
81 claims, err = auth.ParseJWT(token)
82 if err != nil {
83 log.Printf("[AUTH_FAILURE] type=parse_error ip=%s method=%s path=%s error=%v",
84 r.RemoteAddr, r.Method, r.URL.Path, err)
85 writeAuthError(w, "Invalid token")
86 return
87 }
88 } else {
89 // Phase 2: Full verification with signature check
90 //
91 // SECURITY: The access token MUST be verified before trusting any claims.
92 // DPoP is an ADDITIONAL security layer, not a replacement for signature verification.
93 claims, err = auth.VerifyJWT(r.Context(), token, m.jwksFetcher)
94 if err != nil {
95 // Token verification failed - REJECT
96 // DO NOT fall back to DPoP-only verification, as that would trust unverified claims
97 issuer := "unknown"
98 if parsedClaims, parseErr := auth.ParseJWT(token); parseErr == nil {
99 issuer = parsedClaims.Issuer
100 }
101 log.Printf("[AUTH_FAILURE] type=verification_failed ip=%s method=%s path=%s issuer=%s error=%v",
102 r.RemoteAddr, r.Method, r.URL.Path, issuer, err)
103 writeAuthError(w, "Invalid or expired token")
104 return
105 }
106
107 // Token signature verified - now check if DPoP binding is required
108 // If the token has a cnf.jkt claim, DPoP proof is REQUIRED
109 dpopHeader := r.Header.Get("DPoP")
110 hasCnfJkt := claims.Confirmation != nil && claims.Confirmation["jkt"] != nil
111
112 if hasCnfJkt {
113 // Token has DPoP binding - REQUIRE valid DPoP proof
114 if dpopHeader == "" {
115 log.Printf("[AUTH_FAILURE] type=missing_dpop ip=%s method=%s path=%s error=token has cnf.jkt but no DPoP header",
116 r.RemoteAddr, r.Method, r.URL.Path)
117 writeAuthError(w, "DPoP proof required")
118 return
119 }
120
121 proof, err := m.verifyDPoPBinding(r, claims, dpopHeader, token)
122 if err != nil {
123 log.Printf("[AUTH_FAILURE] type=dpop_verification_failed ip=%s method=%s path=%s error=%v",
124 r.RemoteAddr, r.Method, r.URL.Path, err)
125 writeAuthError(w, "Invalid DPoP proof")
126 return
127 }
128
129 // Store verified DPoP proof in context
130 ctx := context.WithValue(r.Context(), DPoPProofKey, proof)
131 r = r.WithContext(ctx)
132 } else if dpopHeader != "" {
133 // DPoP header present but token doesn't have cnf.jkt - this is suspicious
134 // Log warning but don't reject (could be a misconfigured client)
135 log.Printf("[AUTH_WARNING] type=unexpected_dpop ip=%s method=%s path=%s warning=DPoP header present but token has no cnf.jkt",
136 r.RemoteAddr, r.Method, r.URL.Path)
137 }
138 }
139
140 // Extract user DID from 'sub' claim
141 userDID := claims.Subject
142 if userDID == "" {
143 writeAuthError(w, "Missing user DID in token")
144 return
145 }
146
147 // Inject user info and access token into context
148 ctx := context.WithValue(r.Context(), UserDIDKey, userDID)
149 ctx = context.WithValue(ctx, JWTClaimsKey, claims)
150 ctx = context.WithValue(ctx, UserAccessToken, token)
151
152 // Call next handler
153 next.ServeHTTP(w, r.WithContext(ctx))
154 })
155}
156
157// OptionalAuth middleware loads user info if authenticated, but doesn't require it
158// Useful for endpoints that work for both authenticated and anonymous users
159//
160// Only accepts DPoP authorization scheme per RFC 9449:
161// - Authorization: DPoP <token> (DPoP-bound tokens)
162func (m *AtProtoAuthMiddleware) OptionalAuth(next http.Handler) http.Handler {
163 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
164 // Extract Authorization header
165 authHeader := r.Header.Get("Authorization")
166
167 // Only accept DPoP scheme per RFC 9449
168 // HTTP auth schemes are case-insensitive per RFC 7235
169 token, ok := extractDPoPToken(authHeader)
170 if !ok {
171 // Not authenticated or invalid format - continue without user context
172 next.ServeHTTP(w, r)
173 return
174 }
175
176 var claims *auth.Claims
177 var err error
178
179 if m.skipVerify {
180 // Phase 1: Parse only
181 claims, err = auth.ParseJWT(token)
182 } else {
183 // Phase 2: Full verification
184 // SECURITY: Token MUST be verified before trusting claims
185 claims, err = auth.VerifyJWT(r.Context(), token, m.jwksFetcher)
186 }
187
188 if err != nil {
189 // Invalid token - continue without user context
190 log.Printf("Optional auth failed: %v", err)
191 next.ServeHTTP(w, r)
192 return
193 }
194
195 // Check DPoP binding if token has cnf.jkt (after successful verification)
196 // SECURITY: If token has cnf.jkt but no DPoP header, we cannot trust it
197 // (could be a stolen token). Continue as unauthenticated.
198 if !m.skipVerify {
199 dpopHeader := r.Header.Get("DPoP")
200 hasCnfJkt := claims.Confirmation != nil && claims.Confirmation["jkt"] != nil
201
202 if hasCnfJkt {
203 if dpopHeader == "" {
204 // Token requires DPoP binding but no proof provided
205 // Cannot trust this token - continue without auth
206 log.Printf("[AUTH_WARNING] Optional auth: token has cnf.jkt but no DPoP header - treating as unauthenticated (potential token theft)")
207 next.ServeHTTP(w, r)
208 return
209 }
210
211 proof, err := m.verifyDPoPBinding(r, claims, dpopHeader, token)
212 if err != nil {
213 // DPoP verification failed - cannot trust this token
214 log.Printf("[AUTH_WARNING] Optional auth: DPoP verification failed - treating as unauthenticated: %v", err)
215 next.ServeHTTP(w, r)
216 return
217 }
218
219 // DPoP verified - inject proof into context
220 ctx := context.WithValue(r.Context(), UserDIDKey, claims.Subject)
221 ctx = context.WithValue(ctx, JWTClaimsKey, claims)
222 ctx = context.WithValue(ctx, UserAccessToken, token)
223 ctx = context.WithValue(ctx, DPoPProofKey, proof)
224 next.ServeHTTP(w, r.WithContext(ctx))
225 return
226 }
227 }
228
229 // No DPoP binding required - inject user info and access token into context
230 ctx := context.WithValue(r.Context(), UserDIDKey, claims.Subject)
231 ctx = context.WithValue(ctx, JWTClaimsKey, claims)
232 ctx = context.WithValue(ctx, UserAccessToken, token)
233
234 // Call next handler
235 next.ServeHTTP(w, r.WithContext(ctx))
236 })
237}
238
239// GetUserDID extracts the user's DID from the request context
240// Returns empty string if not authenticated
241func GetUserDID(r *http.Request) string {
242 did, _ := r.Context().Value(UserDIDKey).(string)
243 return did
244}
245
246// GetAuthenticatedDID extracts the authenticated user's DID from the context
247// This is used by service layers for defense-in-depth validation
248// Returns empty string if not authenticated
249func GetAuthenticatedDID(ctx context.Context) string {
250 did, _ := ctx.Value(UserDIDKey).(string)
251 return did
252}
253
254// GetJWTClaims extracts the JWT claims from the request context
255// Returns nil if not authenticated
256func GetJWTClaims(r *http.Request) *auth.Claims {
257 claims, _ := r.Context().Value(JWTClaimsKey).(*auth.Claims)
258 return claims
259}
260
261// SetTestUserDID sets the user DID in the context for testing purposes
262// This function should ONLY be used in tests to mock authenticated users
263func SetTestUserDID(ctx context.Context, userDID string) context.Context {
264 return context.WithValue(ctx, UserDIDKey, userDID)
265}
266
267// GetUserAccessToken extracts the user's access token from the request context
268// Returns empty string if not authenticated
269func GetUserAccessToken(r *http.Request) string {
270 token, _ := r.Context().Value(UserAccessToken).(string)
271 return token
272}
273
274// GetDPoPProof extracts the DPoP proof from the request context
275// Returns nil if no DPoP proof was verified
276func GetDPoPProof(r *http.Request) *auth.DPoPProof {
277 proof, _ := r.Context().Value(DPoPProofKey).(*auth.DPoPProof)
278 return proof
279}
280
281// verifyDPoPBinding verifies DPoP proof binding for an ALREADY VERIFIED token.
282//
283// SECURITY: This function ONLY verifies the DPoP proof and its binding to the token.
284// The access token MUST be signature-verified BEFORE calling this function.
285// DPoP is an ADDITIONAL security layer, not a replacement for signature verification.
286//
287// This prevents token theft attacks by proving the client possesses the private key
288// corresponding to the public key thumbprint in the token's cnf.jkt claim.
289func (m *AtProtoAuthMiddleware) verifyDPoPBinding(r *http.Request, claims *auth.Claims, dpopProofHeader, accessToken string) (*auth.DPoPProof, error) {
290 // Extract the cnf.jkt claim from the already-verified token
291 jkt, err := auth.ExtractCnfJkt(claims)
292 if err != nil {
293 return nil, fmt.Errorf("token requires DPoP but missing cnf.jkt: %w", err)
294 }
295
296 // Build the HTTP URI for DPoP verification
297 // Use the full URL including scheme and host, respecting proxy headers
298 scheme, host := extractSchemeAndHost(r)
299
300 // Use EscapedPath to preserve percent-encoding (P3 fix)
301 // r.URL.Path is decoded, but DPoP proofs contain the raw encoded path
302 path := r.URL.EscapedPath()
303 if path == "" {
304 path = r.URL.Path // Fallback if EscapedPath returns empty
305 }
306
307 httpURI := scheme + "://" + host + path
308
309 // Verify the DPoP proof
310 proof, err := m.dpopVerifier.VerifyDPoPProof(dpopProofHeader, r.Method, httpURI)
311 if err != nil {
312 return nil, fmt.Errorf("DPoP proof verification failed: %w", err)
313 }
314
315 // Verify the binding between the proof and the token (cnf.jkt)
316 if err := m.dpopVerifier.VerifyTokenBinding(proof, jkt); err != nil {
317 return nil, fmt.Errorf("DPoP binding verification failed: %w", err)
318 }
319
320 // Verify the access token hash (ath) if present in the proof
321 // Per RFC 9449 section 4.2, if ath is present, it MUST match the access token
322 if err := m.dpopVerifier.VerifyAccessTokenHash(proof, accessToken); err != nil {
323 return nil, fmt.Errorf("DPoP ath verification failed: %w", err)
324 }
325
326 return proof, nil
327}
328
329// extractSchemeAndHost extracts the scheme and host from the request,
330// respecting proxy headers (X-Forwarded-Proto, X-Forwarded-Host, Forwarded).
331// This is critical for DPoP verification when behind TLS-terminating proxies.
332func extractSchemeAndHost(r *http.Request) (scheme, host string) {
333 // Start with request defaults
334 scheme = r.URL.Scheme
335 host = r.Host
336
337 // Check X-Forwarded-Proto for scheme (most common)
338 if forwardedProto := r.Header.Get("X-Forwarded-Proto"); forwardedProto != "" {
339 parts := strings.Split(forwardedProto, ",")
340 if len(parts) > 0 && strings.TrimSpace(parts[0]) != "" {
341 scheme = strings.ToLower(strings.TrimSpace(parts[0]))
342 }
343 }
344
345 // Check X-Forwarded-Host for host (common with nginx/traefik)
346 if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" {
347 parts := strings.Split(forwardedHost, ",")
348 if len(parts) > 0 && strings.TrimSpace(parts[0]) != "" {
349 host = strings.TrimSpace(parts[0])
350 }
351 }
352
353 // Check standard Forwarded header (RFC 7239) - takes precedence if present
354 // Format: Forwarded: for=192.0.2.60;proto=http;by=203.0.113.43;host=example.com
355 // RFC 7239 allows: mixed-case keys (Proto, PROTO), quoted values (host="example.com")
356 if forwarded := r.Header.Get("Forwarded"); forwarded != "" {
357 // Parse the first entry (comma-separated list)
358 firstEntry := strings.Split(forwarded, ",")[0]
359 for _, part := range strings.Split(firstEntry, ";") {
360 part = strings.TrimSpace(part)
361 // Split on first '=' to properly handle key=value pairs
362 if idx := strings.Index(part, "="); idx != -1 {
363 key := strings.ToLower(strings.TrimSpace(part[:idx]))
364 value := strings.TrimSpace(part[idx+1:])
365 // Strip optional quotes per RFC 7239 section 4
366 value = strings.Trim(value, "\"")
367
368 switch key {
369 case "proto":
370 scheme = strings.ToLower(value)
371 case "host":
372 host = value
373 }
374 }
375 }
376 }
377
378 // Fallback scheme detection from TLS
379 if scheme == "" {
380 if r.TLS != nil {
381 scheme = "https"
382 } else {
383 scheme = "http"
384 }
385 }
386
387 return strings.ToLower(scheme), host
388}
389
390// writeAuthError writes a JSON error response for authentication failures
391func writeAuthError(w http.ResponseWriter, message string) {
392 w.Header().Set("Content-Type", "application/json")
393 w.WriteHeader(http.StatusUnauthorized)
394 // Simple error response matching XRPC error format
395 response := `{"error":"AuthenticationRequired","message":"` + message + `"}`
396 if _, err := w.Write([]byte(response)); err != nil {
397 log.Printf("Failed to write auth error response: %v", err)
398 }
399}
400
401// extractDPoPToken extracts the token from a DPoP Authorization header.
402// HTTP auth schemes are case-insensitive per RFC 7235, so "DPoP", "dpop", "DPOP" are all valid.
403// Returns the token and true if valid DPoP scheme, empty string and false otherwise.
404func extractDPoPToken(authHeader string) (string, bool) {
405 if authHeader == "" {
406 return "", false
407 }
408
409 // Split on first space: "DPoP <token>" -> ["DPoP", "<token>"]
410 parts := strings.SplitN(authHeader, " ", 2)
411 if len(parts) != 2 {
412 return "", false
413 }
414
415 // Case-insensitive scheme comparison per RFC 7235
416 if !strings.EqualFold(parts[0], "DPoP") {
417 return "", false
418 }
419
420 token := strings.TrimSpace(parts[1])
421 if token == "" {
422 return "", false
423 }
424
425 return token, true
426}