A community based topic aggregation platform built on atproto
1package middleware
2
3import (
4 "context"
5 "fmt"
6 "log"
7 "net/http"
8 "os"
9 "strings"
10
11 "Coves/internal/api/handlers/oauth"
12 atprotoOAuth "Coves/internal/atproto/oauth"
13 oauthCore "Coves/internal/core/oauth"
14)
15
16// Context keys for storing user information
17type contextKey string
18
19const (
20 UserDIDKey contextKey = "user_did"
21 OAuthSessionKey contextKey = "oauth_session"
22)
23
24const (
25 sessionName = "coves_session"
26 sessionDID = "did"
27)
28
29// AuthMiddleware enforces OAuth authentication for protected routes
30type AuthMiddleware struct {
31 authService *oauthCore.AuthService
32}
33
34// NewAuthMiddleware creates a new auth middleware
35func NewAuthMiddleware(sessionStore oauthCore.SessionStore) (*AuthMiddleware, error) {
36 privateJWK := os.Getenv("OAUTH_PRIVATE_JWK")
37 if privateJWK == "" {
38 return nil, fmt.Errorf("OAUTH_PRIVATE_JWK not configured")
39 }
40
41 // Parse OAuth client key
42 privateKey, err := atprotoOAuth.ParseJWKFromJSON([]byte(privateJWK))
43 if err != nil {
44 return nil, fmt.Errorf("failed to parse OAuth private key: %w", err)
45 }
46
47 // Get AppView URL
48 appviewURL := os.Getenv("APPVIEW_PUBLIC_URL")
49 if appviewURL == "" {
50 appviewURL = "http://localhost:8081"
51 }
52
53 // Determine client ID
54 var clientID string
55 if strings.HasPrefix(appviewURL, "http://localhost") || strings.HasPrefix(appviewURL, "http://127.0.0.1") {
56 clientID = "http://localhost?redirect_uri=" + appviewURL + "/oauth/callback&scope=atproto%20transition:generic"
57 } else {
58 clientID = appviewURL + "/oauth/client-metadata.json"
59 }
60
61 redirectURI := appviewURL + "/oauth/callback"
62
63 oauthClient := atprotoOAuth.NewClient(clientID, privateKey, redirectURI)
64 authService := oauthCore.NewAuthService(sessionStore, oauthClient)
65
66 return &AuthMiddleware{
67 authService: authService,
68 }, nil
69}
70
71// RequireAuth middleware ensures the user is authenticated
72// If not authenticated, returns 401
73// If authenticated, injects user DID and OAuth session into context
74func (m *AuthMiddleware) RequireAuth(next http.Handler) http.Handler {
75 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
76 // Get HTTP session
77 cookieStore := oauth.GetCookieStore()
78 httpSession, err := cookieStore.Get(r, sessionName)
79 if err != nil || httpSession.IsNew {
80 http.Error(w, "Unauthorized", http.StatusUnauthorized)
81 return
82 }
83
84 // Get DID from session
85 did, ok := httpSession.Values[sessionDID].(string)
86 if !ok || did == "" {
87 http.Error(w, "Unauthorized", http.StatusUnauthorized)
88 return
89 }
90
91 // Load OAuth session from database
92 session, err := m.authService.ValidateSession(r.Context(), did)
93 if err != nil {
94 log.Printf("Failed to load OAuth session for DID %s: %v", did, err)
95 http.Error(w, "Session expired", http.StatusUnauthorized)
96 return
97 }
98
99 // Check if token needs refresh and refresh if necessary
100 session, err = m.authService.RefreshTokenIfNeeded(r.Context(), session, oauth.TokenRefreshThreshold)
101 if err != nil {
102 log.Printf("Failed to refresh token for DID %s: %v", did, err)
103 http.Error(w, "Session expired", http.StatusUnauthorized)
104 return
105 }
106
107 // Inject user info into context
108 ctx := context.WithValue(r.Context(), UserDIDKey, did)
109 ctx = context.WithValue(ctx, OAuthSessionKey, session)
110
111 // Call next handler
112 next.ServeHTTP(w, r.WithContext(ctx))
113 })
114}
115
116// OptionalAuth middleware loads user info if authenticated, but doesn't require it
117// Useful for endpoints that work for both authenticated and anonymous users
118func (m *AuthMiddleware) OptionalAuth(next http.Handler) http.Handler {
119 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
120 // Get HTTP session
121 cookieStore := oauth.GetCookieStore()
122 httpSession, err := cookieStore.Get(r, sessionName)
123 if err != nil || httpSession.IsNew {
124 // Not authenticated - continue without user context
125 next.ServeHTTP(w, r)
126 return
127 }
128
129 // Get DID from session
130 did, ok := httpSession.Values[sessionDID].(string)
131 if !ok || did == "" {
132 // No DID - continue without user context
133 next.ServeHTTP(w, r)
134 return
135 }
136
137 // Load OAuth session from database
138 session, err := m.authService.ValidateSession(r.Context(), did)
139 if err != nil {
140 // Session expired - continue without user context
141 next.ServeHTTP(w, r)
142 return
143 }
144
145 // Try to refresh token if needed (best effort)
146 refreshedSession, err := m.authService.RefreshTokenIfNeeded(r.Context(), session, oauth.TokenRefreshThreshold)
147 if err != nil {
148 // If refresh fails, continue with old session (best effort)
149 // Session will still be valid for a few more minutes
150 } else {
151 session = refreshedSession
152 }
153
154 // Inject user info into context
155 ctx := context.WithValue(r.Context(), UserDIDKey, did)
156 ctx = context.WithValue(ctx, OAuthSessionKey, session)
157
158 // Call next handler
159 next.ServeHTTP(w, r.WithContext(ctx))
160 })
161}
162
163// GetUserDID extracts the user's DID from the request context
164// Returns empty string if not authenticated
165func GetUserDID(r *http.Request) string {
166 did, _ := r.Context().Value(UserDIDKey).(string)
167 return did
168}
169
170// GetOAuthSession extracts the OAuth session from the request context
171// Returns nil if not authenticated
172func GetOAuthSession(r *http.Request) *oauthCore.OAuthSession {
173 session, _ := r.Context().Value(OAuthSessionKey).(*oauthCore.OAuthSession)
174 return session
175}