A community based topic aggregation platform built on atproto

Compare changes

Choose any two refs to compare.

+5 -5
docs/COMMENT_SYSTEM_IMPLEMENTATION.md
···
- Lexicon definitions: `social.coves.community.comment.defs` and `getComments`
- Database query methods with Lemmy hot ranking algorithm
- Service layer with iterative loading strategy for nested replies
-
- XRPC HTTP handler with optional authentication
+
- XRPC HTTP handler with optional DPoP authentication
- Comprehensive integration test suite (11 test scenarios)
**What works:**
···
- Nested replies up to configurable depth (default 10, max 100)
- Lemmy hot ranking: `log(greatest(2, score + 2)) / power(time_decay, 1.8)`
- Cursor-based pagination for stable scrolling
-
- Optional authentication for viewer state (stubbed for Phase 2B)
+
- Optional DPoP authentication for viewer state (stubbed for Phase 2B)
- Timeframe filtering for "top" sort (hour/day/week/month/year/all)
**Endpoints:**
···
- Required: `post` (AT-URI)
- Optional: `sort` (hot/top/new), `depth` (0-100), `limit` (1-100), `cursor`, `timeframe`
- Returns: Array of `threadViewComment` with nested replies + post context
-
- Supports Bearer token for authenticated requests (viewer state)
+
- Supports DPoP-bound access token for authenticated requests (viewer state)
**Files created (9):**
1. `internal/atproto/lexicon/social/coves/community/comment/defs.json` - View definitions
···
**8. Viewer Authentication Validation (Non-Issue - Architecture Working as Designed)**
- **Initial Concern:** ViewerDID field trusted without verification in service layer
- **Investigation:** Authentication IS properly validated at middleware layer
-
- `OptionalAuth` middleware extracts and validates JWT Bearer tokens
+
- `OptionalAuth` middleware extracts and validates DPoP-bound access tokens
- Uses PDS public keys (JWKS) for signature verification
-
- Validates token expiration, DID format, issuer
+
- Validates DPoP proof, token expiration, DID format, issuer
- Only injects verified DIDs into request context
- Handler extracts DID using `middleware.GetUserDID(r)`
- **Architecture:** Follows industry best practices (authentication at perimeter)
+7 -4
docs/FEED_SYSTEM_IMPLEMENTATION.md
···
# Get personalized timeline (hot posts from subscriptions)
curl -X GET \
'http://localhost:8081/xrpc/social.coves.feed.getTimeline?sort=hot&limit=15' \
-
-H 'Authorization: Bearer eyJhbGc...'
+
-H 'Authorization: DPoP eyJhbGc...' \
+
-H 'DPoP: eyJhbGc...'
# Get top posts from last week
curl -X GET \
'http://localhost:8081/xrpc/social.coves.feed.getTimeline?sort=top&timeframe=week&limit=20' \
-
-H 'Authorization: Bearer eyJhbGc...'
+
-H 'Authorization: DPoP eyJhbGc...' \
+
-H 'DPoP: eyJhbGc...'
# Get newest posts with pagination
curl -X GET \
'http://localhost:8081/xrpc/social.coves.feed.getTimeline?sort=new&limit=10&cursor=<cursor>' \
-
-H 'Authorization: Bearer eyJhbGc...'
+
-H 'Authorization: DPoP eyJhbGc...' \
+
-H 'DPoP: eyJhbGc...'
```
**Response:**
···
- โœ… Context timeout support
### Authentication (Timeline)
-
- โœ… JWT Bearer token required
+
- โœ… DPoP-bound access token required
- โœ… DID extracted from auth context
- โœ… Validates token signature (when AUTH_SKIP_VERIFY=false)
- โœ… Returns 401 on auth failure
+3 -3
docs/PRD_OAUTH.md
···
- โœ… Auth middleware protecting community endpoints
- โœ… Handlers updated to use `GetUserDID(r)`
- โœ… Comprehensive middleware auth tests (11 test cases)
-
- โœ… E2E tests updated to use Bearer tokens
+
- โœ… E2E tests updated to use DPoP-bound tokens
- โœ… Security logging with IP, method, path, issuer
- โœ… Scope validation (atproto required)
- โœ… Issuer HTTPS validation
···
Authorization: DPoP eyJhbGciOiJFUzI1NiIsInR5cCI6ImF0K2p3dCIsImtpZCI6ImRpZDpwbGM6YWxpY2UjYXRwcm90by1wZHMifQ...
```
-
Format: `DPoP <access_token>`
+
Format: `DPoP <access_token>` (note: uses "DPoP" scheme, not "Bearer")
The access token is a JWT containing:
```json
···
- [x] All community endpoints reject requests without valid JWT structure
- [x] Integration tests pass with mock tokens (11/11 middleware tests passing)
- [x] Zero security regressions from X-User-DID (JWT validation is strictly better)
-
- [x] E2E tests updated to use proper Bearer token authentication
+
- [x] E2E tests updated to use proper DPoP token authentication
- [x] Build succeeds without compilation errors
### Phase 2 (Beta) - โœ… READY FOR TESTING
+3 -1
docs/aggregators/SETUP_GUIDE.md
···
**Request**:
```bash
+
# Note: This calls the PDS directly, so it uses Bearer authorization (not DPoP)
curl -X POST https://bsky.social/xrpc/com.atproto.repo.createRecord \
-H "Authorization: Bearer YOUR_ACCESS_TOKEN" \
-H "Content-Type: application/json" \
···
**Request**:
```bash
+
# Note: This calls the Coves API, so it uses DPoP authorization
curl -X POST https://api.coves.social/xrpc/social.coves.community.post.create \
-
-H "Authorization: Bearer YOUR_ACCESS_TOKEN" \
+
-H "Authorization: DPoP YOUR_ACCESS_TOKEN" \
-H "Content-Type: application/json" \
-d '{
"communityDid": "did:plc:community123...",
+8 -2
docs/federation-prd.md
···
req, _ := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewBuffer(jsonData))
// Use service auth token instead of community credentials
+
// NOTE: Auth scheme depends on target PDS implementation:
+
// - Standard atproto service auth uses "Bearer" scheme
+
// - Our AppView uses "DPoP" scheme when DPoP-bound tokens are required
+
// For server-to-server with standard PDS, use Bearer; adjust based on target.
req.Header.Set("Authorization", "Bearer "+serviceAuthToken)
req.Header.Set("Content-Type", "application/json")
···
**Request to Remote PDS:**
```http
POST https://covesinstance.com/xrpc/com.atproto.server.getServiceAuth
-
Authorization: Bearer {coves-social-instance-jwt}
+
Authorization: DPoP {coves-social-instance-jwt}
+
DPoP: {coves-social-dpop-proof}
Content-Type: application/json
{
···
**Using Token to Create Post:**
```http
POST https://covesinstance.com/xrpc/com.atproto.repo.createRecord
-
Authorization: Bearer {service-auth-token}
+
Authorization: DPoP {service-auth-token}
+
DPoP: {service-auth-dpop-proof}
Content-Type: application/json
{
+1 -1
scripts/aggregator-setup/README.md
···
```bash
curl -X POST https://api.coves.social/xrpc/social.coves.community.post.create \
-
-H "Authorization: Bearer $AGGREGATOR_ACCESS_JWT" \
+
-H "Authorization: DPoP $AGGREGATOR_ACCESS_JWT" \
-H "Content-Type: application/json" \
-d '{
"communityDid": "did:plc:...",
+1
tests/integration/blob_upload_e2e_test.go
···
assert.Equal(t, "POST", r.Method, "Should be POST request")
assert.Equal(t, "/xrpc/com.atproto.repo.uploadBlob", r.URL.Path, "Should hit uploadBlob endpoint")
assert.Equal(t, "image/png", r.Header.Get("Content-Type"), "Should have correct content type")
+
// Note: This is a PDS call, so it uses Bearer (not DPoP)
assert.Contains(t, r.Header.Get("Authorization"), "Bearer ", "Should have auth header")
// Return mock blob reference
+477
internal/atproto/oauth/handlers_security_test.go
···
+
package oauth
+
+
import (
+
"net/http"
+
"net/http/httptest"
+
"testing"
+
+
"github.com/stretchr/testify/assert"
+
"github.com/stretchr/testify/require"
+
)
+
+
// TestIsAllowedMobileRedirectURI tests the mobile redirect URI allowlist with EXACT URI matching
+
// Only Universal Links (HTTPS) are allowed - custom schemes are blocked for security
+
func TestIsAllowedMobileRedirectURI(t *testing.T) {
+
tests := []struct {
+
name string
+
uri string
+
expected bool
+
}{
+
{
+
name: "allowed - Universal Link",
+
uri: "https://coves.social/app/oauth/callback",
+
expected: true,
+
},
+
{
+
name: "rejected - custom scheme coves-app (vulnerable to interception)",
+
uri: "coves-app://oauth/callback",
+
expected: false,
+
},
+
{
+
name: "rejected - custom scheme coves (vulnerable to interception)",
+
uri: "coves://oauth/callback",
+
expected: false,
+
},
+
{
+
name: "rejected - evil scheme",
+
uri: "evil://callback",
+
expected: false,
+
},
+
{
+
name: "rejected - http (not secure)",
+
uri: "http://example.com/callback",
+
expected: false,
+
},
+
{
+
name: "rejected - https different domain",
+
uri: "https://example.com/callback",
+
expected: false,
+
},
+
{
+
name: "rejected - https coves.social wrong path",
+
uri: "https://coves.social/wrong/path",
+
expected: false,
+
},
+
{
+
name: "rejected - invalid URI",
+
uri: "not a uri",
+
expected: false,
+
},
+
{
+
name: "rejected - empty string",
+
uri: "",
+
expected: false,
+
},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
result := isAllowedMobileRedirectURI(tt.uri)
+
assert.Equal(t, tt.expected, result,
+
"isAllowedMobileRedirectURI(%q) = %v, want %v", tt.uri, result, tt.expected)
+
})
+
}
+
}
+
+
// TestExtractScheme tests the scheme extraction function
+
func TestExtractScheme(t *testing.T) {
+
tests := []struct {
+
name string
+
uri string
+
expected string
+
}{
+
{
+
name: "https scheme",
+
uri: "https://coves.social/app/oauth/callback",
+
expected: "https",
+
},
+
{
+
name: "custom scheme",
+
uri: "coves-app://callback",
+
expected: "coves-app",
+
},
+
{
+
name: "invalid URI",
+
uri: "not a uri",
+
expected: "invalid",
+
},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
result := extractScheme(tt.uri)
+
assert.Equal(t, tt.expected, result)
+
})
+
}
+
}
+
+
// TestGenerateCSRFToken tests CSRF token generation
+
func TestGenerateCSRFToken(t *testing.T) {
+
// Generate two tokens and verify they are different (randomness check)
+
token1, err1 := generateCSRFToken()
+
require.NoError(t, err1)
+
require.NotEmpty(t, token1)
+
+
token2, err2 := generateCSRFToken()
+
require.NoError(t, err2)
+
require.NotEmpty(t, token2)
+
+
assert.NotEqual(t, token1, token2, "CSRF tokens should be unique")
+
+
// Verify token is base64 encoded (should decode without error)
+
assert.Greater(t, len(token1), 40, "CSRF token should be reasonably long (32 bytes base64 encoded)")
+
}
+
+
// TestHandleMobileLogin_RedirectURIValidation tests that HandleMobileLogin validates redirect URIs
+
func TestHandleMobileLogin_RedirectURIValidation(t *testing.T) {
+
// Note: This is a unit test for the validation logic only.
+
// Full integration tests with OAuth flow are in tests/integration/oauth_e2e_test.go
+
+
tests := []struct {
+
name string
+
redirectURI string
+
expectedLog string
+
expectedStatus int
+
}{
+
{
+
name: "allowed - Universal Link",
+
redirectURI: "https://coves.social/app/oauth/callback",
+
expectedStatus: http.StatusBadRequest, // Will fail at StartAuthFlow (no OAuth client setup)
+
},
+
{
+
name: "rejected - custom scheme coves-app (insecure)",
+
redirectURI: "coves-app://oauth/callback",
+
expectedStatus: http.StatusBadRequest,
+
expectedLog: "rejected unauthorized mobile redirect URI",
+
},
+
{
+
name: "rejected evil scheme",
+
redirectURI: "evil://callback",
+
expectedStatus: http.StatusBadRequest,
+
expectedLog: "rejected unauthorized mobile redirect URI",
+
},
+
{
+
name: "rejected http",
+
redirectURI: "http://evil.com/callback",
+
expectedStatus: http.StatusBadRequest,
+
expectedLog: "scheme not allowed",
+
},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
// Test the validation function directly
+
result := isAllowedMobileRedirectURI(tt.redirectURI)
+
if tt.expectedLog != "" {
+
assert.False(t, result, "Should reject %s", tt.redirectURI)
+
}
+
})
+
}
+
}
+
+
// TestHandleCallback_CSRFValidation tests that HandleCallback validates CSRF tokens for mobile flow
+
func TestHandleCallback_CSRFValidation(t *testing.T) {
+
// This is a conceptual test structure. Full implementation would require:
+
// 1. Mock OAuthClient
+
// 2. Mock OAuth store
+
// 3. Simulated OAuth callback with cookies
+
+
t.Run("mobile callback requires CSRF token", func(t *testing.T) {
+
// Setup: Create request with mobile_redirect_uri cookie but NO oauth_csrf cookie
+
req := httptest.NewRequest("GET", "/oauth/callback?code=test&state=test", nil)
+
req.AddCookie(&http.Cookie{
+
Name: "mobile_redirect_uri",
+
Value: "https%3A%2F%2Fcoves.social%2Fapp%2Foauth%2Fcallback",
+
})
+
// Missing: oauth_csrf cookie
+
+
// This would be rejected with 403 Forbidden in the actual handler
+
// (Full test in integration tests with real OAuth flow)
+
+
assert.NotNil(t, req) // Placeholder assertion
+
})
+
+
t.Run("mobile callback with valid CSRF token", func(t *testing.T) {
+
// Setup: Create request with both cookies
+
req := httptest.NewRequest("GET", "/oauth/callback?code=test&state=test", nil)
+
req.AddCookie(&http.Cookie{
+
Name: "mobile_redirect_uri",
+
Value: "https%3A%2F%2Fcoves.social%2Fapp%2Foauth%2Fcallback",
+
})
+
req.AddCookie(&http.Cookie{
+
Name: "oauth_csrf",
+
Value: "valid-csrf-token",
+
})
+
+
// This would be accepted (assuming valid OAuth code/state)
+
// (Full test in integration tests with real OAuth flow)
+
+
assert.NotNil(t, req) // Placeholder assertion
+
})
+
}
+
+
// TestHandleMobileCallback_RevalidatesRedirectURI tests that handleMobileCallback re-validates the redirect URI
+
func TestHandleMobileCallback_RevalidatesRedirectURI(t *testing.T) {
+
// This is a critical security test: even if an attacker somehow bypasses the initial check,
+
// the callback handler should re-validate the redirect URI before redirecting.
+
+
tests := []struct {
+
name string
+
redirectURI string
+
shouldPass bool
+
}{
+
{
+
name: "allowed - Universal Link",
+
redirectURI: "https://coves.social/app/oauth/callback",
+
shouldPass: true,
+
},
+
{
+
name: "blocked - custom scheme (insecure)",
+
redirectURI: "coves-app://oauth/callback",
+
shouldPass: false,
+
},
+
{
+
name: "blocked - evil scheme",
+
redirectURI: "evil://callback",
+
shouldPass: false,
+
},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
result := isAllowedMobileRedirectURI(tt.redirectURI)
+
assert.Equal(t, tt.shouldPass, result)
+
})
+
}
+
}
+
+
// TestGenerateMobileRedirectBinding tests the binding token generation
+
// The binding now includes the CSRF token for proper double-submit validation
+
func TestGenerateMobileRedirectBinding(t *testing.T) {
+
csrfToken := "test-csrf-token-12345"
+
tests := []struct {
+
name string
+
redirectURI string
+
}{
+
{
+
name: "Universal Link",
+
redirectURI: "https://coves.social/app/oauth/callback",
+
},
+
{
+
name: "different path",
+
redirectURI: "https://coves.social/different/path",
+
},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
binding1 := generateMobileRedirectBinding(csrfToken, tt.redirectURI)
+
binding2 := generateMobileRedirectBinding(csrfToken, tt.redirectURI)
+
+
// Same CSRF token + URI should produce same binding (deterministic)
+
assert.Equal(t, binding1, binding2, "binding should be deterministic for same inputs")
+
+
// Binding should not be empty
+
assert.NotEmpty(t, binding1, "binding should not be empty")
+
+
// Binding should be base64 encoded (should decode without error)
+
assert.Greater(t, len(binding1), 20, "binding should be reasonably long")
+
})
+
}
+
+
// Different URIs should produce different bindings
+
binding1 := generateMobileRedirectBinding(csrfToken, "https://coves.social/app/oauth/callback")
+
binding2 := generateMobileRedirectBinding(csrfToken, "https://coves.social/different/path")
+
assert.NotEqual(t, binding1, binding2, "different URIs should produce different bindings")
+
+
// Different CSRF tokens should produce different bindings
+
binding3 := generateMobileRedirectBinding("different-csrf-token", "https://coves.social/app/oauth/callback")
+
assert.NotEqual(t, binding1, binding3, "different CSRF tokens should produce different bindings")
+
}
+
+
// TestValidateMobileRedirectBinding tests the binding validation
+
// Now validates both CSRF token and redirect URI together (double-submit pattern)
+
func TestValidateMobileRedirectBinding(t *testing.T) {
+
csrfToken := "test-csrf-token-for-validation"
+
redirectURI := "https://coves.social/app/oauth/callback"
+
validBinding := generateMobileRedirectBinding(csrfToken, redirectURI)
+
+
tests := []struct {
+
name string
+
csrfToken string
+
redirectURI string
+
binding string
+
shouldPass bool
+
}{
+
{
+
name: "valid - correct CSRF token and redirect URI",
+
csrfToken: csrfToken,
+
redirectURI: redirectURI,
+
binding: validBinding,
+
shouldPass: true,
+
},
+
{
+
name: "invalid - wrong redirect URI",
+
csrfToken: csrfToken,
+
redirectURI: "https://coves.social/different/path",
+
binding: validBinding,
+
shouldPass: false,
+
},
+
{
+
name: "invalid - wrong CSRF token",
+
csrfToken: "wrong-csrf-token",
+
redirectURI: redirectURI,
+
binding: validBinding,
+
shouldPass: false,
+
},
+
{
+
name: "invalid - random binding",
+
csrfToken: csrfToken,
+
redirectURI: redirectURI,
+
binding: "random-invalid-binding",
+
shouldPass: false,
+
},
+
{
+
name: "invalid - empty binding",
+
csrfToken: csrfToken,
+
redirectURI: redirectURI,
+
binding: "",
+
shouldPass: false,
+
},
+
{
+
name: "invalid - empty CSRF token",
+
csrfToken: "",
+
redirectURI: redirectURI,
+
binding: validBinding,
+
shouldPass: false,
+
},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
result := validateMobileRedirectBinding(tt.csrfToken, tt.redirectURI, tt.binding)
+
assert.Equal(t, tt.shouldPass, result)
+
})
+
}
+
}
+
+
// TestSessionFixationAttackPrevention tests that the binding prevents session fixation
+
func TestSessionFixationAttackPrevention(t *testing.T) {
+
// Simulate attack scenario:
+
// 1. Attacker plants a cookie for evil://steal with binding for evil://steal
+
// 2. User does a web login (no mobile_redirect_binding cookie)
+
// 3. Callback should NOT redirect to evil://steal
+
+
attackerCSRF := "attacker-csrf-token"
+
attackerRedirectURI := "evil://steal"
+
attackerBinding := generateMobileRedirectBinding(attackerCSRF, attackerRedirectURI)
+
+
// Later, user's legitimate mobile login
+
userCSRF := "user-csrf-token"
+
userRedirectURI := "https://coves.social/app/oauth/callback"
+
userBinding := generateMobileRedirectBinding(userCSRF, userRedirectURI)
+
+
// The attacker's binding should NOT validate for the user's redirect URI
+
assert.False(t, validateMobileRedirectBinding(userCSRF, userRedirectURI, attackerBinding),
+
"attacker's binding should not validate for user's CSRF token and redirect URI")
+
+
// The user's binding should validate for the user's CSRF token and redirect URI
+
assert.True(t, validateMobileRedirectBinding(userCSRF, userRedirectURI, userBinding),
+
"user's binding should validate for user's CSRF token and redirect URI")
+
+
// Cross-validation should fail
+
assert.False(t, validateMobileRedirectBinding(attackerCSRF, attackerRedirectURI, userBinding),
+
"user's binding should not validate for attacker's CSRF token and redirect URI")
+
}
+
+
// TestCSRFTokenValidation tests that CSRF token VALUE is validated, not just presence
+
func TestCSRFTokenValidation(t *testing.T) {
+
// This test verifies the fix for the P1 security issue:
+
// "The callback never validates the token... the csrfToken argument is ignored entirely"
+
//
+
// The fix ensures that the CSRF token VALUE is cryptographically bound to the
+
// binding token, so changing the CSRF token will invalidate the binding.
+
+
t.Run("CSRF token value must match", func(t *testing.T) {
+
originalCSRF := "original-csrf-token-from-login"
+
redirectURI := "https://coves.social/app/oauth/callback"
+
binding := generateMobileRedirectBinding(originalCSRF, redirectURI)
+
+
// Original CSRF token should validate
+
assert.True(t, validateMobileRedirectBinding(originalCSRF, redirectURI, binding),
+
"original CSRF token should validate")
+
+
// Different CSRF token should NOT validate (this is the key security fix)
+
differentCSRF := "attacker-forged-csrf-token"
+
assert.False(t, validateMobileRedirectBinding(differentCSRF, redirectURI, binding),
+
"different CSRF token should NOT validate - this is the security fix")
+
})
+
+
t.Run("attacker cannot forge binding without CSRF token", func(t *testing.T) {
+
// Attacker knows the redirect URI but not the CSRF token
+
redirectURI := "https://coves.social/app/oauth/callback"
+
victimCSRF := "victim-secret-csrf-token"
+
victimBinding := generateMobileRedirectBinding(victimCSRF, redirectURI)
+
+
// Attacker tries various CSRF tokens to forge the binding
+
attackerGuesses := []string{
+
"",
+
"guess1",
+
"attacker-csrf",
+
redirectURI, // trying the redirect URI as CSRF
+
}
+
+
for _, guess := range attackerGuesses {
+
assert.False(t, validateMobileRedirectBinding(guess, redirectURI, victimBinding),
+
"attacker's CSRF guess %q should not validate", guess)
+
}
+
})
+
}
+
+
// TestConstantTimeCompare tests the timing-safe comparison function
+
func TestConstantTimeCompare(t *testing.T) {
+
tests := []struct {
+
name string
+
a string
+
b string
+
expected bool
+
}{
+
{
+
name: "equal strings",
+
a: "abc123",
+
b: "abc123",
+
expected: true,
+
},
+
{
+
name: "different strings same length",
+
a: "abc123",
+
b: "xyz789",
+
expected: false,
+
},
+
{
+
name: "different lengths",
+
a: "short",
+
b: "longer",
+
expected: false,
+
},
+
{
+
name: "empty strings",
+
a: "",
+
b: "",
+
expected: true,
+
},
+
{
+
name: "one empty",
+
a: "abc",
+
b: "",
+
expected: false,
+
},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
result := constantTimeCompare(tt.a, tt.b)
+
assert.Equal(t, tt.expected, result)
+
})
+
}
+
}
+152
internal/atproto/oauth/seal.go
···
+
package oauth
+
+
import (
+
"crypto/aes"
+
"crypto/cipher"
+
"crypto/rand"
+
"encoding/base64"
+
"encoding/json"
+
"fmt"
+
"time"
+
)
+
+
// SealedSession represents the data sealed in a mobile session token
+
type SealedSession struct {
+
DID string `json:"did"` // User's DID
+
SessionID string `json:"sid"` // Session identifier
+
ExpiresAt int64 `json:"exp"` // Unix timestamp when token expires
+
}
+
+
// SealSession creates an encrypted token containing session information.
+
// The token is encrypted using AES-256-GCM and encoded as base64url.
+
//
+
// Token format: base64url(nonce || ciphertext || tag)
+
// - nonce: 12 bytes (GCM standard nonce size)
+
// - ciphertext: encrypted JSON payload
+
// - tag: 16 bytes (GCM authentication tag)
+
//
+
// The sealed token can be safely given to mobile clients and used as
+
// a reference to the server-side session without exposing sensitive data.
+
func (c *OAuthClient) SealSession(did, sessionID string, ttl time.Duration) (string, error) {
+
if len(c.SealSecret) == 0 {
+
return "", fmt.Errorf("seal secret not configured")
+
}
+
+
if did == "" {
+
return "", fmt.Errorf("DID is required")
+
}
+
+
if sessionID == "" {
+
return "", fmt.Errorf("session ID is required")
+
}
+
+
// Create the session data
+
expiresAt := time.Now().Add(ttl).Unix()
+
session := SealedSession{
+
DID: did,
+
SessionID: sessionID,
+
ExpiresAt: expiresAt,
+
}
+
+
// Marshal to JSON
+
plaintext, err := json.Marshal(session)
+
if err != nil {
+
return "", fmt.Errorf("failed to marshal session: %w", err)
+
}
+
+
// Create AES cipher
+
block, err := aes.NewCipher(c.SealSecret)
+
if err != nil {
+
return "", fmt.Errorf("failed to create cipher: %w", err)
+
}
+
+
// Create GCM mode
+
gcm, err := cipher.NewGCM(block)
+
if err != nil {
+
return "", fmt.Errorf("failed to create GCM: %w", err)
+
}
+
+
// Generate random nonce
+
nonce := make([]byte, gcm.NonceSize())
+
if _, err := rand.Read(nonce); err != nil {
+
return "", fmt.Errorf("failed to generate nonce: %w", err)
+
}
+
+
// Encrypt and authenticate
+
// GCM.Seal appends the ciphertext and tag to the nonce
+
ciphertext := gcm.Seal(nonce, nonce, plaintext, nil)
+
+
// Encode as base64url (no padding)
+
token := base64.RawURLEncoding.EncodeToString(ciphertext)
+
+
return token, nil
+
}
+
+
// UnsealSession decrypts and validates a sealed session token.
+
// Returns the session information if the token is valid and not expired.
+
func (c *OAuthClient) UnsealSession(token string) (*SealedSession, error) {
+
if len(c.SealSecret) == 0 {
+
return nil, fmt.Errorf("seal secret not configured")
+
}
+
+
if token == "" {
+
return nil, fmt.Errorf("token is required")
+
}
+
+
// Decode from base64url
+
ciphertext, err := base64.RawURLEncoding.DecodeString(token)
+
if err != nil {
+
return nil, fmt.Errorf("invalid token encoding: %w", err)
+
}
+
+
// Create AES cipher
+
block, err := aes.NewCipher(c.SealSecret)
+
if err != nil {
+
return nil, fmt.Errorf("failed to create cipher: %w", err)
+
}
+
+
// Create GCM mode
+
gcm, err := cipher.NewGCM(block)
+
if err != nil {
+
return nil, fmt.Errorf("failed to create GCM: %w", err)
+
}
+
+
// Verify minimum size (nonce + tag)
+
nonceSize := gcm.NonceSize()
+
if len(ciphertext) < nonceSize {
+
return nil, fmt.Errorf("invalid token: too short")
+
}
+
+
// Extract nonce and ciphertext
+
nonce := ciphertext[:nonceSize]
+
ciphertextData := ciphertext[nonceSize:]
+
+
// Decrypt and authenticate
+
plaintext, err := gcm.Open(nil, nonce, ciphertextData, nil)
+
if err != nil {
+
return nil, fmt.Errorf("failed to decrypt token: %w", err)
+
}
+
+
// Unmarshal JSON
+
var session SealedSession
+
if err := json.Unmarshal(plaintext, &session); err != nil {
+
return nil, fmt.Errorf("failed to unmarshal session: %w", err)
+
}
+
+
// Validate required fields
+
if session.DID == "" {
+
return nil, fmt.Errorf("invalid session: missing DID")
+
}
+
+
if session.SessionID == "" {
+
return nil, fmt.Errorf("invalid session: missing session ID")
+
}
+
+
// Check expiration
+
now := time.Now().Unix()
+
if session.ExpiresAt <= now {
+
return nil, fmt.Errorf("token expired at %v", time.Unix(session.ExpiresAt, 0))
+
}
+
+
return &session, nil
+
}
+331
internal/atproto/oauth/seal_test.go
···
+
package oauth
+
+
import (
+
"crypto/rand"
+
"encoding/base64"
+
"strings"
+
"testing"
+
"time"
+
+
"github.com/stretchr/testify/assert"
+
"github.com/stretchr/testify/require"
+
)
+
+
// generateSealSecret generates a random 32-byte seal secret for testing
+
func generateSealSecret() []byte {
+
secret := make([]byte, 32)
+
if _, err := rand.Read(secret); err != nil {
+
panic(err)
+
}
+
return secret
+
}
+
+
func TestSealSession_RoundTrip(t *testing.T) {
+
// Create client with seal secret
+
client := &OAuthClient{
+
SealSecret: generateSealSecret(),
+
}
+
+
did := "did:plc:abc123"
+
sessionID := "session-xyz"
+
ttl := 1 * time.Hour
+
+
// Seal the session
+
token, err := client.SealSession(did, sessionID, ttl)
+
require.NoError(t, err)
+
require.NotEmpty(t, token)
+
+
// Token should be base64url encoded
+
_, err = base64.RawURLEncoding.DecodeString(token)
+
require.NoError(t, err, "token should be valid base64url")
+
+
// Unseal the session
+
session, err := client.UnsealSession(token)
+
require.NoError(t, err)
+
require.NotNil(t, session)
+
+
// Verify data
+
assert.Equal(t, did, session.DID)
+
assert.Equal(t, sessionID, session.SessionID)
+
+
// Verify expiration is approximately correct (within 1 second)
+
expectedExpiry := time.Now().Add(ttl).Unix()
+
assert.InDelta(t, expectedExpiry, session.ExpiresAt, 1.0)
+
}
+
+
func TestSealSession_ExpirationValidation(t *testing.T) {
+
client := &OAuthClient{
+
SealSecret: generateSealSecret(),
+
}
+
+
did := "did:plc:abc123"
+
sessionID := "session-xyz"
+
ttl := 2 * time.Second // Short TTL (must be >= 1 second due to Unix timestamp granularity)
+
+
// Seal the session
+
token, err := client.SealSession(did, sessionID, ttl)
+
require.NoError(t, err)
+
+
// Should work immediately
+
session, err := client.UnsealSession(token)
+
require.NoError(t, err)
+
assert.Equal(t, did, session.DID)
+
+
// Wait well past expiration
+
time.Sleep(2500 * time.Millisecond)
+
+
// Should fail after expiration
+
session, err = client.UnsealSession(token)
+
assert.Error(t, err)
+
assert.Nil(t, session)
+
assert.Contains(t, err.Error(), "token expired")
+
}
+
+
func TestSealSession_TamperedTokenDetection(t *testing.T) {
+
client := &OAuthClient{
+
SealSecret: generateSealSecret(),
+
}
+
+
did := "did:plc:abc123"
+
sessionID := "session-xyz"
+
ttl := 1 * time.Hour
+
+
// Seal the session
+
token, err := client.SealSession(did, sessionID, ttl)
+
require.NoError(t, err)
+
+
// Tamper with the token by modifying one character
+
tampered := token[:len(token)-5] + "XXXX" + token[len(token)-1:]
+
+
// Should fail to unseal tampered token
+
session, err := client.UnsealSession(tampered)
+
assert.Error(t, err)
+
assert.Nil(t, session)
+
assert.Contains(t, err.Error(), "failed to decrypt token")
+
}
+
+
func TestSealSession_InvalidTokenFormats(t *testing.T) {
+
client := &OAuthClient{
+
SealSecret: generateSealSecret(),
+
}
+
+
tests := []struct {
+
name string
+
token string
+
}{
+
{
+
name: "empty token",
+
token: "",
+
},
+
{
+
name: "invalid base64",
+
token: "not-valid-base64!@#$",
+
},
+
{
+
name: "too short",
+
token: base64.RawURLEncoding.EncodeToString([]byte("short")),
+
},
+
{
+
name: "random bytes",
+
token: base64.RawURLEncoding.EncodeToString(make([]byte, 50)),
+
},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
session, err := client.UnsealSession(tt.token)
+
assert.Error(t, err)
+
assert.Nil(t, session)
+
})
+
}
+
}
+
+
func TestSealSession_DifferentSecrets(t *testing.T) {
+
// Create two clients with different secrets
+
client1 := &OAuthClient{
+
SealSecret: generateSealSecret(),
+
}
+
client2 := &OAuthClient{
+
SealSecret: generateSealSecret(),
+
}
+
+
did := "did:plc:abc123"
+
sessionID := "session-xyz"
+
ttl := 1 * time.Hour
+
+
// Seal with client1
+
token, err := client1.SealSession(did, sessionID, ttl)
+
require.NoError(t, err)
+
+
// Try to unseal with client2 (different secret)
+
session, err := client2.UnsealSession(token)
+
assert.Error(t, err)
+
assert.Nil(t, session)
+
assert.Contains(t, err.Error(), "failed to decrypt token")
+
}
+
+
func TestSealSession_NoSecretConfigured(t *testing.T) {
+
client := &OAuthClient{
+
SealSecret: nil,
+
}
+
+
did := "did:plc:abc123"
+
sessionID := "session-xyz"
+
ttl := 1 * time.Hour
+
+
// Should fail to seal without secret
+
token, err := client.SealSession(did, sessionID, ttl)
+
assert.Error(t, err)
+
assert.Empty(t, token)
+
assert.Contains(t, err.Error(), "seal secret not configured")
+
+
// Should fail to unseal without secret
+
session, err := client.UnsealSession("dummy-token")
+
assert.Error(t, err)
+
assert.Nil(t, session)
+
assert.Contains(t, err.Error(), "seal secret not configured")
+
}
+
+
func TestSealSession_MissingRequiredFields(t *testing.T) {
+
client := &OAuthClient{
+
SealSecret: generateSealSecret(),
+
}
+
+
ttl := 1 * time.Hour
+
+
tests := []struct {
+
name string
+
did string
+
sessionID string
+
errorMsg string
+
}{
+
{
+
name: "missing DID",
+
did: "",
+
sessionID: "session-123",
+
errorMsg: "DID is required",
+
},
+
{
+
name: "missing session ID",
+
did: "did:plc:abc123",
+
sessionID: "",
+
errorMsg: "session ID is required",
+
},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
token, err := client.SealSession(tt.did, tt.sessionID, ttl)
+
assert.Error(t, err)
+
assert.Empty(t, token)
+
assert.Contains(t, err.Error(), tt.errorMsg)
+
})
+
}
+
}
+
+
func TestSealSession_UniquenessPerCall(t *testing.T) {
+
client := &OAuthClient{
+
SealSecret: generateSealSecret(),
+
}
+
+
did := "did:plc:abc123"
+
sessionID := "session-xyz"
+
ttl := 1 * time.Hour
+
+
// Seal the same session twice
+
token1, err := client.SealSession(did, sessionID, ttl)
+
require.NoError(t, err)
+
+
token2, err := client.SealSession(did, sessionID, ttl)
+
require.NoError(t, err)
+
+
// Tokens should be different (different nonces)
+
assert.NotEqual(t, token1, token2, "tokens should be unique due to different nonces")
+
+
// But both should unseal to the same session data
+
session1, err := client.UnsealSession(token1)
+
require.NoError(t, err)
+
+
session2, err := client.UnsealSession(token2)
+
require.NoError(t, err)
+
+
assert.Equal(t, session1.DID, session2.DID)
+
assert.Equal(t, session1.SessionID, session2.SessionID)
+
}
+
+
func TestSealSession_LongDIDAndSessionID(t *testing.T) {
+
client := &OAuthClient{
+
SealSecret: generateSealSecret(),
+
}
+
+
// Test with very long DID and session ID
+
did := "did:plc:" + strings.Repeat("a", 200)
+
sessionID := "session-" + strings.Repeat("x", 200)
+
ttl := 1 * time.Hour
+
+
// Should work with long values
+
token, err := client.SealSession(did, sessionID, ttl)
+
require.NoError(t, err)
+
+
session, err := client.UnsealSession(token)
+
require.NoError(t, err)
+
assert.Equal(t, did, session.DID)
+
assert.Equal(t, sessionID, session.SessionID)
+
}
+
+
func TestSealSession_URLSafeEncoding(t *testing.T) {
+
client := &OAuthClient{
+
SealSecret: generateSealSecret(),
+
}
+
+
did := "did:plc:abc123"
+
sessionID := "session-xyz"
+
ttl := 1 * time.Hour
+
+
// Seal multiple times to get different nonces
+
for i := 0; i < 100; i++ {
+
token, err := client.SealSession(did, sessionID, ttl)
+
require.NoError(t, err)
+
+
// Token should not contain URL-unsafe characters
+
assert.NotContains(t, token, "+", "token should not contain '+'")
+
assert.NotContains(t, token, "/", "token should not contain '/'")
+
assert.NotContains(t, token, "=", "token should not contain '='")
+
+
// Should unseal successfully
+
session, err := client.UnsealSession(token)
+
require.NoError(t, err)
+
assert.Equal(t, did, session.DID)
+
}
+
}
+
+
func TestSealSession_ConcurrentAccess(t *testing.T) {
+
client := &OAuthClient{
+
SealSecret: generateSealSecret(),
+
}
+
+
did := "did:plc:abc123"
+
sessionID := "session-xyz"
+
ttl := 1 * time.Hour
+
+
// Run concurrent seal/unseal operations
+
done := make(chan bool)
+
for i := 0; i < 10; i++ {
+
go func() {
+
for j := 0; j < 100; j++ {
+
token, err := client.SealSession(did, sessionID, ttl)
+
require.NoError(t, err)
+
+
session, err := client.UnsealSession(token)
+
require.NoError(t, err)
+
assert.Equal(t, did, session.DID)
+
}
+
done <- true
+
}()
+
}
+
+
// Wait for all goroutines
+
for i := 0; i < 10; i++ {
+
<-done
+
}
+
}
+614
internal/atproto/oauth/store.go
···
+
package oauth
+
+
import (
+
"context"
+
"database/sql"
+
"errors"
+
"fmt"
+
"log/slog"
+
"net/url"
+
"strings"
+
"time"
+
+
"github.com/bluesky-social/indigo/atproto/auth/oauth"
+
"github.com/bluesky-social/indigo/atproto/syntax"
+
"github.com/lib/pq"
+
)
+
+
var (
+
ErrSessionNotFound = errors.New("oauth session not found")
+
ErrAuthRequestNotFound = errors.New("oauth auth request not found")
+
)
+
+
// PostgresOAuthStore implements oauth.ClientAuthStore interface using PostgreSQL
+
type PostgresOAuthStore struct {
+
db *sql.DB
+
sessionTTL time.Duration
+
}
+
+
// NewPostgresOAuthStore creates a new PostgreSQL-backed OAuth store
+
func NewPostgresOAuthStore(db *sql.DB, sessionTTL time.Duration) oauth.ClientAuthStore {
+
if sessionTTL == 0 {
+
sessionTTL = 7 * 24 * time.Hour // Default to 7 days
+
}
+
return &PostgresOAuthStore{
+
db: db,
+
sessionTTL: sessionTTL,
+
}
+
}
+
+
// GetSession retrieves a session by DID and session ID
+
func (s *PostgresOAuthStore) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauth.ClientSessionData, error) {
+
query := `
+
SELECT
+
did, session_id, host_url, auth_server_iss,
+
auth_server_token_endpoint, auth_server_revocation_endpoint,
+
scopes, access_token, refresh_token,
+
dpop_authserver_nonce, dpop_pds_nonce, dpop_private_key_multibase
+
FROM oauth_sessions
+
WHERE did = $1 AND session_id = $2 AND expires_at > NOW()
+
`
+
+
var session oauth.ClientSessionData
+
var authServerIss, authServerTokenEndpoint, authServerRevocationEndpoint sql.NullString
+
var hostURL, dpopPrivateKeyMultibase sql.NullString
+
var scopes pq.StringArray
+
var dpopAuthServerNonce, dpopHostNonce sql.NullString
+
+
err := s.db.QueryRowContext(ctx, query, did.String(), sessionID).Scan(
+
&session.AccountDID,
+
&session.SessionID,
+
&hostURL,
+
&authServerIss,
+
&authServerTokenEndpoint,
+
&authServerRevocationEndpoint,
+
&scopes,
+
&session.AccessToken,
+
&session.RefreshToken,
+
&dpopAuthServerNonce,
+
&dpopHostNonce,
+
&dpopPrivateKeyMultibase,
+
)
+
+
if err == sql.ErrNoRows {
+
return nil, ErrSessionNotFound
+
}
+
if err != nil {
+
return nil, fmt.Errorf("failed to get session: %w", err)
+
}
+
+
// Convert nullable fields
+
if hostURL.Valid {
+
session.HostURL = hostURL.String
+
}
+
if authServerIss.Valid {
+
session.AuthServerURL = authServerIss.String
+
}
+
if authServerTokenEndpoint.Valid {
+
session.AuthServerTokenEndpoint = authServerTokenEndpoint.String
+
}
+
if authServerRevocationEndpoint.Valid {
+
session.AuthServerRevocationEndpoint = authServerRevocationEndpoint.String
+
}
+
if dpopAuthServerNonce.Valid {
+
session.DPoPAuthServerNonce = dpopAuthServerNonce.String
+
}
+
if dpopHostNonce.Valid {
+
session.DPoPHostNonce = dpopHostNonce.String
+
}
+
if dpopPrivateKeyMultibase.Valid {
+
session.DPoPPrivateKeyMultibase = dpopPrivateKeyMultibase.String
+
}
+
session.Scopes = scopes
+
+
return &session, nil
+
}
+
+
// SaveSession saves or updates a session (upsert operation)
+
func (s *PostgresOAuthStore) SaveSession(ctx context.Context, sess oauth.ClientSessionData) error {
+
// Input validation per atProto OAuth security requirements
+
+
// Validate DID format
+
if _, err := syntax.ParseDID(sess.AccountDID.String()); err != nil {
+
return fmt.Errorf("invalid DID format: %w", err)
+
}
+
+
// Validate token lengths (max 10000 chars to prevent memory issues)
+
const maxTokenLength = 10000
+
if len(sess.AccessToken) > maxTokenLength {
+
return fmt.Errorf("access_token exceeds maximum length of %d characters", maxTokenLength)
+
}
+
if len(sess.RefreshToken) > maxTokenLength {
+
return fmt.Errorf("refresh_token exceeds maximum length of %d characters", maxTokenLength)
+
}
+
+
// Validate session ID is not empty
+
if sess.SessionID == "" {
+
return fmt.Errorf("session_id cannot be empty")
+
}
+
+
// Validate URLs if provided
+
if sess.HostURL != "" {
+
if _, err := url.Parse(sess.HostURL); err != nil {
+
return fmt.Errorf("invalid host_url: %w", err)
+
}
+
}
+
if sess.AuthServerURL != "" {
+
if _, err := url.Parse(sess.AuthServerURL); err != nil {
+
return fmt.Errorf("invalid auth_server URL: %w", err)
+
}
+
}
+
if sess.AuthServerTokenEndpoint != "" {
+
if _, err := url.Parse(sess.AuthServerTokenEndpoint); err != nil {
+
return fmt.Errorf("invalid auth_server_token_endpoint: %w", err)
+
}
+
}
+
if sess.AuthServerRevocationEndpoint != "" {
+
if _, err := url.Parse(sess.AuthServerRevocationEndpoint); err != nil {
+
return fmt.Errorf("invalid auth_server_revocation_endpoint: %w", err)
+
}
+
}
+
+
query := `
+
INSERT INTO oauth_sessions (
+
did, session_id, handle, pds_url, host_url,
+
access_token, refresh_token,
+
dpop_private_jwk, dpop_private_key_multibase,
+
dpop_authserver_nonce, dpop_pds_nonce,
+
auth_server_iss, auth_server_token_endpoint, auth_server_revocation_endpoint,
+
scopes, expires_at, created_at, updated_at
+
) VALUES (
+
$1, $2, $3, $4, $5,
+
$6, $7,
+
NULL, $8,
+
$9, $10,
+
$11, $12, $13,
+
$14, $15, NOW(), NOW()
+
)
+
ON CONFLICT (did, session_id) DO UPDATE SET
+
handle = EXCLUDED.handle,
+
pds_url = EXCLUDED.pds_url,
+
host_url = EXCLUDED.host_url,
+
access_token = EXCLUDED.access_token,
+
refresh_token = EXCLUDED.refresh_token,
+
dpop_private_key_multibase = EXCLUDED.dpop_private_key_multibase,
+
dpop_authserver_nonce = EXCLUDED.dpop_authserver_nonce,
+
dpop_pds_nonce = EXCLUDED.dpop_pds_nonce,
+
auth_server_iss = EXCLUDED.auth_server_iss,
+
auth_server_token_endpoint = EXCLUDED.auth_server_token_endpoint,
+
auth_server_revocation_endpoint = EXCLUDED.auth_server_revocation_endpoint,
+
scopes = EXCLUDED.scopes,
+
expires_at = EXCLUDED.expires_at,
+
updated_at = NOW()
+
`
+
+
// Calculate token expiration using configured TTL
+
expiresAt := time.Now().Add(s.sessionTTL)
+
+
// Convert empty strings to NULL for optional fields
+
var authServerRevocationEndpoint sql.NullString
+
if sess.AuthServerRevocationEndpoint != "" {
+
authServerRevocationEndpoint.String = sess.AuthServerRevocationEndpoint
+
authServerRevocationEndpoint.Valid = true
+
}
+
+
// Extract handle from DID (placeholder - in real implementation, resolve from identity)
+
// For now, use DID as handle since we don't have the handle in ClientSessionData
+
handle := sess.AccountDID.String()
+
+
// Use HostURL as PDS URL
+
pdsURL := sess.HostURL
+
if pdsURL == "" {
+
pdsURL = sess.AuthServerURL // Fallback to auth server URL
+
}
+
+
_, err := s.db.ExecContext(
+
ctx, query,
+
sess.AccountDID.String(),
+
sess.SessionID,
+
handle,
+
pdsURL,
+
sess.HostURL,
+
sess.AccessToken,
+
sess.RefreshToken,
+
sess.DPoPPrivateKeyMultibase,
+
sess.DPoPAuthServerNonce,
+
sess.DPoPHostNonce,
+
sess.AuthServerURL,
+
sess.AuthServerTokenEndpoint,
+
authServerRevocationEndpoint,
+
pq.Array(sess.Scopes),
+
expiresAt,
+
)
+
if err != nil {
+
return fmt.Errorf("failed to save session: %w", err)
+
}
+
+
return nil
+
}
+
+
// DeleteSession deletes a session by DID and session ID
+
func (s *PostgresOAuthStore) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error {
+
query := `DELETE FROM oauth_sessions WHERE did = $1 AND session_id = $2`
+
+
result, err := s.db.ExecContext(ctx, query, did.String(), sessionID)
+
if err != nil {
+
return fmt.Errorf("failed to delete session: %w", err)
+
}
+
+
rows, err := result.RowsAffected()
+
if err != nil {
+
return fmt.Errorf("failed to get rows affected: %w", err)
+
}
+
+
if rows == 0 {
+
return ErrSessionNotFound
+
}
+
+
return nil
+
}
+
+
// GetAuthRequestInfo retrieves auth request information by state
+
func (s *PostgresOAuthStore) GetAuthRequestInfo(ctx context.Context, state string) (*oauth.AuthRequestData, error) {
+
query := `
+
SELECT
+
state, did, handle, pds_url, pkce_verifier,
+
dpop_private_key_multibase, dpop_authserver_nonce,
+
auth_server_iss, request_uri,
+
auth_server_token_endpoint, auth_server_revocation_endpoint,
+
scopes, created_at
+
FROM oauth_requests
+
WHERE state = $1
+
`
+
+
var info oauth.AuthRequestData
+
var did, handle, pdsURL sql.NullString
+
var dpopPrivateKeyMultibase, dpopAuthServerNonce sql.NullString
+
var requestURI, authServerTokenEndpoint, authServerRevocationEndpoint sql.NullString
+
var scopes pq.StringArray
+
var createdAt time.Time
+
+
err := s.db.QueryRowContext(ctx, query, state).Scan(
+
&info.State,
+
&did,
+
&handle,
+
&pdsURL,
+
&info.PKCEVerifier,
+
&dpopPrivateKeyMultibase,
+
&dpopAuthServerNonce,
+
&info.AuthServerURL,
+
&requestURI,
+
&authServerTokenEndpoint,
+
&authServerRevocationEndpoint,
+
&scopes,
+
&createdAt,
+
)
+
+
if err == sql.ErrNoRows {
+
return nil, ErrAuthRequestNotFound
+
}
+
if err != nil {
+
return nil, fmt.Errorf("failed to get auth request info: %w", err)
+
}
+
+
// Parse DID if present
+
if did.Valid && did.String != "" {
+
parsedDID, err := syntax.ParseDID(did.String)
+
if err != nil {
+
return nil, fmt.Errorf("failed to parse DID: %w", err)
+
}
+
info.AccountDID = &parsedDID
+
}
+
+
// Convert nullable fields
+
if dpopPrivateKeyMultibase.Valid {
+
info.DPoPPrivateKeyMultibase = dpopPrivateKeyMultibase.String
+
}
+
if dpopAuthServerNonce.Valid {
+
info.DPoPAuthServerNonce = dpopAuthServerNonce.String
+
}
+
if requestURI.Valid {
+
info.RequestURI = requestURI.String
+
}
+
if authServerTokenEndpoint.Valid {
+
info.AuthServerTokenEndpoint = authServerTokenEndpoint.String
+
}
+
if authServerRevocationEndpoint.Valid {
+
info.AuthServerRevocationEndpoint = authServerRevocationEndpoint.String
+
}
+
info.Scopes = scopes
+
+
return &info, nil
+
}
+
+
// SaveAuthRequestInfo saves auth request information (create only, not upsert)
+
func (s *PostgresOAuthStore) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthRequestData) error {
+
query := `
+
INSERT INTO oauth_requests (
+
state, did, handle, pds_url, pkce_verifier,
+
dpop_private_key_multibase, dpop_authserver_nonce,
+
auth_server_iss, request_uri,
+
auth_server_token_endpoint, auth_server_revocation_endpoint,
+
scopes, return_url, created_at
+
) VALUES (
+
$1, $2, $3, $4, $5,
+
$6, $7,
+
$8, $9,
+
$10, $11,
+
$12, NULL, NOW()
+
)
+
`
+
+
// Extract DID string if present
+
var didStr sql.NullString
+
if info.AccountDID != nil {
+
didStr.String = info.AccountDID.String()
+
didStr.Valid = true
+
}
+
+
// Convert empty strings to NULL for optional fields
+
var authServerRevocationEndpoint sql.NullString
+
if info.AuthServerRevocationEndpoint != "" {
+
authServerRevocationEndpoint.String = info.AuthServerRevocationEndpoint
+
authServerRevocationEndpoint.Valid = true
+
}
+
+
// Placeholder values for handle and pds_url (not in AuthRequestData)
+
// In production, these would be resolved during the auth flow
+
handle := ""
+
pdsURL := ""
+
if info.AccountDID != nil {
+
handle = info.AccountDID.String() // Temporary placeholder
+
pdsURL = info.AuthServerURL // Temporary placeholder
+
}
+
+
_, err := s.db.ExecContext(
+
ctx, query,
+
info.State,
+
didStr,
+
handle,
+
pdsURL,
+
info.PKCEVerifier,
+
info.DPoPPrivateKeyMultibase,
+
info.DPoPAuthServerNonce,
+
info.AuthServerURL,
+
info.RequestURI,
+
info.AuthServerTokenEndpoint,
+
authServerRevocationEndpoint,
+
pq.Array(info.Scopes),
+
)
+
if err != nil {
+
// Check for duplicate state
+
if strings.Contains(err.Error(), "duplicate key") && strings.Contains(err.Error(), "oauth_requests_state_key") {
+
return fmt.Errorf("auth request with state already exists: %s", info.State)
+
}
+
return fmt.Errorf("failed to save auth request info: %w", err)
+
}
+
+
return nil
+
}
+
+
// DeleteAuthRequestInfo deletes auth request information by state
+
func (s *PostgresOAuthStore) DeleteAuthRequestInfo(ctx context.Context, state string) error {
+
query := `DELETE FROM oauth_requests WHERE state = $1`
+
+
result, err := s.db.ExecContext(ctx, query, state)
+
if err != nil {
+
return fmt.Errorf("failed to delete auth request info: %w", err)
+
}
+
+
rows, err := result.RowsAffected()
+
if err != nil {
+
return fmt.Errorf("failed to get rows affected: %w", err)
+
}
+
+
if rows == 0 {
+
return ErrAuthRequestNotFound
+
}
+
+
return nil
+
}
+
+
// CleanupExpiredSessions removes sessions that have expired
+
// Should be called periodically (e.g., via cron job)
+
func (s *PostgresOAuthStore) CleanupExpiredSessions(ctx context.Context) (int64, error) {
+
query := `DELETE FROM oauth_sessions WHERE expires_at < NOW()`
+
+
result, err := s.db.ExecContext(ctx, query)
+
if err != nil {
+
return 0, fmt.Errorf("failed to cleanup expired sessions: %w", err)
+
}
+
+
rows, err := result.RowsAffected()
+
if err != nil {
+
return 0, fmt.Errorf("failed to get rows affected: %w", err)
+
}
+
+
return rows, nil
+
}
+
+
// CleanupExpiredAuthRequests removes auth requests older than 30 minutes
+
// Should be called periodically (e.g., via cron job)
+
func (s *PostgresOAuthStore) CleanupExpiredAuthRequests(ctx context.Context) (int64, error) {
+
query := `DELETE FROM oauth_requests WHERE created_at < NOW() - INTERVAL '30 minutes'`
+
+
result, err := s.db.ExecContext(ctx, query)
+
if err != nil {
+
return 0, fmt.Errorf("failed to cleanup expired auth requests: %w", err)
+
}
+
+
rows, err := result.RowsAffected()
+
if err != nil {
+
return 0, fmt.Errorf("failed to get rows affected: %w", err)
+
}
+
+
return rows, nil
+
}
+
+
// MobileOAuthData holds mobile-specific OAuth flow data
+
type MobileOAuthData struct {
+
CSRFToken string
+
RedirectURI string
+
}
+
+
// mobileFlowContextKey is the context key for mobile flow data
+
type mobileFlowContextKey struct{}
+
+
// ContextWithMobileFlowData adds mobile flow data to a context.
+
// This is used by HandleMobileLogin to pass mobile data to the store wrapper,
+
// which will save it when SaveAuthRequestInfo is called by indigo.
+
func ContextWithMobileFlowData(ctx context.Context, data MobileOAuthData) context.Context {
+
return context.WithValue(ctx, mobileFlowContextKey{}, data)
+
}
+
+
// getMobileFlowDataFromContext retrieves mobile flow data from context, if present
+
func getMobileFlowDataFromContext(ctx context.Context) (MobileOAuthData, bool) {
+
data, ok := ctx.Value(mobileFlowContextKey{}).(MobileOAuthData)
+
return data, ok
+
}
+
+
// MobileAwareClientStore is a marker interface that indicates a store is properly
+
// configured for mobile OAuth flows. Only stores that intercept SaveAuthRequestInfo
+
// to save mobile CSRF data should implement this interface.
+
// This prevents silent mobile OAuth breakage when a plain PostgresOAuthStore is used.
+
type MobileAwareClientStore interface {
+
IsMobileAware() bool
+
}
+
+
// MobileAwareStoreWrapper wraps a ClientAuthStore to automatically save mobile
+
// CSRF data when SaveAuthRequestInfo is called during a mobile OAuth flow.
+
// This is necessary because indigo's StartAuthFlow doesn't expose the OAuth state,
+
// so we intercept the SaveAuthRequestInfo call to capture it.
+
type MobileAwareStoreWrapper struct {
+
oauth.ClientAuthStore
+
mobileStore MobileOAuthStore
+
}
+
+
// IsMobileAware implements MobileAwareClientStore, indicating this store
+
// properly saves mobile CSRF data during OAuth flow initiation.
+
func (w *MobileAwareStoreWrapper) IsMobileAware() bool {
+
return true
+
}
+
+
// NewMobileAwareStoreWrapper creates a wrapper that intercepts SaveAuthRequestInfo
+
// to also save mobile CSRF data when present in context.
+
func NewMobileAwareStoreWrapper(store oauth.ClientAuthStore) *MobileAwareStoreWrapper {
+
wrapper := &MobileAwareStoreWrapper{
+
ClientAuthStore: store,
+
}
+
// Check if the underlying store implements MobileOAuthStore
+
if ms, ok := store.(MobileOAuthStore); ok {
+
wrapper.mobileStore = ms
+
}
+
return wrapper
+
}
+
+
// SaveAuthRequestInfo saves the auth request and also saves mobile CSRF data
+
// if mobile flow data is present in the context.
+
func (w *MobileAwareStoreWrapper) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthRequestData) error {
+
// First, save the auth request to the underlying store
+
if err := w.ClientAuthStore.SaveAuthRequestInfo(ctx, info); err != nil {
+
return err
+
}
+
+
// Check if this is a mobile flow (mobile data in context)
+
if mobileData, ok := getMobileFlowDataFromContext(ctx); ok && w.mobileStore != nil {
+
// Save mobile CSRF data tied to this OAuth state
+
// IMPORTANT: If this fails, we MUST propagate the error. Otherwise:
+
// 1. No server-side CSRF record is stored
+
// 2. Every mobile callback will "fail closed" to web flow
+
// 3. Mobile sign-in silently breaks with no indication
+
// Failing loudly here lets the user retry rather than being confused
+
// about why they're getting a web flow instead of mobile.
+
if err := w.mobileStore.SaveMobileOAuthData(ctx, info.State, mobileData); err != nil {
+
slog.Error("failed to save mobile CSRF data - mobile login will fail",
+
"state", info.State, "error", err)
+
return fmt.Errorf("failed to save mobile OAuth data: %w", err)
+
}
+
}
+
+
return nil
+
}
+
+
// GetMobileOAuthData implements MobileOAuthStore interface by delegating to underlying store
+
func (w *MobileAwareStoreWrapper) GetMobileOAuthData(ctx context.Context, state string) (*MobileOAuthData, error) {
+
if w.mobileStore != nil {
+
return w.mobileStore.GetMobileOAuthData(ctx, state)
+
}
+
return nil, nil
+
}
+
+
// SaveMobileOAuthData implements MobileOAuthStore interface by delegating to underlying store
+
func (w *MobileAwareStoreWrapper) SaveMobileOAuthData(ctx context.Context, state string, data MobileOAuthData) error {
+
if w.mobileStore != nil {
+
return w.mobileStore.SaveMobileOAuthData(ctx, state, data)
+
}
+
return nil
+
}
+
+
// UnwrapPostgresStore returns the underlying PostgresOAuthStore if present.
+
// This is useful for accessing cleanup methods that aren't part of the interface.
+
func (w *MobileAwareStoreWrapper) UnwrapPostgresStore() *PostgresOAuthStore {
+
if ps, ok := w.ClientAuthStore.(*PostgresOAuthStore); ok {
+
return ps
+
}
+
return nil
+
}
+
+
// SaveMobileOAuthData stores mobile CSRF data tied to an OAuth state
+
// This ties the CSRF token to the OAuth flow via the state parameter,
+
// which comes back through the OAuth response for server-side validation.
+
func (s *PostgresOAuthStore) SaveMobileOAuthData(ctx context.Context, state string, data MobileOAuthData) error {
+
query := `
+
UPDATE oauth_requests
+
SET mobile_csrf_token = $2, mobile_redirect_uri = $3
+
WHERE state = $1
+
`
+
+
result, err := s.db.ExecContext(ctx, query, state, data.CSRFToken, data.RedirectURI)
+
if err != nil {
+
return fmt.Errorf("failed to save mobile OAuth data: %w", err)
+
}
+
+
rows, err := result.RowsAffected()
+
if err != nil {
+
return fmt.Errorf("failed to get rows affected: %w", err)
+
}
+
+
if rows == 0 {
+
return ErrAuthRequestNotFound
+
}
+
+
return nil
+
}
+
+
// GetMobileOAuthData retrieves mobile CSRF data by OAuth state
+
// This is called during callback to compare the server-side CSRF token
+
// (retrieved by state from the OAuth response) against the cookie CSRF.
+
func (s *PostgresOAuthStore) GetMobileOAuthData(ctx context.Context, state string) (*MobileOAuthData, error) {
+
query := `
+
SELECT mobile_csrf_token, mobile_redirect_uri
+
FROM oauth_requests
+
WHERE state = $1
+
`
+
+
var csrfToken, redirectURI sql.NullString
+
err := s.db.QueryRowContext(ctx, query, state).Scan(&csrfToken, &redirectURI)
+
+
if err == sql.ErrNoRows {
+
return nil, ErrAuthRequestNotFound
+
}
+
if err != nil {
+
return nil, fmt.Errorf("failed to get mobile OAuth data: %w", err)
+
}
+
+
// Return nil if no mobile data was stored (this was a web flow)
+
if !csrfToken.Valid {
+
return nil, nil
+
}
+
+
return &MobileOAuthData{
+
CSRFToken: csrfToken.String,
+
RedirectURI: redirectURI.String,
+
}, nil
+
}
+522
internal/atproto/oauth/store_test.go
···
+
package oauth
+
+
import (
+
"context"
+
"database/sql"
+
"os"
+
"testing"
+
+
"github.com/bluesky-social/indigo/atproto/auth/oauth"
+
"github.com/bluesky-social/indigo/atproto/syntax"
+
_ "github.com/lib/pq"
+
"github.com/pressly/goose/v3"
+
"github.com/stretchr/testify/assert"
+
"github.com/stretchr/testify/require"
+
)
+
+
// setupTestDB creates a test database connection and runs migrations
+
func setupTestDB(t *testing.T) *sql.DB {
+
dsn := os.Getenv("TEST_DATABASE_URL")
+
if dsn == "" {
+
dsn = "postgres://test_user:test_password@localhost:5434/coves_test?sslmode=disable"
+
}
+
+
db, err := sql.Open("postgres", dsn)
+
require.NoError(t, err, "Failed to connect to test database")
+
+
// Run migrations
+
require.NoError(t, goose.Up(db, "../../db/migrations"), "Failed to run migrations")
+
+
return db
+
}
+
+
// cleanupOAuth removes all test OAuth data from the database
+
func cleanupOAuth(t *testing.T, db *sql.DB) {
+
_, err := db.Exec("DELETE FROM oauth_sessions WHERE did LIKE 'did:plc:test%'")
+
require.NoError(t, err, "Failed to cleanup oauth_sessions")
+
+
_, err = db.Exec("DELETE FROM oauth_requests WHERE state LIKE 'test%'")
+
require.NoError(t, err, "Failed to cleanup oauth_requests")
+
}
+
+
func TestPostgresOAuthStore_SaveAndGetSession(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
defer cleanupOAuth(t, db)
+
+
store := NewPostgresOAuthStore(db, 0) // Use default TTL
+
ctx := context.Background()
+
+
did, err := syntax.ParseDID("did:plc:test123abc")
+
require.NoError(t, err)
+
+
session := oauth.ClientSessionData{
+
AccountDID: did,
+
SessionID: "session123",
+
HostURL: "https://pds.example.com",
+
AuthServerURL: "https://auth.example.com",
+
AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
+
AuthServerRevocationEndpoint: "https://auth.example.com/oauth/revoke",
+
Scopes: []string{"atproto"},
+
AccessToken: "at_test_token_abc123",
+
RefreshToken: "rt_test_token_xyz789",
+
DPoPAuthServerNonce: "nonce_auth_123",
+
DPoPHostNonce: "nonce_host_456",
+
DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
+
}
+
+
// Save session
+
err = store.SaveSession(ctx, session)
+
assert.NoError(t, err)
+
+
// Retrieve session
+
retrieved, err := store.GetSession(ctx, did, "session123")
+
assert.NoError(t, err)
+
assert.NotNil(t, retrieved)
+
assert.Equal(t, session.AccountDID.String(), retrieved.AccountDID.String())
+
assert.Equal(t, session.SessionID, retrieved.SessionID)
+
assert.Equal(t, session.HostURL, retrieved.HostURL)
+
assert.Equal(t, session.AuthServerURL, retrieved.AuthServerURL)
+
assert.Equal(t, session.AuthServerTokenEndpoint, retrieved.AuthServerTokenEndpoint)
+
assert.Equal(t, session.AccessToken, retrieved.AccessToken)
+
assert.Equal(t, session.RefreshToken, retrieved.RefreshToken)
+
assert.Equal(t, session.DPoPAuthServerNonce, retrieved.DPoPAuthServerNonce)
+
assert.Equal(t, session.DPoPHostNonce, retrieved.DPoPHostNonce)
+
assert.Equal(t, session.DPoPPrivateKeyMultibase, retrieved.DPoPPrivateKeyMultibase)
+
assert.Equal(t, session.Scopes, retrieved.Scopes)
+
}
+
+
func TestPostgresOAuthStore_SaveSession_Upsert(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
defer cleanupOAuth(t, db)
+
+
store := NewPostgresOAuthStore(db, 0) // Use default TTL
+
ctx := context.Background()
+
+
did, err := syntax.ParseDID("did:plc:testupsert")
+
require.NoError(t, err)
+
+
// Initial session
+
session1 := oauth.ClientSessionData{
+
AccountDID: did,
+
SessionID: "session_upsert",
+
HostURL: "https://pds1.example.com",
+
AuthServerURL: "https://auth1.example.com",
+
AuthServerTokenEndpoint: "https://auth1.example.com/oauth/token",
+
Scopes: []string{"atproto"},
+
AccessToken: "old_access_token",
+
RefreshToken: "old_refresh_token",
+
DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
+
}
+
+
err = store.SaveSession(ctx, session1)
+
require.NoError(t, err)
+
+
// Updated session (same DID and session ID)
+
session2 := oauth.ClientSessionData{
+
AccountDID: did,
+
SessionID: "session_upsert",
+
HostURL: "https://pds2.example.com",
+
AuthServerURL: "https://auth2.example.com",
+
AuthServerTokenEndpoint: "https://auth2.example.com/oauth/token",
+
Scopes: []string{"atproto", "transition:generic"},
+
AccessToken: "new_access_token",
+
RefreshToken: "new_refresh_token",
+
DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktX",
+
}
+
+
// Save again - should update
+
err = store.SaveSession(ctx, session2)
+
assert.NoError(t, err)
+
+
// Retrieve should get updated values
+
retrieved, err := store.GetSession(ctx, did, "session_upsert")
+
assert.NoError(t, err)
+
assert.Equal(t, "new_access_token", retrieved.AccessToken)
+
assert.Equal(t, "new_refresh_token", retrieved.RefreshToken)
+
assert.Equal(t, "https://pds2.example.com", retrieved.HostURL)
+
assert.Equal(t, []string{"atproto", "transition:generic"}, retrieved.Scopes)
+
}
+
+
func TestPostgresOAuthStore_GetSession_NotFound(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
store := NewPostgresOAuthStore(db, 0) // Use default TTL
+
ctx := context.Background()
+
+
did, err := syntax.ParseDID("did:plc:nonexistent")
+
require.NoError(t, err)
+
+
_, err = store.GetSession(ctx, did, "nonexistent_session")
+
assert.ErrorIs(t, err, ErrSessionNotFound)
+
}
+
+
func TestPostgresOAuthStore_DeleteSession(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
defer cleanupOAuth(t, db)
+
+
store := NewPostgresOAuthStore(db, 0) // Use default TTL
+
ctx := context.Background()
+
+
did, err := syntax.ParseDID("did:plc:testdelete")
+
require.NoError(t, err)
+
+
session := oauth.ClientSessionData{
+
AccountDID: did,
+
SessionID: "session_delete",
+
HostURL: "https://pds.example.com",
+
AuthServerURL: "https://auth.example.com",
+
AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
+
Scopes: []string{"atproto"},
+
AccessToken: "test_token",
+
RefreshToken: "test_refresh",
+
DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
+
}
+
+
// Save session
+
err = store.SaveSession(ctx, session)
+
require.NoError(t, err)
+
+
// Delete session
+
err = store.DeleteSession(ctx, did, "session_delete")
+
assert.NoError(t, err)
+
+
// Verify session is gone
+
_, err = store.GetSession(ctx, did, "session_delete")
+
assert.ErrorIs(t, err, ErrSessionNotFound)
+
}
+
+
func TestPostgresOAuthStore_DeleteSession_NotFound(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
store := NewPostgresOAuthStore(db, 0) // Use default TTL
+
ctx := context.Background()
+
+
did, err := syntax.ParseDID("did:plc:nonexistent")
+
require.NoError(t, err)
+
+
err = store.DeleteSession(ctx, did, "nonexistent_session")
+
assert.ErrorIs(t, err, ErrSessionNotFound)
+
}
+
+
func TestPostgresOAuthStore_SaveAndGetAuthRequestInfo(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
defer cleanupOAuth(t, db)
+
+
store := NewPostgresOAuthStore(db, 0) // Use default TTL
+
ctx := context.Background()
+
+
did, err := syntax.ParseDID("did:plc:testrequest")
+
require.NoError(t, err)
+
+
info := oauth.AuthRequestData{
+
State: "test_state_abc123",
+
AuthServerURL: "https://auth.example.com",
+
AccountDID: &did,
+
Scopes: []string{"atproto"},
+
RequestURI: "urn:ietf:params:oauth:request_uri:abc123",
+
AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
+
AuthServerRevocationEndpoint: "https://auth.example.com/oauth/revoke",
+
PKCEVerifier: "verifier_xyz789",
+
DPoPAuthServerNonce: "nonce_abc",
+
DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
+
}
+
+
// Save auth request info
+
err = store.SaveAuthRequestInfo(ctx, info)
+
assert.NoError(t, err)
+
+
// Retrieve auth request info
+
retrieved, err := store.GetAuthRequestInfo(ctx, "test_state_abc123")
+
assert.NoError(t, err)
+
assert.NotNil(t, retrieved)
+
assert.Equal(t, info.State, retrieved.State)
+
assert.Equal(t, info.AuthServerURL, retrieved.AuthServerURL)
+
assert.NotNil(t, retrieved.AccountDID)
+
assert.Equal(t, info.AccountDID.String(), retrieved.AccountDID.String())
+
assert.Equal(t, info.Scopes, retrieved.Scopes)
+
assert.Equal(t, info.RequestURI, retrieved.RequestURI)
+
assert.Equal(t, info.AuthServerTokenEndpoint, retrieved.AuthServerTokenEndpoint)
+
assert.Equal(t, info.PKCEVerifier, retrieved.PKCEVerifier)
+
assert.Equal(t, info.DPoPAuthServerNonce, retrieved.DPoPAuthServerNonce)
+
assert.Equal(t, info.DPoPPrivateKeyMultibase, retrieved.DPoPPrivateKeyMultibase)
+
}
+
+
func TestPostgresOAuthStore_SaveAuthRequestInfo_NoDID(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
defer cleanupOAuth(t, db)
+
+
store := NewPostgresOAuthStore(db, 0) // Use default TTL
+
ctx := context.Background()
+
+
info := oauth.AuthRequestData{
+
State: "test_state_nodid",
+
AuthServerURL: "https://auth.example.com",
+
AccountDID: nil, // No DID provided
+
Scopes: []string{"atproto"},
+
RequestURI: "urn:ietf:params:oauth:request_uri:nodid",
+
AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
+
PKCEVerifier: "verifier_nodid",
+
DPoPAuthServerNonce: "nonce_nodid",
+
DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
+
}
+
+
// Save auth request info without DID
+
err := store.SaveAuthRequestInfo(ctx, info)
+
assert.NoError(t, err)
+
+
// Retrieve and verify DID is nil
+
retrieved, err := store.GetAuthRequestInfo(ctx, "test_state_nodid")
+
assert.NoError(t, err)
+
assert.Nil(t, retrieved.AccountDID)
+
assert.Equal(t, info.State, retrieved.State)
+
}
+
+
func TestPostgresOAuthStore_GetAuthRequestInfo_NotFound(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
store := NewPostgresOAuthStore(db, 0) // Use default TTL
+
ctx := context.Background()
+
+
_, err := store.GetAuthRequestInfo(ctx, "nonexistent_state")
+
assert.ErrorIs(t, err, ErrAuthRequestNotFound)
+
}
+
+
func TestPostgresOAuthStore_DeleteAuthRequestInfo(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
defer cleanupOAuth(t, db)
+
+
store := NewPostgresOAuthStore(db, 0) // Use default TTL
+
ctx := context.Background()
+
+
info := oauth.AuthRequestData{
+
State: "test_state_delete",
+
AuthServerURL: "https://auth.example.com",
+
Scopes: []string{"atproto"},
+
RequestURI: "urn:ietf:params:oauth:request_uri:delete",
+
AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
+
PKCEVerifier: "verifier_delete",
+
DPoPAuthServerNonce: "nonce_delete",
+
DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
+
}
+
+
// Save auth request info
+
err := store.SaveAuthRequestInfo(ctx, info)
+
require.NoError(t, err)
+
+
// Delete auth request info
+
err = store.DeleteAuthRequestInfo(ctx, "test_state_delete")
+
assert.NoError(t, err)
+
+
// Verify it's gone
+
_, err = store.GetAuthRequestInfo(ctx, "test_state_delete")
+
assert.ErrorIs(t, err, ErrAuthRequestNotFound)
+
}
+
+
func TestPostgresOAuthStore_DeleteAuthRequestInfo_NotFound(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
store := NewPostgresOAuthStore(db, 0) // Use default TTL
+
ctx := context.Background()
+
+
err := store.DeleteAuthRequestInfo(ctx, "nonexistent_state")
+
assert.ErrorIs(t, err, ErrAuthRequestNotFound)
+
}
+
+
func TestPostgresOAuthStore_CleanupExpiredSessions(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
defer cleanupOAuth(t, db)
+
+
storeInterface := NewPostgresOAuthStore(db, 0) // Use default TTL
+
store, ok := storeInterface.(*PostgresOAuthStore)
+
require.True(t, ok, "store should be *PostgresOAuthStore")
+
ctx := context.Background()
+
+
did1, err := syntax.ParseDID("did:plc:testexpired1")
+
require.NoError(t, err)
+
did2, err := syntax.ParseDID("did:plc:testexpired2")
+
require.NoError(t, err)
+
+
// Create an expired session (manually insert with past expiration)
+
_, err = db.ExecContext(ctx, `
+
INSERT INTO oauth_sessions (
+
did, session_id, handle, pds_url, host_url,
+
access_token, refresh_token,
+
dpop_private_key_multibase, auth_server_iss,
+
auth_server_token_endpoint, scopes,
+
expires_at, created_at
+
) VALUES (
+
$1, $2, $3, $4, $5,
+
$6, $7,
+
$8, $9,
+
$10, $11,
+
NOW() - INTERVAL '1 day', NOW()
+
)
+
`, did1.String(), "expired_session", "test.handle", "https://pds.example.com", "https://pds.example.com",
+
"expired_token", "expired_refresh",
+
"z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH", "https://auth.example.com",
+
"https://auth.example.com/oauth/token", `{"atproto"}`)
+
require.NoError(t, err)
+
+
// Create a valid session
+
validSession := oauth.ClientSessionData{
+
AccountDID: did2,
+
SessionID: "valid_session",
+
HostURL: "https://pds.example.com",
+
AuthServerURL: "https://auth.example.com",
+
AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
+
Scopes: []string{"atproto"},
+
AccessToken: "valid_token",
+
RefreshToken: "valid_refresh",
+
DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
+
}
+
err = store.SaveSession(ctx, validSession)
+
require.NoError(t, err)
+
+
// Cleanup expired sessions
+
count, err := store.CleanupExpiredSessions(ctx)
+
assert.NoError(t, err)
+
assert.Equal(t, int64(1), count, "Should delete 1 expired session")
+
+
// Verify expired session is gone
+
_, err = store.GetSession(ctx, did1, "expired_session")
+
assert.ErrorIs(t, err, ErrSessionNotFound)
+
+
// Verify valid session still exists
+
_, err = store.GetSession(ctx, did2, "valid_session")
+
assert.NoError(t, err)
+
}
+
+
func TestPostgresOAuthStore_CleanupExpiredAuthRequests(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
defer cleanupOAuth(t, db)
+
+
storeInterface := NewPostgresOAuthStore(db, 0)
+
pgStore, ok := storeInterface.(*PostgresOAuthStore)
+
require.True(t, ok, "store should be *PostgresOAuthStore")
+
store := oauth.ClientAuthStore(pgStore)
+
ctx := context.Background()
+
+
// Create an old auth request (manually insert with old timestamp)
+
_, err := db.ExecContext(ctx, `
+
INSERT INTO oauth_requests (
+
state, did, handle, pds_url, pkce_verifier,
+
dpop_private_key_multibase, dpop_authserver_nonce,
+
auth_server_iss, request_uri,
+
auth_server_token_endpoint, scopes,
+
created_at
+
) VALUES (
+
$1, $2, $3, $4, $5,
+
$6, $7,
+
$8, $9,
+
$10, $11,
+
NOW() - INTERVAL '1 hour'
+
)
+
`, "test_old_state", "did:plc:testold", "test.handle", "https://pds.example.com",
+
"old_verifier", "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
+
"nonce_old", "https://auth.example.com", "urn:ietf:params:oauth:request_uri:old",
+
"https://auth.example.com/oauth/token", `{"atproto"}`)
+
require.NoError(t, err)
+
+
// Create a recent auth request
+
recentInfo := oauth.AuthRequestData{
+
State: "test_recent_state",
+
AuthServerURL: "https://auth.example.com",
+
Scopes: []string{"atproto"},
+
RequestURI: "urn:ietf:params:oauth:request_uri:recent",
+
AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
+
PKCEVerifier: "recent_verifier",
+
DPoPAuthServerNonce: "nonce_recent",
+
DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
+
}
+
err = store.SaveAuthRequestInfo(ctx, recentInfo)
+
require.NoError(t, err)
+
+
// Cleanup expired auth requests (older than 30 minutes)
+
count, err := pgStore.CleanupExpiredAuthRequests(ctx)
+
assert.NoError(t, err)
+
assert.Equal(t, int64(1), count, "Should delete 1 expired auth request")
+
+
// Verify old request is gone
+
_, err = store.GetAuthRequestInfo(ctx, "test_old_state")
+
assert.ErrorIs(t, err, ErrAuthRequestNotFound)
+
+
// Verify recent request still exists
+
_, err = store.GetAuthRequestInfo(ctx, "test_recent_state")
+
assert.NoError(t, err)
+
}
+
+
func TestPostgresOAuthStore_MultipleSessions(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
defer cleanupOAuth(t, db)
+
+
store := NewPostgresOAuthStore(db, 0) // Use default TTL
+
ctx := context.Background()
+
+
did, err := syntax.ParseDID("did:plc:testmulti")
+
require.NoError(t, err)
+
+
// Create multiple sessions for the same DID
+
session1 := oauth.ClientSessionData{
+
AccountDID: did,
+
SessionID: "browser1",
+
HostURL: "https://pds.example.com",
+
AuthServerURL: "https://auth.example.com",
+
AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
+
Scopes: []string{"atproto"},
+
AccessToken: "token_browser1",
+
RefreshToken: "refresh_browser1",
+
DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
+
}
+
+
session2 := oauth.ClientSessionData{
+
AccountDID: did,
+
SessionID: "mobile_app",
+
HostURL: "https://pds.example.com",
+
AuthServerURL: "https://auth.example.com",
+
AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
+
Scopes: []string{"atproto"},
+
AccessToken: "token_mobile",
+
RefreshToken: "refresh_mobile",
+
DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktX",
+
}
+
+
// Save both sessions
+
err = store.SaveSession(ctx, session1)
+
require.NoError(t, err)
+
err = store.SaveSession(ctx, session2)
+
require.NoError(t, err)
+
+
// Retrieve both sessions
+
retrieved1, err := store.GetSession(ctx, did, "browser1")
+
assert.NoError(t, err)
+
assert.Equal(t, "token_browser1", retrieved1.AccessToken)
+
+
retrieved2, err := store.GetSession(ctx, did, "mobile_app")
+
assert.NoError(t, err)
+
assert.Equal(t, "token_mobile", retrieved2.AccessToken)
+
+
// Delete one session
+
err = store.DeleteSession(ctx, did, "browser1")
+
assert.NoError(t, err)
+
+
// Verify only browser1 is deleted
+
_, err = store.GetSession(ctx, did, "browser1")
+
assert.ErrorIs(t, err, ErrSessionNotFound)
+
+
// mobile_app should still exist
+
_, err = store.GetSession(ctx, did, "mobile_app")
+
assert.NoError(t, err)
+
}
+99
internal/atproto/oauth/transport.go
···
+
package oauth
+
+
import (
+
"fmt"
+
"net"
+
"net/http"
+
"time"
+
)
+
+
// ssrfSafeTransport wraps http.Transport to prevent SSRF attacks
+
type ssrfSafeTransport struct {
+
base *http.Transport
+
allowPrivate bool // For dev/testing only
+
}
+
+
// isPrivateIP checks if an IP is in a private/reserved range
+
func isPrivateIP(ip net.IP) bool {
+
if ip == nil {
+
return false
+
}
+
+
// Check for loopback
+
if ip.IsLoopback() {
+
return true
+
}
+
+
// Check for link-local
+
if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
+
return true
+
}
+
+
// Check for private ranges
+
privateRanges := []string{
+
"10.0.0.0/8",
+
"172.16.0.0/12",
+
"192.168.0.0/16",
+
"169.254.0.0/16",
+
"::1/128",
+
"fc00::/7",
+
"fe80::/10",
+
}
+
+
for _, cidr := range privateRanges {
+
_, network, err := net.ParseCIDR(cidr)
+
if err == nil && network.Contains(ip) {
+
return true
+
}
+
}
+
+
return false
+
}
+
+
func (t *ssrfSafeTransport) RoundTrip(req *http.Request) (*http.Response, error) {
+
host := req.URL.Hostname()
+
+
// Resolve hostname to IP
+
ips, err := net.LookupIP(host)
+
if err != nil {
+
return nil, fmt.Errorf("failed to resolve host: %w", err)
+
}
+
+
// Check all resolved IPs
+
if !t.allowPrivate {
+
for _, ip := range ips {
+
if isPrivateIP(ip) {
+
return nil, fmt.Errorf("SSRF blocked: %s resolves to private IP %s", host, ip)
+
}
+
}
+
}
+
+
return t.base.RoundTrip(req)
+
}
+
+
// NewSSRFSafeHTTPClient creates an HTTP client with SSRF protections
+
func NewSSRFSafeHTTPClient(allowPrivate bool) *http.Client {
+
transport := &ssrfSafeTransport{
+
base: &http.Transport{
+
DialContext: (&net.Dialer{
+
Timeout: 10 * time.Second,
+
KeepAlive: 30 * time.Second,
+
}).DialContext,
+
MaxIdleConns: 100,
+
IdleConnTimeout: 90 * time.Second,
+
TLSHandshakeTimeout: 10 * time.Second,
+
},
+
allowPrivate: allowPrivate,
+
}
+
+
return &http.Client{
+
Timeout: 15 * time.Second,
+
Transport: transport,
+
CheckRedirect: func(req *http.Request, via []*http.Request) error {
+
if len(via) >= 5 {
+
return fmt.Errorf("too many redirects")
+
}
+
return nil
+
},
+
}
+
}
+132
internal/atproto/oauth/transport_test.go
···
+
package oauth
+
+
import (
+
"net"
+
"net/http"
+
"testing"
+
)
+
+
func TestIsPrivateIP(t *testing.T) {
+
tests := []struct {
+
name string
+
ip string
+
expected bool
+
}{
+
// Loopback addresses
+
{"IPv4 loopback", "127.0.0.1", true},
+
{"IPv6 loopback", "::1", true},
+
+
// Private IPv4 ranges
+
{"Private 10.x.x.x", "10.0.0.1", true},
+
{"Private 10.x.x.x edge", "10.255.255.255", true},
+
{"Private 172.16.x.x", "172.16.0.1", true},
+
{"Private 172.31.x.x edge", "172.31.255.255", true},
+
{"Private 192.168.x.x", "192.168.1.1", true},
+
{"Private 192.168.x.x edge", "192.168.255.255", true},
+
+
// Link-local addresses
+
{"Link-local IPv4", "169.254.1.1", true},
+
{"Link-local IPv6", "fe80::1", true},
+
+
// IPv6 private ranges
+
{"IPv6 unique local fc00", "fc00::1", true},
+
{"IPv6 unique local fd00", "fd00::1", true},
+
+
// Public addresses
+
{"Public IP 1.1.1.1", "1.1.1.1", false},
+
{"Public IP 8.8.8.8", "8.8.8.8", false},
+
{"Public IP 172.15.0.1", "172.15.0.1", false}, // Just before 172.16/12
+
{"Public IP 172.32.0.1", "172.32.0.1", false}, // Just after 172.31/12
+
{"Public IP 11.0.0.1", "11.0.0.1", false}, // Just after 10/8
+
{"Public IPv6", "2001:4860:4860::8888", false}, // Google DNS
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
ip := net.ParseIP(tt.ip)
+
if ip == nil {
+
t.Fatalf("Failed to parse IP: %s", tt.ip)
+
}
+
+
result := isPrivateIP(ip)
+
if result != tt.expected {
+
t.Errorf("isPrivateIP(%s) = %v, expected %v", tt.ip, result, tt.expected)
+
}
+
})
+
}
+
}
+
+
func TestIsPrivateIP_NilIP(t *testing.T) {
+
result := isPrivateIP(nil)
+
if result != false {
+
t.Errorf("isPrivateIP(nil) = %v, expected false", result)
+
}
+
}
+
+
func TestNewSSRFSafeHTTPClient(t *testing.T) {
+
tests := []struct {
+
name string
+
allowPrivate bool
+
}{
+
{"Production client (no private IPs)", false},
+
{"Development client (allow private IPs)", true},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
client := NewSSRFSafeHTTPClient(tt.allowPrivate)
+
+
if client == nil {
+
t.Fatal("NewSSRFSafeHTTPClient returned nil")
+
}
+
+
if client.Timeout == 0 {
+
t.Error("Expected timeout to be set")
+
}
+
+
if client.Transport == nil {
+
t.Error("Expected transport to be set")
+
}
+
+
transport, ok := client.Transport.(*ssrfSafeTransport)
+
if !ok {
+
t.Error("Expected ssrfSafeTransport")
+
}
+
+
if transport.allowPrivate != tt.allowPrivate {
+
t.Errorf("Expected allowPrivate=%v, got %v", tt.allowPrivate, transport.allowPrivate)
+
}
+
})
+
}
+
}
+
+
func TestSSRFSafeHTTPClient_RedirectLimit(t *testing.T) {
+
client := NewSSRFSafeHTTPClient(false)
+
+
// Simulate checking redirect limit
+
if client.CheckRedirect == nil {
+
t.Fatal("Expected CheckRedirect to be set")
+
}
+
+
// Test redirect limit (5 redirects)
+
var via []*http.Request
+
for i := 0; i < 5; i++ {
+
req := &http.Request{}
+
via = append(via, req)
+
}
+
+
err := client.CheckRedirect(nil, via)
+
if err == nil {
+
t.Error("Expected error for too many redirects")
+
}
+
if err.Error() != "too many redirects" {
+
t.Errorf("Expected 'too many redirects' error, got: %v", err)
+
}
+
+
// Test within limit (4 redirects)
+
via = via[:4]
+
err = client.CheckRedirect(nil, via)
+
if err != nil {
+
t.Errorf("Expected no error for 4 redirects, got: %v", err)
+
}
+
}
+124
internal/db/migrations/019_update_oauth_for_indigo.sql
···
+
-- +goose Up
+
-- Update OAuth tables to match indigo's ClientAuthStore interface requirements
+
-- This migration adds columns needed for OAuth client sessions and auth requests
+
+
-- Update oauth_requests table
+
-- Add columns for request URI, auth server endpoints, scopes, and DPoP key
+
ALTER TABLE oauth_requests
+
ADD COLUMN request_uri TEXT,
+
ADD COLUMN auth_server_token_endpoint TEXT,
+
ADD COLUMN auth_server_revocation_endpoint TEXT,
+
ADD COLUMN scopes TEXT[],
+
ADD COLUMN dpop_private_key_multibase TEXT;
+
+
-- Make original dpop_private_jwk nullable (we now use dpop_private_key_multibase)
+
ALTER TABLE oauth_requests ALTER COLUMN dpop_private_jwk DROP NOT NULL;
+
+
-- Make did nullable (indigo's AuthRequestData.AccountDID is a pointer - optional)
+
ALTER TABLE oauth_requests ALTER COLUMN did DROP NOT NULL;
+
+
-- Make handle and pds_url nullable too (derived from DID resolution, not always available at auth request time)
+
ALTER TABLE oauth_requests ALTER COLUMN handle DROP NOT NULL;
+
ALTER TABLE oauth_requests ALTER COLUMN pds_url DROP NOT NULL;
+
+
-- Update existing oauth_requests data
+
-- Convert dpop_private_jwk (JSONB) to multibase format if needed
+
-- Note: This will leave the multibase column NULL for now since conversion requires crypto logic
+
-- The application will need to handle NULL values or regenerate keys on next auth flow
+
UPDATE oauth_requests
+
SET
+
auth_server_token_endpoint = auth_server_iss || '/oauth/token',
+
scopes = ARRAY['atproto']::TEXT[]
+
WHERE auth_server_token_endpoint IS NULL;
+
+
-- Add indexes for new columns
+
CREATE INDEX idx_oauth_requests_request_uri ON oauth_requests(request_uri) WHERE request_uri IS NOT NULL;
+
+
-- Update oauth_sessions table
+
-- Add session_id column (will become part of composite key)
+
ALTER TABLE oauth_sessions
+
ADD COLUMN session_id TEXT,
+
ADD COLUMN host_url TEXT,
+
ADD COLUMN auth_server_token_endpoint TEXT,
+
ADD COLUMN auth_server_revocation_endpoint TEXT,
+
ADD COLUMN scopes TEXT[],
+
ADD COLUMN dpop_private_key_multibase TEXT;
+
+
-- Make original dpop_private_jwk nullable (we now use dpop_private_key_multibase)
+
ALTER TABLE oauth_sessions ALTER COLUMN dpop_private_jwk DROP NOT NULL;
+
+
-- Populate session_id for existing sessions (use DID as default for single-session per account)
+
-- In production, you may want to generate unique session IDs
+
UPDATE oauth_sessions
+
SET
+
session_id = 'default',
+
host_url = pds_url,
+
auth_server_token_endpoint = auth_server_iss || '/oauth/token',
+
scopes = ARRAY['atproto']::TEXT[]
+
WHERE session_id IS NULL;
+
+
-- Make session_id NOT NULL after populating existing data
+
ALTER TABLE oauth_sessions
+
ALTER COLUMN session_id SET NOT NULL;
+
+
-- Drop old unique constraint on did only
+
ALTER TABLE oauth_sessions
+
DROP CONSTRAINT IF EXISTS oauth_sessions_did_key;
+
+
-- Create new composite unique constraint for (did, session_id)
+
-- This allows multiple sessions per account
+
-- Note: UNIQUE constraint automatically creates an index, so no separate index needed
+
ALTER TABLE oauth_sessions
+
ADD CONSTRAINT oauth_sessions_did_session_id_key UNIQUE (did, session_id);
+
+
-- Add comment explaining the schema change
+
COMMENT ON COLUMN oauth_sessions.session_id IS 'Session identifier to support multiple concurrent sessions per account';
+
COMMENT ON CONSTRAINT oauth_sessions_did_session_id_key ON oauth_sessions IS 'Composite key allowing multiple sessions per DID';
+
+
-- +goose Down
+
-- Rollback: Remove added columns and restore original unique constraint
+
+
-- oauth_sessions rollback
+
-- Drop composite unique constraint (this also drops the associated index)
+
ALTER TABLE oauth_sessions
+
DROP CONSTRAINT IF EXISTS oauth_sessions_did_session_id_key;
+
+
-- Delete all but the most recent session per DID before restoring unique constraint
+
-- This ensures the UNIQUE (did) constraint can be added without conflicts
+
DELETE FROM oauth_sessions a
+
USING oauth_sessions b
+
WHERE a.did = b.did
+
AND a.created_at < b.created_at;
+
+
-- Restore old unique constraint
+
ALTER TABLE oauth_sessions
+
ADD CONSTRAINT oauth_sessions_did_key UNIQUE (did);
+
+
-- Restore NOT NULL constraint on dpop_private_jwk
+
ALTER TABLE oauth_sessions
+
ALTER COLUMN dpop_private_jwk SET NOT NULL;
+
+
ALTER TABLE oauth_sessions
+
DROP COLUMN IF EXISTS dpop_private_key_multibase,
+
DROP COLUMN IF EXISTS scopes,
+
DROP COLUMN IF EXISTS auth_server_revocation_endpoint,
+
DROP COLUMN IF EXISTS auth_server_token_endpoint,
+
DROP COLUMN IF EXISTS host_url,
+
DROP COLUMN IF EXISTS session_id;
+
+
-- oauth_requests rollback
+
DROP INDEX IF EXISTS idx_oauth_requests_request_uri;
+
+
-- Restore NOT NULL constraints
+
ALTER TABLE oauth_requests
+
ALTER COLUMN dpop_private_jwk SET NOT NULL,
+
ALTER COLUMN did SET NOT NULL,
+
ALTER COLUMN handle SET NOT NULL,
+
ALTER COLUMN pds_url SET NOT NULL;
+
+
ALTER TABLE oauth_requests
+
DROP COLUMN IF EXISTS dpop_private_key_multibase,
+
DROP COLUMN IF EXISTS scopes,
+
DROP COLUMN IF EXISTS auth_server_revocation_endpoint,
+
DROP COLUMN IF EXISTS auth_server_token_endpoint,
+
DROP COLUMN IF EXISTS request_uri;
+23
internal/db/migrations/020_add_mobile_oauth_csrf.sql
···
+
-- +goose Up
+
-- Add columns for mobile OAuth CSRF protection with server-side state
+
-- This ties the CSRF token to the OAuth state, allowing validation against
+
-- a value that comes back through the OAuth response (the state parameter)
+
-- rather than only validating cookies against each other.
+
+
ALTER TABLE oauth_requests
+
ADD COLUMN mobile_csrf_token TEXT,
+
ADD COLUMN mobile_redirect_uri TEXT;
+
+
-- Index for quick lookup of mobile data when callback is received
+
CREATE INDEX idx_oauth_requests_mobile_csrf ON oauth_requests(state)
+
WHERE mobile_csrf_token IS NOT NULL;
+
+
COMMENT ON COLUMN oauth_requests.mobile_csrf_token IS 'CSRF token for mobile OAuth flows, validated against cookie on callback';
+
COMMENT ON COLUMN oauth_requests.mobile_redirect_uri IS 'Mobile redirect URI (Universal Link) for this OAuth flow';
+
+
-- +goose Down
+
DROP INDEX IF EXISTS idx_oauth_requests_mobile_csrf;
+
+
ALTER TABLE oauth_requests
+
DROP COLUMN IF EXISTS mobile_redirect_uri,
+
DROP COLUMN IF EXISTS mobile_csrf_token;
+137
internal/api/handlers/wellknown/universal_links.go
···
+
package wellknown
+
+
import (
+
"encoding/json"
+
"log/slog"
+
"net/http"
+
"os"
+
)
+
+
// HandleAppleAppSiteAssociation serves the iOS Universal Links configuration
+
// GET /.well-known/apple-app-site-association
+
//
+
// Universal Links provide cryptographic binding between the app and domain:
+
// - Requires apple-app-site-association file served over HTTPS
+
// - App must have Associated Domains capability configured
+
// - System verifies domain ownership before routing deep links
+
// - Prevents malicious apps from intercepting deep links
+
//
+
// Spec: https://developer.apple.com/documentation/xcode/supporting-universal-links-in-your-app
+
func HandleAppleAppSiteAssociation(w http.ResponseWriter, r *http.Request) {
+
// Get Apple App ID from environment (format: <Team ID>.<Bundle ID>)
+
// Example: "ABCD1234.social.coves.app"
+
// Find Team ID in Apple Developer Portal -> Membership
+
// Bundle ID is configured in Xcode project
+
appleAppID := os.Getenv("APPLE_APP_ID")
+
if appleAppID == "" {
+
// Development fallback - allows testing without real Team ID
+
// IMPORTANT: This MUST be set in production for Universal Links to work
+
appleAppID = "DEVELOPMENT.social.coves.app"
+
slog.Warn("APPLE_APP_ID not set, using development placeholder",
+
"app_id", appleAppID,
+
"note", "Set APPLE_APP_ID env var for production Universal Links")
+
}
+
+
// Apple requires application/json content type (no charset)
+
w.Header().Set("Content-Type", "application/json")
+
+
// Construct the response per Apple's spec
+
// See: https://developer.apple.com/documentation/bundleresources/applinks
+
response := map[string]interface{}{
+
"applinks": map[string]interface{}{
+
"apps": []string{}, // Must be empty array per Apple spec
+
"details": []map[string]interface{}{
+
{
+
"appID": appleAppID,
+
// Paths that trigger Universal Links when opened in Safari/other apps
+
// These URLs will open the app instead of the browser
+
"paths": []string{
+
"/app/oauth/callback", // Primary Universal Link OAuth callback
+
"/app/oauth/callback/*", // Catch-all for query params
+
},
+
},
+
},
+
},
+
}
+
+
if err := json.NewEncoder(w).Encode(response); err != nil {
+
slog.Error("failed to encode apple-app-site-association", "error", err)
+
http.Error(w, "internal server error", http.StatusInternalServerError)
+
return
+
}
+
+
slog.Debug("served apple-app-site-association", "app_id", appleAppID)
+
}
+
+
// HandleAssetLinks serves the Android App Links configuration
+
// GET /.well-known/assetlinks.json
+
//
+
// App Links provide cryptographic binding between the app and domain:
+
// - Requires assetlinks.json file served over HTTPS
+
// - App must have intent-filter with android:autoVerify="true"
+
// - System verifies domain ownership via SHA-256 certificate fingerprint
+
// - Prevents malicious apps from intercepting deep links
+
//
+
// Spec: https://developer.android.com/training/app-links/verify-android-applinks
+
func HandleAssetLinks(w http.ResponseWriter, r *http.Request) {
+
// Get Android package name from environment
+
// Example: "social.coves.app"
+
androidPackage := os.Getenv("ANDROID_PACKAGE_NAME")
+
if androidPackage == "" {
+
androidPackage = "social.coves.app" // Default for development
+
slog.Warn("ANDROID_PACKAGE_NAME not set, using default",
+
"package", androidPackage,
+
"note", "Set ANDROID_PACKAGE_NAME env var for production App Links")
+
}
+
+
// Get SHA-256 fingerprint from environment
+
// This is the SHA-256 fingerprint of the app's signing certificate
+
//
+
// To get the fingerprint:
+
// Production: keytool -list -v -keystore release.jks -alias release
+
// Debug: keytool -list -v -keystore ~/.android/debug.keystore -alias androiddebugkey -storepass android -keypass android
+
//
+
// Look for "SHA256:" in the output
+
// Format: AA:BB:CC:DD:...:FF (64 hex characters separated by colons)
+
androidFingerprint := os.Getenv("ANDROID_SHA256_FINGERPRINT")
+
if androidFingerprint == "" {
+
// Development fallback - this won't work for real App Links verification
+
// IMPORTANT: This MUST be set in production for App Links to work
+
androidFingerprint = "00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00"
+
slog.Warn("ANDROID_SHA256_FINGERPRINT not set, using development placeholder",
+
"fingerprint", androidFingerprint,
+
"note", "Set ANDROID_SHA256_FINGERPRINT env var for production App Links")
+
}
+
+
w.Header().Set("Content-Type", "application/json")
+
+
// Construct the response per Google's Digital Asset Links spec
+
// See: https://developers.google.com/digital-asset-links/v1/getting-started
+
response := []map[string]interface{}{
+
{
+
// delegate_permission/common.handle_all_urls grants the app permission
+
// to handle URLs for this domain
+
"relation": []string{"delegate_permission/common.handle_all_urls"},
+
"target": map[string]interface{}{
+
"namespace": "android_app",
+
"package_name": androidPackage,
+
// List of certificate fingerprints that can sign the app
+
// Multiple fingerprints can be provided for different signing keys
+
// (e.g., debug + release)
+
"sha256_cert_fingerprints": []string{
+
androidFingerprint,
+
},
+
},
+
},
+
}
+
+
if err := json.NewEncoder(w).Encode(response); err != nil {
+
slog.Error("failed to encode assetlinks.json", "error", err)
+
http.Error(w, "internal server error", http.StatusInternalServerError)
+
return
+
}
+
+
slog.Debug("served assetlinks.json",
+
"package", androidPackage,
+
"fingerprint", androidFingerprint)
+
}
+25
internal/api/routes/wellknown.go
···
+
package routes
+
+
import (
+
"Coves/internal/api/handlers/wellknown"
+
+
"github.com/go-chi/chi/v5"
+
)
+
+
// RegisterWellKnownRoutes registers RFC 8615 well-known URI endpoints
+
// These endpoints are used for service discovery and mobile app deep linking
+
//
+
// Spec: https://www.rfc-editor.org/rfc/rfc8615.html
+
func RegisterWellKnownRoutes(r chi.Router) {
+
// iOS Universal Links configuration
+
// Required for cryptographically-bound deep linking on iOS
+
// Must be served at exact path /.well-known/apple-app-site-association
+
// Content-Type: application/json (no redirects allowed)
+
r.Get("/.well-known/apple-app-site-association", wellknown.HandleAppleAppSiteAssociation)
+
+
// Android App Links configuration
+
// Required for cryptographically-bound deep linking on Android
+
// Must be served at exact path /.well-known/assetlinks.json
+
// Content-Type: application/json (no redirects allowed)
+
r.Get("/.well-known/assetlinks.json", wellknown.HandleAssetLinks)
+
}
+1 -1
internal/api/handlers/comments/middleware.go
···
// The middleware extracts the viewer DID from the Authorization header if present and valid,
// making it available via middleware.GetUserDID(r) in the handler.
// If no valid token is present, the request continues as anonymous (empty DID).
-
func OptionalAuthMiddleware(authMiddleware *middleware.AtProtoAuthMiddleware, next http.HandlerFunc) http.Handler {
+
func OptionalAuthMiddleware(authMiddleware *middleware.OAuthAuthMiddleware, next http.HandlerFunc) http.Handler {
return authMiddleware.OptionalAuth(http.HandlerFunc(next))
}
+1 -1
internal/api/routes/community.go
···
// RegisterCommunityRoutes registers community-related XRPC endpoints on the router
// Implements social.coves.community.* lexicon endpoints
// allowedCommunityCreators restricts who can create communities. If empty, anyone can create.
-
func RegisterCommunityRoutes(r chi.Router, service communities.Service, authMiddleware *middleware.AtProtoAuthMiddleware, allowedCommunityCreators []string) {
+
func RegisterCommunityRoutes(r chi.Router, service communities.Service, authMiddleware *middleware.OAuthAuthMiddleware, allowedCommunityCreators []string) {
// Initialize handlers
createHandler := community.NewCreateHandler(service, allowedCommunityCreators)
getHandler := community.NewGetHandler(service)
+1 -1
internal/api/routes/post.go
···
// RegisterPostRoutes registers post-related XRPC endpoints on the router
// Implements social.coves.community.post.* lexicon endpoints
-
func RegisterPostRoutes(r chi.Router, service posts.Service, authMiddleware *middleware.AtProtoAuthMiddleware) {
+
func RegisterPostRoutes(r chi.Router, service posts.Service, authMiddleware *middleware.OAuthAuthMiddleware) {
// Initialize handlers
createHandler := post.NewCreateHandler(service)
+1 -1
internal/api/routes/timeline.go
···
func RegisterTimelineRoutes(
r chi.Router,
timelineService timelineCore.Service,
-
authMiddleware *middleware.AtProtoAuthMiddleware,
+
authMiddleware *middleware.OAuthAuthMiddleware,
) {
// Create handlers
getTimelineHandler := timeline.NewGetTimelineHandler(timelineService)
+291
tests/e2e/oauth_ratelimit_e2e_test.go
···
+
package e2e
+
+
import (
+
"Coves/internal/api/middleware"
+
"net/http"
+
"net/http/httptest"
+
"testing"
+
"time"
+
+
"github.com/stretchr/testify/assert"
+
)
+
+
// TestRateLimiting_E2E_OAuthEndpoints tests OAuth-specific rate limiting
+
// OAuth endpoints have stricter rate limits to prevent:
+
// - Credential stuffing attacks on login endpoints (10 req/min)
+
// - OAuth state exhaustion
+
// - Refresh token abuse (20 req/min)
+
func TestRateLimiting_E2E_OAuthEndpoints(t *testing.T) {
+
t.Run("Login endpoints have 10 req/min limit", func(t *testing.T) {
+
// Create rate limiter matching oauth.go config: 10 requests per minute
+
loginLimiter := middleware.NewRateLimiter(10, 1*time.Minute)
+
+
// Mock OAuth login handler
+
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
w.WriteHeader(http.StatusOK)
+
_, _ = w.Write([]byte("OK"))
+
})
+
+
handler := loginLimiter.Middleware(testHandler)
+
clientIP := "192.168.1.200:12345"
+
+
// Make exactly 10 requests (at limit)
+
for i := 0; i < 10; i++ {
+
req := httptest.NewRequest("GET", "/oauth/login", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
+
handler.ServeHTTP(rr, req)
+
+
assert.Equal(t, http.StatusOK, rr.Code, "Request %d should succeed", i+1)
+
}
+
+
// 11th request should be rate limited
+
req := httptest.NewRequest("GET", "/oauth/login", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
+
handler.ServeHTTP(rr, req)
+
+
assert.Equal(t, http.StatusTooManyRequests, rr.Code, "Request 11 should be rate limited")
+
assert.Contains(t, rr.Body.String(), "Rate limit exceeded", "Should have rate limit error message")
+
})
+
+
t.Run("Mobile login endpoints have 10 req/min limit", func(t *testing.T) {
+
loginLimiter := middleware.NewRateLimiter(10, 1*time.Minute)
+
+
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
w.WriteHeader(http.StatusOK)
+
})
+
+
handler := loginLimiter.Middleware(testHandler)
+
clientIP := "192.168.1.201:12345"
+
+
// Make 10 requests
+
for i := 0; i < 10; i++ {
+
req := httptest.NewRequest("GET", "/oauth/mobile/login", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
handler.ServeHTTP(rr, req)
+
assert.Equal(t, http.StatusOK, rr.Code)
+
}
+
+
// 11th request blocked
+
req := httptest.NewRequest("GET", "/oauth/mobile/login", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
handler.ServeHTTP(rr, req)
+
+
assert.Equal(t, http.StatusTooManyRequests, rr.Code, "Mobile login should be rate limited at 10 req/min")
+
})
+
+
t.Run("Refresh endpoint has 20 req/min limit", func(t *testing.T) {
+
// Refresh has higher limit (20 req/min) for legitimate token refresh
+
refreshLimiter := middleware.NewRateLimiter(20, 1*time.Minute)
+
+
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
w.WriteHeader(http.StatusOK)
+
})
+
+
handler := refreshLimiter.Middleware(testHandler)
+
clientIP := "192.168.1.202:12345"
+
+
// Make 20 requests
+
for i := 0; i < 20; i++ {
+
req := httptest.NewRequest("POST", "/oauth/refresh", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
handler.ServeHTTP(rr, req)
+
assert.Equal(t, http.StatusOK, rr.Code, "Request %d should succeed", i+1)
+
}
+
+
// 21st request blocked
+
req := httptest.NewRequest("POST", "/oauth/refresh", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
handler.ServeHTTP(rr, req)
+
+
assert.Equal(t, http.StatusTooManyRequests, rr.Code, "Refresh should be rate limited at 20 req/min")
+
})
+
+
t.Run("Logout endpoint has 10 req/min limit", func(t *testing.T) {
+
logoutLimiter := middleware.NewRateLimiter(10, 1*time.Minute)
+
+
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
w.WriteHeader(http.StatusOK)
+
})
+
+
handler := logoutLimiter.Middleware(testHandler)
+
clientIP := "192.168.1.203:12345"
+
+
// Make 10 requests
+
for i := 0; i < 10; i++ {
+
req := httptest.NewRequest("POST", "/oauth/logout", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
handler.ServeHTTP(rr, req)
+
assert.Equal(t, http.StatusOK, rr.Code)
+
}
+
+
// 11th request blocked
+
req := httptest.NewRequest("POST", "/oauth/logout", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
handler.ServeHTTP(rr, req)
+
+
assert.Equal(t, http.StatusTooManyRequests, rr.Code, "Logout should be rate limited at 10 req/min")
+
})
+
+
t.Run("OAuth callback has 10 req/min limit", func(t *testing.T) {
+
// Callback uses same limiter as login (part of auth flow)
+
callbackLimiter := middleware.NewRateLimiter(10, 1*time.Minute)
+
+
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
w.WriteHeader(http.StatusOK)
+
})
+
+
handler := callbackLimiter.Middleware(testHandler)
+
clientIP := "192.168.1.204:12345"
+
+
// Make 10 requests
+
for i := 0; i < 10; i++ {
+
req := httptest.NewRequest("GET", "/oauth/callback", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
handler.ServeHTTP(rr, req)
+
assert.Equal(t, http.StatusOK, rr.Code)
+
}
+
+
// 11th request blocked
+
req := httptest.NewRequest("GET", "/oauth/callback", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
handler.ServeHTTP(rr, req)
+
+
assert.Equal(t, http.StatusTooManyRequests, rr.Code, "Callback should be rate limited at 10 req/min")
+
})
+
+
t.Run("OAuth rate limits are stricter than global limit", func(t *testing.T) {
+
// Verify OAuth limits are more restrictive than global 100 req/min
+
const globalLimit = 100
+
const oauthLoginLimit = 10
+
const oauthRefreshLimit = 20
+
+
assert.Less(t, oauthLoginLimit, globalLimit, "OAuth login limit should be stricter than global")
+
assert.Less(t, oauthRefreshLimit, globalLimit, "OAuth refresh limit should be stricter than global")
+
assert.Greater(t, oauthRefreshLimit, oauthLoginLimit, "Refresh limit should be higher than login (legitimate use case)")
+
})
+
+
t.Run("OAuth limits prevent credential stuffing", func(t *testing.T) {
+
// Simulate credential stuffing attack
+
loginLimiter := middleware.NewRateLimiter(10, 1*time.Minute)
+
+
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
// Simulate failed login attempts
+
w.WriteHeader(http.StatusUnauthorized)
+
})
+
+
handler := loginLimiter.Middleware(testHandler)
+
attackerIP := "203.0.113.50:12345"
+
+
// Attacker tries 15 login attempts (credential stuffing)
+
successfulAttempts := 0
+
blockedAttempts := 0
+
+
for i := 0; i < 15; i++ {
+
req := httptest.NewRequest("GET", "/oauth/login", nil)
+
req.RemoteAddr = attackerIP
+
rr := httptest.NewRecorder()
+
+
handler.ServeHTTP(rr, req)
+
+
if rr.Code == http.StatusUnauthorized {
+
successfulAttempts++ // Reached handler (even if auth failed)
+
} else if rr.Code == http.StatusTooManyRequests {
+
blockedAttempts++
+
}
+
}
+
+
// Rate limiter should block 5 attempts after first 10
+
assert.Equal(t, 10, successfulAttempts, "Should allow 10 login attempts")
+
assert.Equal(t, 5, blockedAttempts, "Should block 5 attempts after limit reached")
+
})
+
+
t.Run("OAuth limits are per-endpoint", func(t *testing.T) {
+
// Each endpoint gets its own rate limiter
+
// This test verifies that limits are independent per endpoint
+
loginLimiter := middleware.NewRateLimiter(10, 1*time.Minute)
+
refreshLimiter := middleware.NewRateLimiter(20, 1*time.Minute)
+
+
loginHandler := loginLimiter.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
w.WriteHeader(http.StatusOK)
+
}))
+
+
refreshHandler := refreshLimiter.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
w.WriteHeader(http.StatusOK)
+
}))
+
+
clientIP := "192.168.1.205:12345"
+
+
// Exhaust login limit
+
for i := 0; i < 10; i++ {
+
req := httptest.NewRequest("GET", "/oauth/login", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
loginHandler.ServeHTTP(rr, req)
+
assert.Equal(t, http.StatusOK, rr.Code)
+
}
+
+
// Login limit exhausted
+
req := httptest.NewRequest("GET", "/oauth/login", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
loginHandler.ServeHTTP(rr, req)
+
assert.Equal(t, http.StatusTooManyRequests, rr.Code, "Login should be rate limited")
+
+
// Refresh endpoint should still work (independent limiter)
+
req = httptest.NewRequest("POST", "/oauth/refresh", nil)
+
req.RemoteAddr = clientIP
+
rr = httptest.NewRecorder()
+
refreshHandler.ServeHTTP(rr, req)
+
assert.Equal(t, http.StatusOK, rr.Code, "Refresh should not be affected by login rate limit")
+
})
+
}
+
+
// OAuth Rate Limiting Configuration Documentation
+
// ================================================
+
// This test file validates OAuth-specific rate limits applied in oauth.go:
+
//
+
// 1. Login Endpoints (Credential Stuffing Protection)
+
// - Endpoints: /oauth/login, /oauth/mobile/login, /oauth/callback
+
// - Limit: 10 requests per minute per IP
+
// - Reason: Prevent brute force and credential stuffing attacks
+
// - Implementation: internal/api/routes/oauth.go:21
+
//
+
// 2. Refresh Endpoint (Token Refresh)
+
// - Endpoint: /oauth/refresh
+
// - Limit: 20 requests per minute per IP
+
// - Reason: Allow legitimate token refresh while preventing abuse
+
// - Implementation: internal/api/routes/oauth.go:24
+
//
+
// 3. Logout Endpoint
+
// - Endpoint: /oauth/logout
+
// - Limit: 10 requests per minute per IP
+
// - Reason: Prevent session exhaustion attacks
+
// - Implementation: internal/api/routes/oauth.go:27
+
//
+
// 4. Metadata Endpoints (No Extra Limit)
+
// - Endpoints: /oauth/client-metadata.json, /oauth/jwks.json
+
// - Limit: Global 100 requests per minute (from main.go)
+
// - Reason: Public metadata, not sensitive to rate abuse
+
//
+
// Security Benefits:
+
// - Credential Stuffing: Limits password guessing to 10 attempts/min
+
// - State Exhaustion: Prevents OAuth state generation spam
+
// - Token Abuse: Limits refresh token usage while allowing legitimate refresh
+
//
+
// Rate Limit Hierarchy:
+
// - OAuth login: 10 req/min (most restrictive)
+
// - OAuth refresh: 20 req/min (moderate)
+
// - Comments: 20 req/min (expensive queries)
+
// - Global: 100 req/min (baseline)
+910
tests/integration/oauth_e2e_test.go
···
+
package integration
+
+
import (
+
"Coves/internal/atproto/oauth"
+
"context"
+
"encoding/json"
+
"fmt"
+
"net/http"
+
"net/http/httptest"
+
"strings"
+
"testing"
+
"time"
+
+
oauthlib "github.com/bluesky-social/indigo/atproto/auth/oauth"
+
"github.com/bluesky-social/indigo/atproto/syntax"
+
"github.com/go-chi/chi/v5"
+
_ "github.com/lib/pq"
+
"github.com/pressly/goose/v3"
+
"github.com/stretchr/testify/assert"
+
"github.com/stretchr/testify/require"
+
)
+
+
// TestOAuth_Components tests OAuth component functionality without requiring PDS.
+
// This validates all Coves OAuth code:
+
// - Session storage and retrieval (PostgreSQL)
+
// - Token sealing (AES-GCM encryption)
+
// - Token unsealing (decryption + validation)
+
// - Session cleanup
+
//
+
// NOTE: Full OAuth redirect flow testing requires both HTTPS PDS and HTTPS Coves deployment.
+
// The OAuth redirect flow is handled by indigo's library and enforces OAuth 2.0 spec
+
// (HTTPS required for authorization servers and redirect URIs).
+
func TestOAuth_Components(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping OAuth component test in short mode")
+
}
+
+
// Setup test database
+
db := setupTestDB(t)
+
defer func() {
+
if err := db.Close(); err != nil {
+
t.Logf("Failed to close database: %v", err)
+
}
+
}()
+
+
// Run migrations to ensure OAuth tables exist
+
require.NoError(t, goose.SetDialect("postgres"))
+
require.NoError(t, goose.Up(db, "../../internal/db/migrations"))
+
+
t.Log("๐Ÿ”ง Testing OAuth Components")
+
+
ctx := context.Background()
+
+
// Setup OAuth client and store
+
store := SetupOAuthTestStore(t, db)
+
client := SetupOAuthTestClient(t, store)
+
require.NotNil(t, client, "OAuth client should be initialized")
+
+
// Use a test DID (doesn't need to exist on PDS for component tests)
+
testDID := "did:plc:componenttest123"
+
+
// Run component tests
+
testOAuthComponentsWithMockedSession(t, ctx, nil, store, client, testDID, "")
+
+
t.Log("")
+
t.Log(strings.Repeat("=", 60))
+
t.Log("โœ… OAuth Component Tests Complete")
+
t.Log(strings.Repeat("=", 60))
+
t.Log("Components validated:")
+
t.Log(" โœ“ Session storage (PostgreSQL)")
+
t.Log(" โœ“ Token sealing (AES-GCM encryption)")
+
t.Log(" โœ“ Token unsealing (decryption + validation)")
+
t.Log(" โœ“ Session cleanup")
+
t.Log("")
+
t.Log("NOTE: Full OAuth redirect flow requires HTTPS PDS + HTTPS Coves")
+
t.Log(strings.Repeat("=", 60))
+
}
+
+
// testOAuthComponentsWithMockedSession tests OAuth components that work without PDS redirect flow.
+
// This is used when testing with localhost PDS, where the indigo library rejects http:// URLs.
+
func testOAuthComponentsWithMockedSession(t *testing.T, ctx context.Context, _ interface{}, store oauthlib.ClientAuthStore, client *oauth.OAuthClient, userDID, _ string) {
+
t.Helper()
+
+
t.Log("๐Ÿ”ง Testing OAuth components with mocked session...")
+
+
// Parse DID
+
parsedDID, err := syntax.ParseDID(userDID)
+
require.NoError(t, err, "Should parse DID")
+
+
// Component 1: Session Storage
+
t.Log(" ๐Ÿ“ฆ Component 1: Testing session storage...")
+
testSession := oauthlib.ClientSessionData{
+
AccountDID: parsedDID,
+
SessionID: fmt.Sprintf("localhost-test-%d", time.Now().UnixNano()),
+
HostURL: "http://localhost:3001",
+
AccessToken: "mocked-access-token",
+
Scopes: []string{"atproto", "transition:generic"},
+
}
+
+
err = store.SaveSession(ctx, testSession)
+
require.NoError(t, err, "Should save session")
+
+
retrieved, err := store.GetSession(ctx, parsedDID, testSession.SessionID)
+
require.NoError(t, err, "Should retrieve session")
+
require.Equal(t, testSession.SessionID, retrieved.SessionID)
+
require.Equal(t, testSession.AccessToken, retrieved.AccessToken)
+
t.Log(" โœ… Session storage working")
+
+
// Component 2: Token Sealing
+
t.Log(" ๐Ÿ” Component 2: Testing token sealing...")
+
sealedToken, err := client.SealSession(parsedDID.String(), testSession.SessionID, time.Hour)
+
require.NoError(t, err, "Should seal token")
+
require.NotEmpty(t, sealedToken, "Sealed token should not be empty")
+
tokenPreview := sealedToken
+
if len(tokenPreview) > 50 {
+
tokenPreview = tokenPreview[:50]
+
}
+
t.Logf(" โœ… Token sealed: %s...", tokenPreview)
+
+
// Component 3: Token Unsealing
+
t.Log(" ๐Ÿ”“ Component 3: Testing token unsealing...")
+
unsealed, err := client.UnsealSession(sealedToken)
+
require.NoError(t, err, "Should unseal token")
+
require.Equal(t, userDID, unsealed.DID)
+
require.Equal(t, testSession.SessionID, unsealed.SessionID)
+
t.Log(" โœ… Token unsealing working")
+
+
// Component 4: Session Cleanup
+
t.Log(" ๐Ÿงน Component 4: Testing session cleanup...")
+
err = store.DeleteSession(ctx, parsedDID, testSession.SessionID)
+
require.NoError(t, err, "Should delete session")
+
+
_, err = store.GetSession(ctx, parsedDID, testSession.SessionID)
+
require.Error(t, err, "Session should not exist after deletion")
+
t.Log(" โœ… Session cleanup working")
+
+
t.Log("โœ… All OAuth components verified!")
+
t.Log("")
+
t.Log("๐Ÿ“ Summary: OAuth implementation validated with mocked session")
+
t.Log(" - Session storage: โœ“")
+
t.Log(" - Token sealing: โœ“")
+
t.Log(" - Token unsealing: โœ“")
+
t.Log(" - Session cleanup: โœ“")
+
t.Log("")
+
t.Log("โš ๏ธ To test full OAuth redirect flow, use a production PDS with HTTPS")
+
}
+
+
// TestOAuthE2E_TokenExpiration tests that expired sealed tokens are rejected
+
func TestOAuthE2E_TokenExpiration(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping OAuth token expiration test in short mode")
+
}
+
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
// Run migrations
+
require.NoError(t, goose.SetDialect("postgres"))
+
require.NoError(t, goose.Up(db, "../../internal/db/migrations"))
+
+
ctx := context.Background()
+
+
t.Log("โฐ Testing OAuth token expiration...")
+
+
// Setup OAuth client and store
+
store := SetupOAuthTestStore(t, db)
+
client := SetupOAuthTestClient(t, store)
+
_ = oauth.NewOAuthHandler(client, store) // Handler created for completeness
+
+
// Create test session with past expiration
+
did, err := syntax.ParseDID("did:plc:expiredtest123")
+
require.NoError(t, err)
+
+
testSession := oauthlib.ClientSessionData{
+
AccountDID: did,
+
SessionID: "expired-session",
+
HostURL: "http://localhost:3001",
+
AccessToken: "expired-token",
+
Scopes: []string{"atproto"},
+
}
+
+
// Save session
+
err = store.SaveSession(ctx, testSession)
+
require.NoError(t, err)
+
+
// Manually update expiration to the past
+
_, err = db.ExecContext(ctx,
+
"UPDATE oauth_sessions SET expires_at = NOW() - INTERVAL '1 day' WHERE did = $1 AND session_id = $2",
+
did.String(), testSession.SessionID)
+
require.NoError(t, err)
+
+
// Try to retrieve expired session
+
_, err = store.GetSession(ctx, did, testSession.SessionID)
+
assert.Error(t, err, "Should not be able to retrieve expired session")
+
assert.Equal(t, oauth.ErrSessionNotFound, err, "Should return ErrSessionNotFound for expired session")
+
+
// Test cleanup of expired sessions
+
cleaned, err := store.(*oauth.PostgresOAuthStore).CleanupExpiredSessions(ctx)
+
require.NoError(t, err, "Cleanup should succeed")
+
assert.Greater(t, cleaned, int64(0), "Should have cleaned up at least one session")
+
+
t.Logf("โœ… Expired session handling verified (cleaned %d sessions)", cleaned)
+
}
+
+
// TestOAuthE2E_InvalidToken tests that invalid/tampered tokens are rejected
+
func TestOAuthE2E_InvalidToken(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping OAuth invalid token test in short mode")
+
}
+
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
// Run migrations
+
require.NoError(t, goose.SetDialect("postgres"))
+
require.NoError(t, goose.Up(db, "../../internal/db/migrations"))
+
+
t.Log("๐Ÿ”’ Testing OAuth invalid token rejection...")
+
+
// Setup OAuth client and store
+
store := SetupOAuthTestStore(t, db)
+
client := SetupOAuthTestClient(t, store)
+
handler := oauth.NewOAuthHandler(client, store)
+
+
// Setup test server with protected endpoint
+
r := chi.NewRouter()
+
r.Get("/api/me", func(w http.ResponseWriter, r *http.Request) {
+
sessData, err := handler.GetSessionFromRequest(r)
+
if err != nil {
+
http.Error(w, "Unauthorized", http.StatusUnauthorized)
+
return
+
}
+
w.Header().Set("Content-Type", "application/json")
+
_ = json.NewEncoder(w).Encode(map[string]string{"did": sessData.AccountDID.String()})
+
})
+
+
server := httptest.NewServer(r)
+
defer server.Close()
+
+
// Test with invalid token formats
+
testCases := []struct {
+
name string
+
token string
+
}{
+
{"Empty token", ""},
+
{"Invalid base64", "not-valid-base64!!!"},
+
{"Tampered token", "dGFtcGVyZWQtdG9rZW4tZGF0YQ=="}, // Valid base64 but invalid content
+
{"Short token", "abc"},
+
}
+
+
for _, tc := range testCases {
+
t.Run(tc.name, func(t *testing.T) {
+
req, _ := http.NewRequest("GET", server.URL+"/api/me", nil)
+
if tc.token != "" {
+
req.Header.Set("Authorization", "Bearer "+tc.token)
+
}
+
+
resp, err := http.DefaultClient.Do(req)
+
require.NoError(t, err)
+
defer func() { _ = resp.Body.Close() }()
+
+
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode,
+
"Invalid token should be rejected with 401")
+
})
+
}
+
+
t.Logf("โœ… Invalid token rejection verified")
+
}
+
+
// TestOAuthE2E_SessionNotFound tests behavior when session doesn't exist in DB
+
func TestOAuthE2E_SessionNotFound(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping OAuth session not found test in short mode")
+
}
+
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
// Run migrations
+
require.NoError(t, goose.SetDialect("postgres"))
+
require.NoError(t, goose.Up(db, "../../internal/db/migrations"))
+
+
ctx := context.Background()
+
+
t.Log("๐Ÿ” Testing OAuth session not found behavior...")
+
+
// Setup OAuth store
+
store := SetupOAuthTestStore(t, db)
+
+
// Try to retrieve non-existent session
+
nonExistentDID, err := syntax.ParseDID("did:plc:nonexistent123")
+
require.NoError(t, err)
+
+
_, err = store.GetSession(ctx, nonExistentDID, "nonexistent-session")
+
assert.Error(t, err, "Should return error for non-existent session")
+
assert.Equal(t, oauth.ErrSessionNotFound, err, "Should return ErrSessionNotFound")
+
+
// Try to delete non-existent session
+
err = store.DeleteSession(ctx, nonExistentDID, "nonexistent-session")
+
assert.Error(t, err, "Should return error when deleting non-existent session")
+
assert.Equal(t, oauth.ErrSessionNotFound, err, "Should return ErrSessionNotFound")
+
+
t.Logf("โœ… Session not found handling verified")
+
}
+
+
// TestOAuthE2E_MultipleSessionsPerUser tests that a user can have multiple active sessions
+
func TestOAuthE2E_MultipleSessionsPerUser(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping OAuth multiple sessions test in short mode")
+
}
+
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
// Run migrations
+
require.NoError(t, goose.SetDialect("postgres"))
+
require.NoError(t, goose.Up(db, "../../internal/db/migrations"))
+
+
ctx := context.Background()
+
+
t.Log("๐Ÿ‘ฅ Testing multiple OAuth sessions per user...")
+
+
// Setup OAuth store
+
store := SetupOAuthTestStore(t, db)
+
+
// Create a test DID
+
did, err := syntax.ParseDID("did:plc:multisession123")
+
require.NoError(t, err)
+
+
// Create multiple sessions for the same user
+
sessions := []oauthlib.ClientSessionData{
+
{
+
AccountDID: did,
+
SessionID: "session-1-web",
+
HostURL: "http://localhost:3001",
+
AccessToken: "token-1",
+
Scopes: []string{"atproto"},
+
},
+
{
+
AccountDID: did,
+
SessionID: "session-2-mobile",
+
HostURL: "http://localhost:3001",
+
AccessToken: "token-2",
+
Scopes: []string{"atproto"},
+
},
+
{
+
AccountDID: did,
+
SessionID: "session-3-tablet",
+
HostURL: "http://localhost:3001",
+
AccessToken: "token-3",
+
Scopes: []string{"atproto"},
+
},
+
}
+
+
// Save all sessions
+
for i, session := range sessions {
+
err := store.SaveSession(ctx, session)
+
require.NoError(t, err, "Should be able to save session %d", i+1)
+
}
+
+
t.Logf("โœ… Created %d sessions for user", len(sessions))
+
+
// Verify all sessions can be retrieved independently
+
for i, session := range sessions {
+
retrieved, err := store.GetSession(ctx, did, session.SessionID)
+
require.NoError(t, err, "Should be able to retrieve session %d", i+1)
+
assert.Equal(t, session.SessionID, retrieved.SessionID, "Session ID should match")
+
assert.Equal(t, session.AccessToken, retrieved.AccessToken, "Access token should match")
+
}
+
+
t.Logf("โœ… All sessions retrieved independently")
+
+
// Delete one session and verify others remain
+
err = store.DeleteSession(ctx, did, sessions[0].SessionID)
+
require.NoError(t, err, "Should be able to delete first session")
+
+
// Verify first session is deleted
+
_, err = store.GetSession(ctx, did, sessions[0].SessionID)
+
assert.Equal(t, oauth.ErrSessionNotFound, err, "First session should be deleted")
+
+
// Verify other sessions still exist
+
for i := 1; i < len(sessions); i++ {
+
_, err := store.GetSession(ctx, did, sessions[i].SessionID)
+
require.NoError(t, err, "Session %d should still exist", i+1)
+
}
+
+
t.Logf("โœ… Multiple sessions per user verified")
+
+
// Cleanup
+
for i := 1; i < len(sessions); i++ {
+
_ = store.DeleteSession(ctx, did, sessions[i].SessionID)
+
}
+
}
+
+
// TestOAuthE2E_AuthRequestStorage tests OAuth auth request storage and retrieval
+
func TestOAuthE2E_AuthRequestStorage(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping OAuth auth request storage test in short mode")
+
}
+
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
// Run migrations
+
require.NoError(t, goose.SetDialect("postgres"))
+
require.NoError(t, goose.Up(db, "../../internal/db/migrations"))
+
+
ctx := context.Background()
+
+
t.Log("๐Ÿ“ Testing OAuth auth request storage...")
+
+
// Setup OAuth store
+
store := SetupOAuthTestStore(t, db)
+
+
// Create test auth request data
+
did, err := syntax.ParseDID("did:plc:authrequest123")
+
require.NoError(t, err)
+
+
authRequest := oauthlib.AuthRequestData{
+
State: "test-state-12345",
+
AccountDID: &did,
+
PKCEVerifier: "test-pkce-verifier",
+
DPoPPrivateKeyMultibase: "test-dpop-key",
+
DPoPAuthServerNonce: "test-nonce",
+
AuthServerURL: "http://localhost:3001",
+
RequestURI: "http://localhost:3001/authorize",
+
AuthServerTokenEndpoint: "http://localhost:3001/oauth/token",
+
AuthServerRevocationEndpoint: "http://localhost:3001/oauth/revoke",
+
Scopes: []string{"atproto", "transition:generic"},
+
}
+
+
// Save auth request
+
err = store.SaveAuthRequestInfo(ctx, authRequest)
+
require.NoError(t, err, "Should be able to save auth request")
+
+
t.Logf("โœ… Auth request saved")
+
+
// Retrieve auth request
+
retrieved, err := store.GetAuthRequestInfo(ctx, authRequest.State)
+
require.NoError(t, err, "Should be able to retrieve auth request")
+
assert.Equal(t, authRequest.State, retrieved.State, "State should match")
+
assert.Equal(t, authRequest.PKCEVerifier, retrieved.PKCEVerifier, "PKCE verifier should match")
+
assert.Equal(t, authRequest.AuthServerURL, retrieved.AuthServerURL, "Auth server URL should match")
+
assert.Equal(t, len(authRequest.Scopes), len(retrieved.Scopes), "Scopes length should match")
+
+
t.Logf("โœ… Auth request retrieved and verified")
+
+
// Test duplicate state error
+
err = store.SaveAuthRequestInfo(ctx, authRequest)
+
assert.Error(t, err, "Should not allow duplicate state")
+
assert.Contains(t, err.Error(), "already exists", "Error should indicate duplicate")
+
+
t.Logf("โœ… Duplicate state prevention verified")
+
+
// Delete auth request
+
err = store.DeleteAuthRequestInfo(ctx, authRequest.State)
+
require.NoError(t, err, "Should be able to delete auth request")
+
+
// Verify deletion
+
_, err = store.GetAuthRequestInfo(ctx, authRequest.State)
+
assert.Equal(t, oauth.ErrAuthRequestNotFound, err, "Auth request should be deleted")
+
+
t.Logf("โœ… Auth request deletion verified")
+
+
// Test cleanup of expired auth requests
+
// Create an auth request and manually set created_at to the past
+
oldAuthRequest := oauthlib.AuthRequestData{
+
State: "old-state-12345",
+
PKCEVerifier: "old-verifier",
+
AuthServerURL: "http://localhost:3001",
+
Scopes: []string{"atproto"},
+
}
+
+
err = store.SaveAuthRequestInfo(ctx, oldAuthRequest)
+
require.NoError(t, err)
+
+
// Update created_at to 1 hour ago
+
_, err = db.ExecContext(ctx,
+
"UPDATE oauth_requests SET created_at = NOW() - INTERVAL '1 hour' WHERE state = $1",
+
oldAuthRequest.State)
+
require.NoError(t, err)
+
+
// Cleanup expired requests
+
cleaned, err := store.(*oauth.PostgresOAuthStore).CleanupExpiredAuthRequests(ctx)
+
require.NoError(t, err, "Cleanup should succeed")
+
assert.Greater(t, cleaned, int64(0), "Should have cleaned up at least one auth request")
+
+
t.Logf("โœ… Expired auth request cleanup verified (cleaned %d requests)", cleaned)
+
}
+
+
// TestOAuthE2E_TokenRefresh tests the refresh token flow
+
func TestOAuthE2E_TokenRefresh(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping OAuth token refresh test in short mode")
+
}
+
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
// Run migrations
+
require.NoError(t, goose.SetDialect("postgres"))
+
require.NoError(t, goose.Up(db, "../../internal/db/migrations"))
+
+
ctx := context.Background()
+
+
t.Log("๐Ÿ”„ Testing OAuth token refresh flow...")
+
+
// Setup OAuth client and store
+
store := SetupOAuthTestStore(t, db)
+
client := SetupOAuthTestClient(t, store)
+
handler := oauth.NewOAuthHandler(client, store)
+
+
// Create a test DID and session
+
did, err := syntax.ParseDID("did:plc:refreshtest123")
+
require.NoError(t, err)
+
+
// Create initial session with refresh token
+
initialSession := oauthlib.ClientSessionData{
+
AccountDID: did,
+
SessionID: "refresh-session-1",
+
HostURL: "http://localhost:3001",
+
AuthServerURL: "http://localhost:3001",
+
AuthServerTokenEndpoint: "http://localhost:3001/oauth/token",
+
AuthServerRevocationEndpoint: "http://localhost:3001/oauth/revoke",
+
AccessToken: "initial-access-token",
+
RefreshToken: "initial-refresh-token",
+
DPoPPrivateKeyMultibase: "test-dpop-key",
+
DPoPAuthServerNonce: "test-nonce",
+
Scopes: []string{"atproto", "transition:generic"},
+
}
+
+
// Save the session
+
err = store.SaveSession(ctx, initialSession)
+
require.NoError(t, err, "Should save initial session")
+
+
t.Logf("โœ… Initial session created")
+
+
// Create a sealed token for this session
+
sealedToken, err := client.SealSession(did.String(), initialSession.SessionID, time.Hour)
+
require.NoError(t, err, "Should seal session token")
+
require.NotEmpty(t, sealedToken, "Sealed token should not be empty")
+
+
t.Logf("โœ… Session token sealed")
+
+
// Setup test server with refresh endpoint
+
r := chi.NewRouter()
+
r.Post("/oauth/refresh", handler.HandleRefresh)
+
+
server := httptest.NewServer(r)
+
defer server.Close()
+
+
t.Run("Valid refresh request", func(t *testing.T) {
+
// NOTE: This test verifies that the refresh endpoint can be called
+
// In a real scenario, the indigo client's RefreshTokens() would call the PDS
+
// Since we're in a component test, we're testing the Coves handler logic
+
+
// Create refresh request
+
refreshReq := map[string]interface{}{
+
"did": did.String(),
+
"session_id": initialSession.SessionID,
+
"sealed_token": sealedToken,
+
}
+
+
reqBody, err := json.Marshal(refreshReq)
+
require.NoError(t, err)
+
+
req, err := http.NewRequest("POST", server.URL+"/oauth/refresh", strings.NewReader(string(reqBody)))
+
require.NoError(t, err)
+
req.Header.Set("Content-Type", "application/json")
+
+
// NOTE: In component testing mode, the indigo client may not have
+
// real PDS credentials, so RefreshTokens() might fail
+
// We're testing that the handler correctly processes the request
+
resp, err := http.DefaultClient.Do(req)
+
require.NoError(t, err)
+
defer func() { _ = resp.Body.Close() }()
+
+
// In component test mode without real PDS, we may get 401
+
// In production with real PDS, this would return 200 with new tokens
+
t.Logf("Refresh response status: %d", resp.StatusCode)
+
+
// The important thing is that the handler doesn't crash
+
// and properly validates the request structure
+
assert.True(t, resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusUnauthorized,
+
"Refresh should return either success or auth failure, got %d", resp.StatusCode)
+
})
+
+
t.Run("Invalid DID format (with valid token)", func(t *testing.T) {
+
// Create a sealed token with an invalid DID format
+
invalidDID := "invalid-did-format"
+
// Create the token with a valid DID first, then we'll try to use it with invalid DID in request
+
validToken, err := client.SealSession(did.String(), initialSession.SessionID, 30*24*time.Hour)
+
require.NoError(t, err)
+
+
refreshReq := map[string]interface{}{
+
"did": invalidDID, // Invalid DID format in request
+
"session_id": initialSession.SessionID,
+
"sealed_token": validToken, // Valid token for different DID
+
}
+
+
reqBody, err := json.Marshal(refreshReq)
+
require.NoError(t, err)
+
+
req, err := http.NewRequest("POST", server.URL+"/oauth/refresh", strings.NewReader(string(reqBody)))
+
require.NoError(t, err)
+
req.Header.Set("Content-Type", "application/json")
+
+
resp, err := http.DefaultClient.Do(req)
+
require.NoError(t, err)
+
defer func() { _ = resp.Body.Close() }()
+
+
// Should reject with 401 due to DID mismatch (not 400) since auth happens first
+
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode,
+
"DID mismatch should be rejected with 401 (auth check happens before format validation)")
+
})
+
+
t.Run("Missing sealed_token (security test)", func(t *testing.T) {
+
refreshReq := map[string]interface{}{
+
"did": did.String(),
+
"session_id": initialSession.SessionID,
+
// Missing sealed_token - should be rejected for security
+
}
+
+
reqBody, err := json.Marshal(refreshReq)
+
require.NoError(t, err)
+
+
req, err := http.NewRequest("POST", server.URL+"/oauth/refresh", strings.NewReader(string(reqBody)))
+
require.NoError(t, err)
+
req.Header.Set("Content-Type", "application/json")
+
+
resp, err := http.DefaultClient.Do(req)
+
require.NoError(t, err)
+
defer func() { _ = resp.Body.Close() }()
+
+
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode,
+
"Missing sealed_token should be rejected (proof of possession required)")
+
})
+
+
t.Run("Invalid sealed_token", func(t *testing.T) {
+
refreshReq := map[string]interface{}{
+
"did": did.String(),
+
"session_id": initialSession.SessionID,
+
"sealed_token": "invalid-token-data",
+
}
+
+
reqBody, err := json.Marshal(refreshReq)
+
require.NoError(t, err)
+
+
req, err := http.NewRequest("POST", server.URL+"/oauth/refresh", strings.NewReader(string(reqBody)))
+
require.NoError(t, err)
+
req.Header.Set("Content-Type", "application/json")
+
+
resp, err := http.DefaultClient.Do(req)
+
require.NoError(t, err)
+
defer func() { _ = resp.Body.Close() }()
+
+
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode,
+
"Invalid sealed_token should be rejected")
+
})
+
+
t.Run("DID mismatch (security test)", func(t *testing.T) {
+
// Create a sealed token for a different DID
+
wrongDID := "did:plc:wronguser123"
+
wrongToken, err := client.SealSession(wrongDID, initialSession.SessionID, 30*24*time.Hour)
+
require.NoError(t, err)
+
+
// Try to use it to refresh the original session
+
refreshReq := map[string]interface{}{
+
"did": did.String(), // Claiming original DID
+
"session_id": initialSession.SessionID,
+
"sealed_token": wrongToken, // But token is for different DID
+
}
+
+
reqBody, err := json.Marshal(refreshReq)
+
require.NoError(t, err)
+
+
req, err := http.NewRequest("POST", server.URL+"/oauth/refresh", strings.NewReader(string(reqBody)))
+
require.NoError(t, err)
+
req.Header.Set("Content-Type", "application/json")
+
+
resp, err := http.DefaultClient.Do(req)
+
require.NoError(t, err)
+
defer func() { _ = resp.Body.Close() }()
+
+
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode,
+
"DID mismatch should be rejected (prevents session hijacking)")
+
})
+
+
t.Run("Session ID mismatch (security test)", func(t *testing.T) {
+
// Create a sealed token with wrong session ID
+
wrongSessionID := "wrong-session-id"
+
wrongToken, err := client.SealSession(did.String(), wrongSessionID, 30*24*time.Hour)
+
require.NoError(t, err)
+
+
// Try to use it to refresh the original session
+
refreshReq := map[string]interface{}{
+
"did": did.String(),
+
"session_id": initialSession.SessionID, // Claiming original session
+
"sealed_token": wrongToken, // But token is for different session
+
}
+
+
reqBody, err := json.Marshal(refreshReq)
+
require.NoError(t, err)
+
+
req, err := http.NewRequest("POST", server.URL+"/oauth/refresh", strings.NewReader(string(reqBody)))
+
require.NoError(t, err)
+
req.Header.Set("Content-Type", "application/json")
+
+
resp, err := http.DefaultClient.Do(req)
+
require.NoError(t, err)
+
defer func() { _ = resp.Body.Close() }()
+
+
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode,
+
"Session ID mismatch should be rejected (prevents session hijacking)")
+
})
+
+
t.Run("Non-existent session", func(t *testing.T) {
+
// Create a valid sealed token for a non-existent session
+
nonExistentSessionID := "nonexistent-session-id"
+
validToken, err := client.SealSession(did.String(), nonExistentSessionID, 30*24*time.Hour)
+
require.NoError(t, err)
+
+
refreshReq := map[string]interface{}{
+
"did": did.String(),
+
"session_id": nonExistentSessionID,
+
"sealed_token": validToken, // Valid token but session doesn't exist
+
}
+
+
reqBody, err := json.Marshal(refreshReq)
+
require.NoError(t, err)
+
+
req, err := http.NewRequest("POST", server.URL+"/oauth/refresh", strings.NewReader(string(reqBody)))
+
require.NoError(t, err)
+
req.Header.Set("Content-Type", "application/json")
+
+
resp, err := http.DefaultClient.Do(req)
+
require.NoError(t, err)
+
defer func() { _ = resp.Body.Close() }()
+
+
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode,
+
"Non-existent session should be rejected with 401")
+
})
+
+
t.Logf("โœ… Token refresh endpoint validation verified")
+
}
+
+
// TestOAuthE2E_SessionUpdate tests that refresh updates the session in database
+
func TestOAuthE2E_SessionUpdate(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping OAuth session update test in short mode")
+
}
+
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
// Run migrations
+
require.NoError(t, goose.SetDialect("postgres"))
+
require.NoError(t, goose.Up(db, "../../internal/db/migrations"))
+
+
ctx := context.Background()
+
+
t.Log("๐Ÿ’พ Testing OAuth session update on refresh...")
+
+
// Setup OAuth store
+
store := SetupOAuthTestStore(t, db)
+
+
// Create a test session
+
did, err := syntax.ParseDID("did:plc:sessionupdate123")
+
require.NoError(t, err)
+
+
originalSession := oauthlib.ClientSessionData{
+
AccountDID: did,
+
SessionID: "update-session-1",
+
HostURL: "http://localhost:3001",
+
AuthServerURL: "http://localhost:3001",
+
AuthServerTokenEndpoint: "http://localhost:3001/oauth/token",
+
AccessToken: "original-access-token",
+
RefreshToken: "original-refresh-token",
+
DPoPPrivateKeyMultibase: "original-dpop-key",
+
Scopes: []string{"atproto"},
+
}
+
+
// Save original session
+
err = store.SaveSession(ctx, originalSession)
+
require.NoError(t, err)
+
+
t.Logf("โœ… Original session saved")
+
+
// Simulate a token refresh by updating the session with new tokens
+
updatedSession := originalSession
+
updatedSession.AccessToken = "new-access-token"
+
updatedSession.RefreshToken = "new-refresh-token"
+
updatedSession.DPoPAuthServerNonce = "new-nonce"
+
+
// Update the session (upsert)
+
err = store.SaveSession(ctx, updatedSession)
+
require.NoError(t, err)
+
+
t.Logf("โœ… Session updated with new tokens")
+
+
// Retrieve the session and verify it was updated
+
retrieved, err := store.GetSession(ctx, did, originalSession.SessionID)
+
require.NoError(t, err, "Should retrieve updated session")
+
+
assert.Equal(t, "new-access-token", retrieved.AccessToken,
+
"Access token should be updated")
+
assert.Equal(t, "new-refresh-token", retrieved.RefreshToken,
+
"Refresh token should be updated")
+
assert.Equal(t, "new-nonce", retrieved.DPoPAuthServerNonce,
+
"DPoP nonce should be updated")
+
+
// Verify session ID and DID remain the same
+
assert.Equal(t, originalSession.SessionID, retrieved.SessionID,
+
"Session ID should remain the same")
+
assert.Equal(t, did, retrieved.AccountDID,
+
"DID should remain the same")
+
+
t.Logf("โœ… Session update verified - tokens refreshed in database")
+
+
// Verify updated_at was changed
+
var updatedAt time.Time
+
err = db.QueryRowContext(ctx,
+
"SELECT updated_at FROM oauth_sessions WHERE did = $1 AND session_id = $2",
+
did.String(), originalSession.SessionID).Scan(&updatedAt)
+
require.NoError(t, err)
+
+
// Updated timestamp should be recent (within last minute)
+
assert.WithinDuration(t, time.Now(), updatedAt, time.Minute,
+
"Session updated_at should be recent")
+
+
t.Logf("โœ… Session timestamp update verified")
+
}
+
+
// TestOAuthE2E_RefreshTokenRotation tests refresh token rotation behavior
+
func TestOAuthE2E_RefreshTokenRotation(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping OAuth refresh token rotation test in short mode")
+
}
+
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
// Run migrations
+
require.NoError(t, goose.SetDialect("postgres"))
+
require.NoError(t, goose.Up(db, "../../internal/db/migrations"))
+
+
ctx := context.Background()
+
+
t.Log("๐Ÿ”„ Testing OAuth refresh token rotation...")
+
+
// Setup OAuth store
+
store := SetupOAuthTestStore(t, db)
+
+
// Create a test session
+
did, err := syntax.ParseDID("did:plc:rotation123")
+
require.NoError(t, err)
+
+
// Simulate multiple refresh cycles
+
sessionID := "rotation-session-1"
+
tokens := []struct {
+
access string
+
refresh string
+
}{
+
{"access-token-v1", "refresh-token-v1"},
+
{"access-token-v2", "refresh-token-v2"},
+
{"access-token-v3", "refresh-token-v3"},
+
}
+
+
for i, tokenPair := range tokens {
+
session := oauthlib.ClientSessionData{
+
AccountDID: did,
+
SessionID: sessionID,
+
HostURL: "http://localhost:3001",
+
AuthServerURL: "http://localhost:3001",
+
AuthServerTokenEndpoint: "http://localhost:3001/oauth/token",
+
AccessToken: tokenPair.access,
+
RefreshToken: tokenPair.refresh,
+
Scopes: []string{"atproto"},
+
}
+
+
// Save/update session
+
err = store.SaveSession(ctx, session)
+
require.NoError(t, err, "Should save session iteration %d", i+1)
+
+
// Retrieve and verify
+
retrieved, err := store.GetSession(ctx, did, sessionID)
+
require.NoError(t, err, "Should retrieve session iteration %d", i+1)
+
+
assert.Equal(t, tokenPair.access, retrieved.AccessToken,
+
"Access token should match iteration %d", i+1)
+
assert.Equal(t, tokenPair.refresh, retrieved.RefreshToken,
+
"Refresh token should match iteration %d", i+1)
+
+
// Small delay to ensure timestamp differences
+
time.Sleep(10 * time.Millisecond)
+
}
+
+
t.Logf("โœ… Refresh token rotation verified through %d cycles", len(tokens))
+
+
// Verify final state
+
finalSession, err := store.GetSession(ctx, did, sessionID)
+
require.NoError(t, err)
+
+
assert.Equal(t, "access-token-v3", finalSession.AccessToken,
+
"Final access token should be from last rotation")
+
assert.Equal(t, "refresh-token-v3", finalSession.RefreshToken,
+
"Final refresh token should be from last rotation")
+
+
t.Logf("โœ… Token rotation state verified")
+
}
+312
tests/integration/oauth_session_fixation_test.go
···
+
package integration
+
+
import (
+
"Coves/internal/atproto/oauth"
+
"context"
+
"crypto/sha256"
+
"encoding/base64"
+
"net/http"
+
"net/http/httptest"
+
"net/url"
+
"testing"
+
"time"
+
+
oauthlib "github.com/bluesky-social/indigo/atproto/auth/oauth"
+
"github.com/bluesky-social/indigo/atproto/syntax"
+
"github.com/go-chi/chi/v5"
+
"github.com/pressly/goose/v3"
+
"github.com/stretchr/testify/assert"
+
"github.com/stretchr/testify/require"
+
)
+
+
// TestOAuth_SessionFixationAttackPrevention tests that the mobile redirect binding
+
// prevents session fixation attacks where an attacker plants a mobile_redirect_uri
+
// cookie, then the user does a web login, and credentials get sent to attacker's deep link.
+
//
+
// Attack scenario:
+
// 1. Attacker tricks user into visiting /oauth/mobile/login?redirect_uri=evil://steal
+
// 2. This plants a mobile_redirect_uri cookie (lives 10 minutes)
+
// 3. User later does normal web OAuth login via /oauth/login
+
// 4. HandleCallback sees the stale mobile_redirect_uri cookie
+
// 5. WITHOUT THE FIX: Callback sends sealed token, DID, session_id to attacker's deep link
+
// 6. WITH THE FIX: Binding mismatch is detected, mobile cookies cleared, user gets web session
+
func TestOAuth_SessionFixationAttackPrevention(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping OAuth session fixation test in short mode")
+
}
+
+
// Setup test database
+
db := setupTestDB(t)
+
defer func() {
+
if err := db.Close(); err != nil {
+
t.Logf("Failed to close database: %v", err)
+
}
+
}()
+
+
// Run migrations
+
require.NoError(t, goose.SetDialect("postgres"))
+
require.NoError(t, goose.Up(db, "../../internal/db/migrations"))
+
+
// Setup OAuth client and store
+
store := SetupOAuthTestStore(t, db)
+
client := SetupOAuthTestClient(t, store)
+
require.NotNil(t, client, "OAuth client should be initialized")
+
+
// Setup handler
+
handler := oauth.NewOAuthHandler(client, store)
+
+
// Setup router
+
r := chi.NewRouter()
+
r.Get("/oauth/callback", handler.HandleCallback)
+
+
t.Run("attack scenario - planted mobile cookie without binding", func(t *testing.T) {
+
ctx := context.Background()
+
+
// Step 1: Simulate a successful OAuth callback (like a user did web login)
+
// We'll create a mock session to simulate what ProcessCallback would return
+
testDID := "did:plc:test123456"
+
parsedDID, err := syntax.ParseDID(testDID)
+
require.NoError(t, err)
+
+
sessionID := "test-session-" + time.Now().Format("20060102150405")
+
testSession := oauthlib.ClientSessionData{
+
AccountDID: parsedDID,
+
SessionID: sessionID,
+
HostURL: "http://localhost:3001",
+
AccessToken: "test-access-token",
+
Scopes: []string{"atproto"},
+
}
+
+
// Save the session (simulating successful OAuth flow)
+
err = store.SaveSession(ctx, testSession)
+
require.NoError(t, err)
+
+
// Step 2: Attacker planted a mobile_redirect_uri cookie (without binding)
+
// This simulates the cookie being planted earlier by attacker
+
attackerRedirectURI := "evil://steal"
+
req := httptest.NewRequest("GET", "/oauth/callback?code=test&state=test&iss=http://localhost:3001", nil)
+
+
// Plant the attacker's cookie (URL escaped as it would be in real scenario)
+
req.AddCookie(&http.Cookie{
+
Name: "mobile_redirect_uri",
+
Value: url.QueryEscape(attackerRedirectURI),
+
Path: "/oauth",
+
})
+
// NOTE: No mobile_redirect_binding cookie! This is the attack scenario.
+
+
rec := httptest.NewRecorder()
+
+
// Step 3: Try to process the callback
+
// This would fail because ProcessCallback needs real OAuth code/state
+
// For this test, we're verifying the handler's security checks work
+
// even before ProcessCallback is called
+
+
// The handler will try to call ProcessCallback which will fail
+
// But we're testing that even if it succeeded, the mobile redirect
+
// validation would prevent the attack
+
handler.HandleCallback(rec, req)
+
+
// Step 4: Verify the attack was prevented
+
// The handler should reject the request due to missing binding
+
// Since ProcessCallback will fail first (no real OAuth code), we expect
+
// a 400 error, but the important thing is it doesn't redirect to evil://steal
+
+
assert.NotEqual(t, http.StatusFound, rec.Code,
+
"Should not redirect when ProcessCallback fails")
+
assert.NotContains(t, rec.Header().Get("Location"), "evil://",
+
"Should never redirect to attacker's URI")
+
})
+
+
t.Run("legitimate mobile flow - with valid binding", func(t *testing.T) {
+
ctx := context.Background()
+
+
// Setup a legitimate mobile session
+
testDID := "did:plc:mobile123"
+
parsedDID, err := syntax.ParseDID(testDID)
+
require.NoError(t, err)
+
+
sessionID := "mobile-session-" + time.Now().Format("20060102150405")
+
testSession := oauthlib.ClientSessionData{
+
AccountDID: parsedDID,
+
SessionID: sessionID,
+
HostURL: "http://localhost:3001",
+
AccessToken: "mobile-access-token",
+
Scopes: []string{"atproto"},
+
}
+
+
// Save the session
+
err = store.SaveSession(ctx, testSession)
+
require.NoError(t, err)
+
+
// Create request with BOTH mobile_redirect_uri AND valid binding
+
// Use Universal Link URI that's in the allowlist
+
legitRedirectURI := "https://coves.social/app/oauth/callback"
+
csrfToken := "valid-csrf-token-for-mobile"
+
req := httptest.NewRequest("GET", "/oauth/callback?code=test&state=test&iss=http://localhost:3001", nil)
+
+
// Add mobile redirect URI cookie
+
req.AddCookie(&http.Cookie{
+
Name: "mobile_redirect_uri",
+
Value: url.QueryEscape(legitRedirectURI),
+
Path: "/oauth",
+
})
+
+
// Add CSRF token (required for mobile flow)
+
req.AddCookie(&http.Cookie{
+
Name: "oauth_csrf",
+
Value: csrfToken,
+
Path: "/oauth",
+
})
+
+
// Add VALID binding cookie (this is what prevents the attack)
+
// In real flow, this would be set by HandleMobileLogin
+
// The binding now includes the CSRF token for double-submit validation
+
mobileBinding := generateMobileRedirectBindingForTest(csrfToken, legitRedirectURI)
+
req.AddCookie(&http.Cookie{
+
Name: "mobile_redirect_binding",
+
Value: mobileBinding,
+
Path: "/oauth",
+
})
+
+
rec := httptest.NewRecorder()
+
handler.HandleCallback(rec, req)
+
+
// This will also fail at ProcessCallback (no real OAuth code)
+
// but we're verifying the binding validation logic is in place
+
// In a real integration test with PDS, this would succeed
+
assert.NotEqual(t, http.StatusFound, rec.Code,
+
"Should not redirect when ProcessCallback fails (expected in mock test)")
+
})
+
+
t.Run("binding mismatch - attacker tries wrong binding", func(t *testing.T) {
+
ctx := context.Background()
+
+
// Setup session
+
testDID := "did:plc:bindingtest"
+
parsedDID, err := syntax.ParseDID(testDID)
+
require.NoError(t, err)
+
+
sessionID := "binding-test-" + time.Now().Format("20060102150405")
+
testSession := oauthlib.ClientSessionData{
+
AccountDID: parsedDID,
+
SessionID: sessionID,
+
HostURL: "http://localhost:3001",
+
AccessToken: "binding-test-token",
+
Scopes: []string{"atproto"},
+
}
+
+
err = store.SaveSession(ctx, testSession)
+
require.NoError(t, err)
+
+
// Attacker tries to plant evil redirect with a binding from different URI
+
attackerRedirectURI := "evil://steal"
+
attackerCSRF := "attacker-csrf-token"
+
req := httptest.NewRequest("GET", "/oauth/callback?code=test&state=test&iss=http://localhost:3001", nil)
+
+
req.AddCookie(&http.Cookie{
+
Name: "mobile_redirect_uri",
+
Value: url.QueryEscape(attackerRedirectURI),
+
Path: "/oauth",
+
})
+
+
req.AddCookie(&http.Cookie{
+
Name: "oauth_csrf",
+
Value: attackerCSRF,
+
Path: "/oauth",
+
})
+
+
// Use binding from a DIFFERENT CSRF token and URI (attacker's attempt to forge)
+
// Even if attacker knows the redirect URI, they don't know the user's CSRF token
+
wrongBinding := generateMobileRedirectBindingForTest("different-csrf", "https://coves.social/app/oauth/callback")
+
req.AddCookie(&http.Cookie{
+
Name: "mobile_redirect_binding",
+
Value: wrongBinding,
+
Path: "/oauth",
+
})
+
+
rec := httptest.NewRecorder()
+
handler.HandleCallback(rec, req)
+
+
// Should fail due to binding mismatch (even before ProcessCallback)
+
// The binding validation happens after ProcessCallback in the real code,
+
// but the mismatch would be caught and cookies cleared
+
assert.NotContains(t, rec.Header().Get("Location"), "evil://",
+
"Should never redirect to attacker's URI on binding mismatch")
+
})
+
+
t.Run("CSRF token value mismatch - attacker tries different CSRF", func(t *testing.T) {
+
ctx := context.Background()
+
+
// Setup session
+
testDID := "did:plc:csrftest"
+
parsedDID, err := syntax.ParseDID(testDID)
+
require.NoError(t, err)
+
+
sessionID := "csrf-test-" + time.Now().Format("20060102150405")
+
testSession := oauthlib.ClientSessionData{
+
AccountDID: parsedDID,
+
SessionID: sessionID,
+
HostURL: "http://localhost:3001",
+
AccessToken: "csrf-test-token",
+
Scopes: []string{"atproto"},
+
}
+
+
err = store.SaveSession(ctx, testSession)
+
require.NoError(t, err)
+
+
// This tests the P1 security fix: CSRF token VALUE must be validated, not just presence
+
// Attack scenario:
+
// 1. User starts mobile login with CSRF token A and redirect URI X
+
// 2. Binding = hash(A + X) is stored in cookie
+
// 3. Attacker somehow gets user to have CSRF token B in cookie (different from A)
+
// 4. Callback receives CSRF token B, redirect URI X, binding = hash(A + X)
+
// 5. hash(B + X) != hash(A + X), so attack is detected
+
+
originalCSRF := "original-csrf-token-set-at-login"
+
redirectURI := "https://coves.social/app/oauth/callback"
+
// Binding was created with original CSRF token
+
originalBinding := generateMobileRedirectBindingForTest(originalCSRF, redirectURI)
+
+
// But attacker managed to change the CSRF cookie
+
attackerCSRF := "attacker-replaced-csrf"
+
+
req := httptest.NewRequest("GET", "/oauth/callback?code=test&state=test&iss=http://localhost:3001", nil)
+
+
req.AddCookie(&http.Cookie{
+
Name: "mobile_redirect_uri",
+
Value: url.QueryEscape(redirectURI),
+
Path: "/oauth",
+
})
+
+
// Attacker's CSRF token (different from what created the binding)
+
req.AddCookie(&http.Cookie{
+
Name: "oauth_csrf",
+
Value: attackerCSRF,
+
Path: "/oauth",
+
})
+
+
// Original binding (created with original CSRF token)
+
req.AddCookie(&http.Cookie{
+
Name: "mobile_redirect_binding",
+
Value: originalBinding,
+
Path: "/oauth",
+
})
+
+
rec := httptest.NewRecorder()
+
handler.HandleCallback(rec, req)
+
+
// Should fail because hash(attackerCSRF + redirectURI) != hash(originalCSRF + redirectURI)
+
// This is the key security fix - CSRF token VALUE is now validated
+
assert.NotEqual(t, http.StatusFound, rec.Code,
+
"Should not redirect when CSRF token doesn't match binding")
+
})
+
}
+
+
// generateMobileRedirectBindingForTest generates a binding for testing
+
// This mirrors the actual logic in handlers_security.go:
+
// binding = base64(sha256(csrfToken + "|" + redirectURI)[:16])
+
func generateMobileRedirectBindingForTest(csrfToken, mobileRedirectURI string) string {
+
combined := csrfToken + "|" + mobileRedirectURI
+
hash := sha256.Sum256([]byte(combined))
+
return base64.URLEncoding.EncodeToString(hash[:16])
+
}
+169
tests/integration/oauth_token_verification_test.go
···
+
package integration
+
+
import (
+
"Coves/internal/api/middleware"
+
"fmt"
+
"net/http"
+
"net/http/httptest"
+
"os"
+
"testing"
+
"time"
+
)
+
+
// TestOAuthTokenVerification tests end-to-end OAuth token verification
+
// with real PDS-issued OAuth tokens. This replaces the old JWT verification test
+
// since we now use OAuth sealed session tokens instead of raw JWTs.
+
//
+
// Flow:
+
// 1. Create account on local PDS (or use existing)
+
// 2. Authenticate to get OAuth tokens and create sealed session token
+
// 3. Verify our auth middleware can unseal and validate the token
+
// 4. Test token validation and session retrieval
+
//
+
// NOTE: This test uses the E2E OAuth middleware which mocks the session unsealing
+
// for testing purposes. Real OAuth tokens from PDS would be sealed using the
+
// OAuth client's seal secret.
+
func TestOAuthTokenVerification(t *testing.T) {
+
// Skip in short mode since this requires real PDS
+
if testing.Short() {
+
t.Skip("Skipping OAuth token verification test in short mode")
+
}
+
+
pdsURL := os.Getenv("PDS_URL")
+
if pdsURL == "" {
+
pdsURL = "http://localhost:3001"
+
}
+
+
// Check if PDS is running
+
healthResp, err := http.Get(pdsURL + "/xrpc/_health")
+
if err != nil {
+
t.Skipf("PDS not running at %s: %v", pdsURL, err)
+
}
+
_ = healthResp.Body.Close()
+
+
t.Run("OAuth token validation and middleware integration", func(t *testing.T) {
+
// Step 1: Create a test account on PDS
+
// Keep handle short to avoid PDS validation errors
+
timestamp := time.Now().Unix() % 100000 // Last 5 digits
+
handle := fmt.Sprintf("oauth%d.local.coves.dev", timestamp)
+
password := "testpass123"
+
email := fmt.Sprintf("oauth%d@test.com", timestamp)
+
+
_, did, err := createPDSAccount(pdsURL, handle, email, password)
+
if err != nil {
+
t.Fatalf("Failed to create PDS account: %v", err)
+
}
+
t.Logf("โœ“ Created test account: %s (DID: %s)", handle, did)
+
+
// Step 2: Create OAuth middleware with mock unsealer for testing
+
// In production, this would unseal real OAuth tokens from PDS
+
t.Log("Testing OAuth middleware with sealed session tokens...")
+
+
e2eAuth := NewE2EOAuthMiddleware()
+
testToken := e2eAuth.AddUser(did)
+
+
handlerCalled := false
+
var extractedDID string
+
+
testHandler := e2eAuth.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
handlerCalled = true
+
extractedDID = middleware.GetUserDID(r)
+
w.WriteHeader(http.StatusOK)
+
_, _ = w.Write([]byte(`{"success": true}`))
+
}))
+
+
req := httptest.NewRequest("GET", "/test", nil)
+
req.Header.Set("Authorization", "Bearer "+testToken)
+
w := httptest.NewRecorder()
+
+
testHandler.ServeHTTP(w, req)
+
+
if !handlerCalled {
+
t.Errorf("Handler was not called - auth middleware rejected valid token")
+
t.Logf("Response status: %d", w.Code)
+
t.Logf("Response body: %s", w.Body.String())
+
}
+
+
if w.Code != http.StatusOK {
+
t.Errorf("Expected status 200, got %d", w.Code)
+
t.Logf("Response body: %s", w.Body.String())
+
}
+
+
if extractedDID != did {
+
t.Errorf("Middleware extracted wrong DID: expected %s, got %s", did, extractedDID)
+
}
+
+
t.Logf("โœ… OAuth middleware with token validation working correctly!")
+
t.Logf(" Handler called: %v", handlerCalled)
+
t.Logf(" Extracted DID: %s", extractedDID)
+
t.Logf(" Response status: %d", w.Code)
+
})
+
+
t.Run("Rejects tampered/invalid sealed tokens", func(t *testing.T) {
+
// Create valid user
+
timestamp := time.Now().Unix() % 100000
+
handle := fmt.Sprintf("tamp%d.local.coves.dev", timestamp)
+
password := "testpass456"
+
email := fmt.Sprintf("tamp%d@test.com", timestamp)
+
+
_, did, err := createPDSAccount(pdsURL, handle, email, password)
+
if err != nil {
+
t.Fatalf("Failed to create PDS account: %v", err)
+
}
+
+
// Create OAuth middleware
+
e2eAuth := NewE2EOAuthMiddleware()
+
validToken := e2eAuth.AddUser(did)
+
+
// Create various invalid tokens to test
+
testCases := []struct {
+
name string
+
token string
+
}{
+
{"Empty token", ""},
+
{"Invalid base64", "not-valid-base64!!!"},
+
{"Tampered token", "dGFtcGVyZWQtdG9rZW4tZGF0YQ=="}, // Valid base64 but not a real sealed session
+
{"Short token", "abc"},
+
{"Modified valid token", validToken + "extra"},
+
}
+
+
for _, tc := range testCases {
+
t.Run(tc.name, func(t *testing.T) {
+
handlerCalled := false
+
testHandler := e2eAuth.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
handlerCalled = true
+
w.WriteHeader(http.StatusOK)
+
}))
+
+
req := httptest.NewRequest("GET", "/test", nil)
+
if tc.token != "" {
+
req.Header.Set("Authorization", "Bearer "+tc.token)
+
}
+
w := httptest.NewRecorder()
+
+
testHandler.ServeHTTP(w, req)
+
+
if handlerCalled {
+
t.Error("Handler was called for invalid token - should have been rejected")
+
}
+
+
if w.Code != http.StatusUnauthorized {
+
t.Errorf("Expected status 401 for invalid token, got %d", w.Code)
+
}
+
+
t.Logf("โœ“ Middleware correctly rejected %s with status %d", tc.name, w.Code)
+
})
+
}
+
+
t.Logf("โœ… All invalid token types correctly rejected")
+
})
+
+
t.Run("Session expiration handling", func(t *testing.T) {
+
// OAuth session expiration is handled at the database level
+
// See TestOAuthE2E_TokenExpiration in oauth_e2e_test.go for full expiration testing
+
t.Log("โ„น๏ธ Session expiration testing is covered in oauth_e2e_test.go")
+
t.Log(" OAuth sessions expire based on database timestamps and are cleaned up periodically")
+
t.Log(" This is different from JWT expiration which was timestamp-based in the token itself")
+
t.Skip("Session expiration is tested in oauth_e2e_test.go - see TestOAuthE2E_TokenExpiration")
+
})
+
}
+2
go.mod
···
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect
github.com/earthboundkid/versioninfo/v2 v2.24.1 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
+
github.com/go-chi/cors v1.2.2 // indirect
github.com/go-logr/logr v1.4.1 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
+
github.com/google/go-querystring v1.1.0 // indirect
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
github.com/hashicorp/go-retryablehttp v0.7.5 // indirect
github.com/hashicorp/golang-lru v1.0.2 // indirect
+5
.env.dev
···
# Also supports base64: prefix for consistency
OAUTH_COOKIE_SECRET=f1132c01b1a625a865c6c455a75ee793572cedb059cebe0c4c1ae4c446598f7d
+
# Seal secret for OAuth session tokens (AES-256-GCM encryption)
+
# Generate with: openssl rand -base64 32
+
# This must be 32 bytes when base64-decoded for AES-256
+
# OAUTH_SEAL_SECRET=ryW6xNVxYhP6hCDA90NGCmK58Q2ONnkYXbHL0oZN2no=
+
# AppView public URL (used for OAuth callback and client metadata)
# Dev: http://127.0.0.1:8081 (use 127.0.0.1 instead of localhost per RFC 8252)
# Prod: https://coves.social
-73
cmd/genjwks/main.go
···
-
package main
-
-
import (
-
"crypto/ecdsa"
-
"crypto/elliptic"
-
"crypto/rand"
-
"encoding/json"
-
"fmt"
-
"log"
-
"os"
-
-
"github.com/lestrrat-go/jwx/v2/jwk"
-
)
-
-
// genjwks generates an ES256 keypair for OAuth client authentication
-
// The private key is stored in the config/env, public key is served at /oauth/jwks.json
-
//
-
// Usage:
-
//
-
// go run cmd/genjwks/main.go
-
//
-
// This will output a JSON private key that should be stored in OAUTH_PRIVATE_JWK
-
func main() {
-
fmt.Println("Generating ES256 keypair for OAuth client authentication...")
-
-
// Generate ES256 (NIST P-256) private key
-
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
-
if err != nil {
-
log.Fatalf("Failed to generate private key: %v", err)
-
}
-
-
// Convert to JWK
-
jwkKey, err := jwk.FromRaw(privateKey)
-
if err != nil {
-
log.Fatalf("Failed to create JWK from private key: %v", err)
-
}
-
-
// Set key parameters
-
if err = jwkKey.Set(jwk.KeyIDKey, "oauth-client-key"); err != nil {
-
log.Fatalf("Failed to set kid: %v", err)
-
}
-
if err = jwkKey.Set(jwk.AlgorithmKey, "ES256"); err != nil {
-
log.Fatalf("Failed to set alg: %v", err)
-
}
-
if err = jwkKey.Set(jwk.KeyUsageKey, "sig"); err != nil {
-
log.Fatalf("Failed to set use: %v", err)
-
}
-
-
// Marshal to JSON
-
jsonData, err := json.MarshalIndent(jwkKey, "", " ")
-
if err != nil {
-
log.Fatalf("Failed to marshal JWK: %v", err)
-
}
-
-
// Output instructions
-
fmt.Println("\nโœ… ES256 keypair generated successfully!")
-
fmt.Println("\n๐Ÿ“ Add this to your .env.dev file:")
-
fmt.Println("\nOAUTH_PRIVATE_JWK='" + string(jsonData) + "'")
-
fmt.Println("\nโš ๏ธ IMPORTANT:")
-
fmt.Println(" - Keep this private key SECRET")
-
fmt.Println(" - Never commit it to version control")
-
fmt.Println(" - Generate a new key for production")
-
fmt.Println(" - The public key will be automatically derived and served at /oauth/jwks.json")
-
-
// Optionally write to a file (not committed)
-
if len(os.Args) > 1 && os.Args[1] == "--save" {
-
filename := "oauth-private-key.json"
-
if err := os.WriteFile(filename, jsonData, 0o600); err != nil {
-
log.Fatalf("Failed to write key file: %v", err)
-
}
-
fmt.Printf("\n๐Ÿ’พ Private key saved to %s (remember to add to .gitignore!)\n", filename)
-
}
-
}
-52
internal/atproto/auth/combined_key_fetcher.go
···
-
package auth
-
-
import (
-
"context"
-
"fmt"
-
"strings"
-
-
indigoIdentity "github.com/bluesky-social/indigo/atproto/identity"
-
)
-
-
// CombinedKeyFetcher handles JWT public key fetching for both:
-
// - DID issuers (did:plc:, did:web:) โ†’ resolves via DID document
-
// - URL issuers (https://) โ†’ fetches via JWKS endpoint (legacy/fallback)
-
//
-
// For atproto service authentication, the issuer is typically the user's DID,
-
// and the signing key is published in their DID document.
-
type CombinedKeyFetcher struct {
-
didFetcher *DIDKeyFetcher
-
jwksFetcher JWKSFetcher
-
}
-
-
// NewCombinedKeyFetcher creates a key fetcher that supports both DID and URL issuers.
-
// Parameters:
-
// - directory: Indigo's identity directory for DID resolution
-
// - jwksFetcher: fallback JWKS fetcher for URL issuers (can be nil if not needed)
-
func NewCombinedKeyFetcher(directory indigoIdentity.Directory, jwksFetcher JWKSFetcher) *CombinedKeyFetcher {
-
return &CombinedKeyFetcher{
-
didFetcher: NewDIDKeyFetcher(directory),
-
jwksFetcher: jwksFetcher,
-
}
-
}
-
-
// FetchPublicKey fetches the public key for verifying a JWT.
-
// Routes to the appropriate fetcher based on issuer format:
-
// - DID (did:plc:, did:web:) โ†’ DIDKeyFetcher
-
// - URL (https://) โ†’ JWKSFetcher
-
func (f *CombinedKeyFetcher) FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) {
-
// Check if issuer is a DID
-
if strings.HasPrefix(issuer, "did:") {
-
return f.didFetcher.FetchPublicKey(ctx, issuer, token)
-
}
-
-
// Check if issuer is a URL (https:// or http:// in dev)
-
if strings.HasPrefix(issuer, "https://") || strings.HasPrefix(issuer, "http://") {
-
if f.jwksFetcher == nil {
-
return nil, fmt.Errorf("URL issuer %s requires JWKS fetcher, but none configured", issuer)
-
}
-
return f.jwksFetcher.FetchPublicKey(ctx, issuer, token)
-
}
-
-
return nil, fmt.Errorf("unsupported issuer format: %s (expected DID or URL)", issuer)
-
}
-616
internal/atproto/auth/dpop.go
···
-
package auth
-
-
import (
-
"crypto/sha256"
-
"encoding/base64"
-
"encoding/json"
-
"fmt"
-
"strings"
-
"sync"
-
"time"
-
-
indigoCrypto "github.com/bluesky-social/indigo/atproto/atcrypto"
-
"github.com/golang-jwt/jwt/v5"
-
)
-
-
// NonceCache provides replay protection for DPoP proofs by tracking seen jti values.
-
// This prevents an attacker from reusing a captured DPoP proof within the validity window.
-
// Per RFC 9449 Section 11.1, servers SHOULD prevent replay attacks.
-
type NonceCache struct {
-
seen map[string]time.Time // jti -> expiration time
-
stopCh chan struct{}
-
maxAge time.Duration // How long to keep entries
-
cleanup time.Duration // How often to clean up expired entries
-
mu sync.RWMutex
-
}
-
-
// NewNonceCache creates a new nonce cache for DPoP replay protection.
-
// maxAge should match or exceed DPoPVerifier.MaxProofAge.
-
func NewNonceCache(maxAge time.Duration) *NonceCache {
-
nc := &NonceCache{
-
seen: make(map[string]time.Time),
-
maxAge: maxAge,
-
cleanup: maxAge / 2, // Clean up at half the max age
-
stopCh: make(chan struct{}),
-
}
-
-
// Start background cleanup goroutine
-
go nc.cleanupLoop()
-
-
return nc
-
}
-
-
// CheckAndStore checks if a jti has been seen before and stores it if not.
-
// Returns true if the jti is fresh (not a replay), false if it's a replay.
-
func (nc *NonceCache) CheckAndStore(jti string) bool {
-
nc.mu.Lock()
-
defer nc.mu.Unlock()
-
-
now := time.Now()
-
expiry := now.Add(nc.maxAge)
-
-
// Check if already seen
-
if existingExpiry, seen := nc.seen[jti]; seen {
-
// Still valid (not expired) - this is a replay
-
if existingExpiry.After(now) {
-
return false
-
}
-
// Expired entry - allow reuse and update expiry
-
}
-
-
// Store the new jti
-
nc.seen[jti] = expiry
-
return true
-
}
-
-
// cleanupLoop periodically removes expired entries from the cache
-
func (nc *NonceCache) cleanupLoop() {
-
ticker := time.NewTicker(nc.cleanup)
-
defer ticker.Stop()
-
-
for {
-
select {
-
case <-ticker.C:
-
nc.cleanupExpired()
-
case <-nc.stopCh:
-
return
-
}
-
}
-
}
-
-
// cleanupExpired removes expired entries from the cache
-
func (nc *NonceCache) cleanupExpired() {
-
nc.mu.Lock()
-
defer nc.mu.Unlock()
-
-
now := time.Now()
-
for jti, expiry := range nc.seen {
-
if expiry.Before(now) {
-
delete(nc.seen, jti)
-
}
-
}
-
}
-
-
// Stop stops the cleanup goroutine. Call this when done with the cache.
-
func (nc *NonceCache) Stop() {
-
close(nc.stopCh)
-
}
-
-
// Size returns the number of entries in the cache (for testing/monitoring)
-
func (nc *NonceCache) Size() int {
-
nc.mu.RLock()
-
defer nc.mu.RUnlock()
-
return len(nc.seen)
-
}
-
-
// DPoPClaims represents the claims in a DPoP proof JWT (RFC 9449)
-
type DPoPClaims struct {
-
jwt.RegisteredClaims
-
-
// HTTP method of the request (e.g., "GET", "POST")
-
HTTPMethod string `json:"htm"`
-
-
// HTTP URI of the request (without query and fragment parts)
-
HTTPURI string `json:"htu"`
-
-
// Access token hash (optional, for token binding)
-
AccessTokenHash string `json:"ath,omitempty"`
-
}
-
-
// DPoPProof represents a parsed and verified DPoP proof
-
type DPoPProof struct {
-
RawPublicJWK map[string]interface{}
-
Claims *DPoPClaims
-
PublicKey interface{} // *ecdsa.PublicKey or similar
-
Thumbprint string // JWK thumbprint (base64url)
-
}
-
-
// DPoPVerifier verifies DPoP proofs for OAuth token binding
-
type DPoPVerifier struct {
-
// Optional: custom nonce validation function (for server-issued nonces)
-
ValidateNonce func(nonce string) bool
-
-
// NonceCache for replay protection (optional but recommended)
-
// If nil, jti replay protection is disabled
-
NonceCache *NonceCache
-
-
// Maximum allowed clock skew for timestamp validation
-
MaxClockSkew time.Duration
-
-
// Maximum age of DPoP proof (prevents replay with old proofs)
-
MaxProofAge time.Duration
-
}
-
-
// NewDPoPVerifier creates a DPoP verifier with sensible defaults including replay protection
-
func NewDPoPVerifier() *DPoPVerifier {
-
maxProofAge := 5 * time.Minute
-
return &DPoPVerifier{
-
MaxClockSkew: 30 * time.Second,
-
MaxProofAge: maxProofAge,
-
NonceCache: NewNonceCache(maxProofAge),
-
}
-
}
-
-
// NewDPoPVerifierWithoutReplayProtection creates a DPoP verifier without replay protection.
-
// This should only be used in testing or when replay protection is handled externally.
-
func NewDPoPVerifierWithoutReplayProtection() *DPoPVerifier {
-
return &DPoPVerifier{
-
MaxClockSkew: 30 * time.Second,
-
MaxProofAge: 5 * time.Minute,
-
NonceCache: nil, // No replay protection
-
}
-
}
-
-
// Stop stops background goroutines. Call this when shutting down.
-
func (v *DPoPVerifier) Stop() {
-
if v.NonceCache != nil {
-
v.NonceCache.Stop()
-
}
-
}
-
-
// VerifyDPoPProof verifies a DPoP proof JWT and returns the parsed proof.
-
// This supports all atProto-compatible ECDSA algorithms including ES256K (secp256k1).
-
func (v *DPoPVerifier) VerifyDPoPProof(dpopProof, httpMethod, httpURI string) (*DPoPProof, error) {
-
// Manually parse the JWT to support ES256K (which golang-jwt doesn't recognize)
-
header, claims, err := parseJWTHeaderAndClaims(dpopProof)
-
if err != nil {
-
return nil, fmt.Errorf("failed to parse DPoP proof: %w", err)
-
}
-
-
// Extract and validate the typ header
-
typ, ok := header["typ"].(string)
-
if !ok || typ != "dpop+jwt" {
-
return nil, fmt.Errorf("invalid DPoP proof: typ must be 'dpop+jwt', got '%s'", typ)
-
}
-
-
alg, ok := header["alg"].(string)
-
if !ok {
-
return nil, fmt.Errorf("invalid DPoP proof: missing alg header")
-
}
-
-
// Extract the JWK from the header first (needed for algorithm-curve validation)
-
jwkRaw, ok := header["jwk"]
-
if !ok {
-
return nil, fmt.Errorf("invalid DPoP proof: missing jwk header")
-
}
-
-
jwkMap, ok := jwkRaw.(map[string]interface{})
-
if !ok {
-
return nil, fmt.Errorf("invalid DPoP proof: jwk must be an object")
-
}
-
-
// Validate the algorithm is supported and matches the JWK curve
-
// This is critical for security - prevents algorithm confusion attacks
-
if err := validateAlgorithmCurveBinding(alg, jwkMap); err != nil {
-
return nil, fmt.Errorf("invalid DPoP proof: %w", err)
-
}
-
-
// Parse the public key using indigo's crypto package
-
// This supports all atProto curves including secp256k1 (ES256K)
-
publicKey, err := parseJWKToIndigoPublicKey(jwkMap)
-
if err != nil {
-
return nil, fmt.Errorf("invalid DPoP proof JWK: %w", err)
-
}
-
-
// Calculate the JWK thumbprint
-
thumbprint, err := CalculateJWKThumbprint(jwkMap)
-
if err != nil {
-
return nil, fmt.Errorf("failed to calculate JWK thumbprint: %w", err)
-
}
-
-
// Verify the signature using indigo's crypto package
-
// This works for all ECDSA algorithms including ES256K
-
if err := verifyJWTSignatureWithIndigo(dpopProof, publicKey); err != nil {
-
return nil, fmt.Errorf("DPoP proof signature verification failed: %w", err)
-
}
-
-
// Validate the claims
-
if err := v.validateDPoPClaims(claims, httpMethod, httpURI); err != nil {
-
return nil, err
-
}
-
-
return &DPoPProof{
-
Claims: claims,
-
PublicKey: publicKey,
-
Thumbprint: thumbprint,
-
RawPublicJWK: jwkMap,
-
}, nil
-
}
-
-
// validateDPoPClaims validates the DPoP proof claims
-
func (v *DPoPVerifier) validateDPoPClaims(claims *DPoPClaims, expectedMethod, expectedURI string) error {
-
// Validate jti (unique identifier) is present
-
if claims.ID == "" {
-
return fmt.Errorf("DPoP proof missing jti claim")
-
}
-
-
// Validate htm (HTTP method)
-
if !strings.EqualFold(claims.HTTPMethod, expectedMethod) {
-
return fmt.Errorf("DPoP proof htm mismatch: expected %s, got %s", expectedMethod, claims.HTTPMethod)
-
}
-
-
// Validate htu (HTTP URI) - compare without query/fragment
-
expectedURIBase := stripQueryFragment(expectedURI)
-
claimURIBase := stripQueryFragment(claims.HTTPURI)
-
if expectedURIBase != claimURIBase {
-
return fmt.Errorf("DPoP proof htu mismatch: expected %s, got %s", expectedURIBase, claimURIBase)
-
}
-
-
// Validate iat (issued at) is present and recent
-
if claims.IssuedAt == nil {
-
return fmt.Errorf("DPoP proof missing iat claim")
-
}
-
-
now := time.Now()
-
iat := claims.IssuedAt.Time
-
-
// Check clock skew (not too far in the future)
-
if iat.After(now.Add(v.MaxClockSkew)) {
-
return fmt.Errorf("DPoP proof iat is in the future")
-
}
-
-
// Check proof age (not too old)
-
if now.Sub(iat) > v.MaxProofAge {
-
return fmt.Errorf("DPoP proof is too old (issued %v ago, max %v)", now.Sub(iat), v.MaxProofAge)
-
}
-
-
// SECURITY: Validate exp claim if present (RFC standard JWT validation)
-
// While DPoP proofs typically use iat + MaxProofAge, if exp is included it must be honored
-
if claims.ExpiresAt != nil {
-
expWithSkew := claims.ExpiresAt.Time.Add(v.MaxClockSkew)
-
if now.After(expWithSkew) {
-
return fmt.Errorf("DPoP proof expired at %v", claims.ExpiresAt.Time)
-
}
-
}
-
-
// SECURITY: Validate nbf claim if present (RFC standard JWT validation)
-
if claims.NotBefore != nil {
-
nbfWithSkew := claims.NotBefore.Time.Add(-v.MaxClockSkew)
-
if now.Before(nbfWithSkew) {
-
return fmt.Errorf("DPoP proof not valid before %v", claims.NotBefore.Time)
-
}
-
}
-
-
// SECURITY: Check for replay attack using jti
-
// Per RFC 9449 Section 11.1, servers SHOULD prevent replay attacks
-
if v.NonceCache != nil {
-
if !v.NonceCache.CheckAndStore(claims.ID) {
-
return fmt.Errorf("DPoP proof replay detected: jti %s already used", claims.ID)
-
}
-
}
-
-
return nil
-
}
-
-
// VerifyTokenBinding verifies that the DPoP proof binds to the access token
-
// by comparing the proof's thumbprint to the token's cnf.jkt claim
-
func (v *DPoPVerifier) VerifyTokenBinding(proof *DPoPProof, expectedThumbprint string) error {
-
if proof.Thumbprint != expectedThumbprint {
-
return fmt.Errorf("DPoP proof thumbprint mismatch: token expects %s, proof has %s",
-
expectedThumbprint, proof.Thumbprint)
-
}
-
return nil
-
}
-
-
// VerifyAccessTokenHash verifies the DPoP proof's ath (access token hash) claim
-
// matches the SHA-256 hash of the presented access token.
-
// Per RFC 9449 section 4.2, if ath is present, the RS MUST verify it.
-
func (v *DPoPVerifier) VerifyAccessTokenHash(proof *DPoPProof, accessToken string) error {
-
// If ath claim is not present, that's acceptable per RFC 9449
-
// (ath is only required when the RS mandates it)
-
if proof.Claims.AccessTokenHash == "" {
-
return nil
-
}
-
-
// Calculate the expected ath: base64url(SHA-256(access_token))
-
hash := sha256.Sum256([]byte(accessToken))
-
expectedAth := base64.RawURLEncoding.EncodeToString(hash[:])
-
-
if proof.Claims.AccessTokenHash != expectedAth {
-
return fmt.Errorf("DPoP proof ath mismatch: proof bound to different access token")
-
}
-
-
return nil
-
}
-
-
// CalculateJWKThumbprint calculates the JWK thumbprint per RFC 7638
-
// The thumbprint is the base64url-encoded SHA-256 hash of the canonical JWK representation
-
func CalculateJWKThumbprint(jwk map[string]interface{}) (string, error) {
-
kty, ok := jwk["kty"].(string)
-
if !ok {
-
return "", fmt.Errorf("JWK missing kty")
-
}
-
-
// Build the canonical JWK representation based on key type
-
// Per RFC 7638, only specific members are included, in lexicographic order
-
var canonical map[string]string
-
-
switch kty {
-
case "EC":
-
crv, ok := jwk["crv"].(string)
-
if !ok {
-
return "", fmt.Errorf("EC JWK missing crv")
-
}
-
x, ok := jwk["x"].(string)
-
if !ok {
-
return "", fmt.Errorf("EC JWK missing x")
-
}
-
y, ok := jwk["y"].(string)
-
if !ok {
-
return "", fmt.Errorf("EC JWK missing y")
-
}
-
// Lexicographic order: crv, kty, x, y
-
canonical = map[string]string{
-
"crv": crv,
-
"kty": kty,
-
"x": x,
-
"y": y,
-
}
-
case "RSA":
-
e, ok := jwk["e"].(string)
-
if !ok {
-
return "", fmt.Errorf("RSA JWK missing e")
-
}
-
n, ok := jwk["n"].(string)
-
if !ok {
-
return "", fmt.Errorf("RSA JWK missing n")
-
}
-
// Lexicographic order: e, kty, n
-
canonical = map[string]string{
-
"e": e,
-
"kty": kty,
-
"n": n,
-
}
-
case "OKP":
-
crv, ok := jwk["crv"].(string)
-
if !ok {
-
return "", fmt.Errorf("OKP JWK missing crv")
-
}
-
x, ok := jwk["x"].(string)
-
if !ok {
-
return "", fmt.Errorf("OKP JWK missing x")
-
}
-
// Lexicographic order: crv, kty, x
-
canonical = map[string]string{
-
"crv": crv,
-
"kty": kty,
-
"x": x,
-
}
-
default:
-
return "", fmt.Errorf("unsupported JWK key type: %s", kty)
-
}
-
-
// Serialize to JSON (Go's json.Marshal produces lexicographically ordered keys for map[string]string)
-
canonicalJSON, err := json.Marshal(canonical)
-
if err != nil {
-
return "", fmt.Errorf("failed to serialize canonical JWK: %w", err)
-
}
-
-
// SHA-256 hash
-
hash := sha256.Sum256(canonicalJSON)
-
-
// Base64url encode (no padding)
-
thumbprint := base64.RawURLEncoding.EncodeToString(hash[:])
-
-
return thumbprint, nil
-
}
-
-
// validateAlgorithmCurveBinding validates that the JWT algorithm matches the JWK curve.
-
// This is critical for security - an attacker could claim alg: "ES256K" but provide
-
// a P-256 key, potentially bypassing algorithm binding requirements.
-
func validateAlgorithmCurveBinding(alg string, jwkMap map[string]interface{}) error {
-
kty, ok := jwkMap["kty"].(string)
-
if !ok {
-
return fmt.Errorf("JWK missing kty")
-
}
-
-
// ECDSA algorithms require EC key type
-
switch alg {
-
case "ES256K", "ES256", "ES384", "ES512":
-
if kty != "EC" {
-
return fmt.Errorf("algorithm %s requires EC key type, got %s", alg, kty)
-
}
-
case "RS256", "RS384", "RS512", "PS256", "PS384", "PS512":
-
return fmt.Errorf("RSA algorithms not yet supported for DPoP: %s", alg)
-
default:
-
return fmt.Errorf("unsupported DPoP algorithm: %s", alg)
-
}
-
-
// Validate curve matches algorithm
-
crv, ok := jwkMap["crv"].(string)
-
if !ok {
-
return fmt.Errorf("EC JWK missing crv")
-
}
-
-
var expectedCurve string
-
switch alg {
-
case "ES256K":
-
expectedCurve = "secp256k1"
-
case "ES256":
-
expectedCurve = "P-256"
-
case "ES384":
-
expectedCurve = "P-384"
-
case "ES512":
-
expectedCurve = "P-521"
-
}
-
-
if crv != expectedCurve {
-
return fmt.Errorf("algorithm %s requires curve %s, got %s", alg, expectedCurve, crv)
-
}
-
-
return nil
-
}
-
-
// parseJWKToIndigoPublicKey parses a JWK map to an indigo PublicKey.
-
// This returns indigo's PublicKey interface which supports all atProto curves
-
// including secp256k1 (ES256K), P-256 (ES256), P-384 (ES384), and P-521 (ES512).
-
func parseJWKToIndigoPublicKey(jwkMap map[string]interface{}) (indigoCrypto.PublicKey, error) {
-
// Convert map to JSON bytes for indigo's parser
-
jwkBytes, err := json.Marshal(jwkMap)
-
if err != nil {
-
return nil, fmt.Errorf("failed to serialize JWK: %w", err)
-
}
-
-
// Parse with indigo's crypto package - this supports all atProto curves
-
// including secp256k1 (ES256K) which Go's crypto/elliptic doesn't support
-
pubKey, err := indigoCrypto.ParsePublicJWKBytes(jwkBytes)
-
if err != nil {
-
return nil, fmt.Errorf("failed to parse JWK: %w", err)
-
}
-
-
return pubKey, nil
-
}
-
-
// parseJWTHeaderAndClaims manually parses a JWT's header and claims without using golang-jwt.
-
// This is necessary to support ES256K (secp256k1) which golang-jwt doesn't recognize.
-
func parseJWTHeaderAndClaims(tokenString string) (map[string]interface{}, *DPoPClaims, error) {
-
parts := strings.Split(tokenString, ".")
-
if len(parts) != 3 {
-
return nil, nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
-
}
-
-
// Decode header
-
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
-
if err != nil {
-
return nil, nil, fmt.Errorf("failed to decode JWT header: %w", err)
-
}
-
-
var header map[string]interface{}
-
if err := json.Unmarshal(headerBytes, &header); err != nil {
-
return nil, nil, fmt.Errorf("failed to parse JWT header: %w", err)
-
}
-
-
// Decode claims
-
claimsBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
-
if err != nil {
-
return nil, nil, fmt.Errorf("failed to decode JWT claims: %w", err)
-
}
-
-
// Parse into raw map first to extract standard claims
-
var rawClaims map[string]interface{}
-
if err := json.Unmarshal(claimsBytes, &rawClaims); err != nil {
-
return nil, nil, fmt.Errorf("failed to parse JWT claims: %w", err)
-
}
-
-
// Build DPoPClaims struct
-
claims := &DPoPClaims{}
-
-
// Extract jti
-
if jti, ok := rawClaims["jti"].(string); ok {
-
claims.ID = jti
-
}
-
-
// Extract iat (issued at)
-
if iat, ok := rawClaims["iat"].(float64); ok {
-
t := time.Unix(int64(iat), 0)
-
claims.IssuedAt = jwt.NewNumericDate(t)
-
}
-
-
// Extract exp (expiration) if present
-
if exp, ok := rawClaims["exp"].(float64); ok {
-
t := time.Unix(int64(exp), 0)
-
claims.ExpiresAt = jwt.NewNumericDate(t)
-
}
-
-
// Extract nbf (not before) if present
-
if nbf, ok := rawClaims["nbf"].(float64); ok {
-
t := time.Unix(int64(nbf), 0)
-
claims.NotBefore = jwt.NewNumericDate(t)
-
}
-
-
// Extract htm (HTTP method)
-
if htm, ok := rawClaims["htm"].(string); ok {
-
claims.HTTPMethod = htm
-
}
-
-
// Extract htu (HTTP URI)
-
if htu, ok := rawClaims["htu"].(string); ok {
-
claims.HTTPURI = htu
-
}
-
-
// Extract ath (access token hash) if present
-
if ath, ok := rawClaims["ath"].(string); ok {
-
claims.AccessTokenHash = ath
-
}
-
-
return header, claims, nil
-
}
-
-
// verifyJWTSignatureWithIndigo verifies a JWT signature using indigo's crypto package.
-
// This is used instead of golang-jwt for algorithms not supported by golang-jwt (like ES256K).
-
// It parses the JWT, extracts the signing input and signature, and uses indigo's
-
// PublicKey.HashAndVerifyLenient() for verification.
-
//
-
// JWT format: header.payload.signature (all base64url-encoded)
-
// Signature is verified over the raw bytes of "header.payload"
-
// (indigo's HashAndVerifyLenient handles SHA-256 hashing internally)
-
func verifyJWTSignatureWithIndigo(tokenString string, pubKey indigoCrypto.PublicKey) error {
-
// Split the JWT into parts
-
parts := strings.Split(tokenString, ".")
-
if len(parts) != 3 {
-
return fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
-
}
-
-
// The signing input is "header.payload" (without decoding)
-
signingInput := parts[0] + "." + parts[1]
-
-
// Decode the signature from base64url
-
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
-
if err != nil {
-
return fmt.Errorf("failed to decode JWT signature: %w", err)
-
}
-
-
// Use indigo's verification - HashAndVerifyLenient handles hashing internally
-
// and accepts both low-S and high-S signatures for maximum compatibility
-
err = pubKey.HashAndVerifyLenient([]byte(signingInput), signature)
-
if err != nil {
-
return fmt.Errorf("signature verification failed: %w", err)
-
}
-
-
return nil
-
}
-
-
// stripQueryFragment removes query and fragment from a URI
-
func stripQueryFragment(uri string) string {
-
if idx := strings.Index(uri, "?"); idx != -1 {
-
uri = uri[:idx]
-
}
-
if idx := strings.Index(uri, "#"); idx != -1 {
-
uri = uri[:idx]
-
}
-
return uri
-
}
-
-
// ExtractCnfJkt extracts the cnf.jkt (confirmation key thumbprint) from JWT claims
-
func ExtractCnfJkt(claims *Claims) (string, error) {
-
if claims.Confirmation == nil {
-
return "", fmt.Errorf("token missing cnf claim (no DPoP binding)")
-
}
-
-
jkt, ok := claims.Confirmation["jkt"].(string)
-
if !ok || jkt == "" {
-
return "", fmt.Errorf("token cnf claim missing jkt (DPoP key thumbprint)")
-
}
-
-
return jkt, nil
-
}
-189
internal/atproto/auth/jwks_fetcher.go
···
-
package auth
-
-
import (
-
"context"
-
"encoding/json"
-
"fmt"
-
"net/http"
-
"strings"
-
"sync"
-
"time"
-
)
-
-
// CachedJWKSFetcher fetches and caches JWKS from authorization servers
-
type CachedJWKSFetcher struct {
-
cache map[string]*cachedJWKS
-
httpClient *http.Client
-
cacheMutex sync.RWMutex
-
cacheTTL time.Duration
-
}
-
-
type cachedJWKS struct {
-
jwks *JWKS
-
expiresAt time.Time
-
}
-
-
// NewCachedJWKSFetcher creates a new JWKS fetcher with caching
-
func NewCachedJWKSFetcher(cacheTTL time.Duration) *CachedJWKSFetcher {
-
return &CachedJWKSFetcher{
-
cache: make(map[string]*cachedJWKS),
-
httpClient: &http.Client{
-
Timeout: 10 * time.Second,
-
},
-
cacheTTL: cacheTTL,
-
}
-
}
-
-
// FetchPublicKey fetches the public key for verifying a JWT from the issuer
-
// Implements JWKSFetcher interface
-
// Returns interface{} to support both RSA and ECDSA keys
-
func (f *CachedJWKSFetcher) FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) {
-
// Extract key ID from token
-
kid, err := ExtractKeyID(token)
-
if err != nil {
-
return nil, fmt.Errorf("failed to extract key ID: %w", err)
-
}
-
-
// Get JWKS from cache or fetch
-
jwks, err := f.getJWKS(ctx, issuer)
-
if err != nil {
-
return nil, err
-
}
-
-
// Find the key by ID
-
jwk, err := jwks.FindKeyByID(kid)
-
if err != nil {
-
// Key not found in cache - try refreshing
-
jwks, err = f.fetchJWKS(ctx, issuer)
-
if err != nil {
-
return nil, fmt.Errorf("failed to refresh JWKS: %w", err)
-
}
-
f.cacheJWKS(issuer, jwks)
-
-
// Try again with fresh JWKS
-
jwk, err = jwks.FindKeyByID(kid)
-
if err != nil {
-
return nil, err
-
}
-
}
-
-
// Convert JWK to public key (RSA or ECDSA)
-
return jwk.ToPublicKey()
-
}
-
-
// getJWKS gets JWKS from cache or fetches if not cached/expired
-
func (f *CachedJWKSFetcher) getJWKS(ctx context.Context, issuer string) (*JWKS, error) {
-
// Check cache first
-
f.cacheMutex.RLock()
-
cached, exists := f.cache[issuer]
-
f.cacheMutex.RUnlock()
-
-
if exists && time.Now().Before(cached.expiresAt) {
-
return cached.jwks, nil
-
}
-
-
// Not in cache or expired - fetch from issuer
-
jwks, err := f.fetchJWKS(ctx, issuer)
-
if err != nil {
-
return nil, err
-
}
-
-
// Cache it
-
f.cacheJWKS(issuer, jwks)
-
-
return jwks, nil
-
}
-
-
// fetchJWKS fetches JWKS from the authorization server
-
func (f *CachedJWKSFetcher) fetchJWKS(ctx context.Context, issuer string) (*JWKS, error) {
-
// Step 1: Fetch OAuth server metadata to get JWKS URI
-
metadataURL := strings.TrimSuffix(issuer, "/") + "/.well-known/oauth-authorization-server"
-
-
req, err := http.NewRequestWithContext(ctx, "GET", metadataURL, nil)
-
if err != nil {
-
return nil, fmt.Errorf("failed to create metadata request: %w", err)
-
}
-
-
resp, err := f.httpClient.Do(req)
-
if err != nil {
-
return nil, fmt.Errorf("failed to fetch metadata: %w", err)
-
}
-
defer func() {
-
_ = resp.Body.Close()
-
}()
-
-
if resp.StatusCode != http.StatusOK {
-
return nil, fmt.Errorf("metadata endpoint returned status %d", resp.StatusCode)
-
}
-
-
var metadata struct {
-
JWKSURI string `json:"jwks_uri"`
-
}
-
if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil {
-
return nil, fmt.Errorf("failed to decode metadata: %w", err)
-
}
-
-
if metadata.JWKSURI == "" {
-
return nil, fmt.Errorf("jwks_uri not found in metadata")
-
}
-
-
// Step 2: Fetch JWKS from the JWKS URI
-
jwksReq, err := http.NewRequestWithContext(ctx, "GET", metadata.JWKSURI, nil)
-
if err != nil {
-
return nil, fmt.Errorf("failed to create JWKS request: %w", err)
-
}
-
-
jwksResp, err := f.httpClient.Do(jwksReq)
-
if err != nil {
-
return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
-
}
-
defer func() {
-
_ = jwksResp.Body.Close()
-
}()
-
-
if jwksResp.StatusCode != http.StatusOK {
-
return nil, fmt.Errorf("JWKS endpoint returned status %d", jwksResp.StatusCode)
-
}
-
-
var jwks JWKS
-
if err := json.NewDecoder(jwksResp.Body).Decode(&jwks); err != nil {
-
return nil, fmt.Errorf("failed to decode JWKS: %w", err)
-
}
-
-
if len(jwks.Keys) == 0 {
-
return nil, fmt.Errorf("no keys found in JWKS")
-
}
-
-
return &jwks, nil
-
}
-
-
// cacheJWKS stores JWKS in the cache
-
func (f *CachedJWKSFetcher) cacheJWKS(issuer string, jwks *JWKS) {
-
f.cacheMutex.Lock()
-
defer f.cacheMutex.Unlock()
-
-
f.cache[issuer] = &cachedJWKS{
-
jwks: jwks,
-
expiresAt: time.Now().Add(f.cacheTTL),
-
}
-
}
-
-
// ClearCache clears the entire JWKS cache
-
func (f *CachedJWKSFetcher) ClearCache() {
-
f.cacheMutex.Lock()
-
defer f.cacheMutex.Unlock()
-
f.cache = make(map[string]*cachedJWKS)
-
}
-
-
// CleanupExpiredCache removes expired entries from the cache
-
func (f *CachedJWKSFetcher) CleanupExpiredCache() {
-
f.cacheMutex.Lock()
-
defer f.cacheMutex.Unlock()
-
-
now := time.Now()
-
for issuer, cached := range f.cache {
-
if now.After(cached.expiresAt) {
-
delete(f.cache, issuer)
-
}
-
}
-
}
-496
internal/atproto/auth/jwt_test.go
···
-
package auth
-
-
import (
-
"context"
-
"testing"
-
"time"
-
-
"github.com/golang-jwt/jwt/v5"
-
)
-
-
func TestParseJWT(t *testing.T) {
-
// Create a test JWT token
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
Issuer: "https://test-pds.example.com",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
IssuedAt: jwt.NewNumericDate(time.Now()),
-
},
-
Scope: "atproto transition:generic",
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
-
tokenString, err := token.SignedString([]byte("test-secret"))
-
if err != nil {
-
t.Fatalf("Failed to create test token: %v", err)
-
}
-
-
// Test parsing
-
parsedClaims, err := ParseJWT(tokenString)
-
if err != nil {
-
t.Fatalf("ParseJWT failed: %v", err)
-
}
-
-
if parsedClaims.Subject != "did:plc:test123" {
-
t.Errorf("Expected subject 'did:plc:test123', got '%s'", parsedClaims.Subject)
-
}
-
-
if parsedClaims.Issuer != "https://test-pds.example.com" {
-
t.Errorf("Expected issuer 'https://test-pds.example.com', got '%s'", parsedClaims.Issuer)
-
}
-
-
if parsedClaims.Scope != "atproto transition:generic" {
-
t.Errorf("Expected scope 'atproto transition:generic', got '%s'", parsedClaims.Scope)
-
}
-
}
-
-
func TestParseJWT_MissingSubject(t *testing.T) {
-
// Create a token without subject
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Issuer: "https://test-pds.example.com",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
},
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
-
tokenString, err := token.SignedString([]byte("test-secret"))
-
if err != nil {
-
t.Fatalf("Failed to create test token: %v", err)
-
}
-
-
// Test parsing - should fail
-
_, err = ParseJWT(tokenString)
-
if err == nil {
-
t.Error("Expected error for missing subject, got nil")
-
}
-
}
-
-
func TestParseJWT_MissingIssuer(t *testing.T) {
-
// Create a token without issuer
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
},
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
-
tokenString, err := token.SignedString([]byte("test-secret"))
-
if err != nil {
-
t.Fatalf("Failed to create test token: %v", err)
-
}
-
-
// Test parsing - should fail
-
_, err = ParseJWT(tokenString)
-
if err == nil {
-
t.Error("Expected error for missing issuer, got nil")
-
}
-
}
-
-
func TestParseJWT_WithBearerPrefix(t *testing.T) {
-
// Create a test JWT token
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
Issuer: "https://test-pds.example.com",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
},
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
-
tokenString, err := token.SignedString([]byte("test-secret"))
-
if err != nil {
-
t.Fatalf("Failed to create test token: %v", err)
-
}
-
-
// Test parsing with Bearer prefix
-
parsedClaims, err := ParseJWT("Bearer " + tokenString)
-
if err != nil {
-
t.Fatalf("ParseJWT failed with Bearer prefix: %v", err)
-
}
-
-
if parsedClaims.Subject != "did:plc:test123" {
-
t.Errorf("Expected subject 'did:plc:test123', got '%s'", parsedClaims.Subject)
-
}
-
}
-
-
func TestValidateClaims_Expired(t *testing.T) {
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
Issuer: "https://test-pds.example.com",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(-1 * time.Hour)), // Expired
-
},
-
}
-
-
err := validateClaims(claims)
-
if err == nil {
-
t.Error("Expected error for expired token, got nil")
-
}
-
}
-
-
func TestValidateClaims_InvalidDID(t *testing.T) {
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "invalid-did-format",
-
Issuer: "https://test-pds.example.com",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
},
-
}
-
-
err := validateClaims(claims)
-
if err == nil {
-
t.Error("Expected error for invalid DID format, got nil")
-
}
-
}
-
-
func TestExtractKeyID(t *testing.T) {
-
// Create a test JWT token with kid in header
-
token := jwt.New(jwt.SigningMethodRS256)
-
token.Header["kid"] = "test-key-id"
-
token.Claims = &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
Issuer: "https://test-pds.example.com",
-
},
-
}
-
-
// Sign with a dummy RSA key (we just need a valid token structure)
-
tokenString, err := token.SignedString([]byte("dummy"))
-
if err == nil {
-
// If it succeeds (shouldn't with wrong key type, but let's handle it)
-
kid, err := ExtractKeyID(tokenString)
-
if err != nil {
-
t.Logf("ExtractKeyID failed (expected if signing fails): %v", err)
-
} else if kid != "test-key-id" {
-
t.Errorf("Expected kid 'test-key-id', got '%s'", kid)
-
}
-
}
-
}
-
-
// === HS256 Verification Tests ===
-
-
// mockJWKSFetcher is a mock implementation of JWKSFetcher for testing
-
type mockJWKSFetcher struct {
-
publicKey interface{}
-
err error
-
}
-
-
func (m *mockJWKSFetcher) FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) {
-
return m.publicKey, m.err
-
}
-
-
func createHS256Token(t *testing.T, subject, issuer, secret string, expiry time.Duration) string {
-
t.Helper()
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: subject,
-
Issuer: issuer,
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiry)),
-
IssuedAt: jwt.NewNumericDate(time.Now()),
-
},
-
Scope: "atproto transition:generic",
-
}
-
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
-
tokenString, err := token.SignedString([]byte(secret))
-
if err != nil {
-
t.Fatalf("Failed to create test token: %v", err)
-
}
-
return tokenString
-
}
-
-
func TestVerifyJWT_HS256_Valid(t *testing.T) {
-
// Setup: Configure environment for HS256 verification
-
secret := "test-jwt-secret-key-12345"
-
issuer := "https://pds.coves.social"
-
-
ResetJWTConfigForTesting()
-
t.Setenv("PDS_JWT_SECRET", secret)
-
t.Setenv("HS256_ISSUERS", issuer)
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
tokenString := createHS256Token(t, "did:plc:test123", issuer, secret, 1*time.Hour)
-
-
// Verify token
-
claims, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
-
if err != nil {
-
t.Fatalf("VerifyJWT failed for valid HS256 token: %v", err)
-
}
-
-
if claims.Subject != "did:plc:test123" {
-
t.Errorf("Expected subject 'did:plc:test123', got '%s'", claims.Subject)
-
}
-
if claims.Issuer != issuer {
-
t.Errorf("Expected issuer '%s', got '%s'", issuer, claims.Issuer)
-
}
-
}
-
-
func TestVerifyJWT_HS256_WrongSecret(t *testing.T) {
-
// Setup: Configure environment with one secret, sign with another
-
issuer := "https://pds.coves.social"
-
-
ResetJWTConfigForTesting()
-
t.Setenv("PDS_JWT_SECRET", "correct-secret")
-
t.Setenv("HS256_ISSUERS", issuer)
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
// Create token with wrong secret
-
tokenString := createHS256Token(t, "did:plc:test123", issuer, "wrong-secret", 1*time.Hour)
-
-
// Verify should fail
-
_, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
-
if err == nil {
-
t.Error("Expected error for HS256 token with wrong secret, got nil")
-
}
-
}
-
-
func TestVerifyJWT_HS256_SecretNotConfigured(t *testing.T) {
-
// Setup: Whitelist issuer but don't configure secret
-
issuer := "https://pds.coves.social"
-
-
ResetJWTConfigForTesting()
-
t.Setenv("PDS_JWT_SECRET", "") // Ensure secret is not set (empty = not configured)
-
t.Setenv("HS256_ISSUERS", issuer)
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
tokenString := createHS256Token(t, "did:plc:test123", issuer, "any-secret", 1*time.Hour)
-
-
// Verify should fail with descriptive error
-
_, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
-
if err == nil {
-
t.Error("Expected error when PDS_JWT_SECRET not configured, got nil")
-
}
-
if err != nil && !contains(err.Error(), "PDS_JWT_SECRET not configured") {
-
t.Errorf("Expected error about PDS_JWT_SECRET not configured, got: %v", err)
-
}
-
}
-
-
// === Algorithm Confusion Attack Prevention Tests ===
-
-
func TestVerifyJWT_AlgorithmConfusionAttack_HS256WithNonWhitelistedIssuer(t *testing.T) {
-
// SECURITY TEST: This tests the algorithm confusion attack prevention
-
// An attacker tries to use HS256 with an issuer that should use RS256/ES256
-
-
ResetJWTConfigForTesting()
-
t.Setenv("PDS_JWT_SECRET", "some-secret")
-
t.Setenv("HS256_ISSUERS", "https://trusted.example.com") // Different from token issuer
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
// Create HS256 token with non-whitelisted issuer (simulating attack)
-
tokenString := createHS256Token(t, "did:plc:attacker", "https://victim-pds.example.com", "some-secret", 1*time.Hour)
-
-
// Verify should fail because issuer is not in HS256 whitelist
-
_, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
-
if err == nil {
-
t.Error("SECURITY VULNERABILITY: HS256 token accepted for non-whitelisted issuer")
-
}
-
if err != nil && !contains(err.Error(), "not in HS256_ISSUERS whitelist") {
-
t.Errorf("Expected error about HS256 not allowed for issuer, got: %v", err)
-
}
-
}
-
-
func TestVerifyJWT_AlgorithmConfusionAttack_EmptyWhitelist(t *testing.T) {
-
// SECURITY TEST: When no issuers are whitelisted for HS256, all HS256 tokens should be rejected
-
-
ResetJWTConfigForTesting()
-
t.Setenv("PDS_JWT_SECRET", "some-secret")
-
t.Setenv("HS256_ISSUERS", "") // Empty whitelist
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
tokenString := createHS256Token(t, "did:plc:test123", "https://any-pds.example.com", "some-secret", 1*time.Hour)
-
-
// Verify should fail because no issuers are whitelisted for HS256
-
_, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
-
if err == nil {
-
t.Error("SECURITY VULNERABILITY: HS256 token accepted with empty issuer whitelist")
-
}
-
}
-
-
func TestVerifyJWT_IssuerRequiresHS256ButTokenUsesRS256(t *testing.T) {
-
// Test that issuer whitelisted for HS256 rejects tokens claiming to use RS256
-
issuer := "https://pds.coves.social"
-
-
ResetJWTConfigForTesting()
-
t.Setenv("PDS_JWT_SECRET", "test-secret")
-
t.Setenv("HS256_ISSUERS", issuer)
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
// Create RS256-signed token (can't actually sign without RSA key, but we can test the header check)
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
Issuer: issuer,
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
},
-
}
-
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
-
// This will create an invalid signature but valid header structure
-
// The test should fail at algorithm check, not signature verification
-
tokenString, _ := token.SignedString([]byte("dummy-key"))
-
-
if tokenString != "" {
-
_, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
-
if err == nil {
-
t.Error("Expected error when HS256 issuer receives non-HS256 token")
-
}
-
}
-
}
-
-
// === ParseJWTHeader Tests ===
-
-
func TestParseJWTHeader_Valid(t *testing.T) {
-
tokenString := createHS256Token(t, "did:plc:test123", "https://test.example.com", "secret", 1*time.Hour)
-
-
header, err := ParseJWTHeader(tokenString)
-
if err != nil {
-
t.Fatalf("ParseJWTHeader failed: %v", err)
-
}
-
-
if header.Alg != AlgorithmHS256 {
-
t.Errorf("Expected alg '%s', got '%s'", AlgorithmHS256, header.Alg)
-
}
-
}
-
-
func TestParseJWTHeader_WithBearerPrefix(t *testing.T) {
-
tokenString := createHS256Token(t, "did:plc:test123", "https://test.example.com", "secret", 1*time.Hour)
-
-
header, err := ParseJWTHeader("Bearer " + tokenString)
-
if err != nil {
-
t.Fatalf("ParseJWTHeader failed with Bearer prefix: %v", err)
-
}
-
-
if header.Alg != AlgorithmHS256 {
-
t.Errorf("Expected alg '%s', got '%s'", AlgorithmHS256, header.Alg)
-
}
-
}
-
-
func TestParseJWTHeader_InvalidFormat(t *testing.T) {
-
testCases := []struct {
-
name string
-
input string
-
}{
-
{"empty string", ""},
-
{"single part", "abc"},
-
{"two parts", "abc.def"},
-
{"too many parts", "a.b.c.d"},
-
}
-
-
for _, tc := range testCases {
-
t.Run(tc.name, func(t *testing.T) {
-
_, err := ParseJWTHeader(tc.input)
-
if err == nil {
-
t.Errorf("Expected error for invalid JWT format '%s', got nil", tc.input)
-
}
-
})
-
}
-
}
-
-
// === shouldUseHS256 and isHS256IssuerWhitelisted Tests ===
-
-
func TestIsHS256IssuerWhitelisted_Whitelisted(t *testing.T) {
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", "https://pds1.example.com,https://pds2.example.com")
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
if !isHS256IssuerWhitelisted("https://pds1.example.com") {
-
t.Error("Expected pds1 to be whitelisted")
-
}
-
if !isHS256IssuerWhitelisted("https://pds2.example.com") {
-
t.Error("Expected pds2 to be whitelisted")
-
}
-
}
-
-
func TestIsHS256IssuerWhitelisted_NotWhitelisted(t *testing.T) {
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", "https://pds1.example.com")
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
if isHS256IssuerWhitelisted("https://attacker.example.com") {
-
t.Error("Expected non-whitelisted issuer to return false")
-
}
-
}
-
-
func TestIsHS256IssuerWhitelisted_EmptyWhitelist(t *testing.T) {
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", "") // Empty whitelist
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
if isHS256IssuerWhitelisted("https://any.example.com") {
-
t.Error("Expected false when whitelist is empty (safe default)")
-
}
-
}
-
-
func TestIsHS256IssuerWhitelisted_WhitespaceHandling(t *testing.T) {
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", " https://pds1.example.com , https://pds2.example.com ")
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
if !isHS256IssuerWhitelisted("https://pds1.example.com") {
-
t.Error("Expected whitespace-trimmed issuer to be whitelisted")
-
}
-
}
-
-
// === shouldUseHS256 Tests (kid-based logic) ===
-
-
func TestShouldUseHS256_WithKid_AlwaysFalse(t *testing.T) {
-
// Tokens with kid should NEVER use HS256, regardless of issuer whitelist
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", "https://whitelisted.example.com")
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
header := &JWTHeader{
-
Alg: AlgorithmHS256,
-
Kid: "some-key-id", // Has kid
-
}
-
-
// Even whitelisted issuer should not use HS256 if token has kid
-
if shouldUseHS256(header, "https://whitelisted.example.com") {
-
t.Error("Tokens with kid should never use HS256 (supports federation)")
-
}
-
}
-
-
func TestShouldUseHS256_WithoutKid_WhitelistedIssuer(t *testing.T) {
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", "https://my-pds.example.com")
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
header := &JWTHeader{
-
Alg: AlgorithmHS256,
-
Kid: "", // No kid
-
}
-
-
if !shouldUseHS256(header, "https://my-pds.example.com") {
-
t.Error("Token without kid from whitelisted issuer should use HS256")
-
}
-
}
-
-
func TestShouldUseHS256_WithoutKid_NotWhitelisted(t *testing.T) {
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", "https://my-pds.example.com")
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
header := &JWTHeader{
-
Alg: AlgorithmHS256,
-
Kid: "", // No kid
-
}
-
-
if shouldUseHS256(header, "https://external-pds.example.com") {
-
t.Error("Token without kid from non-whitelisted issuer should NOT use HS256")
-
}
-
}
-
-
// Helper function
-
func contains(s, substr string) bool {
-
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
-
}
-
-
func containsHelper(s, substr string) bool {
-
for i := 0; i <= len(s)-len(substr); i++ {
-
if s[i:i+len(substr)] == substr {
-
return true
-
}
-
}
-
return false
-
}
+1
docker-compose.prod.yml
···
# Instance identity
INSTANCE_DID: did:web:coves.social
INSTANCE_DOMAIN: coves.social
+
APPVIEW_PUBLIC_URL: https://coves.social
# PDS connection (separate domain!)
PDS_URL: https://coves.me
+3 -1
internal/atproto/oauth/client.go
···
clientConfig = oauth.NewLocalhostConfig(callbackURL, config.Scopes)
} else {
// Production mode: public OAuth client with HTTPS
+
// client_id must be the URL of the client metadata document per atproto OAuth spec
+
clientID := config.PublicURL + "/oauth/client-metadata.json"
callbackURL := config.PublicURL + "/oauth/callback"
-
clientConfig = oauth.NewPublicConfig(config.PublicURL, callbackURL, config.Scopes)
+
clientConfig = oauth.NewPublicConfig(clientID, callbackURL, config.Scopes)
}
// Set user agent
-97
static/oauth/callback.html
···
-
<!DOCTYPE html>
-
<html>
-
<head>
-
<meta charset="utf-8">
-
<meta name="viewport" content="width=device-width, initial-scale=1">
-
<meta http-equiv="Content-Security-Policy" content="default-src 'self'; script-src 'unsafe-inline'; style-src 'unsafe-inline'">
-
<title>Authorization Successful - Coves</title>
-
<style>
-
body {
-
font-family: system-ui, -apple-system, sans-serif;
-
display: flex;
-
align-items: center;
-
justify-content: center;
-
min-height: 100vh;
-
margin: 0;
-
background: #f5f5f5;
-
}
-
.container {
-
text-align: center;
-
padding: 2rem;
-
background: white;
-
border-radius: 8px;
-
box-shadow: 0 2px 8px rgba(0,0,0,0.1);
-
max-width: 400px;
-
}
-
.success { color: #22c55e; font-size: 3rem; margin-bottom: 1rem; }
-
h1 { margin: 0 0 0.5rem; color: #1f2937; font-size: 1.5rem; }
-
p { color: #6b7280; margin: 0.5rem 0; }
-
a {
-
display: inline-block;
-
margin-top: 1rem;
-
padding: 0.75rem 1.5rem;
-
background: #3b82f6;
-
color: white;
-
text-decoration: none;
-
border-radius: 6px;
-
font-weight: 500;
-
}
-
a:hover { background: #2563eb; }
-
</style>
-
</head>
-
<body>
-
<div class="container">
-
<div class="success">โœ“</div>
-
<h1>Authorization Successful!</h1>
-
<p id="status">Returning to Coves...</p>
-
<a href="#" id="manualLink">Open Coves</a>
-
</div>
-
<script>
-
(function() {
-
// Parse and sanitize query params - only allow expected OAuth parameters
-
const urlParams = new URLSearchParams(window.location.search);
-
const safeParams = new URLSearchParams();
-
-
// Whitelist only expected OAuth callback parameters
-
const code = urlParams.get('code');
-
const state = urlParams.get('state');
-
const error = urlParams.get('error');
-
const errorDescription = urlParams.get('error_description');
-
const iss = urlParams.get('iss');
-
-
if (code) safeParams.set('code', code);
-
if (state) safeParams.set('state', state);
-
if (error) safeParams.set('error', error);
-
if (errorDescription) safeParams.set('error_description', errorDescription);
-
if (iss) safeParams.set('iss', iss);
-
-
const sanitizedQuery = safeParams.toString() ? '?' + safeParams.toString() : '';
-
-
const userAgent = navigator.userAgent || '';
-
const isAndroid = /Android/i.test(userAgent);
-
-
// Build deep link based on platform
-
let deepLink;
-
if (isAndroid) {
-
// Android: Intent URL format
-
const pathAndQuery = '/oauth/callback' + sanitizedQuery;
-
deepLink = 'intent:/' + pathAndQuery + '#Intent;scheme=social.coves;package=social.coves;end';
-
} else {
-
// iOS: Custom scheme
-
deepLink = 'social.coves:/oauth/callback' + sanitizedQuery;
-
}
-
-
// Update manual link
-
document.getElementById('manualLink').href = deepLink;
-
-
// Attempt automatic redirect
-
window.location.href = deepLink;
-
-
// Update status after 2 seconds if redirect didn't work
-
setTimeout(function() {
-
document.getElementById('status').textContent = 'Click the button above to continue';
-
}, 2000);
-
})();
-
</script>
-
</body>
-
</html>
+6 -5
internal/api/routes/oauth.go
···
// Use login limiter since callback completes the authentication flow
r.With(corsMiddleware(allowedOrigins), loginLimiter.Middleware).Get("/oauth/callback", handler.HandleCallback)
-
// Mobile Universal Link callback route
-
// This route is used for iOS Universal Links and Android App Links
-
// Path must match the path in .well-known/apple-app-site-association
-
// Uses the same handler as web callback - the system routes it to the mobile app
-
r.With(loginLimiter.Middleware).Get("/app/oauth/callback", handler.HandleCallback)
+
// Mobile Universal Link callback route (fallback when app doesn't intercept)
+
// This route exists for iOS Universal Links and Android App Links.
+
// When properly configured, the mobile OS intercepts this URL and opens the app
+
// BEFORE the request reaches the server. If this handler is reached, it means
+
// Universal Links failed to intercept.
+
r.With(loginLimiter.Middleware).Get("/app/oauth/callback", handler.HandleMobileDeepLinkFallback)
// Session management - dedicated rate limits
r.With(logoutLimiter.Middleware).Post("/oauth/logout", handler.HandleLogout)
+11
static/.well-known/apple-app-site-association
···
+
{
+
"applinks": {
+
"apps": [],
+
"details": [
+
{
+
"appID": "TEAM_ID.social.coves",
+
"paths": ["/app/oauth/callback"]
+
}
+
]
+
}
+
}
+10
static/.well-known/assetlinks.json
···
+
[{
+
"relation": ["delegate_permission/common.handle_all_urls"],
+
"target": {
+
"namespace": "android_app",
+
"package_name": "social.coves",
+
"sha256_cert_fingerprints": [
+
"0B:D8:8C:99:66:25:E5:CD:06:54:80:88:01:6F:B7:38:B9:F4:5B:41:71:F7:95:C8:68:94:87:AD:EA:9F:D9:ED"
+
]
+
}
+
}]
+16 -9
internal/atproto/oauth/handlers_test.go
···
}
// TestIsMobileRedirectURI tests mobile redirect URI validation with EXACT URI matching
-
// Only Universal Links (HTTPS) are allowed - custom schemes are blocked for security
+
// Per atproto spec, custom schemes must match client_id hostname in reverse-domain order
func TestIsMobileRedirectURI(t *testing.T) {
tests := []struct {
uri string
expected bool
}{
-
{"https://coves.social/app/oauth/callback", true}, // Universal Link - allowed
-
{"coves-app://oauth/callback", false}, // Custom scheme - blocked (insecure)
-
{"coves://oauth/callback", false}, // Custom scheme - blocked (insecure)
-
{"coves-app://callback", false}, // Custom scheme - blocked
-
{"coves://oauth", false}, // Custom scheme - blocked
-
{"myapp://oauth", false}, // Not in allowlist
-
{"https://example.com", false}, // Wrong domain
-
{"http://localhost", false}, // HTTP not allowed
+
// Custom scheme per atproto spec (reverse domain of coves.social)
+
{"social.coves:/callback", true},
+
{"social.coves://callback", true},
+
{"social.coves:/oauth/callback", true},
+
{"social.coves://oauth/callback", true},
+
// Universal Link - allowed (strongest security)
+
{"https://coves.social/app/oauth/callback", true},
+
// Wrong custom schemes - not reverse-domain of coves.social
+
{"coves-app://oauth/callback", false},
+
{"coves://oauth/callback", false},
+
{"coves.social://callback", false}, // Not reversed
+
{"myapp://oauth", false},
+
// Wrong domain/scheme
+
{"https://example.com", false},
+
{"http://localhost", false},
{"", false},
{"not-a-uri", false},
}
+41
internal/atproto/lexicon/social/coves/feed/vote/delete.json
···
+
{
+
"lexicon": 1,
+
"id": "social.coves.feed.vote.delete",
+
"defs": {
+
"main": {
+
"type": "procedure",
+
"description": "Delete a vote on a post or comment",
+
"input": {
+
"encoding": "application/json",
+
"schema": {
+
"type": "object",
+
"required": ["subject"],
+
"properties": {
+
"subject": {
+
"type": "ref",
+
"ref": "com.atproto.repo.strongRef",
+
"description": "Strong reference to the post or comment to remove the vote from"
+
}
+
}
+
}
+
},
+
"output": {
+
"encoding": "application/json",
+
"schema": {
+
"type": "object",
+
"properties": {}
+
}
+
},
+
"errors": [
+
{
+
"name": "VoteNotFound",
+
"description": "No vote found for this subject"
+
},
+
{
+
"name": "NotAuthorized",
+
"description": "User is not authorized to delete this vote"
+
}
+
]
+
}
+
}
+
}
+115
internal/api/handlers/vote/create_vote.go
···
+
package vote
+
+
import (
+
"Coves/internal/api/middleware"
+
"Coves/internal/core/votes"
+
"encoding/json"
+
"log"
+
"net/http"
+
)
+
+
// CreateVoteHandler handles vote creation
+
type CreateVoteHandler struct {
+
service votes.Service
+
}
+
+
// NewCreateVoteHandler creates a new create vote handler
+
func NewCreateVoteHandler(service votes.Service) *CreateVoteHandler {
+
return &CreateVoteHandler{
+
service: service,
+
}
+
}
+
+
// CreateVoteInput represents the request body for creating a vote
+
type CreateVoteInput struct {
+
Subject struct {
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
} `json:"subject"`
+
Direction string `json:"direction"`
+
}
+
+
// CreateVoteOutput represents the response body for creating a vote
+
type CreateVoteOutput struct {
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
}
+
+
// HandleCreateVote creates a vote on a post or comment
+
// POST /xrpc/social.coves.vote.create
+
//
+
// Request body: { "subject": { "uri": "at://...", "cid": "..." }, "direction": "up" }
+
// Response: { "uri": "at://...", "cid": "..." }
+
//
+
// Behavior:
+
// - If no vote exists: creates new vote with given direction
+
// - If vote exists with same direction: deletes vote (toggle off)
+
// - If vote exists with different direction: updates to new direction
+
func (h *CreateVoteHandler) HandleCreateVote(w http.ResponseWriter, r *http.Request) {
+
if r.Method != http.MethodPost {
+
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+
return
+
}
+
+
// Parse request body
+
var input CreateVoteInput
+
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "Invalid request body")
+
return
+
}
+
+
// Validate required fields
+
if input.Subject.URI == "" {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "subject.uri is required")
+
return
+
}
+
if input.Subject.CID == "" {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "subject.cid is required")
+
return
+
}
+
if input.Direction == "" {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "direction is required")
+
return
+
}
+
+
// Validate direction
+
if input.Direction != "up" && input.Direction != "down" {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "direction must be 'up' or 'down'")
+
return
+
}
+
+
// Get OAuth session from context (injected by auth middleware)
+
session := middleware.GetOAuthSession(r)
+
if session == nil {
+
writeError(w, http.StatusUnauthorized, "AuthRequired", "Authentication required")
+
return
+
}
+
+
// Create vote request
+
req := votes.CreateVoteRequest{
+
Subject: votes.StrongRef{
+
URI: input.Subject.URI,
+
CID: input.Subject.CID,
+
},
+
Direction: input.Direction,
+
}
+
+
// Call service to create vote
+
response, err := h.service.CreateVote(r.Context(), session, req)
+
if err != nil {
+
handleServiceError(w, err)
+
return
+
}
+
+
// Return success response
+
output := CreateVoteOutput{
+
URI: response.URI,
+
CID: response.CID,
+
}
+
+
w.Header().Set("Content-Type", "application/json")
+
w.WriteHeader(http.StatusOK)
+
if err := json.NewEncoder(w).Encode(output); err != nil {
+
log.Printf("Failed to encode response: %v", err)
+
}
+
}
+93
internal/api/handlers/vote/delete_vote.go
···
+
package vote
+
+
import (
+
"Coves/internal/api/middleware"
+
"Coves/internal/core/votes"
+
"encoding/json"
+
"log"
+
"net/http"
+
)
+
+
// DeleteVoteHandler handles vote deletion
+
type DeleteVoteHandler struct {
+
service votes.Service
+
}
+
+
// NewDeleteVoteHandler creates a new delete vote handler
+
func NewDeleteVoteHandler(service votes.Service) *DeleteVoteHandler {
+
return &DeleteVoteHandler{
+
service: service,
+
}
+
}
+
+
// DeleteVoteInput represents the request body for deleting a vote
+
type DeleteVoteInput struct {
+
Subject struct {
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
} `json:"subject"`
+
}
+
+
// DeleteVoteOutput represents the response body for deleting a vote
+
// Per lexicon: output is an empty object
+
type DeleteVoteOutput struct{}
+
+
// HandleDeleteVote removes a vote from a post or comment
+
// POST /xrpc/social.coves.vote.delete
+
//
+
// Request body: { "subject": { "uri": "at://...", "cid": "..." } }
+
// Response: { "success": true }
+
func (h *DeleteVoteHandler) HandleDeleteVote(w http.ResponseWriter, r *http.Request) {
+
if r.Method != http.MethodPost {
+
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+
return
+
}
+
+
// Parse request body
+
var input DeleteVoteInput
+
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "Invalid request body")
+
return
+
}
+
+
// Validate required fields
+
if input.Subject.URI == "" {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "subject.uri is required")
+
return
+
}
+
if input.Subject.CID == "" {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "subject.cid is required")
+
return
+
}
+
+
// Get OAuth session from context (injected by auth middleware)
+
session := middleware.GetOAuthSession(r)
+
if session == nil {
+
writeError(w, http.StatusUnauthorized, "AuthRequired", "Authentication required")
+
return
+
}
+
+
// Create delete vote request
+
req := votes.DeleteVoteRequest{
+
Subject: votes.StrongRef{
+
URI: input.Subject.URI,
+
CID: input.Subject.CID,
+
},
+
}
+
+
// Call service to delete vote
+
err := h.service.DeleteVote(r.Context(), session, req)
+
if err != nil {
+
handleServiceError(w, err)
+
return
+
}
+
+
// Return success response (empty object per lexicon)
+
output := DeleteVoteOutput{}
+
+
w.Header().Set("Content-Type", "application/json")
+
w.WriteHeader(http.StatusOK)
+
if err := json.NewEncoder(w).Encode(output); err != nil {
+
log.Printf("Failed to encode response: %v", err)
+
}
+
}
+24
internal/api/routes/vote.go
···
+
package routes
+
+
import (
+
"Coves/internal/api/handlers/vote"
+
"Coves/internal/api/middleware"
+
"Coves/internal/core/votes"
+
+
"github.com/go-chi/chi/v5"
+
)
+
+
// RegisterVoteRoutes registers vote-related XRPC endpoints on the router
+
// Implements social.coves.feed.vote.* lexicon endpoints
+
func RegisterVoteRoutes(r chi.Router, voteService votes.Service, authMiddleware *middleware.OAuthAuthMiddleware) {
+
// Initialize handlers
+
createHandler := vote.NewCreateVoteHandler(voteService)
+
deleteHandler := vote.NewDeleteVoteHandler(voteService)
+
+
// Procedure endpoints (POST) - require authentication
+
// social.coves.feed.vote.create - create or update a vote on a post/comment
+
r.With(authMiddleware.RequireAuth).Post("/xrpc/social.coves.feed.vote.create", createHandler.HandleCreateVote)
+
+
// social.coves.feed.vote.delete - delete a vote from a post/comment
+
r.With(authMiddleware.RequireAuth).Post("/xrpc/social.coves.feed.vote.delete", deleteHandler.HandleDeleteVote)
+
}
+3
.beads/beads.left.jsonl
···
+
{"id":"Coves-95q","content_hash":"8ec99d598f067780436b985f9ad57f0fa19632026981038df4f65f192186620b","title":"Add comprehensive API documentation","description":"","status":"open","priority":2,"issue_type":"task","created_at":"2025-11-17T20:30:34.835721854-08:00","updated_at":"2025-11-17T20:30:34.835721854-08:00","source_repo":".","dependencies":[{"issue_id":"Coves-95q","depends_on_id":"Coves-e16","type":"blocks","created_at":"2025-11-17T20:30:46.273899399-08:00","created_by":"daemon"}]}
+
{"id":"Coves-e16","content_hash":"7c5d0fc8f0e7f626be3dad62af0e8412467330bad01a244e5a7e52ac5afff1c1","title":"Complete post creation and moderation features","description":"","status":"open","priority":1,"issue_type":"feature","created_at":"2025-11-17T20:30:12.885991306-08:00","updated_at":"2025-11-17T20:30:12.885991306-08:00","source_repo":"."}
+
{"id":"Coves-fce","content_hash":"26b3e16b99f827316ee0d741cc959464bd0c813446c95aef8105c7fd1e6b09ff","title":"Implement aggregator feed federation","description":"","status":"open","priority":1,"issue_type":"feature","created_at":"2025-11-17T20:30:21.453326012-08:00","updated_at":"2025-11-17T20:30:21.453326012-08:00","source_repo":"."}
+1
.beads/beads.left.meta.json
···
+
{"version":"0.23.1","timestamp":"2025-12-02T18:25:24.009187871-08:00","commit":"00d7d8d"}
-3
internal/api/handlers/vote/errors.go
···
case errors.Is(err, votes.ErrVoteNotFound):
// Matches: social.coves.feed.vote.delete#VoteNotFound
writeError(w, http.StatusNotFound, "VoteNotFound", "No vote found for this subject")
-
case errors.Is(err, votes.ErrSubjectNotFound):
-
// Matches: social.coves.feed.vote.create#SubjectNotFound
-
writeError(w, http.StatusNotFound, "SubjectNotFound", "The subject post or comment was not found")
case errors.Is(err, votes.ErrInvalidDirection):
writeError(w, http.StatusBadRequest, "InvalidRequest", "Vote direction must be 'up' or 'down'")
case errors.Is(err, votes.ErrInvalidSubject):
+4 -4
internal/atproto/oauth/handlers_security.go
···
// - Android: Verified via /.well-known/assetlinks.json
var allowedMobileRedirectURIs = map[string]bool{
// Custom scheme per atproto spec (reverse-domain of coves.social)
-
"social.coves:/callback": true,
-
"social.coves://callback": true, // Some platforms add double slash
-
"social.coves:/oauth/callback": true, // Alternative path
-
"social.coves://oauth/callback": true,
+
"social.coves:/callback": true,
+
"social.coves://callback": true, // Some platforms add double slash
+
"social.coves:/oauth/callback": true, // Alternative path
+
"social.coves://oauth/callback": true,
// Universal Links - cryptographically bound to app (preferred for security)
"https://coves.social/app/oauth/callback": true,
}
-3
internal/core/votes/errors.go
···
// ErrVoteNotFound indicates the requested vote doesn't exist
ErrVoteNotFound = errors.New("vote not found")
-
// ErrSubjectNotFound indicates the post/comment being voted on doesn't exist
-
ErrSubjectNotFound = errors.New("subject not found")
-
// ErrInvalidDirection indicates the vote direction is not "up" or "down"
ErrInvalidDirection = errors.New("invalid vote direction: must be 'up' or 'down'")
+14 -27
internal/core/votes/service_impl.go
···
// voteService implements the Service interface for vote operations
type voteService struct {
-
repo Repository
-
subjectValidator SubjectValidator
-
oauthClient *oauthclient.OAuthClient
-
oauthStore oauth.ClientAuthStore
-
logger *slog.Logger
+
repo Repository
+
oauthClient *oauthclient.OAuthClient
+
oauthStore oauth.ClientAuthStore
+
logger *slog.Logger
}
// NewService creates a new vote service instance
-
// subjectValidator can be nil to skip subject existence checks (not recommended for production)
-
func NewService(repo Repository, subjectValidator SubjectValidator, oauthClient *oauthclient.OAuthClient, oauthStore oauth.ClientAuthStore, logger *slog.Logger) Service {
+
func NewService(repo Repository, oauthClient *oauthclient.OAuthClient, oauthStore oauth.ClientAuthStore, logger *slog.Logger) Service {
if logger == nil {
logger = slog.Default()
}
return &voteService{
-
repo: repo,
-
subjectValidator: subjectValidator,
-
oauthClient: oauthClient,
-
oauthStore: oauthStore,
-
logger: logger,
+
repo: repo,
+
oauthClient: oauthClient,
+
oauthStore: oauthStore,
+
logger: logger,
}
}
···
return nil, ErrInvalidSubject
}
-
// Validate subject exists in AppView (post or comment)
-
// This prevents creating votes on non-existent content
-
if s.subjectValidator != nil {
-
exists, err := s.subjectValidator.SubjectExists(ctx, req.Subject.URI)
-
if err != nil {
-
s.logger.Error("failed to validate subject existence",
-
"error", err,
-
"subject", req.Subject.URI)
-
return nil, fmt.Errorf("failed to validate subject: %w", err)
-
}
-
if !exists {
-
return nil, ErrSubjectNotFound
-
}
-
}
+
// Note: We intentionally don't validate subject existence here.
+
// The vote record goes to the user's PDS regardless. The Jetstream consumer
+
// handles orphaned votes correctly by only updating counts for non-deleted subjects.
+
// This avoids race conditions and eventual consistency issues.
// Check for existing vote by querying PDS directly (source of truth)
// This avoids eventual consistency issues with the AppView database
···
// Parse the listRecords response
var result struct {
+
Cursor string `json:"cursor"`
Records []struct {
URI string `json:"uri"`
CID string `json:"cid"`
···
CreatedAt string `json:"createdAt"`
} `json:"value"`
} `json:"records"`
-
Cursor string `json:"cursor"`
}
if err := json.Unmarshal(body, &result); err != nil {
+3 -2
internal/db/postgres/vote_repo.go
···
return nil
}
-
// GetByURI retrieves a vote by its AT-URI
+
// GetByURI retrieves an active vote by its AT-URI
// Used by Jetstream consumer for DELETE operations
+
// Returns ErrVoteNotFound for soft-deleted votes
func (r *postgresVoteRepo) GetByURI(ctx context.Context, uri string) (*votes.Vote, error) {
query := `
SELECT
···
subject_uri, subject_cid, direction,
created_at, indexed_at, deleted_at
FROM votes
-
WHERE uri = $1
+
WHERE uri = $1 AND deleted_at IS NULL
`
var vote votes.Vote