A community based topic aggregation platform built on atproto
1package middleware
2
3import (
4 "context"
5 "log"
6 "net/http"
7 "strings"
8
9 "Coves/internal/atproto/auth"
10)
11
12// Context keys for storing user information
13type contextKey string
14
15const (
16 UserDIDKey contextKey = "user_did"
17 JWTClaimsKey contextKey = "jwt_claims"
18 UserAccessToken contextKey = "user_access_token"
19)
20
21// AtProtoAuthMiddleware enforces atProto OAuth authentication for protected routes
22// Validates JWT Bearer tokens from the Authorization header
23type AtProtoAuthMiddleware struct {
24 jwksFetcher auth.JWKSFetcher
25 skipVerify bool // For Phase 1 testing only
26}
27
28// NewAtProtoAuthMiddleware creates a new atProto auth middleware
29// skipVerify: if true, only parses JWT without signature verification (Phase 1)
30//
31// if false, performs full signature verification (Phase 2)
32func NewAtProtoAuthMiddleware(jwksFetcher auth.JWKSFetcher, skipVerify bool) *AtProtoAuthMiddleware {
33 return &AtProtoAuthMiddleware{
34 jwksFetcher: jwksFetcher,
35 skipVerify: skipVerify,
36 }
37}
38
39// RequireAuth middleware ensures the user is authenticated with a valid JWT
40// If not authenticated, returns 401
41// If authenticated, injects user DID and JWT claims into context
42func (m *AtProtoAuthMiddleware) RequireAuth(next http.Handler) http.Handler {
43 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
44 // Extract Authorization header
45 authHeader := r.Header.Get("Authorization")
46 if authHeader == "" {
47 writeAuthError(w, "Missing Authorization header")
48 return
49 }
50
51 // Must be Bearer token
52 if !strings.HasPrefix(authHeader, "Bearer ") {
53 writeAuthError(w, "Invalid Authorization header format. Expected: Bearer <token>")
54 return
55 }
56
57 token := strings.TrimPrefix(authHeader, "Bearer ")
58 token = strings.TrimSpace(token)
59
60 var claims *auth.Claims
61 var err error
62
63 if m.skipVerify {
64 // Phase 1: Parse only (no signature verification)
65 claims, err = auth.ParseJWT(token)
66 if err != nil {
67 log.Printf("[AUTH_FAILURE] type=parse_error ip=%s method=%s path=%s error=%v",
68 r.RemoteAddr, r.Method, r.URL.Path, err)
69 writeAuthError(w, "Invalid token")
70 return
71 }
72 } else {
73 // Phase 2: Full verification with signature check
74 claims, err = auth.VerifyJWT(r.Context(), token, m.jwksFetcher)
75 if err != nil {
76 // Try to extract issuer for better logging
77 issuer := "unknown"
78 if parsedClaims, parseErr := auth.ParseJWT(token); parseErr == nil {
79 issuer = parsedClaims.Issuer
80 }
81 log.Printf("[AUTH_FAILURE] type=verification_failed ip=%s method=%s path=%s issuer=%s error=%v",
82 r.RemoteAddr, r.Method, r.URL.Path, issuer, err)
83 writeAuthError(w, "Invalid or expired token")
84 return
85 }
86 }
87
88 // Extract user DID from 'sub' claim
89 userDID := claims.Subject
90 if userDID == "" {
91 writeAuthError(w, "Missing user DID in token")
92 return
93 }
94
95 // Inject user info and access token into context
96 ctx := context.WithValue(r.Context(), UserDIDKey, userDID)
97 ctx = context.WithValue(ctx, JWTClaimsKey, claims)
98 ctx = context.WithValue(ctx, UserAccessToken, token)
99
100 // Call next handler
101 next.ServeHTTP(w, r.WithContext(ctx))
102 })
103}
104
105// OptionalAuth middleware loads user info if authenticated, but doesn't require it
106// Useful for endpoints that work for both authenticated and anonymous users
107func (m *AtProtoAuthMiddleware) OptionalAuth(next http.Handler) http.Handler {
108 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
109 // Extract Authorization header
110 authHeader := r.Header.Get("Authorization")
111 if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
112 // Not authenticated - continue without user context
113 next.ServeHTTP(w, r)
114 return
115 }
116
117 token := strings.TrimPrefix(authHeader, "Bearer ")
118 token = strings.TrimSpace(token)
119
120 var claims *auth.Claims
121 var err error
122
123 if m.skipVerify {
124 // Phase 1: Parse only
125 claims, err = auth.ParseJWT(token)
126 } else {
127 // Phase 2: Full verification
128 claims, err = auth.VerifyJWT(r.Context(), token, m.jwksFetcher)
129 }
130
131 if err != nil {
132 // Invalid token - continue without user context
133 log.Printf("Optional auth failed: %v", err)
134 next.ServeHTTP(w, r)
135 return
136 }
137
138 // Inject user info and access token into context
139 ctx := context.WithValue(r.Context(), UserDIDKey, claims.Subject)
140 ctx = context.WithValue(ctx, JWTClaimsKey, claims)
141 ctx = context.WithValue(ctx, UserAccessToken, token)
142
143 // Call next handler
144 next.ServeHTTP(w, r.WithContext(ctx))
145 })
146}
147
148// GetUserDID extracts the user's DID from the request context
149// Returns empty string if not authenticated
150func GetUserDID(r *http.Request) string {
151 did, _ := r.Context().Value(UserDIDKey).(string)
152 return did
153}
154
155// GetAuthenticatedDID extracts the authenticated user's DID from the context
156// This is used by service layers for defense-in-depth validation
157// Returns empty string if not authenticated
158func GetAuthenticatedDID(ctx context.Context) string {
159 did, _ := ctx.Value(UserDIDKey).(string)
160 return did
161}
162
163// GetJWTClaims extracts the JWT claims from the request context
164// Returns nil if not authenticated
165func GetJWTClaims(r *http.Request) *auth.Claims {
166 claims, _ := r.Context().Value(JWTClaimsKey).(*auth.Claims)
167 return claims
168}
169
170// SetTestUserDID sets the user DID in the context for testing purposes
171// This function should ONLY be used in tests to mock authenticated users
172func SetTestUserDID(ctx context.Context, userDID string) context.Context {
173 return context.WithValue(ctx, UserDIDKey, userDID)
174}
175
176// GetUserAccessToken extracts the user's access token from the request context
177// Returns empty string if not authenticated
178func GetUserAccessToken(r *http.Request) string {
179 token, _ := r.Context().Value(UserAccessToken).(string)
180 return token
181}
182
183// writeAuthError writes a JSON error response for authentication failures
184func writeAuthError(w http.ResponseWriter, message string) {
185 w.Header().Set("Content-Type", "application/json")
186 w.WriteHeader(http.StatusUnauthorized)
187 // Simple error response matching XRPC error format
188 response := `{"error":"AuthenticationRequired","message":"` + message + `"}`
189 if _, err := w.Write([]byte(response)); err != nil {
190 log.Printf("Failed to write auth error response: %v", err)
191 }
192}