A community based topic aggregation platform built on atproto
1package middleware 2 3import ( 4 "Coves/internal/api/handlers/oauth" 5 "context" 6 "fmt" 7 "log" 8 "net/http" 9 "os" 10 "strings" 11 12 atprotoOAuth "Coves/internal/atproto/oauth" 13 oauthCore "Coves/internal/core/oauth" 14) 15 16// Context keys for storing user information 17type contextKey string 18 19const ( 20 UserDIDKey contextKey = "user_did" 21 OAuthSessionKey contextKey = "oauth_session" 22) 23 24const ( 25 sessionName = "coves_session" 26 sessionDID = "did" 27) 28 29// AuthMiddleware enforces OAuth authentication for protected routes 30type AuthMiddleware struct { 31 authService *oauthCore.AuthService 32} 33 34// NewAuthMiddleware creates a new auth middleware 35func NewAuthMiddleware(sessionStore oauthCore.SessionStore) (*AuthMiddleware, error) { 36 privateJWK := os.Getenv("OAUTH_PRIVATE_JWK") 37 if privateJWK == "" { 38 return nil, fmt.Errorf("OAUTH_PRIVATE_JWK not configured") 39 } 40 41 // Parse OAuth client key 42 privateKey, err := atprotoOAuth.ParseJWKFromJSON([]byte(privateJWK)) 43 if err != nil { 44 return nil, fmt.Errorf("failed to parse OAuth private key: %w", err) 45 } 46 47 // Get AppView URL 48 appviewURL := os.Getenv("APPVIEW_PUBLIC_URL") 49 if appviewURL == "" { 50 appviewURL = "http://localhost:8081" 51 } 52 53 // Determine client ID 54 var clientID string 55 if strings.HasPrefix(appviewURL, "http://localhost") || strings.HasPrefix(appviewURL, "http://127.0.0.1") { 56 clientID = "http://localhost?redirect_uri=" + appviewURL + "/oauth/callback&scope=atproto%20transition:generic" 57 } else { 58 clientID = appviewURL + "/oauth/client-metadata.json" 59 } 60 61 redirectURI := appviewURL + "/oauth/callback" 62 63 oauthClient := atprotoOAuth.NewClient(clientID, privateKey, redirectURI) 64 authService := oauthCore.NewAuthService(sessionStore, oauthClient) 65 66 return &AuthMiddleware{ 67 authService: authService, 68 }, nil 69} 70 71// RequireAuth middleware ensures the user is authenticated 72// If not authenticated, returns 401 73// If authenticated, injects user DID and OAuth session into context 74func (m *AuthMiddleware) RequireAuth(next http.Handler) http.Handler { 75 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 76 // Get HTTP session 77 cookieStore := oauth.GetCookieStore() 78 httpSession, err := cookieStore.Get(r, sessionName) 79 if err != nil || httpSession.IsNew { 80 http.Error(w, "Unauthorized", http.StatusUnauthorized) 81 return 82 } 83 84 // Get DID from session 85 did, ok := httpSession.Values[sessionDID].(string) 86 if !ok || did == "" { 87 http.Error(w, "Unauthorized", http.StatusUnauthorized) 88 return 89 } 90 91 // Load OAuth session from database 92 session, err := m.authService.ValidateSession(r.Context(), did) 93 if err != nil { 94 log.Printf("Failed to load OAuth session for DID %s: %v", did, err) 95 http.Error(w, "Session expired", http.StatusUnauthorized) 96 return 97 } 98 99 // Check if token needs refresh and refresh if necessary 100 session, err = m.authService.RefreshTokenIfNeeded(r.Context(), session, oauth.TokenRefreshThreshold) 101 if err != nil { 102 log.Printf("Failed to refresh token for DID %s: %v", did, err) 103 http.Error(w, "Session expired", http.StatusUnauthorized) 104 return 105 } 106 107 // Inject user info into context 108 ctx := context.WithValue(r.Context(), UserDIDKey, did) 109 ctx = context.WithValue(ctx, OAuthSessionKey, session) 110 111 // Call next handler 112 next.ServeHTTP(w, r.WithContext(ctx)) 113 }) 114} 115 116// OptionalAuth middleware loads user info if authenticated, but doesn't require it 117// Useful for endpoints that work for both authenticated and anonymous users 118func (m *AuthMiddleware) OptionalAuth(next http.Handler) http.Handler { 119 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 120 // Get HTTP session 121 cookieStore := oauth.GetCookieStore() 122 httpSession, err := cookieStore.Get(r, sessionName) 123 if err != nil || httpSession.IsNew { 124 // Not authenticated - continue without user context 125 next.ServeHTTP(w, r) 126 return 127 } 128 129 // Get DID from session 130 did, ok := httpSession.Values[sessionDID].(string) 131 if !ok || did == "" { 132 // No DID - continue without user context 133 next.ServeHTTP(w, r) 134 return 135 } 136 137 // Load OAuth session from database 138 session, err := m.authService.ValidateSession(r.Context(), did) 139 if err != nil { 140 // Session expired - continue without user context 141 next.ServeHTTP(w, r) 142 return 143 } 144 145 // Try to refresh token if needed (best effort) 146 refreshedSession, err := m.authService.RefreshTokenIfNeeded(r.Context(), session, oauth.TokenRefreshThreshold) 147 if err != nil { 148 // If refresh fails, continue with old session (best effort) 149 // Session will still be valid for a few more minutes 150 } else { 151 session = refreshedSession 152 } 153 154 // Inject user info into context 155 ctx := context.WithValue(r.Context(), UserDIDKey, did) 156 ctx = context.WithValue(ctx, OAuthSessionKey, session) 157 158 // Call next handler 159 next.ServeHTTP(w, r.WithContext(ctx)) 160 }) 161} 162 163// GetUserDID extracts the user's DID from the request context 164// Returns empty string if not authenticated 165func GetUserDID(r *http.Request) string { 166 did, _ := r.Context().Value(UserDIDKey).(string) 167 return did 168} 169 170// GetOAuthSession extracts the OAuth session from the request context 171// Returns nil if not authenticated 172func GetOAuthSession(r *http.Request) *oauthCore.OAuthSession { 173 session, _ := r.Context().Value(OAuthSessionKey).(*oauthCore.OAuthSession) 174 return session 175}