A community based topic aggregation platform built on atproto
1package middleware
2
3import (
4 "Coves/internal/atproto/oauth"
5 "context"
6 "encoding/json"
7 "log"
8 "net/http"
9 "strings"
10
11 oauthlib "github.com/bluesky-social/indigo/atproto/auth/oauth"
12 "github.com/bluesky-social/indigo/atproto/syntax"
13)
14
15// Context keys for storing user information
16type contextKey string
17
18const (
19 UserDIDKey contextKey = "user_did"
20 OAuthSessionKey contextKey = "oauth_session"
21 UserAccessToken contextKey = "user_access_token" // Kept for backward compatibility
22)
23
24// SessionUnsealer is an interface for unsealing session tokens
25// This allows for mocking in tests
26type SessionUnsealer interface {
27 UnsealSession(token string) (*oauth.SealedSession, error)
28}
29
30// OAuthAuthMiddleware enforces OAuth authentication using sealed session tokens.
31type OAuthAuthMiddleware struct {
32 unsealer SessionUnsealer
33 store oauthlib.ClientAuthStore
34}
35
36// NewOAuthAuthMiddleware creates a new OAuth auth middleware using sealed session tokens.
37func NewOAuthAuthMiddleware(unsealer SessionUnsealer, store oauthlib.ClientAuthStore) *OAuthAuthMiddleware {
38 return &OAuthAuthMiddleware{
39 unsealer: unsealer,
40 store: store,
41 }
42}
43
44// RequireAuth middleware ensures the user is authenticated.
45// Supports sealed session tokens via:
46// - Authorization: Bearer <sealed_token>
47// - Cookie: coves_session=<sealed_token>
48//
49// If not authenticated, returns 401.
50// If authenticated, injects user DID into context.
51func (m *OAuthAuthMiddleware) RequireAuth(next http.Handler) http.Handler {
52 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
53 var token string
54
55 // Try Authorization header first (for mobile/API clients)
56 authHeader := r.Header.Get("Authorization")
57 if authHeader != "" {
58 var ok bool
59 token, ok = extractBearerToken(authHeader)
60 if !ok {
61 writeAuthError(w, "Invalid Authorization header format. Expected: Bearer <token>")
62 return
63 }
64 }
65
66 // If no header, try session cookie (for web clients)
67 if token == "" {
68 if cookie, err := r.Cookie("coves_session"); err == nil {
69 token = cookie.Value
70 }
71 }
72
73 // Must have authentication from either source
74 if token == "" {
75 writeAuthError(w, "Missing authentication")
76 return
77 }
78
79 // Authenticate using sealed token
80 sealedSession, err := m.unsealer.UnsealSession(token)
81 if err != nil {
82 log.Printf("[AUTH_FAILURE] type=unseal_failed ip=%s method=%s path=%s error=%v",
83 r.RemoteAddr, r.Method, r.URL.Path, err)
84 writeAuthError(w, "Invalid or expired token")
85 return
86 }
87
88 // Parse DID
89 did, err := syntax.ParseDID(sealedSession.DID)
90 if err != nil {
91 log.Printf("[AUTH_FAILURE] type=invalid_did ip=%s method=%s path=%s did=%s error=%v",
92 r.RemoteAddr, r.Method, r.URL.Path, sealedSession.DID, err)
93 writeAuthError(w, "Invalid DID in token")
94 return
95 }
96
97 // Load full OAuth session from database
98 session, err := m.store.GetSession(r.Context(), did, sealedSession.SessionID)
99 if err != nil {
100 log.Printf("[AUTH_FAILURE] type=session_not_found ip=%s method=%s path=%s did=%s session_id=%s error=%v",
101 r.RemoteAddr, r.Method, r.URL.Path, sealedSession.DID, sealedSession.SessionID, err)
102 writeAuthError(w, "Session not found or expired")
103 return
104 }
105
106 // Verify session DID matches token DID
107 if session.AccountDID.String() != sealedSession.DID {
108 log.Printf("[AUTH_FAILURE] type=did_mismatch ip=%s method=%s path=%s token_did=%s session_did=%s",
109 r.RemoteAddr, r.Method, r.URL.Path, sealedSession.DID, session.AccountDID.String())
110 writeAuthError(w, "Session DID mismatch")
111 return
112 }
113
114 log.Printf("[AUTH_SUCCESS] ip=%s method=%s path=%s did=%s session_id=%s",
115 r.RemoteAddr, r.Method, r.URL.Path, sealedSession.DID, sealedSession.SessionID)
116
117 // Inject user info and session into context
118 ctx := context.WithValue(r.Context(), UserDIDKey, sealedSession.DID)
119 ctx = context.WithValue(ctx, OAuthSessionKey, session)
120 // Store access token for backward compatibility
121 ctx = context.WithValue(ctx, UserAccessToken, session.AccessToken)
122
123 // Call next handler
124 next.ServeHTTP(w, r.WithContext(ctx))
125 })
126}
127
128// OptionalAuth middleware loads user info if authenticated, but doesn't require it.
129// Useful for endpoints that work for both authenticated and anonymous users.
130//
131// Supports sealed session tokens via:
132// - Authorization: Bearer <sealed_token>
133// - Cookie: coves_session=<sealed_token>
134//
135// If authentication fails, continues without user context (does not return error).
136func (m *OAuthAuthMiddleware) OptionalAuth(next http.Handler) http.Handler {
137 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
138 var token string
139
140 // Try Authorization header first (for mobile/API clients)
141 authHeader := r.Header.Get("Authorization")
142 if authHeader != "" {
143 var ok bool
144 token, ok = extractBearerToken(authHeader)
145 if !ok {
146 // Invalid format - continue without user context
147 next.ServeHTTP(w, r)
148 return
149 }
150 }
151
152 // If no header, try session cookie (for web clients)
153 if token == "" {
154 if cookie, err := r.Cookie("coves_session"); err == nil {
155 token = cookie.Value
156 }
157 }
158
159 // If still no token, continue without authentication
160 if token == "" {
161 next.ServeHTTP(w, r)
162 return
163 }
164
165 // Try to authenticate (don't write errors, just continue without user context on failure)
166 sealedSession, err := m.unsealer.UnsealSession(token)
167 if err != nil {
168 next.ServeHTTP(w, r)
169 return
170 }
171
172 // Parse DID
173 did, err := syntax.ParseDID(sealedSession.DID)
174 if err != nil {
175 log.Printf("[AUTH_WARNING] Optional auth: invalid DID: %v", err)
176 next.ServeHTTP(w, r)
177 return
178 }
179
180 // Load full OAuth session from database
181 session, err := m.store.GetSession(r.Context(), did, sealedSession.SessionID)
182 if err != nil {
183 log.Printf("[AUTH_WARNING] Optional auth: session not found: %v", err)
184 next.ServeHTTP(w, r)
185 return
186 }
187
188 // Verify session DID matches token DID
189 if session.AccountDID.String() != sealedSession.DID {
190 log.Printf("[AUTH_WARNING] Optional auth: DID mismatch")
191 next.ServeHTTP(w, r)
192 return
193 }
194
195 // Build authenticated context
196 ctx := context.WithValue(r.Context(), UserDIDKey, sealedSession.DID)
197 ctx = context.WithValue(ctx, OAuthSessionKey, session)
198 ctx = context.WithValue(ctx, UserAccessToken, session.AccessToken)
199
200 next.ServeHTTP(w, r.WithContext(ctx))
201 })
202}
203
204// GetUserDID extracts the user's DID from the request context
205// Returns empty string if not authenticated
206func GetUserDID(r *http.Request) string {
207 did, _ := r.Context().Value(UserDIDKey).(string)
208 return did
209}
210
211// GetAuthenticatedDID extracts the authenticated user's DID from the context
212// This is used by service layers for defense-in-depth validation
213// Returns empty string if not authenticated
214func GetAuthenticatedDID(ctx context.Context) string {
215 did, _ := ctx.Value(UserDIDKey).(string)
216 return did
217}
218
219// GetOAuthSession extracts the OAuth session from the request context
220// Returns nil if not authenticated
221// Handlers can use this to make authenticated PDS calls
222func GetOAuthSession(r *http.Request) *oauthlib.ClientSessionData {
223 session, _ := r.Context().Value(OAuthSessionKey).(*oauthlib.ClientSessionData)
224 return session
225}
226
227// GetUserAccessToken extracts the user's access token from the request context
228// Returns empty string if not authenticated
229func GetUserAccessToken(r *http.Request) string {
230 token, _ := r.Context().Value(UserAccessToken).(string)
231 return token
232}
233
234// SetTestUserDID sets the user DID in the context for testing purposes
235// This function should ONLY be used in tests to mock authenticated users
236func SetTestUserDID(ctx context.Context, userDID string) context.Context {
237 return context.WithValue(ctx, UserDIDKey, userDID)
238}
239
240// extractBearerToken extracts the token from a Bearer Authorization header.
241// HTTP auth schemes are case-insensitive per RFC 7235, so "Bearer", "bearer", "BEARER" are all valid.
242// Returns the token and true if valid Bearer scheme, empty string and false otherwise.
243func extractBearerToken(authHeader string) (string, bool) {
244 if authHeader == "" {
245 return "", false
246 }
247
248 // Split on first space: "Bearer <token>" -> ["Bearer", "<token>"]
249 parts := strings.SplitN(authHeader, " ", 2)
250 if len(parts) != 2 {
251 return "", false
252 }
253
254 // Case-insensitive scheme comparison per RFC 7235
255 if !strings.EqualFold(parts[0], "Bearer") {
256 return "", false
257 }
258
259 token := strings.TrimSpace(parts[1])
260 if token == "" {
261 return "", false
262 }
263
264 return token, true
265}
266
267// writeAuthError writes a JSON error response for authentication failures
268func writeAuthError(w http.ResponseWriter, message string) {
269 w.Header().Set("Content-Type", "application/json")
270 w.WriteHeader(http.StatusUnauthorized)
271 // Use json.NewEncoder to properly escape the message and prevent injection
272 if err := json.NewEncoder(w).Encode(map[string]string{
273 "error": "AuthenticationRequired",
274 "message": message,
275 }); err != nil {
276 log.Printf("Failed to write auth error response: %v", err)
277 }
278}