A community based topic aggregation platform built on atproto
1package middleware 2 3import ( 4 "context" 5 "log" 6 "net/http" 7 "strings" 8 9 "Coves/internal/atproto/auth" 10) 11 12// Context keys for storing user information 13type contextKey string 14 15const ( 16 UserDIDKey contextKey = "user_did" 17 JWTClaimsKey contextKey = "jwt_claims" 18 UserAccessToken contextKey = "user_access_token" 19) 20 21// AtProtoAuthMiddleware enforces atProto OAuth authentication for protected routes 22// Validates JWT Bearer tokens from the Authorization header 23type AtProtoAuthMiddleware struct { 24 jwksFetcher auth.JWKSFetcher 25 skipVerify bool // For Phase 1 testing only 26} 27 28// NewAtProtoAuthMiddleware creates a new atProto auth middleware 29// skipVerify: if true, only parses JWT without signature verification (Phase 1) 30// 31// if false, performs full signature verification (Phase 2) 32func NewAtProtoAuthMiddleware(jwksFetcher auth.JWKSFetcher, skipVerify bool) *AtProtoAuthMiddleware { 33 return &AtProtoAuthMiddleware{ 34 jwksFetcher: jwksFetcher, 35 skipVerify: skipVerify, 36 } 37} 38 39// RequireAuth middleware ensures the user is authenticated with a valid JWT 40// If not authenticated, returns 401 41// If authenticated, injects user DID and JWT claims into context 42func (m *AtProtoAuthMiddleware) RequireAuth(next http.Handler) http.Handler { 43 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 44 // Extract Authorization header 45 authHeader := r.Header.Get("Authorization") 46 if authHeader == "" { 47 writeAuthError(w, "Missing Authorization header") 48 return 49 } 50 51 // Must be Bearer token 52 if !strings.HasPrefix(authHeader, "Bearer ") { 53 writeAuthError(w, "Invalid Authorization header format. Expected: Bearer <token>") 54 return 55 } 56 57 token := strings.TrimPrefix(authHeader, "Bearer ") 58 token = strings.TrimSpace(token) 59 60 var claims *auth.Claims 61 var err error 62 63 if m.skipVerify { 64 // Phase 1: Parse only (no signature verification) 65 claims, err = auth.ParseJWT(token) 66 if err != nil { 67 log.Printf("[AUTH_FAILURE] type=parse_error ip=%s method=%s path=%s error=%v", 68 r.RemoteAddr, r.Method, r.URL.Path, err) 69 writeAuthError(w, "Invalid token") 70 return 71 } 72 } else { 73 // Phase 2: Full verification with signature check 74 claims, err = auth.VerifyJWT(r.Context(), token, m.jwksFetcher) 75 if err != nil { 76 // Try to extract issuer for better logging 77 issuer := "unknown" 78 if parsedClaims, parseErr := auth.ParseJWT(token); parseErr == nil { 79 issuer = parsedClaims.Issuer 80 } 81 log.Printf("[AUTH_FAILURE] type=verification_failed ip=%s method=%s path=%s issuer=%s error=%v", 82 r.RemoteAddr, r.Method, r.URL.Path, issuer, err) 83 writeAuthError(w, "Invalid or expired token") 84 return 85 } 86 } 87 88 // Extract user DID from 'sub' claim 89 userDID := claims.Subject 90 if userDID == "" { 91 writeAuthError(w, "Missing user DID in token") 92 return 93 } 94 95 // Inject user info and access token into context 96 ctx := context.WithValue(r.Context(), UserDIDKey, userDID) 97 ctx = context.WithValue(ctx, JWTClaimsKey, claims) 98 ctx = context.WithValue(ctx, UserAccessToken, token) 99 100 // Call next handler 101 next.ServeHTTP(w, r.WithContext(ctx)) 102 }) 103} 104 105// OptionalAuth middleware loads user info if authenticated, but doesn't require it 106// Useful for endpoints that work for both authenticated and anonymous users 107func (m *AtProtoAuthMiddleware) OptionalAuth(next http.Handler) http.Handler { 108 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 109 // Extract Authorization header 110 authHeader := r.Header.Get("Authorization") 111 if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") { 112 // Not authenticated - continue without user context 113 next.ServeHTTP(w, r) 114 return 115 } 116 117 token := strings.TrimPrefix(authHeader, "Bearer ") 118 token = strings.TrimSpace(token) 119 120 var claims *auth.Claims 121 var err error 122 123 if m.skipVerify { 124 // Phase 1: Parse only 125 claims, err = auth.ParseJWT(token) 126 } else { 127 // Phase 2: Full verification 128 claims, err = auth.VerifyJWT(r.Context(), token, m.jwksFetcher) 129 } 130 131 if err != nil { 132 // Invalid token - continue without user context 133 log.Printf("Optional auth failed: %v", err) 134 next.ServeHTTP(w, r) 135 return 136 } 137 138 // Inject user info and access token into context 139 ctx := context.WithValue(r.Context(), UserDIDKey, claims.Subject) 140 ctx = context.WithValue(ctx, JWTClaimsKey, claims) 141 ctx = context.WithValue(ctx, UserAccessToken, token) 142 143 // Call next handler 144 next.ServeHTTP(w, r.WithContext(ctx)) 145 }) 146} 147 148// GetUserDID extracts the user's DID from the request context 149// Returns empty string if not authenticated 150func GetUserDID(r *http.Request) string { 151 did, _ := r.Context().Value(UserDIDKey).(string) 152 return did 153} 154 155// GetAuthenticatedDID extracts the authenticated user's DID from the context 156// This is used by service layers for defense-in-depth validation 157// Returns empty string if not authenticated 158func GetAuthenticatedDID(ctx context.Context) string { 159 did, _ := ctx.Value(UserDIDKey).(string) 160 return did 161} 162 163// GetJWTClaims extracts the JWT claims from the request context 164// Returns nil if not authenticated 165func GetJWTClaims(r *http.Request) *auth.Claims { 166 claims, _ := r.Context().Value(JWTClaimsKey).(*auth.Claims) 167 return claims 168} 169 170// SetTestUserDID sets the user DID in the context for testing purposes 171// This function should ONLY be used in tests to mock authenticated users 172func SetTestUserDID(ctx context.Context, userDID string) context.Context { 173 return context.WithValue(ctx, UserDIDKey, userDID) 174} 175 176// GetUserAccessToken extracts the user's access token from the request context 177// Returns empty string if not authenticated 178func GetUserAccessToken(r *http.Request) string { 179 token, _ := r.Context().Value(UserAccessToken).(string) 180 return token 181} 182 183// writeAuthError writes a JSON error response for authentication failures 184func writeAuthError(w http.ResponseWriter, message string) { 185 w.Header().Set("Content-Type", "application/json") 186 w.WriteHeader(http.StatusUnauthorized) 187 // Simple error response matching XRPC error format 188 response := `{"error":"AuthenticationRequired","message":"` + message + `"}` 189 if _, err := w.Write([]byte(response)); err != nil { 190 log.Printf("Failed to write auth error response: %v", err) 191 } 192}