A community based topic aggregation platform built on atproto

Compare changes

Choose any two refs to compare.

Changed files
+7363 -1698
.beads
cmd
genjwks
reindex-votes
docs
internal
scripts
aggregator-setup
static
tests
+5 -5
docs/COMMENT_SYSTEM_IMPLEMENTATION.md
···
- Lexicon definitions: `social.coves.community.comment.defs` and `getComments`
- Database query methods with Lemmy hot ranking algorithm
- Service layer with iterative loading strategy for nested replies
-
- XRPC HTTP handler with optional authentication
+
- XRPC HTTP handler with optional DPoP authentication
- Comprehensive integration test suite (11 test scenarios)
**What works:**
···
- Nested replies up to configurable depth (default 10, max 100)
- Lemmy hot ranking: `log(greatest(2, score + 2)) / power(time_decay, 1.8)`
- Cursor-based pagination for stable scrolling
-
- Optional authentication for viewer state (stubbed for Phase 2B)
+
- Optional DPoP authentication for viewer state (stubbed for Phase 2B)
- Timeframe filtering for "top" sort (hour/day/week/month/year/all)
**Endpoints:**
···
- Required: `post` (AT-URI)
- Optional: `sort` (hot/top/new), `depth` (0-100), `limit` (1-100), `cursor`, `timeframe`
- Returns: Array of `threadViewComment` with nested replies + post context
-
- Supports Bearer token for authenticated requests (viewer state)
+
- Supports DPoP-bound access token for authenticated requests (viewer state)
**Files created (9):**
1. `internal/atproto/lexicon/social/coves/community/comment/defs.json` - View definitions
···
**8. Viewer Authentication Validation (Non-Issue - Architecture Working as Designed)**
- **Initial Concern:** ViewerDID field trusted without verification in service layer
- **Investigation:** Authentication IS properly validated at middleware layer
-
- `OptionalAuth` middleware extracts and validates JWT Bearer tokens
+
- `OptionalAuth` middleware extracts and validates DPoP-bound access tokens
- Uses PDS public keys (JWKS) for signature verification
-
- Validates token expiration, DID format, issuer
+
- Validates DPoP proof, token expiration, DID format, issuer
- Only injects verified DIDs into request context
- Handler extracts DID using `middleware.GetUserDID(r)`
- **Architecture:** Follows industry best practices (authentication at perimeter)
+7 -4
docs/FEED_SYSTEM_IMPLEMENTATION.md
···
# Get personalized timeline (hot posts from subscriptions)
curl -X GET \
'http://localhost:8081/xrpc/social.coves.feed.getTimeline?sort=hot&limit=15' \
-
-H 'Authorization: Bearer eyJhbGc...'
+
-H 'Authorization: DPoP eyJhbGc...' \
+
-H 'DPoP: eyJhbGc...'
# Get top posts from last week
curl -X GET \
'http://localhost:8081/xrpc/social.coves.feed.getTimeline?sort=top&timeframe=week&limit=20' \
-
-H 'Authorization: Bearer eyJhbGc...'
+
-H 'Authorization: DPoP eyJhbGc...' \
+
-H 'DPoP: eyJhbGc...'
# Get newest posts with pagination
curl -X GET \
'http://localhost:8081/xrpc/social.coves.feed.getTimeline?sort=new&limit=10&cursor=<cursor>' \
-
-H 'Authorization: Bearer eyJhbGc...'
+
-H 'Authorization: DPoP eyJhbGc...' \
+
-H 'DPoP: eyJhbGc...'
```
**Response:**
···
- โœ… Context timeout support
### Authentication (Timeline)
-
- โœ… JWT Bearer token required
+
- โœ… DPoP-bound access token required
- โœ… DID extracted from auth context
- โœ… Validates token signature (when AUTH_SKIP_VERIFY=false)
- โœ… Returns 401 on auth failure
+3 -3
docs/PRD_OAUTH.md
···
- โœ… Auth middleware protecting community endpoints
- โœ… Handlers updated to use `GetUserDID(r)`
- โœ… Comprehensive middleware auth tests (11 test cases)
-
- โœ… E2E tests updated to use Bearer tokens
+
- โœ… E2E tests updated to use DPoP-bound tokens
- โœ… Security logging with IP, method, path, issuer
- โœ… Scope validation (atproto required)
- โœ… Issuer HTTPS validation
···
Authorization: DPoP eyJhbGciOiJFUzI1NiIsInR5cCI6ImF0K2p3dCIsImtpZCI6ImRpZDpwbGM6YWxpY2UjYXRwcm90by1wZHMifQ...
```
-
Format: `DPoP <access_token>`
+
Format: `DPoP <access_token>` (note: uses "DPoP" scheme, not "Bearer")
The access token is a JWT containing:
```json
···
- [x] All community endpoints reject requests without valid JWT structure
- [x] Integration tests pass with mock tokens (11/11 middleware tests passing)
- [x] Zero security regressions from X-User-DID (JWT validation is strictly better)
-
- [x] E2E tests updated to use proper Bearer token authentication
+
- [x] E2E tests updated to use proper DPoP token authentication
- [x] Build succeeds without compilation errors
### Phase 2 (Beta) - โœ… READY FOR TESTING
+3 -1
docs/aggregators/SETUP_GUIDE.md
···
**Request**:
```bash
+
# Note: This calls the PDS directly, so it uses Bearer authorization (not DPoP)
curl -X POST https://bsky.social/xrpc/com.atproto.repo.createRecord \
-H "Authorization: Bearer YOUR_ACCESS_TOKEN" \
-H "Content-Type: application/json" \
···
**Request**:
```bash
+
# Note: This calls the Coves API, so it uses DPoP authorization
curl -X POST https://api.coves.social/xrpc/social.coves.community.post.create \
-
-H "Authorization: Bearer YOUR_ACCESS_TOKEN" \
+
-H "Authorization: DPoP YOUR_ACCESS_TOKEN" \
-H "Content-Type: application/json" \
-d '{
"communityDid": "did:plc:community123...",
+8 -2
docs/federation-prd.md
···
req, _ := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewBuffer(jsonData))
// Use service auth token instead of community credentials
+
// NOTE: Auth scheme depends on target PDS implementation:
+
// - Standard atproto service auth uses "Bearer" scheme
+
// - Our AppView uses "DPoP" scheme when DPoP-bound tokens are required
+
// For server-to-server with standard PDS, use Bearer; adjust based on target.
req.Header.Set("Authorization", "Bearer "+serviceAuthToken)
req.Header.Set("Content-Type", "application/json")
···
**Request to Remote PDS:**
```http
POST https://covesinstance.com/xrpc/com.atproto.server.getServiceAuth
-
Authorization: Bearer {coves-social-instance-jwt}
+
Authorization: DPoP {coves-social-instance-jwt}
+
DPoP: {coves-social-dpop-proof}
Content-Type: application/json
{
···
**Using Token to Create Post:**
```http
POST https://covesinstance.com/xrpc/com.atproto.repo.createRecord
-
Authorization: Bearer {service-auth-token}
+
Authorization: DPoP {service-auth-token}
+
DPoP: {service-auth-dpop-proof}
Content-Type: application/json
{
+1 -1
scripts/aggregator-setup/README.md
···
```bash
curl -X POST https://api.coves.social/xrpc/social.coves.community.post.create \
-
-H "Authorization: Bearer $AGGREGATOR_ACCESS_JWT" \
+
-H "Authorization: DPoP $AGGREGATOR_ACCESS_JWT" \
-H "Content-Type: application/json" \
-d '{
"communityDid": "did:plc:...",
+1
tests/integration/blob_upload_e2e_test.go
···
assert.Equal(t, "POST", r.Method, "Should be POST request")
assert.Equal(t, "/xrpc/com.atproto.repo.uploadBlob", r.URL.Path, "Should hit uploadBlob endpoint")
assert.Equal(t, "image/png", r.Header.Get("Content-Type"), "Should have correct content type")
+
// Note: This is a PDS call, so it uses Bearer (not DPoP)
assert.Contains(t, r.Header.Get("Authorization"), "Bearer ", "Should have auth header")
// Return mock blob reference
+477
internal/atproto/oauth/handlers_security_test.go
···
+
package oauth
+
+
import (
+
"net/http"
+
"net/http/httptest"
+
"testing"
+
+
"github.com/stretchr/testify/assert"
+
"github.com/stretchr/testify/require"
+
)
+
+
// TestIsAllowedMobileRedirectURI tests the mobile redirect URI allowlist with EXACT URI matching
+
// Only Universal Links (HTTPS) are allowed - custom schemes are blocked for security
+
func TestIsAllowedMobileRedirectURI(t *testing.T) {
+
tests := []struct {
+
name string
+
uri string
+
expected bool
+
}{
+
{
+
name: "allowed - Universal Link",
+
uri: "https://coves.social/app/oauth/callback",
+
expected: true,
+
},
+
{
+
name: "rejected - custom scheme coves-app (vulnerable to interception)",
+
uri: "coves-app://oauth/callback",
+
expected: false,
+
},
+
{
+
name: "rejected - custom scheme coves (vulnerable to interception)",
+
uri: "coves://oauth/callback",
+
expected: false,
+
},
+
{
+
name: "rejected - evil scheme",
+
uri: "evil://callback",
+
expected: false,
+
},
+
{
+
name: "rejected - http (not secure)",
+
uri: "http://example.com/callback",
+
expected: false,
+
},
+
{
+
name: "rejected - https different domain",
+
uri: "https://example.com/callback",
+
expected: false,
+
},
+
{
+
name: "rejected - https coves.social wrong path",
+
uri: "https://coves.social/wrong/path",
+
expected: false,
+
},
+
{
+
name: "rejected - invalid URI",
+
uri: "not a uri",
+
expected: false,
+
},
+
{
+
name: "rejected - empty string",
+
uri: "",
+
expected: false,
+
},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
result := isAllowedMobileRedirectURI(tt.uri)
+
assert.Equal(t, tt.expected, result,
+
"isAllowedMobileRedirectURI(%q) = %v, want %v", tt.uri, result, tt.expected)
+
})
+
}
+
}
+
+
// TestExtractScheme tests the scheme extraction function
+
func TestExtractScheme(t *testing.T) {
+
tests := []struct {
+
name string
+
uri string
+
expected string
+
}{
+
{
+
name: "https scheme",
+
uri: "https://coves.social/app/oauth/callback",
+
expected: "https",
+
},
+
{
+
name: "custom scheme",
+
uri: "coves-app://callback",
+
expected: "coves-app",
+
},
+
{
+
name: "invalid URI",
+
uri: "not a uri",
+
expected: "invalid",
+
},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
result := extractScheme(tt.uri)
+
assert.Equal(t, tt.expected, result)
+
})
+
}
+
}
+
+
// TestGenerateCSRFToken tests CSRF token generation
+
func TestGenerateCSRFToken(t *testing.T) {
+
// Generate two tokens and verify they are different (randomness check)
+
token1, err1 := generateCSRFToken()
+
require.NoError(t, err1)
+
require.NotEmpty(t, token1)
+
+
token2, err2 := generateCSRFToken()
+
require.NoError(t, err2)
+
require.NotEmpty(t, token2)
+
+
assert.NotEqual(t, token1, token2, "CSRF tokens should be unique")
+
+
// Verify token is base64 encoded (should decode without error)
+
assert.Greater(t, len(token1), 40, "CSRF token should be reasonably long (32 bytes base64 encoded)")
+
}
+
+
// TestHandleMobileLogin_RedirectURIValidation tests that HandleMobileLogin validates redirect URIs
+
func TestHandleMobileLogin_RedirectURIValidation(t *testing.T) {
+
// Note: This is a unit test for the validation logic only.
+
// Full integration tests with OAuth flow are in tests/integration/oauth_e2e_test.go
+
+
tests := []struct {
+
name string
+
redirectURI string
+
expectedLog string
+
expectedStatus int
+
}{
+
{
+
name: "allowed - Universal Link",
+
redirectURI: "https://coves.social/app/oauth/callback",
+
expectedStatus: http.StatusBadRequest, // Will fail at StartAuthFlow (no OAuth client setup)
+
},
+
{
+
name: "rejected - custom scheme coves-app (insecure)",
+
redirectURI: "coves-app://oauth/callback",
+
expectedStatus: http.StatusBadRequest,
+
expectedLog: "rejected unauthorized mobile redirect URI",
+
},
+
{
+
name: "rejected evil scheme",
+
redirectURI: "evil://callback",
+
expectedStatus: http.StatusBadRequest,
+
expectedLog: "rejected unauthorized mobile redirect URI",
+
},
+
{
+
name: "rejected http",
+
redirectURI: "http://evil.com/callback",
+
expectedStatus: http.StatusBadRequest,
+
expectedLog: "scheme not allowed",
+
},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
// Test the validation function directly
+
result := isAllowedMobileRedirectURI(tt.redirectURI)
+
if tt.expectedLog != "" {
+
assert.False(t, result, "Should reject %s", tt.redirectURI)
+
}
+
})
+
}
+
}
+
+
// TestHandleCallback_CSRFValidation tests that HandleCallback validates CSRF tokens for mobile flow
+
func TestHandleCallback_CSRFValidation(t *testing.T) {
+
// This is a conceptual test structure. Full implementation would require:
+
// 1. Mock OAuthClient
+
// 2. Mock OAuth store
+
// 3. Simulated OAuth callback with cookies
+
+
t.Run("mobile callback requires CSRF token", func(t *testing.T) {
+
// Setup: Create request with mobile_redirect_uri cookie but NO oauth_csrf cookie
+
req := httptest.NewRequest("GET", "/oauth/callback?code=test&state=test", nil)
+
req.AddCookie(&http.Cookie{
+
Name: "mobile_redirect_uri",
+
Value: "https%3A%2F%2Fcoves.social%2Fapp%2Foauth%2Fcallback",
+
})
+
// Missing: oauth_csrf cookie
+
+
// This would be rejected with 403 Forbidden in the actual handler
+
// (Full test in integration tests with real OAuth flow)
+
+
assert.NotNil(t, req) // Placeholder assertion
+
})
+
+
t.Run("mobile callback with valid CSRF token", func(t *testing.T) {
+
// Setup: Create request with both cookies
+
req := httptest.NewRequest("GET", "/oauth/callback?code=test&state=test", nil)
+
req.AddCookie(&http.Cookie{
+
Name: "mobile_redirect_uri",
+
Value: "https%3A%2F%2Fcoves.social%2Fapp%2Foauth%2Fcallback",
+
})
+
req.AddCookie(&http.Cookie{
+
Name: "oauth_csrf",
+
Value: "valid-csrf-token",
+
})
+
+
// This would be accepted (assuming valid OAuth code/state)
+
// (Full test in integration tests with real OAuth flow)
+
+
assert.NotNil(t, req) // Placeholder assertion
+
})
+
}
+
+
// TestHandleMobileCallback_RevalidatesRedirectURI tests that handleMobileCallback re-validates the redirect URI
+
func TestHandleMobileCallback_RevalidatesRedirectURI(t *testing.T) {
+
// This is a critical security test: even if an attacker somehow bypasses the initial check,
+
// the callback handler should re-validate the redirect URI before redirecting.
+
+
tests := []struct {
+
name string
+
redirectURI string
+
shouldPass bool
+
}{
+
{
+
name: "allowed - Universal Link",
+
redirectURI: "https://coves.social/app/oauth/callback",
+
shouldPass: true,
+
},
+
{
+
name: "blocked - custom scheme (insecure)",
+
redirectURI: "coves-app://oauth/callback",
+
shouldPass: false,
+
},
+
{
+
name: "blocked - evil scheme",
+
redirectURI: "evil://callback",
+
shouldPass: false,
+
},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
result := isAllowedMobileRedirectURI(tt.redirectURI)
+
assert.Equal(t, tt.shouldPass, result)
+
})
+
}
+
}
+
+
// TestGenerateMobileRedirectBinding tests the binding token generation
+
// The binding now includes the CSRF token for proper double-submit validation
+
func TestGenerateMobileRedirectBinding(t *testing.T) {
+
csrfToken := "test-csrf-token-12345"
+
tests := []struct {
+
name string
+
redirectURI string
+
}{
+
{
+
name: "Universal Link",
+
redirectURI: "https://coves.social/app/oauth/callback",
+
},
+
{
+
name: "different path",
+
redirectURI: "https://coves.social/different/path",
+
},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
binding1 := generateMobileRedirectBinding(csrfToken, tt.redirectURI)
+
binding2 := generateMobileRedirectBinding(csrfToken, tt.redirectURI)
+
+
// Same CSRF token + URI should produce same binding (deterministic)
+
assert.Equal(t, binding1, binding2, "binding should be deterministic for same inputs")
+
+
// Binding should not be empty
+
assert.NotEmpty(t, binding1, "binding should not be empty")
+
+
// Binding should be base64 encoded (should decode without error)
+
assert.Greater(t, len(binding1), 20, "binding should be reasonably long")
+
})
+
}
+
+
// Different URIs should produce different bindings
+
binding1 := generateMobileRedirectBinding(csrfToken, "https://coves.social/app/oauth/callback")
+
binding2 := generateMobileRedirectBinding(csrfToken, "https://coves.social/different/path")
+
assert.NotEqual(t, binding1, binding2, "different URIs should produce different bindings")
+
+
// Different CSRF tokens should produce different bindings
+
binding3 := generateMobileRedirectBinding("different-csrf-token", "https://coves.social/app/oauth/callback")
+
assert.NotEqual(t, binding1, binding3, "different CSRF tokens should produce different bindings")
+
}
+
+
// TestValidateMobileRedirectBinding tests the binding validation
+
// Now validates both CSRF token and redirect URI together (double-submit pattern)
+
func TestValidateMobileRedirectBinding(t *testing.T) {
+
csrfToken := "test-csrf-token-for-validation"
+
redirectURI := "https://coves.social/app/oauth/callback"
+
validBinding := generateMobileRedirectBinding(csrfToken, redirectURI)
+
+
tests := []struct {
+
name string
+
csrfToken string
+
redirectURI string
+
binding string
+
shouldPass bool
+
}{
+
{
+
name: "valid - correct CSRF token and redirect URI",
+
csrfToken: csrfToken,
+
redirectURI: redirectURI,
+
binding: validBinding,
+
shouldPass: true,
+
},
+
{
+
name: "invalid - wrong redirect URI",
+
csrfToken: csrfToken,
+
redirectURI: "https://coves.social/different/path",
+
binding: validBinding,
+
shouldPass: false,
+
},
+
{
+
name: "invalid - wrong CSRF token",
+
csrfToken: "wrong-csrf-token",
+
redirectURI: redirectURI,
+
binding: validBinding,
+
shouldPass: false,
+
},
+
{
+
name: "invalid - random binding",
+
csrfToken: csrfToken,
+
redirectURI: redirectURI,
+
binding: "random-invalid-binding",
+
shouldPass: false,
+
},
+
{
+
name: "invalid - empty binding",
+
csrfToken: csrfToken,
+
redirectURI: redirectURI,
+
binding: "",
+
shouldPass: false,
+
},
+
{
+
name: "invalid - empty CSRF token",
+
csrfToken: "",
+
redirectURI: redirectURI,
+
binding: validBinding,
+
shouldPass: false,
+
},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
result := validateMobileRedirectBinding(tt.csrfToken, tt.redirectURI, tt.binding)
+
assert.Equal(t, tt.shouldPass, result)
+
})
+
}
+
}
+
+
// TestSessionFixationAttackPrevention tests that the binding prevents session fixation
+
func TestSessionFixationAttackPrevention(t *testing.T) {
+
// Simulate attack scenario:
+
// 1. Attacker plants a cookie for evil://steal with binding for evil://steal
+
// 2. User does a web login (no mobile_redirect_binding cookie)
+
// 3. Callback should NOT redirect to evil://steal
+
+
attackerCSRF := "attacker-csrf-token"
+
attackerRedirectURI := "evil://steal"
+
attackerBinding := generateMobileRedirectBinding(attackerCSRF, attackerRedirectURI)
+
+
// Later, user's legitimate mobile login
+
userCSRF := "user-csrf-token"
+
userRedirectURI := "https://coves.social/app/oauth/callback"
+
userBinding := generateMobileRedirectBinding(userCSRF, userRedirectURI)
+
+
// The attacker's binding should NOT validate for the user's redirect URI
+
assert.False(t, validateMobileRedirectBinding(userCSRF, userRedirectURI, attackerBinding),
+
"attacker's binding should not validate for user's CSRF token and redirect URI")
+
+
// The user's binding should validate for the user's CSRF token and redirect URI
+
assert.True(t, validateMobileRedirectBinding(userCSRF, userRedirectURI, userBinding),
+
"user's binding should validate for user's CSRF token and redirect URI")
+
+
// Cross-validation should fail
+
assert.False(t, validateMobileRedirectBinding(attackerCSRF, attackerRedirectURI, userBinding),
+
"user's binding should not validate for attacker's CSRF token and redirect URI")
+
}
+
+
// TestCSRFTokenValidation tests that CSRF token VALUE is validated, not just presence
+
func TestCSRFTokenValidation(t *testing.T) {
+
// This test verifies the fix for the P1 security issue:
+
// "The callback never validates the token... the csrfToken argument is ignored entirely"
+
//
+
// The fix ensures that the CSRF token VALUE is cryptographically bound to the
+
// binding token, so changing the CSRF token will invalidate the binding.
+
+
t.Run("CSRF token value must match", func(t *testing.T) {
+
originalCSRF := "original-csrf-token-from-login"
+
redirectURI := "https://coves.social/app/oauth/callback"
+
binding := generateMobileRedirectBinding(originalCSRF, redirectURI)
+
+
// Original CSRF token should validate
+
assert.True(t, validateMobileRedirectBinding(originalCSRF, redirectURI, binding),
+
"original CSRF token should validate")
+
+
// Different CSRF token should NOT validate (this is the key security fix)
+
differentCSRF := "attacker-forged-csrf-token"
+
assert.False(t, validateMobileRedirectBinding(differentCSRF, redirectURI, binding),
+
"different CSRF token should NOT validate - this is the security fix")
+
})
+
+
t.Run("attacker cannot forge binding without CSRF token", func(t *testing.T) {
+
// Attacker knows the redirect URI but not the CSRF token
+
redirectURI := "https://coves.social/app/oauth/callback"
+
victimCSRF := "victim-secret-csrf-token"
+
victimBinding := generateMobileRedirectBinding(victimCSRF, redirectURI)
+
+
// Attacker tries various CSRF tokens to forge the binding
+
attackerGuesses := []string{
+
"",
+
"guess1",
+
"attacker-csrf",
+
redirectURI, // trying the redirect URI as CSRF
+
}
+
+
for _, guess := range attackerGuesses {
+
assert.False(t, validateMobileRedirectBinding(guess, redirectURI, victimBinding),
+
"attacker's CSRF guess %q should not validate", guess)
+
}
+
})
+
}
+
+
// TestConstantTimeCompare tests the timing-safe comparison function
+
func TestConstantTimeCompare(t *testing.T) {
+
tests := []struct {
+
name string
+
a string
+
b string
+
expected bool
+
}{
+
{
+
name: "equal strings",
+
a: "abc123",
+
b: "abc123",
+
expected: true,
+
},
+
{
+
name: "different strings same length",
+
a: "abc123",
+
b: "xyz789",
+
expected: false,
+
},
+
{
+
name: "different lengths",
+
a: "short",
+
b: "longer",
+
expected: false,
+
},
+
{
+
name: "empty strings",
+
a: "",
+
b: "",
+
expected: true,
+
},
+
{
+
name: "one empty",
+
a: "abc",
+
b: "",
+
expected: false,
+
},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
result := constantTimeCompare(tt.a, tt.b)
+
assert.Equal(t, tt.expected, result)
+
})
+
}
+
}
+152
internal/atproto/oauth/seal.go
···
+
package oauth
+
+
import (
+
"crypto/aes"
+
"crypto/cipher"
+
"crypto/rand"
+
"encoding/base64"
+
"encoding/json"
+
"fmt"
+
"time"
+
)
+
+
// SealedSession represents the data sealed in a mobile session token
+
type SealedSession struct {
+
DID string `json:"did"` // User's DID
+
SessionID string `json:"sid"` // Session identifier
+
ExpiresAt int64 `json:"exp"` // Unix timestamp when token expires
+
}
+
+
// SealSession creates an encrypted token containing session information.
+
// The token is encrypted using AES-256-GCM and encoded as base64url.
+
//
+
// Token format: base64url(nonce || ciphertext || tag)
+
// - nonce: 12 bytes (GCM standard nonce size)
+
// - ciphertext: encrypted JSON payload
+
// - tag: 16 bytes (GCM authentication tag)
+
//
+
// The sealed token can be safely given to mobile clients and used as
+
// a reference to the server-side session without exposing sensitive data.
+
func (c *OAuthClient) SealSession(did, sessionID string, ttl time.Duration) (string, error) {
+
if len(c.SealSecret) == 0 {
+
return "", fmt.Errorf("seal secret not configured")
+
}
+
+
if did == "" {
+
return "", fmt.Errorf("DID is required")
+
}
+
+
if sessionID == "" {
+
return "", fmt.Errorf("session ID is required")
+
}
+
+
// Create the session data
+
expiresAt := time.Now().Add(ttl).Unix()
+
session := SealedSession{
+
DID: did,
+
SessionID: sessionID,
+
ExpiresAt: expiresAt,
+
}
+
+
// Marshal to JSON
+
plaintext, err := json.Marshal(session)
+
if err != nil {
+
return "", fmt.Errorf("failed to marshal session: %w", err)
+
}
+
+
// Create AES cipher
+
block, err := aes.NewCipher(c.SealSecret)
+
if err != nil {
+
return "", fmt.Errorf("failed to create cipher: %w", err)
+
}
+
+
// Create GCM mode
+
gcm, err := cipher.NewGCM(block)
+
if err != nil {
+
return "", fmt.Errorf("failed to create GCM: %w", err)
+
}
+
+
// Generate random nonce
+
nonce := make([]byte, gcm.NonceSize())
+
if _, err := rand.Read(nonce); err != nil {
+
return "", fmt.Errorf("failed to generate nonce: %w", err)
+
}
+
+
// Encrypt and authenticate
+
// GCM.Seal appends the ciphertext and tag to the nonce
+
ciphertext := gcm.Seal(nonce, nonce, plaintext, nil)
+
+
// Encode as base64url (no padding)
+
token := base64.RawURLEncoding.EncodeToString(ciphertext)
+
+
return token, nil
+
}
+
+
// UnsealSession decrypts and validates a sealed session token.
+
// Returns the session information if the token is valid and not expired.
+
func (c *OAuthClient) UnsealSession(token string) (*SealedSession, error) {
+
if len(c.SealSecret) == 0 {
+
return nil, fmt.Errorf("seal secret not configured")
+
}
+
+
if token == "" {
+
return nil, fmt.Errorf("token is required")
+
}
+
+
// Decode from base64url
+
ciphertext, err := base64.RawURLEncoding.DecodeString(token)
+
if err != nil {
+
return nil, fmt.Errorf("invalid token encoding: %w", err)
+
}
+
+
// Create AES cipher
+
block, err := aes.NewCipher(c.SealSecret)
+
if err != nil {
+
return nil, fmt.Errorf("failed to create cipher: %w", err)
+
}
+
+
// Create GCM mode
+
gcm, err := cipher.NewGCM(block)
+
if err != nil {
+
return nil, fmt.Errorf("failed to create GCM: %w", err)
+
}
+
+
// Verify minimum size (nonce + tag)
+
nonceSize := gcm.NonceSize()
+
if len(ciphertext) < nonceSize {
+
return nil, fmt.Errorf("invalid token: too short")
+
}
+
+
// Extract nonce and ciphertext
+
nonce := ciphertext[:nonceSize]
+
ciphertextData := ciphertext[nonceSize:]
+
+
// Decrypt and authenticate
+
plaintext, err := gcm.Open(nil, nonce, ciphertextData, nil)
+
if err != nil {
+
return nil, fmt.Errorf("failed to decrypt token: %w", err)
+
}
+
+
// Unmarshal JSON
+
var session SealedSession
+
if err := json.Unmarshal(plaintext, &session); err != nil {
+
return nil, fmt.Errorf("failed to unmarshal session: %w", err)
+
}
+
+
// Validate required fields
+
if session.DID == "" {
+
return nil, fmt.Errorf("invalid session: missing DID")
+
}
+
+
if session.SessionID == "" {
+
return nil, fmt.Errorf("invalid session: missing session ID")
+
}
+
+
// Check expiration
+
now := time.Now().Unix()
+
if session.ExpiresAt <= now {
+
return nil, fmt.Errorf("token expired at %v", time.Unix(session.ExpiresAt, 0))
+
}
+
+
return &session, nil
+
}
+331
internal/atproto/oauth/seal_test.go
···
+
package oauth
+
+
import (
+
"crypto/rand"
+
"encoding/base64"
+
"strings"
+
"testing"
+
"time"
+
+
"github.com/stretchr/testify/assert"
+
"github.com/stretchr/testify/require"
+
)
+
+
// generateSealSecret generates a random 32-byte seal secret for testing
+
func generateSealSecret() []byte {
+
secret := make([]byte, 32)
+
if _, err := rand.Read(secret); err != nil {
+
panic(err)
+
}
+
return secret
+
}
+
+
func TestSealSession_RoundTrip(t *testing.T) {
+
// Create client with seal secret
+
client := &OAuthClient{
+
SealSecret: generateSealSecret(),
+
}
+
+
did := "did:plc:abc123"
+
sessionID := "session-xyz"
+
ttl := 1 * time.Hour
+
+
// Seal the session
+
token, err := client.SealSession(did, sessionID, ttl)
+
require.NoError(t, err)
+
require.NotEmpty(t, token)
+
+
// Token should be base64url encoded
+
_, err = base64.RawURLEncoding.DecodeString(token)
+
require.NoError(t, err, "token should be valid base64url")
+
+
// Unseal the session
+
session, err := client.UnsealSession(token)
+
require.NoError(t, err)
+
require.NotNil(t, session)
+
+
// Verify data
+
assert.Equal(t, did, session.DID)
+
assert.Equal(t, sessionID, session.SessionID)
+
+
// Verify expiration is approximately correct (within 1 second)
+
expectedExpiry := time.Now().Add(ttl).Unix()
+
assert.InDelta(t, expectedExpiry, session.ExpiresAt, 1.0)
+
}
+
+
func TestSealSession_ExpirationValidation(t *testing.T) {
+
client := &OAuthClient{
+
SealSecret: generateSealSecret(),
+
}
+
+
did := "did:plc:abc123"
+
sessionID := "session-xyz"
+
ttl := 2 * time.Second // Short TTL (must be >= 1 second due to Unix timestamp granularity)
+
+
// Seal the session
+
token, err := client.SealSession(did, sessionID, ttl)
+
require.NoError(t, err)
+
+
// Should work immediately
+
session, err := client.UnsealSession(token)
+
require.NoError(t, err)
+
assert.Equal(t, did, session.DID)
+
+
// Wait well past expiration
+
time.Sleep(2500 * time.Millisecond)
+
+
// Should fail after expiration
+
session, err = client.UnsealSession(token)
+
assert.Error(t, err)
+
assert.Nil(t, session)
+
assert.Contains(t, err.Error(), "token expired")
+
}
+
+
func TestSealSession_TamperedTokenDetection(t *testing.T) {
+
client := &OAuthClient{
+
SealSecret: generateSealSecret(),
+
}
+
+
did := "did:plc:abc123"
+
sessionID := "session-xyz"
+
ttl := 1 * time.Hour
+
+
// Seal the session
+
token, err := client.SealSession(did, sessionID, ttl)
+
require.NoError(t, err)
+
+
// Tamper with the token by modifying one character
+
tampered := token[:len(token)-5] + "XXXX" + token[len(token)-1:]
+
+
// Should fail to unseal tampered token
+
session, err := client.UnsealSession(tampered)
+
assert.Error(t, err)
+
assert.Nil(t, session)
+
assert.Contains(t, err.Error(), "failed to decrypt token")
+
}
+
+
func TestSealSession_InvalidTokenFormats(t *testing.T) {
+
client := &OAuthClient{
+
SealSecret: generateSealSecret(),
+
}
+
+
tests := []struct {
+
name string
+
token string
+
}{
+
{
+
name: "empty token",
+
token: "",
+
},
+
{
+
name: "invalid base64",
+
token: "not-valid-base64!@#$",
+
},
+
{
+
name: "too short",
+
token: base64.RawURLEncoding.EncodeToString([]byte("short")),
+
},
+
{
+
name: "random bytes",
+
token: base64.RawURLEncoding.EncodeToString(make([]byte, 50)),
+
},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
session, err := client.UnsealSession(tt.token)
+
assert.Error(t, err)
+
assert.Nil(t, session)
+
})
+
}
+
}
+
+
func TestSealSession_DifferentSecrets(t *testing.T) {
+
// Create two clients with different secrets
+
client1 := &OAuthClient{
+
SealSecret: generateSealSecret(),
+
}
+
client2 := &OAuthClient{
+
SealSecret: generateSealSecret(),
+
}
+
+
did := "did:plc:abc123"
+
sessionID := "session-xyz"
+
ttl := 1 * time.Hour
+
+
// Seal with client1
+
token, err := client1.SealSession(did, sessionID, ttl)
+
require.NoError(t, err)
+
+
// Try to unseal with client2 (different secret)
+
session, err := client2.UnsealSession(token)
+
assert.Error(t, err)
+
assert.Nil(t, session)
+
assert.Contains(t, err.Error(), "failed to decrypt token")
+
}
+
+
func TestSealSession_NoSecretConfigured(t *testing.T) {
+
client := &OAuthClient{
+
SealSecret: nil,
+
}
+
+
did := "did:plc:abc123"
+
sessionID := "session-xyz"
+
ttl := 1 * time.Hour
+
+
// Should fail to seal without secret
+
token, err := client.SealSession(did, sessionID, ttl)
+
assert.Error(t, err)
+
assert.Empty(t, token)
+
assert.Contains(t, err.Error(), "seal secret not configured")
+
+
// Should fail to unseal without secret
+
session, err := client.UnsealSession("dummy-token")
+
assert.Error(t, err)
+
assert.Nil(t, session)
+
assert.Contains(t, err.Error(), "seal secret not configured")
+
}
+
+
func TestSealSession_MissingRequiredFields(t *testing.T) {
+
client := &OAuthClient{
+
SealSecret: generateSealSecret(),
+
}
+
+
ttl := 1 * time.Hour
+
+
tests := []struct {
+
name string
+
did string
+
sessionID string
+
errorMsg string
+
}{
+
{
+
name: "missing DID",
+
did: "",
+
sessionID: "session-123",
+
errorMsg: "DID is required",
+
},
+
{
+
name: "missing session ID",
+
did: "did:plc:abc123",
+
sessionID: "",
+
errorMsg: "session ID is required",
+
},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
token, err := client.SealSession(tt.did, tt.sessionID, ttl)
+
assert.Error(t, err)
+
assert.Empty(t, token)
+
assert.Contains(t, err.Error(), tt.errorMsg)
+
})
+
}
+
}
+
+
func TestSealSession_UniquenessPerCall(t *testing.T) {
+
client := &OAuthClient{
+
SealSecret: generateSealSecret(),
+
}
+
+
did := "did:plc:abc123"
+
sessionID := "session-xyz"
+
ttl := 1 * time.Hour
+
+
// Seal the same session twice
+
token1, err := client.SealSession(did, sessionID, ttl)
+
require.NoError(t, err)
+
+
token2, err := client.SealSession(did, sessionID, ttl)
+
require.NoError(t, err)
+
+
// Tokens should be different (different nonces)
+
assert.NotEqual(t, token1, token2, "tokens should be unique due to different nonces")
+
+
// But both should unseal to the same session data
+
session1, err := client.UnsealSession(token1)
+
require.NoError(t, err)
+
+
session2, err := client.UnsealSession(token2)
+
require.NoError(t, err)
+
+
assert.Equal(t, session1.DID, session2.DID)
+
assert.Equal(t, session1.SessionID, session2.SessionID)
+
}
+
+
func TestSealSession_LongDIDAndSessionID(t *testing.T) {
+
client := &OAuthClient{
+
SealSecret: generateSealSecret(),
+
}
+
+
// Test with very long DID and session ID
+
did := "did:plc:" + strings.Repeat("a", 200)
+
sessionID := "session-" + strings.Repeat("x", 200)
+
ttl := 1 * time.Hour
+
+
// Should work with long values
+
token, err := client.SealSession(did, sessionID, ttl)
+
require.NoError(t, err)
+
+
session, err := client.UnsealSession(token)
+
require.NoError(t, err)
+
assert.Equal(t, did, session.DID)
+
assert.Equal(t, sessionID, session.SessionID)
+
}
+
+
func TestSealSession_URLSafeEncoding(t *testing.T) {
+
client := &OAuthClient{
+
SealSecret: generateSealSecret(),
+
}
+
+
did := "did:plc:abc123"
+
sessionID := "session-xyz"
+
ttl := 1 * time.Hour
+
+
// Seal multiple times to get different nonces
+
for i := 0; i < 100; i++ {
+
token, err := client.SealSession(did, sessionID, ttl)
+
require.NoError(t, err)
+
+
// Token should not contain URL-unsafe characters
+
assert.NotContains(t, token, "+", "token should not contain '+'")
+
assert.NotContains(t, token, "/", "token should not contain '/'")
+
assert.NotContains(t, token, "=", "token should not contain '='")
+
+
// Should unseal successfully
+
session, err := client.UnsealSession(token)
+
require.NoError(t, err)
+
assert.Equal(t, did, session.DID)
+
}
+
}
+
+
func TestSealSession_ConcurrentAccess(t *testing.T) {
+
client := &OAuthClient{
+
SealSecret: generateSealSecret(),
+
}
+
+
did := "did:plc:abc123"
+
sessionID := "session-xyz"
+
ttl := 1 * time.Hour
+
+
// Run concurrent seal/unseal operations
+
done := make(chan bool)
+
for i := 0; i < 10; i++ {
+
go func() {
+
for j := 0; j < 100; j++ {
+
token, err := client.SealSession(did, sessionID, ttl)
+
require.NoError(t, err)
+
+
session, err := client.UnsealSession(token)
+
require.NoError(t, err)
+
assert.Equal(t, did, session.DID)
+
}
+
done <- true
+
}()
+
}
+
+
// Wait for all goroutines
+
for i := 0; i < 10; i++ {
+
<-done
+
}
+
}
+614
internal/atproto/oauth/store.go
···
+
package oauth
+
+
import (
+
"context"
+
"database/sql"
+
"errors"
+
"fmt"
+
"log/slog"
+
"net/url"
+
"strings"
+
"time"
+
+
"github.com/bluesky-social/indigo/atproto/auth/oauth"
+
"github.com/bluesky-social/indigo/atproto/syntax"
+
"github.com/lib/pq"
+
)
+
+
var (
+
ErrSessionNotFound = errors.New("oauth session not found")
+
ErrAuthRequestNotFound = errors.New("oauth auth request not found")
+
)
+
+
// PostgresOAuthStore implements oauth.ClientAuthStore interface using PostgreSQL
+
type PostgresOAuthStore struct {
+
db *sql.DB
+
sessionTTL time.Duration
+
}
+
+
// NewPostgresOAuthStore creates a new PostgreSQL-backed OAuth store
+
func NewPostgresOAuthStore(db *sql.DB, sessionTTL time.Duration) oauth.ClientAuthStore {
+
if sessionTTL == 0 {
+
sessionTTL = 7 * 24 * time.Hour // Default to 7 days
+
}
+
return &PostgresOAuthStore{
+
db: db,
+
sessionTTL: sessionTTL,
+
}
+
}
+
+
// GetSession retrieves a session by DID and session ID
+
func (s *PostgresOAuthStore) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauth.ClientSessionData, error) {
+
query := `
+
SELECT
+
did, session_id, host_url, auth_server_iss,
+
auth_server_token_endpoint, auth_server_revocation_endpoint,
+
scopes, access_token, refresh_token,
+
dpop_authserver_nonce, dpop_pds_nonce, dpop_private_key_multibase
+
FROM oauth_sessions
+
WHERE did = $1 AND session_id = $2 AND expires_at > NOW()
+
`
+
+
var session oauth.ClientSessionData
+
var authServerIss, authServerTokenEndpoint, authServerRevocationEndpoint sql.NullString
+
var hostURL, dpopPrivateKeyMultibase sql.NullString
+
var scopes pq.StringArray
+
var dpopAuthServerNonce, dpopHostNonce sql.NullString
+
+
err := s.db.QueryRowContext(ctx, query, did.String(), sessionID).Scan(
+
&session.AccountDID,
+
&session.SessionID,
+
&hostURL,
+
&authServerIss,
+
&authServerTokenEndpoint,
+
&authServerRevocationEndpoint,
+
&scopes,
+
&session.AccessToken,
+
&session.RefreshToken,
+
&dpopAuthServerNonce,
+
&dpopHostNonce,
+
&dpopPrivateKeyMultibase,
+
)
+
+
if err == sql.ErrNoRows {
+
return nil, ErrSessionNotFound
+
}
+
if err != nil {
+
return nil, fmt.Errorf("failed to get session: %w", err)
+
}
+
+
// Convert nullable fields
+
if hostURL.Valid {
+
session.HostURL = hostURL.String
+
}
+
if authServerIss.Valid {
+
session.AuthServerURL = authServerIss.String
+
}
+
if authServerTokenEndpoint.Valid {
+
session.AuthServerTokenEndpoint = authServerTokenEndpoint.String
+
}
+
if authServerRevocationEndpoint.Valid {
+
session.AuthServerRevocationEndpoint = authServerRevocationEndpoint.String
+
}
+
if dpopAuthServerNonce.Valid {
+
session.DPoPAuthServerNonce = dpopAuthServerNonce.String
+
}
+
if dpopHostNonce.Valid {
+
session.DPoPHostNonce = dpopHostNonce.String
+
}
+
if dpopPrivateKeyMultibase.Valid {
+
session.DPoPPrivateKeyMultibase = dpopPrivateKeyMultibase.String
+
}
+
session.Scopes = scopes
+
+
return &session, nil
+
}
+
+
// SaveSession saves or updates a session (upsert operation)
+
func (s *PostgresOAuthStore) SaveSession(ctx context.Context, sess oauth.ClientSessionData) error {
+
// Input validation per atProto OAuth security requirements
+
+
// Validate DID format
+
if _, err := syntax.ParseDID(sess.AccountDID.String()); err != nil {
+
return fmt.Errorf("invalid DID format: %w", err)
+
}
+
+
// Validate token lengths (max 10000 chars to prevent memory issues)
+
const maxTokenLength = 10000
+
if len(sess.AccessToken) > maxTokenLength {
+
return fmt.Errorf("access_token exceeds maximum length of %d characters", maxTokenLength)
+
}
+
if len(sess.RefreshToken) > maxTokenLength {
+
return fmt.Errorf("refresh_token exceeds maximum length of %d characters", maxTokenLength)
+
}
+
+
// Validate session ID is not empty
+
if sess.SessionID == "" {
+
return fmt.Errorf("session_id cannot be empty")
+
}
+
+
// Validate URLs if provided
+
if sess.HostURL != "" {
+
if _, err := url.Parse(sess.HostURL); err != nil {
+
return fmt.Errorf("invalid host_url: %w", err)
+
}
+
}
+
if sess.AuthServerURL != "" {
+
if _, err := url.Parse(sess.AuthServerURL); err != nil {
+
return fmt.Errorf("invalid auth_server URL: %w", err)
+
}
+
}
+
if sess.AuthServerTokenEndpoint != "" {
+
if _, err := url.Parse(sess.AuthServerTokenEndpoint); err != nil {
+
return fmt.Errorf("invalid auth_server_token_endpoint: %w", err)
+
}
+
}
+
if sess.AuthServerRevocationEndpoint != "" {
+
if _, err := url.Parse(sess.AuthServerRevocationEndpoint); err != nil {
+
return fmt.Errorf("invalid auth_server_revocation_endpoint: %w", err)
+
}
+
}
+
+
query := `
+
INSERT INTO oauth_sessions (
+
did, session_id, handle, pds_url, host_url,
+
access_token, refresh_token,
+
dpop_private_jwk, dpop_private_key_multibase,
+
dpop_authserver_nonce, dpop_pds_nonce,
+
auth_server_iss, auth_server_token_endpoint, auth_server_revocation_endpoint,
+
scopes, expires_at, created_at, updated_at
+
) VALUES (
+
$1, $2, $3, $4, $5,
+
$6, $7,
+
NULL, $8,
+
$9, $10,
+
$11, $12, $13,
+
$14, $15, NOW(), NOW()
+
)
+
ON CONFLICT (did, session_id) DO UPDATE SET
+
handle = EXCLUDED.handle,
+
pds_url = EXCLUDED.pds_url,
+
host_url = EXCLUDED.host_url,
+
access_token = EXCLUDED.access_token,
+
refresh_token = EXCLUDED.refresh_token,
+
dpop_private_key_multibase = EXCLUDED.dpop_private_key_multibase,
+
dpop_authserver_nonce = EXCLUDED.dpop_authserver_nonce,
+
dpop_pds_nonce = EXCLUDED.dpop_pds_nonce,
+
auth_server_iss = EXCLUDED.auth_server_iss,
+
auth_server_token_endpoint = EXCLUDED.auth_server_token_endpoint,
+
auth_server_revocation_endpoint = EXCLUDED.auth_server_revocation_endpoint,
+
scopes = EXCLUDED.scopes,
+
expires_at = EXCLUDED.expires_at,
+
updated_at = NOW()
+
`
+
+
// Calculate token expiration using configured TTL
+
expiresAt := time.Now().Add(s.sessionTTL)
+
+
// Convert empty strings to NULL for optional fields
+
var authServerRevocationEndpoint sql.NullString
+
if sess.AuthServerRevocationEndpoint != "" {
+
authServerRevocationEndpoint.String = sess.AuthServerRevocationEndpoint
+
authServerRevocationEndpoint.Valid = true
+
}
+
+
// Extract handle from DID (placeholder - in real implementation, resolve from identity)
+
// For now, use DID as handle since we don't have the handle in ClientSessionData
+
handle := sess.AccountDID.String()
+
+
// Use HostURL as PDS URL
+
pdsURL := sess.HostURL
+
if pdsURL == "" {
+
pdsURL = sess.AuthServerURL // Fallback to auth server URL
+
}
+
+
_, err := s.db.ExecContext(
+
ctx, query,
+
sess.AccountDID.String(),
+
sess.SessionID,
+
handle,
+
pdsURL,
+
sess.HostURL,
+
sess.AccessToken,
+
sess.RefreshToken,
+
sess.DPoPPrivateKeyMultibase,
+
sess.DPoPAuthServerNonce,
+
sess.DPoPHostNonce,
+
sess.AuthServerURL,
+
sess.AuthServerTokenEndpoint,
+
authServerRevocationEndpoint,
+
pq.Array(sess.Scopes),
+
expiresAt,
+
)
+
if err != nil {
+
return fmt.Errorf("failed to save session: %w", err)
+
}
+
+
return nil
+
}
+
+
// DeleteSession deletes a session by DID and session ID
+
func (s *PostgresOAuthStore) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error {
+
query := `DELETE FROM oauth_sessions WHERE did = $1 AND session_id = $2`
+
+
result, err := s.db.ExecContext(ctx, query, did.String(), sessionID)
+
if err != nil {
+
return fmt.Errorf("failed to delete session: %w", err)
+
}
+
+
rows, err := result.RowsAffected()
+
if err != nil {
+
return fmt.Errorf("failed to get rows affected: %w", err)
+
}
+
+
if rows == 0 {
+
return ErrSessionNotFound
+
}
+
+
return nil
+
}
+
+
// GetAuthRequestInfo retrieves auth request information by state
+
func (s *PostgresOAuthStore) GetAuthRequestInfo(ctx context.Context, state string) (*oauth.AuthRequestData, error) {
+
query := `
+
SELECT
+
state, did, handle, pds_url, pkce_verifier,
+
dpop_private_key_multibase, dpop_authserver_nonce,
+
auth_server_iss, request_uri,
+
auth_server_token_endpoint, auth_server_revocation_endpoint,
+
scopes, created_at
+
FROM oauth_requests
+
WHERE state = $1
+
`
+
+
var info oauth.AuthRequestData
+
var did, handle, pdsURL sql.NullString
+
var dpopPrivateKeyMultibase, dpopAuthServerNonce sql.NullString
+
var requestURI, authServerTokenEndpoint, authServerRevocationEndpoint sql.NullString
+
var scopes pq.StringArray
+
var createdAt time.Time
+
+
err := s.db.QueryRowContext(ctx, query, state).Scan(
+
&info.State,
+
&did,
+
&handle,
+
&pdsURL,
+
&info.PKCEVerifier,
+
&dpopPrivateKeyMultibase,
+
&dpopAuthServerNonce,
+
&info.AuthServerURL,
+
&requestURI,
+
&authServerTokenEndpoint,
+
&authServerRevocationEndpoint,
+
&scopes,
+
&createdAt,
+
)
+
+
if err == sql.ErrNoRows {
+
return nil, ErrAuthRequestNotFound
+
}
+
if err != nil {
+
return nil, fmt.Errorf("failed to get auth request info: %w", err)
+
}
+
+
// Parse DID if present
+
if did.Valid && did.String != "" {
+
parsedDID, err := syntax.ParseDID(did.String)
+
if err != nil {
+
return nil, fmt.Errorf("failed to parse DID: %w", err)
+
}
+
info.AccountDID = &parsedDID
+
}
+
+
// Convert nullable fields
+
if dpopPrivateKeyMultibase.Valid {
+
info.DPoPPrivateKeyMultibase = dpopPrivateKeyMultibase.String
+
}
+
if dpopAuthServerNonce.Valid {
+
info.DPoPAuthServerNonce = dpopAuthServerNonce.String
+
}
+
if requestURI.Valid {
+
info.RequestURI = requestURI.String
+
}
+
if authServerTokenEndpoint.Valid {
+
info.AuthServerTokenEndpoint = authServerTokenEndpoint.String
+
}
+
if authServerRevocationEndpoint.Valid {
+
info.AuthServerRevocationEndpoint = authServerRevocationEndpoint.String
+
}
+
info.Scopes = scopes
+
+
return &info, nil
+
}
+
+
// SaveAuthRequestInfo saves auth request information (create only, not upsert)
+
func (s *PostgresOAuthStore) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthRequestData) error {
+
query := `
+
INSERT INTO oauth_requests (
+
state, did, handle, pds_url, pkce_verifier,
+
dpop_private_key_multibase, dpop_authserver_nonce,
+
auth_server_iss, request_uri,
+
auth_server_token_endpoint, auth_server_revocation_endpoint,
+
scopes, return_url, created_at
+
) VALUES (
+
$1, $2, $3, $4, $5,
+
$6, $7,
+
$8, $9,
+
$10, $11,
+
$12, NULL, NOW()
+
)
+
`
+
+
// Extract DID string if present
+
var didStr sql.NullString
+
if info.AccountDID != nil {
+
didStr.String = info.AccountDID.String()
+
didStr.Valid = true
+
}
+
+
// Convert empty strings to NULL for optional fields
+
var authServerRevocationEndpoint sql.NullString
+
if info.AuthServerRevocationEndpoint != "" {
+
authServerRevocationEndpoint.String = info.AuthServerRevocationEndpoint
+
authServerRevocationEndpoint.Valid = true
+
}
+
+
// Placeholder values for handle and pds_url (not in AuthRequestData)
+
// In production, these would be resolved during the auth flow
+
handle := ""
+
pdsURL := ""
+
if info.AccountDID != nil {
+
handle = info.AccountDID.String() // Temporary placeholder
+
pdsURL = info.AuthServerURL // Temporary placeholder
+
}
+
+
_, err := s.db.ExecContext(
+
ctx, query,
+
info.State,
+
didStr,
+
handle,
+
pdsURL,
+
info.PKCEVerifier,
+
info.DPoPPrivateKeyMultibase,
+
info.DPoPAuthServerNonce,
+
info.AuthServerURL,
+
info.RequestURI,
+
info.AuthServerTokenEndpoint,
+
authServerRevocationEndpoint,
+
pq.Array(info.Scopes),
+
)
+
if err != nil {
+
// Check for duplicate state
+
if strings.Contains(err.Error(), "duplicate key") && strings.Contains(err.Error(), "oauth_requests_state_key") {
+
return fmt.Errorf("auth request with state already exists: %s", info.State)
+
}
+
return fmt.Errorf("failed to save auth request info: %w", err)
+
}
+
+
return nil
+
}
+
+
// DeleteAuthRequestInfo deletes auth request information by state
+
func (s *PostgresOAuthStore) DeleteAuthRequestInfo(ctx context.Context, state string) error {
+
query := `DELETE FROM oauth_requests WHERE state = $1`
+
+
result, err := s.db.ExecContext(ctx, query, state)
+
if err != nil {
+
return fmt.Errorf("failed to delete auth request info: %w", err)
+
}
+
+
rows, err := result.RowsAffected()
+
if err != nil {
+
return fmt.Errorf("failed to get rows affected: %w", err)
+
}
+
+
if rows == 0 {
+
return ErrAuthRequestNotFound
+
}
+
+
return nil
+
}
+
+
// CleanupExpiredSessions removes sessions that have expired
+
// Should be called periodically (e.g., via cron job)
+
func (s *PostgresOAuthStore) CleanupExpiredSessions(ctx context.Context) (int64, error) {
+
query := `DELETE FROM oauth_sessions WHERE expires_at < NOW()`
+
+
result, err := s.db.ExecContext(ctx, query)
+
if err != nil {
+
return 0, fmt.Errorf("failed to cleanup expired sessions: %w", err)
+
}
+
+
rows, err := result.RowsAffected()
+
if err != nil {
+
return 0, fmt.Errorf("failed to get rows affected: %w", err)
+
}
+
+
return rows, nil
+
}
+
+
// CleanupExpiredAuthRequests removes auth requests older than 30 minutes
+
// Should be called periodically (e.g., via cron job)
+
func (s *PostgresOAuthStore) CleanupExpiredAuthRequests(ctx context.Context) (int64, error) {
+
query := `DELETE FROM oauth_requests WHERE created_at < NOW() - INTERVAL '30 minutes'`
+
+
result, err := s.db.ExecContext(ctx, query)
+
if err != nil {
+
return 0, fmt.Errorf("failed to cleanup expired auth requests: %w", err)
+
}
+
+
rows, err := result.RowsAffected()
+
if err != nil {
+
return 0, fmt.Errorf("failed to get rows affected: %w", err)
+
}
+
+
return rows, nil
+
}
+
+
// MobileOAuthData holds mobile-specific OAuth flow data
+
type MobileOAuthData struct {
+
CSRFToken string
+
RedirectURI string
+
}
+
+
// mobileFlowContextKey is the context key for mobile flow data
+
type mobileFlowContextKey struct{}
+
+
// ContextWithMobileFlowData adds mobile flow data to a context.
+
// This is used by HandleMobileLogin to pass mobile data to the store wrapper,
+
// which will save it when SaveAuthRequestInfo is called by indigo.
+
func ContextWithMobileFlowData(ctx context.Context, data MobileOAuthData) context.Context {
+
return context.WithValue(ctx, mobileFlowContextKey{}, data)
+
}
+
+
// getMobileFlowDataFromContext retrieves mobile flow data from context, if present
+
func getMobileFlowDataFromContext(ctx context.Context) (MobileOAuthData, bool) {
+
data, ok := ctx.Value(mobileFlowContextKey{}).(MobileOAuthData)
+
return data, ok
+
}
+
+
// MobileAwareClientStore is a marker interface that indicates a store is properly
+
// configured for mobile OAuth flows. Only stores that intercept SaveAuthRequestInfo
+
// to save mobile CSRF data should implement this interface.
+
// This prevents silent mobile OAuth breakage when a plain PostgresOAuthStore is used.
+
type MobileAwareClientStore interface {
+
IsMobileAware() bool
+
}
+
+
// MobileAwareStoreWrapper wraps a ClientAuthStore to automatically save mobile
+
// CSRF data when SaveAuthRequestInfo is called during a mobile OAuth flow.
+
// This is necessary because indigo's StartAuthFlow doesn't expose the OAuth state,
+
// so we intercept the SaveAuthRequestInfo call to capture it.
+
type MobileAwareStoreWrapper struct {
+
oauth.ClientAuthStore
+
mobileStore MobileOAuthStore
+
}
+
+
// IsMobileAware implements MobileAwareClientStore, indicating this store
+
// properly saves mobile CSRF data during OAuth flow initiation.
+
func (w *MobileAwareStoreWrapper) IsMobileAware() bool {
+
return true
+
}
+
+
// NewMobileAwareStoreWrapper creates a wrapper that intercepts SaveAuthRequestInfo
+
// to also save mobile CSRF data when present in context.
+
func NewMobileAwareStoreWrapper(store oauth.ClientAuthStore) *MobileAwareStoreWrapper {
+
wrapper := &MobileAwareStoreWrapper{
+
ClientAuthStore: store,
+
}
+
// Check if the underlying store implements MobileOAuthStore
+
if ms, ok := store.(MobileOAuthStore); ok {
+
wrapper.mobileStore = ms
+
}
+
return wrapper
+
}
+
+
// SaveAuthRequestInfo saves the auth request and also saves mobile CSRF data
+
// if mobile flow data is present in the context.
+
func (w *MobileAwareStoreWrapper) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthRequestData) error {
+
// First, save the auth request to the underlying store
+
if err := w.ClientAuthStore.SaveAuthRequestInfo(ctx, info); err != nil {
+
return err
+
}
+
+
// Check if this is a mobile flow (mobile data in context)
+
if mobileData, ok := getMobileFlowDataFromContext(ctx); ok && w.mobileStore != nil {
+
// Save mobile CSRF data tied to this OAuth state
+
// IMPORTANT: If this fails, we MUST propagate the error. Otherwise:
+
// 1. No server-side CSRF record is stored
+
// 2. Every mobile callback will "fail closed" to web flow
+
// 3. Mobile sign-in silently breaks with no indication
+
// Failing loudly here lets the user retry rather than being confused
+
// about why they're getting a web flow instead of mobile.
+
if err := w.mobileStore.SaveMobileOAuthData(ctx, info.State, mobileData); err != nil {
+
slog.Error("failed to save mobile CSRF data - mobile login will fail",
+
"state", info.State, "error", err)
+
return fmt.Errorf("failed to save mobile OAuth data: %w", err)
+
}
+
}
+
+
return nil
+
}
+
+
// GetMobileOAuthData implements MobileOAuthStore interface by delegating to underlying store
+
func (w *MobileAwareStoreWrapper) GetMobileOAuthData(ctx context.Context, state string) (*MobileOAuthData, error) {
+
if w.mobileStore != nil {
+
return w.mobileStore.GetMobileOAuthData(ctx, state)
+
}
+
return nil, nil
+
}
+
+
// SaveMobileOAuthData implements MobileOAuthStore interface by delegating to underlying store
+
func (w *MobileAwareStoreWrapper) SaveMobileOAuthData(ctx context.Context, state string, data MobileOAuthData) error {
+
if w.mobileStore != nil {
+
return w.mobileStore.SaveMobileOAuthData(ctx, state, data)
+
}
+
return nil
+
}
+
+
// UnwrapPostgresStore returns the underlying PostgresOAuthStore if present.
+
// This is useful for accessing cleanup methods that aren't part of the interface.
+
func (w *MobileAwareStoreWrapper) UnwrapPostgresStore() *PostgresOAuthStore {
+
if ps, ok := w.ClientAuthStore.(*PostgresOAuthStore); ok {
+
return ps
+
}
+
return nil
+
}
+
+
// SaveMobileOAuthData stores mobile CSRF data tied to an OAuth state
+
// This ties the CSRF token to the OAuth flow via the state parameter,
+
// which comes back through the OAuth response for server-side validation.
+
func (s *PostgresOAuthStore) SaveMobileOAuthData(ctx context.Context, state string, data MobileOAuthData) error {
+
query := `
+
UPDATE oauth_requests
+
SET mobile_csrf_token = $2, mobile_redirect_uri = $3
+
WHERE state = $1
+
`
+
+
result, err := s.db.ExecContext(ctx, query, state, data.CSRFToken, data.RedirectURI)
+
if err != nil {
+
return fmt.Errorf("failed to save mobile OAuth data: %w", err)
+
}
+
+
rows, err := result.RowsAffected()
+
if err != nil {
+
return fmt.Errorf("failed to get rows affected: %w", err)
+
}
+
+
if rows == 0 {
+
return ErrAuthRequestNotFound
+
}
+
+
return nil
+
}
+
+
// GetMobileOAuthData retrieves mobile CSRF data by OAuth state
+
// This is called during callback to compare the server-side CSRF token
+
// (retrieved by state from the OAuth response) against the cookie CSRF.
+
func (s *PostgresOAuthStore) GetMobileOAuthData(ctx context.Context, state string) (*MobileOAuthData, error) {
+
query := `
+
SELECT mobile_csrf_token, mobile_redirect_uri
+
FROM oauth_requests
+
WHERE state = $1
+
`
+
+
var csrfToken, redirectURI sql.NullString
+
err := s.db.QueryRowContext(ctx, query, state).Scan(&csrfToken, &redirectURI)
+
+
if err == sql.ErrNoRows {
+
return nil, ErrAuthRequestNotFound
+
}
+
if err != nil {
+
return nil, fmt.Errorf("failed to get mobile OAuth data: %w", err)
+
}
+
+
// Return nil if no mobile data was stored (this was a web flow)
+
if !csrfToken.Valid {
+
return nil, nil
+
}
+
+
return &MobileOAuthData{
+
CSRFToken: csrfToken.String,
+
RedirectURI: redirectURI.String,
+
}, nil
+
}
+522
internal/atproto/oauth/store_test.go
···
+
package oauth
+
+
import (
+
"context"
+
"database/sql"
+
"os"
+
"testing"
+
+
"github.com/bluesky-social/indigo/atproto/auth/oauth"
+
"github.com/bluesky-social/indigo/atproto/syntax"
+
_ "github.com/lib/pq"
+
"github.com/pressly/goose/v3"
+
"github.com/stretchr/testify/assert"
+
"github.com/stretchr/testify/require"
+
)
+
+
// setupTestDB creates a test database connection and runs migrations
+
func setupTestDB(t *testing.T) *sql.DB {
+
dsn := os.Getenv("TEST_DATABASE_URL")
+
if dsn == "" {
+
dsn = "postgres://test_user:test_password@localhost:5434/coves_test?sslmode=disable"
+
}
+
+
db, err := sql.Open("postgres", dsn)
+
require.NoError(t, err, "Failed to connect to test database")
+
+
// Run migrations
+
require.NoError(t, goose.Up(db, "../../db/migrations"), "Failed to run migrations")
+
+
return db
+
}
+
+
// cleanupOAuth removes all test OAuth data from the database
+
func cleanupOAuth(t *testing.T, db *sql.DB) {
+
_, err := db.Exec("DELETE FROM oauth_sessions WHERE did LIKE 'did:plc:test%'")
+
require.NoError(t, err, "Failed to cleanup oauth_sessions")
+
+
_, err = db.Exec("DELETE FROM oauth_requests WHERE state LIKE 'test%'")
+
require.NoError(t, err, "Failed to cleanup oauth_requests")
+
}
+
+
func TestPostgresOAuthStore_SaveAndGetSession(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
defer cleanupOAuth(t, db)
+
+
store := NewPostgresOAuthStore(db, 0) // Use default TTL
+
ctx := context.Background()
+
+
did, err := syntax.ParseDID("did:plc:test123abc")
+
require.NoError(t, err)
+
+
session := oauth.ClientSessionData{
+
AccountDID: did,
+
SessionID: "session123",
+
HostURL: "https://pds.example.com",
+
AuthServerURL: "https://auth.example.com",
+
AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
+
AuthServerRevocationEndpoint: "https://auth.example.com/oauth/revoke",
+
Scopes: []string{"atproto"},
+
AccessToken: "at_test_token_abc123",
+
RefreshToken: "rt_test_token_xyz789",
+
DPoPAuthServerNonce: "nonce_auth_123",
+
DPoPHostNonce: "nonce_host_456",
+
DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
+
}
+
+
// Save session
+
err = store.SaveSession(ctx, session)
+
assert.NoError(t, err)
+
+
// Retrieve session
+
retrieved, err := store.GetSession(ctx, did, "session123")
+
assert.NoError(t, err)
+
assert.NotNil(t, retrieved)
+
assert.Equal(t, session.AccountDID.String(), retrieved.AccountDID.String())
+
assert.Equal(t, session.SessionID, retrieved.SessionID)
+
assert.Equal(t, session.HostURL, retrieved.HostURL)
+
assert.Equal(t, session.AuthServerURL, retrieved.AuthServerURL)
+
assert.Equal(t, session.AuthServerTokenEndpoint, retrieved.AuthServerTokenEndpoint)
+
assert.Equal(t, session.AccessToken, retrieved.AccessToken)
+
assert.Equal(t, session.RefreshToken, retrieved.RefreshToken)
+
assert.Equal(t, session.DPoPAuthServerNonce, retrieved.DPoPAuthServerNonce)
+
assert.Equal(t, session.DPoPHostNonce, retrieved.DPoPHostNonce)
+
assert.Equal(t, session.DPoPPrivateKeyMultibase, retrieved.DPoPPrivateKeyMultibase)
+
assert.Equal(t, session.Scopes, retrieved.Scopes)
+
}
+
+
func TestPostgresOAuthStore_SaveSession_Upsert(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
defer cleanupOAuth(t, db)
+
+
store := NewPostgresOAuthStore(db, 0) // Use default TTL
+
ctx := context.Background()
+
+
did, err := syntax.ParseDID("did:plc:testupsert")
+
require.NoError(t, err)
+
+
// Initial session
+
session1 := oauth.ClientSessionData{
+
AccountDID: did,
+
SessionID: "session_upsert",
+
HostURL: "https://pds1.example.com",
+
AuthServerURL: "https://auth1.example.com",
+
AuthServerTokenEndpoint: "https://auth1.example.com/oauth/token",
+
Scopes: []string{"atproto"},
+
AccessToken: "old_access_token",
+
RefreshToken: "old_refresh_token",
+
DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
+
}
+
+
err = store.SaveSession(ctx, session1)
+
require.NoError(t, err)
+
+
// Updated session (same DID and session ID)
+
session2 := oauth.ClientSessionData{
+
AccountDID: did,
+
SessionID: "session_upsert",
+
HostURL: "https://pds2.example.com",
+
AuthServerURL: "https://auth2.example.com",
+
AuthServerTokenEndpoint: "https://auth2.example.com/oauth/token",
+
Scopes: []string{"atproto", "transition:generic"},
+
AccessToken: "new_access_token",
+
RefreshToken: "new_refresh_token",
+
DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktX",
+
}
+
+
// Save again - should update
+
err = store.SaveSession(ctx, session2)
+
assert.NoError(t, err)
+
+
// Retrieve should get updated values
+
retrieved, err := store.GetSession(ctx, did, "session_upsert")
+
assert.NoError(t, err)
+
assert.Equal(t, "new_access_token", retrieved.AccessToken)
+
assert.Equal(t, "new_refresh_token", retrieved.RefreshToken)
+
assert.Equal(t, "https://pds2.example.com", retrieved.HostURL)
+
assert.Equal(t, []string{"atproto", "transition:generic"}, retrieved.Scopes)
+
}
+
+
func TestPostgresOAuthStore_GetSession_NotFound(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
store := NewPostgresOAuthStore(db, 0) // Use default TTL
+
ctx := context.Background()
+
+
did, err := syntax.ParseDID("did:plc:nonexistent")
+
require.NoError(t, err)
+
+
_, err = store.GetSession(ctx, did, "nonexistent_session")
+
assert.ErrorIs(t, err, ErrSessionNotFound)
+
}
+
+
func TestPostgresOAuthStore_DeleteSession(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
defer cleanupOAuth(t, db)
+
+
store := NewPostgresOAuthStore(db, 0) // Use default TTL
+
ctx := context.Background()
+
+
did, err := syntax.ParseDID("did:plc:testdelete")
+
require.NoError(t, err)
+
+
session := oauth.ClientSessionData{
+
AccountDID: did,
+
SessionID: "session_delete",
+
HostURL: "https://pds.example.com",
+
AuthServerURL: "https://auth.example.com",
+
AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
+
Scopes: []string{"atproto"},
+
AccessToken: "test_token",
+
RefreshToken: "test_refresh",
+
DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
+
}
+
+
// Save session
+
err = store.SaveSession(ctx, session)
+
require.NoError(t, err)
+
+
// Delete session
+
err = store.DeleteSession(ctx, did, "session_delete")
+
assert.NoError(t, err)
+
+
// Verify session is gone
+
_, err = store.GetSession(ctx, did, "session_delete")
+
assert.ErrorIs(t, err, ErrSessionNotFound)
+
}
+
+
func TestPostgresOAuthStore_DeleteSession_NotFound(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
store := NewPostgresOAuthStore(db, 0) // Use default TTL
+
ctx := context.Background()
+
+
did, err := syntax.ParseDID("did:plc:nonexistent")
+
require.NoError(t, err)
+
+
err = store.DeleteSession(ctx, did, "nonexistent_session")
+
assert.ErrorIs(t, err, ErrSessionNotFound)
+
}
+
+
func TestPostgresOAuthStore_SaveAndGetAuthRequestInfo(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
defer cleanupOAuth(t, db)
+
+
store := NewPostgresOAuthStore(db, 0) // Use default TTL
+
ctx := context.Background()
+
+
did, err := syntax.ParseDID("did:plc:testrequest")
+
require.NoError(t, err)
+
+
info := oauth.AuthRequestData{
+
State: "test_state_abc123",
+
AuthServerURL: "https://auth.example.com",
+
AccountDID: &did,
+
Scopes: []string{"atproto"},
+
RequestURI: "urn:ietf:params:oauth:request_uri:abc123",
+
AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
+
AuthServerRevocationEndpoint: "https://auth.example.com/oauth/revoke",
+
PKCEVerifier: "verifier_xyz789",
+
DPoPAuthServerNonce: "nonce_abc",
+
DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
+
}
+
+
// Save auth request info
+
err = store.SaveAuthRequestInfo(ctx, info)
+
assert.NoError(t, err)
+
+
// Retrieve auth request info
+
retrieved, err := store.GetAuthRequestInfo(ctx, "test_state_abc123")
+
assert.NoError(t, err)
+
assert.NotNil(t, retrieved)
+
assert.Equal(t, info.State, retrieved.State)
+
assert.Equal(t, info.AuthServerURL, retrieved.AuthServerURL)
+
assert.NotNil(t, retrieved.AccountDID)
+
assert.Equal(t, info.AccountDID.String(), retrieved.AccountDID.String())
+
assert.Equal(t, info.Scopes, retrieved.Scopes)
+
assert.Equal(t, info.RequestURI, retrieved.RequestURI)
+
assert.Equal(t, info.AuthServerTokenEndpoint, retrieved.AuthServerTokenEndpoint)
+
assert.Equal(t, info.PKCEVerifier, retrieved.PKCEVerifier)
+
assert.Equal(t, info.DPoPAuthServerNonce, retrieved.DPoPAuthServerNonce)
+
assert.Equal(t, info.DPoPPrivateKeyMultibase, retrieved.DPoPPrivateKeyMultibase)
+
}
+
+
func TestPostgresOAuthStore_SaveAuthRequestInfo_NoDID(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
defer cleanupOAuth(t, db)
+
+
store := NewPostgresOAuthStore(db, 0) // Use default TTL
+
ctx := context.Background()
+
+
info := oauth.AuthRequestData{
+
State: "test_state_nodid",
+
AuthServerURL: "https://auth.example.com",
+
AccountDID: nil, // No DID provided
+
Scopes: []string{"atproto"},
+
RequestURI: "urn:ietf:params:oauth:request_uri:nodid",
+
AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
+
PKCEVerifier: "verifier_nodid",
+
DPoPAuthServerNonce: "nonce_nodid",
+
DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
+
}
+
+
// Save auth request info without DID
+
err := store.SaveAuthRequestInfo(ctx, info)
+
assert.NoError(t, err)
+
+
// Retrieve and verify DID is nil
+
retrieved, err := store.GetAuthRequestInfo(ctx, "test_state_nodid")
+
assert.NoError(t, err)
+
assert.Nil(t, retrieved.AccountDID)
+
assert.Equal(t, info.State, retrieved.State)
+
}
+
+
func TestPostgresOAuthStore_GetAuthRequestInfo_NotFound(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
store := NewPostgresOAuthStore(db, 0) // Use default TTL
+
ctx := context.Background()
+
+
_, err := store.GetAuthRequestInfo(ctx, "nonexistent_state")
+
assert.ErrorIs(t, err, ErrAuthRequestNotFound)
+
}
+
+
func TestPostgresOAuthStore_DeleteAuthRequestInfo(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
defer cleanupOAuth(t, db)
+
+
store := NewPostgresOAuthStore(db, 0) // Use default TTL
+
ctx := context.Background()
+
+
info := oauth.AuthRequestData{
+
State: "test_state_delete",
+
AuthServerURL: "https://auth.example.com",
+
Scopes: []string{"atproto"},
+
RequestURI: "urn:ietf:params:oauth:request_uri:delete",
+
AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
+
PKCEVerifier: "verifier_delete",
+
DPoPAuthServerNonce: "nonce_delete",
+
DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
+
}
+
+
// Save auth request info
+
err := store.SaveAuthRequestInfo(ctx, info)
+
require.NoError(t, err)
+
+
// Delete auth request info
+
err = store.DeleteAuthRequestInfo(ctx, "test_state_delete")
+
assert.NoError(t, err)
+
+
// Verify it's gone
+
_, err = store.GetAuthRequestInfo(ctx, "test_state_delete")
+
assert.ErrorIs(t, err, ErrAuthRequestNotFound)
+
}
+
+
func TestPostgresOAuthStore_DeleteAuthRequestInfo_NotFound(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
store := NewPostgresOAuthStore(db, 0) // Use default TTL
+
ctx := context.Background()
+
+
err := store.DeleteAuthRequestInfo(ctx, "nonexistent_state")
+
assert.ErrorIs(t, err, ErrAuthRequestNotFound)
+
}
+
+
func TestPostgresOAuthStore_CleanupExpiredSessions(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
defer cleanupOAuth(t, db)
+
+
storeInterface := NewPostgresOAuthStore(db, 0) // Use default TTL
+
store, ok := storeInterface.(*PostgresOAuthStore)
+
require.True(t, ok, "store should be *PostgresOAuthStore")
+
ctx := context.Background()
+
+
did1, err := syntax.ParseDID("did:plc:testexpired1")
+
require.NoError(t, err)
+
did2, err := syntax.ParseDID("did:plc:testexpired2")
+
require.NoError(t, err)
+
+
// Create an expired session (manually insert with past expiration)
+
_, err = db.ExecContext(ctx, `
+
INSERT INTO oauth_sessions (
+
did, session_id, handle, pds_url, host_url,
+
access_token, refresh_token,
+
dpop_private_key_multibase, auth_server_iss,
+
auth_server_token_endpoint, scopes,
+
expires_at, created_at
+
) VALUES (
+
$1, $2, $3, $4, $5,
+
$6, $7,
+
$8, $9,
+
$10, $11,
+
NOW() - INTERVAL '1 day', NOW()
+
)
+
`, did1.String(), "expired_session", "test.handle", "https://pds.example.com", "https://pds.example.com",
+
"expired_token", "expired_refresh",
+
"z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH", "https://auth.example.com",
+
"https://auth.example.com/oauth/token", `{"atproto"}`)
+
require.NoError(t, err)
+
+
// Create a valid session
+
validSession := oauth.ClientSessionData{
+
AccountDID: did2,
+
SessionID: "valid_session",
+
HostURL: "https://pds.example.com",
+
AuthServerURL: "https://auth.example.com",
+
AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
+
Scopes: []string{"atproto"},
+
AccessToken: "valid_token",
+
RefreshToken: "valid_refresh",
+
DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
+
}
+
err = store.SaveSession(ctx, validSession)
+
require.NoError(t, err)
+
+
// Cleanup expired sessions
+
count, err := store.CleanupExpiredSessions(ctx)
+
assert.NoError(t, err)
+
assert.Equal(t, int64(1), count, "Should delete 1 expired session")
+
+
// Verify expired session is gone
+
_, err = store.GetSession(ctx, did1, "expired_session")
+
assert.ErrorIs(t, err, ErrSessionNotFound)
+
+
// Verify valid session still exists
+
_, err = store.GetSession(ctx, did2, "valid_session")
+
assert.NoError(t, err)
+
}
+
+
func TestPostgresOAuthStore_CleanupExpiredAuthRequests(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
defer cleanupOAuth(t, db)
+
+
storeInterface := NewPostgresOAuthStore(db, 0)
+
pgStore, ok := storeInterface.(*PostgresOAuthStore)
+
require.True(t, ok, "store should be *PostgresOAuthStore")
+
store := oauth.ClientAuthStore(pgStore)
+
ctx := context.Background()
+
+
// Create an old auth request (manually insert with old timestamp)
+
_, err := db.ExecContext(ctx, `
+
INSERT INTO oauth_requests (
+
state, did, handle, pds_url, pkce_verifier,
+
dpop_private_key_multibase, dpop_authserver_nonce,
+
auth_server_iss, request_uri,
+
auth_server_token_endpoint, scopes,
+
created_at
+
) VALUES (
+
$1, $2, $3, $4, $5,
+
$6, $7,
+
$8, $9,
+
$10, $11,
+
NOW() - INTERVAL '1 hour'
+
)
+
`, "test_old_state", "did:plc:testold", "test.handle", "https://pds.example.com",
+
"old_verifier", "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
+
"nonce_old", "https://auth.example.com", "urn:ietf:params:oauth:request_uri:old",
+
"https://auth.example.com/oauth/token", `{"atproto"}`)
+
require.NoError(t, err)
+
+
// Create a recent auth request
+
recentInfo := oauth.AuthRequestData{
+
State: "test_recent_state",
+
AuthServerURL: "https://auth.example.com",
+
Scopes: []string{"atproto"},
+
RequestURI: "urn:ietf:params:oauth:request_uri:recent",
+
AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
+
PKCEVerifier: "recent_verifier",
+
DPoPAuthServerNonce: "nonce_recent",
+
DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
+
}
+
err = store.SaveAuthRequestInfo(ctx, recentInfo)
+
require.NoError(t, err)
+
+
// Cleanup expired auth requests (older than 30 minutes)
+
count, err := pgStore.CleanupExpiredAuthRequests(ctx)
+
assert.NoError(t, err)
+
assert.Equal(t, int64(1), count, "Should delete 1 expired auth request")
+
+
// Verify old request is gone
+
_, err = store.GetAuthRequestInfo(ctx, "test_old_state")
+
assert.ErrorIs(t, err, ErrAuthRequestNotFound)
+
+
// Verify recent request still exists
+
_, err = store.GetAuthRequestInfo(ctx, "test_recent_state")
+
assert.NoError(t, err)
+
}
+
+
func TestPostgresOAuthStore_MultipleSessions(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
defer cleanupOAuth(t, db)
+
+
store := NewPostgresOAuthStore(db, 0) // Use default TTL
+
ctx := context.Background()
+
+
did, err := syntax.ParseDID("did:plc:testmulti")
+
require.NoError(t, err)
+
+
// Create multiple sessions for the same DID
+
session1 := oauth.ClientSessionData{
+
AccountDID: did,
+
SessionID: "browser1",
+
HostURL: "https://pds.example.com",
+
AuthServerURL: "https://auth.example.com",
+
AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
+
Scopes: []string{"atproto"},
+
AccessToken: "token_browser1",
+
RefreshToken: "refresh_browser1",
+
DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
+
}
+
+
session2 := oauth.ClientSessionData{
+
AccountDID: did,
+
SessionID: "mobile_app",
+
HostURL: "https://pds.example.com",
+
AuthServerURL: "https://auth.example.com",
+
AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
+
Scopes: []string{"atproto"},
+
AccessToken: "token_mobile",
+
RefreshToken: "refresh_mobile",
+
DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktX",
+
}
+
+
// Save both sessions
+
err = store.SaveSession(ctx, session1)
+
require.NoError(t, err)
+
err = store.SaveSession(ctx, session2)
+
require.NoError(t, err)
+
+
// Retrieve both sessions
+
retrieved1, err := store.GetSession(ctx, did, "browser1")
+
assert.NoError(t, err)
+
assert.Equal(t, "token_browser1", retrieved1.AccessToken)
+
+
retrieved2, err := store.GetSession(ctx, did, "mobile_app")
+
assert.NoError(t, err)
+
assert.Equal(t, "token_mobile", retrieved2.AccessToken)
+
+
// Delete one session
+
err = store.DeleteSession(ctx, did, "browser1")
+
assert.NoError(t, err)
+
+
// Verify only browser1 is deleted
+
_, err = store.GetSession(ctx, did, "browser1")
+
assert.ErrorIs(t, err, ErrSessionNotFound)
+
+
// mobile_app should still exist
+
_, err = store.GetSession(ctx, did, "mobile_app")
+
assert.NoError(t, err)
+
}
+99
internal/atproto/oauth/transport.go
···
+
package oauth
+
+
import (
+
"fmt"
+
"net"
+
"net/http"
+
"time"
+
)
+
+
// ssrfSafeTransport wraps http.Transport to prevent SSRF attacks
+
type ssrfSafeTransport struct {
+
base *http.Transport
+
allowPrivate bool // For dev/testing only
+
}
+
+
// isPrivateIP checks if an IP is in a private/reserved range
+
func isPrivateIP(ip net.IP) bool {
+
if ip == nil {
+
return false
+
}
+
+
// Check for loopback
+
if ip.IsLoopback() {
+
return true
+
}
+
+
// Check for link-local
+
if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
+
return true
+
}
+
+
// Check for private ranges
+
privateRanges := []string{
+
"10.0.0.0/8",
+
"172.16.0.0/12",
+
"192.168.0.0/16",
+
"169.254.0.0/16",
+
"::1/128",
+
"fc00::/7",
+
"fe80::/10",
+
}
+
+
for _, cidr := range privateRanges {
+
_, network, err := net.ParseCIDR(cidr)
+
if err == nil && network.Contains(ip) {
+
return true
+
}
+
}
+
+
return false
+
}
+
+
func (t *ssrfSafeTransport) RoundTrip(req *http.Request) (*http.Response, error) {
+
host := req.URL.Hostname()
+
+
// Resolve hostname to IP
+
ips, err := net.LookupIP(host)
+
if err != nil {
+
return nil, fmt.Errorf("failed to resolve host: %w", err)
+
}
+
+
// Check all resolved IPs
+
if !t.allowPrivate {
+
for _, ip := range ips {
+
if isPrivateIP(ip) {
+
return nil, fmt.Errorf("SSRF blocked: %s resolves to private IP %s", host, ip)
+
}
+
}
+
}
+
+
return t.base.RoundTrip(req)
+
}
+
+
// NewSSRFSafeHTTPClient creates an HTTP client with SSRF protections
+
func NewSSRFSafeHTTPClient(allowPrivate bool) *http.Client {
+
transport := &ssrfSafeTransport{
+
base: &http.Transport{
+
DialContext: (&net.Dialer{
+
Timeout: 10 * time.Second,
+
KeepAlive: 30 * time.Second,
+
}).DialContext,
+
MaxIdleConns: 100,
+
IdleConnTimeout: 90 * time.Second,
+
TLSHandshakeTimeout: 10 * time.Second,
+
},
+
allowPrivate: allowPrivate,
+
}
+
+
return &http.Client{
+
Timeout: 15 * time.Second,
+
Transport: transport,
+
CheckRedirect: func(req *http.Request, via []*http.Request) error {
+
if len(via) >= 5 {
+
return fmt.Errorf("too many redirects")
+
}
+
return nil
+
},
+
}
+
}
+132
internal/atproto/oauth/transport_test.go
···
+
package oauth
+
+
import (
+
"net"
+
"net/http"
+
"testing"
+
)
+
+
func TestIsPrivateIP(t *testing.T) {
+
tests := []struct {
+
name string
+
ip string
+
expected bool
+
}{
+
// Loopback addresses
+
{"IPv4 loopback", "127.0.0.1", true},
+
{"IPv6 loopback", "::1", true},
+
+
// Private IPv4 ranges
+
{"Private 10.x.x.x", "10.0.0.1", true},
+
{"Private 10.x.x.x edge", "10.255.255.255", true},
+
{"Private 172.16.x.x", "172.16.0.1", true},
+
{"Private 172.31.x.x edge", "172.31.255.255", true},
+
{"Private 192.168.x.x", "192.168.1.1", true},
+
{"Private 192.168.x.x edge", "192.168.255.255", true},
+
+
// Link-local addresses
+
{"Link-local IPv4", "169.254.1.1", true},
+
{"Link-local IPv6", "fe80::1", true},
+
+
// IPv6 private ranges
+
{"IPv6 unique local fc00", "fc00::1", true},
+
{"IPv6 unique local fd00", "fd00::1", true},
+
+
// Public addresses
+
{"Public IP 1.1.1.1", "1.1.1.1", false},
+
{"Public IP 8.8.8.8", "8.8.8.8", false},
+
{"Public IP 172.15.0.1", "172.15.0.1", false}, // Just before 172.16/12
+
{"Public IP 172.32.0.1", "172.32.0.1", false}, // Just after 172.31/12
+
{"Public IP 11.0.0.1", "11.0.0.1", false}, // Just after 10/8
+
{"Public IPv6", "2001:4860:4860::8888", false}, // Google DNS
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
ip := net.ParseIP(tt.ip)
+
if ip == nil {
+
t.Fatalf("Failed to parse IP: %s", tt.ip)
+
}
+
+
result := isPrivateIP(ip)
+
if result != tt.expected {
+
t.Errorf("isPrivateIP(%s) = %v, expected %v", tt.ip, result, tt.expected)
+
}
+
})
+
}
+
}
+
+
func TestIsPrivateIP_NilIP(t *testing.T) {
+
result := isPrivateIP(nil)
+
if result != false {
+
t.Errorf("isPrivateIP(nil) = %v, expected false", result)
+
}
+
}
+
+
func TestNewSSRFSafeHTTPClient(t *testing.T) {
+
tests := []struct {
+
name string
+
allowPrivate bool
+
}{
+
{"Production client (no private IPs)", false},
+
{"Development client (allow private IPs)", true},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
client := NewSSRFSafeHTTPClient(tt.allowPrivate)
+
+
if client == nil {
+
t.Fatal("NewSSRFSafeHTTPClient returned nil")
+
}
+
+
if client.Timeout == 0 {
+
t.Error("Expected timeout to be set")
+
}
+
+
if client.Transport == nil {
+
t.Error("Expected transport to be set")
+
}
+
+
transport, ok := client.Transport.(*ssrfSafeTransport)
+
if !ok {
+
t.Error("Expected ssrfSafeTransport")
+
}
+
+
if transport.allowPrivate != tt.allowPrivate {
+
t.Errorf("Expected allowPrivate=%v, got %v", tt.allowPrivate, transport.allowPrivate)
+
}
+
})
+
}
+
}
+
+
func TestSSRFSafeHTTPClient_RedirectLimit(t *testing.T) {
+
client := NewSSRFSafeHTTPClient(false)
+
+
// Simulate checking redirect limit
+
if client.CheckRedirect == nil {
+
t.Fatal("Expected CheckRedirect to be set")
+
}
+
+
// Test redirect limit (5 redirects)
+
var via []*http.Request
+
for i := 0; i < 5; i++ {
+
req := &http.Request{}
+
via = append(via, req)
+
}
+
+
err := client.CheckRedirect(nil, via)
+
if err == nil {
+
t.Error("Expected error for too many redirects")
+
}
+
if err.Error() != "too many redirects" {
+
t.Errorf("Expected 'too many redirects' error, got: %v", err)
+
}
+
+
// Test within limit (4 redirects)
+
via = via[:4]
+
err = client.CheckRedirect(nil, via)
+
if err != nil {
+
t.Errorf("Expected no error for 4 redirects, got: %v", err)
+
}
+
}
+124
internal/db/migrations/019_update_oauth_for_indigo.sql
···
+
-- +goose Up
+
-- Update OAuth tables to match indigo's ClientAuthStore interface requirements
+
-- This migration adds columns needed for OAuth client sessions and auth requests
+
+
-- Update oauth_requests table
+
-- Add columns for request URI, auth server endpoints, scopes, and DPoP key
+
ALTER TABLE oauth_requests
+
ADD COLUMN request_uri TEXT,
+
ADD COLUMN auth_server_token_endpoint TEXT,
+
ADD COLUMN auth_server_revocation_endpoint TEXT,
+
ADD COLUMN scopes TEXT[],
+
ADD COLUMN dpop_private_key_multibase TEXT;
+
+
-- Make original dpop_private_jwk nullable (we now use dpop_private_key_multibase)
+
ALTER TABLE oauth_requests ALTER COLUMN dpop_private_jwk DROP NOT NULL;
+
+
-- Make did nullable (indigo's AuthRequestData.AccountDID is a pointer - optional)
+
ALTER TABLE oauth_requests ALTER COLUMN did DROP NOT NULL;
+
+
-- Make handle and pds_url nullable too (derived from DID resolution, not always available at auth request time)
+
ALTER TABLE oauth_requests ALTER COLUMN handle DROP NOT NULL;
+
ALTER TABLE oauth_requests ALTER COLUMN pds_url DROP NOT NULL;
+
+
-- Update existing oauth_requests data
+
-- Convert dpop_private_jwk (JSONB) to multibase format if needed
+
-- Note: This will leave the multibase column NULL for now since conversion requires crypto logic
+
-- The application will need to handle NULL values or regenerate keys on next auth flow
+
UPDATE oauth_requests
+
SET
+
auth_server_token_endpoint = auth_server_iss || '/oauth/token',
+
scopes = ARRAY['atproto']::TEXT[]
+
WHERE auth_server_token_endpoint IS NULL;
+
+
-- Add indexes for new columns
+
CREATE INDEX idx_oauth_requests_request_uri ON oauth_requests(request_uri) WHERE request_uri IS NOT NULL;
+
+
-- Update oauth_sessions table
+
-- Add session_id column (will become part of composite key)
+
ALTER TABLE oauth_sessions
+
ADD COLUMN session_id TEXT,
+
ADD COLUMN host_url TEXT,
+
ADD COLUMN auth_server_token_endpoint TEXT,
+
ADD COLUMN auth_server_revocation_endpoint TEXT,
+
ADD COLUMN scopes TEXT[],
+
ADD COLUMN dpop_private_key_multibase TEXT;
+
+
-- Make original dpop_private_jwk nullable (we now use dpop_private_key_multibase)
+
ALTER TABLE oauth_sessions ALTER COLUMN dpop_private_jwk DROP NOT NULL;
+
+
-- Populate session_id for existing sessions (use DID as default for single-session per account)
+
-- In production, you may want to generate unique session IDs
+
UPDATE oauth_sessions
+
SET
+
session_id = 'default',
+
host_url = pds_url,
+
auth_server_token_endpoint = auth_server_iss || '/oauth/token',
+
scopes = ARRAY['atproto']::TEXT[]
+
WHERE session_id IS NULL;
+
+
-- Make session_id NOT NULL after populating existing data
+
ALTER TABLE oauth_sessions
+
ALTER COLUMN session_id SET NOT NULL;
+
+
-- Drop old unique constraint on did only
+
ALTER TABLE oauth_sessions
+
DROP CONSTRAINT IF EXISTS oauth_sessions_did_key;
+
+
-- Create new composite unique constraint for (did, session_id)
+
-- This allows multiple sessions per account
+
-- Note: UNIQUE constraint automatically creates an index, so no separate index needed
+
ALTER TABLE oauth_sessions
+
ADD CONSTRAINT oauth_sessions_did_session_id_key UNIQUE (did, session_id);
+
+
-- Add comment explaining the schema change
+
COMMENT ON COLUMN oauth_sessions.session_id IS 'Session identifier to support multiple concurrent sessions per account';
+
COMMENT ON CONSTRAINT oauth_sessions_did_session_id_key ON oauth_sessions IS 'Composite key allowing multiple sessions per DID';
+
+
-- +goose Down
+
-- Rollback: Remove added columns and restore original unique constraint
+
+
-- oauth_sessions rollback
+
-- Drop composite unique constraint (this also drops the associated index)
+
ALTER TABLE oauth_sessions
+
DROP CONSTRAINT IF EXISTS oauth_sessions_did_session_id_key;
+
+
-- Delete all but the most recent session per DID before restoring unique constraint
+
-- This ensures the UNIQUE (did) constraint can be added without conflicts
+
DELETE FROM oauth_sessions a
+
USING oauth_sessions b
+
WHERE a.did = b.did
+
AND a.created_at < b.created_at;
+
+
-- Restore old unique constraint
+
ALTER TABLE oauth_sessions
+
ADD CONSTRAINT oauth_sessions_did_key UNIQUE (did);
+
+
-- Restore NOT NULL constraint on dpop_private_jwk
+
ALTER TABLE oauth_sessions
+
ALTER COLUMN dpop_private_jwk SET NOT NULL;
+
+
ALTER TABLE oauth_sessions
+
DROP COLUMN IF EXISTS dpop_private_key_multibase,
+
DROP COLUMN IF EXISTS scopes,
+
DROP COLUMN IF EXISTS auth_server_revocation_endpoint,
+
DROP COLUMN IF EXISTS auth_server_token_endpoint,
+
DROP COLUMN IF EXISTS host_url,
+
DROP COLUMN IF EXISTS session_id;
+
+
-- oauth_requests rollback
+
DROP INDEX IF EXISTS idx_oauth_requests_request_uri;
+
+
-- Restore NOT NULL constraints
+
ALTER TABLE oauth_requests
+
ALTER COLUMN dpop_private_jwk SET NOT NULL,
+
ALTER COLUMN did SET NOT NULL,
+
ALTER COLUMN handle SET NOT NULL,
+
ALTER COLUMN pds_url SET NOT NULL;
+
+
ALTER TABLE oauth_requests
+
DROP COLUMN IF EXISTS dpop_private_key_multibase,
+
DROP COLUMN IF EXISTS scopes,
+
DROP COLUMN IF EXISTS auth_server_revocation_endpoint,
+
DROP COLUMN IF EXISTS auth_server_token_endpoint,
+
DROP COLUMN IF EXISTS request_uri;
+23
internal/db/migrations/020_add_mobile_oauth_csrf.sql
···
+
-- +goose Up
+
-- Add columns for mobile OAuth CSRF protection with server-side state
+
-- This ties the CSRF token to the OAuth state, allowing validation against
+
-- a value that comes back through the OAuth response (the state parameter)
+
-- rather than only validating cookies against each other.
+
+
ALTER TABLE oauth_requests
+
ADD COLUMN mobile_csrf_token TEXT,
+
ADD COLUMN mobile_redirect_uri TEXT;
+
+
-- Index for quick lookup of mobile data when callback is received
+
CREATE INDEX idx_oauth_requests_mobile_csrf ON oauth_requests(state)
+
WHERE mobile_csrf_token IS NOT NULL;
+
+
COMMENT ON COLUMN oauth_requests.mobile_csrf_token IS 'CSRF token for mobile OAuth flows, validated against cookie on callback';
+
COMMENT ON COLUMN oauth_requests.mobile_redirect_uri IS 'Mobile redirect URI (Universal Link) for this OAuth flow';
+
+
-- +goose Down
+
DROP INDEX IF EXISTS idx_oauth_requests_mobile_csrf;
+
+
ALTER TABLE oauth_requests
+
DROP COLUMN IF EXISTS mobile_redirect_uri,
+
DROP COLUMN IF EXISTS mobile_csrf_token;
+137
internal/api/handlers/wellknown/universal_links.go
···
+
package wellknown
+
+
import (
+
"encoding/json"
+
"log/slog"
+
"net/http"
+
"os"
+
)
+
+
// HandleAppleAppSiteAssociation serves the iOS Universal Links configuration
+
// GET /.well-known/apple-app-site-association
+
//
+
// Universal Links provide cryptographic binding between the app and domain:
+
// - Requires apple-app-site-association file served over HTTPS
+
// - App must have Associated Domains capability configured
+
// - System verifies domain ownership before routing deep links
+
// - Prevents malicious apps from intercepting deep links
+
//
+
// Spec: https://developer.apple.com/documentation/xcode/supporting-universal-links-in-your-app
+
func HandleAppleAppSiteAssociation(w http.ResponseWriter, r *http.Request) {
+
// Get Apple App ID from environment (format: <Team ID>.<Bundle ID>)
+
// Example: "ABCD1234.social.coves.app"
+
// Find Team ID in Apple Developer Portal -> Membership
+
// Bundle ID is configured in Xcode project
+
appleAppID := os.Getenv("APPLE_APP_ID")
+
if appleAppID == "" {
+
// Development fallback - allows testing without real Team ID
+
// IMPORTANT: This MUST be set in production for Universal Links to work
+
appleAppID = "DEVELOPMENT.social.coves.app"
+
slog.Warn("APPLE_APP_ID not set, using development placeholder",
+
"app_id", appleAppID,
+
"note", "Set APPLE_APP_ID env var for production Universal Links")
+
}
+
+
// Apple requires application/json content type (no charset)
+
w.Header().Set("Content-Type", "application/json")
+
+
// Construct the response per Apple's spec
+
// See: https://developer.apple.com/documentation/bundleresources/applinks
+
response := map[string]interface{}{
+
"applinks": map[string]interface{}{
+
"apps": []string{}, // Must be empty array per Apple spec
+
"details": []map[string]interface{}{
+
{
+
"appID": appleAppID,
+
// Paths that trigger Universal Links when opened in Safari/other apps
+
// These URLs will open the app instead of the browser
+
"paths": []string{
+
"/app/oauth/callback", // Primary Universal Link OAuth callback
+
"/app/oauth/callback/*", // Catch-all for query params
+
},
+
},
+
},
+
},
+
}
+
+
if err := json.NewEncoder(w).Encode(response); err != nil {
+
slog.Error("failed to encode apple-app-site-association", "error", err)
+
http.Error(w, "internal server error", http.StatusInternalServerError)
+
return
+
}
+
+
slog.Debug("served apple-app-site-association", "app_id", appleAppID)
+
}
+
+
// HandleAssetLinks serves the Android App Links configuration
+
// GET /.well-known/assetlinks.json
+
//
+
// App Links provide cryptographic binding between the app and domain:
+
// - Requires assetlinks.json file served over HTTPS
+
// - App must have intent-filter with android:autoVerify="true"
+
// - System verifies domain ownership via SHA-256 certificate fingerprint
+
// - Prevents malicious apps from intercepting deep links
+
//
+
// Spec: https://developer.android.com/training/app-links/verify-android-applinks
+
func HandleAssetLinks(w http.ResponseWriter, r *http.Request) {
+
// Get Android package name from environment
+
// Example: "social.coves.app"
+
androidPackage := os.Getenv("ANDROID_PACKAGE_NAME")
+
if androidPackage == "" {
+
androidPackage = "social.coves.app" // Default for development
+
slog.Warn("ANDROID_PACKAGE_NAME not set, using default",
+
"package", androidPackage,
+
"note", "Set ANDROID_PACKAGE_NAME env var for production App Links")
+
}
+
+
// Get SHA-256 fingerprint from environment
+
// This is the SHA-256 fingerprint of the app's signing certificate
+
//
+
// To get the fingerprint:
+
// Production: keytool -list -v -keystore release.jks -alias release
+
// Debug: keytool -list -v -keystore ~/.android/debug.keystore -alias androiddebugkey -storepass android -keypass android
+
//
+
// Look for "SHA256:" in the output
+
// Format: AA:BB:CC:DD:...:FF (64 hex characters separated by colons)
+
androidFingerprint := os.Getenv("ANDROID_SHA256_FINGERPRINT")
+
if androidFingerprint == "" {
+
// Development fallback - this won't work for real App Links verification
+
// IMPORTANT: This MUST be set in production for App Links to work
+
androidFingerprint = "00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00"
+
slog.Warn("ANDROID_SHA256_FINGERPRINT not set, using development placeholder",
+
"fingerprint", androidFingerprint,
+
"note", "Set ANDROID_SHA256_FINGERPRINT env var for production App Links")
+
}
+
+
w.Header().Set("Content-Type", "application/json")
+
+
// Construct the response per Google's Digital Asset Links spec
+
// See: https://developers.google.com/digital-asset-links/v1/getting-started
+
response := []map[string]interface{}{
+
{
+
// delegate_permission/common.handle_all_urls grants the app permission
+
// to handle URLs for this domain
+
"relation": []string{"delegate_permission/common.handle_all_urls"},
+
"target": map[string]interface{}{
+
"namespace": "android_app",
+
"package_name": androidPackage,
+
// List of certificate fingerprints that can sign the app
+
// Multiple fingerprints can be provided for different signing keys
+
// (e.g., debug + release)
+
"sha256_cert_fingerprints": []string{
+
androidFingerprint,
+
},
+
},
+
},
+
}
+
+
if err := json.NewEncoder(w).Encode(response); err != nil {
+
slog.Error("failed to encode assetlinks.json", "error", err)
+
http.Error(w, "internal server error", http.StatusInternalServerError)
+
return
+
}
+
+
slog.Debug("served assetlinks.json",
+
"package", androidPackage,
+
"fingerprint", androidFingerprint)
+
}
+25
internal/api/routes/wellknown.go
···
+
package routes
+
+
import (
+
"Coves/internal/api/handlers/wellknown"
+
+
"github.com/go-chi/chi/v5"
+
)
+
+
// RegisterWellKnownRoutes registers RFC 8615 well-known URI endpoints
+
// These endpoints are used for service discovery and mobile app deep linking
+
//
+
// Spec: https://www.rfc-editor.org/rfc/rfc8615.html
+
func RegisterWellKnownRoutes(r chi.Router) {
+
// iOS Universal Links configuration
+
// Required for cryptographically-bound deep linking on iOS
+
// Must be served at exact path /.well-known/apple-app-site-association
+
// Content-Type: application/json (no redirects allowed)
+
r.Get("/.well-known/apple-app-site-association", wellknown.HandleAppleAppSiteAssociation)
+
+
// Android App Links configuration
+
// Required for cryptographically-bound deep linking on Android
+
// Must be served at exact path /.well-known/assetlinks.json
+
// Content-Type: application/json (no redirects allowed)
+
r.Get("/.well-known/assetlinks.json", wellknown.HandleAssetLinks)
+
}
+1 -1
internal/api/handlers/comments/middleware.go
···
// The middleware extracts the viewer DID from the Authorization header if present and valid,
// making it available via middleware.GetUserDID(r) in the handler.
// If no valid token is present, the request continues as anonymous (empty DID).
-
func OptionalAuthMiddleware(authMiddleware *middleware.AtProtoAuthMiddleware, next http.HandlerFunc) http.Handler {
+
func OptionalAuthMiddleware(authMiddleware *middleware.OAuthAuthMiddleware, next http.HandlerFunc) http.Handler {
return authMiddleware.OptionalAuth(http.HandlerFunc(next))
}
+1 -1
internal/api/routes/community.go
···
// RegisterCommunityRoutes registers community-related XRPC endpoints on the router
// Implements social.coves.community.* lexicon endpoints
// allowedCommunityCreators restricts who can create communities. If empty, anyone can create.
-
func RegisterCommunityRoutes(r chi.Router, service communities.Service, authMiddleware *middleware.AtProtoAuthMiddleware, allowedCommunityCreators []string) {
+
func RegisterCommunityRoutes(r chi.Router, service communities.Service, authMiddleware *middleware.OAuthAuthMiddleware, allowedCommunityCreators []string) {
// Initialize handlers
createHandler := community.NewCreateHandler(service, allowedCommunityCreators)
getHandler := community.NewGetHandler(service)
+1 -1
internal/api/routes/post.go
···
// RegisterPostRoutes registers post-related XRPC endpoints on the router
// Implements social.coves.community.post.* lexicon endpoints
-
func RegisterPostRoutes(r chi.Router, service posts.Service, authMiddleware *middleware.AtProtoAuthMiddleware) {
+
func RegisterPostRoutes(r chi.Router, service posts.Service, authMiddleware *middleware.OAuthAuthMiddleware) {
// Initialize handlers
createHandler := post.NewCreateHandler(service)
+291
tests/e2e/oauth_ratelimit_e2e_test.go
···
+
package e2e
+
+
import (
+
"Coves/internal/api/middleware"
+
"net/http"
+
"net/http/httptest"
+
"testing"
+
"time"
+
+
"github.com/stretchr/testify/assert"
+
)
+
+
// TestRateLimiting_E2E_OAuthEndpoints tests OAuth-specific rate limiting
+
// OAuth endpoints have stricter rate limits to prevent:
+
// - Credential stuffing attacks on login endpoints (10 req/min)
+
// - OAuth state exhaustion
+
// - Refresh token abuse (20 req/min)
+
func TestRateLimiting_E2E_OAuthEndpoints(t *testing.T) {
+
t.Run("Login endpoints have 10 req/min limit", func(t *testing.T) {
+
// Create rate limiter matching oauth.go config: 10 requests per minute
+
loginLimiter := middleware.NewRateLimiter(10, 1*time.Minute)
+
+
// Mock OAuth login handler
+
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
w.WriteHeader(http.StatusOK)
+
_, _ = w.Write([]byte("OK"))
+
})
+
+
handler := loginLimiter.Middleware(testHandler)
+
clientIP := "192.168.1.200:12345"
+
+
// Make exactly 10 requests (at limit)
+
for i := 0; i < 10; i++ {
+
req := httptest.NewRequest("GET", "/oauth/login", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
+
handler.ServeHTTP(rr, req)
+
+
assert.Equal(t, http.StatusOK, rr.Code, "Request %d should succeed", i+1)
+
}
+
+
// 11th request should be rate limited
+
req := httptest.NewRequest("GET", "/oauth/login", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
+
handler.ServeHTTP(rr, req)
+
+
assert.Equal(t, http.StatusTooManyRequests, rr.Code, "Request 11 should be rate limited")
+
assert.Contains(t, rr.Body.String(), "Rate limit exceeded", "Should have rate limit error message")
+
})
+
+
t.Run("Mobile login endpoints have 10 req/min limit", func(t *testing.T) {
+
loginLimiter := middleware.NewRateLimiter(10, 1*time.Minute)
+
+
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
w.WriteHeader(http.StatusOK)
+
})
+
+
handler := loginLimiter.Middleware(testHandler)
+
clientIP := "192.168.1.201:12345"
+
+
// Make 10 requests
+
for i := 0; i < 10; i++ {
+
req := httptest.NewRequest("GET", "/oauth/mobile/login", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
handler.ServeHTTP(rr, req)
+
assert.Equal(t, http.StatusOK, rr.Code)
+
}
+
+
// 11th request blocked
+
req := httptest.NewRequest("GET", "/oauth/mobile/login", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
handler.ServeHTTP(rr, req)
+
+
assert.Equal(t, http.StatusTooManyRequests, rr.Code, "Mobile login should be rate limited at 10 req/min")
+
})
+
+
t.Run("Refresh endpoint has 20 req/min limit", func(t *testing.T) {
+
// Refresh has higher limit (20 req/min) for legitimate token refresh
+
refreshLimiter := middleware.NewRateLimiter(20, 1*time.Minute)
+
+
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
w.WriteHeader(http.StatusOK)
+
})
+
+
handler := refreshLimiter.Middleware(testHandler)
+
clientIP := "192.168.1.202:12345"
+
+
// Make 20 requests
+
for i := 0; i < 20; i++ {
+
req := httptest.NewRequest("POST", "/oauth/refresh", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
handler.ServeHTTP(rr, req)
+
assert.Equal(t, http.StatusOK, rr.Code, "Request %d should succeed", i+1)
+
}
+
+
// 21st request blocked
+
req := httptest.NewRequest("POST", "/oauth/refresh", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
handler.ServeHTTP(rr, req)
+
+
assert.Equal(t, http.StatusTooManyRequests, rr.Code, "Refresh should be rate limited at 20 req/min")
+
})
+
+
t.Run("Logout endpoint has 10 req/min limit", func(t *testing.T) {
+
logoutLimiter := middleware.NewRateLimiter(10, 1*time.Minute)
+
+
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
w.WriteHeader(http.StatusOK)
+
})
+
+
handler := logoutLimiter.Middleware(testHandler)
+
clientIP := "192.168.1.203:12345"
+
+
// Make 10 requests
+
for i := 0; i < 10; i++ {
+
req := httptest.NewRequest("POST", "/oauth/logout", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
handler.ServeHTTP(rr, req)
+
assert.Equal(t, http.StatusOK, rr.Code)
+
}
+
+
// 11th request blocked
+
req := httptest.NewRequest("POST", "/oauth/logout", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
handler.ServeHTTP(rr, req)
+
+
assert.Equal(t, http.StatusTooManyRequests, rr.Code, "Logout should be rate limited at 10 req/min")
+
})
+
+
t.Run("OAuth callback has 10 req/min limit", func(t *testing.T) {
+
// Callback uses same limiter as login (part of auth flow)
+
callbackLimiter := middleware.NewRateLimiter(10, 1*time.Minute)
+
+
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
w.WriteHeader(http.StatusOK)
+
})
+
+
handler := callbackLimiter.Middleware(testHandler)
+
clientIP := "192.168.1.204:12345"
+
+
// Make 10 requests
+
for i := 0; i < 10; i++ {
+
req := httptest.NewRequest("GET", "/oauth/callback", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
handler.ServeHTTP(rr, req)
+
assert.Equal(t, http.StatusOK, rr.Code)
+
}
+
+
// 11th request blocked
+
req := httptest.NewRequest("GET", "/oauth/callback", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
handler.ServeHTTP(rr, req)
+
+
assert.Equal(t, http.StatusTooManyRequests, rr.Code, "Callback should be rate limited at 10 req/min")
+
})
+
+
t.Run("OAuth rate limits are stricter than global limit", func(t *testing.T) {
+
// Verify OAuth limits are more restrictive than global 100 req/min
+
const globalLimit = 100
+
const oauthLoginLimit = 10
+
const oauthRefreshLimit = 20
+
+
assert.Less(t, oauthLoginLimit, globalLimit, "OAuth login limit should be stricter than global")
+
assert.Less(t, oauthRefreshLimit, globalLimit, "OAuth refresh limit should be stricter than global")
+
assert.Greater(t, oauthRefreshLimit, oauthLoginLimit, "Refresh limit should be higher than login (legitimate use case)")
+
})
+
+
t.Run("OAuth limits prevent credential stuffing", func(t *testing.T) {
+
// Simulate credential stuffing attack
+
loginLimiter := middleware.NewRateLimiter(10, 1*time.Minute)
+
+
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
// Simulate failed login attempts
+
w.WriteHeader(http.StatusUnauthorized)
+
})
+
+
handler := loginLimiter.Middleware(testHandler)
+
attackerIP := "203.0.113.50:12345"
+
+
// Attacker tries 15 login attempts (credential stuffing)
+
successfulAttempts := 0
+
blockedAttempts := 0
+
+
for i := 0; i < 15; i++ {
+
req := httptest.NewRequest("GET", "/oauth/login", nil)
+
req.RemoteAddr = attackerIP
+
rr := httptest.NewRecorder()
+
+
handler.ServeHTTP(rr, req)
+
+
if rr.Code == http.StatusUnauthorized {
+
successfulAttempts++ // Reached handler (even if auth failed)
+
} else if rr.Code == http.StatusTooManyRequests {
+
blockedAttempts++
+
}
+
}
+
+
// Rate limiter should block 5 attempts after first 10
+
assert.Equal(t, 10, successfulAttempts, "Should allow 10 login attempts")
+
assert.Equal(t, 5, blockedAttempts, "Should block 5 attempts after limit reached")
+
})
+
+
t.Run("OAuth limits are per-endpoint", func(t *testing.T) {
+
// Each endpoint gets its own rate limiter
+
// This test verifies that limits are independent per endpoint
+
loginLimiter := middleware.NewRateLimiter(10, 1*time.Minute)
+
refreshLimiter := middleware.NewRateLimiter(20, 1*time.Minute)
+
+
loginHandler := loginLimiter.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
w.WriteHeader(http.StatusOK)
+
}))
+
+
refreshHandler := refreshLimiter.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
w.WriteHeader(http.StatusOK)
+
}))
+
+
clientIP := "192.168.1.205:12345"
+
+
// Exhaust login limit
+
for i := 0; i < 10; i++ {
+
req := httptest.NewRequest("GET", "/oauth/login", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
loginHandler.ServeHTTP(rr, req)
+
assert.Equal(t, http.StatusOK, rr.Code)
+
}
+
+
// Login limit exhausted
+
req := httptest.NewRequest("GET", "/oauth/login", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
loginHandler.ServeHTTP(rr, req)
+
assert.Equal(t, http.StatusTooManyRequests, rr.Code, "Login should be rate limited")
+
+
// Refresh endpoint should still work (independent limiter)
+
req = httptest.NewRequest("POST", "/oauth/refresh", nil)
+
req.RemoteAddr = clientIP
+
rr = httptest.NewRecorder()
+
refreshHandler.ServeHTTP(rr, req)
+
assert.Equal(t, http.StatusOK, rr.Code, "Refresh should not be affected by login rate limit")
+
})
+
}
+
+
// OAuth Rate Limiting Configuration Documentation
+
// ================================================
+
// This test file validates OAuth-specific rate limits applied in oauth.go:
+
//
+
// 1. Login Endpoints (Credential Stuffing Protection)
+
// - Endpoints: /oauth/login, /oauth/mobile/login, /oauth/callback
+
// - Limit: 10 requests per minute per IP
+
// - Reason: Prevent brute force and credential stuffing attacks
+
// - Implementation: internal/api/routes/oauth.go:21
+
//
+
// 2. Refresh Endpoint (Token Refresh)
+
// - Endpoint: /oauth/refresh
+
// - Limit: 20 requests per minute per IP
+
// - Reason: Allow legitimate token refresh while preventing abuse
+
// - Implementation: internal/api/routes/oauth.go:24
+
//
+
// 3. Logout Endpoint
+
// - Endpoint: /oauth/logout
+
// - Limit: 10 requests per minute per IP
+
// - Reason: Prevent session exhaustion attacks
+
// - Implementation: internal/api/routes/oauth.go:27
+
//
+
// 4. Metadata Endpoints (No Extra Limit)
+
// - Endpoints: /oauth/client-metadata.json, /oauth/jwks.json
+
// - Limit: Global 100 requests per minute (from main.go)
+
// - Reason: Public metadata, not sensitive to rate abuse
+
//
+
// Security Benefits:
+
// - Credential Stuffing: Limits password guessing to 10 attempts/min
+
// - State Exhaustion: Prevents OAuth state generation spam
+
// - Token Abuse: Limits refresh token usage while allowing legitimate refresh
+
//
+
// Rate Limit Hierarchy:
+
// - OAuth login: 10 req/min (most restrictive)
+
// - OAuth refresh: 20 req/min (moderate)
+
// - Comments: 20 req/min (expensive queries)
+
// - Global: 100 req/min (baseline)
+910
tests/integration/oauth_e2e_test.go
···
+
package integration
+
+
import (
+
"Coves/internal/atproto/oauth"
+
"context"
+
"encoding/json"
+
"fmt"
+
"net/http"
+
"net/http/httptest"
+
"strings"
+
"testing"
+
"time"
+
+
oauthlib "github.com/bluesky-social/indigo/atproto/auth/oauth"
+
"github.com/bluesky-social/indigo/atproto/syntax"
+
"github.com/go-chi/chi/v5"
+
_ "github.com/lib/pq"
+
"github.com/pressly/goose/v3"
+
"github.com/stretchr/testify/assert"
+
"github.com/stretchr/testify/require"
+
)
+
+
// TestOAuth_Components tests OAuth component functionality without requiring PDS.
+
// This validates all Coves OAuth code:
+
// - Session storage and retrieval (PostgreSQL)
+
// - Token sealing (AES-GCM encryption)
+
// - Token unsealing (decryption + validation)
+
// - Session cleanup
+
//
+
// NOTE: Full OAuth redirect flow testing requires both HTTPS PDS and HTTPS Coves deployment.
+
// The OAuth redirect flow is handled by indigo's library and enforces OAuth 2.0 spec
+
// (HTTPS required for authorization servers and redirect URIs).
+
func TestOAuth_Components(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping OAuth component test in short mode")
+
}
+
+
// Setup test database
+
db := setupTestDB(t)
+
defer func() {
+
if err := db.Close(); err != nil {
+
t.Logf("Failed to close database: %v", err)
+
}
+
}()
+
+
// Run migrations to ensure OAuth tables exist
+
require.NoError(t, goose.SetDialect("postgres"))
+
require.NoError(t, goose.Up(db, "../../internal/db/migrations"))
+
+
t.Log("๐Ÿ”ง Testing OAuth Components")
+
+
ctx := context.Background()
+
+
// Setup OAuth client and store
+
store := SetupOAuthTestStore(t, db)
+
client := SetupOAuthTestClient(t, store)
+
require.NotNil(t, client, "OAuth client should be initialized")
+
+
// Use a test DID (doesn't need to exist on PDS for component tests)
+
testDID := "did:plc:componenttest123"
+
+
// Run component tests
+
testOAuthComponentsWithMockedSession(t, ctx, nil, store, client, testDID, "")
+
+
t.Log("")
+
t.Log(strings.Repeat("=", 60))
+
t.Log("โœ… OAuth Component Tests Complete")
+
t.Log(strings.Repeat("=", 60))
+
t.Log("Components validated:")
+
t.Log(" โœ“ Session storage (PostgreSQL)")
+
t.Log(" โœ“ Token sealing (AES-GCM encryption)")
+
t.Log(" โœ“ Token unsealing (decryption + validation)")
+
t.Log(" โœ“ Session cleanup")
+
t.Log("")
+
t.Log("NOTE: Full OAuth redirect flow requires HTTPS PDS + HTTPS Coves")
+
t.Log(strings.Repeat("=", 60))
+
}
+
+
// testOAuthComponentsWithMockedSession tests OAuth components that work without PDS redirect flow.
+
// This is used when testing with localhost PDS, where the indigo library rejects http:// URLs.
+
func testOAuthComponentsWithMockedSession(t *testing.T, ctx context.Context, _ interface{}, store oauthlib.ClientAuthStore, client *oauth.OAuthClient, userDID, _ string) {
+
t.Helper()
+
+
t.Log("๐Ÿ”ง Testing OAuth components with mocked session...")
+
+
// Parse DID
+
parsedDID, err := syntax.ParseDID(userDID)
+
require.NoError(t, err, "Should parse DID")
+
+
// Component 1: Session Storage
+
t.Log(" ๐Ÿ“ฆ Component 1: Testing session storage...")
+
testSession := oauthlib.ClientSessionData{
+
AccountDID: parsedDID,
+
SessionID: fmt.Sprintf("localhost-test-%d", time.Now().UnixNano()),
+
HostURL: "http://localhost:3001",
+
AccessToken: "mocked-access-token",
+
Scopes: []string{"atproto", "transition:generic"},
+
}
+
+
err = store.SaveSession(ctx, testSession)
+
require.NoError(t, err, "Should save session")
+
+
retrieved, err := store.GetSession(ctx, parsedDID, testSession.SessionID)
+
require.NoError(t, err, "Should retrieve session")
+
require.Equal(t, testSession.SessionID, retrieved.SessionID)
+
require.Equal(t, testSession.AccessToken, retrieved.AccessToken)
+
t.Log(" โœ… Session storage working")
+
+
// Component 2: Token Sealing
+
t.Log(" ๐Ÿ” Component 2: Testing token sealing...")
+
sealedToken, err := client.SealSession(parsedDID.String(), testSession.SessionID, time.Hour)
+
require.NoError(t, err, "Should seal token")
+
require.NotEmpty(t, sealedToken, "Sealed token should not be empty")
+
tokenPreview := sealedToken
+
if len(tokenPreview) > 50 {
+
tokenPreview = tokenPreview[:50]
+
}
+
t.Logf(" โœ… Token sealed: %s...", tokenPreview)
+
+
// Component 3: Token Unsealing
+
t.Log(" ๐Ÿ”“ Component 3: Testing token unsealing...")
+
unsealed, err := client.UnsealSession(sealedToken)
+
require.NoError(t, err, "Should unseal token")
+
require.Equal(t, userDID, unsealed.DID)
+
require.Equal(t, testSession.SessionID, unsealed.SessionID)
+
t.Log(" โœ… Token unsealing working")
+
+
// Component 4: Session Cleanup
+
t.Log(" ๐Ÿงน Component 4: Testing session cleanup...")
+
err = store.DeleteSession(ctx, parsedDID, testSession.SessionID)
+
require.NoError(t, err, "Should delete session")
+
+
_, err = store.GetSession(ctx, parsedDID, testSession.SessionID)
+
require.Error(t, err, "Session should not exist after deletion")
+
t.Log(" โœ… Session cleanup working")
+
+
t.Log("โœ… All OAuth components verified!")
+
t.Log("")
+
t.Log("๐Ÿ“ Summary: OAuth implementation validated with mocked session")
+
t.Log(" - Session storage: โœ“")
+
t.Log(" - Token sealing: โœ“")
+
t.Log(" - Token unsealing: โœ“")
+
t.Log(" - Session cleanup: โœ“")
+
t.Log("")
+
t.Log("โš ๏ธ To test full OAuth redirect flow, use a production PDS with HTTPS")
+
}
+
+
// TestOAuthE2E_TokenExpiration tests that expired sealed tokens are rejected
+
func TestOAuthE2E_TokenExpiration(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping OAuth token expiration test in short mode")
+
}
+
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
// Run migrations
+
require.NoError(t, goose.SetDialect("postgres"))
+
require.NoError(t, goose.Up(db, "../../internal/db/migrations"))
+
+
ctx := context.Background()
+
+
t.Log("โฐ Testing OAuth token expiration...")
+
+
// Setup OAuth client and store
+
store := SetupOAuthTestStore(t, db)
+
client := SetupOAuthTestClient(t, store)
+
_ = oauth.NewOAuthHandler(client, store) // Handler created for completeness
+
+
// Create test session with past expiration
+
did, err := syntax.ParseDID("did:plc:expiredtest123")
+
require.NoError(t, err)
+
+
testSession := oauthlib.ClientSessionData{
+
AccountDID: did,
+
SessionID: "expired-session",
+
HostURL: "http://localhost:3001",
+
AccessToken: "expired-token",
+
Scopes: []string{"atproto"},
+
}
+
+
// Save session
+
err = store.SaveSession(ctx, testSession)
+
require.NoError(t, err)
+
+
// Manually update expiration to the past
+
_, err = db.ExecContext(ctx,
+
"UPDATE oauth_sessions SET expires_at = NOW() - INTERVAL '1 day' WHERE did = $1 AND session_id = $2",
+
did.String(), testSession.SessionID)
+
require.NoError(t, err)
+
+
// Try to retrieve expired session
+
_, err = store.GetSession(ctx, did, testSession.SessionID)
+
assert.Error(t, err, "Should not be able to retrieve expired session")
+
assert.Equal(t, oauth.ErrSessionNotFound, err, "Should return ErrSessionNotFound for expired session")
+
+
// Test cleanup of expired sessions
+
cleaned, err := store.(*oauth.PostgresOAuthStore).CleanupExpiredSessions(ctx)
+
require.NoError(t, err, "Cleanup should succeed")
+
assert.Greater(t, cleaned, int64(0), "Should have cleaned up at least one session")
+
+
t.Logf("โœ… Expired session handling verified (cleaned %d sessions)", cleaned)
+
}
+
+
// TestOAuthE2E_InvalidToken tests that invalid/tampered tokens are rejected
+
func TestOAuthE2E_InvalidToken(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping OAuth invalid token test in short mode")
+
}
+
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
// Run migrations
+
require.NoError(t, goose.SetDialect("postgres"))
+
require.NoError(t, goose.Up(db, "../../internal/db/migrations"))
+
+
t.Log("๐Ÿ”’ Testing OAuth invalid token rejection...")
+
+
// Setup OAuth client and store
+
store := SetupOAuthTestStore(t, db)
+
client := SetupOAuthTestClient(t, store)
+
handler := oauth.NewOAuthHandler(client, store)
+
+
// Setup test server with protected endpoint
+
r := chi.NewRouter()
+
r.Get("/api/me", func(w http.ResponseWriter, r *http.Request) {
+
sessData, err := handler.GetSessionFromRequest(r)
+
if err != nil {
+
http.Error(w, "Unauthorized", http.StatusUnauthorized)
+
return
+
}
+
w.Header().Set("Content-Type", "application/json")
+
_ = json.NewEncoder(w).Encode(map[string]string{"did": sessData.AccountDID.String()})
+
})
+
+
server := httptest.NewServer(r)
+
defer server.Close()
+
+
// Test with invalid token formats
+
testCases := []struct {
+
name string
+
token string
+
}{
+
{"Empty token", ""},
+
{"Invalid base64", "not-valid-base64!!!"},
+
{"Tampered token", "dGFtcGVyZWQtdG9rZW4tZGF0YQ=="}, // Valid base64 but invalid content
+
{"Short token", "abc"},
+
}
+
+
for _, tc := range testCases {
+
t.Run(tc.name, func(t *testing.T) {
+
req, _ := http.NewRequest("GET", server.URL+"/api/me", nil)
+
if tc.token != "" {
+
req.Header.Set("Authorization", "Bearer "+tc.token)
+
}
+
+
resp, err := http.DefaultClient.Do(req)
+
require.NoError(t, err)
+
defer func() { _ = resp.Body.Close() }()
+
+
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode,
+
"Invalid token should be rejected with 401")
+
})
+
}
+
+
t.Logf("โœ… Invalid token rejection verified")
+
}
+
+
// TestOAuthE2E_SessionNotFound tests behavior when session doesn't exist in DB
+
func TestOAuthE2E_SessionNotFound(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping OAuth session not found test in short mode")
+
}
+
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
// Run migrations
+
require.NoError(t, goose.SetDialect("postgres"))
+
require.NoError(t, goose.Up(db, "../../internal/db/migrations"))
+
+
ctx := context.Background()
+
+
t.Log("๐Ÿ” Testing OAuth session not found behavior...")
+
+
// Setup OAuth store
+
store := SetupOAuthTestStore(t, db)
+
+
// Try to retrieve non-existent session
+
nonExistentDID, err := syntax.ParseDID("did:plc:nonexistent123")
+
require.NoError(t, err)
+
+
_, err = store.GetSession(ctx, nonExistentDID, "nonexistent-session")
+
assert.Error(t, err, "Should return error for non-existent session")
+
assert.Equal(t, oauth.ErrSessionNotFound, err, "Should return ErrSessionNotFound")
+
+
// Try to delete non-existent session
+
err = store.DeleteSession(ctx, nonExistentDID, "nonexistent-session")
+
assert.Error(t, err, "Should return error when deleting non-existent session")
+
assert.Equal(t, oauth.ErrSessionNotFound, err, "Should return ErrSessionNotFound")
+
+
t.Logf("โœ… Session not found handling verified")
+
}
+
+
// TestOAuthE2E_MultipleSessionsPerUser tests that a user can have multiple active sessions
+
func TestOAuthE2E_MultipleSessionsPerUser(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping OAuth multiple sessions test in short mode")
+
}
+
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
// Run migrations
+
require.NoError(t, goose.SetDialect("postgres"))
+
require.NoError(t, goose.Up(db, "../../internal/db/migrations"))
+
+
ctx := context.Background()
+
+
t.Log("๐Ÿ‘ฅ Testing multiple OAuth sessions per user...")
+
+
// Setup OAuth store
+
store := SetupOAuthTestStore(t, db)
+
+
// Create a test DID
+
did, err := syntax.ParseDID("did:plc:multisession123")
+
require.NoError(t, err)
+
+
// Create multiple sessions for the same user
+
sessions := []oauthlib.ClientSessionData{
+
{
+
AccountDID: did,
+
SessionID: "session-1-web",
+
HostURL: "http://localhost:3001",
+
AccessToken: "token-1",
+
Scopes: []string{"atproto"},
+
},
+
{
+
AccountDID: did,
+
SessionID: "session-2-mobile",
+
HostURL: "http://localhost:3001",
+
AccessToken: "token-2",
+
Scopes: []string{"atproto"},
+
},
+
{
+
AccountDID: did,
+
SessionID: "session-3-tablet",
+
HostURL: "http://localhost:3001",
+
AccessToken: "token-3",
+
Scopes: []string{"atproto"},
+
},
+
}
+
+
// Save all sessions
+
for i, session := range sessions {
+
err := store.SaveSession(ctx, session)
+
require.NoError(t, err, "Should be able to save session %d", i+1)
+
}
+
+
t.Logf("โœ… Created %d sessions for user", len(sessions))
+
+
// Verify all sessions can be retrieved independently
+
for i, session := range sessions {
+
retrieved, err := store.GetSession(ctx, did, session.SessionID)
+
require.NoError(t, err, "Should be able to retrieve session %d", i+1)
+
assert.Equal(t, session.SessionID, retrieved.SessionID, "Session ID should match")
+
assert.Equal(t, session.AccessToken, retrieved.AccessToken, "Access token should match")
+
}
+
+
t.Logf("โœ… All sessions retrieved independently")
+
+
// Delete one session and verify others remain
+
err = store.DeleteSession(ctx, did, sessions[0].SessionID)
+
require.NoError(t, err, "Should be able to delete first session")
+
+
// Verify first session is deleted
+
_, err = store.GetSession(ctx, did, sessions[0].SessionID)
+
assert.Equal(t, oauth.ErrSessionNotFound, err, "First session should be deleted")
+
+
// Verify other sessions still exist
+
for i := 1; i < len(sessions); i++ {
+
_, err := store.GetSession(ctx, did, sessions[i].SessionID)
+
require.NoError(t, err, "Session %d should still exist", i+1)
+
}
+
+
t.Logf("โœ… Multiple sessions per user verified")
+
+
// Cleanup
+
for i := 1; i < len(sessions); i++ {
+
_ = store.DeleteSession(ctx, did, sessions[i].SessionID)
+
}
+
}
+
+
// TestOAuthE2E_AuthRequestStorage tests OAuth auth request storage and retrieval
+
func TestOAuthE2E_AuthRequestStorage(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping OAuth auth request storage test in short mode")
+
}
+
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
// Run migrations
+
require.NoError(t, goose.SetDialect("postgres"))
+
require.NoError(t, goose.Up(db, "../../internal/db/migrations"))
+
+
ctx := context.Background()
+
+
t.Log("๐Ÿ“ Testing OAuth auth request storage...")
+
+
// Setup OAuth store
+
store := SetupOAuthTestStore(t, db)
+
+
// Create test auth request data
+
did, err := syntax.ParseDID("did:plc:authrequest123")
+
require.NoError(t, err)
+
+
authRequest := oauthlib.AuthRequestData{
+
State: "test-state-12345",
+
AccountDID: &did,
+
PKCEVerifier: "test-pkce-verifier",
+
DPoPPrivateKeyMultibase: "test-dpop-key",
+
DPoPAuthServerNonce: "test-nonce",
+
AuthServerURL: "http://localhost:3001",
+
RequestURI: "http://localhost:3001/authorize",
+
AuthServerTokenEndpoint: "http://localhost:3001/oauth/token",
+
AuthServerRevocationEndpoint: "http://localhost:3001/oauth/revoke",
+
Scopes: []string{"atproto", "transition:generic"},
+
}
+
+
// Save auth request
+
err = store.SaveAuthRequestInfo(ctx, authRequest)
+
require.NoError(t, err, "Should be able to save auth request")
+
+
t.Logf("โœ… Auth request saved")
+
+
// Retrieve auth request
+
retrieved, err := store.GetAuthRequestInfo(ctx, authRequest.State)
+
require.NoError(t, err, "Should be able to retrieve auth request")
+
assert.Equal(t, authRequest.State, retrieved.State, "State should match")
+
assert.Equal(t, authRequest.PKCEVerifier, retrieved.PKCEVerifier, "PKCE verifier should match")
+
assert.Equal(t, authRequest.AuthServerURL, retrieved.AuthServerURL, "Auth server URL should match")
+
assert.Equal(t, len(authRequest.Scopes), len(retrieved.Scopes), "Scopes length should match")
+
+
t.Logf("โœ… Auth request retrieved and verified")
+
+
// Test duplicate state error
+
err = store.SaveAuthRequestInfo(ctx, authRequest)
+
assert.Error(t, err, "Should not allow duplicate state")
+
assert.Contains(t, err.Error(), "already exists", "Error should indicate duplicate")
+
+
t.Logf("โœ… Duplicate state prevention verified")
+
+
// Delete auth request
+
err = store.DeleteAuthRequestInfo(ctx, authRequest.State)
+
require.NoError(t, err, "Should be able to delete auth request")
+
+
// Verify deletion
+
_, err = store.GetAuthRequestInfo(ctx, authRequest.State)
+
assert.Equal(t, oauth.ErrAuthRequestNotFound, err, "Auth request should be deleted")
+
+
t.Logf("โœ… Auth request deletion verified")
+
+
// Test cleanup of expired auth requests
+
// Create an auth request and manually set created_at to the past
+
oldAuthRequest := oauthlib.AuthRequestData{
+
State: "old-state-12345",
+
PKCEVerifier: "old-verifier",
+
AuthServerURL: "http://localhost:3001",
+
Scopes: []string{"atproto"},
+
}
+
+
err = store.SaveAuthRequestInfo(ctx, oldAuthRequest)
+
require.NoError(t, err)
+
+
// Update created_at to 1 hour ago
+
_, err = db.ExecContext(ctx,
+
"UPDATE oauth_requests SET created_at = NOW() - INTERVAL '1 hour' WHERE state = $1",
+
oldAuthRequest.State)
+
require.NoError(t, err)
+
+
// Cleanup expired requests
+
cleaned, err := store.(*oauth.PostgresOAuthStore).CleanupExpiredAuthRequests(ctx)
+
require.NoError(t, err, "Cleanup should succeed")
+
assert.Greater(t, cleaned, int64(0), "Should have cleaned up at least one auth request")
+
+
t.Logf("โœ… Expired auth request cleanup verified (cleaned %d requests)", cleaned)
+
}
+
+
// TestOAuthE2E_TokenRefresh tests the refresh token flow
+
func TestOAuthE2E_TokenRefresh(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping OAuth token refresh test in short mode")
+
}
+
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
// Run migrations
+
require.NoError(t, goose.SetDialect("postgres"))
+
require.NoError(t, goose.Up(db, "../../internal/db/migrations"))
+
+
ctx := context.Background()
+
+
t.Log("๐Ÿ”„ Testing OAuth token refresh flow...")
+
+
// Setup OAuth client and store
+
store := SetupOAuthTestStore(t, db)
+
client := SetupOAuthTestClient(t, store)
+
handler := oauth.NewOAuthHandler(client, store)
+
+
// Create a test DID and session
+
did, err := syntax.ParseDID("did:plc:refreshtest123")
+
require.NoError(t, err)
+
+
// Create initial session with refresh token
+
initialSession := oauthlib.ClientSessionData{
+
AccountDID: did,
+
SessionID: "refresh-session-1",
+
HostURL: "http://localhost:3001",
+
AuthServerURL: "http://localhost:3001",
+
AuthServerTokenEndpoint: "http://localhost:3001/oauth/token",
+
AuthServerRevocationEndpoint: "http://localhost:3001/oauth/revoke",
+
AccessToken: "initial-access-token",
+
RefreshToken: "initial-refresh-token",
+
DPoPPrivateKeyMultibase: "test-dpop-key",
+
DPoPAuthServerNonce: "test-nonce",
+
Scopes: []string{"atproto", "transition:generic"},
+
}
+
+
// Save the session
+
err = store.SaveSession(ctx, initialSession)
+
require.NoError(t, err, "Should save initial session")
+
+
t.Logf("โœ… Initial session created")
+
+
// Create a sealed token for this session
+
sealedToken, err := client.SealSession(did.String(), initialSession.SessionID, time.Hour)
+
require.NoError(t, err, "Should seal session token")
+
require.NotEmpty(t, sealedToken, "Sealed token should not be empty")
+
+
t.Logf("โœ… Session token sealed")
+
+
// Setup test server with refresh endpoint
+
r := chi.NewRouter()
+
r.Post("/oauth/refresh", handler.HandleRefresh)
+
+
server := httptest.NewServer(r)
+
defer server.Close()
+
+
t.Run("Valid refresh request", func(t *testing.T) {
+
// NOTE: This test verifies that the refresh endpoint can be called
+
// In a real scenario, the indigo client's RefreshTokens() would call the PDS
+
// Since we're in a component test, we're testing the Coves handler logic
+
+
// Create refresh request
+
refreshReq := map[string]interface{}{
+
"did": did.String(),
+
"session_id": initialSession.SessionID,
+
"sealed_token": sealedToken,
+
}
+
+
reqBody, err := json.Marshal(refreshReq)
+
require.NoError(t, err)
+
+
req, err := http.NewRequest("POST", server.URL+"/oauth/refresh", strings.NewReader(string(reqBody)))
+
require.NoError(t, err)
+
req.Header.Set("Content-Type", "application/json")
+
+
// NOTE: In component testing mode, the indigo client may not have
+
// real PDS credentials, so RefreshTokens() might fail
+
// We're testing that the handler correctly processes the request
+
resp, err := http.DefaultClient.Do(req)
+
require.NoError(t, err)
+
defer func() { _ = resp.Body.Close() }()
+
+
// In component test mode without real PDS, we may get 401
+
// In production with real PDS, this would return 200 with new tokens
+
t.Logf("Refresh response status: %d", resp.StatusCode)
+
+
// The important thing is that the handler doesn't crash
+
// and properly validates the request structure
+
assert.True(t, resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusUnauthorized,
+
"Refresh should return either success or auth failure, got %d", resp.StatusCode)
+
})
+
+
t.Run("Invalid DID format (with valid token)", func(t *testing.T) {
+
// Create a sealed token with an invalid DID format
+
invalidDID := "invalid-did-format"
+
// Create the token with a valid DID first, then we'll try to use it with invalid DID in request
+
validToken, err := client.SealSession(did.String(), initialSession.SessionID, 30*24*time.Hour)
+
require.NoError(t, err)
+
+
refreshReq := map[string]interface{}{
+
"did": invalidDID, // Invalid DID format in request
+
"session_id": initialSession.SessionID,
+
"sealed_token": validToken, // Valid token for different DID
+
}
+
+
reqBody, err := json.Marshal(refreshReq)
+
require.NoError(t, err)
+
+
req, err := http.NewRequest("POST", server.URL+"/oauth/refresh", strings.NewReader(string(reqBody)))
+
require.NoError(t, err)
+
req.Header.Set("Content-Type", "application/json")
+
+
resp, err := http.DefaultClient.Do(req)
+
require.NoError(t, err)
+
defer func() { _ = resp.Body.Close() }()
+
+
// Should reject with 401 due to DID mismatch (not 400) since auth happens first
+
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode,
+
"DID mismatch should be rejected with 401 (auth check happens before format validation)")
+
})
+
+
t.Run("Missing sealed_token (security test)", func(t *testing.T) {
+
refreshReq := map[string]interface{}{
+
"did": did.String(),
+
"session_id": initialSession.SessionID,
+
// Missing sealed_token - should be rejected for security
+
}
+
+
reqBody, err := json.Marshal(refreshReq)
+
require.NoError(t, err)
+
+
req, err := http.NewRequest("POST", server.URL+"/oauth/refresh", strings.NewReader(string(reqBody)))
+
require.NoError(t, err)
+
req.Header.Set("Content-Type", "application/json")
+
+
resp, err := http.DefaultClient.Do(req)
+
require.NoError(t, err)
+
defer func() { _ = resp.Body.Close() }()
+
+
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode,
+
"Missing sealed_token should be rejected (proof of possession required)")
+
})
+
+
t.Run("Invalid sealed_token", func(t *testing.T) {
+
refreshReq := map[string]interface{}{
+
"did": did.String(),
+
"session_id": initialSession.SessionID,
+
"sealed_token": "invalid-token-data",
+
}
+
+
reqBody, err := json.Marshal(refreshReq)
+
require.NoError(t, err)
+
+
req, err := http.NewRequest("POST", server.URL+"/oauth/refresh", strings.NewReader(string(reqBody)))
+
require.NoError(t, err)
+
req.Header.Set("Content-Type", "application/json")
+
+
resp, err := http.DefaultClient.Do(req)
+
require.NoError(t, err)
+
defer func() { _ = resp.Body.Close() }()
+
+
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode,
+
"Invalid sealed_token should be rejected")
+
})
+
+
t.Run("DID mismatch (security test)", func(t *testing.T) {
+
// Create a sealed token for a different DID
+
wrongDID := "did:plc:wronguser123"
+
wrongToken, err := client.SealSession(wrongDID, initialSession.SessionID, 30*24*time.Hour)
+
require.NoError(t, err)
+
+
// Try to use it to refresh the original session
+
refreshReq := map[string]interface{}{
+
"did": did.String(), // Claiming original DID
+
"session_id": initialSession.SessionID,
+
"sealed_token": wrongToken, // But token is for different DID
+
}
+
+
reqBody, err := json.Marshal(refreshReq)
+
require.NoError(t, err)
+
+
req, err := http.NewRequest("POST", server.URL+"/oauth/refresh", strings.NewReader(string(reqBody)))
+
require.NoError(t, err)
+
req.Header.Set("Content-Type", "application/json")
+
+
resp, err := http.DefaultClient.Do(req)
+
require.NoError(t, err)
+
defer func() { _ = resp.Body.Close() }()
+
+
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode,
+
"DID mismatch should be rejected (prevents session hijacking)")
+
})
+
+
t.Run("Session ID mismatch (security test)", func(t *testing.T) {
+
// Create a sealed token with wrong session ID
+
wrongSessionID := "wrong-session-id"
+
wrongToken, err := client.SealSession(did.String(), wrongSessionID, 30*24*time.Hour)
+
require.NoError(t, err)
+
+
// Try to use it to refresh the original session
+
refreshReq := map[string]interface{}{
+
"did": did.String(),
+
"session_id": initialSession.SessionID, // Claiming original session
+
"sealed_token": wrongToken, // But token is for different session
+
}
+
+
reqBody, err := json.Marshal(refreshReq)
+
require.NoError(t, err)
+
+
req, err := http.NewRequest("POST", server.URL+"/oauth/refresh", strings.NewReader(string(reqBody)))
+
require.NoError(t, err)
+
req.Header.Set("Content-Type", "application/json")
+
+
resp, err := http.DefaultClient.Do(req)
+
require.NoError(t, err)
+
defer func() { _ = resp.Body.Close() }()
+
+
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode,
+
"Session ID mismatch should be rejected (prevents session hijacking)")
+
})
+
+
t.Run("Non-existent session", func(t *testing.T) {
+
// Create a valid sealed token for a non-existent session
+
nonExistentSessionID := "nonexistent-session-id"
+
validToken, err := client.SealSession(did.String(), nonExistentSessionID, 30*24*time.Hour)
+
require.NoError(t, err)
+
+
refreshReq := map[string]interface{}{
+
"did": did.String(),
+
"session_id": nonExistentSessionID,
+
"sealed_token": validToken, // Valid token but session doesn't exist
+
}
+
+
reqBody, err := json.Marshal(refreshReq)
+
require.NoError(t, err)
+
+
req, err := http.NewRequest("POST", server.URL+"/oauth/refresh", strings.NewReader(string(reqBody)))
+
require.NoError(t, err)
+
req.Header.Set("Content-Type", "application/json")
+
+
resp, err := http.DefaultClient.Do(req)
+
require.NoError(t, err)
+
defer func() { _ = resp.Body.Close() }()
+
+
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode,
+
"Non-existent session should be rejected with 401")
+
})
+
+
t.Logf("โœ… Token refresh endpoint validation verified")
+
}
+
+
// TestOAuthE2E_SessionUpdate tests that refresh updates the session in database
+
func TestOAuthE2E_SessionUpdate(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping OAuth session update test in short mode")
+
}
+
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
// Run migrations
+
require.NoError(t, goose.SetDialect("postgres"))
+
require.NoError(t, goose.Up(db, "../../internal/db/migrations"))
+
+
ctx := context.Background()
+
+
t.Log("๐Ÿ’พ Testing OAuth session update on refresh...")
+
+
// Setup OAuth store
+
store := SetupOAuthTestStore(t, db)
+
+
// Create a test session
+
did, err := syntax.ParseDID("did:plc:sessionupdate123")
+
require.NoError(t, err)
+
+
originalSession := oauthlib.ClientSessionData{
+
AccountDID: did,
+
SessionID: "update-session-1",
+
HostURL: "http://localhost:3001",
+
AuthServerURL: "http://localhost:3001",
+
AuthServerTokenEndpoint: "http://localhost:3001/oauth/token",
+
AccessToken: "original-access-token",
+
RefreshToken: "original-refresh-token",
+
DPoPPrivateKeyMultibase: "original-dpop-key",
+
Scopes: []string{"atproto"},
+
}
+
+
// Save original session
+
err = store.SaveSession(ctx, originalSession)
+
require.NoError(t, err)
+
+
t.Logf("โœ… Original session saved")
+
+
// Simulate a token refresh by updating the session with new tokens
+
updatedSession := originalSession
+
updatedSession.AccessToken = "new-access-token"
+
updatedSession.RefreshToken = "new-refresh-token"
+
updatedSession.DPoPAuthServerNonce = "new-nonce"
+
+
// Update the session (upsert)
+
err = store.SaveSession(ctx, updatedSession)
+
require.NoError(t, err)
+
+
t.Logf("โœ… Session updated with new tokens")
+
+
// Retrieve the session and verify it was updated
+
retrieved, err := store.GetSession(ctx, did, originalSession.SessionID)
+
require.NoError(t, err, "Should retrieve updated session")
+
+
assert.Equal(t, "new-access-token", retrieved.AccessToken,
+
"Access token should be updated")
+
assert.Equal(t, "new-refresh-token", retrieved.RefreshToken,
+
"Refresh token should be updated")
+
assert.Equal(t, "new-nonce", retrieved.DPoPAuthServerNonce,
+
"DPoP nonce should be updated")
+
+
// Verify session ID and DID remain the same
+
assert.Equal(t, originalSession.SessionID, retrieved.SessionID,
+
"Session ID should remain the same")
+
assert.Equal(t, did, retrieved.AccountDID,
+
"DID should remain the same")
+
+
t.Logf("โœ… Session update verified - tokens refreshed in database")
+
+
// Verify updated_at was changed
+
var updatedAt time.Time
+
err = db.QueryRowContext(ctx,
+
"SELECT updated_at FROM oauth_sessions WHERE did = $1 AND session_id = $2",
+
did.String(), originalSession.SessionID).Scan(&updatedAt)
+
require.NoError(t, err)
+
+
// Updated timestamp should be recent (within last minute)
+
assert.WithinDuration(t, time.Now(), updatedAt, time.Minute,
+
"Session updated_at should be recent")
+
+
t.Logf("โœ… Session timestamp update verified")
+
}
+
+
// TestOAuthE2E_RefreshTokenRotation tests refresh token rotation behavior
+
func TestOAuthE2E_RefreshTokenRotation(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping OAuth refresh token rotation test in short mode")
+
}
+
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
// Run migrations
+
require.NoError(t, goose.SetDialect("postgres"))
+
require.NoError(t, goose.Up(db, "../../internal/db/migrations"))
+
+
ctx := context.Background()
+
+
t.Log("๐Ÿ”„ Testing OAuth refresh token rotation...")
+
+
// Setup OAuth store
+
store := SetupOAuthTestStore(t, db)
+
+
// Create a test session
+
did, err := syntax.ParseDID("did:plc:rotation123")
+
require.NoError(t, err)
+
+
// Simulate multiple refresh cycles
+
sessionID := "rotation-session-1"
+
tokens := []struct {
+
access string
+
refresh string
+
}{
+
{"access-token-v1", "refresh-token-v1"},
+
{"access-token-v2", "refresh-token-v2"},
+
{"access-token-v3", "refresh-token-v3"},
+
}
+
+
for i, tokenPair := range tokens {
+
session := oauthlib.ClientSessionData{
+
AccountDID: did,
+
SessionID: sessionID,
+
HostURL: "http://localhost:3001",
+
AuthServerURL: "http://localhost:3001",
+
AuthServerTokenEndpoint: "http://localhost:3001/oauth/token",
+
AccessToken: tokenPair.access,
+
RefreshToken: tokenPair.refresh,
+
Scopes: []string{"atproto"},
+
}
+
+
// Save/update session
+
err = store.SaveSession(ctx, session)
+
require.NoError(t, err, "Should save session iteration %d", i+1)
+
+
// Retrieve and verify
+
retrieved, err := store.GetSession(ctx, did, sessionID)
+
require.NoError(t, err, "Should retrieve session iteration %d", i+1)
+
+
assert.Equal(t, tokenPair.access, retrieved.AccessToken,
+
"Access token should match iteration %d", i+1)
+
assert.Equal(t, tokenPair.refresh, retrieved.RefreshToken,
+
"Refresh token should match iteration %d", i+1)
+
+
// Small delay to ensure timestamp differences
+
time.Sleep(10 * time.Millisecond)
+
}
+
+
t.Logf("โœ… Refresh token rotation verified through %d cycles", len(tokens))
+
+
// Verify final state
+
finalSession, err := store.GetSession(ctx, did, sessionID)
+
require.NoError(t, err)
+
+
assert.Equal(t, "access-token-v3", finalSession.AccessToken,
+
"Final access token should be from last rotation")
+
assert.Equal(t, "refresh-token-v3", finalSession.RefreshToken,
+
"Final refresh token should be from last rotation")
+
+
t.Logf("โœ… Token rotation state verified")
+
}
+312
tests/integration/oauth_session_fixation_test.go
···
+
package integration
+
+
import (
+
"Coves/internal/atproto/oauth"
+
"context"
+
"crypto/sha256"
+
"encoding/base64"
+
"net/http"
+
"net/http/httptest"
+
"net/url"
+
"testing"
+
"time"
+
+
oauthlib "github.com/bluesky-social/indigo/atproto/auth/oauth"
+
"github.com/bluesky-social/indigo/atproto/syntax"
+
"github.com/go-chi/chi/v5"
+
"github.com/pressly/goose/v3"
+
"github.com/stretchr/testify/assert"
+
"github.com/stretchr/testify/require"
+
)
+
+
// TestOAuth_SessionFixationAttackPrevention tests that the mobile redirect binding
+
// prevents session fixation attacks where an attacker plants a mobile_redirect_uri
+
// cookie, then the user does a web login, and credentials get sent to attacker's deep link.
+
//
+
// Attack scenario:
+
// 1. Attacker tricks user into visiting /oauth/mobile/login?redirect_uri=evil://steal
+
// 2. This plants a mobile_redirect_uri cookie (lives 10 minutes)
+
// 3. User later does normal web OAuth login via /oauth/login
+
// 4. HandleCallback sees the stale mobile_redirect_uri cookie
+
// 5. WITHOUT THE FIX: Callback sends sealed token, DID, session_id to attacker's deep link
+
// 6. WITH THE FIX: Binding mismatch is detected, mobile cookies cleared, user gets web session
+
func TestOAuth_SessionFixationAttackPrevention(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping OAuth session fixation test in short mode")
+
}
+
+
// Setup test database
+
db := setupTestDB(t)
+
defer func() {
+
if err := db.Close(); err != nil {
+
t.Logf("Failed to close database: %v", err)
+
}
+
}()
+
+
// Run migrations
+
require.NoError(t, goose.SetDialect("postgres"))
+
require.NoError(t, goose.Up(db, "../../internal/db/migrations"))
+
+
// Setup OAuth client and store
+
store := SetupOAuthTestStore(t, db)
+
client := SetupOAuthTestClient(t, store)
+
require.NotNil(t, client, "OAuth client should be initialized")
+
+
// Setup handler
+
handler := oauth.NewOAuthHandler(client, store)
+
+
// Setup router
+
r := chi.NewRouter()
+
r.Get("/oauth/callback", handler.HandleCallback)
+
+
t.Run("attack scenario - planted mobile cookie without binding", func(t *testing.T) {
+
ctx := context.Background()
+
+
// Step 1: Simulate a successful OAuth callback (like a user did web login)
+
// We'll create a mock session to simulate what ProcessCallback would return
+
testDID := "did:plc:test123456"
+
parsedDID, err := syntax.ParseDID(testDID)
+
require.NoError(t, err)
+
+
sessionID := "test-session-" + time.Now().Format("20060102150405")
+
testSession := oauthlib.ClientSessionData{
+
AccountDID: parsedDID,
+
SessionID: sessionID,
+
HostURL: "http://localhost:3001",
+
AccessToken: "test-access-token",
+
Scopes: []string{"atproto"},
+
}
+
+
// Save the session (simulating successful OAuth flow)
+
err = store.SaveSession(ctx, testSession)
+
require.NoError(t, err)
+
+
// Step 2: Attacker planted a mobile_redirect_uri cookie (without binding)
+
// This simulates the cookie being planted earlier by attacker
+
attackerRedirectURI := "evil://steal"
+
req := httptest.NewRequest("GET", "/oauth/callback?code=test&state=test&iss=http://localhost:3001", nil)
+
+
// Plant the attacker's cookie (URL escaped as it would be in real scenario)
+
req.AddCookie(&http.Cookie{
+
Name: "mobile_redirect_uri",
+
Value: url.QueryEscape(attackerRedirectURI),
+
Path: "/oauth",
+
})
+
// NOTE: No mobile_redirect_binding cookie! This is the attack scenario.
+
+
rec := httptest.NewRecorder()
+
+
// Step 3: Try to process the callback
+
// This would fail because ProcessCallback needs real OAuth code/state
+
// For this test, we're verifying the handler's security checks work
+
// even before ProcessCallback is called
+
+
// The handler will try to call ProcessCallback which will fail
+
// But we're testing that even if it succeeded, the mobile redirect
+
// validation would prevent the attack
+
handler.HandleCallback(rec, req)
+
+
// Step 4: Verify the attack was prevented
+
// The handler should reject the request due to missing binding
+
// Since ProcessCallback will fail first (no real OAuth code), we expect
+
// a 400 error, but the important thing is it doesn't redirect to evil://steal
+
+
assert.NotEqual(t, http.StatusFound, rec.Code,
+
"Should not redirect when ProcessCallback fails")
+
assert.NotContains(t, rec.Header().Get("Location"), "evil://",
+
"Should never redirect to attacker's URI")
+
})
+
+
t.Run("legitimate mobile flow - with valid binding", func(t *testing.T) {
+
ctx := context.Background()
+
+
// Setup a legitimate mobile session
+
testDID := "did:plc:mobile123"
+
parsedDID, err := syntax.ParseDID(testDID)
+
require.NoError(t, err)
+
+
sessionID := "mobile-session-" + time.Now().Format("20060102150405")
+
testSession := oauthlib.ClientSessionData{
+
AccountDID: parsedDID,
+
SessionID: sessionID,
+
HostURL: "http://localhost:3001",
+
AccessToken: "mobile-access-token",
+
Scopes: []string{"atproto"},
+
}
+
+
// Save the session
+
err = store.SaveSession(ctx, testSession)
+
require.NoError(t, err)
+
+
// Create request with BOTH mobile_redirect_uri AND valid binding
+
// Use Universal Link URI that's in the allowlist
+
legitRedirectURI := "https://coves.social/app/oauth/callback"
+
csrfToken := "valid-csrf-token-for-mobile"
+
req := httptest.NewRequest("GET", "/oauth/callback?code=test&state=test&iss=http://localhost:3001", nil)
+
+
// Add mobile redirect URI cookie
+
req.AddCookie(&http.Cookie{
+
Name: "mobile_redirect_uri",
+
Value: url.QueryEscape(legitRedirectURI),
+
Path: "/oauth",
+
})
+
+
// Add CSRF token (required for mobile flow)
+
req.AddCookie(&http.Cookie{
+
Name: "oauth_csrf",
+
Value: csrfToken,
+
Path: "/oauth",
+
})
+
+
// Add VALID binding cookie (this is what prevents the attack)
+
// In real flow, this would be set by HandleMobileLogin
+
// The binding now includes the CSRF token for double-submit validation
+
mobileBinding := generateMobileRedirectBindingForTest(csrfToken, legitRedirectURI)
+
req.AddCookie(&http.Cookie{
+
Name: "mobile_redirect_binding",
+
Value: mobileBinding,
+
Path: "/oauth",
+
})
+
+
rec := httptest.NewRecorder()
+
handler.HandleCallback(rec, req)
+
+
// This will also fail at ProcessCallback (no real OAuth code)
+
// but we're verifying the binding validation logic is in place
+
// In a real integration test with PDS, this would succeed
+
assert.NotEqual(t, http.StatusFound, rec.Code,
+
"Should not redirect when ProcessCallback fails (expected in mock test)")
+
})
+
+
t.Run("binding mismatch - attacker tries wrong binding", func(t *testing.T) {
+
ctx := context.Background()
+
+
// Setup session
+
testDID := "did:plc:bindingtest"
+
parsedDID, err := syntax.ParseDID(testDID)
+
require.NoError(t, err)
+
+
sessionID := "binding-test-" + time.Now().Format("20060102150405")
+
testSession := oauthlib.ClientSessionData{
+
AccountDID: parsedDID,
+
SessionID: sessionID,
+
HostURL: "http://localhost:3001",
+
AccessToken: "binding-test-token",
+
Scopes: []string{"atproto"},
+
}
+
+
err = store.SaveSession(ctx, testSession)
+
require.NoError(t, err)
+
+
// Attacker tries to plant evil redirect with a binding from different URI
+
attackerRedirectURI := "evil://steal"
+
attackerCSRF := "attacker-csrf-token"
+
req := httptest.NewRequest("GET", "/oauth/callback?code=test&state=test&iss=http://localhost:3001", nil)
+
+
req.AddCookie(&http.Cookie{
+
Name: "mobile_redirect_uri",
+
Value: url.QueryEscape(attackerRedirectURI),
+
Path: "/oauth",
+
})
+
+
req.AddCookie(&http.Cookie{
+
Name: "oauth_csrf",
+
Value: attackerCSRF,
+
Path: "/oauth",
+
})
+
+
// Use binding from a DIFFERENT CSRF token and URI (attacker's attempt to forge)
+
// Even if attacker knows the redirect URI, they don't know the user's CSRF token
+
wrongBinding := generateMobileRedirectBindingForTest("different-csrf", "https://coves.social/app/oauth/callback")
+
req.AddCookie(&http.Cookie{
+
Name: "mobile_redirect_binding",
+
Value: wrongBinding,
+
Path: "/oauth",
+
})
+
+
rec := httptest.NewRecorder()
+
handler.HandleCallback(rec, req)
+
+
// Should fail due to binding mismatch (even before ProcessCallback)
+
// The binding validation happens after ProcessCallback in the real code,
+
// but the mismatch would be caught and cookies cleared
+
assert.NotContains(t, rec.Header().Get("Location"), "evil://",
+
"Should never redirect to attacker's URI on binding mismatch")
+
})
+
+
t.Run("CSRF token value mismatch - attacker tries different CSRF", func(t *testing.T) {
+
ctx := context.Background()
+
+
// Setup session
+
testDID := "did:plc:csrftest"
+
parsedDID, err := syntax.ParseDID(testDID)
+
require.NoError(t, err)
+
+
sessionID := "csrf-test-" + time.Now().Format("20060102150405")
+
testSession := oauthlib.ClientSessionData{
+
AccountDID: parsedDID,
+
SessionID: sessionID,
+
HostURL: "http://localhost:3001",
+
AccessToken: "csrf-test-token",
+
Scopes: []string{"atproto"},
+
}
+
+
err = store.SaveSession(ctx, testSession)
+
require.NoError(t, err)
+
+
// This tests the P1 security fix: CSRF token VALUE must be validated, not just presence
+
// Attack scenario:
+
// 1. User starts mobile login with CSRF token A and redirect URI X
+
// 2. Binding = hash(A + X) is stored in cookie
+
// 3. Attacker somehow gets user to have CSRF token B in cookie (different from A)
+
// 4. Callback receives CSRF token B, redirect URI X, binding = hash(A + X)
+
// 5. hash(B + X) != hash(A + X), so attack is detected
+
+
originalCSRF := "original-csrf-token-set-at-login"
+
redirectURI := "https://coves.social/app/oauth/callback"
+
// Binding was created with original CSRF token
+
originalBinding := generateMobileRedirectBindingForTest(originalCSRF, redirectURI)
+
+
// But attacker managed to change the CSRF cookie
+
attackerCSRF := "attacker-replaced-csrf"
+
+
req := httptest.NewRequest("GET", "/oauth/callback?code=test&state=test&iss=http://localhost:3001", nil)
+
+
req.AddCookie(&http.Cookie{
+
Name: "mobile_redirect_uri",
+
Value: url.QueryEscape(redirectURI),
+
Path: "/oauth",
+
})
+
+
// Attacker's CSRF token (different from what created the binding)
+
req.AddCookie(&http.Cookie{
+
Name: "oauth_csrf",
+
Value: attackerCSRF,
+
Path: "/oauth",
+
})
+
+
// Original binding (created with original CSRF token)
+
req.AddCookie(&http.Cookie{
+
Name: "mobile_redirect_binding",
+
Value: originalBinding,
+
Path: "/oauth",
+
})
+
+
rec := httptest.NewRecorder()
+
handler.HandleCallback(rec, req)
+
+
// Should fail because hash(attackerCSRF + redirectURI) != hash(originalCSRF + redirectURI)
+
// This is the key security fix - CSRF token VALUE is now validated
+
assert.NotEqual(t, http.StatusFound, rec.Code,
+
"Should not redirect when CSRF token doesn't match binding")
+
})
+
}
+
+
// generateMobileRedirectBindingForTest generates a binding for testing
+
// This mirrors the actual logic in handlers_security.go:
+
// binding = base64(sha256(csrfToken + "|" + redirectURI)[:16])
+
func generateMobileRedirectBindingForTest(csrfToken, mobileRedirectURI string) string {
+
combined := csrfToken + "|" + mobileRedirectURI
+
hash := sha256.Sum256([]byte(combined))
+
return base64.URLEncoding.EncodeToString(hash[:16])
+
}
+169
tests/integration/oauth_token_verification_test.go
···
+
package integration
+
+
import (
+
"Coves/internal/api/middleware"
+
"fmt"
+
"net/http"
+
"net/http/httptest"
+
"os"
+
"testing"
+
"time"
+
)
+
+
// TestOAuthTokenVerification tests end-to-end OAuth token verification
+
// with real PDS-issued OAuth tokens. This replaces the old JWT verification test
+
// since we now use OAuth sealed session tokens instead of raw JWTs.
+
//
+
// Flow:
+
// 1. Create account on local PDS (or use existing)
+
// 2. Authenticate to get OAuth tokens and create sealed session token
+
// 3. Verify our auth middleware can unseal and validate the token
+
// 4. Test token validation and session retrieval
+
//
+
// NOTE: This test uses the E2E OAuth middleware which mocks the session unsealing
+
// for testing purposes. Real OAuth tokens from PDS would be sealed using the
+
// OAuth client's seal secret.
+
func TestOAuthTokenVerification(t *testing.T) {
+
// Skip in short mode since this requires real PDS
+
if testing.Short() {
+
t.Skip("Skipping OAuth token verification test in short mode")
+
}
+
+
pdsURL := os.Getenv("PDS_URL")
+
if pdsURL == "" {
+
pdsURL = "http://localhost:3001"
+
}
+
+
// Check if PDS is running
+
healthResp, err := http.Get(pdsURL + "/xrpc/_health")
+
if err != nil {
+
t.Skipf("PDS not running at %s: %v", pdsURL, err)
+
}
+
_ = healthResp.Body.Close()
+
+
t.Run("OAuth token validation and middleware integration", func(t *testing.T) {
+
// Step 1: Create a test account on PDS
+
// Keep handle short to avoid PDS validation errors
+
timestamp := time.Now().Unix() % 100000 // Last 5 digits
+
handle := fmt.Sprintf("oauth%d.local.coves.dev", timestamp)
+
password := "testpass123"
+
email := fmt.Sprintf("oauth%d@test.com", timestamp)
+
+
_, did, err := createPDSAccount(pdsURL, handle, email, password)
+
if err != nil {
+
t.Fatalf("Failed to create PDS account: %v", err)
+
}
+
t.Logf("โœ“ Created test account: %s (DID: %s)", handle, did)
+
+
// Step 2: Create OAuth middleware with mock unsealer for testing
+
// In production, this would unseal real OAuth tokens from PDS
+
t.Log("Testing OAuth middleware with sealed session tokens...")
+
+
e2eAuth := NewE2EOAuthMiddleware()
+
testToken := e2eAuth.AddUser(did)
+
+
handlerCalled := false
+
var extractedDID string
+
+
testHandler := e2eAuth.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
handlerCalled = true
+
extractedDID = middleware.GetUserDID(r)
+
w.WriteHeader(http.StatusOK)
+
_, _ = w.Write([]byte(`{"success": true}`))
+
}))
+
+
req := httptest.NewRequest("GET", "/test", nil)
+
req.Header.Set("Authorization", "Bearer "+testToken)
+
w := httptest.NewRecorder()
+
+
testHandler.ServeHTTP(w, req)
+
+
if !handlerCalled {
+
t.Errorf("Handler was not called - auth middleware rejected valid token")
+
t.Logf("Response status: %d", w.Code)
+
t.Logf("Response body: %s", w.Body.String())
+
}
+
+
if w.Code != http.StatusOK {
+
t.Errorf("Expected status 200, got %d", w.Code)
+
t.Logf("Response body: %s", w.Body.String())
+
}
+
+
if extractedDID != did {
+
t.Errorf("Middleware extracted wrong DID: expected %s, got %s", did, extractedDID)
+
}
+
+
t.Logf("โœ… OAuth middleware with token validation working correctly!")
+
t.Logf(" Handler called: %v", handlerCalled)
+
t.Logf(" Extracted DID: %s", extractedDID)
+
t.Logf(" Response status: %d", w.Code)
+
})
+
+
t.Run("Rejects tampered/invalid sealed tokens", func(t *testing.T) {
+
// Create valid user
+
timestamp := time.Now().Unix() % 100000
+
handle := fmt.Sprintf("tamp%d.local.coves.dev", timestamp)
+
password := "testpass456"
+
email := fmt.Sprintf("tamp%d@test.com", timestamp)
+
+
_, did, err := createPDSAccount(pdsURL, handle, email, password)
+
if err != nil {
+
t.Fatalf("Failed to create PDS account: %v", err)
+
}
+
+
// Create OAuth middleware
+
e2eAuth := NewE2EOAuthMiddleware()
+
validToken := e2eAuth.AddUser(did)
+
+
// Create various invalid tokens to test
+
testCases := []struct {
+
name string
+
token string
+
}{
+
{"Empty token", ""},
+
{"Invalid base64", "not-valid-base64!!!"},
+
{"Tampered token", "dGFtcGVyZWQtdG9rZW4tZGF0YQ=="}, // Valid base64 but not a real sealed session
+
{"Short token", "abc"},
+
{"Modified valid token", validToken + "extra"},
+
}
+
+
for _, tc := range testCases {
+
t.Run(tc.name, func(t *testing.T) {
+
handlerCalled := false
+
testHandler := e2eAuth.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
handlerCalled = true
+
w.WriteHeader(http.StatusOK)
+
}))
+
+
req := httptest.NewRequest("GET", "/test", nil)
+
if tc.token != "" {
+
req.Header.Set("Authorization", "Bearer "+tc.token)
+
}
+
w := httptest.NewRecorder()
+
+
testHandler.ServeHTTP(w, req)
+
+
if handlerCalled {
+
t.Error("Handler was called for invalid token - should have been rejected")
+
}
+
+
if w.Code != http.StatusUnauthorized {
+
t.Errorf("Expected status 401 for invalid token, got %d", w.Code)
+
}
+
+
t.Logf("โœ“ Middleware correctly rejected %s with status %d", tc.name, w.Code)
+
})
+
}
+
+
t.Logf("โœ… All invalid token types correctly rejected")
+
})
+
+
t.Run("Session expiration handling", func(t *testing.T) {
+
// OAuth session expiration is handled at the database level
+
// See TestOAuthE2E_TokenExpiration in oauth_e2e_test.go for full expiration testing
+
t.Log("โ„น๏ธ Session expiration testing is covered in oauth_e2e_test.go")
+
t.Log(" OAuth sessions expire based on database timestamps and are cleaned up periodically")
+
t.Log(" This is different from JWT expiration which was timestamp-based in the token itself")
+
t.Skip("Session expiration is tested in oauth_e2e_test.go - see TestOAuthE2E_TokenExpiration")
+
})
+
}
-73
cmd/genjwks/main.go
···
-
package main
-
-
import (
-
"crypto/ecdsa"
-
"crypto/elliptic"
-
"crypto/rand"
-
"encoding/json"
-
"fmt"
-
"log"
-
"os"
-
-
"github.com/lestrrat-go/jwx/v2/jwk"
-
)
-
-
// genjwks generates an ES256 keypair for OAuth client authentication
-
// The private key is stored in the config/env, public key is served at /oauth/jwks.json
-
//
-
// Usage:
-
//
-
// go run cmd/genjwks/main.go
-
//
-
// This will output a JSON private key that should be stored in OAUTH_PRIVATE_JWK
-
func main() {
-
fmt.Println("Generating ES256 keypair for OAuth client authentication...")
-
-
// Generate ES256 (NIST P-256) private key
-
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
-
if err != nil {
-
log.Fatalf("Failed to generate private key: %v", err)
-
}
-
-
// Convert to JWK
-
jwkKey, err := jwk.FromRaw(privateKey)
-
if err != nil {
-
log.Fatalf("Failed to create JWK from private key: %v", err)
-
}
-
-
// Set key parameters
-
if err = jwkKey.Set(jwk.KeyIDKey, "oauth-client-key"); err != nil {
-
log.Fatalf("Failed to set kid: %v", err)
-
}
-
if err = jwkKey.Set(jwk.AlgorithmKey, "ES256"); err != nil {
-
log.Fatalf("Failed to set alg: %v", err)
-
}
-
if err = jwkKey.Set(jwk.KeyUsageKey, "sig"); err != nil {
-
log.Fatalf("Failed to set use: %v", err)
-
}
-
-
// Marshal to JSON
-
jsonData, err := json.MarshalIndent(jwkKey, "", " ")
-
if err != nil {
-
log.Fatalf("Failed to marshal JWK: %v", err)
-
}
-
-
// Output instructions
-
fmt.Println("\nโœ… ES256 keypair generated successfully!")
-
fmt.Println("\n๐Ÿ“ Add this to your .env.dev file:")
-
fmt.Println("\nOAUTH_PRIVATE_JWK='" + string(jsonData) + "'")
-
fmt.Println("\nโš ๏ธ IMPORTANT:")
-
fmt.Println(" - Keep this private key SECRET")
-
fmt.Println(" - Never commit it to version control")
-
fmt.Println(" - Generate a new key for production")
-
fmt.Println(" - The public key will be automatically derived and served at /oauth/jwks.json")
-
-
// Optionally write to a file (not committed)
-
if len(os.Args) > 1 && os.Args[1] == "--save" {
-
filename := "oauth-private-key.json"
-
if err := os.WriteFile(filename, jsonData, 0o600); err != nil {
-
log.Fatalf("Failed to write key file: %v", err)
-
}
-
fmt.Printf("\n๐Ÿ’พ Private key saved to %s (remember to add to .gitignore!)\n", filename)
-
}
-
}
-52
internal/atproto/auth/combined_key_fetcher.go
···
-
package auth
-
-
import (
-
"context"
-
"fmt"
-
"strings"
-
-
indigoIdentity "github.com/bluesky-social/indigo/atproto/identity"
-
)
-
-
// CombinedKeyFetcher handles JWT public key fetching for both:
-
// - DID issuers (did:plc:, did:web:) โ†’ resolves via DID document
-
// - URL issuers (https://) โ†’ fetches via JWKS endpoint (legacy/fallback)
-
//
-
// For atproto service authentication, the issuer is typically the user's DID,
-
// and the signing key is published in their DID document.
-
type CombinedKeyFetcher struct {
-
didFetcher *DIDKeyFetcher
-
jwksFetcher JWKSFetcher
-
}
-
-
// NewCombinedKeyFetcher creates a key fetcher that supports both DID and URL issuers.
-
// Parameters:
-
// - directory: Indigo's identity directory for DID resolution
-
// - jwksFetcher: fallback JWKS fetcher for URL issuers (can be nil if not needed)
-
func NewCombinedKeyFetcher(directory indigoIdentity.Directory, jwksFetcher JWKSFetcher) *CombinedKeyFetcher {
-
return &CombinedKeyFetcher{
-
didFetcher: NewDIDKeyFetcher(directory),
-
jwksFetcher: jwksFetcher,
-
}
-
}
-
-
// FetchPublicKey fetches the public key for verifying a JWT.
-
// Routes to the appropriate fetcher based on issuer format:
-
// - DID (did:plc:, did:web:) โ†’ DIDKeyFetcher
-
// - URL (https://) โ†’ JWKSFetcher
-
func (f *CombinedKeyFetcher) FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) {
-
// Check if issuer is a DID
-
if strings.HasPrefix(issuer, "did:") {
-
return f.didFetcher.FetchPublicKey(ctx, issuer, token)
-
}
-
-
// Check if issuer is a URL (https:// or http:// in dev)
-
if strings.HasPrefix(issuer, "https://") || strings.HasPrefix(issuer, "http://") {
-
if f.jwksFetcher == nil {
-
return nil, fmt.Errorf("URL issuer %s requires JWKS fetcher, but none configured", issuer)
-
}
-
return f.jwksFetcher.FetchPublicKey(ctx, issuer, token)
-
}
-
-
return nil, fmt.Errorf("unsupported issuer format: %s (expected DID or URL)", issuer)
-
}
-616
internal/atproto/auth/dpop.go
···
-
package auth
-
-
import (
-
"crypto/sha256"
-
"encoding/base64"
-
"encoding/json"
-
"fmt"
-
"strings"
-
"sync"
-
"time"
-
-
indigoCrypto "github.com/bluesky-social/indigo/atproto/atcrypto"
-
"github.com/golang-jwt/jwt/v5"
-
)
-
-
// NonceCache provides replay protection for DPoP proofs by tracking seen jti values.
-
// This prevents an attacker from reusing a captured DPoP proof within the validity window.
-
// Per RFC 9449 Section 11.1, servers SHOULD prevent replay attacks.
-
type NonceCache struct {
-
seen map[string]time.Time // jti -> expiration time
-
stopCh chan struct{}
-
maxAge time.Duration // How long to keep entries
-
cleanup time.Duration // How often to clean up expired entries
-
mu sync.RWMutex
-
}
-
-
// NewNonceCache creates a new nonce cache for DPoP replay protection.
-
// maxAge should match or exceed DPoPVerifier.MaxProofAge.
-
func NewNonceCache(maxAge time.Duration) *NonceCache {
-
nc := &NonceCache{
-
seen: make(map[string]time.Time),
-
maxAge: maxAge,
-
cleanup: maxAge / 2, // Clean up at half the max age
-
stopCh: make(chan struct{}),
-
}
-
-
// Start background cleanup goroutine
-
go nc.cleanupLoop()
-
-
return nc
-
}
-
-
// CheckAndStore checks if a jti has been seen before and stores it if not.
-
// Returns true if the jti is fresh (not a replay), false if it's a replay.
-
func (nc *NonceCache) CheckAndStore(jti string) bool {
-
nc.mu.Lock()
-
defer nc.mu.Unlock()
-
-
now := time.Now()
-
expiry := now.Add(nc.maxAge)
-
-
// Check if already seen
-
if existingExpiry, seen := nc.seen[jti]; seen {
-
// Still valid (not expired) - this is a replay
-
if existingExpiry.After(now) {
-
return false
-
}
-
// Expired entry - allow reuse and update expiry
-
}
-
-
// Store the new jti
-
nc.seen[jti] = expiry
-
return true
-
}
-
-
// cleanupLoop periodically removes expired entries from the cache
-
func (nc *NonceCache) cleanupLoop() {
-
ticker := time.NewTicker(nc.cleanup)
-
defer ticker.Stop()
-
-
for {
-
select {
-
case <-ticker.C:
-
nc.cleanupExpired()
-
case <-nc.stopCh:
-
return
-
}
-
}
-
}
-
-
// cleanupExpired removes expired entries from the cache
-
func (nc *NonceCache) cleanupExpired() {
-
nc.mu.Lock()
-
defer nc.mu.Unlock()
-
-
now := time.Now()
-
for jti, expiry := range nc.seen {
-
if expiry.Before(now) {
-
delete(nc.seen, jti)
-
}
-
}
-
}
-
-
// Stop stops the cleanup goroutine. Call this when done with the cache.
-
func (nc *NonceCache) Stop() {
-
close(nc.stopCh)
-
}
-
-
// Size returns the number of entries in the cache (for testing/monitoring)
-
func (nc *NonceCache) Size() int {
-
nc.mu.RLock()
-
defer nc.mu.RUnlock()
-
return len(nc.seen)
-
}
-
-
// DPoPClaims represents the claims in a DPoP proof JWT (RFC 9449)
-
type DPoPClaims struct {
-
jwt.RegisteredClaims
-
-
// HTTP method of the request (e.g., "GET", "POST")
-
HTTPMethod string `json:"htm"`
-
-
// HTTP URI of the request (without query and fragment parts)
-
HTTPURI string `json:"htu"`
-
-
// Access token hash (optional, for token binding)
-
AccessTokenHash string `json:"ath,omitempty"`
-
}
-
-
// DPoPProof represents a parsed and verified DPoP proof
-
type DPoPProof struct {
-
RawPublicJWK map[string]interface{}
-
Claims *DPoPClaims
-
PublicKey interface{} // *ecdsa.PublicKey or similar
-
Thumbprint string // JWK thumbprint (base64url)
-
}
-
-
// DPoPVerifier verifies DPoP proofs for OAuth token binding
-
type DPoPVerifier struct {
-
// Optional: custom nonce validation function (for server-issued nonces)
-
ValidateNonce func(nonce string) bool
-
-
// NonceCache for replay protection (optional but recommended)
-
// If nil, jti replay protection is disabled
-
NonceCache *NonceCache
-
-
// Maximum allowed clock skew for timestamp validation
-
MaxClockSkew time.Duration
-
-
// Maximum age of DPoP proof (prevents replay with old proofs)
-
MaxProofAge time.Duration
-
}
-
-
// NewDPoPVerifier creates a DPoP verifier with sensible defaults including replay protection
-
func NewDPoPVerifier() *DPoPVerifier {
-
maxProofAge := 5 * time.Minute
-
return &DPoPVerifier{
-
MaxClockSkew: 30 * time.Second,
-
MaxProofAge: maxProofAge,
-
NonceCache: NewNonceCache(maxProofAge),
-
}
-
}
-
-
// NewDPoPVerifierWithoutReplayProtection creates a DPoP verifier without replay protection.
-
// This should only be used in testing or when replay protection is handled externally.
-
func NewDPoPVerifierWithoutReplayProtection() *DPoPVerifier {
-
return &DPoPVerifier{
-
MaxClockSkew: 30 * time.Second,
-
MaxProofAge: 5 * time.Minute,
-
NonceCache: nil, // No replay protection
-
}
-
}
-
-
// Stop stops background goroutines. Call this when shutting down.
-
func (v *DPoPVerifier) Stop() {
-
if v.NonceCache != nil {
-
v.NonceCache.Stop()
-
}
-
}
-
-
// VerifyDPoPProof verifies a DPoP proof JWT and returns the parsed proof.
-
// This supports all atProto-compatible ECDSA algorithms including ES256K (secp256k1).
-
func (v *DPoPVerifier) VerifyDPoPProof(dpopProof, httpMethod, httpURI string) (*DPoPProof, error) {
-
// Manually parse the JWT to support ES256K (which golang-jwt doesn't recognize)
-
header, claims, err := parseJWTHeaderAndClaims(dpopProof)
-
if err != nil {
-
return nil, fmt.Errorf("failed to parse DPoP proof: %w", err)
-
}
-
-
// Extract and validate the typ header
-
typ, ok := header["typ"].(string)
-
if !ok || typ != "dpop+jwt" {
-
return nil, fmt.Errorf("invalid DPoP proof: typ must be 'dpop+jwt', got '%s'", typ)
-
}
-
-
alg, ok := header["alg"].(string)
-
if !ok {
-
return nil, fmt.Errorf("invalid DPoP proof: missing alg header")
-
}
-
-
// Extract the JWK from the header first (needed for algorithm-curve validation)
-
jwkRaw, ok := header["jwk"]
-
if !ok {
-
return nil, fmt.Errorf("invalid DPoP proof: missing jwk header")
-
}
-
-
jwkMap, ok := jwkRaw.(map[string]interface{})
-
if !ok {
-
return nil, fmt.Errorf("invalid DPoP proof: jwk must be an object")
-
}
-
-
// Validate the algorithm is supported and matches the JWK curve
-
// This is critical for security - prevents algorithm confusion attacks
-
if err := validateAlgorithmCurveBinding(alg, jwkMap); err != nil {
-
return nil, fmt.Errorf("invalid DPoP proof: %w", err)
-
}
-
-
// Parse the public key using indigo's crypto package
-
// This supports all atProto curves including secp256k1 (ES256K)
-
publicKey, err := parseJWKToIndigoPublicKey(jwkMap)
-
if err != nil {
-
return nil, fmt.Errorf("invalid DPoP proof JWK: %w", err)
-
}
-
-
// Calculate the JWK thumbprint
-
thumbprint, err := CalculateJWKThumbprint(jwkMap)
-
if err != nil {
-
return nil, fmt.Errorf("failed to calculate JWK thumbprint: %w", err)
-
}
-
-
// Verify the signature using indigo's crypto package
-
// This works for all ECDSA algorithms including ES256K
-
if err := verifyJWTSignatureWithIndigo(dpopProof, publicKey); err != nil {
-
return nil, fmt.Errorf("DPoP proof signature verification failed: %w", err)
-
}
-
-
// Validate the claims
-
if err := v.validateDPoPClaims(claims, httpMethod, httpURI); err != nil {
-
return nil, err
-
}
-
-
return &DPoPProof{
-
Claims: claims,
-
PublicKey: publicKey,
-
Thumbprint: thumbprint,
-
RawPublicJWK: jwkMap,
-
}, nil
-
}
-
-
// validateDPoPClaims validates the DPoP proof claims
-
func (v *DPoPVerifier) validateDPoPClaims(claims *DPoPClaims, expectedMethod, expectedURI string) error {
-
// Validate jti (unique identifier) is present
-
if claims.ID == "" {
-
return fmt.Errorf("DPoP proof missing jti claim")
-
}
-
-
// Validate htm (HTTP method)
-
if !strings.EqualFold(claims.HTTPMethod, expectedMethod) {
-
return fmt.Errorf("DPoP proof htm mismatch: expected %s, got %s", expectedMethod, claims.HTTPMethod)
-
}
-
-
// Validate htu (HTTP URI) - compare without query/fragment
-
expectedURIBase := stripQueryFragment(expectedURI)
-
claimURIBase := stripQueryFragment(claims.HTTPURI)
-
if expectedURIBase != claimURIBase {
-
return fmt.Errorf("DPoP proof htu mismatch: expected %s, got %s", expectedURIBase, claimURIBase)
-
}
-
-
// Validate iat (issued at) is present and recent
-
if claims.IssuedAt == nil {
-
return fmt.Errorf("DPoP proof missing iat claim")
-
}
-
-
now := time.Now()
-
iat := claims.IssuedAt.Time
-
-
// Check clock skew (not too far in the future)
-
if iat.After(now.Add(v.MaxClockSkew)) {
-
return fmt.Errorf("DPoP proof iat is in the future")
-
}
-
-
// Check proof age (not too old)
-
if now.Sub(iat) > v.MaxProofAge {
-
return fmt.Errorf("DPoP proof is too old (issued %v ago, max %v)", now.Sub(iat), v.MaxProofAge)
-
}
-
-
// SECURITY: Validate exp claim if present (RFC standard JWT validation)
-
// While DPoP proofs typically use iat + MaxProofAge, if exp is included it must be honored
-
if claims.ExpiresAt != nil {
-
expWithSkew := claims.ExpiresAt.Time.Add(v.MaxClockSkew)
-
if now.After(expWithSkew) {
-
return fmt.Errorf("DPoP proof expired at %v", claims.ExpiresAt.Time)
-
}
-
}
-
-
// SECURITY: Validate nbf claim if present (RFC standard JWT validation)
-
if claims.NotBefore != nil {
-
nbfWithSkew := claims.NotBefore.Time.Add(-v.MaxClockSkew)
-
if now.Before(nbfWithSkew) {
-
return fmt.Errorf("DPoP proof not valid before %v", claims.NotBefore.Time)
-
}
-
}
-
-
// SECURITY: Check for replay attack using jti
-
// Per RFC 9449 Section 11.1, servers SHOULD prevent replay attacks
-
if v.NonceCache != nil {
-
if !v.NonceCache.CheckAndStore(claims.ID) {
-
return fmt.Errorf("DPoP proof replay detected: jti %s already used", claims.ID)
-
}
-
}
-
-
return nil
-
}
-
-
// VerifyTokenBinding verifies that the DPoP proof binds to the access token
-
// by comparing the proof's thumbprint to the token's cnf.jkt claim
-
func (v *DPoPVerifier) VerifyTokenBinding(proof *DPoPProof, expectedThumbprint string) error {
-
if proof.Thumbprint != expectedThumbprint {
-
return fmt.Errorf("DPoP proof thumbprint mismatch: token expects %s, proof has %s",
-
expectedThumbprint, proof.Thumbprint)
-
}
-
return nil
-
}
-
-
// VerifyAccessTokenHash verifies the DPoP proof's ath (access token hash) claim
-
// matches the SHA-256 hash of the presented access token.
-
// Per RFC 9449 section 4.2, if ath is present, the RS MUST verify it.
-
func (v *DPoPVerifier) VerifyAccessTokenHash(proof *DPoPProof, accessToken string) error {
-
// If ath claim is not present, that's acceptable per RFC 9449
-
// (ath is only required when the RS mandates it)
-
if proof.Claims.AccessTokenHash == "" {
-
return nil
-
}
-
-
// Calculate the expected ath: base64url(SHA-256(access_token))
-
hash := sha256.Sum256([]byte(accessToken))
-
expectedAth := base64.RawURLEncoding.EncodeToString(hash[:])
-
-
if proof.Claims.AccessTokenHash != expectedAth {
-
return fmt.Errorf("DPoP proof ath mismatch: proof bound to different access token")
-
}
-
-
return nil
-
}
-
-
// CalculateJWKThumbprint calculates the JWK thumbprint per RFC 7638
-
// The thumbprint is the base64url-encoded SHA-256 hash of the canonical JWK representation
-
func CalculateJWKThumbprint(jwk map[string]interface{}) (string, error) {
-
kty, ok := jwk["kty"].(string)
-
if !ok {
-
return "", fmt.Errorf("JWK missing kty")
-
}
-
-
// Build the canonical JWK representation based on key type
-
// Per RFC 7638, only specific members are included, in lexicographic order
-
var canonical map[string]string
-
-
switch kty {
-
case "EC":
-
crv, ok := jwk["crv"].(string)
-
if !ok {
-
return "", fmt.Errorf("EC JWK missing crv")
-
}
-
x, ok := jwk["x"].(string)
-
if !ok {
-
return "", fmt.Errorf("EC JWK missing x")
-
}
-
y, ok := jwk["y"].(string)
-
if !ok {
-
return "", fmt.Errorf("EC JWK missing y")
-
}
-
// Lexicographic order: crv, kty, x, y
-
canonical = map[string]string{
-
"crv": crv,
-
"kty": kty,
-
"x": x,
-
"y": y,
-
}
-
case "RSA":
-
e, ok := jwk["e"].(string)
-
if !ok {
-
return "", fmt.Errorf("RSA JWK missing e")
-
}
-
n, ok := jwk["n"].(string)
-
if !ok {
-
return "", fmt.Errorf("RSA JWK missing n")
-
}
-
// Lexicographic order: e, kty, n
-
canonical = map[string]string{
-
"e": e,
-
"kty": kty,
-
"n": n,
-
}
-
case "OKP":
-
crv, ok := jwk["crv"].(string)
-
if !ok {
-
return "", fmt.Errorf("OKP JWK missing crv")
-
}
-
x, ok := jwk["x"].(string)
-
if !ok {
-
return "", fmt.Errorf("OKP JWK missing x")
-
}
-
// Lexicographic order: crv, kty, x
-
canonical = map[string]string{
-
"crv": crv,
-
"kty": kty,
-
"x": x,
-
}
-
default:
-
return "", fmt.Errorf("unsupported JWK key type: %s", kty)
-
}
-
-
// Serialize to JSON (Go's json.Marshal produces lexicographically ordered keys for map[string]string)
-
canonicalJSON, err := json.Marshal(canonical)
-
if err != nil {
-
return "", fmt.Errorf("failed to serialize canonical JWK: %w", err)
-
}
-
-
// SHA-256 hash
-
hash := sha256.Sum256(canonicalJSON)
-
-
// Base64url encode (no padding)
-
thumbprint := base64.RawURLEncoding.EncodeToString(hash[:])
-
-
return thumbprint, nil
-
}
-
-
// validateAlgorithmCurveBinding validates that the JWT algorithm matches the JWK curve.
-
// This is critical for security - an attacker could claim alg: "ES256K" but provide
-
// a P-256 key, potentially bypassing algorithm binding requirements.
-
func validateAlgorithmCurveBinding(alg string, jwkMap map[string]interface{}) error {
-
kty, ok := jwkMap["kty"].(string)
-
if !ok {
-
return fmt.Errorf("JWK missing kty")
-
}
-
-
// ECDSA algorithms require EC key type
-
switch alg {
-
case "ES256K", "ES256", "ES384", "ES512":
-
if kty != "EC" {
-
return fmt.Errorf("algorithm %s requires EC key type, got %s", alg, kty)
-
}
-
case "RS256", "RS384", "RS512", "PS256", "PS384", "PS512":
-
return fmt.Errorf("RSA algorithms not yet supported for DPoP: %s", alg)
-
default:
-
return fmt.Errorf("unsupported DPoP algorithm: %s", alg)
-
}
-
-
// Validate curve matches algorithm
-
crv, ok := jwkMap["crv"].(string)
-
if !ok {
-
return fmt.Errorf("EC JWK missing crv")
-
}
-
-
var expectedCurve string
-
switch alg {
-
case "ES256K":
-
expectedCurve = "secp256k1"
-
case "ES256":
-
expectedCurve = "P-256"
-
case "ES384":
-
expectedCurve = "P-384"
-
case "ES512":
-
expectedCurve = "P-521"
-
}
-
-
if crv != expectedCurve {
-
return fmt.Errorf("algorithm %s requires curve %s, got %s", alg, expectedCurve, crv)
-
}
-
-
return nil
-
}
-
-
// parseJWKToIndigoPublicKey parses a JWK map to an indigo PublicKey.
-
// This returns indigo's PublicKey interface which supports all atProto curves
-
// including secp256k1 (ES256K), P-256 (ES256), P-384 (ES384), and P-521 (ES512).
-
func parseJWKToIndigoPublicKey(jwkMap map[string]interface{}) (indigoCrypto.PublicKey, error) {
-
// Convert map to JSON bytes for indigo's parser
-
jwkBytes, err := json.Marshal(jwkMap)
-
if err != nil {
-
return nil, fmt.Errorf("failed to serialize JWK: %w", err)
-
}
-
-
// Parse with indigo's crypto package - this supports all atProto curves
-
// including secp256k1 (ES256K) which Go's crypto/elliptic doesn't support
-
pubKey, err := indigoCrypto.ParsePublicJWKBytes(jwkBytes)
-
if err != nil {
-
return nil, fmt.Errorf("failed to parse JWK: %w", err)
-
}
-
-
return pubKey, nil
-
}
-
-
// parseJWTHeaderAndClaims manually parses a JWT's header and claims without using golang-jwt.
-
// This is necessary to support ES256K (secp256k1) which golang-jwt doesn't recognize.
-
func parseJWTHeaderAndClaims(tokenString string) (map[string]interface{}, *DPoPClaims, error) {
-
parts := strings.Split(tokenString, ".")
-
if len(parts) != 3 {
-
return nil, nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
-
}
-
-
// Decode header
-
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
-
if err != nil {
-
return nil, nil, fmt.Errorf("failed to decode JWT header: %w", err)
-
}
-
-
var header map[string]interface{}
-
if err := json.Unmarshal(headerBytes, &header); err != nil {
-
return nil, nil, fmt.Errorf("failed to parse JWT header: %w", err)
-
}
-
-
// Decode claims
-
claimsBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
-
if err != nil {
-
return nil, nil, fmt.Errorf("failed to decode JWT claims: %w", err)
-
}
-
-
// Parse into raw map first to extract standard claims
-
var rawClaims map[string]interface{}
-
if err := json.Unmarshal(claimsBytes, &rawClaims); err != nil {
-
return nil, nil, fmt.Errorf("failed to parse JWT claims: %w", err)
-
}
-
-
// Build DPoPClaims struct
-
claims := &DPoPClaims{}
-
-
// Extract jti
-
if jti, ok := rawClaims["jti"].(string); ok {
-
claims.ID = jti
-
}
-
-
// Extract iat (issued at)
-
if iat, ok := rawClaims["iat"].(float64); ok {
-
t := time.Unix(int64(iat), 0)
-
claims.IssuedAt = jwt.NewNumericDate(t)
-
}
-
-
// Extract exp (expiration) if present
-
if exp, ok := rawClaims["exp"].(float64); ok {
-
t := time.Unix(int64(exp), 0)
-
claims.ExpiresAt = jwt.NewNumericDate(t)
-
}
-
-
// Extract nbf (not before) if present
-
if nbf, ok := rawClaims["nbf"].(float64); ok {
-
t := time.Unix(int64(nbf), 0)
-
claims.NotBefore = jwt.NewNumericDate(t)
-
}
-
-
// Extract htm (HTTP method)
-
if htm, ok := rawClaims["htm"].(string); ok {
-
claims.HTTPMethod = htm
-
}
-
-
// Extract htu (HTTP URI)
-
if htu, ok := rawClaims["htu"].(string); ok {
-
claims.HTTPURI = htu
-
}
-
-
// Extract ath (access token hash) if present
-
if ath, ok := rawClaims["ath"].(string); ok {
-
claims.AccessTokenHash = ath
-
}
-
-
return header, claims, nil
-
}
-
-
// verifyJWTSignatureWithIndigo verifies a JWT signature using indigo's crypto package.
-
// This is used instead of golang-jwt for algorithms not supported by golang-jwt (like ES256K).
-
// It parses the JWT, extracts the signing input and signature, and uses indigo's
-
// PublicKey.HashAndVerifyLenient() for verification.
-
//
-
// JWT format: header.payload.signature (all base64url-encoded)
-
// Signature is verified over the raw bytes of "header.payload"
-
// (indigo's HashAndVerifyLenient handles SHA-256 hashing internally)
-
func verifyJWTSignatureWithIndigo(tokenString string, pubKey indigoCrypto.PublicKey) error {
-
// Split the JWT into parts
-
parts := strings.Split(tokenString, ".")
-
if len(parts) != 3 {
-
return fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
-
}
-
-
// The signing input is "header.payload" (without decoding)
-
signingInput := parts[0] + "." + parts[1]
-
-
// Decode the signature from base64url
-
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
-
if err != nil {
-
return fmt.Errorf("failed to decode JWT signature: %w", err)
-
}
-
-
// Use indigo's verification - HashAndVerifyLenient handles hashing internally
-
// and accepts both low-S and high-S signatures for maximum compatibility
-
err = pubKey.HashAndVerifyLenient([]byte(signingInput), signature)
-
if err != nil {
-
return fmt.Errorf("signature verification failed: %w", err)
-
}
-
-
return nil
-
}
-
-
// stripQueryFragment removes query and fragment from a URI
-
func stripQueryFragment(uri string) string {
-
if idx := strings.Index(uri, "?"); idx != -1 {
-
uri = uri[:idx]
-
}
-
if idx := strings.Index(uri, "#"); idx != -1 {
-
uri = uri[:idx]
-
}
-
return uri
-
}
-
-
// ExtractCnfJkt extracts the cnf.jkt (confirmation key thumbprint) from JWT claims
-
func ExtractCnfJkt(claims *Claims) (string, error) {
-
if claims.Confirmation == nil {
-
return "", fmt.Errorf("token missing cnf claim (no DPoP binding)")
-
}
-
-
jkt, ok := claims.Confirmation["jkt"].(string)
-
if !ok || jkt == "" {
-
return "", fmt.Errorf("token cnf claim missing jkt (DPoP key thumbprint)")
-
}
-
-
return jkt, nil
-
}
-189
internal/atproto/auth/jwks_fetcher.go
···
-
package auth
-
-
import (
-
"context"
-
"encoding/json"
-
"fmt"
-
"net/http"
-
"strings"
-
"sync"
-
"time"
-
)
-
-
// CachedJWKSFetcher fetches and caches JWKS from authorization servers
-
type CachedJWKSFetcher struct {
-
cache map[string]*cachedJWKS
-
httpClient *http.Client
-
cacheMutex sync.RWMutex
-
cacheTTL time.Duration
-
}
-
-
type cachedJWKS struct {
-
jwks *JWKS
-
expiresAt time.Time
-
}
-
-
// NewCachedJWKSFetcher creates a new JWKS fetcher with caching
-
func NewCachedJWKSFetcher(cacheTTL time.Duration) *CachedJWKSFetcher {
-
return &CachedJWKSFetcher{
-
cache: make(map[string]*cachedJWKS),
-
httpClient: &http.Client{
-
Timeout: 10 * time.Second,
-
},
-
cacheTTL: cacheTTL,
-
}
-
}
-
-
// FetchPublicKey fetches the public key for verifying a JWT from the issuer
-
// Implements JWKSFetcher interface
-
// Returns interface{} to support both RSA and ECDSA keys
-
func (f *CachedJWKSFetcher) FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) {
-
// Extract key ID from token
-
kid, err := ExtractKeyID(token)
-
if err != nil {
-
return nil, fmt.Errorf("failed to extract key ID: %w", err)
-
}
-
-
// Get JWKS from cache or fetch
-
jwks, err := f.getJWKS(ctx, issuer)
-
if err != nil {
-
return nil, err
-
}
-
-
// Find the key by ID
-
jwk, err := jwks.FindKeyByID(kid)
-
if err != nil {
-
// Key not found in cache - try refreshing
-
jwks, err = f.fetchJWKS(ctx, issuer)
-
if err != nil {
-
return nil, fmt.Errorf("failed to refresh JWKS: %w", err)
-
}
-
f.cacheJWKS(issuer, jwks)
-
-
// Try again with fresh JWKS
-
jwk, err = jwks.FindKeyByID(kid)
-
if err != nil {
-
return nil, err
-
}
-
}
-
-
// Convert JWK to public key (RSA or ECDSA)
-
return jwk.ToPublicKey()
-
}
-
-
// getJWKS gets JWKS from cache or fetches if not cached/expired
-
func (f *CachedJWKSFetcher) getJWKS(ctx context.Context, issuer string) (*JWKS, error) {
-
// Check cache first
-
f.cacheMutex.RLock()
-
cached, exists := f.cache[issuer]
-
f.cacheMutex.RUnlock()
-
-
if exists && time.Now().Before(cached.expiresAt) {
-
return cached.jwks, nil
-
}
-
-
// Not in cache or expired - fetch from issuer
-
jwks, err := f.fetchJWKS(ctx, issuer)
-
if err != nil {
-
return nil, err
-
}
-
-
// Cache it
-
f.cacheJWKS(issuer, jwks)
-
-
return jwks, nil
-
}
-
-
// fetchJWKS fetches JWKS from the authorization server
-
func (f *CachedJWKSFetcher) fetchJWKS(ctx context.Context, issuer string) (*JWKS, error) {
-
// Step 1: Fetch OAuth server metadata to get JWKS URI
-
metadataURL := strings.TrimSuffix(issuer, "/") + "/.well-known/oauth-authorization-server"
-
-
req, err := http.NewRequestWithContext(ctx, "GET", metadataURL, nil)
-
if err != nil {
-
return nil, fmt.Errorf("failed to create metadata request: %w", err)
-
}
-
-
resp, err := f.httpClient.Do(req)
-
if err != nil {
-
return nil, fmt.Errorf("failed to fetch metadata: %w", err)
-
}
-
defer func() {
-
_ = resp.Body.Close()
-
}()
-
-
if resp.StatusCode != http.StatusOK {
-
return nil, fmt.Errorf("metadata endpoint returned status %d", resp.StatusCode)
-
}
-
-
var metadata struct {
-
JWKSURI string `json:"jwks_uri"`
-
}
-
if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil {
-
return nil, fmt.Errorf("failed to decode metadata: %w", err)
-
}
-
-
if metadata.JWKSURI == "" {
-
return nil, fmt.Errorf("jwks_uri not found in metadata")
-
}
-
-
// Step 2: Fetch JWKS from the JWKS URI
-
jwksReq, err := http.NewRequestWithContext(ctx, "GET", metadata.JWKSURI, nil)
-
if err != nil {
-
return nil, fmt.Errorf("failed to create JWKS request: %w", err)
-
}
-
-
jwksResp, err := f.httpClient.Do(jwksReq)
-
if err != nil {
-
return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
-
}
-
defer func() {
-
_ = jwksResp.Body.Close()
-
}()
-
-
if jwksResp.StatusCode != http.StatusOK {
-
return nil, fmt.Errorf("JWKS endpoint returned status %d", jwksResp.StatusCode)
-
}
-
-
var jwks JWKS
-
if err := json.NewDecoder(jwksResp.Body).Decode(&jwks); err != nil {
-
return nil, fmt.Errorf("failed to decode JWKS: %w", err)
-
}
-
-
if len(jwks.Keys) == 0 {
-
return nil, fmt.Errorf("no keys found in JWKS")
-
}
-
-
return &jwks, nil
-
}
-
-
// cacheJWKS stores JWKS in the cache
-
func (f *CachedJWKSFetcher) cacheJWKS(issuer string, jwks *JWKS) {
-
f.cacheMutex.Lock()
-
defer f.cacheMutex.Unlock()
-
-
f.cache[issuer] = &cachedJWKS{
-
jwks: jwks,
-
expiresAt: time.Now().Add(f.cacheTTL),
-
}
-
}
-
-
// ClearCache clears the entire JWKS cache
-
func (f *CachedJWKSFetcher) ClearCache() {
-
f.cacheMutex.Lock()
-
defer f.cacheMutex.Unlock()
-
f.cache = make(map[string]*cachedJWKS)
-
}
-
-
// CleanupExpiredCache removes expired entries from the cache
-
func (f *CachedJWKSFetcher) CleanupExpiredCache() {
-
f.cacheMutex.Lock()
-
defer f.cacheMutex.Unlock()
-
-
now := time.Now()
-
for issuer, cached := range f.cache {
-
if now.After(cached.expiresAt) {
-
delete(f.cache, issuer)
-
}
-
}
-
}
-496
internal/atproto/auth/jwt_test.go
···
-
package auth
-
-
import (
-
"context"
-
"testing"
-
"time"
-
-
"github.com/golang-jwt/jwt/v5"
-
)
-
-
func TestParseJWT(t *testing.T) {
-
// Create a test JWT token
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
Issuer: "https://test-pds.example.com",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
IssuedAt: jwt.NewNumericDate(time.Now()),
-
},
-
Scope: "atproto transition:generic",
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
-
tokenString, err := token.SignedString([]byte("test-secret"))
-
if err != nil {
-
t.Fatalf("Failed to create test token: %v", err)
-
}
-
-
// Test parsing
-
parsedClaims, err := ParseJWT(tokenString)
-
if err != nil {
-
t.Fatalf("ParseJWT failed: %v", err)
-
}
-
-
if parsedClaims.Subject != "did:plc:test123" {
-
t.Errorf("Expected subject 'did:plc:test123', got '%s'", parsedClaims.Subject)
-
}
-
-
if parsedClaims.Issuer != "https://test-pds.example.com" {
-
t.Errorf("Expected issuer 'https://test-pds.example.com', got '%s'", parsedClaims.Issuer)
-
}
-
-
if parsedClaims.Scope != "atproto transition:generic" {
-
t.Errorf("Expected scope 'atproto transition:generic', got '%s'", parsedClaims.Scope)
-
}
-
}
-
-
func TestParseJWT_MissingSubject(t *testing.T) {
-
// Create a token without subject
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Issuer: "https://test-pds.example.com",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
},
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
-
tokenString, err := token.SignedString([]byte("test-secret"))
-
if err != nil {
-
t.Fatalf("Failed to create test token: %v", err)
-
}
-
-
// Test parsing - should fail
-
_, err = ParseJWT(tokenString)
-
if err == nil {
-
t.Error("Expected error for missing subject, got nil")
-
}
-
}
-
-
func TestParseJWT_MissingIssuer(t *testing.T) {
-
// Create a token without issuer
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
},
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
-
tokenString, err := token.SignedString([]byte("test-secret"))
-
if err != nil {
-
t.Fatalf("Failed to create test token: %v", err)
-
}
-
-
// Test parsing - should fail
-
_, err = ParseJWT(tokenString)
-
if err == nil {
-
t.Error("Expected error for missing issuer, got nil")
-
}
-
}
-
-
func TestParseJWT_WithBearerPrefix(t *testing.T) {
-
// Create a test JWT token
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
Issuer: "https://test-pds.example.com",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
},
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
-
tokenString, err := token.SignedString([]byte("test-secret"))
-
if err != nil {
-
t.Fatalf("Failed to create test token: %v", err)
-
}
-
-
// Test parsing with Bearer prefix
-
parsedClaims, err := ParseJWT("Bearer " + tokenString)
-
if err != nil {
-
t.Fatalf("ParseJWT failed with Bearer prefix: %v", err)
-
}
-
-
if parsedClaims.Subject != "did:plc:test123" {
-
t.Errorf("Expected subject 'did:plc:test123', got '%s'", parsedClaims.Subject)
-
}
-
}
-
-
func TestValidateClaims_Expired(t *testing.T) {
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
Issuer: "https://test-pds.example.com",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(-1 * time.Hour)), // Expired
-
},
-
}
-
-
err := validateClaims(claims)
-
if err == nil {
-
t.Error("Expected error for expired token, got nil")
-
}
-
}
-
-
func TestValidateClaims_InvalidDID(t *testing.T) {
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "invalid-did-format",
-
Issuer: "https://test-pds.example.com",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
},
-
}
-
-
err := validateClaims(claims)
-
if err == nil {
-
t.Error("Expected error for invalid DID format, got nil")
-
}
-
}
-
-
func TestExtractKeyID(t *testing.T) {
-
// Create a test JWT token with kid in header
-
token := jwt.New(jwt.SigningMethodRS256)
-
token.Header["kid"] = "test-key-id"
-
token.Claims = &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
Issuer: "https://test-pds.example.com",
-
},
-
}
-
-
// Sign with a dummy RSA key (we just need a valid token structure)
-
tokenString, err := token.SignedString([]byte("dummy"))
-
if err == nil {
-
// If it succeeds (shouldn't with wrong key type, but let's handle it)
-
kid, err := ExtractKeyID(tokenString)
-
if err != nil {
-
t.Logf("ExtractKeyID failed (expected if signing fails): %v", err)
-
} else if kid != "test-key-id" {
-
t.Errorf("Expected kid 'test-key-id', got '%s'", kid)
-
}
-
}
-
}
-
-
// === HS256 Verification Tests ===
-
-
// mockJWKSFetcher is a mock implementation of JWKSFetcher for testing
-
type mockJWKSFetcher struct {
-
publicKey interface{}
-
err error
-
}
-
-
func (m *mockJWKSFetcher) FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) {
-
return m.publicKey, m.err
-
}
-
-
func createHS256Token(t *testing.T, subject, issuer, secret string, expiry time.Duration) string {
-
t.Helper()
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: subject,
-
Issuer: issuer,
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiry)),
-
IssuedAt: jwt.NewNumericDate(time.Now()),
-
},
-
Scope: "atproto transition:generic",
-
}
-
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
-
tokenString, err := token.SignedString([]byte(secret))
-
if err != nil {
-
t.Fatalf("Failed to create test token: %v", err)
-
}
-
return tokenString
-
}
-
-
func TestVerifyJWT_HS256_Valid(t *testing.T) {
-
// Setup: Configure environment for HS256 verification
-
secret := "test-jwt-secret-key-12345"
-
issuer := "https://pds.coves.social"
-
-
ResetJWTConfigForTesting()
-
t.Setenv("PDS_JWT_SECRET", secret)
-
t.Setenv("HS256_ISSUERS", issuer)
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
tokenString := createHS256Token(t, "did:plc:test123", issuer, secret, 1*time.Hour)
-
-
// Verify token
-
claims, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
-
if err != nil {
-
t.Fatalf("VerifyJWT failed for valid HS256 token: %v", err)
-
}
-
-
if claims.Subject != "did:plc:test123" {
-
t.Errorf("Expected subject 'did:plc:test123', got '%s'", claims.Subject)
-
}
-
if claims.Issuer != issuer {
-
t.Errorf("Expected issuer '%s', got '%s'", issuer, claims.Issuer)
-
}
-
}
-
-
func TestVerifyJWT_HS256_WrongSecret(t *testing.T) {
-
// Setup: Configure environment with one secret, sign with another
-
issuer := "https://pds.coves.social"
-
-
ResetJWTConfigForTesting()
-
t.Setenv("PDS_JWT_SECRET", "correct-secret")
-
t.Setenv("HS256_ISSUERS", issuer)
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
// Create token with wrong secret
-
tokenString := createHS256Token(t, "did:plc:test123", issuer, "wrong-secret", 1*time.Hour)
-
-
// Verify should fail
-
_, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
-
if err == nil {
-
t.Error("Expected error for HS256 token with wrong secret, got nil")
-
}
-
}
-
-
func TestVerifyJWT_HS256_SecretNotConfigured(t *testing.T) {
-
// Setup: Whitelist issuer but don't configure secret
-
issuer := "https://pds.coves.social"
-
-
ResetJWTConfigForTesting()
-
t.Setenv("PDS_JWT_SECRET", "") // Ensure secret is not set (empty = not configured)
-
t.Setenv("HS256_ISSUERS", issuer)
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
tokenString := createHS256Token(t, "did:plc:test123", issuer, "any-secret", 1*time.Hour)
-
-
// Verify should fail with descriptive error
-
_, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
-
if err == nil {
-
t.Error("Expected error when PDS_JWT_SECRET not configured, got nil")
-
}
-
if err != nil && !contains(err.Error(), "PDS_JWT_SECRET not configured") {
-
t.Errorf("Expected error about PDS_JWT_SECRET not configured, got: %v", err)
-
}
-
}
-
-
// === Algorithm Confusion Attack Prevention Tests ===
-
-
func TestVerifyJWT_AlgorithmConfusionAttack_HS256WithNonWhitelistedIssuer(t *testing.T) {
-
// SECURITY TEST: This tests the algorithm confusion attack prevention
-
// An attacker tries to use HS256 with an issuer that should use RS256/ES256
-
-
ResetJWTConfigForTesting()
-
t.Setenv("PDS_JWT_SECRET", "some-secret")
-
t.Setenv("HS256_ISSUERS", "https://trusted.example.com") // Different from token issuer
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
// Create HS256 token with non-whitelisted issuer (simulating attack)
-
tokenString := createHS256Token(t, "did:plc:attacker", "https://victim-pds.example.com", "some-secret", 1*time.Hour)
-
-
// Verify should fail because issuer is not in HS256 whitelist
-
_, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
-
if err == nil {
-
t.Error("SECURITY VULNERABILITY: HS256 token accepted for non-whitelisted issuer")
-
}
-
if err != nil && !contains(err.Error(), "not in HS256_ISSUERS whitelist") {
-
t.Errorf("Expected error about HS256 not allowed for issuer, got: %v", err)
-
}
-
}
-
-
func TestVerifyJWT_AlgorithmConfusionAttack_EmptyWhitelist(t *testing.T) {
-
// SECURITY TEST: When no issuers are whitelisted for HS256, all HS256 tokens should be rejected
-
-
ResetJWTConfigForTesting()
-
t.Setenv("PDS_JWT_SECRET", "some-secret")
-
t.Setenv("HS256_ISSUERS", "") // Empty whitelist
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
tokenString := createHS256Token(t, "did:plc:test123", "https://any-pds.example.com", "some-secret", 1*time.Hour)
-
-
// Verify should fail because no issuers are whitelisted for HS256
-
_, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
-
if err == nil {
-
t.Error("SECURITY VULNERABILITY: HS256 token accepted with empty issuer whitelist")
-
}
-
}
-
-
func TestVerifyJWT_IssuerRequiresHS256ButTokenUsesRS256(t *testing.T) {
-
// Test that issuer whitelisted for HS256 rejects tokens claiming to use RS256
-
issuer := "https://pds.coves.social"
-
-
ResetJWTConfigForTesting()
-
t.Setenv("PDS_JWT_SECRET", "test-secret")
-
t.Setenv("HS256_ISSUERS", issuer)
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
// Create RS256-signed token (can't actually sign without RSA key, but we can test the header check)
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
Issuer: issuer,
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
},
-
}
-
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
-
// This will create an invalid signature but valid header structure
-
// The test should fail at algorithm check, not signature verification
-
tokenString, _ := token.SignedString([]byte("dummy-key"))
-
-
if tokenString != "" {
-
_, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
-
if err == nil {
-
t.Error("Expected error when HS256 issuer receives non-HS256 token")
-
}
-
}
-
}
-
-
// === ParseJWTHeader Tests ===
-
-
func TestParseJWTHeader_Valid(t *testing.T) {
-
tokenString := createHS256Token(t, "did:plc:test123", "https://test.example.com", "secret", 1*time.Hour)
-
-
header, err := ParseJWTHeader(tokenString)
-
if err != nil {
-
t.Fatalf("ParseJWTHeader failed: %v", err)
-
}
-
-
if header.Alg != AlgorithmHS256 {
-
t.Errorf("Expected alg '%s', got '%s'", AlgorithmHS256, header.Alg)
-
}
-
}
-
-
func TestParseJWTHeader_WithBearerPrefix(t *testing.T) {
-
tokenString := createHS256Token(t, "did:plc:test123", "https://test.example.com", "secret", 1*time.Hour)
-
-
header, err := ParseJWTHeader("Bearer " + tokenString)
-
if err != nil {
-
t.Fatalf("ParseJWTHeader failed with Bearer prefix: %v", err)
-
}
-
-
if header.Alg != AlgorithmHS256 {
-
t.Errorf("Expected alg '%s', got '%s'", AlgorithmHS256, header.Alg)
-
}
-
}
-
-
func TestParseJWTHeader_InvalidFormat(t *testing.T) {
-
testCases := []struct {
-
name string
-
input string
-
}{
-
{"empty string", ""},
-
{"single part", "abc"},
-
{"two parts", "abc.def"},
-
{"too many parts", "a.b.c.d"},
-
}
-
-
for _, tc := range testCases {
-
t.Run(tc.name, func(t *testing.T) {
-
_, err := ParseJWTHeader(tc.input)
-
if err == nil {
-
t.Errorf("Expected error for invalid JWT format '%s', got nil", tc.input)
-
}
-
})
-
}
-
}
-
-
// === shouldUseHS256 and isHS256IssuerWhitelisted Tests ===
-
-
func TestIsHS256IssuerWhitelisted_Whitelisted(t *testing.T) {
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", "https://pds1.example.com,https://pds2.example.com")
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
if !isHS256IssuerWhitelisted("https://pds1.example.com") {
-
t.Error("Expected pds1 to be whitelisted")
-
}
-
if !isHS256IssuerWhitelisted("https://pds2.example.com") {
-
t.Error("Expected pds2 to be whitelisted")
-
}
-
}
-
-
func TestIsHS256IssuerWhitelisted_NotWhitelisted(t *testing.T) {
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", "https://pds1.example.com")
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
if isHS256IssuerWhitelisted("https://attacker.example.com") {
-
t.Error("Expected non-whitelisted issuer to return false")
-
}
-
}
-
-
func TestIsHS256IssuerWhitelisted_EmptyWhitelist(t *testing.T) {
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", "") // Empty whitelist
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
if isHS256IssuerWhitelisted("https://any.example.com") {
-
t.Error("Expected false when whitelist is empty (safe default)")
-
}
-
}
-
-
func TestIsHS256IssuerWhitelisted_WhitespaceHandling(t *testing.T) {
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", " https://pds1.example.com , https://pds2.example.com ")
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
if !isHS256IssuerWhitelisted("https://pds1.example.com") {
-
t.Error("Expected whitespace-trimmed issuer to be whitelisted")
-
}
-
}
-
-
// === shouldUseHS256 Tests (kid-based logic) ===
-
-
func TestShouldUseHS256_WithKid_AlwaysFalse(t *testing.T) {
-
// Tokens with kid should NEVER use HS256, regardless of issuer whitelist
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", "https://whitelisted.example.com")
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
header := &JWTHeader{
-
Alg: AlgorithmHS256,
-
Kid: "some-key-id", // Has kid
-
}
-
-
// Even whitelisted issuer should not use HS256 if token has kid
-
if shouldUseHS256(header, "https://whitelisted.example.com") {
-
t.Error("Tokens with kid should never use HS256 (supports federation)")
-
}
-
}
-
-
func TestShouldUseHS256_WithoutKid_WhitelistedIssuer(t *testing.T) {
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", "https://my-pds.example.com")
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
header := &JWTHeader{
-
Alg: AlgorithmHS256,
-
Kid: "", // No kid
-
}
-
-
if !shouldUseHS256(header, "https://my-pds.example.com") {
-
t.Error("Token without kid from whitelisted issuer should use HS256")
-
}
-
}
-
-
func TestShouldUseHS256_WithoutKid_NotWhitelisted(t *testing.T) {
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", "https://my-pds.example.com")
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
header := &JWTHeader{
-
Alg: AlgorithmHS256,
-
Kid: "", // No kid
-
}
-
-
if shouldUseHS256(header, "https://external-pds.example.com") {
-
t.Error("Token without kid from non-whitelisted issuer should NOT use HS256")
-
}
-
}
-
-
// Helper function
-
func contains(s, substr string) bool {
-
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
-
}
-
-
func containsHelper(s, substr string) bool {
-
for i := 0; i <= len(s)-len(substr); i++ {
-
if s[i:i+len(substr)] == substr {
-
return true
-
}
-
}
-
return false
-
}
+1
docker-compose.prod.yml
···
# Instance identity
INSTANCE_DID: did:web:coves.social
INSTANCE_DOMAIN: coves.social
+
APPVIEW_PUBLIC_URL: https://coves.social
# PDS connection (separate domain!)
PDS_URL: https://coves.me
-97
static/oauth/callback.html
···
-
<!DOCTYPE html>
-
<html>
-
<head>
-
<meta charset="utf-8">
-
<meta name="viewport" content="width=device-width, initial-scale=1">
-
<meta http-equiv="Content-Security-Policy" content="default-src 'self'; script-src 'unsafe-inline'; style-src 'unsafe-inline'">
-
<title>Authorization Successful - Coves</title>
-
<style>
-
body {
-
font-family: system-ui, -apple-system, sans-serif;
-
display: flex;
-
align-items: center;
-
justify-content: center;
-
min-height: 100vh;
-
margin: 0;
-
background: #f5f5f5;
-
}
-
.container {
-
text-align: center;
-
padding: 2rem;
-
background: white;
-
border-radius: 8px;
-
box-shadow: 0 2px 8px rgba(0,0,0,0.1);
-
max-width: 400px;
-
}
-
.success { color: #22c55e; font-size: 3rem; margin-bottom: 1rem; }
-
h1 { margin: 0 0 0.5rem; color: #1f2937; font-size: 1.5rem; }
-
p { color: #6b7280; margin: 0.5rem 0; }
-
a {
-
display: inline-block;
-
margin-top: 1rem;
-
padding: 0.75rem 1.5rem;
-
background: #3b82f6;
-
color: white;
-
text-decoration: none;
-
border-radius: 6px;
-
font-weight: 500;
-
}
-
a:hover { background: #2563eb; }
-
</style>
-
</head>
-
<body>
-
<div class="container">
-
<div class="success">โœ“</div>
-
<h1>Authorization Successful!</h1>
-
<p id="status">Returning to Coves...</p>
-
<a href="#" id="manualLink">Open Coves</a>
-
</div>
-
<script>
-
(function() {
-
// Parse and sanitize query params - only allow expected OAuth parameters
-
const urlParams = new URLSearchParams(window.location.search);
-
const safeParams = new URLSearchParams();
-
-
// Whitelist only expected OAuth callback parameters
-
const code = urlParams.get('code');
-
const state = urlParams.get('state');
-
const error = urlParams.get('error');
-
const errorDescription = urlParams.get('error_description');
-
const iss = urlParams.get('iss');
-
-
if (code) safeParams.set('code', code);
-
if (state) safeParams.set('state', state);
-
if (error) safeParams.set('error', error);
-
if (errorDescription) safeParams.set('error_description', errorDescription);
-
if (iss) safeParams.set('iss', iss);
-
-
const sanitizedQuery = safeParams.toString() ? '?' + safeParams.toString() : '';
-
-
const userAgent = navigator.userAgent || '';
-
const isAndroid = /Android/i.test(userAgent);
-
-
// Build deep link based on platform
-
let deepLink;
-
if (isAndroid) {
-
// Android: Intent URL format
-
const pathAndQuery = '/oauth/callback' + sanitizedQuery;
-
deepLink = 'intent:/' + pathAndQuery + '#Intent;scheme=social.coves;package=social.coves;end';
-
} else {
-
// iOS: Custom scheme
-
deepLink = 'social.coves:/oauth/callback' + sanitizedQuery;
-
}
-
-
// Update manual link
-
document.getElementById('manualLink').href = deepLink;
-
-
// Attempt automatic redirect
-
window.location.href = deepLink;
-
-
// Update status after 2 seconds if redirect didn't work
-
setTimeout(function() {
-
document.getElementById('status').textContent = 'Click the button above to continue';
-
}, 2000);
-
})();
-
</script>
-
</body>
-
</html>
+6 -5
internal/api/routes/oauth.go
···
// Use login limiter since callback completes the authentication flow
r.With(corsMiddleware(allowedOrigins), loginLimiter.Middleware).Get("/oauth/callback", handler.HandleCallback)
-
// Mobile Universal Link callback route
-
// This route is used for iOS Universal Links and Android App Links
-
// Path must match the path in .well-known/apple-app-site-association
-
// Uses the same handler as web callback - the system routes it to the mobile app
-
r.With(loginLimiter.Middleware).Get("/app/oauth/callback", handler.HandleCallback)
+
// Mobile Universal Link callback route (fallback when app doesn't intercept)
+
// This route exists for iOS Universal Links and Android App Links.
+
// When properly configured, the mobile OS intercepts this URL and opens the app
+
// BEFORE the request reaches the server. If this handler is reached, it means
+
// Universal Links failed to intercept.
+
r.With(loginLimiter.Middleware).Get("/app/oauth/callback", handler.HandleMobileDeepLinkFallback)
// Session management - dedicated rate limits
r.With(logoutLimiter.Middleware).Post("/oauth/logout", handler.HandleLogout)
+11
static/.well-known/apple-app-site-association
···
+
{
+
"applinks": {
+
"apps": [],
+
"details": [
+
{
+
"appID": "TEAM_ID.social.coves",
+
"paths": ["/app/oauth/callback"]
+
}
+
]
+
}
+
}
+10
static/.well-known/assetlinks.json
···
+
[{
+
"relation": ["delegate_permission/common.handle_all_urls"],
+
"target": {
+
"namespace": "android_app",
+
"package_name": "social.coves",
+
"sha256_cert_fingerprints": [
+
"0B:D8:8C:99:66:25:E5:CD:06:54:80:88:01:6F:B7:38:B9:F4:5B:41:71:F7:95:C8:68:94:87:AD:EA:9F:D9:ED"
+
]
+
}
+
}]
+16 -9
internal/atproto/oauth/handlers_test.go
···
}
// TestIsMobileRedirectURI tests mobile redirect URI validation with EXACT URI matching
-
// Only Universal Links (HTTPS) are allowed - custom schemes are blocked for security
+
// Per atproto spec, custom schemes must match client_id hostname in reverse-domain order
func TestIsMobileRedirectURI(t *testing.T) {
tests := []struct {
uri string
expected bool
}{
-
{"https://coves.social/app/oauth/callback", true}, // Universal Link - allowed
-
{"coves-app://oauth/callback", false}, // Custom scheme - blocked (insecure)
-
{"coves://oauth/callback", false}, // Custom scheme - blocked (insecure)
-
{"coves-app://callback", false}, // Custom scheme - blocked
-
{"coves://oauth", false}, // Custom scheme - blocked
-
{"myapp://oauth", false}, // Not in allowlist
-
{"https://example.com", false}, // Wrong domain
-
{"http://localhost", false}, // HTTP not allowed
+
// Custom scheme per atproto spec (reverse domain of coves.social)
+
{"social.coves:/callback", true},
+
{"social.coves://callback", true},
+
{"social.coves:/oauth/callback", true},
+
{"social.coves://oauth/callback", true},
+
// Universal Link - allowed (strongest security)
+
{"https://coves.social/app/oauth/callback", true},
+
// Wrong custom schemes - not reverse-domain of coves.social
+
{"coves-app://oauth/callback", false},
+
{"coves://oauth/callback", false},
+
{"coves.social://callback", false}, // Not reversed
+
{"myapp://oauth", false},
+
// Wrong domain/scheme
+
{"https://example.com", false},
+
{"http://localhost", false},
{"", false},
{"not-a-uri", false},
}
+41
internal/atproto/lexicon/social/coves/feed/vote/delete.json
···
+
{
+
"lexicon": 1,
+
"id": "social.coves.feed.vote.delete",
+
"defs": {
+
"main": {
+
"type": "procedure",
+
"description": "Delete a vote on a post or comment",
+
"input": {
+
"encoding": "application/json",
+
"schema": {
+
"type": "object",
+
"required": ["subject"],
+
"properties": {
+
"subject": {
+
"type": "ref",
+
"ref": "com.atproto.repo.strongRef",
+
"description": "Strong reference to the post or comment to remove the vote from"
+
}
+
}
+
}
+
},
+
"output": {
+
"encoding": "application/json",
+
"schema": {
+
"type": "object",
+
"properties": {}
+
}
+
},
+
"errors": [
+
{
+
"name": "VoteNotFound",
+
"description": "No vote found for this subject"
+
},
+
{
+
"name": "NotAuthorized",
+
"description": "User is not authorized to delete this vote"
+
}
+
]
+
}
+
}
+
}
+115
internal/api/handlers/vote/create_vote.go
···
+
package vote
+
+
import (
+
"Coves/internal/api/middleware"
+
"Coves/internal/core/votes"
+
"encoding/json"
+
"log"
+
"net/http"
+
)
+
+
// CreateVoteHandler handles vote creation
+
type CreateVoteHandler struct {
+
service votes.Service
+
}
+
+
// NewCreateVoteHandler creates a new create vote handler
+
func NewCreateVoteHandler(service votes.Service) *CreateVoteHandler {
+
return &CreateVoteHandler{
+
service: service,
+
}
+
}
+
+
// CreateVoteInput represents the request body for creating a vote
+
type CreateVoteInput struct {
+
Subject struct {
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
} `json:"subject"`
+
Direction string `json:"direction"`
+
}
+
+
// CreateVoteOutput represents the response body for creating a vote
+
type CreateVoteOutput struct {
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
}
+
+
// HandleCreateVote creates a vote on a post or comment
+
// POST /xrpc/social.coves.vote.create
+
//
+
// Request body: { "subject": { "uri": "at://...", "cid": "..." }, "direction": "up" }
+
// Response: { "uri": "at://...", "cid": "..." }
+
//
+
// Behavior:
+
// - If no vote exists: creates new vote with given direction
+
// - If vote exists with same direction: deletes vote (toggle off)
+
// - If vote exists with different direction: updates to new direction
+
func (h *CreateVoteHandler) HandleCreateVote(w http.ResponseWriter, r *http.Request) {
+
if r.Method != http.MethodPost {
+
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+
return
+
}
+
+
// Parse request body
+
var input CreateVoteInput
+
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "Invalid request body")
+
return
+
}
+
+
// Validate required fields
+
if input.Subject.URI == "" {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "subject.uri is required")
+
return
+
}
+
if input.Subject.CID == "" {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "subject.cid is required")
+
return
+
}
+
if input.Direction == "" {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "direction is required")
+
return
+
}
+
+
// Validate direction
+
if input.Direction != "up" && input.Direction != "down" {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "direction must be 'up' or 'down'")
+
return
+
}
+
+
// Get OAuth session from context (injected by auth middleware)
+
session := middleware.GetOAuthSession(r)
+
if session == nil {
+
writeError(w, http.StatusUnauthorized, "AuthRequired", "Authentication required")
+
return
+
}
+
+
// Create vote request
+
req := votes.CreateVoteRequest{
+
Subject: votes.StrongRef{
+
URI: input.Subject.URI,
+
CID: input.Subject.CID,
+
},
+
Direction: input.Direction,
+
}
+
+
// Call service to create vote
+
response, err := h.service.CreateVote(r.Context(), session, req)
+
if err != nil {
+
handleServiceError(w, err)
+
return
+
}
+
+
// Return success response
+
output := CreateVoteOutput{
+
URI: response.URI,
+
CID: response.CID,
+
}
+
+
w.Header().Set("Content-Type", "application/json")
+
w.WriteHeader(http.StatusOK)
+
if err := json.NewEncoder(w).Encode(output); err != nil {
+
log.Printf("Failed to encode response: %v", err)
+
}
+
}
+93
internal/api/handlers/vote/delete_vote.go
···
+
package vote
+
+
import (
+
"Coves/internal/api/middleware"
+
"Coves/internal/core/votes"
+
"encoding/json"
+
"log"
+
"net/http"
+
)
+
+
// DeleteVoteHandler handles vote deletion
+
type DeleteVoteHandler struct {
+
service votes.Service
+
}
+
+
// NewDeleteVoteHandler creates a new delete vote handler
+
func NewDeleteVoteHandler(service votes.Service) *DeleteVoteHandler {
+
return &DeleteVoteHandler{
+
service: service,
+
}
+
}
+
+
// DeleteVoteInput represents the request body for deleting a vote
+
type DeleteVoteInput struct {
+
Subject struct {
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
} `json:"subject"`
+
}
+
+
// DeleteVoteOutput represents the response body for deleting a vote
+
// Per lexicon: output is an empty object
+
type DeleteVoteOutput struct{}
+
+
// HandleDeleteVote removes a vote from a post or comment
+
// POST /xrpc/social.coves.vote.delete
+
//
+
// Request body: { "subject": { "uri": "at://...", "cid": "..." } }
+
// Response: { "success": true }
+
func (h *DeleteVoteHandler) HandleDeleteVote(w http.ResponseWriter, r *http.Request) {
+
if r.Method != http.MethodPost {
+
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+
return
+
}
+
+
// Parse request body
+
var input DeleteVoteInput
+
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "Invalid request body")
+
return
+
}
+
+
// Validate required fields
+
if input.Subject.URI == "" {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "subject.uri is required")
+
return
+
}
+
if input.Subject.CID == "" {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "subject.cid is required")
+
return
+
}
+
+
// Get OAuth session from context (injected by auth middleware)
+
session := middleware.GetOAuthSession(r)
+
if session == nil {
+
writeError(w, http.StatusUnauthorized, "AuthRequired", "Authentication required")
+
return
+
}
+
+
// Create delete vote request
+
req := votes.DeleteVoteRequest{
+
Subject: votes.StrongRef{
+
URI: input.Subject.URI,
+
CID: input.Subject.CID,
+
},
+
}
+
+
// Call service to delete vote
+
err := h.service.DeleteVote(r.Context(), session, req)
+
if err != nil {
+
handleServiceError(w, err)
+
return
+
}
+
+
// Return success response (empty object per lexicon)
+
output := DeleteVoteOutput{}
+
+
w.Header().Set("Content-Type", "application/json")
+
w.WriteHeader(http.StatusOK)
+
if err := json.NewEncoder(w).Encode(output); err != nil {
+
log.Printf("Failed to encode response: %v", err)
+
}
+
}
+24
internal/api/routes/vote.go
···
+
package routes
+
+
import (
+
"Coves/internal/api/handlers/vote"
+
"Coves/internal/api/middleware"
+
"Coves/internal/core/votes"
+
+
"github.com/go-chi/chi/v5"
+
)
+
+
// RegisterVoteRoutes registers vote-related XRPC endpoints on the router
+
// Implements social.coves.feed.vote.* lexicon endpoints
+
func RegisterVoteRoutes(r chi.Router, voteService votes.Service, authMiddleware *middleware.OAuthAuthMiddleware) {
+
// Initialize handlers
+
createHandler := vote.NewCreateVoteHandler(voteService)
+
deleteHandler := vote.NewDeleteVoteHandler(voteService)
+
+
// Procedure endpoints (POST) - require authentication
+
// social.coves.feed.vote.create - create or update a vote on a post/comment
+
r.With(authMiddleware.RequireAuth).Post("/xrpc/social.coves.feed.vote.create", createHandler.HandleCreateVote)
+
+
// social.coves.feed.vote.delete - delete a vote from a post/comment
+
r.With(authMiddleware.RequireAuth).Post("/xrpc/social.coves.feed.vote.delete", deleteHandler.HandleDeleteVote)
+
}
+3
.beads/beads.left.jsonl
···
+
{"id":"Coves-95q","content_hash":"8ec99d598f067780436b985f9ad57f0fa19632026981038df4f65f192186620b","title":"Add comprehensive API documentation","description":"","status":"open","priority":2,"issue_type":"task","created_at":"2025-11-17T20:30:34.835721854-08:00","updated_at":"2025-11-17T20:30:34.835721854-08:00","source_repo":".","dependencies":[{"issue_id":"Coves-95q","depends_on_id":"Coves-e16","type":"blocks","created_at":"2025-11-17T20:30:46.273899399-08:00","created_by":"daemon"}]}
+
{"id":"Coves-e16","content_hash":"7c5d0fc8f0e7f626be3dad62af0e8412467330bad01a244e5a7e52ac5afff1c1","title":"Complete post creation and moderation features","description":"","status":"open","priority":1,"issue_type":"feature","created_at":"2025-11-17T20:30:12.885991306-08:00","updated_at":"2025-11-17T20:30:12.885991306-08:00","source_repo":"."}
+
{"id":"Coves-fce","content_hash":"26b3e16b99f827316ee0d741cc959464bd0c813446c95aef8105c7fd1e6b09ff","title":"Implement aggregator feed federation","description":"","status":"open","priority":1,"issue_type":"feature","created_at":"2025-11-17T20:30:21.453326012-08:00","updated_at":"2025-11-17T20:30:21.453326012-08:00","source_repo":"."}
+1
.beads/beads.left.meta.json
···
+
{"version":"0.23.1","timestamp":"2025-12-02T18:25:24.009187871-08:00","commit":"00d7d8d"}
-3
internal/api/handlers/vote/errors.go
···
case errors.Is(err, votes.ErrVoteNotFound):
// Matches: social.coves.feed.vote.delete#VoteNotFound
writeError(w, http.StatusNotFound, "VoteNotFound", "No vote found for this subject")
-
case errors.Is(err, votes.ErrSubjectNotFound):
-
// Matches: social.coves.feed.vote.create#SubjectNotFound
-
writeError(w, http.StatusNotFound, "SubjectNotFound", "The subject post or comment was not found")
case errors.Is(err, votes.ErrInvalidDirection):
writeError(w, http.StatusBadRequest, "InvalidRequest", "Vote direction must be 'up' or 'down'")
case errors.Is(err, votes.ErrInvalidSubject):
+4 -4
internal/atproto/oauth/handlers_security.go
···
// - Android: Verified via /.well-known/assetlinks.json
var allowedMobileRedirectURIs = map[string]bool{
// Custom scheme per atproto spec (reverse-domain of coves.social)
-
"social.coves:/callback": true,
-
"social.coves://callback": true, // Some platforms add double slash
-
"social.coves:/oauth/callback": true, // Alternative path
-
"social.coves://oauth/callback": true,
+
"social.coves:/callback": true,
+
"social.coves://callback": true, // Some platforms add double slash
+
"social.coves:/oauth/callback": true, // Alternative path
+
"social.coves://oauth/callback": true,
// Universal Links - cryptographically bound to app (preferred for security)
"https://coves.social/app/oauth/callback": true,
}
-3
internal/core/votes/errors.go
···
// ErrVoteNotFound indicates the requested vote doesn't exist
ErrVoteNotFound = errors.New("vote not found")
-
// ErrSubjectNotFound indicates the post/comment being voted on doesn't exist
-
ErrSubjectNotFound = errors.New("subject not found")
-
// ErrInvalidDirection indicates the vote direction is not "up" or "down"
ErrInvalidDirection = errors.New("invalid vote direction: must be 'up' or 'down'")
+3 -2
internal/db/postgres/vote_repo.go
···
return nil
}
-
// GetByURI retrieves a vote by its AT-URI
+
// GetByURI retrieves an active vote by its AT-URI
// Used by Jetstream consumer for DELETE operations
+
// Returns ErrVoteNotFound for soft-deleted votes
func (r *postgresVoteRepo) GetByURI(ctx context.Context, uri string) (*votes.Vote, error) {
query := `
SELECT
···
subject_uri, subject_cid, direction,
created_at, indexed_at, deleted_at
FROM votes
-
WHERE uri = $1
+
WHERE uri = $1 AND deleted_at IS NULL
`
var vote votes.Vote
+92
.env.dev.example
···
+
# Coves Local Development Environment Configuration
+
# Copy this to .env.dev and fill in your values
+
#
+
# Quick Start:
+
# 1. cp .env.dev.example .env.dev
+
# 2. Generate OAuth key: go run cmd/genjwks/main.go (copy output to OAUTH_PRIVATE_JWK)
+
# 3. Generate cookie secret: openssl rand -hex 32
+
# 4. make dev-up # Start Docker services
+
# 5. make run # Start the server (uses -tags dev)
+
+
# =============================================================================
+
# Dev Mode Quick Reference
+
# =============================================================================
+
# REQUIRED for local OAuth to work with local PDS:
+
# IS_DEV_ENV=true # Master switch for dev mode
+
# PDS_URL=http://localhost:3001 # Local PDS for handle resolution
+
# PLC_DIRECTORY_URL=http://localhost:3002 # Local PLC directory
+
# APPVIEW_PUBLIC_URL=http://127.0.0.1:8081 # Use IP not localhost (RFC 8252)
+
#
+
# BUILD TAGS:
+
# make run - Runs with -tags dev (includes localhost OAuth resolvers)
+
# make build - Production binary (no dev code)
+
# make build-dev - Dev binary (includes dev code)
+
+
# =============================================================================
+
# PostgreSQL Configuration
+
# =============================================================================
+
POSTGRES_HOST=localhost
+
POSTGRES_PORT=5435
+
POSTGRES_DB=coves_dev
+
POSTGRES_USER=dev_user
+
POSTGRES_PASSWORD=dev_password
+
+
# Test database
+
POSTGRES_TEST_DB=coves_test
+
POSTGRES_TEST_USER=test_user
+
POSTGRES_TEST_PASSWORD=test_password
+
POSTGRES_TEST_PORT=5434
+
+
# =============================================================================
+
# PDS Configuration
+
# =============================================================================
+
PDS_HOSTNAME=localhost
+
PDS_PORT=3001
+
PDS_SERVICE_ENDPOINT=http://localhost:3000
+
PDS_DID_PLC_URL=http://plc-directory:3000
+
PDS_JWT_SECRET=local-dev-jwt-secret-change-in-production
+
PDS_ADMIN_PASSWORD=admin
+
PDS_SERVICE_HANDLE_DOMAINS=.local.coves.dev,.community.coves.social
+
PDS_PLC_ROTATION_KEY=<generate-a-random-hex-key>
+
+
# =============================================================================
+
# AppView Configuration
+
# =============================================================================
+
APPVIEW_PORT=8081
+
FIREHOSE_URL=ws://localhost:3001/xrpc/com.atproto.sync.subscribeRepos
+
PDS_URL=http://localhost:3001
+
APPVIEW_PUBLIC_URL=http://127.0.0.1:8081
+
+
# =============================================================================
+
# Jetstream Configuration
+
# =============================================================================
+
JETSTREAM_URL=ws://localhost:6008/subscribe
+
+
# =============================================================================
+
# Identity Resolution
+
# =============================================================================
+
IDENTITY_CACHE_TTL=24h
+
PLC_DIRECTORY_URL=http://localhost:3002
+
+
# =============================================================================
+
# OAuth Configuration (MUST GENERATE YOUR OWN)
+
# =============================================================================
+
# Generate with: go run cmd/genjwks/main.go
+
OAUTH_PRIVATE_JWK=<generate-your-own-jwk>
+
+
# Generate with: openssl rand -hex 32
+
OAUTH_COOKIE_SECRET=<generate-your-own-secret>
+
+
# =============================================================================
+
# Development Settings
+
# =============================================================================
+
ENV=development
+
NODE_ENV=development
+
IS_DEV_ENV=true
+
LOG_LEVEL=debug
+
LOG_ENABLED=true
+
+
# Security settings (ONLY for local dev - set to false in production!)
+
SKIP_DID_WEB_VERIFICATION=true
+
AUTH_SKIP_VERIFY=true
+
HS256_ISSUERS=http://localhost:3001
+25 -3
Makefile
···
-
.PHONY: help dev-up dev-down dev-logs dev-status dev-reset test e2e-test clean
+
.PHONY: help dev-up dev-down dev-logs dev-status dev-reset test e2e-test clean verify-stack create-test-account mobile-full-setup
# Default target - show help
.DEFAULT_GOAL := help
···
##@ Build & Run
-
build: ## Build the Coves server
-
@echo "$(GREEN)Building Coves server...$(RESET)"
+
build: ## Build the Coves server (production - no dev code)
+
@echo "$(GREEN)Building Coves server (production)...$(RESET)"
@go build -o server ./cmd/server
@echo "$(GREEN)โœ“ Build complete: ./server$(RESET)"
+
build-dev: ## Build the Coves server with dev mode (includes localhost OAuth resolvers)
+
@echo "$(GREEN)Building Coves server (dev mode)...$(RESET)"
+
@go build -tags dev -o server ./cmd/server
+
@echo "$(GREEN)โœ“ Build complete: ./server (with dev tags)$(RESET)"
+
run: ## Run the Coves server with dev environment (requires database running)
@./scripts/dev-run.sh
···
@adb reverse --remove-all || echo "$(YELLOW)No device connected$(RESET)"
@echo "$(GREEN)โœ“ Port forwarding removed$(RESET)"
+
verify-stack: ## Verify local development stack (PLC, PDS, configs)
+
@./scripts/verify-local-stack.sh
+
+
create-test-account: ## Create a test account on local PDS for OAuth testing
+
@./scripts/create-test-account.sh
+
+
mobile-full-setup: verify-stack create-test-account mobile-setup ## Full mobile setup: verify stack, create account, setup ports
+
@echo ""
+
@echo "$(GREEN)โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•$(RESET)"
+
@echo "$(GREEN) Mobile development environment ready! $(RESET)"
+
@echo "$(GREEN)โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•$(RESET)"
+
@echo ""
+
@echo "$(CYAN)Run the Flutter app with:$(RESET)"
+
@echo " $(YELLOW)cd /home/bretton/Code/coves-mobile$(RESET)"
+
@echo " $(YELLOW)flutter run --dart-define=ENVIRONMENT=local$(RESET)"
+
@echo ""
+
ngrok-up: ## Start ngrok tunnels (for iOS or WiFi testing - requires paid plan for 3 tunnels)
@echo "$(GREEN)Starting ngrok tunnels for mobile testing...$(RESET)"
@./scripts/start-ngrok.sh
+5 -1
docker-compose.dev.yml
···
# Bluesky Personal Data Server (PDS)
# Handles user repositories, DIDs, and CAR files
+
# NOTE: When using --profile plc, PDS waits for PLC directory to be healthy
pds:
image: ghcr.io/bluesky-social/pds:latest
container_name: coves-dev-pds
···
PDS_PORT: 3001 # Match external port for correct DID registration
PDS_DATA_DIRECTORY: /pds
PDS_BLOBSTORE_DISK_LOCATION: /pds/blocks
-
PDS_DID_PLC_URL: ${PDS_DID_PLC_URL:-https://plc.directory}
+
# IMPORTANT: For local E2E testing, this MUST point to local PLC directory
+
# Default to local PLC (http://plc-directory:3000) for full local stack
+
# The container hostname 'plc-directory' is used for Docker network communication
+
PDS_DID_PLC_URL: ${PDS_DID_PLC_URL:-http://plc-directory:3000}
# PDS_CRAWLERS not needed - we're not using a relay for local dev
# Note: PDS uses its own internal SQLite database and CAR file storage
+285
internal/atproto/oauth/dev_auth_resolver.go
···
+
//go:build dev
+
+
package oauth
+
+
import (
+
"context"
+
"encoding/json"
+
"fmt"
+
"log/slog"
+
"net/http"
+
"net/url"
+
"strings"
+
+
oauthlib "github.com/bluesky-social/indigo/atproto/auth/oauth"
+
"github.com/bluesky-social/indigo/atproto/identity"
+
"github.com/bluesky-social/indigo/atproto/syntax"
+
)
+
+
// DevAuthResolver is a custom OAuth resolver that allows HTTP localhost URLs for development.
+
// The standard indigo OAuth resolver requires HTTPS and no port numbers, which breaks local testing.
+
type DevAuthResolver struct {
+
Client *http.Client
+
UserAgent string
+
PDSURL string // For resolving handles via local PDS
+
handleResolver *DevHandleResolver
+
}
+
+
// ProtectedResourceMetadata matches the OAuth protected resource metadata document format
+
type ProtectedResourceMetadata struct {
+
Resource string `json:"resource"`
+
AuthorizationServers []string `json:"authorization_servers"`
+
}
+
+
// NewDevAuthResolver creates a resolver that accepts localhost HTTP URLs
+
func NewDevAuthResolver(pdsURL string, allowPrivateIPs bool) *DevAuthResolver {
+
resolver := &DevAuthResolver{
+
Client: NewSSRFSafeHTTPClient(allowPrivateIPs),
+
UserAgent: "Coves/1.0",
+
PDSURL: pdsURL,
+
}
+
// Create handle resolver for resolving handles via local PDS
+
if pdsURL != "" {
+
resolver.handleResolver = NewDevHandleResolver(pdsURL, allowPrivateIPs)
+
}
+
return resolver
+
}
+
+
// ResolveAuthServerURL resolves a PDS URL to an auth server URL.
+
// Unlike indigo's standard resolver, this allows HTTP and ports for localhost.
+
func (r *DevAuthResolver) ResolveAuthServerURL(ctx context.Context, hostURL string) (string, error) {
+
u, err := url.Parse(hostURL)
+
if err != nil {
+
return "", err
+
}
+
+
// For localhost, allow HTTP and port numbers
+
isLocalhost := u.Hostname() == "localhost" || u.Hostname() == "127.0.0.1"
+
if !isLocalhost {
+
// For non-localhost, enforce HTTPS and no port (standard rules)
+
if u.Scheme != "https" || u.Port() != "" {
+
return "", fmt.Errorf("not a valid public host URL: %s", hostURL)
+
}
+
}
+
+
// Build the protected resource document URL
+
var docURL string
+
if isLocalhost {
+
// For localhost, preserve the port and use HTTP
+
port := u.Port()
+
if port == "" {
+
port = "3001" // Default PDS port
+
}
+
docURL = fmt.Sprintf("http://%s:%s/.well-known/oauth-protected-resource", u.Hostname(), port)
+
} else {
+
docURL = fmt.Sprintf("https://%s/.well-known/oauth-protected-resource", u.Hostname())
+
}
+
+
// Fetch the protected resource document
+
req, err := http.NewRequestWithContext(ctx, "GET", docURL, nil)
+
if err != nil {
+
return "", err
+
}
+
if r.UserAgent != "" {
+
req.Header.Set("User-Agent", r.UserAgent)
+
}
+
+
resp, err := r.Client.Do(req)
+
if err != nil {
+
return "", fmt.Errorf("fetching protected resource document: %w", err)
+
}
+
defer resp.Body.Close()
+
+
if resp.StatusCode != http.StatusOK {
+
return "", fmt.Errorf("HTTP error fetching protected resource document: %d", resp.StatusCode)
+
}
+
+
var body ProtectedResourceMetadata
+
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
+
return "", fmt.Errorf("invalid protected resource document: %w", err)
+
}
+
+
if len(body.AuthorizationServers) < 1 {
+
return "", fmt.Errorf("no auth server URL in protected resource document")
+
}
+
+
authURL := body.AuthorizationServers[0]
+
+
// Validate the auth server URL (with localhost exception)
+
au, err := url.Parse(authURL)
+
if err != nil {
+
return "", fmt.Errorf("invalid auth server URL: %w", err)
+
}
+
+
authIsLocalhost := au.Hostname() == "localhost" || au.Hostname() == "127.0.0.1"
+
if !authIsLocalhost {
+
if au.Scheme != "https" || au.Port() != "" {
+
return "", fmt.Errorf("invalid auth server URL: %s", authURL)
+
}
+
}
+
+
return authURL, nil
+
}
+
+
// ResolveAuthServerMetadataDev fetches OAuth server metadata from a given auth server URL.
+
// Unlike indigo's resolver, this allows HTTP and ports for localhost.
+
func (r *DevAuthResolver) ResolveAuthServerMetadataDev(ctx context.Context, serverURL string) (*oauthlib.AuthServerMetadata, error) {
+
u, err := url.Parse(serverURL)
+
if err != nil {
+
return nil, err
+
}
+
+
// Build metadata URL - preserve port for localhost
+
var metaURL string
+
isLocalhost := u.Hostname() == "localhost" || u.Hostname() == "127.0.0.1"
+
if isLocalhost && u.Port() != "" {
+
metaURL = fmt.Sprintf("%s://%s:%s/.well-known/oauth-authorization-server", u.Scheme, u.Hostname(), u.Port())
+
} else if isLocalhost {
+
metaURL = fmt.Sprintf("%s://%s/.well-known/oauth-authorization-server", u.Scheme, u.Hostname())
+
} else {
+
metaURL = fmt.Sprintf("https://%s/.well-known/oauth-authorization-server", u.Hostname())
+
}
+
+
slog.Debug("dev mode: fetching auth server metadata", "url", metaURL)
+
+
req, err := http.NewRequestWithContext(ctx, "GET", metaURL, nil)
+
if err != nil {
+
return nil, err
+
}
+
if r.UserAgent != "" {
+
req.Header.Set("User-Agent", r.UserAgent)
+
}
+
+
resp, err := r.Client.Do(req)
+
if err != nil {
+
return nil, fmt.Errorf("fetching auth server metadata: %w", err)
+
}
+
defer resp.Body.Close()
+
+
if resp.StatusCode != http.StatusOK {
+
return nil, fmt.Errorf("HTTP error fetching auth server metadata: %d", resp.StatusCode)
+
}
+
+
var metadata oauthlib.AuthServerMetadata
+
if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil {
+
return nil, fmt.Errorf("invalid auth server metadata: %w", err)
+
}
+
+
// Skip validation for localhost (indigo's Validate checks HTTPS)
+
if !isLocalhost {
+
if err := metadata.Validate(serverURL); err != nil {
+
return nil, fmt.Errorf("invalid auth server metadata: %w", err)
+
}
+
}
+
+
return &metadata, nil
+
}
+
+
// StartDevAuthFlow performs OAuth flow for localhost development.
+
// This bypasses indigo's HTTPS validation for the auth server URL.
+
// It resolves the identity, gets the PDS endpoint, fetches auth server metadata,
+
// and returns a redirect URL for the user to approve.
+
func (r *DevAuthResolver) StartDevAuthFlow(ctx context.Context, client *OAuthClient, identifier string, dir identity.Directory) (string, error) {
+
var accountDID syntax.DID
+
var pdsEndpoint string
+
+
// Check if identifier is a handle or DID
+
if strings.HasPrefix(identifier, "did:") {
+
// It's a DID - look up via directory (PLC)
+
atid, err := syntax.ParseAtIdentifier(identifier)
+
if err != nil {
+
return "", fmt.Errorf("not a valid DID (%s): %w", identifier, err)
+
}
+
ident, err := dir.Lookup(ctx, *atid)
+
if err != nil {
+
return "", fmt.Errorf("failed to resolve DID (%s): %w", identifier, err)
+
}
+
accountDID = ident.DID
+
pdsEndpoint = ident.PDSEndpoint()
+
} else {
+
// It's a handle - resolve via local PDS first
+
if r.handleResolver == nil {
+
return "", fmt.Errorf("handle resolution not configured (PDS URL not set)")
+
}
+
+
// Resolve handle to DID via local PDS
+
did, err := r.handleResolver.ResolveHandle(ctx, identifier)
+
if err != nil {
+
return "", fmt.Errorf("failed to resolve handle via PDS (%s): %w", identifier, err)
+
}
+
if did == "" {
+
return "", fmt.Errorf("handle not found: %s", identifier)
+
}
+
+
slog.Info("dev mode: resolved handle via local PDS", "handle", identifier, "did", did)
+
+
// Parse the DID
+
parsedDID, err := syntax.ParseDID(did)
+
if err != nil {
+
return "", fmt.Errorf("invalid DID from PDS (%s): %w", did, err)
+
}
+
accountDID = parsedDID
+
+
// Now look up the DID document via PLC to get PDS endpoint
+
atid, err := syntax.ParseAtIdentifier(did)
+
if err != nil {
+
return "", fmt.Errorf("not a valid DID (%s): %w", did, err)
+
}
+
ident, err := dir.Lookup(ctx, *atid)
+
if err != nil {
+
return "", fmt.Errorf("failed to resolve DID document (%s): %w", did, err)
+
}
+
pdsEndpoint = ident.PDSEndpoint()
+
}
+
+
if pdsEndpoint == "" {
+
return "", fmt.Errorf("identity does not link to an atproto host (PDS)")
+
}
+
+
slog.Debug("dev mode: resolving auth server",
+
"did", accountDID,
+
"pds", pdsEndpoint)
+
+
// Resolve auth server URL (allowing HTTP for localhost)
+
authServerURL, err := r.ResolveAuthServerURL(ctx, pdsEndpoint)
+
if err != nil {
+
return "", fmt.Errorf("resolving auth server: %w", err)
+
}
+
+
slog.Info("dev mode: resolved auth server", "url", authServerURL)
+
+
// Fetch auth server metadata using our dev-friendly resolver
+
authMeta, err := r.ResolveAuthServerMetadataDev(ctx, authServerURL)
+
if err != nil {
+
return "", fmt.Errorf("fetching auth server metadata: %w", err)
+
}
+
+
slog.Debug("dev mode: got auth server metadata",
+
"issuer", authMeta.Issuer,
+
"authorization_endpoint", authMeta.AuthorizationEndpoint,
+
"token_endpoint", authMeta.TokenEndpoint)
+
+
// Send auth request (PAR) using indigo's method
+
info, err := client.ClientApp.SendAuthRequest(ctx, authMeta, client.Config.Scopes, identifier)
+
if err != nil {
+
return "", fmt.Errorf("auth request failed: %w", err)
+
}
+
+
// Set the account DID
+
info.AccountDID = &accountDID
+
+
// Persist auth request info
+
client.ClientApp.Store.SaveAuthRequestInfo(ctx, *info)
+
+
// Build redirect URL
+
params := url.Values{}
+
params.Set("client_id", client.ClientApp.Config.ClientID)
+
params.Set("request_uri", info.RequestURI)
+
+
authEndpoint := authMeta.AuthorizationEndpoint
+
redirectURL := fmt.Sprintf("%s?%s", authEndpoint, params.Encode())
+
+
slog.Info("dev mode: OAuth redirect URL built", "url_prefix", authEndpoint)
+
+
return redirectURL, nil
+
}
+106
internal/atproto/oauth/dev_resolver.go
···
+
//go:build dev
+
+
package oauth
+
+
import (
+
"context"
+
"encoding/json"
+
"fmt"
+
"log/slog"
+
"net/http"
+
"net/url"
+
"strings"
+
"time"
+
)
+
+
// DevHandleResolver resolves handles via local PDS for development
+
// This is needed because local handles (e.g., user.local.coves.dev) can't be
+
// resolved via standard DNS/HTTP well-known methods - they only exist on the local PDS.
+
type DevHandleResolver struct {
+
pdsURL string
+
httpClient *http.Client
+
}
+
+
// NewDevHandleResolver creates a resolver that queries local PDS for handle resolution
+
func NewDevHandleResolver(pdsURL string, allowPrivateIPs bool) *DevHandleResolver {
+
return &DevHandleResolver{
+
pdsURL: strings.TrimSuffix(pdsURL, "/"),
+
httpClient: NewSSRFSafeHTTPClient(allowPrivateIPs),
+
}
+
}
+
+
// ResolveHandle queries the local PDS to resolve a handle to a DID
+
// Returns the DID if successful, or empty string if not found
+
func (r *DevHandleResolver) ResolveHandle(ctx context.Context, handle string) (string, error) {
+
if r.pdsURL == "" {
+
return "", fmt.Errorf("PDS URL not configured")
+
}
+
+
// Build the resolve handle URL
+
resolveURL := fmt.Sprintf("%s/xrpc/com.atproto.identity.resolveHandle?handle=%s",
+
r.pdsURL, url.QueryEscape(handle))
+
+
// Create request with context and timeout
+
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
+
defer cancel()
+
+
req, err := http.NewRequestWithContext(ctx, "GET", resolveURL, nil)
+
if err != nil {
+
return "", fmt.Errorf("failed to create request: %w", err)
+
}
+
req.Header.Set("User-Agent", "Coves/1.0")
+
+
// Execute request
+
resp, err := r.httpClient.Do(req)
+
if err != nil {
+
return "", fmt.Errorf("failed to query PDS: %w", err)
+
}
+
defer resp.Body.Close()
+
+
// Check response status
+
if resp.StatusCode == http.StatusNotFound || resp.StatusCode == http.StatusBadRequest {
+
return "", nil // Handle not found
+
}
+
if resp.StatusCode != http.StatusOK {
+
return "", fmt.Errorf("PDS returned status %d", resp.StatusCode)
+
}
+
+
// Parse response
+
var result struct {
+
DID string `json:"did"`
+
}
+
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
+
return "", fmt.Errorf("failed to parse PDS response: %w", err)
+
}
+
+
if result.DID == "" {
+
return "", nil // No DID in response
+
}
+
+
slog.Debug("resolved handle via local PDS",
+
"handle", handle,
+
"did", result.DID,
+
"pds_url", r.pdsURL)
+
+
return result.DID, nil
+
}
+
+
// ResolveIdentifier attempts to resolve a handle to DID, or returns the DID if already provided
+
// This is the main entry point for the handlers
+
func (r *DevHandleResolver) ResolveIdentifier(ctx context.Context, identifier string) (string, error) {
+
// If it's already a DID, return as-is
+
if strings.HasPrefix(identifier, "did:") {
+
return identifier, nil
+
}
+
+
// Try to resolve the handle via local PDS
+
did, err := r.ResolveHandle(ctx, identifier)
+
if err != nil {
+
return "", fmt.Errorf("failed to resolve handle via PDS: %w", err)
+
}
+
if did == "" {
+
return "", fmt.Errorf("handle not found on local PDS: %s", identifier)
+
}
+
+
return did, nil
+
}
+41
internal/atproto/oauth/dev_stubs.go
···
+
//go:build !dev
+
+
package oauth
+
+
import (
+
"context"
+
+
"github.com/bluesky-social/indigo/atproto/identity"
+
)
+
+
// DevHandleResolver is a stub for production builds.
+
// The actual implementation is in dev_resolver.go (only compiled with -tags dev).
+
type DevHandleResolver struct{}
+
+
// NewDevHandleResolver returns nil in production builds.
+
// Dev mode features are only available when built with -tags dev.
+
func NewDevHandleResolver(pdsURL string, allowPrivateIPs bool) *DevHandleResolver {
+
return nil
+
}
+
+
// ResolveHandle is a stub that should never be called in production.
+
// The nil check in handlers.go prevents this from being reached.
+
func (r *DevHandleResolver) ResolveHandle(ctx context.Context, handle string) (string, error) {
+
panic("dev mode: ResolveHandle called in production build - this should never happen")
+
}
+
+
// DevAuthResolver is a stub for production builds.
+
// The actual implementation is in dev_auth_resolver.go (only compiled with -tags dev).
+
type DevAuthResolver struct{}
+
+
// NewDevAuthResolver returns nil in production builds.
+
// Dev mode features are only available when built with -tags dev.
+
func NewDevAuthResolver(pdsURL string, allowPrivateIPs bool) *DevAuthResolver {
+
return nil
+
}
+
+
// StartDevAuthFlow is a stub that should never be called in production.
+
// The nil check in handlers.go prevents this from being reached.
+
func (r *DevAuthResolver) StartDevAuthFlow(ctx context.Context, client *OAuthClient, identifier string, dir identity.Directory) (string, error) {
+
panic("dev mode: StartDevAuthFlow called in production build - this should never happen")
+
}
+107 -15
internal/atproto/oauth/handlers.go
···
"log/slog"
"net/http"
"net/url"
+
"strings"
"github.com/bluesky-social/indigo/atproto/auth/oauth"
"github.com/bluesky-social/indigo/atproto/syntax"
···
// OAuthHandler handles OAuth-related HTTP endpoints
type OAuthHandler struct {
-
client *OAuthClient
-
store oauth.ClientAuthStore
-
mobileStore MobileOAuthStore // For server-side CSRF validation
+
client *OAuthClient
+
store oauth.ClientAuthStore
+
mobileStore MobileOAuthStore // For server-side CSRF validation
+
devResolver *DevHandleResolver // For dev mode: resolve handles via local PDS
+
devAuthResolver *DevAuthResolver // For dev mode: bypass HTTPS validation for localhost OAuth
}
// NewOAuthHandler creates a new OAuth handler
···
handler.mobileStore = mobileStore
}
+
// In dev mode, create resolvers for local PDS/PLC
+
// This is needed because:
+
// 1. Local handles (e.g., user.local.coves.dev) can't be resolved via DNS/HTTP
+
// 2. Indigo's OAuth library requires HTTPS, which localhost doesn't have
+
if client.Config.DevMode {
+
if client.Config.PDSURL != "" {
+
handler.devResolver = NewDevHandleResolver(client.Config.PDSURL, client.Config.AllowPrivateIPs)
+
slog.Info("dev mode: handle resolution via local PDS enabled", "pds_url", client.Config.PDSURL)
+
}
+
// Create dev auth resolver to bypass HTTPS validation (pass PDS URL for handle resolution)
+
handler.devAuthResolver = NewDevAuthResolver(client.Config.PDSURL, client.Config.AllowPrivateIPs)
+
slog.Info("dev mode: localhost OAuth auth resolver enabled", "pds_url", client.Config.PDSURL)
+
}
+
return handler
}
···
return
}
-
// Start OAuth flow
-
redirectURL, err := h.client.ClientApp.StartAuthFlow(ctx, identifier)
-
if err != nil {
-
slog.Error("failed to start OAuth flow", "error", err, "identifier", identifier)
-
http.Error(w, fmt.Sprintf("failed to start OAuth flow: %v", err), http.StatusBadRequest)
-
return
+
var redirectURL string
+
var err error
+
+
// DEV MODE: Use custom OAuth flow that bypasses HTTPS validation
+
// This is needed because:
+
// 1. Local handles can't be resolved via DNS/HTTP well-known
+
// 2. Indigo's OAuth library requires HTTPS for auth servers
+
if h.devAuthResolver != nil {
+
slog.Info("dev mode: using localhost OAuth flow", "identifier", identifier)
+
redirectURL, err = h.devAuthResolver.StartDevAuthFlow(ctx, h.client, identifier, h.client.ClientApp.Dir)
+
if err != nil {
+
slog.Error("dev mode: failed to start OAuth flow", "error", err, "identifier", identifier)
+
http.Error(w, fmt.Sprintf("failed to start OAuth flow: %v", err), http.StatusBadRequest)
+
return
+
}
+
} else {
+
// Production mode: use standard indigo OAuth flow
+
redirectURL, err = h.client.ClientApp.StartAuthFlow(ctx, identifier)
+
if err != nil {
+
slog.Error("failed to start OAuth flow", "error", err, "identifier", identifier)
+
http.Error(w, fmt.Sprintf("failed to start OAuth flow: %v", err), http.StatusBadRequest)
+
return
+
}
}
// Log OAuth flow initiation (sanitized - no full URL to avoid leaking state)
···
func (h *OAuthHandler) HandleMobileLogin(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
+
// DEV MODE: Redirect localhost to 127.0.0.1 for cookie consistency
+
// The OAuth callback URL uses 127.0.0.1 (per RFC 8252), so cookies must be set
+
// on 127.0.0.1. If user calls localhost, redirect to 127.0.0.1 first.
+
if h.client.Config.DevMode && strings.Contains(r.Host, "localhost") {
+
// Use the configured PublicURL host for consistency
+
redirectURL := h.client.Config.PublicURL + r.URL.RequestURI()
+
slog.Info("dev mode: redirecting localhost to PublicURL host for cookie consistency",
+
"from", r.Host, "to", h.client.Config.PublicURL)
+
http.Redirect(w, r, redirectURL, http.StatusFound)
+
return
+
}
+
// Get handle or DID from query params
identifier := r.URL.Query().Get("handle")
if identifier == "" {
···
RedirectURI: mobileRedirectURI,
})
-
// Start OAuth flow (the store wrapper will save mobile data when auth request is saved)
-
redirectURL, err := h.client.ClientApp.StartAuthFlow(mobileCtx, identifier)
-
if err != nil {
-
slog.Error("failed to start OAuth flow", "error", err, "identifier", identifier)
-
http.Error(w, fmt.Sprintf("failed to start OAuth flow: %v", err), http.StatusBadRequest)
-
return
+
var redirectURL string
+
+
// DEV MODE: Use custom OAuth flow that bypasses HTTPS validation
+
// This is needed because:
+
// 1. Local handles can't be resolved via DNS/HTTP well-known
+
// 2. Indigo's OAuth library requires HTTPS for auth servers
+
if h.devAuthResolver != nil {
+
slog.Info("dev mode: using localhost OAuth flow for mobile", "identifier", identifier)
+
redirectURL, err = h.devAuthResolver.StartDevAuthFlow(mobileCtx, h.client, identifier, h.client.ClientApp.Dir)
+
if err != nil {
+
slog.Error("dev mode: failed to start OAuth flow", "error", err, "identifier", identifier)
+
http.Error(w, fmt.Sprintf("failed to start OAuth flow: %v", err), http.StatusBadRequest)
+
return
+
}
+
} else {
+
// Production mode: use standard indigo OAuth flow
+
redirectURL, err = h.client.ClientApp.StartAuthFlow(mobileCtx, identifier)
+
if err != nil {
+
slog.Error("failed to start OAuth flow", "error", err, "identifier", identifier)
+
http.Error(w, fmt.Sprintf("failed to start OAuth flow: %v", err), http.StatusBadRequest)
+
return
+
}
}
// Log mobile OAuth flow initiation (sanitized - no full URLs or sensitive params)
···
// Check if the handle is the special "handle.invalid" value
// This indicates that bidirectional verification failed (DID->handle->DID roundtrip failed)
if ident.Handle.String() == "handle.invalid" {
+
// DEV MODE: For local handles, verify via PDS instead of DNS/HTTP
+
// Local handles like "user.local.coves.dev" can't be resolved via DNS
+
if h.devResolver != nil {
+
// Get the handle from DID document (alsoKnownAs)
+
declaredHandle := ""
+
if len(ident.AlsoKnownAs) > 0 {
+
// Extract handle from at:// URI
+
for _, aka := range ident.AlsoKnownAs {
+
if len(aka) > 5 && aka[:5] == "at://" {
+
declaredHandle = aka[5:]
+
break
+
}
+
}
+
}
+
+
if declaredHandle != "" {
+
// Verify handle via PDS
+
resolvedDID, err := h.devResolver.ResolveHandle(ctx, declaredHandle)
+
if err == nil && resolvedDID == sessData.AccountDID.String() {
+
slog.Info("OAuth callback successful (dev mode: handle verified via PDS)",
+
"did", sessData.AccountDID, "handle", declaredHandle)
+
goto handleVerificationPassed
+
}
+
slog.Warn("dev mode: PDS handle verification failed",
+
"did", sessData.AccountDID, "handle", declaredHandle,
+
"resolved_did", resolvedDID, "error", err)
+
}
+
}
+
slog.Warn("OAuth callback: bidirectional handle verification failed",
"did", sessData.AccountDID,
"handle", "handle.invalid",
···
"did", sessData.AccountDID)
slog.Info("OAuth callback successful (no handle verification)", "did", sessData.AccountDID)
}
+
handleVerificationPassed:
// Check if this is a mobile callback (check for mobile_redirect_uri cookie)
mobileRedirect, err := r.Cookie("mobile_redirect_uri")
+5 -1
scripts/dev-run.sh
···
#!/bin/bash
# Development server runner - loads .env.dev before starting
+
# Uses -tags dev to include dev-only code (localhost OAuth resolvers, etc.)
set -a # automatically export all variables
source .env.dev
···
echo " IS_DEV_ENV: $IS_DEV_ENV"
echo " PLC_DIRECTORY_URL: $PLC_DIRECTORY_URL"
echo " JETSTREAM_URL: $JETSTREAM_URL"
+
echo " APPVIEW_PUBLIC_URL: $APPVIEW_PUBLIC_URL"
+
echo " PDS_URL: $PDS_URL"
+
echo " Build tags: dev"
echo ""
-
go run ./cmd/server
+
go run -tags dev ./cmd/server
+125
internal/atproto/pds/factory.go
···
+
package pds
+
+
import (
+
"context"
+
"fmt"
+
"net/http"
+
+
"github.com/bluesky-social/indigo/atproto/atclient"
+
"github.com/bluesky-social/indigo/atproto/auth/oauth"
+
"github.com/bluesky-social/indigo/atproto/syntax"
+
)
+
+
// NewFromOAuthSession creates a PDS client from an OAuth session.
+
// This uses DPoP authentication - the correct method for OAuth tokens.
+
//
+
// The oauthClient is used to resume the session and get a properly configured
+
// APIClient that handles DPoP proof generation and nonce rotation automatically.
+
func NewFromOAuthSession(ctx context.Context, oauthClient *oauth.ClientApp, sessionData *oauth.ClientSessionData) (Client, error) {
+
if oauthClient == nil {
+
return nil, fmt.Errorf("oauthClient is required")
+
}
+
if sessionData == nil {
+
return nil, fmt.Errorf("sessionData is required")
+
}
+
+
// ResumeSession reconstructs the OAuth session with DPoP key
+
// and returns a ClientSession that can generate authenticated requests
+
sess, err := oauthClient.ResumeSession(ctx, sessionData.AccountDID, sessionData.SessionID)
+
if err != nil {
+
return nil, fmt.Errorf("failed to resume OAuth session: %w", err)
+
}
+
+
// APIClient() returns an *atclient.APIClient configured with DPoP auth
+
apiClient := sess.APIClient()
+
+
return &client{
+
apiClient: apiClient,
+
did: sessionData.AccountDID.String(),
+
host: sessionData.HostURL,
+
}, nil
+
}
+
+
// NewFromPasswordAuth creates a PDS client using password authentication.
+
// This uses Bearer token authentication from com.atproto.server.createSession.
+
//
+
// Primarily used for:
+
// - E2E tests with local PDS
+
// - Development/debugging tools
+
// - Non-OAuth clients
+
//
+
// Note: This establishes a new session with the PDS. For repeated calls,
+
// consider using NewFromAccessToken if you already have a valid access token.
+
func NewFromPasswordAuth(ctx context.Context, host, handle, password string) (Client, error) {
+
if host == "" {
+
return nil, fmt.Errorf("host is required")
+
}
+
if handle == "" {
+
return nil, fmt.Errorf("handle is required")
+
}
+
if password == "" {
+
return nil, fmt.Errorf("password is required")
+
}
+
+
// LoginWithPasswordHost creates a session and returns an authenticated APIClient
+
// This handles the createSession call and Bearer token setup
+
apiClient, err := atclient.LoginWithPasswordHost(ctx, host, handle, password, "", nil)
+
if err != nil {
+
return nil, fmt.Errorf("failed to login with password: %w", err)
+
}
+
+
// Get DID from the authenticated client
+
did := ""
+
if apiClient.AccountDID != nil {
+
did = apiClient.AccountDID.String()
+
}
+
+
return &client{
+
apiClient: apiClient,
+
did: did,
+
host: host,
+
}, nil
+
}
+
+
// NewFromAccessToken creates a PDS client from an existing access token.
+
// This is useful when you already have a valid Bearer token (e.g., from createSession)
+
// and don't want to re-authenticate.
+
//
+
// WARNING: This creates a client with Bearer auth only. Do NOT use this with
+
// OAuth access tokens - those require DPoP proofs. Use NewFromOAuthSession instead.
+
func NewFromAccessToken(host, did, accessToken string) (Client, error) {
+
if host == "" {
+
return nil, fmt.Errorf("host is required")
+
}
+
if did == "" {
+
return nil, fmt.Errorf("did is required")
+
}
+
if accessToken == "" {
+
return nil, fmt.Errorf("accessToken is required")
+
}
+
+
// Create APIClient with Bearer auth
+
apiClient := atclient.NewAPIClient(host)
+
apiClient.Auth = &bearerAuth{token: accessToken}
+
+
return &client{
+
apiClient: apiClient,
+
did: did,
+
host: host,
+
}, nil
+
}
+
+
// bearerAuth implements atclient.AuthMethod for simple Bearer token auth.
+
// This is used for password-based sessions where DPoP is not required.
+
type bearerAuth struct {
+
token string
+
}
+
+
// Ensure bearerAuth implements atclient.AuthMethod.
+
var _ atclient.AuthMethod = (*bearerAuth)(nil)
+
+
// DoWithAuth adds the Bearer token to the request and executes it.
+
func (b *bearerAuth) DoWithAuth(c *http.Client, req *http.Request, _ syntax.NSID) (*http.Response, error) {
+
req.Header.Set("Authorization", "Bearer "+b.token)
+
return c.Do(req)
+
}
+18
tests/integration/helpers.go
···
import (
"Coves/internal/api/middleware"
"Coves/internal/atproto/oauth"
+
"Coves/internal/atproto/pds"
"Coves/internal/core/users"
+
"Coves/internal/core/votes"
"bytes"
"context"
"database/sql"
···
e.store.AddSessionWithPDS(did, sessionID, pdsAccessToken, pdsURL)
return token
}
+
+
// PasswordAuthPDSClientFactory creates a PDSClientFactory that uses password-based Bearer auth.
+
// This is for E2E tests that use createSession instead of OAuth.
+
// The factory extracts the access token and host URL from the session data.
+
func PasswordAuthPDSClientFactory() votes.PDSClientFactory {
+
return func(ctx context.Context, session *oauthlib.ClientSessionData) (pds.Client, error) {
+
if session.AccessToken == "" {
+
return nil, fmt.Errorf("session has no access token")
+
}
+
if session.HostURL == "" {
+
return nil, fmt.Errorf("session has no host URL")
+
}
+
+
return pds.NewFromAccessToken(session.HostURL, session.AccountDID.String(), session.AccessToken)
+
}
+
}
+267
cmd/reindex-votes/main.go
···
+
// cmd/reindex-votes/main.go
+
// Quick tool to reindex votes from PDS to AppView database
+
package main
+
+
import (
+
"context"
+
"database/sql"
+
"encoding/json"
+
"fmt"
+
"log"
+
"net/http"
+
"net/url"
+
"os"
+
"strings"
+
"time"
+
+
_ "github.com/lib/pq"
+
)
+
+
type ListRecordsResponse struct {
+
Records []Record `json:"records"`
+
Cursor string `json:"cursor"`
+
}
+
+
type Record struct {
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
Value map[string]interface{} `json:"value"`
+
}
+
+
func main() {
+
// Get config from env
+
dbURL := os.Getenv("DATABASE_URL")
+
if dbURL == "" {
+
dbURL = "postgres://dev_user:dev_password@localhost:5435/coves_dev?sslmode=disable"
+
}
+
pdsURL := os.Getenv("PDS_URL")
+
if pdsURL == "" {
+
pdsURL = "http://localhost:3001"
+
}
+
+
log.Printf("Connecting to database...")
+
db, err := sql.Open("postgres", dbURL)
+
if err != nil {
+
log.Fatalf("Failed to connect to database: %v", err)
+
}
+
defer db.Close()
+
+
ctx := context.Background()
+
+
// Get all accounts directly from the PDS
+
log.Printf("Fetching accounts from PDS (%s)...", pdsURL)
+
dids, err := fetchAllAccountsFromPDS(pdsURL)
+
if err != nil {
+
log.Fatalf("Failed to fetch accounts from PDS: %v", err)
+
}
+
log.Printf("Found %d accounts on PDS to check for votes", len(dids))
+
+
// Reset vote counts first
+
log.Printf("Resetting all vote counts...")
+
if _, err := db.ExecContext(ctx, "DELETE FROM votes"); err != nil {
+
log.Fatalf("Failed to clear votes table: %v", err)
+
}
+
if _, err := db.ExecContext(ctx, "UPDATE posts SET upvote_count = 0, downvote_count = 0, score = 0"); err != nil {
+
log.Fatalf("Failed to reset post vote counts: %v", err)
+
}
+
if _, err := db.ExecContext(ctx, "UPDATE comments SET upvote_count = 0, downvote_count = 0, score = 0"); err != nil {
+
log.Fatalf("Failed to reset comment vote counts: %v", err)
+
}
+
+
// For each user, fetch their votes from PDS
+
totalVotes := 0
+
for _, did := range dids {
+
votes, err := fetchVotesFromPDS(pdsURL, did)
+
if err != nil {
+
log.Printf("Warning: failed to fetch votes for %s: %v", did, err)
+
continue
+
}
+
+
if len(votes) == 0 {
+
continue
+
}
+
+
log.Printf("Found %d votes for %s", len(votes), did)
+
+
// Index each vote
+
for _, vote := range votes {
+
if err := indexVote(ctx, db, did, vote); err != nil {
+
log.Printf("Warning: failed to index vote %s: %v", vote.URI, err)
+
continue
+
}
+
totalVotes++
+
}
+
}
+
+
log.Printf("โœ“ Reindexed %d votes from PDS", totalVotes)
+
}
+
+
// fetchAllAccountsFromPDS queries the PDS sync API to get all repo DIDs
+
func fetchAllAccountsFromPDS(pdsURL string) ([]string, error) {
+
// Use com.atproto.sync.listRepos to get all repos on this PDS
+
var allDIDs []string
+
cursor := ""
+
+
for {
+
reqURL := fmt.Sprintf("%s/xrpc/com.atproto.sync.listRepos?limit=100", pdsURL)
+
if cursor != "" {
+
reqURL += "&cursor=" + url.QueryEscape(cursor)
+
}
+
+
resp, err := http.Get(reqURL)
+
if err != nil {
+
return nil, fmt.Errorf("HTTP request failed: %w", err)
+
}
+
defer resp.Body.Close()
+
+
if resp.StatusCode != 200 {
+
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
+
}
+
+
var result struct {
+
Repos []struct {
+
DID string `json:"did"`
+
} `json:"repos"`
+
Cursor string `json:"cursor"`
+
}
+
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
+
return nil, fmt.Errorf("failed to decode response: %w", err)
+
}
+
+
for _, repo := range result.Repos {
+
allDIDs = append(allDIDs, repo.DID)
+
}
+
+
if result.Cursor == "" {
+
break
+
}
+
cursor = result.Cursor
+
}
+
+
return allDIDs, nil
+
}
+
+
func fetchVotesFromPDS(pdsURL, did string) ([]Record, error) {
+
var allRecords []Record
+
cursor := ""
+
collection := "social.coves.feed.vote"
+
+
for {
+
reqURL := fmt.Sprintf("%s/xrpc/com.atproto.repo.listRecords?repo=%s&collection=%s&limit=100",
+
pdsURL, url.QueryEscape(did), url.QueryEscape(collection))
+
if cursor != "" {
+
reqURL += "&cursor=" + url.QueryEscape(cursor)
+
}
+
+
resp, err := http.Get(reqURL)
+
if err != nil {
+
return nil, fmt.Errorf("HTTP request failed: %w", err)
+
}
+
defer resp.Body.Close()
+
+
if resp.StatusCode == 400 {
+
// User doesn't exist on this PDS or has no records - that's OK
+
return nil, nil
+
}
+
if resp.StatusCode != 200 {
+
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
+
}
+
+
var result ListRecordsResponse
+
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
+
return nil, fmt.Errorf("failed to decode response: %w", err)
+
}
+
+
allRecords = append(allRecords, result.Records...)
+
+
if result.Cursor == "" {
+
break
+
}
+
cursor = result.Cursor
+
}
+
+
return allRecords, nil
+
}
+
+
func indexVote(ctx context.Context, db *sql.DB, voterDID string, record Record) error {
+
// Extract vote data from record
+
subject, ok := record.Value["subject"].(map[string]interface{})
+
if !ok {
+
return fmt.Errorf("missing subject")
+
}
+
subjectURI, _ := subject["uri"].(string)
+
subjectCID, _ := subject["cid"].(string)
+
direction, _ := record.Value["direction"].(string)
+
createdAtStr, _ := record.Value["createdAt"].(string)
+
+
if subjectURI == "" || direction == "" {
+
return fmt.Errorf("invalid vote record: missing required fields")
+
}
+
+
// Parse created_at
+
createdAt, err := time.Parse(time.RFC3339, createdAtStr)
+
if err != nil {
+
createdAt = time.Now()
+
}
+
+
// Extract rkey from URI (at://did/collection/rkey)
+
parts := strings.Split(record.URI, "/")
+
if len(parts) < 5 {
+
return fmt.Errorf("invalid URI format: %s", record.URI)
+
}
+
rkey := parts[len(parts)-1]
+
+
// Start transaction
+
tx, err := db.BeginTx(ctx, nil)
+
if err != nil {
+
return fmt.Errorf("failed to begin transaction: %w", err)
+
}
+
defer tx.Rollback()
+
+
// Insert vote
+
_, err = tx.ExecContext(ctx, `
+
INSERT INTO votes (uri, cid, rkey, voter_did, subject_uri, subject_cid, direction, created_at, indexed_at)
+
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, NOW())
+
ON CONFLICT (uri) DO NOTHING
+
`, record.URI, record.CID, rkey, voterDID, subjectURI, subjectCID, direction, createdAt)
+
if err != nil {
+
return fmt.Errorf("failed to insert vote: %w", err)
+
}
+
+
// Update post/comment counts
+
collection := extractCollectionFromURI(subjectURI)
+
var updateQuery string
+
+
switch collection {
+
case "social.coves.community.post":
+
if direction == "up" {
+
updateQuery = `UPDATE posts SET upvote_count = upvote_count + 1, score = upvote_count + 1 - downvote_count WHERE uri = $1 AND deleted_at IS NULL`
+
} else {
+
updateQuery = `UPDATE posts SET downvote_count = downvote_count + 1, score = upvote_count - (downvote_count + 1) WHERE uri = $1 AND deleted_at IS NULL`
+
}
+
case "social.coves.community.comment":
+
if direction == "up" {
+
updateQuery = `UPDATE comments SET upvote_count = upvote_count + 1, score = upvote_count + 1 - downvote_count WHERE uri = $1 AND deleted_at IS NULL`
+
} else {
+
updateQuery = `UPDATE comments SET downvote_count = downvote_count + 1, score = upvote_count - (downvote_count + 1) WHERE uri = $1 AND deleted_at IS NULL`
+
}
+
default:
+
// Unknown collection, just index the vote
+
return tx.Commit()
+
}
+
+
if _, err := tx.ExecContext(ctx, updateQuery, subjectURI); err != nil {
+
return fmt.Errorf("failed to update vote counts: %w", err)
+
}
+
+
return tx.Commit()
+
}
+
+
func extractCollectionFromURI(uri string) string {
+
// at://did:plc:xxx/social.coves.community.post/rkey
+
parts := strings.Split(uri, "/")
+
if len(parts) >= 4 {
+
return parts[3]
+
}
+
return ""
+
}
+7 -5
internal/api/routes/communityFeed.go
···
import (
"Coves/internal/api/handlers/communityFeed"
+
"Coves/internal/api/middleware"
"Coves/internal/core/communityFeeds"
+
"Coves/internal/core/votes"
"github.com/go-chi/chi/v5"
)
···
func RegisterCommunityFeedRoutes(
r chi.Router,
feedService communityFeeds.Service,
+
voteService votes.Service,
+
authMiddleware *middleware.OAuthAuthMiddleware,
) {
// Create handlers
-
getCommunityHandler := communityFeed.NewGetCommunityHandler(feedService)
+
getCommunityHandler := communityFeed.NewGetCommunityHandler(feedService, voteService)
// GET /xrpc/social.coves.communityFeed.getCommunity
-
// Public endpoint - basic community sorting only for Alpha
-
// TODO(feed-generator): Add OptionalAuth middleware when implementing viewer-specific state
-
// (blocks, upvotes, saves, etc.) in feed generator skeleton
-
r.Get("/xrpc/social.coves.communityFeed.getCommunity", getCommunityHandler.HandleGetCommunity)
+
// Public endpoint with optional auth for viewer-specific state (vote state)
+
r.With(authMiddleware.OptionalAuth).Get("/xrpc/social.coves.communityFeed.getCommunity", getCommunityHandler.HandleGetCommunity)
}
+221
internal/core/votes/cache.go
···
+
package votes
+
+
import (
+
"context"
+
"fmt"
+
"log/slog"
+
"strings"
+
"sync"
+
"time"
+
+
"Coves/internal/atproto/pds"
+
)
+
+
// CachedVote represents a vote stored in the cache
+
type CachedVote struct {
+
Direction string // "up" or "down"
+
URI string // vote record URI (at://did/collection/rkey)
+
RKey string // record key
+
}
+
+
// VoteCache provides an in-memory cache of user votes fetched from their PDS.
+
// This avoids eventual consistency issues with the AppView database.
+
type VoteCache struct {
+
mu sync.RWMutex
+
votes map[string]map[string]*CachedVote // userDID -> subjectURI -> vote
+
expiry map[string]time.Time // userDID -> expiry time
+
ttl time.Duration
+
logger *slog.Logger
+
}
+
+
// NewVoteCache creates a new vote cache with the specified TTL
+
func NewVoteCache(ttl time.Duration, logger *slog.Logger) *VoteCache {
+
if logger == nil {
+
logger = slog.Default()
+
}
+
return &VoteCache{
+
votes: make(map[string]map[string]*CachedVote),
+
expiry: make(map[string]time.Time),
+
ttl: ttl,
+
logger: logger,
+
}
+
}
+
+
// GetVotesForUser returns all cached votes for a user.
+
// Returns nil if cache is empty or expired for this user.
+
func (c *VoteCache) GetVotesForUser(userDID string) map[string]*CachedVote {
+
c.mu.RLock()
+
defer c.mu.RUnlock()
+
+
// Check if cache exists and is not expired
+
expiry, exists := c.expiry[userDID]
+
if !exists || time.Now().After(expiry) {
+
return nil
+
}
+
+
return c.votes[userDID]
+
}
+
+
// GetVote returns the cached vote for a specific subject, or nil if not found/expired
+
func (c *VoteCache) GetVote(userDID, subjectURI string) *CachedVote {
+
votes := c.GetVotesForUser(userDID)
+
if votes == nil {
+
return nil
+
}
+
return votes[subjectURI]
+
}
+
+
// IsCached returns true if the user's votes are cached and not expired
+
func (c *VoteCache) IsCached(userDID string) bool {
+
c.mu.RLock()
+
defer c.mu.RUnlock()
+
+
expiry, exists := c.expiry[userDID]
+
return exists && time.Now().Before(expiry)
+
}
+
+
// SetVotesForUser replaces all cached votes for a user
+
func (c *VoteCache) SetVotesForUser(userDID string, votes map[string]*CachedVote) {
+
c.mu.Lock()
+
defer c.mu.Unlock()
+
+
c.votes[userDID] = votes
+
c.expiry[userDID] = time.Now().Add(c.ttl)
+
+
c.logger.Debug("vote cache updated",
+
"user", userDID,
+
"vote_count", len(votes),
+
"expires_at", c.expiry[userDID])
+
}
+
+
// SetVote adds or updates a single vote in the cache
+
func (c *VoteCache) SetVote(userDID, subjectURI string, vote *CachedVote) {
+
c.mu.Lock()
+
defer c.mu.Unlock()
+
+
if c.votes[userDID] == nil {
+
c.votes[userDID] = make(map[string]*CachedVote)
+
}
+
+
c.votes[userDID][subjectURI] = vote
+
+
// Always extend expiry on vote action - active users keep their cache fresh
+
c.expiry[userDID] = time.Now().Add(c.ttl)
+
+
c.logger.Debug("vote cached",
+
"user", userDID,
+
"subject", subjectURI,
+
"direction", vote.Direction)
+
}
+
+
// RemoveVote removes a vote from the cache (for toggle-off)
+
func (c *VoteCache) RemoveVote(userDID, subjectURI string) {
+
c.mu.Lock()
+
defer c.mu.Unlock()
+
+
if c.votes[userDID] != nil {
+
delete(c.votes[userDID], subjectURI)
+
+
// Extend expiry on vote action - active users keep their cache fresh
+
c.expiry[userDID] = time.Now().Add(c.ttl)
+
+
c.logger.Debug("vote removed from cache",
+
"user", userDID,
+
"subject", subjectURI)
+
}
+
}
+
+
// Invalidate removes all cached votes for a user
+
func (c *VoteCache) Invalidate(userDID string) {
+
c.mu.Lock()
+
defer c.mu.Unlock()
+
+
delete(c.votes, userDID)
+
delete(c.expiry, userDID)
+
+
c.logger.Debug("vote cache invalidated", "user", userDID)
+
}
+
+
// FetchAndCacheFromPDS fetches all votes from the user's PDS and caches them.
+
// This should be called on first authenticated request or when cache is expired.
+
func (c *VoteCache) FetchAndCacheFromPDS(ctx context.Context, pdsClient pds.Client) error {
+
userDID := pdsClient.DID()
+
+
c.logger.Debug("fetching votes from PDS",
+
"user", userDID,
+
"pds", pdsClient.HostURL())
+
+
votes, err := c.fetchAllVotesFromPDS(ctx, pdsClient)
+
if err != nil {
+
return fmt.Errorf("failed to fetch votes from PDS: %w", err)
+
}
+
+
c.SetVotesForUser(userDID, votes)
+
+
c.logger.Info("vote cache populated from PDS",
+
"user", userDID,
+
"vote_count", len(votes))
+
+
return nil
+
}
+
+
// fetchAllVotesFromPDS paginates through all vote records on the user's PDS
+
func (c *VoteCache) fetchAllVotesFromPDS(ctx context.Context, pdsClient pds.Client) (map[string]*CachedVote, error) {
+
votes := make(map[string]*CachedVote)
+
cursor := ""
+
const pageSize = 100
+
const collection = "social.coves.feed.vote"
+
+
for {
+
result, err := pdsClient.ListRecords(ctx, collection, pageSize, cursor)
+
if err != nil {
+
if pds.IsAuthError(err) {
+
return nil, ErrNotAuthorized
+
}
+
return nil, fmt.Errorf("listRecords failed: %w", err)
+
}
+
+
for _, rec := range result.Records {
+
// Extract subject from record value
+
subject, ok := rec.Value["subject"].(map[string]any)
+
if !ok {
+
continue
+
}
+
+
subjectURI, ok := subject["uri"].(string)
+
if !ok || subjectURI == "" {
+
continue
+
}
+
+
direction, _ := rec.Value["direction"].(string)
+
if direction == "" {
+
continue
+
}
+
+
// Extract rkey from URI
+
rkey := extractRKeyFromURI(rec.URI)
+
+
votes[subjectURI] = &CachedVote{
+
Direction: direction,
+
URI: rec.URI,
+
RKey: rkey,
+
}
+
}
+
+
if result.Cursor == "" {
+
break
+
}
+
cursor = result.Cursor
+
}
+
+
return votes, nil
+
}
+
+
// extractRKeyFromURI extracts the rkey from an AT-URI (at://did/collection/rkey)
+
func extractRKeyFromURI(uri string) string {
+
parts := strings.Split(uri, "/")
+
if len(parts) >= 5 {
+
return parts[len(parts)-1]
+
}
+
return ""
+
}
+14
internal/core/votes/service.go
···
// - Deletes the user's vote record from their PDS
// - AppView will soft-delete via Jetstream consumer
DeleteVote(ctx context.Context, session *oauthlib.ClientSessionData, req DeleteVoteRequest) error
+
+
// EnsureCachePopulated fetches the user's votes from their PDS if not already cached.
+
// This should be called before rendering feeds to ensure vote state is available.
+
// If cache is already populated and not expired, this is a no-op.
+
EnsureCachePopulated(ctx context.Context, session *oauthlib.ClientSessionData) error
+
+
// GetViewerVote returns the viewer's vote for a specific subject, or nil if not voted.
+
// Returns from cache if available, otherwise returns nil (caller should ensure cache is populated).
+
GetViewerVote(userDID, subjectURI string) *CachedVote
+
+
// GetViewerVotesForSubjects returns the viewer's votes for multiple subjects.
+
// Returns a map of subjectURI -> CachedVote for subjects the user has voted on.
+
// This is efficient for batch lookups when rendering feeds.
+
GetViewerVotesForSubjects(userDID string, subjectURIs []string) map[string]*CachedVote
}
// CreateVoteRequest contains the parameters for creating a vote
+84 -2
internal/core/votes/service_impl.go
···
oauthStore oauth.ClientAuthStore
logger *slog.Logger
pdsClientFactory PDSClientFactory // Optional, for testing. If nil, uses OAuth.
+
cache *VoteCache // In-memory cache of user votes from PDS
}
// NewService creates a new vote service instance
-
func NewService(repo Repository, oauthClient *oauthclient.OAuthClient, oauthStore oauth.ClientAuthStore, logger *slog.Logger) Service {
+
func NewService(repo Repository, oauthClient *oauthclient.OAuthClient, oauthStore oauth.ClientAuthStore, cache *VoteCache, logger *slog.Logger) Service {
if logger == nil {
logger = slog.Default()
}
···
repo: repo,
oauthClient: oauthClient,
oauthStore: oauthStore,
+
cache: cache,
logger: logger,
}
}
// NewServiceWithPDSFactory creates a vote service with a custom PDS client factory.
// This is primarily for testing with password-based authentication.
-
func NewServiceWithPDSFactory(repo Repository, logger *slog.Logger, factory PDSClientFactory) Service {
+
func NewServiceWithPDSFactory(repo Repository, cache *VoteCache, logger *slog.Logger, factory PDSClientFactory) Service {
if logger == nil {
logger = slog.Default()
}
return &voteService{
repo: repo,
+
cache: cache,
logger: logger,
pdsClientFactory: factory,
}
···
"subject", req.Subject.URI,
"direction", req.Direction)
+
// Update cache - remove the vote
+
if s.cache != nil {
+
s.cache.RemoveVote(session.AccountDID.String(), req.Subject.URI)
+
}
+
// Return empty response to indicate deletion
return &CreateVoteResponse{
URI: "",
···
"uri", uri,
"cid", cid)
+
// Update cache - add the new vote
+
if s.cache != nil {
+
s.cache.SetVote(session.AccountDID.String(), req.Subject.URI, &CachedVote{
+
Direction: req.Direction,
+
URI: uri,
+
RKey: extractRKeyFromURI(uri),
+
})
+
}
+
return &CreateVoteResponse{
URI: uri,
CID: cid,
···
"subject", req.Subject.URI,
"uri", existing.URI)
+
// Update cache - remove the vote
+
if s.cache != nil {
+
s.cache.RemoveVote(session.AccountDID.String(), req.Subject.URI)
+
}
+
return nil
}
···
// No vote found for this subject after checking all pages
return nil, nil
}
+
+
// EnsureCachePopulated fetches the user's votes from their PDS if not already cached.
+
func (s *voteService) EnsureCachePopulated(ctx context.Context, session *oauth.ClientSessionData) error {
+
if s.cache == nil {
+
return nil // No cache configured
+
}
+
+
// Check if already cached
+
if s.cache.IsCached(session.AccountDID.String()) {
+
return nil
+
}
+
+
// Create PDS client for this session
+
pdsClient, err := s.getPDSClient(ctx, session)
+
if err != nil {
+
s.logger.Error("failed to create PDS client for cache population",
+
"error", err,
+
"user", session.AccountDID)
+
return fmt.Errorf("failed to create PDS client: %w", err)
+
}
+
+
// Fetch and cache votes from PDS
+
if err := s.cache.FetchAndCacheFromPDS(ctx, pdsClient); err != nil {
+
s.logger.Error("failed to populate vote cache from PDS",
+
"error", err,
+
"user", session.AccountDID)
+
return fmt.Errorf("failed to populate vote cache: %w", err)
+
}
+
+
return nil
+
}
+
+
// GetViewerVote returns the viewer's vote for a specific subject, or nil if not voted.
+
func (s *voteService) GetViewerVote(userDID, subjectURI string) *CachedVote {
+
if s.cache == nil {
+
return nil
+
}
+
return s.cache.GetVote(userDID, subjectURI)
+
}
+
+
// GetViewerVotesForSubjects returns the viewer's votes for multiple subjects.
+
func (s *voteService) GetViewerVotesForSubjects(userDID string, subjectURIs []string) map[string]*CachedVote {
+
result := make(map[string]*CachedVote)
+
if s.cache == nil {
+
return result
+
}
+
+
allVotes := s.cache.GetVotesForUser(userDID)
+
if allVotes == nil {
+
return result
+
}
+
+
for _, uri := range subjectURIs {
+
if vote, exists := allVotes[uri]; exists {
+
result[uri] = vote
+
}
+
}
+
+
return result
+
}
+76 -16
internal/atproto/jetstream/vote_consumer.go
···
}
// Atomically: Index vote + Update post counts
-
if err := c.indexVoteAndUpdateCounts(ctx, vote); err != nil {
+
wasNew, err := c.indexVoteAndUpdateCounts(ctx, vote)
+
if err != nil {
return fmt.Errorf("failed to index vote and update counts: %w", err)
}
-
log.Printf("โœ“ Indexed vote: %s (%s on %s)", uri, vote.Direction, vote.SubjectURI)
+
if wasNew {
+
log.Printf("โœ“ Indexed vote: %s (%s on %s)", uri, vote.Direction, vote.SubjectURI)
+
}
return nil
}
···
}
// indexVoteAndUpdateCounts atomically indexes a vote and updates post vote counts
-
func (c *VoteEventConsumer) indexVoteAndUpdateCounts(ctx context.Context, vote *votes.Vote) error {
+
// Returns (true, nil) if vote was newly inserted, (false, nil) if already existed (idempotent)
+
func (c *VoteEventConsumer) indexVoteAndUpdateCounts(ctx context.Context, vote *votes.Vote) (bool, error) {
tx, err := c.db.BeginTx(ctx, nil)
if err != nil {
-
return fmt.Errorf("failed to begin transaction: %w", err)
+
return false, fmt.Errorf("failed to begin transaction: %w", err)
}
defer func() {
if rollbackErr := tx.Rollback(); rollbackErr != nil && rollbackErr != sql.ErrTxDone {
···
}
}()
-
// 1. Index the vote (idempotent with ON CONFLICT DO NOTHING)
+
// 1. Check for existing active vote with different URI (stale record)
+
// This handles cases where:
+
// - User voted on another client and we missed the delete event
+
// - Vote was reindexed but user created a new vote with different rkey
+
// - Any other state mismatch between PDS and AppView
+
var existingDirection sql.NullString
+
checkQuery := `
+
SELECT direction FROM votes
+
WHERE voter_did = $1
+
AND subject_uri = $2
+
AND deleted_at IS NULL
+
AND uri != $3
+
LIMIT 1
+
`
+
if err := tx.QueryRowContext(ctx, checkQuery, vote.VoterDID, vote.SubjectURI, vote.URI).Scan(&existingDirection); err != nil && err != sql.ErrNoRows {
+
return false, fmt.Errorf("failed to check existing vote: %w", err)
+
}
+
+
// If there's a stale vote, soft-delete it and adjust counts
+
if existingDirection.Valid {
+
softDeleteQuery := `
+
UPDATE votes
+
SET deleted_at = NOW()
+
WHERE voter_did = $1
+
AND subject_uri = $2
+
AND deleted_at IS NULL
+
AND uri != $3
+
`
+
if _, err := tx.ExecContext(ctx, softDeleteQuery, vote.VoterDID, vote.SubjectURI, vote.URI); err != nil {
+
return false, fmt.Errorf("failed to soft-delete existing votes: %w", err)
+
}
+
+
// Decrement the old vote's count (will be re-incremented below if same direction)
+
collection := utils.ExtractCollectionFromURI(vote.SubjectURI)
+
var decrementQuery string
+
if existingDirection.String == "up" {
+
if collection == "social.coves.community.post" {
+
decrementQuery = `UPDATE posts SET upvote_count = GREATEST(0, upvote_count - 1), score = upvote_count - 1 - downvote_count WHERE uri = $1 AND deleted_at IS NULL`
+
} else if collection == "social.coves.community.comment" {
+
decrementQuery = `UPDATE comments SET upvote_count = GREATEST(0, upvote_count - 1), score = upvote_count - 1 - downvote_count WHERE uri = $1 AND deleted_at IS NULL`
+
}
+
} else {
+
if collection == "social.coves.community.post" {
+
decrementQuery = `UPDATE posts SET downvote_count = GREATEST(0, downvote_count - 1), score = upvote_count - (downvote_count - 1) WHERE uri = $1 AND deleted_at IS NULL`
+
} else if collection == "social.coves.community.comment" {
+
decrementQuery = `UPDATE comments SET downvote_count = GREATEST(0, downvote_count - 1), score = upvote_count - (downvote_count - 1) WHERE uri = $1 AND deleted_at IS NULL`
+
}
+
}
+
if decrementQuery != "" {
+
if _, err := tx.ExecContext(ctx, decrementQuery, vote.SubjectURI); err != nil {
+
return false, fmt.Errorf("failed to decrement old vote count: %w", err)
+
}
+
}
+
log.Printf("Cleaned up stale vote for %s on %s (was %s)", vote.VoterDID, vote.SubjectURI, existingDirection.String)
+
}
+
+
// 2. Index the vote (idempotent with ON CONFLICT DO NOTHING)
query := `
INSERT INTO votes (
uri, cid, rkey, voter_did,
···
// If no rows returned, vote already exists (idempotent - OK for Jetstream replays)
if err == sql.ErrNoRows {
-
log.Printf("Vote already indexed: %s (idempotent)", vote.URI)
+
// Silently handle idempotent case - no log needed for replayed events
if commitErr := tx.Commit(); commitErr != nil {
-
return fmt.Errorf("failed to commit transaction: %w", commitErr)
+
return false, fmt.Errorf("failed to commit transaction: %w", commitErr)
}
-
return nil
+
return false, nil // Vote already existed
}
if err != nil {
-
return fmt.Errorf("failed to insert vote: %w", err)
+
return false, fmt.Errorf("failed to insert vote: %w", err)
}
-
// 2. Update vote counts on the subject (post or comment)
+
// 3. Update vote counts on the subject (post or comment)
// Parse collection from subject URI to determine target table
collection := utils.ExtractCollectionFromURI(vote.SubjectURI)
···
// Vote is still indexed in votes table, we just don't update denormalized counts
log.Printf("Vote subject has unsupported collection: %s (vote indexed, counts not updated)", collection)
if commitErr := tx.Commit(); commitErr != nil {
-
return fmt.Errorf("failed to commit transaction: %w", commitErr)
+
return false, fmt.Errorf("failed to commit transaction: %w", commitErr)
}
-
return nil
+
return true, nil // Vote was newly indexed
}
result, err := tx.ExecContext(ctx, updateQuery, vote.SubjectURI)
if err != nil {
-
return fmt.Errorf("failed to update vote counts: %w", err)
+
return false, fmt.Errorf("failed to update vote counts: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
-
return fmt.Errorf("failed to check update result: %w", err)
+
return false, fmt.Errorf("failed to check update result: %w", err)
}
// If subject doesn't exist or is deleted, that's OK (vote still indexed)
···
// Commit transaction
if err := tx.Commit(); err != nil {
-
return fmt.Errorf("failed to commit transaction: %w", err)
+
return false, fmt.Errorf("failed to commit transaction: %w", err)
}
-
return nil
+
return true, nil // Vote was newly indexed
}
// deleteVoteAndUpdateCounts atomically soft-deletes a vote and updates post vote counts
+109
internal/atproto/lexicon/social/coves/community/comment/create.json
···
+
{
+
"lexicon": 1,
+
"id": "social.coves.community.comment.create",
+
"defs": {
+
"main": {
+
"type": "procedure",
+
"description": "Create a comment on a post or another comment. Comments support nested threading, rich text, embeds, and self-labeling.",
+
"input": {
+
"encoding": "application/json",
+
"schema": {
+
"type": "object",
+
"required": ["reply", "content"],
+
"properties": {
+
"reply": {
+
"type": "object",
+
"description": "References for maintaining thread structure. Root always points to the original post, parent points to the immediate parent (post or comment).",
+
"required": ["root", "parent"],
+
"properties": {
+
"root": {
+
"type": "ref",
+
"ref": "com.atproto.repo.strongRef",
+
"description": "Strong reference to the original post that started the thread"
+
},
+
"parent": {
+
"type": "ref",
+
"ref": "com.atproto.repo.strongRef",
+
"description": "Strong reference to the immediate parent (post or comment) being replied to"
+
}
+
}
+
},
+
"content": {
+
"type": "string",
+
"maxGraphemes": 10000,
+
"maxLength": 100000,
+
"description": "Comment text content"
+
},
+
"facets": {
+
"type": "array",
+
"description": "Annotations for rich text (mentions, links, etc.)",
+
"items": {
+
"type": "ref",
+
"ref": "social.coves.richtext.facet"
+
}
+
},
+
"embed": {
+
"type": "union",
+
"description": "Embedded media or quoted posts",
+
"refs": [
+
"social.coves.embed.images",
+
"social.coves.embed.post"
+
]
+
},
+
"langs": {
+
"type": "array",
+
"description": "Languages used in the comment content (ISO 639-1)",
+
"maxLength": 3,
+
"items": {
+
"type": "string",
+
"format": "language"
+
}
+
},
+
"labels": {
+
"type": "ref",
+
"ref": "com.atproto.label.defs#selfLabels",
+
"description": "Self-applied content labels"
+
}
+
}
+
}
+
},
+
"output": {
+
"encoding": "application/json",
+
"schema": {
+
"type": "object",
+
"required": ["uri", "cid"],
+
"properties": {
+
"uri": {
+
"type": "string",
+
"format": "at-uri",
+
"description": "AT-URI of the created comment"
+
},
+
"cid": {
+
"type": "string",
+
"format": "cid",
+
"description": "CID of the created comment record"
+
}
+
}
+
}
+
},
+
"errors": [
+
{
+
"name": "InvalidReply",
+
"description": "The reply reference is invalid, malformed, or refers to non-existent content"
+
},
+
{
+
"name": "ContentTooLong",
+
"description": "Comment content exceeds maximum length constraints"
+
},
+
{
+
"name": "ContentEmpty",
+
"description": "Comment content is empty or contains only whitespace"
+
},
+
{
+
"name": "NotAuthorized",
+
"description": "User is not authorized to create comments on this content"
+
}
+
]
+
}
+
}
+
}
+41
internal/atproto/lexicon/social/coves/community/comment/delete.json
···
+
{
+
"lexicon": 1,
+
"id": "social.coves.community.comment.delete",
+
"defs": {
+
"main": {
+
"type": "procedure",
+
"description": "Delete a comment. Only the comment author can delete their own comments.",
+
"input": {
+
"encoding": "application/json",
+
"schema": {
+
"type": "object",
+
"required": ["uri"],
+
"properties": {
+
"uri": {
+
"type": "string",
+
"format": "at-uri",
+
"description": "AT-URI of the comment to delete"
+
}
+
}
+
}
+
},
+
"output": {
+
"encoding": "application/json",
+
"schema": {
+
"type": "object",
+
"properties": {}
+
}
+
},
+
"errors": [
+
{
+
"name": "CommentNotFound",
+
"description": "Comment with the specified URI does not exist"
+
},
+
{
+
"name": "NotAuthorized",
+
"description": "User is not authorized to delete this comment (not the author)"
+
}
+
]
+
}
+
}
+
}
+97
internal/atproto/lexicon/social/coves/community/comment/update.json
···
+
{
+
"lexicon": 1,
+
"id": "social.coves.community.comment.update",
+
"defs": {
+
"main": {
+
"type": "procedure",
+
"description": "Update an existing comment's content, facets, embed, languages, or labels. Threading references (reply.root and reply.parent) are immutable and cannot be changed.",
+
"input": {
+
"encoding": "application/json",
+
"schema": {
+
"type": "object",
+
"required": ["uri", "content"],
+
"properties": {
+
"uri": {
+
"type": "string",
+
"format": "at-uri",
+
"description": "AT-URI of the comment to update"
+
},
+
"content": {
+
"type": "string",
+
"maxGraphemes": 10000,
+
"maxLength": 100000,
+
"description": "Updated comment text content"
+
},
+
"facets": {
+
"type": "array",
+
"description": "Updated annotations for rich text (mentions, links, etc.)",
+
"items": {
+
"type": "ref",
+
"ref": "social.coves.richtext.facet"
+
}
+
},
+
"embed": {
+
"type": "union",
+
"description": "Updated embedded media or quoted posts",
+
"refs": [
+
"social.coves.embed.images",
+
"social.coves.embed.post"
+
]
+
},
+
"langs": {
+
"type": "array",
+
"description": "Updated languages used in the comment content (ISO 639-1)",
+
"maxLength": 3,
+
"items": {
+
"type": "string",
+
"format": "language"
+
}
+
},
+
"labels": {
+
"type": "ref",
+
"ref": "com.atproto.label.defs#selfLabels",
+
"description": "Updated self-applied content labels"
+
}
+
}
+
}
+
},
+
"output": {
+
"encoding": "application/json",
+
"schema": {
+
"type": "object",
+
"required": ["uri", "cid"],
+
"properties": {
+
"uri": {
+
"type": "string",
+
"format": "at-uri",
+
"description": "AT-URI of the updated comment (unchanged from input)"
+
},
+
"cid": {
+
"type": "string",
+
"format": "cid",
+
"description": "New CID of the updated comment record"
+
}
+
}
+
}
+
},
+
"errors": [
+
{
+
"name": "CommentNotFound",
+
"description": "Comment with the specified URI does not exist"
+
},
+
{
+
"name": "ContentTooLong",
+
"description": "Updated comment content exceeds maximum length constraints"
+
},
+
{
+
"name": "ContentEmpty",
+
"description": "Updated comment content is empty or contains only whitespace"
+
},
+
{
+
"name": "NotAuthorized",
+
"description": "User is not authorized to update this comment (not the author)"
+
}
+
]
+
}
+
}
+
}
+38
internal/core/comments/types.go
···
+
package comments
+
+
// CreateCommentRequest contains parameters for creating a comment
+
type CreateCommentRequest struct {
+
Reply ReplyRef `json:"reply"`
+
Content string `json:"content"`
+
Facets []interface{} `json:"facets,omitempty"`
+
Embed interface{} `json:"embed,omitempty"`
+
Langs []string `json:"langs,omitempty"`
+
Labels *SelfLabels `json:"labels,omitempty"`
+
}
+
+
// CreateCommentResponse contains the result of creating a comment
+
type CreateCommentResponse struct {
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
}
+
+
// UpdateCommentRequest contains parameters for updating a comment
+
type UpdateCommentRequest struct {
+
URI string `json:"uri"`
+
Content string `json:"content"`
+
Facets []interface{} `json:"facets,omitempty"`
+
Embed interface{} `json:"embed,omitempty"`
+
Langs []string `json:"langs,omitempty"`
+
Labels *SelfLabels `json:"labels,omitempty"`
+
}
+
+
// UpdateCommentResponse contains the result of updating a comment
+
type UpdateCommentResponse struct {
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
}
+
+
// DeleteCommentRequest contains parameters for deleting a comment
+
type DeleteCommentRequest struct {
+
URI string `json:"uri"`
+
}
+130
internal/api/handlers/comments/create_comment.go
···
+
package comments
+
+
import (
+
"Coves/internal/api/middleware"
+
"Coves/internal/core/comments"
+
"encoding/json"
+
"log"
+
"net/http"
+
)
+
+
// CreateCommentHandler handles comment creation requests
+
type CreateCommentHandler struct {
+
service comments.Service
+
}
+
+
// NewCreateCommentHandler creates a new handler for creating comments
+
func NewCreateCommentHandler(service comments.Service) *CreateCommentHandler {
+
return &CreateCommentHandler{
+
service: service,
+
}
+
}
+
+
// CreateCommentInput matches the lexicon input schema for social.coves.community.comment.create
+
type CreateCommentInput struct {
+
Reply struct {
+
Root struct {
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
} `json:"root"`
+
Parent struct {
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
} `json:"parent"`
+
} `json:"reply"`
+
Content string `json:"content"`
+
Facets []interface{} `json:"facets,omitempty"`
+
Embed interface{} `json:"embed,omitempty"`
+
Langs []string `json:"langs,omitempty"`
+
Labels interface{} `json:"labels,omitempty"`
+
}
+
+
// CreateCommentOutput matches the lexicon output schema
+
type CreateCommentOutput struct {
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
}
+
+
// HandleCreate handles comment creation requests
+
// POST /xrpc/social.coves.community.comment.create
+
//
+
// Request body: { "reply": { "root": {...}, "parent": {...} }, "content": "..." }
+
// Response: { "uri": "at://...", "cid": "..." }
+
func (h *CreateCommentHandler) HandleCreate(w http.ResponseWriter, r *http.Request) {
+
// 1. Check method is POST
+
if r.Method != http.MethodPost {
+
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+
return
+
}
+
+
// 2. Limit request body size to prevent DoS attacks (100KB should be plenty for comments)
+
r.Body = http.MaxBytesReader(w, r.Body, 100*1024)
+
+
// 3. Parse JSON body into CreateCommentInput
+
var input CreateCommentInput
+
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "Invalid request body")
+
return
+
}
+
+
// 4. Get OAuth session from context (injected by auth middleware)
+
session := middleware.GetOAuthSession(r)
+
if session == nil {
+
writeError(w, http.StatusUnauthorized, "AuthRequired", "Authentication required")
+
return
+
}
+
+
// 5. Convert labels interface{} to *comments.SelfLabels if provided
+
var labels *comments.SelfLabels
+
if input.Labels != nil {
+
labelsJSON, err := json.Marshal(input.Labels)
+
if err != nil {
+
writeError(w, http.StatusBadRequest, "InvalidLabels", "Invalid labels format")
+
return
+
}
+
var selfLabels comments.SelfLabels
+
if err := json.Unmarshal(labelsJSON, &selfLabels); err != nil {
+
writeError(w, http.StatusBadRequest, "InvalidLabels", "Invalid labels structure")
+
return
+
}
+
labels = &selfLabels
+
}
+
+
// 6. Convert input to CreateCommentRequest
+
req := comments.CreateCommentRequest{
+
Reply: comments.ReplyRef{
+
Root: comments.StrongRef{
+
URI: input.Reply.Root.URI,
+
CID: input.Reply.Root.CID,
+
},
+
Parent: comments.StrongRef{
+
URI: input.Reply.Parent.URI,
+
CID: input.Reply.Parent.CID,
+
},
+
},
+
Content: input.Content,
+
Facets: input.Facets,
+
Embed: input.Embed,
+
Langs: input.Langs,
+
Labels: labels,
+
}
+
+
// 7. Call service to create comment
+
response, err := h.service.CreateComment(r.Context(), session, req)
+
if err != nil {
+
handleServiceError(w, err)
+
return
+
}
+
+
// 8. Return JSON response with URI and CID
+
output := CreateCommentOutput{
+
URI: response.URI,
+
CID: response.CID,
+
}
+
+
w.Header().Set("Content-Type", "application/json")
+
w.WriteHeader(http.StatusOK)
+
if err := json.NewEncoder(w).Encode(output); err != nil {
+
log.Printf("Failed to encode response: %v", err)
+
}
+
}
+80
internal/api/handlers/comments/delete_comment.go
···
+
package comments
+
+
import (
+
"Coves/internal/api/middleware"
+
"Coves/internal/core/comments"
+
"encoding/json"
+
"log"
+
"net/http"
+
)
+
+
// DeleteCommentHandler handles comment deletion requests
+
type DeleteCommentHandler struct {
+
service comments.Service
+
}
+
+
// NewDeleteCommentHandler creates a new handler for deleting comments
+
func NewDeleteCommentHandler(service comments.Service) *DeleteCommentHandler {
+
return &DeleteCommentHandler{
+
service: service,
+
}
+
}
+
+
// DeleteCommentInput matches the lexicon input schema for social.coves.community.comment.delete
+
type DeleteCommentInput struct {
+
URI string `json:"uri"`
+
}
+
+
// DeleteCommentOutput is empty per lexicon specification
+
type DeleteCommentOutput struct{}
+
+
// HandleDelete handles comment deletion requests
+
// POST /xrpc/social.coves.community.comment.delete
+
//
+
// Request body: { "uri": "at://..." }
+
// Response: {}
+
func (h *DeleteCommentHandler) HandleDelete(w http.ResponseWriter, r *http.Request) {
+
// 1. Check method is POST
+
if r.Method != http.MethodPost {
+
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+
return
+
}
+
+
// 2. Limit request body size to prevent DoS attacks (100KB should be plenty for comments)
+
r.Body = http.MaxBytesReader(w, r.Body, 100*1024)
+
+
// 3. Parse JSON body into DeleteCommentInput
+
var input DeleteCommentInput
+
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "Invalid request body")
+
return
+
}
+
+
// 4. Get OAuth session from context (injected by auth middleware)
+
session := middleware.GetOAuthSession(r)
+
if session == nil {
+
writeError(w, http.StatusUnauthorized, "AuthRequired", "Authentication required")
+
return
+
}
+
+
// 5. Convert input to DeleteCommentRequest
+
req := comments.DeleteCommentRequest{
+
URI: input.URI,
+
}
+
+
// 6. Call service to delete comment
+
err := h.service.DeleteComment(r.Context(), session, req)
+
if err != nil {
+
handleServiceError(w, err)
+
return
+
}
+
+
// 7. Return empty JSON object per lexicon specification
+
output := DeleteCommentOutput{}
+
+
w.Header().Set("Content-Type", "application/json")
+
w.WriteHeader(http.StatusOK)
+
if err := json.NewEncoder(w).Encode(output); err != nil {
+
log.Printf("Failed to encode response: %v", err)
+
}
+
}
+34 -2
internal/api/handlers/comments/errors.go
···
import (
"Coves/internal/core/comments"
"encoding/json"
+
"errors"
"log"
"net/http"
)
···
func handleServiceError(w http.ResponseWriter, err error) {
switch {
case comments.IsNotFound(err):
-
writeError(w, http.StatusNotFound, "NotFound", err.Error())
+
// Map specific not found errors to appropriate messages
+
switch {
+
case errors.Is(err, comments.ErrCommentNotFound):
+
writeError(w, http.StatusNotFound, "CommentNotFound", "Comment not found")
+
case errors.Is(err, comments.ErrParentNotFound):
+
writeError(w, http.StatusNotFound, "ParentNotFound", "Parent post or comment not found")
+
case errors.Is(err, comments.ErrRootNotFound):
+
writeError(w, http.StatusNotFound, "RootNotFound", "Root post not found")
+
default:
+
writeError(w, http.StatusNotFound, "NotFound", err.Error())
+
}
case comments.IsValidationError(err):
-
writeError(w, http.StatusBadRequest, "InvalidRequest", err.Error())
+
// Map specific validation errors to appropriate messages
+
switch {
+
case errors.Is(err, comments.ErrInvalidReply):
+
writeError(w, http.StatusBadRequest, "InvalidReply", "The reply reference is invalid or malformed")
+
case errors.Is(err, comments.ErrContentTooLong):
+
writeError(w, http.StatusBadRequest, "ContentTooLong", "Comment content exceeds 10000 graphemes")
+
case errors.Is(err, comments.ErrContentEmpty):
+
writeError(w, http.StatusBadRequest, "ContentEmpty", "Comment content is required")
+
default:
+
writeError(w, http.StatusBadRequest, "InvalidRequest", err.Error())
+
}
+
+
case errors.Is(err, comments.ErrNotAuthorized):
+
writeError(w, http.StatusForbidden, "NotAuthorized", "User is not authorized to perform this action")
+
+
case errors.Is(err, comments.ErrBanned):
+
writeError(w, http.StatusForbidden, "Banned", "User is banned from this community")
+
+
// NOTE: IsConflict case removed - the PDS handles duplicate detection via CreateRecord,
+
// so ErrCommentAlreadyExists is never returned from the service layer. If the PDS rejects
+
// a duplicate record, it returns an auth/validation error which is handled by other cases.
+
// Keeping this code would be dead code that never executes.
default:
// Don't leak internal error details to clients
+112
internal/api/handlers/comments/update_comment.go
···
+
package comments
+
+
import (
+
"Coves/internal/api/middleware"
+
"Coves/internal/core/comments"
+
"encoding/json"
+
"log"
+
"net/http"
+
)
+
+
// UpdateCommentHandler handles comment update requests
+
type UpdateCommentHandler struct {
+
service comments.Service
+
}
+
+
// NewUpdateCommentHandler creates a new handler for updating comments
+
func NewUpdateCommentHandler(service comments.Service) *UpdateCommentHandler {
+
return &UpdateCommentHandler{
+
service: service,
+
}
+
}
+
+
// UpdateCommentInput matches the lexicon input schema for social.coves.community.comment.update
+
type UpdateCommentInput struct {
+
URI string `json:"uri"`
+
Content string `json:"content"`
+
Facets []interface{} `json:"facets,omitempty"`
+
Embed interface{} `json:"embed,omitempty"`
+
Langs []string `json:"langs,omitempty"`
+
Labels interface{} `json:"labels,omitempty"`
+
}
+
+
// UpdateCommentOutput matches the lexicon output schema
+
type UpdateCommentOutput struct {
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
}
+
+
// HandleUpdate handles comment update requests
+
// POST /xrpc/social.coves.community.comment.update
+
//
+
// Request body: { "uri": "at://...", "content": "..." }
+
// Response: { "uri": "at://...", "cid": "..." }
+
func (h *UpdateCommentHandler) HandleUpdate(w http.ResponseWriter, r *http.Request) {
+
// 1. Check method is POST
+
if r.Method != http.MethodPost {
+
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+
return
+
}
+
+
// 2. Limit request body size to prevent DoS attacks (100KB should be plenty for comments)
+
r.Body = http.MaxBytesReader(w, r.Body, 100*1024)
+
+
// 3. Parse JSON body into UpdateCommentInput
+
var input UpdateCommentInput
+
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "Invalid request body")
+
return
+
}
+
+
// 4. Get OAuth session from context (injected by auth middleware)
+
session := middleware.GetOAuthSession(r)
+
if session == nil {
+
writeError(w, http.StatusUnauthorized, "AuthRequired", "Authentication required")
+
return
+
}
+
+
// 5. Convert labels interface{} to *comments.SelfLabels if provided
+
var labels *comments.SelfLabels
+
if input.Labels != nil {
+
labelsJSON, err := json.Marshal(input.Labels)
+
if err != nil {
+
writeError(w, http.StatusBadRequest, "InvalidLabels", "Invalid labels format")
+
return
+
}
+
var selfLabels comments.SelfLabels
+
if err := json.Unmarshal(labelsJSON, &selfLabels); err != nil {
+
writeError(w, http.StatusBadRequest, "InvalidLabels", "Invalid labels structure")
+
return
+
}
+
labels = &selfLabels
+
}
+
+
// 6. Convert input to UpdateCommentRequest
+
req := comments.UpdateCommentRequest{
+
URI: input.URI,
+
Content: input.Content,
+
Facets: input.Facets,
+
Embed: input.Embed,
+
Langs: input.Langs,
+
Labels: labels,
+
}
+
+
// 7. Call service to update comment
+
response, err := h.service.UpdateComment(r.Context(), session, req)
+
if err != nil {
+
handleServiceError(w, err)
+
return
+
}
+
+
// 8. Return JSON response with URI and CID
+
output := UpdateCommentOutput{
+
URI: response.URI,
+
CID: response.CID,
+
}
+
+
w.Header().Set("Content-Type", "application/json")
+
w.WriteHeader(http.StatusOK)
+
if err := json.NewEncoder(w).Encode(output); err != nil {
+
log.Printf("Failed to encode response: %v", err)
+
}
+
}
+35
internal/api/routes/comment.go
···
+
package routes
+
+
import (
+
"Coves/internal/api/handlers/comments"
+
"Coves/internal/api/middleware"
+
commentsCore "Coves/internal/core/comments"
+
+
"github.com/go-chi/chi/v5"
+
)
+
+
// RegisterCommentRoutes registers comment-related XRPC endpoints on the router
+
// Implements social.coves.community.comment.* lexicon endpoints
+
// All write operations (create, update, delete) require authentication
+
func RegisterCommentRoutes(r chi.Router, service commentsCore.Service, authMiddleware *middleware.OAuthAuthMiddleware) {
+
// Initialize handlers
+
createHandler := comments.NewCreateCommentHandler(service)
+
updateHandler := comments.NewUpdateCommentHandler(service)
+
deleteHandler := comments.NewDeleteCommentHandler(service)
+
+
// Procedure endpoints (POST) - require authentication
+
// social.coves.community.comment.create - create a new comment on a post or another comment
+
r.With(authMiddleware.RequireAuth).Post(
+
"/xrpc/social.coves.community.comment.create",
+
createHandler.HandleCreate)
+
+
// social.coves.community.comment.update - update an existing comment's content
+
r.With(authMiddleware.RequireAuth).Post(
+
"/xrpc/social.coves.community.comment.update",
+
updateHandler.HandleUpdate)
+
+
// social.coves.community.comment.delete - soft delete a comment
+
r.With(authMiddleware.RequireAuth).Post(
+
"/xrpc/social.coves.community.comment.delete",
+
deleteHandler.HandleDelete)
+
}
+4 -2
tests/integration/comment_query_test.go
···
postRepo := postgres.NewPostRepository(db)
userRepo := postgres.NewUserRepository(db)
communityRepo := postgres.NewCommunityRepository(db)
-
return comments.NewCommentService(commentRepo, userRepo, postRepo, communityRepo)
+
// Use factory constructor with nil factory - these tests only use the read path (GetComments)
+
return comments.NewCommentServiceWithPDSFactory(commentRepo, userRepo, postRepo, communityRepo, nil, nil)
}
// Helper: createTestCommentWithScore creates a comment with specific vote counts
···
postRepo := postgres.NewPostRepository(db)
userRepo := postgres.NewUserRepository(db)
communityRepo := postgres.NewCommunityRepository(db)
-
service := comments.NewCommentService(commentRepo, userRepo, postRepo, communityRepo)
+
// Use factory constructor with nil factory - these tests only use the read path (GetComments)
+
service := comments.NewCommentServiceWithPDSFactory(commentRepo, userRepo, postRepo, communityRepo, nil, nil)
return &testCommentServiceAdapter{service: service}
}
+6 -3
tests/integration/comment_vote_test.go
···
}
// Query comments with viewer authentication
-
commentService := comments.NewCommentService(commentRepo, userRepo, postRepo, communityRepo)
+
// Use factory constructor with nil factory - this test only uses the read path (GetComments)
+
commentService := comments.NewCommentServiceWithPDSFactory(commentRepo, userRepo, postRepo, communityRepo, nil, nil)
response, err := commentService.GetComments(ctx, &comments.GetCommentsRequest{
PostURI: testPostURI,
Sort: "new",
···
}
// Query with authentication but no vote
-
commentService := comments.NewCommentService(commentRepo, userRepo, postRepo, communityRepo)
+
// Use factory constructor with nil factory - this test only uses the read path (GetComments)
+
commentService := comments.NewCommentServiceWithPDSFactory(commentRepo, userRepo, postRepo, communityRepo, nil, nil)
response, err := commentService.GetComments(ctx, &comments.GetCommentsRequest{
PostURI: testPostURI,
Sort: "new",
···
t.Run("Unauthenticated request has no viewer state", func(t *testing.T) {
// Query without authentication
-
commentService := comments.NewCommentService(commentRepo, userRepo, postRepo, communityRepo)
+
// Use factory constructor with nil factory - this test only uses the read path (GetComments)
+
commentService := comments.NewCommentServiceWithPDSFactory(commentRepo, userRepo, postRepo, communityRepo, nil, nil)
response, err := commentService.GetComments(ctx, &comments.GetCommentsRequest{
PostURI: testPostURI,
Sort: "new",
+2 -1
tests/integration/concurrent_scenarios_test.go
···
}
// Verify all comments are retrievable via service
-
commentService := comments.NewCommentService(commentRepo, userRepo, postRepo, communityRepo)
+
// Use factory constructor with nil factory - this test only uses the read path (GetComments)
+
commentService := comments.NewCommentServiceWithPDSFactory(commentRepo, userRepo, postRepo, communityRepo, nil, nil)
response, err := commentService.GetComments(ctx, &comments.GetCommentsRequest{
PostURI: postURI,
Sort: "new",
+2
go.sum
···
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/uniseg v0.1.0 h1:+2KBaVoUmb9XzDsrx/Ct0W/EYOSFf/nWTauy++DprtY=
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
+
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
+
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
+66
internal/db/migrations/021_add_comment_deletion_metadata.sql
···
+
-- +goose Up
+
-- Add deletion reason tracking to preserve thread structure while respecting privacy
+
-- When comments are deleted, we blank content but keep the record for threading
+
+
-- Create enum type for deletion reasons
+
CREATE TYPE deletion_reason AS ENUM ('author', 'moderator');
+
+
-- Add new columns to comments table
+
ALTER TABLE comments ADD COLUMN deletion_reason deletion_reason;
+
ALTER TABLE comments ADD COLUMN deleted_by TEXT;
+
+
-- Add comments for new columns
+
COMMENT ON COLUMN comments.deletion_reason IS 'Reason for deletion: author (user deleted), moderator (community mod removed)';
+
COMMENT ON COLUMN comments.deleted_by IS 'DID of the actor who performed the deletion';
+
+
-- Backfill existing deleted comments as author-deleted
+
-- This handles existing soft-deleted comments gracefully
+
UPDATE comments
+
SET deletion_reason = 'author',
+
deleted_by = commenter_did
+
WHERE deleted_at IS NOT NULL AND deletion_reason IS NULL;
+
+
-- Modify existing indexes to NOT filter deleted_at IS NULL
+
-- This allows deleted comments to appear in thread queries for structure preservation
+
-- Note: We drop and recreate to change the partial index condition
+
+
-- Drop old partial indexes that exclude deleted comments
+
DROP INDEX IF EXISTS idx_comments_root;
+
DROP INDEX IF EXISTS idx_comments_parent;
+
DROP INDEX IF EXISTS idx_comments_parent_score;
+
DROP INDEX IF EXISTS idx_comments_uri_active;
+
+
-- Recreate indexes without the deleted_at filter (include all comments for threading)
+
CREATE INDEX idx_comments_root ON comments(root_uri, created_at DESC);
+
CREATE INDEX idx_comments_parent ON comments(parent_uri, created_at DESC);
+
CREATE INDEX idx_comments_parent_score ON comments(parent_uri, score DESC, created_at DESC);
+
CREATE INDEX idx_comments_uri_lookup ON comments(uri);
+
+
-- Add index for querying by deletion_reason (for moderation dashboard)
+
CREATE INDEX idx_comments_deleted_reason ON comments(deletion_reason, deleted_at DESC)
+
WHERE deleted_at IS NOT NULL;
+
+
-- Add index for querying by deleted_by (for moderation audit/filtering)
+
CREATE INDEX idx_comments_deleted_by ON comments(deleted_by, deleted_at DESC)
+
WHERE deleted_at IS NOT NULL;
+
+
-- +goose Down
+
-- Remove deletion metadata columns and restore original indexes
+
+
DROP INDEX IF EXISTS idx_comments_deleted_by;
+
DROP INDEX IF EXISTS idx_comments_deleted_reason;
+
DROP INDEX IF EXISTS idx_comments_uri_lookup;
+
DROP INDEX IF EXISTS idx_comments_parent_score;
+
DROP INDEX IF EXISTS idx_comments_parent;
+
DROP INDEX IF EXISTS idx_comments_root;
+
+
-- Restore original partial indexes (excluding deleted comments)
+
CREATE INDEX idx_comments_root ON comments(root_uri, created_at DESC) WHERE deleted_at IS NULL;
+
CREATE INDEX idx_comments_parent ON comments(parent_uri, created_at DESC) WHERE deleted_at IS NULL;
+
CREATE INDEX idx_comments_parent_score ON comments(parent_uri, score DESC, created_at DESC) WHERE deleted_at IS NULL;
+
CREATE INDEX idx_comments_uri_active ON comments(uri) WHERE deleted_at IS NULL;
+
+
ALTER TABLE comments DROP COLUMN IF EXISTS deleted_by;
+
ALTER TABLE comments DROP COLUMN IF EXISTS deletion_reason;
+
+
DROP TYPE IF EXISTS deletion_reason;
+17 -13
internal/core/comments/view_models.go
···
// CommentView represents the full view of a comment with all metadata
// Matches social.coves.community.comment.getComments#commentView lexicon
// Used in thread views and get endpoints
+
// For deleted comments, IsDeleted=true and content-related fields are empty/nil
type CommentView struct {
-
Embed interface{} `json:"embed,omitempty"`
-
Record interface{} `json:"record"`
-
Viewer *CommentViewerState `json:"viewer,omitempty"`
-
Author *posts.AuthorView `json:"author"`
-
Post *CommentRef `json:"post"`
-
Parent *CommentRef `json:"parent,omitempty"`
-
Stats *CommentStats `json:"stats"`
-
Content string `json:"content"`
-
CreatedAt string `json:"createdAt"`
-
IndexedAt string `json:"indexedAt"`
-
URI string `json:"uri"`
-
CID string `json:"cid"`
-
ContentFacets []interface{} `json:"contentFacets,omitempty"`
+
Embed interface{} `json:"embed,omitempty"`
+
Record interface{} `json:"record"`
+
Viewer *CommentViewerState `json:"viewer,omitempty"`
+
Author *posts.AuthorView `json:"author"`
+
Post *CommentRef `json:"post"`
+
Parent *CommentRef `json:"parent,omitempty"`
+
Stats *CommentStats `json:"stats"`
+
Content string `json:"content"`
+
CreatedAt string `json:"createdAt"`
+
IndexedAt string `json:"indexedAt"`
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
ContentFacets []interface{} `json:"contentFacets,omitempty"`
+
IsDeleted bool `json:"isDeleted,omitempty"`
+
DeletionReason *string `json:"deletionReason,omitempty"`
+
DeletedAt *string `json:"deletedAt,omitempty"`
}
// ThreadViewComment represents a comment with its nested replies
+23 -1
internal/core/comments/interfaces.go
···
package comments
-
import "context"
+
import (
+
"context"
+
"database/sql"
+
)
// Repository defines the data access interface for comments
// Used by Jetstream consumer to index comments from firehose
···
// Delete soft-deletes a comment (sets deleted_at)
// Called by Jetstream consumer after comment is deleted from PDS
+
// Deprecated: Use SoftDeleteWithReason for new code to preserve thread structure
Delete(ctx context.Context, uri string) error
+
// SoftDeleteWithReason performs a soft delete that blanks content but preserves thread structure
+
// This allows deleted comments to appear as "[deleted]" placeholders in thread views
+
// reason: "author" (user deleted) or "moderator" (mod removed)
+
// deletedByDID: DID of the actor who performed the deletion
+
SoftDeleteWithReason(ctx context.Context, uri, reason, deletedByDID string) error
+
// ListByRoot retrieves all comments in a thread (flat)
// Used for fetching entire comment threads on posts
ListByRoot(ctx context.Context, rootURI string, limit, offset int) ([]*Comment, error)
···
limitPerParent int,
) (map[string][]*Comment, error)
}
+
+
// RepositoryTx provides transaction-aware operations for consumers that need atomicity
+
// Used by Jetstream consumer to perform atomic delete + count updates
+
// Implementations that support transactions should also implement this interface
+
type RepositoryTx interface {
+
// SoftDeleteWithReasonTx performs a soft delete within a transaction
+
// If tx is nil, executes directly against the database
+
// Returns rows affected count for callers that need to check idempotency
+
// reason: must be DeletionReasonAuthor or DeletionReasonModerator
+
// deletedByDID: DID of the actor who performed the deletion
+
SoftDeleteWithReasonTx(ctx context.Context, tx *sql.Tx, uri, reason, deletedByDID string) (int64, error)
+
}
+87 -27
internal/db/postgres/comment_repo.go
···
id, uri, cid, rkey, commenter_did,
root_uri, root_cid, parent_uri, parent_cid,
content, content_facets, embed, content_labels, langs,
-
created_at, indexed_at, deleted_at,
+
created_at, indexed_at, deleted_at, deletion_reason, deleted_by,
upvote_count, downvote_count, score, reply_count
FROM comments
WHERE uri = $1
···
&comment.ID, &comment.URI, &comment.CID, &comment.RKey, &comment.CommenterDID,
&comment.RootURI, &comment.RootCID, &comment.ParentURI, &comment.ParentCID,
&comment.Content, &comment.ContentFacets, &comment.Embed, &comment.ContentLabels, &langs,
-
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt,
+
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt, &comment.DeletionReason, &comment.DeletedBy,
&comment.UpvoteCount, &comment.DownvoteCount, &comment.Score, &comment.ReplyCount,
)
···
// Delete soft-deletes a comment (sets deleted_at)
// Called by Jetstream consumer after comment is deleted from PDS
// Idempotent: Returns success if comment already deleted
+
// Deprecated: Use SoftDeleteWithReason for new code to preserve thread structure
func (r *postgresCommentRepo) Delete(ctx context.Context, uri string) error {
query := `
UPDATE comments
···
return nil
}
-
// ListByRoot retrieves all active comments in a thread (flat)
+
// SoftDeleteWithReason performs a soft delete that blanks content but preserves thread structure
+
// This allows deleted comments to appear as "[deleted]" placeholders in thread views
+
// Idempotent: Returns success if comment already deleted
+
// Validates that reason is a known deletion reason constant
+
func (r *postgresCommentRepo) SoftDeleteWithReason(ctx context.Context, uri, reason, deletedByDID string) error {
+
// Validate deletion reason
+
if reason != comments.DeletionReasonAuthor && reason != comments.DeletionReasonModerator {
+
return fmt.Errorf("invalid deletion reason: %s", reason)
+
}
+
+
_, err := r.SoftDeleteWithReasonTx(ctx, nil, uri, reason, deletedByDID)
+
return err
+
}
+
+
// SoftDeleteWithReasonTx performs a soft delete within an optional transaction
+
// If tx is nil, executes directly against the database
+
// Returns rows affected count for callers that need to check idempotency
+
// This method is used by both the repository and the Jetstream consumer
+
func (r *postgresCommentRepo) SoftDeleteWithReasonTx(ctx context.Context, tx *sql.Tx, uri, reason, deletedByDID string) (int64, error) {
+
query := `
+
UPDATE comments
+
SET
+
content = '',
+
content_facets = NULL,
+
embed = NULL,
+
content_labels = NULL,
+
deleted_at = NOW(),
+
deletion_reason = $2,
+
deleted_by = $3
+
WHERE uri = $1 AND deleted_at IS NULL
+
`
+
+
var result sql.Result
+
var err error
+
+
if tx != nil {
+
result, err = tx.ExecContext(ctx, query, uri, reason, deletedByDID)
+
} else {
+
result, err = r.db.ExecContext(ctx, query, uri, reason, deletedByDID)
+
}
+
+
if err != nil {
+
return 0, fmt.Errorf("failed to soft delete comment: %w", err)
+
}
+
+
rowsAffected, err := result.RowsAffected()
+
if err != nil {
+
return 0, fmt.Errorf("failed to check delete result: %w", err)
+
}
+
+
return rowsAffected, nil
+
}
+
+
// ListByRoot retrieves all comments in a thread (flat), including deleted ones
// Used for fetching entire comment threads on posts
+
// Includes deleted comments to preserve thread structure (shown as "[deleted]" placeholders)
func (r *postgresCommentRepo) ListByRoot(ctx context.Context, rootURI string, limit, offset int) ([]*comments.Comment, error) {
query := `
SELECT
id, uri, cid, rkey, commenter_did,
root_uri, root_cid, parent_uri, parent_cid,
content, content_facets, embed, content_labels, langs,
-
created_at, indexed_at, deleted_at,
+
created_at, indexed_at, deleted_at, deletion_reason, deleted_by,
upvote_count, downvote_count, score, reply_count
FROM comments
-
WHERE root_uri = $1 AND deleted_at IS NULL
+
WHERE root_uri = $1
ORDER BY created_at ASC
LIMIT $2 OFFSET $3
`
···
&comment.ID, &comment.URI, &comment.CID, &comment.RKey, &comment.CommenterDID,
&comment.RootURI, &comment.RootCID, &comment.ParentURI, &comment.ParentCID,
&comment.Content, &comment.ContentFacets, &comment.Embed, &comment.ContentLabels, &langs,
-
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt,
+
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt, &comment.DeletionReason, &comment.DeletedBy,
&comment.UpvoteCount, &comment.DownvoteCount, &comment.Score, &comment.ReplyCount,
)
if err != nil {
···
return result, nil
}
-
// ListByParent retrieves direct replies to a post or comment
+
// ListByParent retrieves direct replies to a post or comment, including deleted ones
// Used for building nested/threaded comment views
+
// Includes deleted comments to preserve thread structure (shown as "[deleted]" placeholders)
func (r *postgresCommentRepo) ListByParent(ctx context.Context, parentURI string, limit, offset int) ([]*comments.Comment, error) {
query := `
SELECT
id, uri, cid, rkey, commenter_did,
root_uri, root_cid, parent_uri, parent_cid,
content, content_facets, embed, content_labels, langs,
-
created_at, indexed_at, deleted_at,
+
created_at, indexed_at, deleted_at, deletion_reason, deleted_by,
upvote_count, downvote_count, score, reply_count
FROM comments
-
WHERE parent_uri = $1 AND deleted_at IS NULL
+
WHERE parent_uri = $1
ORDER BY created_at ASC
LIMIT $2 OFFSET $3
`
···
&comment.ID, &comment.URI, &comment.CID, &comment.RKey, &comment.CommenterDID,
&comment.RootURI, &comment.RootCID, &comment.ParentURI, &comment.ParentCID,
&comment.Content, &comment.ContentFacets, &comment.Embed, &comment.ContentLabels, &langs,
-
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt,
+
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt, &comment.DeletionReason, &comment.DeletedBy,
&comment.UpvoteCount, &comment.DownvoteCount, &comment.Score, &comment.ReplyCount,
)
if err != nil {
···
}
// ListByCommenter retrieves all active comments by a specific user
-
// Future: Used for user comment history
+
// Used for user comment history - filters out deleted comments
func (r *postgresCommentRepo) ListByCommenter(ctx context.Context, commenterDID string, limit, offset int) ([]*comments.Comment, error) {
query := `
SELECT
id, uri, cid, rkey, commenter_did,
root_uri, root_cid, parent_uri, parent_cid,
content, content_facets, embed, content_labels, langs,
-
created_at, indexed_at, deleted_at,
+
created_at, indexed_at, deleted_at, deletion_reason, deleted_by,
upvote_count, downvote_count, score, reply_count
FROM comments
WHERE commenter_did = $1 AND deleted_at IS NULL
···
&comment.ID, &comment.URI, &comment.CID, &comment.RKey, &comment.CommenterDID,
&comment.RootURI, &comment.RootCID, &comment.ParentURI, &comment.ParentCID,
&comment.Content, &comment.ContentFacets, &comment.Embed, &comment.ContentLabels, &langs,
-
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt,
+
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt, &comment.DeletionReason, &comment.DeletedBy,
&comment.UpvoteCount, &comment.DownvoteCount, &comment.Score, &comment.ReplyCount,
)
if err != nil {
···
c.id, c.uri, c.cid, c.rkey, c.commenter_did,
c.root_uri, c.root_cid, c.parent_uri, c.parent_cid,
c.content, c.content_facets, c.embed, c.content_labels, c.langs,
-
c.created_at, c.indexed_at, c.deleted_at,
+
c.created_at, c.indexed_at, c.deleted_at, c.deletion_reason, c.deleted_by,
c.upvote_count, c.downvote_count, c.score, c.reply_count,
log(greatest(2, c.score + 2)) / power(((EXTRACT(EPOCH FROM (NOW() - c.created_at)) / 3600) + 2), 1.8) as hot_rank,
COALESCE(u.handle, c.commenter_did) as author_handle
···
c.id, c.uri, c.cid, c.rkey, c.commenter_did,
c.root_uri, c.root_cid, c.parent_uri, c.parent_cid,
c.content, c.content_facets, c.embed, c.content_labels, c.langs,
-
c.created_at, c.indexed_at, c.deleted_at,
+
c.created_at, c.indexed_at, c.deleted_at, c.deletion_reason, c.deleted_by,
c.upvote_count, c.downvote_count, c.score, c.reply_count,
NULL::numeric as hot_rank,
COALESCE(u.handle, c.commenter_did) as author_handle
···
// Build complete query with JOINs and filters
// LEFT JOIN prevents data loss when user record hasn't been indexed yet (out-of-order Jetstream events)
+
// Includes deleted comments to preserve thread structure (shown as "[deleted]" placeholders)
query := fmt.Sprintf(`
%s
LEFT JOIN users u ON c.commenter_did = u.did
-
WHERE c.parent_uri = $1 AND c.deleted_at IS NULL
+
WHERE c.parent_uri = $1
%s
%s
ORDER BY %s
···
&comment.ID, &comment.URI, &comment.CID, &comment.RKey, &comment.CommenterDID,
&comment.RootURI, &comment.RootCID, &comment.ParentURI, &comment.ParentCID,
&comment.Content, &comment.ContentFacets, &comment.Embed, &comment.ContentLabels, &langs,
-
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt,
+
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt, &comment.DeletionReason, &comment.DeletedBy,
&comment.UpvoteCount, &comment.DownvoteCount, &comment.Score, &comment.ReplyCount,
&hotRank, &authorHandle,
)
···
// GetByURIsBatch retrieves multiple comments by their AT-URIs in a single query
// Returns map[uri]*Comment for efficient lookups without N+1 queries
+
// Includes deleted comments to preserve thread structure
func (r *postgresCommentRepo) GetByURIsBatch(ctx context.Context, uris []string) (map[string]*comments.Comment, error) {
if len(uris) == 0 {
return make(map[string]*comments.Comment), nil
···
// LEFT JOIN prevents data loss when user record hasn't been indexed yet (out-of-order Jetstream events)
// COALESCE falls back to DID when handle is NULL (user not yet in users table)
+
// Includes deleted comments to preserve thread structure (shown as "[deleted]" placeholders)
query := `
SELECT
c.id, c.uri, c.cid, c.rkey, c.commenter_did,
c.root_uri, c.root_cid, c.parent_uri, c.parent_cid,
c.content, c.content_facets, c.embed, c.content_labels, c.langs,
-
c.created_at, c.indexed_at, c.deleted_at,
+
c.created_at, c.indexed_at, c.deleted_at, c.deletion_reason, c.deleted_by,
c.upvote_count, c.downvote_count, c.score, c.reply_count,
COALESCE(u.handle, c.commenter_did) as author_handle
FROM comments c
LEFT JOIN users u ON c.commenter_did = u.did
-
WHERE c.uri = ANY($1) AND c.deleted_at IS NULL
+
WHERE c.uri = ANY($1)
`
rows, err := r.db.QueryContext(ctx, query, pq.Array(uris))
···
&comment.ID, &comment.URI, &comment.CID, &comment.RKey, &comment.CommenterDID,
&comment.RootURI, &comment.RootCID, &comment.ParentURI, &comment.ParentCID,
&comment.Content, &comment.ContentFacets, &comment.Embed, &comment.ContentLabels, &langs,
-
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt,
+
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt, &comment.DeletionReason, &comment.DeletedBy,
&comment.UpvoteCount, &comment.DownvoteCount, &comment.Score, &comment.ReplyCount,
&authorHandle,
)
···
c.id, c.uri, c.cid, c.rkey, c.commenter_did,
c.root_uri, c.root_cid, c.parent_uri, c.parent_cid,
c.content, c.content_facets, c.embed, c.content_labels, c.langs,
-
c.created_at, c.indexed_at, c.deleted_at,
+
c.created_at, c.indexed_at, c.deleted_at, c.deletion_reason, c.deleted_by,
c.upvote_count, c.downvote_count, c.score, c.reply_count,
log(greatest(2, c.score + 2)) / power(((EXTRACT(EPOCH FROM (NOW() - c.created_at)) / 3600) + 2), 1.8) as hot_rank,
COALESCE(u.handle, c.commenter_did) as author_handle`
···
c.id, c.uri, c.cid, c.rkey, c.commenter_did,
c.root_uri, c.root_cid, c.parent_uri, c.parent_cid,
c.content, c.content_facets, c.embed, c.content_labels, c.langs,
-
c.created_at, c.indexed_at, c.deleted_at,
+
c.created_at, c.indexed_at, c.deleted_at, c.deletion_reason, c.deleted_by,
c.upvote_count, c.downvote_count, c.score, c.reply_count,
NULL::numeric as hot_rank,
COALESCE(u.handle, c.commenter_did) as author_handle`
···
c.id, c.uri, c.cid, c.rkey, c.commenter_did,
c.root_uri, c.root_cid, c.parent_uri, c.parent_cid,
c.content, c.content_facets, c.embed, c.content_labels, c.langs,
-
c.created_at, c.indexed_at, c.deleted_at,
+
c.created_at, c.indexed_at, c.deleted_at, c.deletion_reason, c.deleted_by,
c.upvote_count, c.downvote_count, c.score, c.reply_count,
NULL::numeric as hot_rank,
COALESCE(u.handle, c.commenter_did) as author_handle`
···
c.id, c.uri, c.cid, c.rkey, c.commenter_did,
c.root_uri, c.root_cid, c.parent_uri, c.parent_cid,
c.content, c.content_facets, c.embed, c.content_labels, c.langs,
-
c.created_at, c.indexed_at, c.deleted_at,
+
c.created_at, c.indexed_at, c.deleted_at, c.deletion_reason, c.deleted_by,
c.upvote_count, c.downvote_count, c.score, c.reply_count,
log(greatest(2, c.score + 2)) / power(((EXTRACT(EPOCH FROM (NOW() - c.created_at)) / 3600) + 2), 1.8) as hot_rank,
COALESCE(u.handle, c.commenter_did) as author_handle`
···
// Use window function to limit results per parent
// This is more efficient than LIMIT in a subquery per parent
// LEFT JOIN prevents data loss when user record hasn't been indexed yet (out-of-order Jetstream events)
+
// Includes deleted comments to preserve thread structure (shown as "[deleted]" placeholders)
query := fmt.Sprintf(`
WITH ranked_comments AS (
SELECT
···
) as rn
FROM comments c
LEFT JOIN users u ON c.commenter_did = u.did
-
WHERE c.parent_uri = ANY($1) AND c.deleted_at IS NULL
+
WHERE c.parent_uri = ANY($1)
)
SELECT
id, uri, cid, rkey, commenter_did,
root_uri, root_cid, parent_uri, parent_cid,
content, content_facets, embed, content_labels, langs,
-
created_at, indexed_at, deleted_at,
+
created_at, indexed_at, deleted_at, deletion_reason, deleted_by,
upvote_count, downvote_count, score, reply_count,
hot_rank, author_handle
FROM ranked_comments
···
&comment.ID, &comment.URI, &comment.CID, &comment.RKey, &comment.CommenterDID,
&comment.RootURI, &comment.RootCID, &comment.ParentURI, &comment.ParentCID,
&comment.Content, &comment.ContentFacets, &comment.Embed, &comment.ContentLabels, &langs,
-
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt,
+
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt, &comment.DeletionReason, &comment.DeletedBy,
&comment.UpvoteCount, &comment.DownvoteCount, &comment.Score, &comment.ReplyCount,
&hotRank, &authorHandle,
)
+5 -6
internal/core/comments/comment_service.go
···
CreatedAt: createdAt, // Preserve original timestamp
}
-
// Update the record on PDS (putRecord)
-
// Note: This creates a new CID even though the URI stays the same
-
// TODO: Use PutRecord instead of CreateRecord for proper update semantics with optimistic locking.
-
// PutRecord should accept the existing CID (existingRecord.CID) to ensure concurrent updates are detected.
-
// However, PutRecord is not yet implemented in internal/atproto/pds/client.go.
-
uri, cid, err := pdsClient.CreateRecord(ctx, commentCollection, rkey, updatedRecord)
+
// Update the record on PDS with optimistic locking via swapRecord CID
+
uri, cid, err := pdsClient.PutRecord(ctx, commentCollection, rkey, updatedRecord, existingRecord.CID)
if err != nil {
s.logger.Error("failed to update comment on PDS",
"error", err,
···
if pds.IsAuthError(err) {
return nil, ErrNotAuthorized
}
+
if errors.Is(err, pds.ErrConflict) {
+
return nil, ErrConcurrentModification
+
}
return nil, fmt.Errorf("failed to update comment: %w", err)
}
+73
internal/api/handlers/common/viewer_state.go
···
+
package common
+
+
import (
+
"Coves/internal/api/middleware"
+
"Coves/internal/core/posts"
+
"Coves/internal/core/votes"
+
"context"
+
"log"
+
"net/http"
+
)
+
+
// FeedPostProvider is implemented by any feed post wrapper that contains a PostView.
+
// This allows the helper to work with different feed post types (discover, timeline, communityFeed).
+
type FeedPostProvider interface {
+
GetPost() *posts.PostView
+
}
+
+
// PopulateViewerVoteState enriches feed posts with the authenticated user's vote state.
+
// This is a no-op if voteService is nil or the request is unauthenticated.
+
//
+
// Parameters:
+
// - ctx: Request context for PDS calls
+
// - r: HTTP request (used to extract OAuth session)
+
// - voteService: Vote service for cache lookup (may be nil)
+
// - feedPosts: Posts to enrich with viewer state (must implement FeedPostProvider)
+
//
+
// The function logs but does not fail on errors - viewer state is optional enrichment.
+
func PopulateViewerVoteState[T FeedPostProvider](
+
ctx context.Context,
+
r *http.Request,
+
voteService votes.Service,
+
feedPosts []T,
+
) {
+
if voteService == nil {
+
return
+
}
+
+
session := middleware.GetOAuthSession(r)
+
if session == nil {
+
return
+
}
+
+
userDID := middleware.GetUserDID(r)
+
+
// Ensure vote cache is populated from PDS
+
if err := voteService.EnsureCachePopulated(ctx, session); err != nil {
+
log.Printf("Warning: failed to populate vote cache: %v", err)
+
return
+
}
+
+
// Collect post URIs to batch lookup
+
postURIs := make([]string, 0, len(feedPosts))
+
for _, feedPost := range feedPosts {
+
if post := feedPost.GetPost(); post != nil {
+
postURIs = append(postURIs, post.URI)
+
}
+
}
+
+
// Get viewer votes for all posts
+
viewerVotes := voteService.GetViewerVotesForSubjects(userDID, postURIs)
+
+
// Populate viewer state on each post
+
for _, feedPost := range feedPosts {
+
if post := feedPost.GetPost(); post != nil {
+
if vote, exists := viewerVotes[post.URI]; exists {
+
post.Viewer = &posts.ViewerState{
+
Vote: &vote.Direction,
+
VoteURI: &vote.URI,
+
}
+
}
+
}
+
}
+
}
+11 -4
internal/api/handlers/discover/get_discover.go
···
package discover
import (
+
"Coves/internal/api/handlers/common"
"Coves/internal/core/discover"
"Coves/internal/core/posts"
+
"Coves/internal/core/votes"
"encoding/json"
"log"
"net/http"
···
// GetDiscoverHandler handles discover feed retrieval
type GetDiscoverHandler struct {
-
service discover.Service
+
service discover.Service
+
voteService votes.Service
}
// NewGetDiscoverHandler creates a new discover handler
-
func NewGetDiscoverHandler(service discover.Service) *GetDiscoverHandler {
+
func NewGetDiscoverHandler(service discover.Service, voteService votes.Service) *GetDiscoverHandler {
return &GetDiscoverHandler{
-
service: service,
+
service: service,
+
voteService: voteService,
}
}
// HandleGetDiscover retrieves posts from all communities (public feed)
// GET /xrpc/social.coves.feed.getDiscover?sort=hot&limit=15&cursor=...
-
// Public endpoint - no authentication required
+
// Public endpoint with optional auth - if authenticated, includes viewer vote state
func (h *GetDiscoverHandler) HandleGetDiscover(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
···
return
}
+
// Populate viewer vote state if authenticated
+
common.PopulateViewerVoteState(r.Context(), r, h.voteService, response.Feed)
+
// Transform blob refs to URLs for all posts
for _, feedPost := range response.Feed {
if feedPost.Post != nil {
+9 -4
internal/api/routes/discover.go
···
import (
"Coves/internal/api/handlers/discover"
+
"Coves/internal/api/middleware"
discoverCore "Coves/internal/core/discover"
+
"Coves/internal/core/votes"
"github.com/go-chi/chi/v5"
)
···
// RegisterDiscoverRoutes registers discover-related XRPC endpoints
//
// SECURITY & RATE LIMITING:
-
// - Discover feed is PUBLIC (no authentication required)
+
// - Discover feed is PUBLIC (works without authentication)
+
// - Optional auth: if authenticated, includes viewer vote state on posts
// - Protected by global rate limiter: 100 requests/minute per IP (main.go:84)
// - Query timeout enforced via context (prevents long-running queries)
// - Result limit capped at 50 posts per request (validated in service layer)
···
func RegisterDiscoverRoutes(
r chi.Router,
discoverService discoverCore.Service,
+
voteService votes.Service,
+
authMiddleware *middleware.OAuthAuthMiddleware,
) {
// Create handlers
-
getDiscoverHandler := discover.NewGetDiscoverHandler(discoverService)
+
getDiscoverHandler := discover.NewGetDiscoverHandler(discoverService, voteService)
// GET /xrpc/social.coves.feed.getDiscover
-
// Public endpoint - no authentication required
+
// Public endpoint with optional auth for viewer-specific state (vote state)
// Shows posts from ALL communities (not personalized)
// Rate limited: 100 req/min per IP via global middleware
-
r.Get("/xrpc/social.coves.feed.getDiscover", getDiscoverHandler.HandleGetDiscover)
+
r.With(authMiddleware.OptionalAuth).Get("/xrpc/social.coves.feed.getDiscover", getDiscoverHandler.HandleGetDiscover)
}
+5
internal/core/communityFeeds/types.go
···
Reply *ReplyRef `json:"reply,omitempty"` // Reply context
}
+
// GetPost returns the underlying PostView for viewer state enrichment
+
func (f *FeedViewPost) GetPost() *posts.PostView {
+
return f.Post
+
}
+
// FeedReason is a union type for feed context
// Can be reasonRepost or reasonPin
type FeedReason struct {
+5
internal/core/discover/types.go
···
Reply *ReplyRef `json:"reply,omitempty"`
}
+
// GetPost returns the underlying PostView for viewer state enrichment
+
func (f *FeedViewPost) GetPost() *posts.PostView {
+
return f.Post
+
}
+
// FeedReason is a union type for feed context
type FeedReason struct {
Repost *ReasonRepost `json:"-"`
+5
internal/core/timeline/types.go
···
Reply *ReplyRef `json:"reply,omitempty"` // Reply context
}
+
// GetPost returns the underlying PostView for viewer state enrichment
+
func (f *FeedViewPost) GetPost() *posts.PostView {
+
return f.Post
+
}
+
// FeedReason is a union type for feed context
// Future: Can be reasonRepost or reasonCommunity
type FeedReason struct {
+193 -5
tests/integration/discover_test.go
···
import (
"Coves/internal/api/handlers/discover"
+
"Coves/internal/api/middleware"
+
"Coves/internal/core/votes"
"Coves/internal/db/postgres"
"context"
"encoding/json"
···
discoverCore "Coves/internal/core/discover"
+
oauthlib "github.com/bluesky-social/indigo/atproto/auth/oauth"
+
"github.com/bluesky-social/indigo/atproto/syntax"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
+
// mockVoteService implements votes.Service for testing viewer vote state
+
type mockVoteService struct {
+
cachedVotes map[string]*votes.CachedVote // userDID:subjectURI -> vote
+
}
+
+
func newMockVoteService() *mockVoteService {
+
return &mockVoteService{
+
cachedVotes: make(map[string]*votes.CachedVote),
+
}
+
}
+
+
func (m *mockVoteService) AddVote(userDID, subjectURI, direction, voteURI string) {
+
key := userDID + ":" + subjectURI
+
m.cachedVotes[key] = &votes.CachedVote{
+
Direction: direction,
+
URI: voteURI,
+
}
+
}
+
+
func (m *mockVoteService) CreateVote(_ context.Context, _ *oauthlib.ClientSessionData, _ votes.CreateVoteRequest) (*votes.CreateVoteResponse, error) {
+
return &votes.CreateVoteResponse{}, nil
+
}
+
+
func (m *mockVoteService) DeleteVote(_ context.Context, _ *oauthlib.ClientSessionData, _ votes.DeleteVoteRequest) error {
+
return nil
+
}
+
+
func (m *mockVoteService) EnsureCachePopulated(_ context.Context, _ *oauthlib.ClientSessionData) error {
+
return nil // Mock always succeeds - votes pre-populated via AddVote
+
}
+
+
func (m *mockVoteService) GetViewerVote(userDID, subjectURI string) *votes.CachedVote {
+
key := userDID + ":" + subjectURI
+
return m.cachedVotes[key]
+
}
+
+
func (m *mockVoteService) GetViewerVotesForSubjects(userDID string, subjectURIs []string) map[string]*votes.CachedVote {
+
result := make(map[string]*votes.CachedVote)
+
for _, uri := range subjectURIs {
+
key := userDID + ":" + uri
+
if vote, exists := m.cachedVotes[key]; exists {
+
result[uri] = vote
+
}
+
}
+
return result
+
}
+
// TestGetDiscover_ShowsAllCommunities tests discover feed shows posts from ALL communities
func TestGetDiscover_ShowsAllCommunities(t *testing.T) {
if testing.Short() {
···
// Setup services
discoverRepo := postgres.NewDiscoverRepository(db, "test-cursor-secret")
discoverService := discoverCore.NewDiscoverService(discoverRepo)
-
handler := discover.NewGetDiscoverHandler(discoverService)
+
handler := discover.NewGetDiscoverHandler(discoverService, nil) // nil vote service - tests don't need vote state
ctx := context.Background()
testID := time.Now().UnixNano()
···
// Setup services
discoverRepo := postgres.NewDiscoverRepository(db, "test-cursor-secret")
discoverService := discoverCore.NewDiscoverService(discoverRepo)
-
handler := discover.NewGetDiscoverHandler(discoverService)
+
handler := discover.NewGetDiscoverHandler(discoverService, nil) // nil vote service - tests don't need vote state
ctx := context.Background()
testID := time.Now().UnixNano()
···
// Setup services
discoverRepo := postgres.NewDiscoverRepository(db, "test-cursor-secret")
discoverService := discoverCore.NewDiscoverService(discoverRepo)
-
handler := discover.NewGetDiscoverHandler(discoverService)
+
handler := discover.NewGetDiscoverHandler(discoverService, nil)
ctx := context.Background()
testID := time.Now().UnixNano()
···
// Setup services
discoverRepo := postgres.NewDiscoverRepository(db, "test-cursor-secret")
discoverService := discoverCore.NewDiscoverService(discoverRepo)
-
handler := discover.NewGetDiscoverHandler(discoverService)
+
handler := discover.NewGetDiscoverHandler(discoverService, nil)
ctx := context.Background()
testID := time.Now().UnixNano()
···
// Setup services
discoverRepo := postgres.NewDiscoverRepository(db, "test-cursor-secret")
discoverService := discoverCore.NewDiscoverService(discoverRepo)
-
handler := discover.NewGetDiscoverHandler(discoverService)
+
handler := discover.NewGetDiscoverHandler(discoverService, nil)
t.Run("Limit exceeds maximum", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/xrpc/social.coves.feed.getDiscover?sort=new&limit=100", nil)
···
assert.Contains(t, errorResp["message"], "limit")
})
}
+
+
// TestGetDiscover_ViewerVoteState tests that authenticated users see their vote state on posts
+
func TestGetDiscover_ViewerVoteState(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping integration test in short mode")
+
}
+
+
db := setupTestDB(t)
+
t.Cleanup(func() { _ = db.Close() })
+
+
ctx := context.Background()
+
testID := time.Now().UnixNano()
+
+
// Create community and posts
+
communityDID, err := createFeedTestCommunity(db, ctx, fmt.Sprintf("votes-%d", testID), fmt.Sprintf("alice-%d.test", testID))
+
require.NoError(t, err)
+
+
post1URI := createTestPost(t, db, communityDID, "did:plc:author1", "Post with upvote", 10, time.Now().Add(-1*time.Hour))
+
post2URI := createTestPost(t, db, communityDID, "did:plc:author2", "Post with downvote", 5, time.Now().Add(-2*time.Hour))
+
_ = createTestPost(t, db, communityDID, "did:plc:author3", "Post without vote", 3, time.Now().Add(-3*time.Hour))
+
+
// Setup mock vote service with pre-populated votes
+
viewerDID := "did:plc:viewer123"
+
mockVotes := newMockVoteService()
+
mockVotes.AddVote(viewerDID, post1URI, "up", "at://"+viewerDID+"/social.coves.vote/vote1")
+
mockVotes.AddVote(viewerDID, post2URI, "down", "at://"+viewerDID+"/social.coves.vote/vote2")
+
+
// Setup handler with mock vote service
+
discoverRepo := postgres.NewDiscoverRepository(db, "test-cursor-secret")
+
discoverService := discoverCore.NewDiscoverService(discoverRepo)
+
handler := discover.NewGetDiscoverHandler(discoverService, mockVotes)
+
+
// Create request with authenticated user context
+
req := httptest.NewRequest(http.MethodGet, "/xrpc/social.coves.feed.getDiscover?sort=new&limit=50", nil)
+
+
// Inject OAuth session into context (simulates OptionalAuth middleware)
+
did, _ := syntax.ParseDID(viewerDID)
+
session := &oauthlib.ClientSessionData{
+
AccountDID: did,
+
AccessToken: "test_token",
+
}
+
reqCtx := context.WithValue(req.Context(), middleware.UserDIDKey, viewerDID)
+
reqCtx = context.WithValue(reqCtx, middleware.OAuthSessionKey, session)
+
req = req.WithContext(reqCtx)
+
+
rec := httptest.NewRecorder()
+
handler.HandleGetDiscover(rec, req)
+
+
// Assertions
+
assert.Equal(t, http.StatusOK, rec.Code)
+
+
var response discoverCore.DiscoverResponse
+
err = json.Unmarshal(rec.Body.Bytes(), &response)
+
require.NoError(t, err)
+
+
// Find our test posts and verify vote state
+
var foundPost1, foundPost2, foundPost3 bool
+
for _, feedPost := range response.Feed {
+
switch feedPost.Post.URI {
+
case post1URI:
+
foundPost1 = true
+
require.NotNil(t, feedPost.Post.Viewer, "Post1 should have viewer state")
+
require.NotNil(t, feedPost.Post.Viewer.Vote, "Post1 should have vote direction")
+
assert.Equal(t, "up", *feedPost.Post.Viewer.Vote, "Post1 should show upvote")
+
require.NotNil(t, feedPost.Post.Viewer.VoteURI, "Post1 should have vote URI")
+
assert.Contains(t, *feedPost.Post.Viewer.VoteURI, "vote1", "Post1 should have correct vote URI")
+
+
case post2URI:
+
foundPost2 = true
+
require.NotNil(t, feedPost.Post.Viewer, "Post2 should have viewer state")
+
require.NotNil(t, feedPost.Post.Viewer.Vote, "Post2 should have vote direction")
+
assert.Equal(t, "down", *feedPost.Post.Viewer.Vote, "Post2 should show downvote")
+
require.NotNil(t, feedPost.Post.Viewer.VoteURI, "Post2 should have vote URI")
+
+
default:
+
// Posts without votes should have nil Viewer or nil Vote
+
if feedPost.Post.Viewer != nil && feedPost.Post.Viewer.Vote != nil {
+
// This post has a vote from our viewer - it's not post3
+
continue
+
}
+
foundPost3 = true
+
}
+
}
+
+
assert.True(t, foundPost1, "Should find post1 with upvote")
+
assert.True(t, foundPost2, "Should find post2 with downvote")
+
assert.True(t, foundPost3, "Should find post3 without vote")
+
}
+
+
// TestGetDiscover_NoViewerStateWithoutAuth tests that unauthenticated users don't get viewer state
+
func TestGetDiscover_NoViewerStateWithoutAuth(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping integration test in short mode")
+
}
+
+
db := setupTestDB(t)
+
t.Cleanup(func() { _ = db.Close() })
+
+
ctx := context.Background()
+
testID := time.Now().UnixNano()
+
+
// Create community and post
+
communityDID, err := createFeedTestCommunity(db, ctx, fmt.Sprintf("noauth-%d", testID), fmt.Sprintf("alice-%d.test", testID))
+
require.NoError(t, err)
+
+
postURI := createTestPost(t, db, communityDID, "did:plc:author", "Some post", 10, time.Now())
+
+
// Setup mock vote service with a vote (but request will be unauthenticated)
+
mockVotes := newMockVoteService()
+
mockVotes.AddVote("did:plc:someuser", postURI, "up", "at://did:plc:someuser/social.coves.vote/vote1")
+
+
// Setup handler with mock vote service
+
discoverRepo := postgres.NewDiscoverRepository(db, "test-cursor-secret")
+
discoverService := discoverCore.NewDiscoverService(discoverRepo)
+
handler := discover.NewGetDiscoverHandler(discoverService, mockVotes)
+
+
// Create request WITHOUT auth context
+
req := httptest.NewRequest(http.MethodGet, "/xrpc/social.coves.feed.getDiscover?sort=new&limit=50", nil)
+
rec := httptest.NewRecorder()
+
handler.HandleGetDiscover(rec, req)
+
+
// Should succeed
+
assert.Equal(t, http.StatusOK, rec.Code)
+
+
var response discoverCore.DiscoverResponse
+
err = json.Unmarshal(rec.Body.Bytes(), &response)
+
require.NoError(t, err)
+
+
// Find our post and verify NO viewer state (unauthenticated)
+
for _, feedPost := range response.Feed {
+
if feedPost.Post.URI == postURI {
+
assert.Nil(t, feedPost.Post.Viewer, "Unauthenticated request should not have viewer state")
+
return
+
}
+
}
+
t.Fatal("Test post not found in response")
+
}
+11 -11
tests/integration/feed_test.go
···
nil,
)
feedService := communityFeeds.NewCommunityFeedService(feedRepo, communityService)
-
handler := communityFeed.NewGetCommunityHandler(feedService)
+
handler := communityFeed.NewGetCommunityHandler(feedService, nil)
// Setup test data: community, users, and posts
ctx := context.Background()
···
nil,
)
feedService := communityFeeds.NewCommunityFeedService(feedRepo, communityService)
-
handler := communityFeed.NewGetCommunityHandler(feedService)
+
handler := communityFeed.NewGetCommunityHandler(feedService, nil)
// Setup test data
ctx := context.Background()
···
nil,
)
feedService := communityFeeds.NewCommunityFeedService(feedRepo, communityService)
-
handler := communityFeed.NewGetCommunityHandler(feedService)
+
handler := communityFeed.NewGetCommunityHandler(feedService, nil)
// Setup test data
ctx := context.Background()
···
nil,
)
feedService := communityFeeds.NewCommunityFeedService(feedRepo, communityService)
-
handler := communityFeed.NewGetCommunityHandler(feedService)
+
handler := communityFeed.NewGetCommunityHandler(feedService, nil)
// Setup test data with many posts
ctx := context.Background()
···
nil,
)
feedService := communityFeeds.NewCommunityFeedService(feedRepo, communityService)
-
handler := communityFeed.NewGetCommunityHandler(feedService)
+
handler := communityFeed.NewGetCommunityHandler(feedService, nil)
// Request feed for non-existent community
req := httptest.NewRequest(http.MethodGet, "/xrpc/social.coves.communityFeed.getCommunity?community=did:plc:nonexistent&sort=hot&limit=10", nil)
···
nil,
)
feedService := communityFeeds.NewCommunityFeedService(feedRepo, communityService)
-
handler := communityFeed.NewGetCommunityHandler(feedService)
+
handler := communityFeed.NewGetCommunityHandler(feedService, nil)
// Setup test community
ctx := context.Background()
···
nil,
)
feedService := communityFeeds.NewCommunityFeedService(feedRepo, communityService)
-
handler := communityFeed.NewGetCommunityHandler(feedService)
+
handler := communityFeed.NewGetCommunityHandler(feedService, nil)
// Create community with no posts
ctx := context.Background()
···
nil,
)
feedService := communityFeeds.NewCommunityFeedService(feedRepo, communityService)
-
handler := communityFeed.NewGetCommunityHandler(feedService)
+
handler := communityFeed.NewGetCommunityHandler(feedService, nil)
// Setup test community
ctx := context.Background()
···
nil,
)
feedService := communityFeeds.NewCommunityFeedService(feedRepo, communityService)
-
handler := communityFeed.NewGetCommunityHandler(feedService)
+
handler := communityFeed.NewGetCommunityHandler(feedService, nil)
// Setup test data
ctx := context.Background()
···
nil,
)
feedService := communityFeeds.NewCommunityFeedService(feedRepo, communityService)
-
handler := communityFeed.NewGetCommunityHandler(feedService)
+
handler := communityFeed.NewGetCommunityHandler(feedService, nil)
// Setup test data
ctx := context.Background()
···
nil,
)
feedService := communityFeeds.NewCommunityFeedService(feedRepo, communityService)
-
handler := communityFeed.NewGetCommunityHandler(feedService)
+
handler := communityFeed.NewGetCommunityHandler(feedService, nil)
// Setup test data
ctx := context.Background()
+7 -7
tests/integration/timeline_test.go
···
// Setup services
timelineRepo := postgres.NewTimelineRepository(db, "test-cursor-secret")
timelineService := timelineCore.NewTimelineService(timelineRepo)
-
handler := timeline.NewGetTimelineHandler(timelineService)
+
handler := timeline.NewGetTimelineHandler(timelineService, nil)
ctx := context.Background()
testID := time.Now().UnixNano()
···
// Setup services
timelineRepo := postgres.NewTimelineRepository(db, "test-cursor-secret")
timelineService := timelineCore.NewTimelineService(timelineRepo)
-
handler := timeline.NewGetTimelineHandler(timelineService)
+
handler := timeline.NewGetTimelineHandler(timelineService, nil)
ctx := context.Background()
testID := time.Now().UnixNano()
···
// Setup services
timelineRepo := postgres.NewTimelineRepository(db, "test-cursor-secret")
timelineService := timelineCore.NewTimelineService(timelineRepo)
-
handler := timeline.NewGetTimelineHandler(timelineService)
+
handler := timeline.NewGetTimelineHandler(timelineService, nil)
ctx := context.Background()
testID := time.Now().UnixNano()
···
// Setup services
timelineRepo := postgres.NewTimelineRepository(db, "test-cursor-secret")
timelineService := timelineCore.NewTimelineService(timelineRepo)
-
handler := timeline.NewGetTimelineHandler(timelineService)
+
handler := timeline.NewGetTimelineHandler(timelineService, nil)
ctx := context.Background()
testID := time.Now().UnixNano()
···
// Setup services
timelineRepo := postgres.NewTimelineRepository(db, "test-cursor-secret")
timelineService := timelineCore.NewTimelineService(timelineRepo)
-
handler := timeline.NewGetTimelineHandler(timelineService)
+
handler := timeline.NewGetTimelineHandler(timelineService, nil)
// Request timeline WITHOUT auth context
req := httptest.NewRequest(http.MethodGet, "/xrpc/social.coves.feed.getTimeline?sort=new&limit=10", nil)
···
// Setup services
timelineRepo := postgres.NewTimelineRepository(db, "test-cursor-secret")
timelineService := timelineCore.NewTimelineService(timelineRepo)
-
handler := timeline.NewGetTimelineHandler(timelineService)
+
handler := timeline.NewGetTimelineHandler(timelineService, nil)
ctx := context.Background()
testID := time.Now().UnixNano()
···
// Setup services
timelineRepo := postgres.NewTimelineRepository(db, "test-cursor-secret")
timelineService := timelineCore.NewTimelineService(timelineRepo)
-
handler := timeline.NewGetTimelineHandler(timelineService)
+
handler := timeline.NewGetTimelineHandler(timelineService, nil)
ctx := context.Background()
testID := time.Now().UnixNano()
+1 -1
tests/integration/user_journey_e2e_test.go
···
r := chi.NewRouter()
routes.RegisterCommunityRoutes(r, communityService, e2eAuth.OAuthAuthMiddleware, nil) // nil = allow all community creators
routes.RegisterPostRoutes(r, postService, e2eAuth.OAuthAuthMiddleware)
-
routes.RegisterTimelineRoutes(r, timelineService, e2eAuth.OAuthAuthMiddleware)
+
routes.RegisterTimelineRoutes(r, timelineService, nil, e2eAuth.OAuthAuthMiddleware)
httpServer := httptest.NewServer(r)
defer httpServer.Close()