A community based topic aggregation platform built on atproto
1package middleware 2 3import ( 4 "Coves/internal/atproto/auth" 5 "context" 6 "log" 7 "net/http" 8 "strings" 9) 10 11// Context keys for storing user information 12type contextKey string 13 14const ( 15 UserDIDKey contextKey = "user_did" 16 JWTClaimsKey contextKey = "jwt_claims" 17 UserAccessToken contextKey = "user_access_token" 18) 19 20// AtProtoAuthMiddleware enforces atProto OAuth authentication for protected routes 21// Validates JWT Bearer tokens from the Authorization header 22type AtProtoAuthMiddleware struct { 23 jwksFetcher auth.JWKSFetcher 24 skipVerify bool // For Phase 1 testing only 25} 26 27// NewAtProtoAuthMiddleware creates a new atProto auth middleware 28// skipVerify: if true, only parses JWT without signature verification (Phase 1) 29// 30// if false, performs full signature verification (Phase 2) 31func NewAtProtoAuthMiddleware(jwksFetcher auth.JWKSFetcher, skipVerify bool) *AtProtoAuthMiddleware { 32 return &AtProtoAuthMiddleware{ 33 jwksFetcher: jwksFetcher, 34 skipVerify: skipVerify, 35 } 36} 37 38// RequireAuth middleware ensures the user is authenticated with a valid JWT 39// If not authenticated, returns 401 40// If authenticated, injects user DID and JWT claims into context 41func (m *AtProtoAuthMiddleware) RequireAuth(next http.Handler) http.Handler { 42 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 43 // Extract Authorization header 44 authHeader := r.Header.Get("Authorization") 45 if authHeader == "" { 46 writeAuthError(w, "Missing Authorization header") 47 return 48 } 49 50 // Must be Bearer token 51 if !strings.HasPrefix(authHeader, "Bearer ") { 52 writeAuthError(w, "Invalid Authorization header format. Expected: Bearer <token>") 53 return 54 } 55 56 token := strings.TrimPrefix(authHeader, "Bearer ") 57 token = strings.TrimSpace(token) 58 59 var claims *auth.Claims 60 var err error 61 62 if m.skipVerify { 63 // Phase 1: Parse only (no signature verification) 64 claims, err = auth.ParseJWT(token) 65 if err != nil { 66 log.Printf("[AUTH_FAILURE] type=parse_error ip=%s method=%s path=%s error=%v", 67 r.RemoteAddr, r.Method, r.URL.Path, err) 68 writeAuthError(w, "Invalid token") 69 return 70 } 71 } else { 72 // Phase 2: Full verification with signature check 73 claims, err = auth.VerifyJWT(r.Context(), token, m.jwksFetcher) 74 if err != nil { 75 // Try to extract issuer for better logging 76 issuer := "unknown" 77 if parsedClaims, parseErr := auth.ParseJWT(token); parseErr == nil { 78 issuer = parsedClaims.Issuer 79 } 80 log.Printf("[AUTH_FAILURE] type=verification_failed ip=%s method=%s path=%s issuer=%s error=%v", 81 r.RemoteAddr, r.Method, r.URL.Path, issuer, err) 82 writeAuthError(w, "Invalid or expired token") 83 return 84 } 85 } 86 87 // Extract user DID from 'sub' claim 88 userDID := claims.Subject 89 if userDID == "" { 90 writeAuthError(w, "Missing user DID in token") 91 return 92 } 93 94 // Inject user info and access token into context 95 ctx := context.WithValue(r.Context(), UserDIDKey, userDID) 96 ctx = context.WithValue(ctx, JWTClaimsKey, claims) 97 ctx = context.WithValue(ctx, UserAccessToken, token) 98 99 // Call next handler 100 next.ServeHTTP(w, r.WithContext(ctx)) 101 }) 102} 103 104// OptionalAuth middleware loads user info if authenticated, but doesn't require it 105// Useful for endpoints that work for both authenticated and anonymous users 106func (m *AtProtoAuthMiddleware) OptionalAuth(next http.Handler) http.Handler { 107 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 108 // Extract Authorization header 109 authHeader := r.Header.Get("Authorization") 110 if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") { 111 // Not authenticated - continue without user context 112 next.ServeHTTP(w, r) 113 return 114 } 115 116 token := strings.TrimPrefix(authHeader, "Bearer ") 117 token = strings.TrimSpace(token) 118 119 var claims *auth.Claims 120 var err error 121 122 if m.skipVerify { 123 // Phase 1: Parse only 124 claims, err = auth.ParseJWT(token) 125 } else { 126 // Phase 2: Full verification 127 claims, err = auth.VerifyJWT(r.Context(), token, m.jwksFetcher) 128 } 129 130 if err != nil { 131 // Invalid token - continue without user context 132 log.Printf("Optional auth failed: %v", err) 133 next.ServeHTTP(w, r) 134 return 135 } 136 137 // Inject user info and access token into context 138 ctx := context.WithValue(r.Context(), UserDIDKey, claims.Subject) 139 ctx = context.WithValue(ctx, JWTClaimsKey, claims) 140 ctx = context.WithValue(ctx, UserAccessToken, token) 141 142 // Call next handler 143 next.ServeHTTP(w, r.WithContext(ctx)) 144 }) 145} 146 147// GetUserDID extracts the user's DID from the request context 148// Returns empty string if not authenticated 149func GetUserDID(r *http.Request) string { 150 did, _ := r.Context().Value(UserDIDKey).(string) 151 return did 152} 153 154// GetJWTClaims extracts the JWT claims from the request context 155// Returns nil if not authenticated 156func GetJWTClaims(r *http.Request) *auth.Claims { 157 claims, _ := r.Context().Value(JWTClaimsKey).(*auth.Claims) 158 return claims 159} 160 161// GetUserAccessToken extracts the user's access token from the request context 162// Returns empty string if not authenticated 163func GetUserAccessToken(r *http.Request) string { 164 token, _ := r.Context().Value(UserAccessToken).(string) 165 return token 166} 167 168// writeAuthError writes a JSON error response for authentication failures 169func writeAuthError(w http.ResponseWriter, message string) { 170 w.Header().Set("Content-Type", "application/json") 171 w.WriteHeader(http.StatusUnauthorized) 172 // Simple error response matching XRPC error format 173 response := `{"error":"AuthenticationRequired","message":"` + message + `"}` 174 if _, err := w.Write([]byte(response)); err != nil { 175 log.Printf("Failed to write auth error response: %v", err) 176 } 177}