A community based topic aggregation platform built on atproto
1package middleware 2 3import ( 4 "Coves/internal/atproto/oauth" 5 "context" 6 "encoding/json" 7 "log" 8 "net/http" 9 "strings" 10 11 oauthlib "github.com/bluesky-social/indigo/atproto/auth/oauth" 12 "github.com/bluesky-social/indigo/atproto/syntax" 13) 14 15// Context keys for storing user information 16type contextKey string 17 18const ( 19 UserDIDKey contextKey = "user_did" 20 OAuthSessionKey contextKey = "oauth_session" 21 UserAccessToken contextKey = "user_access_token" // Kept for backward compatibility 22) 23 24// SessionUnsealer is an interface for unsealing session tokens 25// This allows for mocking in tests 26type SessionUnsealer interface { 27 UnsealSession(token string) (*oauth.SealedSession, error) 28} 29 30// OAuthAuthMiddleware enforces OAuth authentication using sealed session tokens. 31type OAuthAuthMiddleware struct { 32 unsealer SessionUnsealer 33 store oauthlib.ClientAuthStore 34} 35 36// NewOAuthAuthMiddleware creates a new OAuth auth middleware using sealed session tokens. 37func NewOAuthAuthMiddleware(unsealer SessionUnsealer, store oauthlib.ClientAuthStore) *OAuthAuthMiddleware { 38 return &OAuthAuthMiddleware{ 39 unsealer: unsealer, 40 store: store, 41 } 42} 43 44// RequireAuth middleware ensures the user is authenticated. 45// Supports sealed session tokens via: 46// - Authorization: Bearer <sealed_token> 47// - Cookie: coves_session=<sealed_token> 48// 49// If not authenticated, returns 401. 50// If authenticated, injects user DID into context. 51func (m *OAuthAuthMiddleware) RequireAuth(next http.Handler) http.Handler { 52 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 53 var token string 54 55 // Try Authorization header first (for mobile/API clients) 56 authHeader := r.Header.Get("Authorization") 57 if authHeader != "" { 58 var ok bool 59 token, ok = extractBearerToken(authHeader) 60 if !ok { 61 writeAuthError(w, "Invalid Authorization header format. Expected: Bearer <token>") 62 return 63 } 64 } 65 66 // If no header, try session cookie (for web clients) 67 if token == "" { 68 if cookie, err := r.Cookie("coves_session"); err == nil { 69 token = cookie.Value 70 } 71 } 72 73 // Must have authentication from either source 74 if token == "" { 75 writeAuthError(w, "Missing authentication") 76 return 77 } 78 79 // Authenticate using sealed token 80 sealedSession, err := m.unsealer.UnsealSession(token) 81 if err != nil { 82 log.Printf("[AUTH_FAILURE] type=unseal_failed ip=%s method=%s path=%s error=%v", 83 r.RemoteAddr, r.Method, r.URL.Path, err) 84 writeAuthError(w, "Invalid or expired token") 85 return 86 } 87 88 // Parse DID 89 did, err := syntax.ParseDID(sealedSession.DID) 90 if err != nil { 91 log.Printf("[AUTH_FAILURE] type=invalid_did ip=%s method=%s path=%s did=%s error=%v", 92 r.RemoteAddr, r.Method, r.URL.Path, sealedSession.DID, err) 93 writeAuthError(w, "Invalid DID in token") 94 return 95 } 96 97 // Load full OAuth session from database 98 session, err := m.store.GetSession(r.Context(), did, sealedSession.SessionID) 99 if err != nil { 100 log.Printf("[AUTH_FAILURE] type=session_not_found ip=%s method=%s path=%s did=%s session_id=%s error=%v", 101 r.RemoteAddr, r.Method, r.URL.Path, sealedSession.DID, sealedSession.SessionID, err) 102 writeAuthError(w, "Session not found or expired") 103 return 104 } 105 106 // Verify session DID matches token DID 107 if session.AccountDID.String() != sealedSession.DID { 108 log.Printf("[AUTH_FAILURE] type=did_mismatch ip=%s method=%s path=%s token_did=%s session_did=%s", 109 r.RemoteAddr, r.Method, r.URL.Path, sealedSession.DID, session.AccountDID.String()) 110 writeAuthError(w, "Session DID mismatch") 111 return 112 } 113 114 log.Printf("[AUTH_SUCCESS] ip=%s method=%s path=%s did=%s session_id=%s", 115 r.RemoteAddr, r.Method, r.URL.Path, sealedSession.DID, sealedSession.SessionID) 116 117 // Inject user info and session into context 118 ctx := context.WithValue(r.Context(), UserDIDKey, sealedSession.DID) 119 ctx = context.WithValue(ctx, OAuthSessionKey, session) 120 // Store access token for backward compatibility 121 ctx = context.WithValue(ctx, UserAccessToken, session.AccessToken) 122 123 // Call next handler 124 next.ServeHTTP(w, r.WithContext(ctx)) 125 }) 126} 127 128// OptionalAuth middleware loads user info if authenticated, but doesn't require it. 129// Useful for endpoints that work for both authenticated and anonymous users. 130// 131// Supports sealed session tokens via: 132// - Authorization: Bearer <sealed_token> 133// - Cookie: coves_session=<sealed_token> 134// 135// If authentication fails, continues without user context (does not return error). 136func (m *OAuthAuthMiddleware) OptionalAuth(next http.Handler) http.Handler { 137 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 138 var token string 139 140 // Try Authorization header first (for mobile/API clients) 141 authHeader := r.Header.Get("Authorization") 142 if authHeader != "" { 143 var ok bool 144 token, ok = extractBearerToken(authHeader) 145 if !ok { 146 // Invalid format - continue without user context 147 next.ServeHTTP(w, r) 148 return 149 } 150 } 151 152 // If no header, try session cookie (for web clients) 153 if token == "" { 154 if cookie, err := r.Cookie("coves_session"); err == nil { 155 token = cookie.Value 156 } 157 } 158 159 // If still no token, continue without authentication 160 if token == "" { 161 next.ServeHTTP(w, r) 162 return 163 } 164 165 // Try to authenticate (don't write errors, just continue without user context on failure) 166 sealedSession, err := m.unsealer.UnsealSession(token) 167 if err != nil { 168 next.ServeHTTP(w, r) 169 return 170 } 171 172 // Parse DID 173 did, err := syntax.ParseDID(sealedSession.DID) 174 if err != nil { 175 log.Printf("[AUTH_WARNING] Optional auth: invalid DID: %v", err) 176 next.ServeHTTP(w, r) 177 return 178 } 179 180 // Load full OAuth session from database 181 session, err := m.store.GetSession(r.Context(), did, sealedSession.SessionID) 182 if err != nil { 183 log.Printf("[AUTH_WARNING] Optional auth: session not found: %v", err) 184 next.ServeHTTP(w, r) 185 return 186 } 187 188 // Verify session DID matches token DID 189 if session.AccountDID.String() != sealedSession.DID { 190 log.Printf("[AUTH_WARNING] Optional auth: DID mismatch") 191 next.ServeHTTP(w, r) 192 return 193 } 194 195 // Build authenticated context 196 ctx := context.WithValue(r.Context(), UserDIDKey, sealedSession.DID) 197 ctx = context.WithValue(ctx, OAuthSessionKey, session) 198 ctx = context.WithValue(ctx, UserAccessToken, session.AccessToken) 199 200 next.ServeHTTP(w, r.WithContext(ctx)) 201 }) 202} 203 204// GetUserDID extracts the user's DID from the request context 205// Returns empty string if not authenticated 206func GetUserDID(r *http.Request) string { 207 did, _ := r.Context().Value(UserDIDKey).(string) 208 return did 209} 210 211// GetAuthenticatedDID extracts the authenticated user's DID from the context 212// This is used by service layers for defense-in-depth validation 213// Returns empty string if not authenticated 214func GetAuthenticatedDID(ctx context.Context) string { 215 did, _ := ctx.Value(UserDIDKey).(string) 216 return did 217} 218 219// GetOAuthSession extracts the OAuth session from the request context 220// Returns nil if not authenticated 221// Handlers can use this to make authenticated PDS calls 222func GetOAuthSession(r *http.Request) *oauthlib.ClientSessionData { 223 session, _ := r.Context().Value(OAuthSessionKey).(*oauthlib.ClientSessionData) 224 return session 225} 226 227// GetUserAccessToken extracts the user's access token from the request context 228// Returns empty string if not authenticated 229func GetUserAccessToken(r *http.Request) string { 230 token, _ := r.Context().Value(UserAccessToken).(string) 231 return token 232} 233 234// SetTestUserDID sets the user DID in the context for testing purposes 235// This function should ONLY be used in tests to mock authenticated users 236func SetTestUserDID(ctx context.Context, userDID string) context.Context { 237 return context.WithValue(ctx, UserDIDKey, userDID) 238} 239 240// extractBearerToken extracts the token from a Bearer Authorization header. 241// HTTP auth schemes are case-insensitive per RFC 7235, so "Bearer", "bearer", "BEARER" are all valid. 242// Returns the token and true if valid Bearer scheme, empty string and false otherwise. 243func extractBearerToken(authHeader string) (string, bool) { 244 if authHeader == "" { 245 return "", false 246 } 247 248 // Split on first space: "Bearer <token>" -> ["Bearer", "<token>"] 249 parts := strings.SplitN(authHeader, " ", 2) 250 if len(parts) != 2 { 251 return "", false 252 } 253 254 // Case-insensitive scheme comparison per RFC 7235 255 if !strings.EqualFold(parts[0], "Bearer") { 256 return "", false 257 } 258 259 token := strings.TrimSpace(parts[1]) 260 if token == "" { 261 return "", false 262 } 263 264 return token, true 265} 266 267// writeAuthError writes a JSON error response for authentication failures 268func writeAuthError(w http.ResponseWriter, message string) { 269 w.Header().Set("Content-Type", "application/json") 270 w.WriteHeader(http.StatusUnauthorized) 271 // Use json.NewEncoder to properly escape the message and prevent injection 272 if err := json.NewEncoder(w).Encode(map[string]string{ 273 "error": "AuthenticationRequired", 274 "message": message, 275 }); err != nil { 276 log.Printf("Failed to write auth error response: %v", err) 277 } 278}