A community based topic aggregation platform built on atproto
1package middleware
2
3import (
4 "Coves/internal/atproto/auth"
5 "context"
6 "fmt"
7 "log"
8 "net/http"
9 "strings"
10)
11
12// Context keys for storing user information
13type contextKey string
14
15const (
16 UserDIDKey contextKey = "user_did"
17 JWTClaimsKey contextKey = "jwt_claims"
18 UserAccessToken contextKey = "user_access_token"
19 DPoPProofKey contextKey = "dpop_proof"
20)
21
22// AtProtoAuthMiddleware enforces atProto OAuth authentication for protected routes
23// Validates JWT Bearer tokens from the Authorization header
24// Supports DPoP (RFC 9449) for token binding verification
25type AtProtoAuthMiddleware struct {
26 jwksFetcher auth.JWKSFetcher
27 dpopVerifier *auth.DPoPVerifier
28 skipVerify bool // For Phase 1 testing only
29}
30
31// NewAtProtoAuthMiddleware creates a new atProto auth middleware
32// skipVerify: if true, only parses JWT without signature verification (Phase 1)
33//
34// if false, performs full signature verification (Phase 2)
35//
36// IMPORTANT: Call Stop() when shutting down to clean up background goroutines.
37func NewAtProtoAuthMiddleware(jwksFetcher auth.JWKSFetcher, skipVerify bool) *AtProtoAuthMiddleware {
38 return &AtProtoAuthMiddleware{
39 jwksFetcher: jwksFetcher,
40 dpopVerifier: auth.NewDPoPVerifier(),
41 skipVerify: skipVerify,
42 }
43}
44
45// Stop stops background goroutines. Call this when shutting down the server.
46// This prevents goroutine leaks from the DPoP verifier's replay protection cache.
47func (m *AtProtoAuthMiddleware) Stop() {
48 if m.dpopVerifier != nil {
49 m.dpopVerifier.Stop()
50 }
51}
52
53// RequireAuth middleware ensures the user is authenticated with a valid JWT
54// If not authenticated, returns 401
55// If authenticated, injects user DID and JWT claims into context
56func (m *AtProtoAuthMiddleware) RequireAuth(next http.Handler) http.Handler {
57 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
58 // Extract Authorization header
59 authHeader := r.Header.Get("Authorization")
60 if authHeader == "" {
61 writeAuthError(w, "Missing Authorization header")
62 return
63 }
64
65 // Must be Bearer token
66 if !strings.HasPrefix(authHeader, "Bearer ") {
67 writeAuthError(w, "Invalid Authorization header format. Expected: Bearer <token>")
68 return
69 }
70
71 token := strings.TrimPrefix(authHeader, "Bearer ")
72 token = strings.TrimSpace(token)
73
74 var claims *auth.Claims
75 var err error
76
77 if m.skipVerify {
78 // Phase 1: Parse only (no signature verification)
79 claims, err = auth.ParseJWT(token)
80 if err != nil {
81 log.Printf("[AUTH_FAILURE] type=parse_error ip=%s method=%s path=%s error=%v",
82 r.RemoteAddr, r.Method, r.URL.Path, err)
83 writeAuthError(w, "Invalid token")
84 return
85 }
86 } else {
87 // Phase 2: Full verification with signature check
88 //
89 // SECURITY: The access token MUST be verified before trusting any claims.
90 // DPoP is an ADDITIONAL security layer, not a replacement for signature verification.
91 claims, err = auth.VerifyJWT(r.Context(), token, m.jwksFetcher)
92 if err != nil {
93 // Token verification failed - REJECT
94 // DO NOT fall back to DPoP-only verification, as that would trust unverified claims
95 issuer := "unknown"
96 if parsedClaims, parseErr := auth.ParseJWT(token); parseErr == nil {
97 issuer = parsedClaims.Issuer
98 }
99 log.Printf("[AUTH_FAILURE] type=verification_failed ip=%s method=%s path=%s issuer=%s error=%v",
100 r.RemoteAddr, r.Method, r.URL.Path, issuer, err)
101 writeAuthError(w, "Invalid or expired token")
102 return
103 }
104
105 // Token signature verified - now check if DPoP binding is required
106 // If the token has a cnf.jkt claim, DPoP proof is REQUIRED
107 dpopHeader := r.Header.Get("DPoP")
108 hasCnfJkt := claims.Confirmation != nil && claims.Confirmation["jkt"] != nil
109
110 if hasCnfJkt {
111 // Token has DPoP binding - REQUIRE valid DPoP proof
112 if dpopHeader == "" {
113 log.Printf("[AUTH_FAILURE] type=missing_dpop ip=%s method=%s path=%s error=token has cnf.jkt but no DPoP header",
114 r.RemoteAddr, r.Method, r.URL.Path)
115 writeAuthError(w, "DPoP proof required")
116 return
117 }
118
119 proof, err := m.verifyDPoPBinding(r, claims, dpopHeader)
120 if err != nil {
121 log.Printf("[AUTH_FAILURE] type=dpop_verification_failed ip=%s method=%s path=%s error=%v",
122 r.RemoteAddr, r.Method, r.URL.Path, err)
123 writeAuthError(w, "Invalid DPoP proof")
124 return
125 }
126
127 // Store verified DPoP proof in context
128 ctx := context.WithValue(r.Context(), DPoPProofKey, proof)
129 r = r.WithContext(ctx)
130 } else if dpopHeader != "" {
131 // DPoP header present but token doesn't have cnf.jkt - this is suspicious
132 // Log warning but don't reject (could be a misconfigured client)
133 log.Printf("[AUTH_WARNING] type=unexpected_dpop ip=%s method=%s path=%s warning=DPoP header present but token has no cnf.jkt",
134 r.RemoteAddr, r.Method, r.URL.Path)
135 }
136 }
137
138 // Extract user DID from 'sub' claim
139 userDID := claims.Subject
140 if userDID == "" {
141 writeAuthError(w, "Missing user DID in token")
142 return
143 }
144
145 // Inject user info and access token into context
146 ctx := context.WithValue(r.Context(), UserDIDKey, userDID)
147 ctx = context.WithValue(ctx, JWTClaimsKey, claims)
148 ctx = context.WithValue(ctx, UserAccessToken, token)
149
150 // Call next handler
151 next.ServeHTTP(w, r.WithContext(ctx))
152 })
153}
154
155// OptionalAuth middleware loads user info if authenticated, but doesn't require it
156// Useful for endpoints that work for both authenticated and anonymous users
157func (m *AtProtoAuthMiddleware) OptionalAuth(next http.Handler) http.Handler {
158 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
159 // Extract Authorization header
160 authHeader := r.Header.Get("Authorization")
161 if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
162 // Not authenticated - continue without user context
163 next.ServeHTTP(w, r)
164 return
165 }
166
167 token := strings.TrimPrefix(authHeader, "Bearer ")
168 token = strings.TrimSpace(token)
169
170 var claims *auth.Claims
171 var err error
172
173 if m.skipVerify {
174 // Phase 1: Parse only
175 claims, err = auth.ParseJWT(token)
176 } else {
177 // Phase 2: Full verification
178 // SECURITY: Token MUST be verified before trusting claims
179 claims, err = auth.VerifyJWT(r.Context(), token, m.jwksFetcher)
180 }
181
182 if err != nil {
183 // Invalid token - continue without user context
184 log.Printf("Optional auth failed: %v", err)
185 next.ServeHTTP(w, r)
186 return
187 }
188
189 // Check DPoP binding if token has cnf.jkt (after successful verification)
190 // SECURITY: If token has cnf.jkt but no DPoP header, we cannot trust it
191 // (could be a stolen token). Continue as unauthenticated.
192 if !m.skipVerify {
193 dpopHeader := r.Header.Get("DPoP")
194 hasCnfJkt := claims.Confirmation != nil && claims.Confirmation["jkt"] != nil
195
196 if hasCnfJkt {
197 if dpopHeader == "" {
198 // Token requires DPoP binding but no proof provided
199 // Cannot trust this token - continue without auth
200 log.Printf("[AUTH_WARNING] Optional auth: token has cnf.jkt but no DPoP header - treating as unauthenticated (potential token theft)")
201 next.ServeHTTP(w, r)
202 return
203 }
204
205 proof, err := m.verifyDPoPBinding(r, claims, dpopHeader)
206 if err != nil {
207 // DPoP verification failed - cannot trust this token
208 log.Printf("[AUTH_WARNING] Optional auth: DPoP verification failed - treating as unauthenticated: %v", err)
209 next.ServeHTTP(w, r)
210 return
211 }
212
213 // DPoP verified - inject proof into context
214 ctx := context.WithValue(r.Context(), UserDIDKey, claims.Subject)
215 ctx = context.WithValue(ctx, JWTClaimsKey, claims)
216 ctx = context.WithValue(ctx, UserAccessToken, token)
217 ctx = context.WithValue(ctx, DPoPProofKey, proof)
218 next.ServeHTTP(w, r.WithContext(ctx))
219 return
220 }
221 }
222
223 // No DPoP binding required - inject user info and access token into context
224 ctx := context.WithValue(r.Context(), UserDIDKey, claims.Subject)
225 ctx = context.WithValue(ctx, JWTClaimsKey, claims)
226 ctx = context.WithValue(ctx, UserAccessToken, token)
227
228 // Call next handler
229 next.ServeHTTP(w, r.WithContext(ctx))
230 })
231}
232
233// GetUserDID extracts the user's DID from the request context
234// Returns empty string if not authenticated
235func GetUserDID(r *http.Request) string {
236 did, _ := r.Context().Value(UserDIDKey).(string)
237 return did
238}
239
240// GetAuthenticatedDID extracts the authenticated user's DID from the context
241// This is used by service layers for defense-in-depth validation
242// Returns empty string if not authenticated
243func GetAuthenticatedDID(ctx context.Context) string {
244 did, _ := ctx.Value(UserDIDKey).(string)
245 return did
246}
247
248// GetJWTClaims extracts the JWT claims from the request context
249// Returns nil if not authenticated
250func GetJWTClaims(r *http.Request) *auth.Claims {
251 claims, _ := r.Context().Value(JWTClaimsKey).(*auth.Claims)
252 return claims
253}
254
255// SetTestUserDID sets the user DID in the context for testing purposes
256// This function should ONLY be used in tests to mock authenticated users
257func SetTestUserDID(ctx context.Context, userDID string) context.Context {
258 return context.WithValue(ctx, UserDIDKey, userDID)
259}
260
261// GetUserAccessToken extracts the user's access token from the request context
262// Returns empty string if not authenticated
263func GetUserAccessToken(r *http.Request) string {
264 token, _ := r.Context().Value(UserAccessToken).(string)
265 return token
266}
267
268// GetDPoPProof extracts the DPoP proof from the request context
269// Returns nil if no DPoP proof was verified
270func GetDPoPProof(r *http.Request) *auth.DPoPProof {
271 proof, _ := r.Context().Value(DPoPProofKey).(*auth.DPoPProof)
272 return proof
273}
274
275// verifyDPoPBinding verifies DPoP proof binding for an ALREADY VERIFIED token.
276//
277// SECURITY: This function ONLY verifies the DPoP proof and its binding to the token.
278// The access token MUST be signature-verified BEFORE calling this function.
279// DPoP is an ADDITIONAL security layer, not a replacement for signature verification.
280//
281// This prevents token theft attacks by proving the client possesses the private key
282// corresponding to the public key thumbprint in the token's cnf.jkt claim.
283func (m *AtProtoAuthMiddleware) verifyDPoPBinding(r *http.Request, claims *auth.Claims, dpopProofHeader string) (*auth.DPoPProof, error) {
284 // Extract the cnf.jkt claim from the already-verified token
285 jkt, err := auth.ExtractCnfJkt(claims)
286 if err != nil {
287 return nil, fmt.Errorf("token requires DPoP but missing cnf.jkt: %w", err)
288 }
289
290 // Build the HTTP URI for DPoP verification
291 // Use the full URL including scheme and host
292 scheme := strings.TrimSpace(r.URL.Scheme)
293 if forwardedProto := r.Header.Get("X-Forwarded-Proto"); forwardedProto != "" {
294 // Forwarded proto may contain a comma-separated list; use the first entry
295 parts := strings.Split(forwardedProto, ",")
296 if len(parts) > 0 && strings.TrimSpace(parts[0]) != "" {
297 scheme = strings.ToLower(strings.TrimSpace(parts[0]))
298 }
299 }
300 if scheme == "" {
301 if r.TLS != nil {
302 scheme = "https"
303 } else {
304 scheme = "http"
305 }
306 }
307 scheme = strings.ToLower(scheme)
308 httpURI := scheme + "://" + r.Host + r.URL.Path
309
310 // Verify the DPoP proof
311 proof, err := m.dpopVerifier.VerifyDPoPProof(dpopProofHeader, r.Method, httpURI)
312 if err != nil {
313 return nil, fmt.Errorf("DPoP proof verification failed: %w", err)
314 }
315
316 // Verify the binding between the proof and the token
317 if err := m.dpopVerifier.VerifyTokenBinding(proof, jkt); err != nil {
318 return nil, fmt.Errorf("DPoP binding verification failed: %w", err)
319 }
320
321 return proof, nil
322}
323
324// writeAuthError writes a JSON error response for authentication failures
325func writeAuthError(w http.ResponseWriter, message string) {
326 w.Header().Set("Content-Type", "application/json")
327 w.WriteHeader(http.StatusUnauthorized)
328 // Simple error response matching XRPC error format
329 response := `{"error":"AuthenticationRequired","message":"` + message + `"}`
330 if _, err := w.Write([]byte(response)); err != nil {
331 log.Printf("Failed to write auth error response: %v", err)
332 }
333}