A community based topic aggregation platform built on atproto
1package middleware 2 3import ( 4 "Coves/internal/atproto/auth" 5 "context" 6 "fmt" 7 "log" 8 "net/http" 9 "strings" 10) 11 12// Context keys for storing user information 13type contextKey string 14 15const ( 16 UserDIDKey contextKey = "user_did" 17 JWTClaimsKey contextKey = "jwt_claims" 18 UserAccessToken contextKey = "user_access_token" 19 DPoPProofKey contextKey = "dpop_proof" 20) 21 22// AtProtoAuthMiddleware enforces atProto OAuth authentication for protected routes 23// Validates JWT Bearer tokens from the Authorization header 24// Supports DPoP (RFC 9449) for token binding verification 25type AtProtoAuthMiddleware struct { 26 jwksFetcher auth.JWKSFetcher 27 dpopVerifier *auth.DPoPVerifier 28 skipVerify bool // For Phase 1 testing only 29} 30 31// NewAtProtoAuthMiddleware creates a new atProto auth middleware 32// skipVerify: if true, only parses JWT without signature verification (Phase 1) 33// 34// if false, performs full signature verification (Phase 2) 35// 36// IMPORTANT: Call Stop() when shutting down to clean up background goroutines. 37func NewAtProtoAuthMiddleware(jwksFetcher auth.JWKSFetcher, skipVerify bool) *AtProtoAuthMiddleware { 38 return &AtProtoAuthMiddleware{ 39 jwksFetcher: jwksFetcher, 40 dpopVerifier: auth.NewDPoPVerifier(), 41 skipVerify: skipVerify, 42 } 43} 44 45// Stop stops background goroutines. Call this when shutting down the server. 46// This prevents goroutine leaks from the DPoP verifier's replay protection cache. 47func (m *AtProtoAuthMiddleware) Stop() { 48 if m.dpopVerifier != nil { 49 m.dpopVerifier.Stop() 50 } 51} 52 53// RequireAuth middleware ensures the user is authenticated with a valid JWT 54// If not authenticated, returns 401 55// If authenticated, injects user DID and JWT claims into context 56// 57// Only accepts DPoP authorization scheme per RFC 9449: 58// - Authorization: DPoP <token> (DPoP-bound tokens) 59func (m *AtProtoAuthMiddleware) RequireAuth(next http.Handler) http.Handler { 60 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 61 // Extract Authorization header 62 authHeader := r.Header.Get("Authorization") 63 if authHeader == "" { 64 writeAuthError(w, "Missing Authorization header") 65 return 66 } 67 68 // Only accept DPoP scheme per RFC 9449 69 // HTTP auth schemes are case-insensitive per RFC 7235 70 token, ok := extractDPoPToken(authHeader) 71 if !ok { 72 writeAuthError(w, "Invalid Authorization header format. Expected: DPoP <token>") 73 return 74 } 75 76 var claims *auth.Claims 77 var err error 78 79 if m.skipVerify { 80 // Phase 1: Parse only (no signature verification) 81 claims, err = auth.ParseJWT(token) 82 if err != nil { 83 log.Printf("[AUTH_FAILURE] type=parse_error ip=%s method=%s path=%s error=%v", 84 r.RemoteAddr, r.Method, r.URL.Path, err) 85 writeAuthError(w, "Invalid token") 86 return 87 } 88 } else { 89 // Phase 2: Full verification with signature check 90 // 91 // SECURITY: The access token MUST be verified before trusting any claims. 92 // DPoP is an ADDITIONAL security layer, not a replacement for signature verification. 93 claims, err = auth.VerifyJWT(r.Context(), token, m.jwksFetcher) 94 if err != nil { 95 // Token verification failed - REJECT 96 // DO NOT fall back to DPoP-only verification, as that would trust unverified claims 97 issuer := "unknown" 98 if parsedClaims, parseErr := auth.ParseJWT(token); parseErr == nil { 99 issuer = parsedClaims.Issuer 100 } 101 log.Printf("[AUTH_FAILURE] type=verification_failed ip=%s method=%s path=%s issuer=%s error=%v", 102 r.RemoteAddr, r.Method, r.URL.Path, issuer, err) 103 writeAuthError(w, "Invalid or expired token") 104 return 105 } 106 107 // Token signature verified - now check if DPoP binding is required 108 // If the token has a cnf.jkt claim, DPoP proof is REQUIRED 109 dpopHeader := r.Header.Get("DPoP") 110 hasCnfJkt := claims.Confirmation != nil && claims.Confirmation["jkt"] != nil 111 112 if hasCnfJkt { 113 // Token has DPoP binding - REQUIRE valid DPoP proof 114 if dpopHeader == "" { 115 log.Printf("[AUTH_FAILURE] type=missing_dpop ip=%s method=%s path=%s error=token has cnf.jkt but no DPoP header", 116 r.RemoteAddr, r.Method, r.URL.Path) 117 writeAuthError(w, "DPoP proof required") 118 return 119 } 120 121 proof, err := m.verifyDPoPBinding(r, claims, dpopHeader, token) 122 if err != nil { 123 log.Printf("[AUTH_FAILURE] type=dpop_verification_failed ip=%s method=%s path=%s error=%v", 124 r.RemoteAddr, r.Method, r.URL.Path, err) 125 writeAuthError(w, "Invalid DPoP proof") 126 return 127 } 128 129 // Store verified DPoP proof in context 130 ctx := context.WithValue(r.Context(), DPoPProofKey, proof) 131 r = r.WithContext(ctx) 132 } else if dpopHeader != "" { 133 // DPoP header present but token doesn't have cnf.jkt - this is suspicious 134 // Log warning but don't reject (could be a misconfigured client) 135 log.Printf("[AUTH_WARNING] type=unexpected_dpop ip=%s method=%s path=%s warning=DPoP header present but token has no cnf.jkt", 136 r.RemoteAddr, r.Method, r.URL.Path) 137 } 138 } 139 140 // Extract user DID from 'sub' claim 141 userDID := claims.Subject 142 if userDID == "" { 143 writeAuthError(w, "Missing user DID in token") 144 return 145 } 146 147 // Inject user info and access token into context 148 ctx := context.WithValue(r.Context(), UserDIDKey, userDID) 149 ctx = context.WithValue(ctx, JWTClaimsKey, claims) 150 ctx = context.WithValue(ctx, UserAccessToken, token) 151 152 // Call next handler 153 next.ServeHTTP(w, r.WithContext(ctx)) 154 }) 155} 156 157// OptionalAuth middleware loads user info if authenticated, but doesn't require it 158// Useful for endpoints that work for both authenticated and anonymous users 159// 160// Only accepts DPoP authorization scheme per RFC 9449: 161// - Authorization: DPoP <token> (DPoP-bound tokens) 162func (m *AtProtoAuthMiddleware) OptionalAuth(next http.Handler) http.Handler { 163 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 164 // Extract Authorization header 165 authHeader := r.Header.Get("Authorization") 166 167 // Only accept DPoP scheme per RFC 9449 168 // HTTP auth schemes are case-insensitive per RFC 7235 169 token, ok := extractDPoPToken(authHeader) 170 if !ok { 171 // Not authenticated or invalid format - continue without user context 172 next.ServeHTTP(w, r) 173 return 174 } 175 176 var claims *auth.Claims 177 var err error 178 179 if m.skipVerify { 180 // Phase 1: Parse only 181 claims, err = auth.ParseJWT(token) 182 } else { 183 // Phase 2: Full verification 184 // SECURITY: Token MUST be verified before trusting claims 185 claims, err = auth.VerifyJWT(r.Context(), token, m.jwksFetcher) 186 } 187 188 if err != nil { 189 // Invalid token - continue without user context 190 log.Printf("Optional auth failed: %v", err) 191 next.ServeHTTP(w, r) 192 return 193 } 194 195 // Check DPoP binding if token has cnf.jkt (after successful verification) 196 // SECURITY: If token has cnf.jkt but no DPoP header, we cannot trust it 197 // (could be a stolen token). Continue as unauthenticated. 198 if !m.skipVerify { 199 dpopHeader := r.Header.Get("DPoP") 200 hasCnfJkt := claims.Confirmation != nil && claims.Confirmation["jkt"] != nil 201 202 if hasCnfJkt { 203 if dpopHeader == "" { 204 // Token requires DPoP binding but no proof provided 205 // Cannot trust this token - continue without auth 206 log.Printf("[AUTH_WARNING] Optional auth: token has cnf.jkt but no DPoP header - treating as unauthenticated (potential token theft)") 207 next.ServeHTTP(w, r) 208 return 209 } 210 211 proof, err := m.verifyDPoPBinding(r, claims, dpopHeader, token) 212 if err != nil { 213 // DPoP verification failed - cannot trust this token 214 log.Printf("[AUTH_WARNING] Optional auth: DPoP verification failed - treating as unauthenticated: %v", err) 215 next.ServeHTTP(w, r) 216 return 217 } 218 219 // DPoP verified - inject proof into context 220 ctx := context.WithValue(r.Context(), UserDIDKey, claims.Subject) 221 ctx = context.WithValue(ctx, JWTClaimsKey, claims) 222 ctx = context.WithValue(ctx, UserAccessToken, token) 223 ctx = context.WithValue(ctx, DPoPProofKey, proof) 224 next.ServeHTTP(w, r.WithContext(ctx)) 225 return 226 } 227 } 228 229 // No DPoP binding required - inject user info and access token into context 230 ctx := context.WithValue(r.Context(), UserDIDKey, claims.Subject) 231 ctx = context.WithValue(ctx, JWTClaimsKey, claims) 232 ctx = context.WithValue(ctx, UserAccessToken, token) 233 234 // Call next handler 235 next.ServeHTTP(w, r.WithContext(ctx)) 236 }) 237} 238 239// GetUserDID extracts the user's DID from the request context 240// Returns empty string if not authenticated 241func GetUserDID(r *http.Request) string { 242 did, _ := r.Context().Value(UserDIDKey).(string) 243 return did 244} 245 246// GetAuthenticatedDID extracts the authenticated user's DID from the context 247// This is used by service layers for defense-in-depth validation 248// Returns empty string if not authenticated 249func GetAuthenticatedDID(ctx context.Context) string { 250 did, _ := ctx.Value(UserDIDKey).(string) 251 return did 252} 253 254// GetJWTClaims extracts the JWT claims from the request context 255// Returns nil if not authenticated 256func GetJWTClaims(r *http.Request) *auth.Claims { 257 claims, _ := r.Context().Value(JWTClaimsKey).(*auth.Claims) 258 return claims 259} 260 261// SetTestUserDID sets the user DID in the context for testing purposes 262// This function should ONLY be used in tests to mock authenticated users 263func SetTestUserDID(ctx context.Context, userDID string) context.Context { 264 return context.WithValue(ctx, UserDIDKey, userDID) 265} 266 267// GetUserAccessToken extracts the user's access token from the request context 268// Returns empty string if not authenticated 269func GetUserAccessToken(r *http.Request) string { 270 token, _ := r.Context().Value(UserAccessToken).(string) 271 return token 272} 273 274// GetDPoPProof extracts the DPoP proof from the request context 275// Returns nil if no DPoP proof was verified 276func GetDPoPProof(r *http.Request) *auth.DPoPProof { 277 proof, _ := r.Context().Value(DPoPProofKey).(*auth.DPoPProof) 278 return proof 279} 280 281// verifyDPoPBinding verifies DPoP proof binding for an ALREADY VERIFIED token. 282// 283// SECURITY: This function ONLY verifies the DPoP proof and its binding to the token. 284// The access token MUST be signature-verified BEFORE calling this function. 285// DPoP is an ADDITIONAL security layer, not a replacement for signature verification. 286// 287// This prevents token theft attacks by proving the client possesses the private key 288// corresponding to the public key thumbprint in the token's cnf.jkt claim. 289func (m *AtProtoAuthMiddleware) verifyDPoPBinding(r *http.Request, claims *auth.Claims, dpopProofHeader, accessToken string) (*auth.DPoPProof, error) { 290 // Extract the cnf.jkt claim from the already-verified token 291 jkt, err := auth.ExtractCnfJkt(claims) 292 if err != nil { 293 return nil, fmt.Errorf("token requires DPoP but missing cnf.jkt: %w", err) 294 } 295 296 // Build the HTTP URI for DPoP verification 297 // Use the full URL including scheme and host, respecting proxy headers 298 scheme, host := extractSchemeAndHost(r) 299 300 // Use EscapedPath to preserve percent-encoding (P3 fix) 301 // r.URL.Path is decoded, but DPoP proofs contain the raw encoded path 302 path := r.URL.EscapedPath() 303 if path == "" { 304 path = r.URL.Path // Fallback if EscapedPath returns empty 305 } 306 307 httpURI := scheme + "://" + host + path 308 309 // Verify the DPoP proof 310 proof, err := m.dpopVerifier.VerifyDPoPProof(dpopProofHeader, r.Method, httpURI) 311 if err != nil { 312 return nil, fmt.Errorf("DPoP proof verification failed: %w", err) 313 } 314 315 // Verify the binding between the proof and the token (cnf.jkt) 316 if err := m.dpopVerifier.VerifyTokenBinding(proof, jkt); err != nil { 317 return nil, fmt.Errorf("DPoP binding verification failed: %w", err) 318 } 319 320 // Verify the access token hash (ath) if present in the proof 321 // Per RFC 9449 section 4.2, if ath is present, it MUST match the access token 322 if err := m.dpopVerifier.VerifyAccessTokenHash(proof, accessToken); err != nil { 323 return nil, fmt.Errorf("DPoP ath verification failed: %w", err) 324 } 325 326 return proof, nil 327} 328 329// extractSchemeAndHost extracts the scheme and host from the request, 330// respecting proxy headers (X-Forwarded-Proto, X-Forwarded-Host, Forwarded). 331// This is critical for DPoP verification when behind TLS-terminating proxies. 332func extractSchemeAndHost(r *http.Request) (scheme, host string) { 333 // Start with request defaults 334 scheme = r.URL.Scheme 335 host = r.Host 336 337 // Check X-Forwarded-Proto for scheme (most common) 338 if forwardedProto := r.Header.Get("X-Forwarded-Proto"); forwardedProto != "" { 339 parts := strings.Split(forwardedProto, ",") 340 if len(parts) > 0 && strings.TrimSpace(parts[0]) != "" { 341 scheme = strings.ToLower(strings.TrimSpace(parts[0])) 342 } 343 } 344 345 // Check X-Forwarded-Host for host (common with nginx/traefik) 346 if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" { 347 parts := strings.Split(forwardedHost, ",") 348 if len(parts) > 0 && strings.TrimSpace(parts[0]) != "" { 349 host = strings.TrimSpace(parts[0]) 350 } 351 } 352 353 // Check standard Forwarded header (RFC 7239) - takes precedence if present 354 // Format: Forwarded: for=192.0.2.60;proto=http;by=203.0.113.43;host=example.com 355 // RFC 7239 allows: mixed-case keys (Proto, PROTO), quoted values (host="example.com") 356 if forwarded := r.Header.Get("Forwarded"); forwarded != "" { 357 // Parse the first entry (comma-separated list) 358 firstEntry := strings.Split(forwarded, ",")[0] 359 for _, part := range strings.Split(firstEntry, ";") { 360 part = strings.TrimSpace(part) 361 // Split on first '=' to properly handle key=value pairs 362 if idx := strings.Index(part, "="); idx != -1 { 363 key := strings.ToLower(strings.TrimSpace(part[:idx])) 364 value := strings.TrimSpace(part[idx+1:]) 365 // Strip optional quotes per RFC 7239 section 4 366 value = strings.Trim(value, "\"") 367 368 switch key { 369 case "proto": 370 scheme = strings.ToLower(value) 371 case "host": 372 host = value 373 } 374 } 375 } 376 } 377 378 // Fallback scheme detection from TLS 379 if scheme == "" { 380 if r.TLS != nil { 381 scheme = "https" 382 } else { 383 scheme = "http" 384 } 385 } 386 387 return strings.ToLower(scheme), host 388} 389 390// writeAuthError writes a JSON error response for authentication failures 391func writeAuthError(w http.ResponseWriter, message string) { 392 w.Header().Set("Content-Type", "application/json") 393 w.WriteHeader(http.StatusUnauthorized) 394 // Simple error response matching XRPC error format 395 response := `{"error":"AuthenticationRequired","message":"` + message + `"}` 396 if _, err := w.Write([]byte(response)); err != nil { 397 log.Printf("Failed to write auth error response: %v", err) 398 } 399} 400 401// extractDPoPToken extracts the token from a DPoP Authorization header. 402// HTTP auth schemes are case-insensitive per RFC 7235, so "DPoP", "dpop", "DPOP" are all valid. 403// Returns the token and true if valid DPoP scheme, empty string and false otherwise. 404func extractDPoPToken(authHeader string) (string, bool) { 405 if authHeader == "" { 406 return "", false 407 } 408 409 // Split on first space: "DPoP <token>" -> ["DPoP", "<token>"] 410 parts := strings.SplitN(authHeader, " ", 2) 411 if len(parts) != 2 { 412 return "", false 413 } 414 415 // Case-insensitive scheme comparison per RFC 7235 416 if !strings.EqualFold(parts[0], "DPoP") { 417 return "", false 418 } 419 420 token := strings.TrimSpace(parts[1]) 421 if token == "" { 422 return "", false 423 } 424 425 return token, true 426}