A community based topic aggregation platform built on atproto
1package middleware
2
3import (
4 "Coves/internal/atproto/auth"
5 "context"
6 "log"
7 "net/http"
8 "strings"
9)
10
11// Context keys for storing user information
12type contextKey string
13
14const (
15 UserDIDKey contextKey = "user_did"
16 JWTClaimsKey contextKey = "jwt_claims"
17 UserAccessToken contextKey = "user_access_token"
18)
19
20// AtProtoAuthMiddleware enforces atProto OAuth authentication for protected routes
21// Validates JWT Bearer tokens from the Authorization header
22type AtProtoAuthMiddleware struct {
23 jwksFetcher auth.JWKSFetcher
24 skipVerify bool // For Phase 1 testing only
25}
26
27// NewAtProtoAuthMiddleware creates a new atProto auth middleware
28// skipVerify: if true, only parses JWT without signature verification (Phase 1)
29//
30// if false, performs full signature verification (Phase 2)
31func NewAtProtoAuthMiddleware(jwksFetcher auth.JWKSFetcher, skipVerify bool) *AtProtoAuthMiddleware {
32 return &AtProtoAuthMiddleware{
33 jwksFetcher: jwksFetcher,
34 skipVerify: skipVerify,
35 }
36}
37
38// RequireAuth middleware ensures the user is authenticated with a valid JWT
39// If not authenticated, returns 401
40// If authenticated, injects user DID and JWT claims into context
41func (m *AtProtoAuthMiddleware) RequireAuth(next http.Handler) http.Handler {
42 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
43 // Extract Authorization header
44 authHeader := r.Header.Get("Authorization")
45 if authHeader == "" {
46 writeAuthError(w, "Missing Authorization header")
47 return
48 }
49
50 // Must be Bearer token
51 if !strings.HasPrefix(authHeader, "Bearer ") {
52 writeAuthError(w, "Invalid Authorization header format. Expected: Bearer <token>")
53 return
54 }
55
56 token := strings.TrimPrefix(authHeader, "Bearer ")
57 token = strings.TrimSpace(token)
58
59 var claims *auth.Claims
60 var err error
61
62 if m.skipVerify {
63 // Phase 1: Parse only (no signature verification)
64 claims, err = auth.ParseJWT(token)
65 if err != nil {
66 log.Printf("[AUTH_FAILURE] type=parse_error ip=%s method=%s path=%s error=%v",
67 r.RemoteAddr, r.Method, r.URL.Path, err)
68 writeAuthError(w, "Invalid token")
69 return
70 }
71 } else {
72 // Phase 2: Full verification with signature check
73 claims, err = auth.VerifyJWT(r.Context(), token, m.jwksFetcher)
74 if err != nil {
75 // Try to extract issuer for better logging
76 issuer := "unknown"
77 if parsedClaims, parseErr := auth.ParseJWT(token); parseErr == nil {
78 issuer = parsedClaims.Issuer
79 }
80 log.Printf("[AUTH_FAILURE] type=verification_failed ip=%s method=%s path=%s issuer=%s error=%v",
81 r.RemoteAddr, r.Method, r.URL.Path, issuer, err)
82 writeAuthError(w, "Invalid or expired token")
83 return
84 }
85 }
86
87 // Extract user DID from 'sub' claim
88 userDID := claims.Subject
89 if userDID == "" {
90 writeAuthError(w, "Missing user DID in token")
91 return
92 }
93
94 // Inject user info and access token into context
95 ctx := context.WithValue(r.Context(), UserDIDKey, userDID)
96 ctx = context.WithValue(ctx, JWTClaimsKey, claims)
97 ctx = context.WithValue(ctx, UserAccessToken, token)
98
99 // Call next handler
100 next.ServeHTTP(w, r.WithContext(ctx))
101 })
102}
103
104// OptionalAuth middleware loads user info if authenticated, but doesn't require it
105// Useful for endpoints that work for both authenticated and anonymous users
106func (m *AtProtoAuthMiddleware) OptionalAuth(next http.Handler) http.Handler {
107 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
108 // Extract Authorization header
109 authHeader := r.Header.Get("Authorization")
110 if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
111 // Not authenticated - continue without user context
112 next.ServeHTTP(w, r)
113 return
114 }
115
116 token := strings.TrimPrefix(authHeader, "Bearer ")
117 token = strings.TrimSpace(token)
118
119 var claims *auth.Claims
120 var err error
121
122 if m.skipVerify {
123 // Phase 1: Parse only
124 claims, err = auth.ParseJWT(token)
125 } else {
126 // Phase 2: Full verification
127 claims, err = auth.VerifyJWT(r.Context(), token, m.jwksFetcher)
128 }
129
130 if err != nil {
131 // Invalid token - continue without user context
132 log.Printf("Optional auth failed: %v", err)
133 next.ServeHTTP(w, r)
134 return
135 }
136
137 // Inject user info and access token into context
138 ctx := context.WithValue(r.Context(), UserDIDKey, claims.Subject)
139 ctx = context.WithValue(ctx, JWTClaimsKey, claims)
140 ctx = context.WithValue(ctx, UserAccessToken, token)
141
142 // Call next handler
143 next.ServeHTTP(w, r.WithContext(ctx))
144 })
145}
146
147// GetUserDID extracts the user's DID from the request context
148// Returns empty string if not authenticated
149func GetUserDID(r *http.Request) string {
150 did, _ := r.Context().Value(UserDIDKey).(string)
151 return did
152}
153
154// GetJWTClaims extracts the JWT claims from the request context
155// Returns nil if not authenticated
156func GetJWTClaims(r *http.Request) *auth.Claims {
157 claims, _ := r.Context().Value(JWTClaimsKey).(*auth.Claims)
158 return claims
159}
160
161// GetUserAccessToken extracts the user's access token from the request context
162// Returns empty string if not authenticated
163func GetUserAccessToken(r *http.Request) string {
164 token, _ := r.Context().Value(UserAccessToken).(string)
165 return token
166}
167
168// writeAuthError writes a JSON error response for authentication failures
169func writeAuthError(w http.ResponseWriter, message string) {
170 w.Header().Set("Content-Type", "application/json")
171 w.WriteHeader(http.StatusUnauthorized)
172 // Simple error response matching XRPC error format
173 response := `{"error":"AuthenticationRequired","message":"` + message + `"}`
174 if _, err := w.Write([]byte(response)); err != nil {
175 log.Printf("Failed to write auth error response: %v", err)
176 }
177}