A community based topic aggregation platform built on atproto
1package middleware 2 3import ( 4 "Coves/internal/atproto/auth" 5 "context" 6 "fmt" 7 "log" 8 "net/http" 9 "strings" 10) 11 12// Context keys for storing user information 13type contextKey string 14 15const ( 16 UserDIDKey contextKey = "user_did" 17 JWTClaimsKey contextKey = "jwt_claims" 18 UserAccessToken contextKey = "user_access_token" 19 DPoPProofKey contextKey = "dpop_proof" 20) 21 22// AtProtoAuthMiddleware enforces atProto OAuth authentication for protected routes 23// Validates JWT Bearer tokens from the Authorization header 24// Supports DPoP (RFC 9449) for token binding verification 25type AtProtoAuthMiddleware struct { 26 jwksFetcher auth.JWKSFetcher 27 dpopVerifier *auth.DPoPVerifier 28 skipVerify bool // For Phase 1 testing only 29} 30 31// NewAtProtoAuthMiddleware creates a new atProto auth middleware 32// skipVerify: if true, only parses JWT without signature verification (Phase 1) 33// 34// if false, performs full signature verification (Phase 2) 35// 36// IMPORTANT: Call Stop() when shutting down to clean up background goroutines. 37func NewAtProtoAuthMiddleware(jwksFetcher auth.JWKSFetcher, skipVerify bool) *AtProtoAuthMiddleware { 38 return &AtProtoAuthMiddleware{ 39 jwksFetcher: jwksFetcher, 40 dpopVerifier: auth.NewDPoPVerifier(), 41 skipVerify: skipVerify, 42 } 43} 44 45// Stop stops background goroutines. Call this when shutting down the server. 46// This prevents goroutine leaks from the DPoP verifier's replay protection cache. 47func (m *AtProtoAuthMiddleware) Stop() { 48 if m.dpopVerifier != nil { 49 m.dpopVerifier.Stop() 50 } 51} 52 53// RequireAuth middleware ensures the user is authenticated with a valid JWT 54// If not authenticated, returns 401 55// If authenticated, injects user DID and JWT claims into context 56func (m *AtProtoAuthMiddleware) RequireAuth(next http.Handler) http.Handler { 57 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 58 // Extract Authorization header 59 authHeader := r.Header.Get("Authorization") 60 if authHeader == "" { 61 writeAuthError(w, "Missing Authorization header") 62 return 63 } 64 65 // Must be Bearer token 66 if !strings.HasPrefix(authHeader, "Bearer ") { 67 writeAuthError(w, "Invalid Authorization header format. Expected: Bearer <token>") 68 return 69 } 70 71 token := strings.TrimPrefix(authHeader, "Bearer ") 72 token = strings.TrimSpace(token) 73 74 var claims *auth.Claims 75 var err error 76 77 if m.skipVerify { 78 // Phase 1: Parse only (no signature verification) 79 claims, err = auth.ParseJWT(token) 80 if err != nil { 81 log.Printf("[AUTH_FAILURE] type=parse_error ip=%s method=%s path=%s error=%v", 82 r.RemoteAddr, r.Method, r.URL.Path, err) 83 writeAuthError(w, "Invalid token") 84 return 85 } 86 } else { 87 // Phase 2: Full verification with signature check 88 // 89 // SECURITY: The access token MUST be verified before trusting any claims. 90 // DPoP is an ADDITIONAL security layer, not a replacement for signature verification. 91 claims, err = auth.VerifyJWT(r.Context(), token, m.jwksFetcher) 92 if err != nil { 93 // Token verification failed - REJECT 94 // DO NOT fall back to DPoP-only verification, as that would trust unverified claims 95 issuer := "unknown" 96 if parsedClaims, parseErr := auth.ParseJWT(token); parseErr == nil { 97 issuer = parsedClaims.Issuer 98 } 99 log.Printf("[AUTH_FAILURE] type=verification_failed ip=%s method=%s path=%s issuer=%s error=%v", 100 r.RemoteAddr, r.Method, r.URL.Path, issuer, err) 101 writeAuthError(w, "Invalid or expired token") 102 return 103 } 104 105 // Token signature verified - now check if DPoP binding is required 106 // If the token has a cnf.jkt claim, DPoP proof is REQUIRED 107 dpopHeader := r.Header.Get("DPoP") 108 hasCnfJkt := claims.Confirmation != nil && claims.Confirmation["jkt"] != nil 109 110 if hasCnfJkt { 111 // Token has DPoP binding - REQUIRE valid DPoP proof 112 if dpopHeader == "" { 113 log.Printf("[AUTH_FAILURE] type=missing_dpop ip=%s method=%s path=%s error=token has cnf.jkt but no DPoP header", 114 r.RemoteAddr, r.Method, r.URL.Path) 115 writeAuthError(w, "DPoP proof required") 116 return 117 } 118 119 proof, err := m.verifyDPoPBinding(r, claims, dpopHeader) 120 if err != nil { 121 log.Printf("[AUTH_FAILURE] type=dpop_verification_failed ip=%s method=%s path=%s error=%v", 122 r.RemoteAddr, r.Method, r.URL.Path, err) 123 writeAuthError(w, "Invalid DPoP proof") 124 return 125 } 126 127 // Store verified DPoP proof in context 128 ctx := context.WithValue(r.Context(), DPoPProofKey, proof) 129 r = r.WithContext(ctx) 130 } else if dpopHeader != "" { 131 // DPoP header present but token doesn't have cnf.jkt - this is suspicious 132 // Log warning but don't reject (could be a misconfigured client) 133 log.Printf("[AUTH_WARNING] type=unexpected_dpop ip=%s method=%s path=%s warning=DPoP header present but token has no cnf.jkt", 134 r.RemoteAddr, r.Method, r.URL.Path) 135 } 136 } 137 138 // Extract user DID from 'sub' claim 139 userDID := claims.Subject 140 if userDID == "" { 141 writeAuthError(w, "Missing user DID in token") 142 return 143 } 144 145 // Inject user info and access token into context 146 ctx := context.WithValue(r.Context(), UserDIDKey, userDID) 147 ctx = context.WithValue(ctx, JWTClaimsKey, claims) 148 ctx = context.WithValue(ctx, UserAccessToken, token) 149 150 // Call next handler 151 next.ServeHTTP(w, r.WithContext(ctx)) 152 }) 153} 154 155// OptionalAuth middleware loads user info if authenticated, but doesn't require it 156// Useful for endpoints that work for both authenticated and anonymous users 157func (m *AtProtoAuthMiddleware) OptionalAuth(next http.Handler) http.Handler { 158 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 159 // Extract Authorization header 160 authHeader := r.Header.Get("Authorization") 161 if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") { 162 // Not authenticated - continue without user context 163 next.ServeHTTP(w, r) 164 return 165 } 166 167 token := strings.TrimPrefix(authHeader, "Bearer ") 168 token = strings.TrimSpace(token) 169 170 var claims *auth.Claims 171 var err error 172 173 if m.skipVerify { 174 // Phase 1: Parse only 175 claims, err = auth.ParseJWT(token) 176 } else { 177 // Phase 2: Full verification 178 // SECURITY: Token MUST be verified before trusting claims 179 claims, err = auth.VerifyJWT(r.Context(), token, m.jwksFetcher) 180 } 181 182 if err != nil { 183 // Invalid token - continue without user context 184 log.Printf("Optional auth failed: %v", err) 185 next.ServeHTTP(w, r) 186 return 187 } 188 189 // Check DPoP binding if token has cnf.jkt (after successful verification) 190 // SECURITY: If token has cnf.jkt but no DPoP header, we cannot trust it 191 // (could be a stolen token). Continue as unauthenticated. 192 if !m.skipVerify { 193 dpopHeader := r.Header.Get("DPoP") 194 hasCnfJkt := claims.Confirmation != nil && claims.Confirmation["jkt"] != nil 195 196 if hasCnfJkt { 197 if dpopHeader == "" { 198 // Token requires DPoP binding but no proof provided 199 // Cannot trust this token - continue without auth 200 log.Printf("[AUTH_WARNING] Optional auth: token has cnf.jkt but no DPoP header - treating as unauthenticated (potential token theft)") 201 next.ServeHTTP(w, r) 202 return 203 } 204 205 proof, err := m.verifyDPoPBinding(r, claims, dpopHeader) 206 if err != nil { 207 // DPoP verification failed - cannot trust this token 208 log.Printf("[AUTH_WARNING] Optional auth: DPoP verification failed - treating as unauthenticated: %v", err) 209 next.ServeHTTP(w, r) 210 return 211 } 212 213 // DPoP verified - inject proof into context 214 ctx := context.WithValue(r.Context(), UserDIDKey, claims.Subject) 215 ctx = context.WithValue(ctx, JWTClaimsKey, claims) 216 ctx = context.WithValue(ctx, UserAccessToken, token) 217 ctx = context.WithValue(ctx, DPoPProofKey, proof) 218 next.ServeHTTP(w, r.WithContext(ctx)) 219 return 220 } 221 } 222 223 // No DPoP binding required - inject user info and access token into context 224 ctx := context.WithValue(r.Context(), UserDIDKey, claims.Subject) 225 ctx = context.WithValue(ctx, JWTClaimsKey, claims) 226 ctx = context.WithValue(ctx, UserAccessToken, token) 227 228 // Call next handler 229 next.ServeHTTP(w, r.WithContext(ctx)) 230 }) 231} 232 233// GetUserDID extracts the user's DID from the request context 234// Returns empty string if not authenticated 235func GetUserDID(r *http.Request) string { 236 did, _ := r.Context().Value(UserDIDKey).(string) 237 return did 238} 239 240// GetAuthenticatedDID extracts the authenticated user's DID from the context 241// This is used by service layers for defense-in-depth validation 242// Returns empty string if not authenticated 243func GetAuthenticatedDID(ctx context.Context) string { 244 did, _ := ctx.Value(UserDIDKey).(string) 245 return did 246} 247 248// GetJWTClaims extracts the JWT claims from the request context 249// Returns nil if not authenticated 250func GetJWTClaims(r *http.Request) *auth.Claims { 251 claims, _ := r.Context().Value(JWTClaimsKey).(*auth.Claims) 252 return claims 253} 254 255// SetTestUserDID sets the user DID in the context for testing purposes 256// This function should ONLY be used in tests to mock authenticated users 257func SetTestUserDID(ctx context.Context, userDID string) context.Context { 258 return context.WithValue(ctx, UserDIDKey, userDID) 259} 260 261// GetUserAccessToken extracts the user's access token from the request context 262// Returns empty string if not authenticated 263func GetUserAccessToken(r *http.Request) string { 264 token, _ := r.Context().Value(UserAccessToken).(string) 265 return token 266} 267 268// GetDPoPProof extracts the DPoP proof from the request context 269// Returns nil if no DPoP proof was verified 270func GetDPoPProof(r *http.Request) *auth.DPoPProof { 271 proof, _ := r.Context().Value(DPoPProofKey).(*auth.DPoPProof) 272 return proof 273} 274 275// verifyDPoPBinding verifies DPoP proof binding for an ALREADY VERIFIED token. 276// 277// SECURITY: This function ONLY verifies the DPoP proof and its binding to the token. 278// The access token MUST be signature-verified BEFORE calling this function. 279// DPoP is an ADDITIONAL security layer, not a replacement for signature verification. 280// 281// This prevents token theft attacks by proving the client possesses the private key 282// corresponding to the public key thumbprint in the token's cnf.jkt claim. 283func (m *AtProtoAuthMiddleware) verifyDPoPBinding(r *http.Request, claims *auth.Claims, dpopProofHeader string) (*auth.DPoPProof, error) { 284 // Extract the cnf.jkt claim from the already-verified token 285 jkt, err := auth.ExtractCnfJkt(claims) 286 if err != nil { 287 return nil, fmt.Errorf("token requires DPoP but missing cnf.jkt: %w", err) 288 } 289 290 // Build the HTTP URI for DPoP verification 291 // Use the full URL including scheme and host 292 scheme := strings.TrimSpace(r.URL.Scheme) 293 if forwardedProto := r.Header.Get("X-Forwarded-Proto"); forwardedProto != "" { 294 // Forwarded proto may contain a comma-separated list; use the first entry 295 parts := strings.Split(forwardedProto, ",") 296 if len(parts) > 0 && strings.TrimSpace(parts[0]) != "" { 297 scheme = strings.ToLower(strings.TrimSpace(parts[0])) 298 } 299 } 300 if scheme == "" { 301 if r.TLS != nil { 302 scheme = "https" 303 } else { 304 scheme = "http" 305 } 306 } 307 scheme = strings.ToLower(scheme) 308 httpURI := scheme + "://" + r.Host + r.URL.Path 309 310 // Verify the DPoP proof 311 proof, err := m.dpopVerifier.VerifyDPoPProof(dpopProofHeader, r.Method, httpURI) 312 if err != nil { 313 return nil, fmt.Errorf("DPoP proof verification failed: %w", err) 314 } 315 316 // Verify the binding between the proof and the token 317 if err := m.dpopVerifier.VerifyTokenBinding(proof, jkt); err != nil { 318 return nil, fmt.Errorf("DPoP binding verification failed: %w", err) 319 } 320 321 return proof, nil 322} 323 324// writeAuthError writes a JSON error response for authentication failures 325func writeAuthError(w http.ResponseWriter, message string) { 326 w.Header().Set("Content-Type", "application/json") 327 w.WriteHeader(http.StatusUnauthorized) 328 // Simple error response matching XRPC error format 329 response := `{"error":"AuthenticationRequired","message":"` + message + `"}` 330 if _, err := w.Write([]byte(response)); err != nil { 331 log.Printf("Failed to write auth error response: %v", err) 332 } 333}