A community based topic aggregation platform built on atproto

Compare changes

Choose any two refs to compare.

-73
cmd/genjwks/main.go
···
-
package main
-
-
import (
-
"crypto/ecdsa"
-
"crypto/elliptic"
-
"crypto/rand"
-
"encoding/json"
-
"fmt"
-
"log"
-
"os"
-
-
"github.com/lestrrat-go/jwx/v2/jwk"
-
)
-
-
// genjwks generates an ES256 keypair for OAuth client authentication
-
// The private key is stored in the config/env, public key is served at /oauth/jwks.json
-
//
-
// Usage:
-
//
-
// go run cmd/genjwks/main.go
-
//
-
// This will output a JSON private key that should be stored in OAUTH_PRIVATE_JWK
-
func main() {
-
fmt.Println("Generating ES256 keypair for OAuth client authentication...")
-
-
// Generate ES256 (NIST P-256) private key
-
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
-
if err != nil {
-
log.Fatalf("Failed to generate private key: %v", err)
-
}
-
-
// Convert to JWK
-
jwkKey, err := jwk.FromRaw(privateKey)
-
if err != nil {
-
log.Fatalf("Failed to create JWK from private key: %v", err)
-
}
-
-
// Set key parameters
-
if err = jwkKey.Set(jwk.KeyIDKey, "oauth-client-key"); err != nil {
-
log.Fatalf("Failed to set kid: %v", err)
-
}
-
if err = jwkKey.Set(jwk.AlgorithmKey, "ES256"); err != nil {
-
log.Fatalf("Failed to set alg: %v", err)
-
}
-
if err = jwkKey.Set(jwk.KeyUsageKey, "sig"); err != nil {
-
log.Fatalf("Failed to set use: %v", err)
-
}
-
-
// Marshal to JSON
-
jsonData, err := json.MarshalIndent(jwkKey, "", " ")
-
if err != nil {
-
log.Fatalf("Failed to marshal JWK: %v", err)
-
}
-
-
// Output instructions
-
fmt.Println("\nโœ… ES256 keypair generated successfully!")
-
fmt.Println("\n๐Ÿ“ Add this to your .env.dev file:")
-
fmt.Println("\nOAUTH_PRIVATE_JWK='" + string(jsonData) + "'")
-
fmt.Println("\nโš ๏ธ IMPORTANT:")
-
fmt.Println(" - Keep this private key SECRET")
-
fmt.Println(" - Never commit it to version control")
-
fmt.Println(" - Generate a new key for production")
-
fmt.Println(" - The public key will be automatically derived and served at /oauth/jwks.json")
-
-
// Optionally write to a file (not committed)
-
if len(os.Args) > 1 && os.Args[1] == "--save" {
-
filename := "oauth-private-key.json"
-
if err := os.WriteFile(filename, jsonData, 0o600); err != nil {
-
log.Fatalf("Failed to write key file: %v", err)
-
}
-
fmt.Printf("\n๐Ÿ’พ Private key saved to %s (remember to add to .gitignore!)\n", filename)
-
}
-
}
-330
internal/atproto/auth/README.md
···
-
# atProto OAuth Authentication
-
-
This package implements third-party OAuth authentication for Coves, validating DPoP-bound access tokens from mobile apps and other atProto clients.
-
-
## Architecture
-
-
This is **third-party authentication** (validating incoming requests), not first-party authentication (logging users into Coves web frontend).
-
-
### Components
-
-
1. **JWT Parser** (`jwt.go`) - Parses and validates JWT tokens
-
2. **JWKS Fetcher** (`jwks_fetcher.go`) - Fetches and caches public keys from PDS authorization servers
-
3. **Auth Middleware** (`internal/api/middleware/auth.go`) - HTTP middleware that protects endpoints
-
-
### Flow
-
-
```
-
Client Request
-
โ†“
-
Authorization: DPoP <access_token>
-
DPoP: <proof-jwt>
-
โ†“
-
Auth Middleware
-
โ†“
-
Extract JWT โ†’ Parse Claims โ†’ Verify Signature (via JWKS) โ†’ Verify DPoP Proof
-
โ†“
-
Inject DID into Context โ†’ Call Handler
-
```
-
-
## Usage
-
-
### Phase 1: Parse-Only Mode (Testing)
-
-
Set `AUTH_SKIP_VERIFY=true` to only parse JWTs without signature verification:
-
-
```bash
-
export AUTH_SKIP_VERIFY=true
-
```
-
-
This is useful for:
-
- Initial integration testing
-
- Testing with mock tokens
-
- Debugging JWT structure
-
-
### Phase 2: Full Verification (Production)
-
-
Set `AUTH_SKIP_VERIFY=false` (or unset) to enable full JWT signature verification:
-
-
```bash
-
export AUTH_SKIP_VERIFY=false
-
# or just unset it
-
```
-
-
This is **required for production** and validates:
-
- JWT signature using PDS public key
-
- Token expiration
-
- Required claims (sub, iss)
-
- DID format
-
-
## Protected Endpoints
-
-
The following endpoints require authentication:
-
-
- `POST /xrpc/social.coves.community.create`
-
- `POST /xrpc/social.coves.community.update`
-
- `POST /xrpc/social.coves.community.subscribe`
-
- `POST /xrpc/social.coves.community.unsubscribe`
-
-
### Making Authenticated Requests
-
-
Include the JWT in the `Authorization` header:
-
-
```bash
-
curl -X POST https://coves.social/xrpc/social.coves.community.create \
-
-H "Authorization: DPoP eyJhbGc..." \
-
-H "DPoP: eyJhbGc..." \
-
-H "Content-Type: application/json" \
-
-d '{"name":"Gaming","hostedByDid":"did:plc:..."}'
-
```
-
-
### Getting User DID in Handlers
-
-
The middleware injects the authenticated user's DID into the request context:
-
-
```go
-
import "Coves/internal/api/middleware"
-
-
func (h *Handler) HandleCreate(w http.ResponseWriter, r *http.Request) {
-
// Extract authenticated user DID
-
userDID := middleware.GetUserDID(r)
-
if userDID == "" {
-
// Not authenticated (should never happen with RequireAuth middleware)
-
http.Error(w, "Unauthorized", http.StatusUnauthorized)
-
return
-
}
-
-
// Use userDID for authorization checks
-
// ...
-
}
-
```
-
-
## Key Caching
-
-
Public keys are fetched from PDS authorization servers and cached for 1 hour. The cache is automatically cleaned up hourly to remove expired entries.
-
-
### JWKS Discovery Flow
-
-
1. Extract `iss` claim from JWT (e.g., `https://pds.example.com`)
-
2. Fetch `https://pds.example.com/.well-known/oauth-authorization-server`
-
3. Extract `jwks_uri` from metadata
-
4. Fetch JWKS from `jwks_uri`
-
5. Find matching key by `kid` from JWT header
-
6. Cache the JWKS for 1 hour
-
-
## DPoP Token Binding
-
-
DPoP (Demonstrating Proof-of-Possession) binds access tokens to client-controlled cryptographic keys, preventing token theft and replay attacks.
-
-
### What is DPoP?
-
-
DPoP is an OAuth extension (RFC 9449) that adds proof-of-possession semantics to bearer tokens. When a PDS issues a DPoP-bound access token:
-
-
1. Access token contains `cnf.jkt` claim (JWK thumbprint of client's public key)
-
2. Client creates a DPoP proof JWT signed with their private key
-
3. Server verifies the proof signature and checks it matches the token's `cnf.jkt`
-
-
### CRITICAL: DPoP Security Model
-
-
> โš ๏ธ **DPoP is an ADDITIONAL security layer, NOT a replacement for token signature verification.**
-
-
The correct verification order is:
-
1. **ALWAYS verify the access token signature first** (via JWKS, HS256 shared secret, or DID resolution)
-
2. **If the verified token has `cnf.jkt`, REQUIRE valid DPoP proof**
-
3. **NEVER use DPoP as a fallback when signature verification fails**
-
-
**Why This Matters**: An attacker could create a fake token with `sub: "did:plc:victim"` and their own `cnf.jkt`, then present a valid DPoP proof signed with their key. If we accept DPoP as a fallback, the attacker can impersonate any user.
-
-
### How DPoP Works
-
-
```
-
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
-
โ”‚ Client โ”‚ โ”‚ Server โ”‚
-
โ”‚ โ”‚ โ”‚ (Coves) โ”‚
-
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
-
โ”‚ โ”‚
-
โ”‚ 1. Authorization: DPoP <token> โ”‚
-
โ”‚ DPoP: <proof-jwt> โ”‚
-
โ”‚โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€>โ”‚
-
โ”‚ โ”‚
-
โ”‚ โ”‚ 2. VERIFY token signature
-
โ”‚ โ”‚ (REQUIRED - no fallback!)
-
โ”‚ โ”‚
-
โ”‚ โ”‚ 3. If token has cnf.jkt:
-
โ”‚ โ”‚ - Verify DPoP proof
-
โ”‚ โ”‚ - Check thumbprint match
-
โ”‚ โ”‚
-
โ”‚ 200 OK โ”‚
-
โ”‚<โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”‚
-
```
-
-
### When DPoP is Required
-
-
DPoP verification is **REQUIRED** when:
-
- Access token signature has been verified AND
-
- Access token contains `cnf.jkt` claim (DPoP-bound)
-
-
If the token has `cnf.jkt` but no DPoP header is present, the request is **REJECTED**.
-
-
### Replay Protection
-
-
DPoP proofs include a unique `jti` (JWT ID) claim. The server tracks seen `jti` values to prevent replay attacks:
-
-
```go
-
// Create a verifier with replay protection (default)
-
verifier := auth.NewDPoPVerifier()
-
defer verifier.Stop() // Stop cleanup goroutine on shutdown
-
-
// The verifier automatically rejects reused jti values within the proof validity window (5 minutes)
-
```
-
-
### DPoP Implementation
-
-
The `dpop.go` module provides:
-
-
```go
-
// Create a verifier with replay protection
-
verifier := auth.NewDPoPVerifier()
-
defer verifier.Stop()
-
-
// Verify the DPoP proof
-
proof, err := verifier.VerifyDPoPProof(dpopHeader, "POST", "https://coves.social/xrpc/...")
-
if err != nil {
-
// Invalid proof (includes replay detection)
-
}
-
-
// Verify it binds to the VERIFIED access token
-
expectedThumbprint, err := auth.ExtractCnfJkt(claims)
-
if err != nil {
-
// Token not DPoP-bound
-
}
-
-
if err := verifier.VerifyTokenBinding(proof, expectedThumbprint); err != nil {
-
// Proof doesn't match token
-
}
-
```
-
-
### DPoP Proof Format
-
-
The DPoP header contains a JWT with:
-
-
**Header**:
-
- `typ`: `"dpop+jwt"` (required)
-
- `alg`: `"ES256"` (or other supported algorithm)
-
- `jwk`: Client's public key (JWK format)
-
-
**Claims**:
-
- `jti`: Unique proof identifier (tracked for replay protection)
-
- `htm`: HTTP method (e.g., `"POST"`)
-
- `htu`: HTTP URI (without query/fragment)
-
- `iat`: Timestamp (must be recent, within 5 minutes)
-
-
**Example**:
-
```json
-
{
-
"typ": "dpop+jwt",
-
"alg": "ES256",
-
"jwk": {
-
"kty": "EC",
-
"crv": "P-256",
-
"x": "...",
-
"y": "..."
-
}
-
}
-
{
-
"jti": "unique-id-123",
-
"htm": "POST",
-
"htu": "https://coves.social/xrpc/social.coves.community.create",
-
"iat": 1700000000
-
}
-
```
-
-
## Security Considerations
-
-
### โœ… Implemented
-
-
- JWT signature verification with PDS public keys
-
- Token expiration validation
-
- DID format validation
-
- Required claims validation (sub, iss)
-
- Key caching with TTL
-
- Secure error messages (no internal details leaked)
-
- **DPoP proof verification** (proof-of-possession for token binding)
-
- **DPoP thumbprint validation** (prevents token theft attacks)
-
- **DPoP freshness checks** (5-minute proof validity window)
-
- **DPoP replay protection** (jti tracking with in-memory cache)
-
- **Secure DPoP model** (DPoP required AFTER signature verification, never as fallback)
-
-
### โš ๏ธ Not Yet Implemented
-
-
- Server-issued DPoP nonces (additional replay protection)
-
- Scope validation (checking `scope` claim)
-
- Audience validation (checking `aud` claim)
-
- Rate limiting per DID
-
- Token revocation checking
-
-
## Testing
-
-
Run the test suite:
-
-
```bash
-
go test ./internal/atproto/auth/... -v
-
```
-
-
### Manual Testing
-
-
1. **Phase 1 (Parse Only)**:
-
```bash
-
# Create a test JWT (use jwt.io or a tool)
-
export AUTH_SKIP_VERIFY=true
-
curl -X POST http://localhost:8081/xrpc/social.coves.community.create \
-
-H "Authorization: DPoP <test-jwt>" \
-
-H "DPoP: <test-dpop-proof>" \
-
-d '{"name":"Test","hostedByDid":"did:plc:test"}'
-
```
-
-
2. **Phase 2 (Full Verification)**:
-
```bash
-
# Use a real JWT from a PDS
-
export AUTH_SKIP_VERIFY=false
-
curl -X POST http://localhost:8081/xrpc/social.coves.community.create \
-
-H "Authorization: DPoP <real-jwt>" \
-
-H "DPoP: <real-dpop-proof>" \
-
-d '{"name":"Test","hostedByDid":"did:plc:test"}'
-
```
-
-
## Error Responses
-
-
### 401 Unauthorized
-
-
Missing or invalid token:
-
-
```json
-
{
-
"error": "AuthenticationRequired",
-
"message": "Missing Authorization header"
-
}
-
```
-
-
```json
-
{
-
"error": "AuthenticationRequired",
-
"message": "Invalid or expired token"
-
}
-
```
-
-
### Common Issues
-
-
1. **Missing Authorization header** โ†’ Add `Authorization: DPoP <token>` and `DPoP: <proof>`
-
2. **Token expired** โ†’ Get a new token from PDS
-
3. **Invalid signature** โ†’ Ensure token is from a valid PDS
-
4. **JWKS fetch fails** โ†’ Check PDS availability and network connectivity
-
-
## Future Enhancements
-
-
- [ ] DPoP nonce validation (server-managed nonce for additional replay protection)
-
- [ ] Scope-based authorization
-
- [ ] Audience claim validation
-
- [ ] Token revocation support
-
- [ ] Rate limiting per DID
-
- [ ] Metrics and monitoring
-52
internal/atproto/auth/combined_key_fetcher.go
···
-
package auth
-
-
import (
-
"context"
-
"fmt"
-
"strings"
-
-
indigoIdentity "github.com/bluesky-social/indigo/atproto/identity"
-
)
-
-
// CombinedKeyFetcher handles JWT public key fetching for both:
-
// - DID issuers (did:plc:, did:web:) โ†’ resolves via DID document
-
// - URL issuers (https://) โ†’ fetches via JWKS endpoint (legacy/fallback)
-
//
-
// For atproto service authentication, the issuer is typically the user's DID,
-
// and the signing key is published in their DID document.
-
type CombinedKeyFetcher struct {
-
didFetcher *DIDKeyFetcher
-
jwksFetcher JWKSFetcher
-
}
-
-
// NewCombinedKeyFetcher creates a key fetcher that supports both DID and URL issuers.
-
// Parameters:
-
// - directory: Indigo's identity directory for DID resolution
-
// - jwksFetcher: fallback JWKS fetcher for URL issuers (can be nil if not needed)
-
func NewCombinedKeyFetcher(directory indigoIdentity.Directory, jwksFetcher JWKSFetcher) *CombinedKeyFetcher {
-
return &CombinedKeyFetcher{
-
didFetcher: NewDIDKeyFetcher(directory),
-
jwksFetcher: jwksFetcher,
-
}
-
}
-
-
// FetchPublicKey fetches the public key for verifying a JWT.
-
// Routes to the appropriate fetcher based on issuer format:
-
// - DID (did:plc:, did:web:) โ†’ DIDKeyFetcher
-
// - URL (https://) โ†’ JWKSFetcher
-
func (f *CombinedKeyFetcher) FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) {
-
// Check if issuer is a DID
-
if strings.HasPrefix(issuer, "did:") {
-
return f.didFetcher.FetchPublicKey(ctx, issuer, token)
-
}
-
-
// Check if issuer is a URL (https:// or http:// in dev)
-
if strings.HasPrefix(issuer, "https://") || strings.HasPrefix(issuer, "http://") {
-
if f.jwksFetcher == nil {
-
return nil, fmt.Errorf("URL issuer %s requires JWKS fetcher, but none configured", issuer)
-
}
-
return f.jwksFetcher.FetchPublicKey(ctx, issuer, token)
-
}
-
-
return nil, fmt.Errorf("unsupported issuer format: %s (expected DID or URL)", issuer)
-
}
-122
internal/atproto/auth/did_key_fetcher.go
···
-
package auth
-
-
import (
-
"context"
-
"crypto/ecdsa"
-
"crypto/elliptic"
-
"encoding/base64"
-
"fmt"
-
"math/big"
-
"strings"
-
-
indigoCrypto "github.com/bluesky-social/indigo/atproto/atcrypto"
-
indigoIdentity "github.com/bluesky-social/indigo/atproto/identity"
-
"github.com/bluesky-social/indigo/atproto/syntax"
-
)
-
-
// DIDKeyFetcher fetches public keys from DID documents for JWT verification.
-
// This is the primary method for atproto service authentication, where:
-
// - The JWT issuer is the user's DID (e.g., did:plc:abc123)
-
// - The signing key is published in the user's DID document
-
// - Verification happens by resolving the DID and checking the signature
-
type DIDKeyFetcher struct {
-
directory indigoIdentity.Directory
-
}
-
-
// NewDIDKeyFetcher creates a new DID-based key fetcher.
-
func NewDIDKeyFetcher(directory indigoIdentity.Directory) *DIDKeyFetcher {
-
return &DIDKeyFetcher{
-
directory: directory,
-
}
-
}
-
-
// FetchPublicKey fetches the public key for verifying a JWT from the issuer's DID document.
-
// For DID issuers (did:plc: or did:web:), resolves the DID and extracts the signing key.
-
//
-
// Returns:
-
// - indigoCrypto.PublicKey for secp256k1 (ES256K) keys - use indigo for verification
-
// - *ecdsa.PublicKey for NIST curves (P-256, P-384, P-521) - compatible with golang-jwt
-
func (f *DIDKeyFetcher) FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) {
-
// Only handle DID issuers
-
if !strings.HasPrefix(issuer, "did:") {
-
return nil, fmt.Errorf("DIDKeyFetcher only handles DID issuers, got: %s", issuer)
-
}
-
-
// Parse the DID
-
did, err := syntax.ParseDID(issuer)
-
if err != nil {
-
return nil, fmt.Errorf("invalid DID format: %w", err)
-
}
-
-
// Resolve the DID to get the identity (includes public keys)
-
ident, err := f.directory.LookupDID(ctx, did)
-
if err != nil {
-
return nil, fmt.Errorf("failed to resolve DID %s: %w", issuer, err)
-
}
-
-
// Get the atproto signing key from the DID document
-
pubKey, err := ident.PublicKey()
-
if err != nil {
-
return nil, fmt.Errorf("failed to get public key from DID document: %w", err)
-
}
-
-
// Convert to JWK format to check curve type
-
jwk, err := pubKey.JWK()
-
if err != nil {
-
return nil, fmt.Errorf("failed to convert public key to JWK: %w", err)
-
}
-
-
// For secp256k1 (ES256K), return indigo's PublicKey directly
-
// since Go's crypto/ecdsa doesn't support this curve
-
if jwk.Curve == "secp256k1" {
-
return pubKey, nil
-
}
-
-
// For NIST curves, convert to Go's ecdsa.PublicKey for golang-jwt compatibility
-
return atcryptoJWKToECDSA(jwk)
-
}
-
-
// atcryptoJWKToECDSA converts an indigoCrypto.JWK to a Go ecdsa.PublicKey.
-
// Note: secp256k1 is handled separately in FetchPublicKey by returning indigo's PublicKey directly.
-
func atcryptoJWKToECDSA(jwk *indigoCrypto.JWK) (*ecdsa.PublicKey, error) {
-
if jwk.KeyType != "EC" {
-
return nil, fmt.Errorf("unsupported JWK key type: %s (expected EC)", jwk.KeyType)
-
}
-
-
// Decode X and Y coordinates (base64url, no padding)
-
xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X)
-
if err != nil {
-
return nil, fmt.Errorf("invalid JWK X coordinate encoding: %w", err)
-
}
-
yBytes, err := base64.RawURLEncoding.DecodeString(jwk.Y)
-
if err != nil {
-
return nil, fmt.Errorf("invalid JWK Y coordinate encoding: %w", err)
-
}
-
-
var ecCurve elliptic.Curve
-
switch jwk.Curve {
-
case "P-256":
-
ecCurve = elliptic.P256()
-
case "P-384":
-
ecCurve = elliptic.P384()
-
case "P-521":
-
ecCurve = elliptic.P521()
-
default:
-
// secp256k1 should be handled before calling this function
-
return nil, fmt.Errorf("unsupported JWK curve for Go ecdsa: %s (secp256k1 uses indigo)", jwk.Curve)
-
}
-
-
// Create the public key
-
pubKey := &ecdsa.PublicKey{
-
Curve: ecCurve,
-
X: new(big.Int).SetBytes(xBytes),
-
Y: new(big.Int).SetBytes(yBytes),
-
}
-
-
// Validate point is on curve
-
if !ecCurve.IsOnCurve(pubKey.X, pubKey.Y) {
-
return nil, fmt.Errorf("invalid public key: point not on curve")
-
}
-
-
return pubKey, nil
-
}
-616
internal/atproto/auth/dpop.go
···
-
package auth
-
-
import (
-
"crypto/sha256"
-
"encoding/base64"
-
"encoding/json"
-
"fmt"
-
"strings"
-
"sync"
-
"time"
-
-
indigoCrypto "github.com/bluesky-social/indigo/atproto/atcrypto"
-
"github.com/golang-jwt/jwt/v5"
-
)
-
-
// NonceCache provides replay protection for DPoP proofs by tracking seen jti values.
-
// This prevents an attacker from reusing a captured DPoP proof within the validity window.
-
// Per RFC 9449 Section 11.1, servers SHOULD prevent replay attacks.
-
type NonceCache struct {
-
seen map[string]time.Time // jti -> expiration time
-
stopCh chan struct{}
-
maxAge time.Duration // How long to keep entries
-
cleanup time.Duration // How often to clean up expired entries
-
mu sync.RWMutex
-
}
-
-
// NewNonceCache creates a new nonce cache for DPoP replay protection.
-
// maxAge should match or exceed DPoPVerifier.MaxProofAge.
-
func NewNonceCache(maxAge time.Duration) *NonceCache {
-
nc := &NonceCache{
-
seen: make(map[string]time.Time),
-
maxAge: maxAge,
-
cleanup: maxAge / 2, // Clean up at half the max age
-
stopCh: make(chan struct{}),
-
}
-
-
// Start background cleanup goroutine
-
go nc.cleanupLoop()
-
-
return nc
-
}
-
-
// CheckAndStore checks if a jti has been seen before and stores it if not.
-
// Returns true if the jti is fresh (not a replay), false if it's a replay.
-
func (nc *NonceCache) CheckAndStore(jti string) bool {
-
nc.mu.Lock()
-
defer nc.mu.Unlock()
-
-
now := time.Now()
-
expiry := now.Add(nc.maxAge)
-
-
// Check if already seen
-
if existingExpiry, seen := nc.seen[jti]; seen {
-
// Still valid (not expired) - this is a replay
-
if existingExpiry.After(now) {
-
return false
-
}
-
// Expired entry - allow reuse and update expiry
-
}
-
-
// Store the new jti
-
nc.seen[jti] = expiry
-
return true
-
}
-
-
// cleanupLoop periodically removes expired entries from the cache
-
func (nc *NonceCache) cleanupLoop() {
-
ticker := time.NewTicker(nc.cleanup)
-
defer ticker.Stop()
-
-
for {
-
select {
-
case <-ticker.C:
-
nc.cleanupExpired()
-
case <-nc.stopCh:
-
return
-
}
-
}
-
}
-
-
// cleanupExpired removes expired entries from the cache
-
func (nc *NonceCache) cleanupExpired() {
-
nc.mu.Lock()
-
defer nc.mu.Unlock()
-
-
now := time.Now()
-
for jti, expiry := range nc.seen {
-
if expiry.Before(now) {
-
delete(nc.seen, jti)
-
}
-
}
-
}
-
-
// Stop stops the cleanup goroutine. Call this when done with the cache.
-
func (nc *NonceCache) Stop() {
-
close(nc.stopCh)
-
}
-
-
// Size returns the number of entries in the cache (for testing/monitoring)
-
func (nc *NonceCache) Size() int {
-
nc.mu.RLock()
-
defer nc.mu.RUnlock()
-
return len(nc.seen)
-
}
-
-
// DPoPClaims represents the claims in a DPoP proof JWT (RFC 9449)
-
type DPoPClaims struct {
-
jwt.RegisteredClaims
-
-
// HTTP method of the request (e.g., "GET", "POST")
-
HTTPMethod string `json:"htm"`
-
-
// HTTP URI of the request (without query and fragment parts)
-
HTTPURI string `json:"htu"`
-
-
// Access token hash (optional, for token binding)
-
AccessTokenHash string `json:"ath,omitempty"`
-
}
-
-
// DPoPProof represents a parsed and verified DPoP proof
-
type DPoPProof struct {
-
RawPublicJWK map[string]interface{}
-
Claims *DPoPClaims
-
PublicKey interface{} // *ecdsa.PublicKey or similar
-
Thumbprint string // JWK thumbprint (base64url)
-
}
-
-
// DPoPVerifier verifies DPoP proofs for OAuth token binding
-
type DPoPVerifier struct {
-
// Optional: custom nonce validation function (for server-issued nonces)
-
ValidateNonce func(nonce string) bool
-
-
// NonceCache for replay protection (optional but recommended)
-
// If nil, jti replay protection is disabled
-
NonceCache *NonceCache
-
-
// Maximum allowed clock skew for timestamp validation
-
MaxClockSkew time.Duration
-
-
// Maximum age of DPoP proof (prevents replay with old proofs)
-
MaxProofAge time.Duration
-
}
-
-
// NewDPoPVerifier creates a DPoP verifier with sensible defaults including replay protection
-
func NewDPoPVerifier() *DPoPVerifier {
-
maxProofAge := 5 * time.Minute
-
return &DPoPVerifier{
-
MaxClockSkew: 30 * time.Second,
-
MaxProofAge: maxProofAge,
-
NonceCache: NewNonceCache(maxProofAge),
-
}
-
}
-
-
// NewDPoPVerifierWithoutReplayProtection creates a DPoP verifier without replay protection.
-
// This should only be used in testing or when replay protection is handled externally.
-
func NewDPoPVerifierWithoutReplayProtection() *DPoPVerifier {
-
return &DPoPVerifier{
-
MaxClockSkew: 30 * time.Second,
-
MaxProofAge: 5 * time.Minute,
-
NonceCache: nil, // No replay protection
-
}
-
}
-
-
// Stop stops background goroutines. Call this when shutting down.
-
func (v *DPoPVerifier) Stop() {
-
if v.NonceCache != nil {
-
v.NonceCache.Stop()
-
}
-
}
-
-
// VerifyDPoPProof verifies a DPoP proof JWT and returns the parsed proof.
-
// This supports all atProto-compatible ECDSA algorithms including ES256K (secp256k1).
-
func (v *DPoPVerifier) VerifyDPoPProof(dpopProof, httpMethod, httpURI string) (*DPoPProof, error) {
-
// Manually parse the JWT to support ES256K (which golang-jwt doesn't recognize)
-
header, claims, err := parseJWTHeaderAndClaims(dpopProof)
-
if err != nil {
-
return nil, fmt.Errorf("failed to parse DPoP proof: %w", err)
-
}
-
-
// Extract and validate the typ header
-
typ, ok := header["typ"].(string)
-
if !ok || typ != "dpop+jwt" {
-
return nil, fmt.Errorf("invalid DPoP proof: typ must be 'dpop+jwt', got '%s'", typ)
-
}
-
-
alg, ok := header["alg"].(string)
-
if !ok {
-
return nil, fmt.Errorf("invalid DPoP proof: missing alg header")
-
}
-
-
// Extract the JWK from the header first (needed for algorithm-curve validation)
-
jwkRaw, ok := header["jwk"]
-
if !ok {
-
return nil, fmt.Errorf("invalid DPoP proof: missing jwk header")
-
}
-
-
jwkMap, ok := jwkRaw.(map[string]interface{})
-
if !ok {
-
return nil, fmt.Errorf("invalid DPoP proof: jwk must be an object")
-
}
-
-
// Validate the algorithm is supported and matches the JWK curve
-
// This is critical for security - prevents algorithm confusion attacks
-
if err := validateAlgorithmCurveBinding(alg, jwkMap); err != nil {
-
return nil, fmt.Errorf("invalid DPoP proof: %w", err)
-
}
-
-
// Parse the public key using indigo's crypto package
-
// This supports all atProto curves including secp256k1 (ES256K)
-
publicKey, err := parseJWKToIndigoPublicKey(jwkMap)
-
if err != nil {
-
return nil, fmt.Errorf("invalid DPoP proof JWK: %w", err)
-
}
-
-
// Calculate the JWK thumbprint
-
thumbprint, err := CalculateJWKThumbprint(jwkMap)
-
if err != nil {
-
return nil, fmt.Errorf("failed to calculate JWK thumbprint: %w", err)
-
}
-
-
// Verify the signature using indigo's crypto package
-
// This works for all ECDSA algorithms including ES256K
-
if err := verifyJWTSignatureWithIndigo(dpopProof, publicKey); err != nil {
-
return nil, fmt.Errorf("DPoP proof signature verification failed: %w", err)
-
}
-
-
// Validate the claims
-
if err := v.validateDPoPClaims(claims, httpMethod, httpURI); err != nil {
-
return nil, err
-
}
-
-
return &DPoPProof{
-
Claims: claims,
-
PublicKey: publicKey,
-
Thumbprint: thumbprint,
-
RawPublicJWK: jwkMap,
-
}, nil
-
}
-
-
// validateDPoPClaims validates the DPoP proof claims
-
func (v *DPoPVerifier) validateDPoPClaims(claims *DPoPClaims, expectedMethod, expectedURI string) error {
-
// Validate jti (unique identifier) is present
-
if claims.ID == "" {
-
return fmt.Errorf("DPoP proof missing jti claim")
-
}
-
-
// Validate htm (HTTP method)
-
if !strings.EqualFold(claims.HTTPMethod, expectedMethod) {
-
return fmt.Errorf("DPoP proof htm mismatch: expected %s, got %s", expectedMethod, claims.HTTPMethod)
-
}
-
-
// Validate htu (HTTP URI) - compare without query/fragment
-
expectedURIBase := stripQueryFragment(expectedURI)
-
claimURIBase := stripQueryFragment(claims.HTTPURI)
-
if expectedURIBase != claimURIBase {
-
return fmt.Errorf("DPoP proof htu mismatch: expected %s, got %s", expectedURIBase, claimURIBase)
-
}
-
-
// Validate iat (issued at) is present and recent
-
if claims.IssuedAt == nil {
-
return fmt.Errorf("DPoP proof missing iat claim")
-
}
-
-
now := time.Now()
-
iat := claims.IssuedAt.Time
-
-
// Check clock skew (not too far in the future)
-
if iat.After(now.Add(v.MaxClockSkew)) {
-
return fmt.Errorf("DPoP proof iat is in the future")
-
}
-
-
// Check proof age (not too old)
-
if now.Sub(iat) > v.MaxProofAge {
-
return fmt.Errorf("DPoP proof is too old (issued %v ago, max %v)", now.Sub(iat), v.MaxProofAge)
-
}
-
-
// SECURITY: Validate exp claim if present (RFC standard JWT validation)
-
// While DPoP proofs typically use iat + MaxProofAge, if exp is included it must be honored
-
if claims.ExpiresAt != nil {
-
expWithSkew := claims.ExpiresAt.Time.Add(v.MaxClockSkew)
-
if now.After(expWithSkew) {
-
return fmt.Errorf("DPoP proof expired at %v", claims.ExpiresAt.Time)
-
}
-
}
-
-
// SECURITY: Validate nbf claim if present (RFC standard JWT validation)
-
if claims.NotBefore != nil {
-
nbfWithSkew := claims.NotBefore.Time.Add(-v.MaxClockSkew)
-
if now.Before(nbfWithSkew) {
-
return fmt.Errorf("DPoP proof not valid before %v", claims.NotBefore.Time)
-
}
-
}
-
-
// SECURITY: Check for replay attack using jti
-
// Per RFC 9449 Section 11.1, servers SHOULD prevent replay attacks
-
if v.NonceCache != nil {
-
if !v.NonceCache.CheckAndStore(claims.ID) {
-
return fmt.Errorf("DPoP proof replay detected: jti %s already used", claims.ID)
-
}
-
}
-
-
return nil
-
}
-
-
// VerifyTokenBinding verifies that the DPoP proof binds to the access token
-
// by comparing the proof's thumbprint to the token's cnf.jkt claim
-
func (v *DPoPVerifier) VerifyTokenBinding(proof *DPoPProof, expectedThumbprint string) error {
-
if proof.Thumbprint != expectedThumbprint {
-
return fmt.Errorf("DPoP proof thumbprint mismatch: token expects %s, proof has %s",
-
expectedThumbprint, proof.Thumbprint)
-
}
-
return nil
-
}
-
-
// VerifyAccessTokenHash verifies the DPoP proof's ath (access token hash) claim
-
// matches the SHA-256 hash of the presented access token.
-
// Per RFC 9449 section 4.2, if ath is present, the RS MUST verify it.
-
func (v *DPoPVerifier) VerifyAccessTokenHash(proof *DPoPProof, accessToken string) error {
-
// If ath claim is not present, that's acceptable per RFC 9449
-
// (ath is only required when the RS mandates it)
-
if proof.Claims.AccessTokenHash == "" {
-
return nil
-
}
-
-
// Calculate the expected ath: base64url(SHA-256(access_token))
-
hash := sha256.Sum256([]byte(accessToken))
-
expectedAth := base64.RawURLEncoding.EncodeToString(hash[:])
-
-
if proof.Claims.AccessTokenHash != expectedAth {
-
return fmt.Errorf("DPoP proof ath mismatch: proof bound to different access token")
-
}
-
-
return nil
-
}
-
-
// CalculateJWKThumbprint calculates the JWK thumbprint per RFC 7638
-
// The thumbprint is the base64url-encoded SHA-256 hash of the canonical JWK representation
-
func CalculateJWKThumbprint(jwk map[string]interface{}) (string, error) {
-
kty, ok := jwk["kty"].(string)
-
if !ok {
-
return "", fmt.Errorf("JWK missing kty")
-
}
-
-
// Build the canonical JWK representation based on key type
-
// Per RFC 7638, only specific members are included, in lexicographic order
-
var canonical map[string]string
-
-
switch kty {
-
case "EC":
-
crv, ok := jwk["crv"].(string)
-
if !ok {
-
return "", fmt.Errorf("EC JWK missing crv")
-
}
-
x, ok := jwk["x"].(string)
-
if !ok {
-
return "", fmt.Errorf("EC JWK missing x")
-
}
-
y, ok := jwk["y"].(string)
-
if !ok {
-
return "", fmt.Errorf("EC JWK missing y")
-
}
-
// Lexicographic order: crv, kty, x, y
-
canonical = map[string]string{
-
"crv": crv,
-
"kty": kty,
-
"x": x,
-
"y": y,
-
}
-
case "RSA":
-
e, ok := jwk["e"].(string)
-
if !ok {
-
return "", fmt.Errorf("RSA JWK missing e")
-
}
-
n, ok := jwk["n"].(string)
-
if !ok {
-
return "", fmt.Errorf("RSA JWK missing n")
-
}
-
// Lexicographic order: e, kty, n
-
canonical = map[string]string{
-
"e": e,
-
"kty": kty,
-
"n": n,
-
}
-
case "OKP":
-
crv, ok := jwk["crv"].(string)
-
if !ok {
-
return "", fmt.Errorf("OKP JWK missing crv")
-
}
-
x, ok := jwk["x"].(string)
-
if !ok {
-
return "", fmt.Errorf("OKP JWK missing x")
-
}
-
// Lexicographic order: crv, kty, x
-
canonical = map[string]string{
-
"crv": crv,
-
"kty": kty,
-
"x": x,
-
}
-
default:
-
return "", fmt.Errorf("unsupported JWK key type: %s", kty)
-
}
-
-
// Serialize to JSON (Go's json.Marshal produces lexicographically ordered keys for map[string]string)
-
canonicalJSON, err := json.Marshal(canonical)
-
if err != nil {
-
return "", fmt.Errorf("failed to serialize canonical JWK: %w", err)
-
}
-
-
// SHA-256 hash
-
hash := sha256.Sum256(canonicalJSON)
-
-
// Base64url encode (no padding)
-
thumbprint := base64.RawURLEncoding.EncodeToString(hash[:])
-
-
return thumbprint, nil
-
}
-
-
// validateAlgorithmCurveBinding validates that the JWT algorithm matches the JWK curve.
-
// This is critical for security - an attacker could claim alg: "ES256K" but provide
-
// a P-256 key, potentially bypassing algorithm binding requirements.
-
func validateAlgorithmCurveBinding(alg string, jwkMap map[string]interface{}) error {
-
kty, ok := jwkMap["kty"].(string)
-
if !ok {
-
return fmt.Errorf("JWK missing kty")
-
}
-
-
// ECDSA algorithms require EC key type
-
switch alg {
-
case "ES256K", "ES256", "ES384", "ES512":
-
if kty != "EC" {
-
return fmt.Errorf("algorithm %s requires EC key type, got %s", alg, kty)
-
}
-
case "RS256", "RS384", "RS512", "PS256", "PS384", "PS512":
-
return fmt.Errorf("RSA algorithms not yet supported for DPoP: %s", alg)
-
default:
-
return fmt.Errorf("unsupported DPoP algorithm: %s", alg)
-
}
-
-
// Validate curve matches algorithm
-
crv, ok := jwkMap["crv"].(string)
-
if !ok {
-
return fmt.Errorf("EC JWK missing crv")
-
}
-
-
var expectedCurve string
-
switch alg {
-
case "ES256K":
-
expectedCurve = "secp256k1"
-
case "ES256":
-
expectedCurve = "P-256"
-
case "ES384":
-
expectedCurve = "P-384"
-
case "ES512":
-
expectedCurve = "P-521"
-
}
-
-
if crv != expectedCurve {
-
return fmt.Errorf("algorithm %s requires curve %s, got %s", alg, expectedCurve, crv)
-
}
-
-
return nil
-
}
-
-
// parseJWKToIndigoPublicKey parses a JWK map to an indigo PublicKey.
-
// This returns indigo's PublicKey interface which supports all atProto curves
-
// including secp256k1 (ES256K), P-256 (ES256), P-384 (ES384), and P-521 (ES512).
-
func parseJWKToIndigoPublicKey(jwkMap map[string]interface{}) (indigoCrypto.PublicKey, error) {
-
// Convert map to JSON bytes for indigo's parser
-
jwkBytes, err := json.Marshal(jwkMap)
-
if err != nil {
-
return nil, fmt.Errorf("failed to serialize JWK: %w", err)
-
}
-
-
// Parse with indigo's crypto package - this supports all atProto curves
-
// including secp256k1 (ES256K) which Go's crypto/elliptic doesn't support
-
pubKey, err := indigoCrypto.ParsePublicJWKBytes(jwkBytes)
-
if err != nil {
-
return nil, fmt.Errorf("failed to parse JWK: %w", err)
-
}
-
-
return pubKey, nil
-
}
-
-
// parseJWTHeaderAndClaims manually parses a JWT's header and claims without using golang-jwt.
-
// This is necessary to support ES256K (secp256k1) which golang-jwt doesn't recognize.
-
func parseJWTHeaderAndClaims(tokenString string) (map[string]interface{}, *DPoPClaims, error) {
-
parts := strings.Split(tokenString, ".")
-
if len(parts) != 3 {
-
return nil, nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
-
}
-
-
// Decode header
-
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
-
if err != nil {
-
return nil, nil, fmt.Errorf("failed to decode JWT header: %w", err)
-
}
-
-
var header map[string]interface{}
-
if err := json.Unmarshal(headerBytes, &header); err != nil {
-
return nil, nil, fmt.Errorf("failed to parse JWT header: %w", err)
-
}
-
-
// Decode claims
-
claimsBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
-
if err != nil {
-
return nil, nil, fmt.Errorf("failed to decode JWT claims: %w", err)
-
}
-
-
// Parse into raw map first to extract standard claims
-
var rawClaims map[string]interface{}
-
if err := json.Unmarshal(claimsBytes, &rawClaims); err != nil {
-
return nil, nil, fmt.Errorf("failed to parse JWT claims: %w", err)
-
}
-
-
// Build DPoPClaims struct
-
claims := &DPoPClaims{}
-
-
// Extract jti
-
if jti, ok := rawClaims["jti"].(string); ok {
-
claims.ID = jti
-
}
-
-
// Extract iat (issued at)
-
if iat, ok := rawClaims["iat"].(float64); ok {
-
t := time.Unix(int64(iat), 0)
-
claims.IssuedAt = jwt.NewNumericDate(t)
-
}
-
-
// Extract exp (expiration) if present
-
if exp, ok := rawClaims["exp"].(float64); ok {
-
t := time.Unix(int64(exp), 0)
-
claims.ExpiresAt = jwt.NewNumericDate(t)
-
}
-
-
// Extract nbf (not before) if present
-
if nbf, ok := rawClaims["nbf"].(float64); ok {
-
t := time.Unix(int64(nbf), 0)
-
claims.NotBefore = jwt.NewNumericDate(t)
-
}
-
-
// Extract htm (HTTP method)
-
if htm, ok := rawClaims["htm"].(string); ok {
-
claims.HTTPMethod = htm
-
}
-
-
// Extract htu (HTTP URI)
-
if htu, ok := rawClaims["htu"].(string); ok {
-
claims.HTTPURI = htu
-
}
-
-
// Extract ath (access token hash) if present
-
if ath, ok := rawClaims["ath"].(string); ok {
-
claims.AccessTokenHash = ath
-
}
-
-
return header, claims, nil
-
}
-
-
// verifyJWTSignatureWithIndigo verifies a JWT signature using indigo's crypto package.
-
// This is used instead of golang-jwt for algorithms not supported by golang-jwt (like ES256K).
-
// It parses the JWT, extracts the signing input and signature, and uses indigo's
-
// PublicKey.HashAndVerifyLenient() for verification.
-
//
-
// JWT format: header.payload.signature (all base64url-encoded)
-
// Signature is verified over the raw bytes of "header.payload"
-
// (indigo's HashAndVerifyLenient handles SHA-256 hashing internally)
-
func verifyJWTSignatureWithIndigo(tokenString string, pubKey indigoCrypto.PublicKey) error {
-
// Split the JWT into parts
-
parts := strings.Split(tokenString, ".")
-
if len(parts) != 3 {
-
return fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
-
}
-
-
// The signing input is "header.payload" (without decoding)
-
signingInput := parts[0] + "." + parts[1]
-
-
// Decode the signature from base64url
-
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
-
if err != nil {
-
return fmt.Errorf("failed to decode JWT signature: %w", err)
-
}
-
-
// Use indigo's verification - HashAndVerifyLenient handles hashing internally
-
// and accepts both low-S and high-S signatures for maximum compatibility
-
err = pubKey.HashAndVerifyLenient([]byte(signingInput), signature)
-
if err != nil {
-
return fmt.Errorf("signature verification failed: %w", err)
-
}
-
-
return nil
-
}
-
-
// stripQueryFragment removes query and fragment from a URI
-
func stripQueryFragment(uri string) string {
-
if idx := strings.Index(uri, "?"); idx != -1 {
-
uri = uri[:idx]
-
}
-
if idx := strings.Index(uri, "#"); idx != -1 {
-
uri = uri[:idx]
-
}
-
return uri
-
}
-
-
// ExtractCnfJkt extracts the cnf.jkt (confirmation key thumbprint) from JWT claims
-
func ExtractCnfJkt(claims *Claims) (string, error) {
-
if claims.Confirmation == nil {
-
return "", fmt.Errorf("token missing cnf claim (no DPoP binding)")
-
}
-
-
jkt, ok := claims.Confirmation["jkt"].(string)
-
if !ok || jkt == "" {
-
return "", fmt.Errorf("token cnf claim missing jkt (DPoP key thumbprint)")
-
}
-
-
return jkt, nil
-
}
-1308
internal/atproto/auth/dpop_test.go
···
-
package auth
-
-
import (
-
"crypto/ecdsa"
-
"crypto/elliptic"
-
"crypto/rand"
-
"crypto/sha256"
-
"encoding/base64"
-
"encoding/json"
-
"strings"
-
"testing"
-
"time"
-
-
indigoCrypto "github.com/bluesky-social/indigo/atproto/atcrypto"
-
"github.com/golang-jwt/jwt/v5"
-
"github.com/google/uuid"
-
)
-
-
// === Test Helpers ===
-
-
// testECKey holds a test ES256 key pair
-
type testECKey struct {
-
privateKey *ecdsa.PrivateKey
-
publicKey *ecdsa.PublicKey
-
jwk map[string]interface{}
-
thumbprint string
-
}
-
-
// generateTestES256Key generates a test ES256 key pair and JWK
-
func generateTestES256Key(t *testing.T) *testECKey {
-
t.Helper()
-
-
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
-
if err != nil {
-
t.Fatalf("Failed to generate test key: %v", err)
-
}
-
-
// Encode public key coordinates as base64url
-
xBytes := privateKey.PublicKey.X.Bytes()
-
yBytes := privateKey.PublicKey.Y.Bytes()
-
-
// P-256 coordinates must be 32 bytes (pad if needed)
-
xBytes = padTo32Bytes(xBytes)
-
yBytes = padTo32Bytes(yBytes)
-
-
x := base64.RawURLEncoding.EncodeToString(xBytes)
-
y := base64.RawURLEncoding.EncodeToString(yBytes)
-
-
jwk := map[string]interface{}{
-
"kty": "EC",
-
"crv": "P-256",
-
"x": x,
-
"y": y,
-
}
-
-
// Calculate thumbprint
-
thumbprint, err := CalculateJWKThumbprint(jwk)
-
if err != nil {
-
t.Fatalf("Failed to calculate thumbprint: %v", err)
-
}
-
-
return &testECKey{
-
privateKey: privateKey,
-
publicKey: &privateKey.PublicKey,
-
jwk: jwk,
-
thumbprint: thumbprint,
-
}
-
}
-
-
// padTo32Bytes pads a byte slice to 32 bytes (required for P-256 coordinates)
-
func padTo32Bytes(b []byte) []byte {
-
if len(b) >= 32 {
-
return b
-
}
-
padded := make([]byte, 32)
-
copy(padded[32-len(b):], b)
-
return padded
-
}
-
-
// createDPoPProof creates a DPoP proof JWT for testing
-
func createDPoPProof(t *testing.T, key *testECKey, method, uri string, iat time.Time, jti string) string {
-
t.Helper()
-
-
claims := &DPoPClaims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
ID: jti,
-
IssuedAt: jwt.NewNumericDate(iat),
-
},
-
HTTPMethod: method,
-
HTTPURI: uri,
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
-
token.Header["typ"] = "dpop+jwt"
-
token.Header["jwk"] = key.jwk
-
-
tokenString, err := token.SignedString(key.privateKey)
-
if err != nil {
-
t.Fatalf("Failed to create DPoP proof: %v", err)
-
}
-
-
return tokenString
-
}
-
-
// === JWK Thumbprint Tests (RFC 7638) ===
-
-
func TestCalculateJWKThumbprint_EC_P256(t *testing.T) {
-
// Test with known values from RFC 7638 Appendix A (adapted for P-256)
-
jwk := map[string]interface{}{
-
"kty": "EC",
-
"crv": "P-256",
-
"x": "WKn-ZIGevcwGIyyrzFoZNBdaq9_TsqzGl96oc0CWuis",
-
"y": "y77t-RvAHRKTsSGdIYUfweuOvwrvDD-Q3Hv5J0fSKbE",
-
}
-
-
thumbprint, err := CalculateJWKThumbprint(jwk)
-
if err != nil {
-
t.Fatalf("CalculateJWKThumbprint failed: %v", err)
-
}
-
-
if thumbprint == "" {
-
t.Error("Expected non-empty thumbprint")
-
}
-
-
// Verify it's valid base64url
-
_, err = base64.RawURLEncoding.DecodeString(thumbprint)
-
if err != nil {
-
t.Errorf("Thumbprint is not valid base64url: %v", err)
-
}
-
-
// Verify length (SHA-256 produces 32 bytes = 43 base64url chars)
-
if len(thumbprint) != 43 {
-
t.Errorf("Expected thumbprint length 43, got %d", len(thumbprint))
-
}
-
}
-
-
func TestCalculateJWKThumbprint_Deterministic(t *testing.T) {
-
// Same key should produce same thumbprint
-
jwk := map[string]interface{}{
-
"kty": "EC",
-
"crv": "P-256",
-
"x": "test-x-coordinate",
-
"y": "test-y-coordinate",
-
}
-
-
thumbprint1, err := CalculateJWKThumbprint(jwk)
-
if err != nil {
-
t.Fatalf("First CalculateJWKThumbprint failed: %v", err)
-
}
-
-
thumbprint2, err := CalculateJWKThumbprint(jwk)
-
if err != nil {
-
t.Fatalf("Second CalculateJWKThumbprint failed: %v", err)
-
}
-
-
if thumbprint1 != thumbprint2 {
-
t.Errorf("Thumbprints are not deterministic: %s != %s", thumbprint1, thumbprint2)
-
}
-
}
-
-
func TestCalculateJWKThumbprint_DifferentKeys(t *testing.T) {
-
// Different keys should produce different thumbprints
-
jwk1 := map[string]interface{}{
-
"kty": "EC",
-
"crv": "P-256",
-
"x": "coordinate-x-1",
-
"y": "coordinate-y-1",
-
}
-
-
jwk2 := map[string]interface{}{
-
"kty": "EC",
-
"crv": "P-256",
-
"x": "coordinate-x-2",
-
"y": "coordinate-y-2",
-
}
-
-
thumbprint1, err := CalculateJWKThumbprint(jwk1)
-
if err != nil {
-
t.Fatalf("First CalculateJWKThumbprint failed: %v", err)
-
}
-
-
thumbprint2, err := CalculateJWKThumbprint(jwk2)
-
if err != nil {
-
t.Fatalf("Second CalculateJWKThumbprint failed: %v", err)
-
}
-
-
if thumbprint1 == thumbprint2 {
-
t.Error("Different keys produced same thumbprint (collision)")
-
}
-
}
-
-
func TestCalculateJWKThumbprint_MissingKty(t *testing.T) {
-
jwk := map[string]interface{}{
-
"crv": "P-256",
-
"x": "test-x",
-
"y": "test-y",
-
}
-
-
_, err := CalculateJWKThumbprint(jwk)
-
if err == nil {
-
t.Error("Expected error for missing kty, got nil")
-
}
-
if err != nil && !contains(err.Error(), "missing kty") {
-
t.Errorf("Expected error about missing kty, got: %v", err)
-
}
-
}
-
-
func TestCalculateJWKThumbprint_EC_MissingCrv(t *testing.T) {
-
jwk := map[string]interface{}{
-
"kty": "EC",
-
"x": "test-x",
-
"y": "test-y",
-
}
-
-
_, err := CalculateJWKThumbprint(jwk)
-
if err == nil {
-
t.Error("Expected error for missing crv, got nil")
-
}
-
if err != nil && !contains(err.Error(), "missing crv") {
-
t.Errorf("Expected error about missing crv, got: %v", err)
-
}
-
}
-
-
func TestCalculateJWKThumbprint_EC_MissingX(t *testing.T) {
-
jwk := map[string]interface{}{
-
"kty": "EC",
-
"crv": "P-256",
-
"y": "test-y",
-
}
-
-
_, err := CalculateJWKThumbprint(jwk)
-
if err == nil {
-
t.Error("Expected error for missing x, got nil")
-
}
-
if err != nil && !contains(err.Error(), "missing x") {
-
t.Errorf("Expected error about missing x, got: %v", err)
-
}
-
}
-
-
func TestCalculateJWKThumbprint_EC_MissingY(t *testing.T) {
-
jwk := map[string]interface{}{
-
"kty": "EC",
-
"crv": "P-256",
-
"x": "test-x",
-
}
-
-
_, err := CalculateJWKThumbprint(jwk)
-
if err == nil {
-
t.Error("Expected error for missing y, got nil")
-
}
-
if err != nil && !contains(err.Error(), "missing y") {
-
t.Errorf("Expected error about missing y, got: %v", err)
-
}
-
}
-
-
func TestCalculateJWKThumbprint_RSA(t *testing.T) {
-
// Test RSA key thumbprint calculation
-
jwk := map[string]interface{}{
-
"kty": "RSA",
-
"e": "AQAB",
-
"n": "test-modulus",
-
}
-
-
thumbprint, err := CalculateJWKThumbprint(jwk)
-
if err != nil {
-
t.Fatalf("CalculateJWKThumbprint failed for RSA: %v", err)
-
}
-
-
if thumbprint == "" {
-
t.Error("Expected non-empty thumbprint for RSA key")
-
}
-
}
-
-
func TestCalculateJWKThumbprint_OKP(t *testing.T) {
-
// Test OKP (Octet Key Pair) thumbprint calculation
-
jwk := map[string]interface{}{
-
"kty": "OKP",
-
"crv": "Ed25519",
-
"x": "test-x-coordinate",
-
}
-
-
thumbprint, err := CalculateJWKThumbprint(jwk)
-
if err != nil {
-
t.Fatalf("CalculateJWKThumbprint failed for OKP: %v", err)
-
}
-
-
if thumbprint == "" {
-
t.Error("Expected non-empty thumbprint for OKP key")
-
}
-
}
-
-
func TestCalculateJWKThumbprint_UnsupportedKeyType(t *testing.T) {
-
jwk := map[string]interface{}{
-
"kty": "UNKNOWN",
-
}
-
-
_, err := CalculateJWKThumbprint(jwk)
-
if err == nil {
-
t.Error("Expected error for unsupported key type, got nil")
-
}
-
if err != nil && !contains(err.Error(), "unsupported JWK key type") {
-
t.Errorf("Expected error about unsupported key type, got: %v", err)
-
}
-
}
-
-
func TestCalculateJWKThumbprint_CanonicalJSON(t *testing.T) {
-
// RFC 7638 requires lexicographic ordering of keys in canonical JSON
-
// This test verifies that the canonical JSON is correctly ordered
-
-
jwk := map[string]interface{}{
-
"kty": "EC",
-
"crv": "P-256",
-
"x": "x-coord",
-
"y": "y-coord",
-
}
-
-
// The canonical JSON should be: {"crv":"P-256","kty":"EC","x":"x-coord","y":"y-coord"}
-
// (lexicographically ordered: crv, kty, x, y)
-
-
canonical := map[string]string{
-
"crv": "P-256",
-
"kty": "EC",
-
"x": "x-coord",
-
"y": "y-coord",
-
}
-
-
canonicalJSON, err := json.Marshal(canonical)
-
if err != nil {
-
t.Fatalf("Failed to marshal canonical JSON: %v", err)
-
}
-
-
expectedHash := sha256.Sum256(canonicalJSON)
-
expectedThumbprint := base64.RawURLEncoding.EncodeToString(expectedHash[:])
-
-
actualThumbprint, err := CalculateJWKThumbprint(jwk)
-
if err != nil {
-
t.Fatalf("CalculateJWKThumbprint failed: %v", err)
-
}
-
-
if actualThumbprint != expectedThumbprint {
-
t.Errorf("Thumbprint doesn't match expected canonical JSON hash\nExpected: %s\nGot: %s",
-
expectedThumbprint, actualThumbprint)
-
}
-
}
-
-
// === DPoP Proof Verification Tests ===
-
-
func TestVerifyDPoPProof_Valid(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, uri, iat, jti)
-
-
result, err := verifier.VerifyDPoPProof(proof, method, uri)
-
if err != nil {
-
t.Fatalf("VerifyDPoPProof failed for valid proof: %v", err)
-
}
-
-
if result == nil {
-
t.Fatal("Expected non-nil proof result")
-
}
-
-
if result.Claims.HTTPMethod != method {
-
t.Errorf("Expected method %s, got %s", method, result.Claims.HTTPMethod)
-
}
-
-
if result.Claims.HTTPURI != uri {
-
t.Errorf("Expected URI %s, got %s", uri, result.Claims.HTTPURI)
-
}
-
-
if result.Claims.ID != jti {
-
t.Errorf("Expected jti %s, got %s", jti, result.Claims.ID)
-
}
-
-
if result.Thumbprint != key.thumbprint {
-
t.Errorf("Expected thumbprint %s, got %s", key.thumbprint, result.Thumbprint)
-
}
-
}
-
-
func TestVerifyDPoPProof_InvalidSignature(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
wrongKey := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
// Create proof with one key
-
proof := createDPoPProof(t, key, method, uri, iat, jti)
-
-
// Parse and modify to use wrong key's JWK in header (signature won't match)
-
parts := splitJWT(proof)
-
header := parseJWTHeader(t, parts[0])
-
header["jwk"] = wrongKey.jwk
-
modifiedHeader := encodeJSON(t, header)
-
tamperedProof := modifiedHeader + "." + parts[1] + "." + parts[2]
-
-
_, err := verifier.VerifyDPoPProof(tamperedProof, method, uri)
-
if err == nil {
-
t.Error("Expected error for invalid signature, got nil")
-
}
-
if err != nil && !contains(err.Error(), "signature verification failed") {
-
t.Errorf("Expected signature verification error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_WrongHTTPMethod(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
wrongMethod := "GET"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, uri, iat, jti)
-
-
_, err := verifier.VerifyDPoPProof(proof, wrongMethod, uri)
-
if err == nil {
-
t.Error("Expected error for HTTP method mismatch, got nil")
-
}
-
if err != nil && !contains(err.Error(), "htm mismatch") {
-
t.Errorf("Expected htm mismatch error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_WrongURI(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
wrongURI := "https://api.example.com/different"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, uri, iat, jti)
-
-
_, err := verifier.VerifyDPoPProof(proof, method, wrongURI)
-
if err == nil {
-
t.Error("Expected error for URI mismatch, got nil")
-
}
-
if err != nil && !contains(err.Error(), "htu mismatch") {
-
t.Errorf("Expected htu mismatch error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_URIWithQuery(t *testing.T) {
-
// URI comparison should strip query and fragment
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
baseURI := "https://api.example.com/resource"
-
uriWithQuery := baseURI + "?param=value"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, baseURI, iat, jti)
-
-
// Should succeed because query is stripped
-
_, err := verifier.VerifyDPoPProof(proof, method, uriWithQuery)
-
if err != nil {
-
t.Fatalf("VerifyDPoPProof failed for URI with query: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_URIWithFragment(t *testing.T) {
-
// URI comparison should strip query and fragment
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
baseURI := "https://api.example.com/resource"
-
uriWithFragment := baseURI + "#section"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, baseURI, iat, jti)
-
-
// Should succeed because fragment is stripped
-
_, err := verifier.VerifyDPoPProof(proof, method, uriWithFragment)
-
if err != nil {
-
t.Fatalf("VerifyDPoPProof failed for URI with fragment: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_ExpiredProof(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
// Proof issued 10 minutes ago (exceeds default MaxProofAge of 5 minutes)
-
iat := time.Now().Add(-10 * time.Minute)
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, uri, iat, jti)
-
-
_, err := verifier.VerifyDPoPProof(proof, method, uri)
-
if err == nil {
-
t.Error("Expected error for expired proof, got nil")
-
}
-
if err != nil && !contains(err.Error(), "too old") {
-
t.Errorf("Expected 'too old' error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_FutureProof(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
// Proof issued 1 minute in the future (exceeds MaxClockSkew)
-
iat := time.Now().Add(1 * time.Minute)
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, uri, iat, jti)
-
-
_, err := verifier.VerifyDPoPProof(proof, method, uri)
-
if err == nil {
-
t.Error("Expected error for future proof, got nil")
-
}
-
if err != nil && !contains(err.Error(), "in the future") {
-
t.Errorf("Expected 'in the future' error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_WithinClockSkew(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
// Proof issued 15 seconds in the future (within MaxClockSkew of 30s)
-
iat := time.Now().Add(15 * time.Second)
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, uri, iat, jti)
-
-
_, err := verifier.VerifyDPoPProof(proof, method, uri)
-
if err != nil {
-
t.Fatalf("VerifyDPoPProof failed for proof within clock skew: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_MissingJti(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
-
claims := &DPoPClaims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
// No ID (jti)
-
IssuedAt: jwt.NewNumericDate(iat),
-
},
-
HTTPMethod: method,
-
HTTPURI: uri,
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
-
token.Header["typ"] = "dpop+jwt"
-
token.Header["jwk"] = key.jwk
-
-
proof, err := token.SignedString(key.privateKey)
-
if err != nil {
-
t.Fatalf("Failed to create test proof: %v", err)
-
}
-
-
_, err = verifier.VerifyDPoPProof(proof, method, uri)
-
if err == nil {
-
t.Error("Expected error for missing jti, got nil")
-
}
-
if err != nil && !contains(err.Error(), "missing jti") {
-
t.Errorf("Expected missing jti error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_MissingTypHeader(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
claims := &DPoPClaims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
ID: jti,
-
IssuedAt: jwt.NewNumericDate(iat),
-
},
-
HTTPMethod: method,
-
HTTPURI: uri,
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
-
// Don't set typ header
-
token.Header["jwk"] = key.jwk
-
-
proof, err := token.SignedString(key.privateKey)
-
if err != nil {
-
t.Fatalf("Failed to create test proof: %v", err)
-
}
-
-
_, err = verifier.VerifyDPoPProof(proof, method, uri)
-
if err == nil {
-
t.Error("Expected error for missing typ header, got nil")
-
}
-
if err != nil && !contains(err.Error(), "typ must be 'dpop+jwt'") {
-
t.Errorf("Expected typ header error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_WrongTypHeader(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
claims := &DPoPClaims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
ID: jti,
-
IssuedAt: jwt.NewNumericDate(iat),
-
},
-
HTTPMethod: method,
-
HTTPURI: uri,
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
-
token.Header["typ"] = "JWT" // Wrong typ
-
token.Header["jwk"] = key.jwk
-
-
proof, err := token.SignedString(key.privateKey)
-
if err != nil {
-
t.Fatalf("Failed to create test proof: %v", err)
-
}
-
-
_, err = verifier.VerifyDPoPProof(proof, method, uri)
-
if err == nil {
-
t.Error("Expected error for wrong typ header, got nil")
-
}
-
if err != nil && !contains(err.Error(), "typ must be 'dpop+jwt'") {
-
t.Errorf("Expected typ header error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_MissingJWK(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
claims := &DPoPClaims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
ID: jti,
-
IssuedAt: jwt.NewNumericDate(iat),
-
},
-
HTTPMethod: method,
-
HTTPURI: uri,
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
-
token.Header["typ"] = "dpop+jwt"
-
// Don't include JWK
-
-
proof, err := token.SignedString(key.privateKey)
-
if err != nil {
-
t.Fatalf("Failed to create test proof: %v", err)
-
}
-
-
_, err = verifier.VerifyDPoPProof(proof, method, uri)
-
if err == nil {
-
t.Error("Expected error for missing jwk header, got nil")
-
}
-
if err != nil && !contains(err.Error(), "missing jwk") {
-
t.Errorf("Expected missing jwk error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_CustomTimeSettings(t *testing.T) {
-
verifier := &DPoPVerifier{
-
MaxClockSkew: 1 * time.Minute,
-
MaxProofAge: 10 * time.Minute,
-
}
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
// Proof issued 50 seconds in the future (within custom MaxClockSkew)
-
iat := time.Now().Add(50 * time.Second)
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, uri, iat, jti)
-
-
_, err := verifier.VerifyDPoPProof(proof, method, uri)
-
if err != nil {
-
t.Fatalf("VerifyDPoPProof failed with custom time settings: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_HTTPMethodCaseInsensitive(t *testing.T) {
-
// HTTP method comparison should be case-insensitive per spec
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "post"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, uri, iat, jti)
-
-
// Verify with uppercase method
-
_, err := verifier.VerifyDPoPProof(proof, "POST", uri)
-
if err != nil {
-
t.Fatalf("VerifyDPoPProof failed for case-insensitive method: %v", err)
-
}
-
}
-
-
// === Token Binding Verification Tests ===
-
-
func TestVerifyTokenBinding_Matching(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, uri, iat, jti)
-
-
result, err := verifier.VerifyDPoPProof(proof, method, uri)
-
if err != nil {
-
t.Fatalf("VerifyDPoPProof failed: %v", err)
-
}
-
-
// Verify token binding with matching thumbprint
-
err = verifier.VerifyTokenBinding(result, key.thumbprint)
-
if err != nil {
-
t.Fatalf("VerifyTokenBinding failed for matching thumbprint: %v", err)
-
}
-
}
-
-
func TestVerifyTokenBinding_Mismatch(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
wrongKey := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, uri, iat, jti)
-
-
result, err := verifier.VerifyDPoPProof(proof, method, uri)
-
if err != nil {
-
t.Fatalf("VerifyDPoPProof failed: %v", err)
-
}
-
-
// Verify token binding with wrong thumbprint
-
err = verifier.VerifyTokenBinding(result, wrongKey.thumbprint)
-
if err == nil {
-
t.Error("Expected error for thumbprint mismatch, got nil")
-
}
-
if err != nil && !contains(err.Error(), "thumbprint mismatch") {
-
t.Errorf("Expected thumbprint mismatch error, got: %v", err)
-
}
-
}
-
-
// === ExtractCnfJkt Tests ===
-
-
func TestExtractCnfJkt_Valid(t *testing.T) {
-
expectedJkt := "test-thumbprint-123"
-
claims := &Claims{
-
Confirmation: map[string]interface{}{
-
"jkt": expectedJkt,
-
},
-
}
-
-
jkt, err := ExtractCnfJkt(claims)
-
if err != nil {
-
t.Fatalf("ExtractCnfJkt failed for valid claims: %v", err)
-
}
-
-
if jkt != expectedJkt {
-
t.Errorf("Expected jkt %s, got %s", expectedJkt, jkt)
-
}
-
}
-
-
func TestExtractCnfJkt_MissingCnf(t *testing.T) {
-
claims := &Claims{
-
// No Confirmation
-
}
-
-
_, err := ExtractCnfJkt(claims)
-
if err == nil {
-
t.Error("Expected error for missing cnf, got nil")
-
}
-
if err != nil && !contains(err.Error(), "missing cnf claim") {
-
t.Errorf("Expected missing cnf error, got: %v", err)
-
}
-
}
-
-
func TestExtractCnfJkt_NilCnf(t *testing.T) {
-
claims := &Claims{
-
Confirmation: nil,
-
}
-
-
_, err := ExtractCnfJkt(claims)
-
if err == nil {
-
t.Error("Expected error for nil cnf, got nil")
-
}
-
if err != nil && !contains(err.Error(), "missing cnf claim") {
-
t.Errorf("Expected missing cnf error, got: %v", err)
-
}
-
}
-
-
func TestExtractCnfJkt_MissingJkt(t *testing.T) {
-
claims := &Claims{
-
Confirmation: map[string]interface{}{
-
"other": "value",
-
},
-
}
-
-
_, err := ExtractCnfJkt(claims)
-
if err == nil {
-
t.Error("Expected error for missing jkt, got nil")
-
}
-
if err != nil && !contains(err.Error(), "missing jkt") {
-
t.Errorf("Expected missing jkt error, got: %v", err)
-
}
-
}
-
-
func TestExtractCnfJkt_EmptyJkt(t *testing.T) {
-
claims := &Claims{
-
Confirmation: map[string]interface{}{
-
"jkt": "",
-
},
-
}
-
-
_, err := ExtractCnfJkt(claims)
-
if err == nil {
-
t.Error("Expected error for empty jkt, got nil")
-
}
-
if err != nil && !contains(err.Error(), "missing jkt") {
-
t.Errorf("Expected missing jkt error, got: %v", err)
-
}
-
}
-
-
func TestExtractCnfJkt_WrongType(t *testing.T) {
-
claims := &Claims{
-
Confirmation: map[string]interface{}{
-
"jkt": 123, // Not a string
-
},
-
}
-
-
_, err := ExtractCnfJkt(claims)
-
if err == nil {
-
t.Error("Expected error for wrong type jkt, got nil")
-
}
-
if err != nil && !contains(err.Error(), "missing jkt") {
-
t.Errorf("Expected missing jkt error, got: %v", err)
-
}
-
}
-
-
// === Helper Functions for Tests ===
-
-
// splitJWT splits a JWT into its three parts
-
func splitJWT(token string) []string {
-
return []string{
-
token[:strings.IndexByte(token, '.')],
-
token[strings.IndexByte(token, '.')+1 : strings.LastIndexByte(token, '.')],
-
token[strings.LastIndexByte(token, '.')+1:],
-
}
-
}
-
-
// parseJWTHeader parses a base64url-encoded JWT header
-
func parseJWTHeader(t *testing.T, encoded string) map[string]interface{} {
-
t.Helper()
-
decoded, err := base64.RawURLEncoding.DecodeString(encoded)
-
if err != nil {
-
t.Fatalf("Failed to decode header: %v", err)
-
}
-
-
var header map[string]interface{}
-
if err := json.Unmarshal(decoded, &header); err != nil {
-
t.Fatalf("Failed to unmarshal header: %v", err)
-
}
-
-
return header
-
}
-
-
// encodeJSON encodes a value to base64url-encoded JSON
-
func encodeJSON(t *testing.T, v interface{}) string {
-
t.Helper()
-
data, err := json.Marshal(v)
-
if err != nil {
-
t.Fatalf("Failed to marshal JSON: %v", err)
-
}
-
return base64.RawURLEncoding.EncodeToString(data)
-
}
-
-
// === ES256K (secp256k1) Test Helpers ===
-
-
// testES256KKey holds a test ES256K key pair using indigo
-
type testES256KKey struct {
-
privateKey indigoCrypto.PrivateKey
-
publicKey indigoCrypto.PublicKey
-
jwk map[string]interface{}
-
thumbprint string
-
}
-
-
// generateTestES256KKey generates a test ES256K (secp256k1) key pair and JWK
-
func generateTestES256KKey(t *testing.T) *testES256KKey {
-
t.Helper()
-
-
privateKey, err := indigoCrypto.GeneratePrivateKeyK256()
-
if err != nil {
-
t.Fatalf("Failed to generate ES256K test key: %v", err)
-
}
-
-
publicKey, err := privateKey.PublicKey()
-
if err != nil {
-
t.Fatalf("Failed to get public key from ES256K private key: %v", err)
-
}
-
-
// Get the JWK representation
-
jwkStruct, err := publicKey.JWK()
-
if err != nil {
-
t.Fatalf("Failed to get JWK from ES256K public key: %v", err)
-
}
-
jwk := map[string]interface{}{
-
"kty": jwkStruct.KeyType,
-
"crv": jwkStruct.Curve,
-
"x": jwkStruct.X,
-
"y": jwkStruct.Y,
-
}
-
-
// Calculate thumbprint
-
thumbprint, err := CalculateJWKThumbprint(jwk)
-
if err != nil {
-
t.Fatalf("Failed to calculate ES256K thumbprint: %v", err)
-
}
-
-
return &testES256KKey{
-
privateKey: privateKey,
-
publicKey: publicKey,
-
jwk: jwk,
-
thumbprint: thumbprint,
-
}
-
}
-
-
// createES256KDPoPProof creates a DPoP proof JWT using ES256K for testing
-
func createES256KDPoPProof(t *testing.T, key *testES256KKey, method, uri string, iat time.Time, jti string) string {
-
t.Helper()
-
-
// Build claims
-
claims := map[string]interface{}{
-
"jti": jti,
-
"iat": iat.Unix(),
-
"htm": method,
-
"htu": uri,
-
}
-
-
// Build header
-
header := map[string]interface{}{
-
"typ": "dpop+jwt",
-
"alg": "ES256K",
-
"jwk": key.jwk,
-
}
-
-
// Encode header and claims
-
headerJSON, err := json.Marshal(header)
-
if err != nil {
-
t.Fatalf("Failed to marshal header: %v", err)
-
}
-
claimsJSON, err := json.Marshal(claims)
-
if err != nil {
-
t.Fatalf("Failed to marshal claims: %v", err)
-
}
-
-
headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
-
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
-
-
// Sign with indigo
-
signingInput := headerB64 + "." + claimsB64
-
signature, err := key.privateKey.HashAndSign([]byte(signingInput))
-
if err != nil {
-
t.Fatalf("Failed to sign ES256K proof: %v", err)
-
}
-
-
signatureB64 := base64.RawURLEncoding.EncodeToString(signature)
-
return signingInput + "." + signatureB64
-
}
-
-
// === ES256K Tests ===
-
-
func TestVerifyDPoPProof_ES256K_Valid(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256KKey(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
proof := createES256KDPoPProof(t, key, method, uri, iat, jti)
-
-
result, err := verifier.VerifyDPoPProof(proof, method, uri)
-
if err != nil {
-
t.Fatalf("VerifyDPoPProof failed for valid ES256K proof: %v", err)
-
}
-
-
if result == nil {
-
t.Fatal("Expected non-nil proof result")
-
}
-
-
if result.Claims.HTTPMethod != method {
-
t.Errorf("Expected method %s, got %s", method, result.Claims.HTTPMethod)
-
}
-
-
if result.Claims.HTTPURI != uri {
-
t.Errorf("Expected URI %s, got %s", uri, result.Claims.HTTPURI)
-
}
-
-
if result.Thumbprint != key.thumbprint {
-
t.Errorf("Expected thumbprint %s, got %s", key.thumbprint, result.Thumbprint)
-
}
-
}
-
-
func TestVerifyDPoPProof_ES256K_InvalidSignature(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256KKey(t)
-
wrongKey := generateTestES256KKey(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
// Create proof with one key
-
proof := createES256KDPoPProof(t, key, method, uri, iat, jti)
-
-
// Tamper by replacing JWK with wrong key
-
parts := splitJWT(proof)
-
header := parseJWTHeader(t, parts[0])
-
header["jwk"] = wrongKey.jwk
-
modifiedHeader := encodeJSON(t, header)
-
tamperedProof := modifiedHeader + "." + parts[1] + "." + parts[2]
-
-
_, err := verifier.VerifyDPoPProof(tamperedProof, method, uri)
-
if err == nil {
-
t.Error("Expected error for invalid ES256K signature, got nil")
-
}
-
if err != nil && !contains(err.Error(), "signature verification failed") {
-
t.Errorf("Expected signature verification error, got: %v", err)
-
}
-
}
-
-
func TestCalculateJWKThumbprint_ES256K(t *testing.T) {
-
// Test thumbprint calculation for secp256k1 keys
-
key := generateTestES256KKey(t)
-
-
thumbprint, err := CalculateJWKThumbprint(key.jwk)
-
if err != nil {
-
t.Fatalf("CalculateJWKThumbprint failed for ES256K: %v", err)
-
}
-
-
if thumbprint == "" {
-
t.Error("Expected non-empty thumbprint for ES256K key")
-
}
-
-
// Verify it's valid base64url
-
_, err = base64.RawURLEncoding.DecodeString(thumbprint)
-
if err != nil {
-
t.Errorf("ES256K thumbprint is not valid base64url: %v", err)
-
}
-
-
// Verify length (SHA-256 produces 32 bytes = 43 base64url chars)
-
if len(thumbprint) != 43 {
-
t.Errorf("Expected ES256K thumbprint length 43, got %d", len(thumbprint))
-
}
-
}
-
-
// === Algorithm-Curve Binding Tests ===
-
-
func TestVerifyDPoPProof_AlgorithmCurveMismatch_ES256KWithP256Key(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t) // P-256 key
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
// Create a proof claiming ES256K but using P-256 key
-
claims := &DPoPClaims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
ID: jti,
-
IssuedAt: jwt.NewNumericDate(iat),
-
},
-
HTTPMethod: method,
-
HTTPURI: uri,
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
-
token.Header["typ"] = "dpop+jwt"
-
token.Header["alg"] = "ES256K" // Claim ES256K
-
token.Header["jwk"] = key.jwk // But use P-256 key
-
-
proof, err := token.SignedString(key.privateKey)
-
if err != nil {
-
t.Fatalf("Failed to create test proof: %v", err)
-
}
-
-
_, err = verifier.VerifyDPoPProof(proof, method, uri)
-
if err == nil {
-
t.Error("Expected error for ES256K algorithm with P-256 curve, got nil")
-
}
-
if err != nil && !contains(err.Error(), "requires curve secp256k1") {
-
t.Errorf("Expected curve mismatch error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_AlgorithmCurveMismatch_ES256WithSecp256k1Key(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256KKey(t) // secp256k1 key
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
// Build claims
-
claims := map[string]interface{}{
-
"jti": jti,
-
"iat": iat.Unix(),
-
"htm": method,
-
"htu": uri,
-
}
-
-
// Build header claiming ES256 but using secp256k1 key
-
header := map[string]interface{}{
-
"typ": "dpop+jwt",
-
"alg": "ES256", // Claim ES256
-
"jwk": key.jwk, // But use secp256k1 key
-
}
-
-
headerJSON, _ := json.Marshal(header)
-
claimsJSON, _ := json.Marshal(claims)
-
-
headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
-
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
-
-
signingInput := headerB64 + "." + claimsB64
-
signature, err := key.privateKey.HashAndSign([]byte(signingInput))
-
if err != nil {
-
t.Fatalf("Failed to sign: %v", err)
-
}
-
-
proof := signingInput + "." + base64.RawURLEncoding.EncodeToString(signature)
-
-
_, err = verifier.VerifyDPoPProof(proof, method, uri)
-
if err == nil {
-
t.Error("Expected error for ES256 algorithm with secp256k1 curve, got nil")
-
}
-
if err != nil && !contains(err.Error(), "requires curve P-256") {
-
t.Errorf("Expected curve mismatch error, got: %v", err)
-
}
-
}
-
-
// === exp/nbf Validation Tests ===
-
-
func TestVerifyDPoPProof_ExpiredWithExpClaim(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now().Add(-2 * time.Minute)
-
exp := time.Now().Add(-1 * time.Minute) // Expired 1 minute ago
-
jti := uuid.New().String()
-
-
claims := &DPoPClaims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
ID: jti,
-
IssuedAt: jwt.NewNumericDate(iat),
-
ExpiresAt: jwt.NewNumericDate(exp),
-
},
-
HTTPMethod: method,
-
HTTPURI: uri,
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
-
token.Header["typ"] = "dpop+jwt"
-
token.Header["jwk"] = key.jwk
-
-
proof, err := token.SignedString(key.privateKey)
-
if err != nil {
-
t.Fatalf("Failed to create test proof: %v", err)
-
}
-
-
_, err = verifier.VerifyDPoPProof(proof, method, uri)
-
if err == nil {
-
t.Error("Expected error for expired proof with exp claim, got nil")
-
}
-
if err != nil && !contains(err.Error(), "expired") {
-
t.Errorf("Expected expiration error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_NotYetValidWithNbfClaim(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
nbf := time.Now().Add(5 * time.Minute) // Not valid for another 5 minutes
-
jti := uuid.New().String()
-
-
claims := &DPoPClaims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
ID: jti,
-
IssuedAt: jwt.NewNumericDate(iat),
-
NotBefore: jwt.NewNumericDate(nbf),
-
},
-
HTTPMethod: method,
-
HTTPURI: uri,
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
-
token.Header["typ"] = "dpop+jwt"
-
token.Header["jwk"] = key.jwk
-
-
proof, err := token.SignedString(key.privateKey)
-
if err != nil {
-
t.Fatalf("Failed to create test proof: %v", err)
-
}
-
-
_, err = verifier.VerifyDPoPProof(proof, method, uri)
-
if err == nil {
-
t.Error("Expected error for not-yet-valid proof with nbf claim, got nil")
-
}
-
if err != nil && !contains(err.Error(), "not valid before") {
-
t.Errorf("Expected not-before error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_ValidWithExpClaimInFuture(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
exp := time.Now().Add(5 * time.Minute) // Valid for 5 more minutes
-
jti := uuid.New().String()
-
-
claims := &DPoPClaims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
ID: jti,
-
IssuedAt: jwt.NewNumericDate(iat),
-
ExpiresAt: jwt.NewNumericDate(exp),
-
},
-
HTTPMethod: method,
-
HTTPURI: uri,
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
-
token.Header["typ"] = "dpop+jwt"
-
token.Header["jwk"] = key.jwk
-
-
proof, err := token.SignedString(key.privateKey)
-
if err != nil {
-
t.Fatalf("Failed to create test proof: %v", err)
-
}
-
-
result, err := verifier.VerifyDPoPProof(proof, method, uri)
-
if err != nil {
-
t.Fatalf("VerifyDPoPProof failed for valid proof with exp in future: %v", err)
-
}
-
-
if result == nil {
-
t.Error("Expected non-nil result for valid proof")
-
}
-
}
-189
internal/atproto/auth/jwks_fetcher.go
···
-
package auth
-
-
import (
-
"context"
-
"encoding/json"
-
"fmt"
-
"net/http"
-
"strings"
-
"sync"
-
"time"
-
)
-
-
// CachedJWKSFetcher fetches and caches JWKS from authorization servers
-
type CachedJWKSFetcher struct {
-
cache map[string]*cachedJWKS
-
httpClient *http.Client
-
cacheMutex sync.RWMutex
-
cacheTTL time.Duration
-
}
-
-
type cachedJWKS struct {
-
jwks *JWKS
-
expiresAt time.Time
-
}
-
-
// NewCachedJWKSFetcher creates a new JWKS fetcher with caching
-
func NewCachedJWKSFetcher(cacheTTL time.Duration) *CachedJWKSFetcher {
-
return &CachedJWKSFetcher{
-
cache: make(map[string]*cachedJWKS),
-
httpClient: &http.Client{
-
Timeout: 10 * time.Second,
-
},
-
cacheTTL: cacheTTL,
-
}
-
}
-
-
// FetchPublicKey fetches the public key for verifying a JWT from the issuer
-
// Implements JWKSFetcher interface
-
// Returns interface{} to support both RSA and ECDSA keys
-
func (f *CachedJWKSFetcher) FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) {
-
// Extract key ID from token
-
kid, err := ExtractKeyID(token)
-
if err != nil {
-
return nil, fmt.Errorf("failed to extract key ID: %w", err)
-
}
-
-
// Get JWKS from cache or fetch
-
jwks, err := f.getJWKS(ctx, issuer)
-
if err != nil {
-
return nil, err
-
}
-
-
// Find the key by ID
-
jwk, err := jwks.FindKeyByID(kid)
-
if err != nil {
-
// Key not found in cache - try refreshing
-
jwks, err = f.fetchJWKS(ctx, issuer)
-
if err != nil {
-
return nil, fmt.Errorf("failed to refresh JWKS: %w", err)
-
}
-
f.cacheJWKS(issuer, jwks)
-
-
// Try again with fresh JWKS
-
jwk, err = jwks.FindKeyByID(kid)
-
if err != nil {
-
return nil, err
-
}
-
}
-
-
// Convert JWK to public key (RSA or ECDSA)
-
return jwk.ToPublicKey()
-
}
-
-
// getJWKS gets JWKS from cache or fetches if not cached/expired
-
func (f *CachedJWKSFetcher) getJWKS(ctx context.Context, issuer string) (*JWKS, error) {
-
// Check cache first
-
f.cacheMutex.RLock()
-
cached, exists := f.cache[issuer]
-
f.cacheMutex.RUnlock()
-
-
if exists && time.Now().Before(cached.expiresAt) {
-
return cached.jwks, nil
-
}
-
-
// Not in cache or expired - fetch from issuer
-
jwks, err := f.fetchJWKS(ctx, issuer)
-
if err != nil {
-
return nil, err
-
}
-
-
// Cache it
-
f.cacheJWKS(issuer, jwks)
-
-
return jwks, nil
-
}
-
-
// fetchJWKS fetches JWKS from the authorization server
-
func (f *CachedJWKSFetcher) fetchJWKS(ctx context.Context, issuer string) (*JWKS, error) {
-
// Step 1: Fetch OAuth server metadata to get JWKS URI
-
metadataURL := strings.TrimSuffix(issuer, "/") + "/.well-known/oauth-authorization-server"
-
-
req, err := http.NewRequestWithContext(ctx, "GET", metadataURL, nil)
-
if err != nil {
-
return nil, fmt.Errorf("failed to create metadata request: %w", err)
-
}
-
-
resp, err := f.httpClient.Do(req)
-
if err != nil {
-
return nil, fmt.Errorf("failed to fetch metadata: %w", err)
-
}
-
defer func() {
-
_ = resp.Body.Close()
-
}()
-
-
if resp.StatusCode != http.StatusOK {
-
return nil, fmt.Errorf("metadata endpoint returned status %d", resp.StatusCode)
-
}
-
-
var metadata struct {
-
JWKSURI string `json:"jwks_uri"`
-
}
-
if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil {
-
return nil, fmt.Errorf("failed to decode metadata: %w", err)
-
}
-
-
if metadata.JWKSURI == "" {
-
return nil, fmt.Errorf("jwks_uri not found in metadata")
-
}
-
-
// Step 2: Fetch JWKS from the JWKS URI
-
jwksReq, err := http.NewRequestWithContext(ctx, "GET", metadata.JWKSURI, nil)
-
if err != nil {
-
return nil, fmt.Errorf("failed to create JWKS request: %w", err)
-
}
-
-
jwksResp, err := f.httpClient.Do(jwksReq)
-
if err != nil {
-
return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
-
}
-
defer func() {
-
_ = jwksResp.Body.Close()
-
}()
-
-
if jwksResp.StatusCode != http.StatusOK {
-
return nil, fmt.Errorf("JWKS endpoint returned status %d", jwksResp.StatusCode)
-
}
-
-
var jwks JWKS
-
if err := json.NewDecoder(jwksResp.Body).Decode(&jwks); err != nil {
-
return nil, fmt.Errorf("failed to decode JWKS: %w", err)
-
}
-
-
if len(jwks.Keys) == 0 {
-
return nil, fmt.Errorf("no keys found in JWKS")
-
}
-
-
return &jwks, nil
-
}
-
-
// cacheJWKS stores JWKS in the cache
-
func (f *CachedJWKSFetcher) cacheJWKS(issuer string, jwks *JWKS) {
-
f.cacheMutex.Lock()
-
defer f.cacheMutex.Unlock()
-
-
f.cache[issuer] = &cachedJWKS{
-
jwks: jwks,
-
expiresAt: time.Now().Add(f.cacheTTL),
-
}
-
}
-
-
// ClearCache clears the entire JWKS cache
-
func (f *CachedJWKSFetcher) ClearCache() {
-
f.cacheMutex.Lock()
-
defer f.cacheMutex.Unlock()
-
f.cache = make(map[string]*cachedJWKS)
-
}
-
-
// CleanupExpiredCache removes expired entries from the cache
-
func (f *CachedJWKSFetcher) CleanupExpiredCache() {
-
f.cacheMutex.Lock()
-
defer f.cacheMutex.Unlock()
-
-
now := time.Now()
-
for issuer, cached := range f.cache {
-
if now.After(cached.expiresAt) {
-
delete(f.cache, issuer)
-
}
-
}
-
}
-709
internal/atproto/auth/jwt.go
···
-
package auth
-
-
import (
-
"context"
-
"crypto/ecdsa"
-
"crypto/elliptic"
-
"crypto/rsa"
-
"encoding/base64"
-
"encoding/json"
-
"fmt"
-
"math/big"
-
"net/url"
-
"os"
-
"strings"
-
"sync"
-
"time"
-
-
indigoCrypto "github.com/bluesky-social/indigo/atproto/atcrypto"
-
"github.com/golang-jwt/jwt/v5"
-
)
-
-
// jwtConfig holds cached JWT configuration to avoid reading env vars on every request
-
type jwtConfig struct {
-
hs256Issuers map[string]struct{} // Set of whitelisted HS256 issuers
-
pdsJWTSecret []byte // Cached PDS_JWT_SECRET
-
isDevEnv bool // Cached IS_DEV_ENV
-
}
-
-
var (
-
cachedConfig *jwtConfig
-
configOnce sync.Once
-
)
-
-
// InitJWTConfig initializes the JWT configuration from environment variables.
-
// This should be called once at startup. If not called explicitly, it will be
-
// initialized lazily on first use.
-
func InitJWTConfig() {
-
configOnce.Do(func() {
-
cachedConfig = &jwtConfig{
-
hs256Issuers: make(map[string]struct{}),
-
isDevEnv: os.Getenv("IS_DEV_ENV") == "true",
-
}
-
-
// Parse HS256_ISSUERS into a set for O(1) lookup
-
if issuers := os.Getenv("HS256_ISSUERS"); issuers != "" {
-
for _, issuer := range strings.Split(issuers, ",") {
-
issuer = strings.TrimSpace(issuer)
-
if issuer != "" {
-
cachedConfig.hs256Issuers[issuer] = struct{}{}
-
}
-
}
-
}
-
-
// Cache PDS_JWT_SECRET
-
if secret := os.Getenv("PDS_JWT_SECRET"); secret != "" {
-
cachedConfig.pdsJWTSecret = []byte(secret)
-
}
-
})
-
}
-
-
// getConfig returns the cached config, initializing if needed
-
func getConfig() *jwtConfig {
-
InitJWTConfig()
-
return cachedConfig
-
}
-
-
// ResetJWTConfigForTesting resets the cached config to allow re-initialization.
-
// This should ONLY be used in tests.
-
func ResetJWTConfigForTesting() {
-
cachedConfig = nil
-
configOnce = sync.Once{}
-
}
-
-
// Algorithm constants for JWT signing methods
-
const (
-
AlgorithmHS256 = "HS256"
-
AlgorithmRS256 = "RS256"
-
AlgorithmES256 = "ES256"
-
)
-
-
// JWTHeader represents the parsed JWT header
-
type JWTHeader struct {
-
Alg string `json:"alg"`
-
Kid string `json:"kid"`
-
Typ string `json:"typ,omitempty"`
-
}
-
-
// Claims represents the standard JWT claims we care about
-
type Claims struct {
-
jwt.RegisteredClaims
-
// Confirmation claim for DPoP token binding (RFC 9449)
-
// Contains "jkt" (JWK thumbprint) when token is bound to a DPoP key
-
Confirmation map[string]interface{} `json:"cnf,omitempty"`
-
Scope string `json:"scope,omitempty"`
-
}
-
-
// stripBearerPrefix removes the "Bearer " prefix from a token string
-
func stripBearerPrefix(tokenString string) string {
-
tokenString = strings.TrimPrefix(tokenString, "Bearer ")
-
return strings.TrimSpace(tokenString)
-
}
-
-
// ParseJWTHeader extracts and parses the JWT header from a token string
-
// This is a reusable function for getting algorithm and key ID information
-
func ParseJWTHeader(tokenString string) (*JWTHeader, error) {
-
tokenString = stripBearerPrefix(tokenString)
-
-
parts := strings.Split(tokenString, ".")
-
if len(parts) != 3 {
-
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
-
}
-
-
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
-
if err != nil {
-
return nil, fmt.Errorf("failed to decode JWT header: %w", err)
-
}
-
-
var header JWTHeader
-
if err := json.Unmarshal(headerBytes, &header); err != nil {
-
return nil, fmt.Errorf("failed to parse JWT header: %w", err)
-
}
-
-
return &header, nil
-
}
-
-
// shouldUseHS256 determines if a token should use HS256 verification
-
// This prevents algorithm confusion attacks by using multiple signals:
-
// 1. If the token has a `kid` (key ID), it MUST use asymmetric verification
-
// 2. If no `kid`, only allow HS256 from whitelisted issuers (your own PDS)
-
//
-
// This approach supports open federation because:
-
// - External PDSes publish keys via JWKS and include `kid` in their tokens
-
// - Only your own PDS (which shares PDS_JWT_SECRET) uses HS256 without `kid`
-
func shouldUseHS256(header *JWTHeader, issuer string) bool {
-
// If token has a key ID, it MUST use asymmetric verification
-
// This is the primary defense against algorithm confusion attacks
-
if header.Kid != "" {
-
return false
-
}
-
-
// No kid - check if issuer is whitelisted for HS256
-
// This should only include your own PDS URL(s)
-
return isHS256IssuerWhitelisted(issuer)
-
}
-
-
// isHS256IssuerWhitelisted checks if the issuer is in the HS256 whitelist
-
// Only your own PDS should be in this list - external PDSes should use JWKS
-
func isHS256IssuerWhitelisted(issuer string) bool {
-
cfg := getConfig()
-
_, whitelisted := cfg.hs256Issuers[issuer]
-
return whitelisted
-
}
-
-
// ParseJWT parses a JWT token without verification (Phase 1)
-
// Returns the claims if the token is valid JSON and has required fields
-
func ParseJWT(tokenString string) (*Claims, error) {
-
// Remove "Bearer " prefix if present
-
tokenString = stripBearerPrefix(tokenString)
-
-
// Parse without verification first to extract claims
-
parser := jwt.NewParser(jwt.WithoutClaimsValidation())
-
token, _, err := parser.ParseUnverified(tokenString, &Claims{})
-
if err != nil {
-
return nil, fmt.Errorf("failed to parse JWT: %w", err)
-
}
-
-
claims, ok := token.Claims.(*Claims)
-
if !ok {
-
return nil, fmt.Errorf("invalid claims type")
-
}
-
-
// Validate required fields
-
if claims.Subject == "" {
-
return nil, fmt.Errorf("missing 'sub' claim (user DID)")
-
}
-
-
// atProto PDSes may use 'aud' instead of 'iss' for the authorization server
-
// If 'iss' is missing, use 'aud' as the authorization server identifier
-
if claims.Issuer == "" {
-
if len(claims.Audience) > 0 {
-
claims.Issuer = claims.Audience[0]
-
} else {
-
return nil, fmt.Errorf("missing both 'iss' and 'aud' claims (authorization server)")
-
}
-
}
-
-
// Validate claims (even in Phase 1, we need basic validation like expiry)
-
if err := validateClaims(claims); err != nil {
-
return nil, err
-
}
-
-
return claims, nil
-
}
-
-
// VerifyJWT verifies a JWT token's signature and claims (Phase 2)
-
// Fetches the public key from the issuer's JWKS endpoint and validates the signature
-
// For HS256 tokens from whitelisted issuers, uses the shared PDS_JWT_SECRET
-
//
-
// SECURITY: Algorithm is determined by the issuer whitelist, NOT the token header,
-
// to prevent algorithm confusion attacks where an attacker could re-sign a token
-
// with HS256 using a public key as the secret.
-
func VerifyJWT(ctx context.Context, tokenString string, keyFetcher JWKSFetcher) (*Claims, error) {
-
// Strip Bearer prefix once at the start
-
tokenString = stripBearerPrefix(tokenString)
-
-
// First parse to get the issuer (needed to determine expected algorithm)
-
claims, err := ParseJWT(tokenString)
-
if err != nil {
-
return nil, err
-
}
-
-
// Parse header to get the claimed algorithm (for validation)
-
header, err := ParseJWTHeader(tokenString)
-
if err != nil {
-
return nil, err
-
}
-
-
// SECURITY: Determine verification method based on token characteristics
-
// 1. Tokens with `kid` MUST use asymmetric verification (supports federation)
-
// 2. Tokens without `kid` can use HS256 only from whitelisted issuers (your own PDS)
-
useHS256 := shouldUseHS256(header, claims.Issuer)
-
-
if useHS256 {
-
// Verify token actually claims to use HS256
-
if header.Alg != AlgorithmHS256 {
-
return nil, fmt.Errorf("expected HS256 for issuer %s but token uses %s", claims.Issuer, header.Alg)
-
}
-
return verifyHS256Token(tokenString)
-
}
-
-
// Token must use asymmetric verification
-
// Reject HS256 tokens that don't meet the criteria above
-
if header.Alg == AlgorithmHS256 {
-
if header.Kid != "" {
-
return nil, fmt.Errorf("HS256 tokens with kid must use asymmetric verification")
-
}
-
return nil, fmt.Errorf("HS256 not allowed for issuer %s (not in HS256_ISSUERS whitelist)", claims.Issuer)
-
}
-
-
// For RSA/ECDSA, fetch public key from JWKS and verify
-
return verifyAsymmetricToken(ctx, tokenString, claims.Issuer, keyFetcher)
-
}
-
-
// verifyHS256Token verifies a JWT using HMAC-SHA256 with the shared secret
-
func verifyHS256Token(tokenString string) (*Claims, error) {
-
cfg := getConfig()
-
if len(cfg.pdsJWTSecret) == 0 {
-
return nil, fmt.Errorf("HS256 verification failed: PDS_JWT_SECRET not configured")
-
}
-
-
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
-
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
-
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
-
}
-
return cfg.pdsJWTSecret, nil
-
})
-
if err != nil {
-
return nil, fmt.Errorf("HS256 verification failed: %w", err)
-
}
-
-
if !token.Valid {
-
return nil, fmt.Errorf("HS256 verification failed: token signature invalid")
-
}
-
-
verifiedClaims, ok := token.Claims.(*Claims)
-
if !ok {
-
return nil, fmt.Errorf("HS256 verification failed: invalid claims type")
-
}
-
-
if err := validateClaims(verifiedClaims); err != nil {
-
return nil, err
-
}
-
-
return verifiedClaims, nil
-
}
-
-
// verifyAsymmetricToken verifies a JWT using RSA or ECDSA with a public key from JWKS.
-
// For ES256K (secp256k1), uses indigo's crypto package since golang-jwt doesn't support it.
-
func verifyAsymmetricToken(ctx context.Context, tokenString, issuer string, keyFetcher JWKSFetcher) (*Claims, error) {
-
// Parse header to check algorithm
-
header, err := ParseJWTHeader(tokenString)
-
if err != nil {
-
return nil, fmt.Errorf("failed to parse JWT header: %w", err)
-
}
-
-
// ES256K (secp256k1) requires special handling via indigo's crypto package
-
// golang-jwt doesn't recognize ES256K as a valid signing method
-
if header.Alg == "ES256K" {
-
return verifyES256KToken(ctx, tokenString, issuer, keyFetcher)
-
}
-
-
// For standard algorithms (ES256, ES384, ES512, RS256, etc.), use golang-jwt
-
publicKey, err := keyFetcher.FetchPublicKey(ctx, issuer, tokenString)
-
if err != nil {
-
return nil, fmt.Errorf("failed to fetch public key: %w", err)
-
}
-
-
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
-
// Validate signing method - support both RSA and ECDSA (atProto uses ES256 primarily)
-
switch token.Method.(type) {
-
case *jwt.SigningMethodRSA, *jwt.SigningMethodECDSA:
-
// Valid signing methods for atProto
-
default:
-
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
-
}
-
return publicKey, nil
-
})
-
if err != nil {
-
return nil, fmt.Errorf("asymmetric verification failed: %w", err)
-
}
-
-
if !token.Valid {
-
return nil, fmt.Errorf("asymmetric verification failed: token signature invalid")
-
}
-
-
verifiedClaims, ok := token.Claims.(*Claims)
-
if !ok {
-
return nil, fmt.Errorf("asymmetric verification failed: invalid claims type")
-
}
-
-
if err := validateClaims(verifiedClaims); err != nil {
-
return nil, err
-
}
-
-
return verifiedClaims, nil
-
}
-
-
// verifyES256KToken verifies a JWT signed with ES256K (secp256k1) using indigo's crypto package.
-
// This is necessary because golang-jwt doesn't support ES256K as a signing method.
-
func verifyES256KToken(ctx context.Context, tokenString, issuer string, keyFetcher JWKSFetcher) (*Claims, error) {
-
// Fetch the public key - for ES256K, the fetcher returns a JWK map or indigo PublicKey
-
keyData, err := keyFetcher.FetchPublicKey(ctx, issuer, tokenString)
-
if err != nil {
-
return nil, fmt.Errorf("failed to fetch public key for ES256K: %w", err)
-
}
-
-
// Convert to indigo PublicKey based on what the fetcher returned
-
var pubKey indigoCrypto.PublicKey
-
switch k := keyData.(type) {
-
case indigoCrypto.PublicKey:
-
// Already an indigo PublicKey (from DIDKeyFetcher or updated JWKSFetcher)
-
pubKey = k
-
case map[string]interface{}:
-
// Raw JWK map - parse with indigo
-
pubKey, err = parseJWKMapToIndigoPublicKey(k)
-
if err != nil {
-
return nil, fmt.Errorf("failed to parse ES256K JWK: %w", err)
-
}
-
default:
-
return nil, fmt.Errorf("ES256K verification requires indigo PublicKey or JWK map, got %T", keyData)
-
}
-
-
// Verify signature using indigo
-
if err := verifyJWTSignatureWithIndigoKey(tokenString, pubKey); err != nil {
-
return nil, fmt.Errorf("ES256K signature verification failed: %w", err)
-
}
-
-
// Parse claims (signature already verified)
-
claims, err := parseJWTClaimsManually(tokenString)
-
if err != nil {
-
return nil, fmt.Errorf("failed to parse ES256K JWT claims: %w", err)
-
}
-
-
if err := validateClaims(claims); err != nil {
-
return nil, err
-
}
-
-
return claims, nil
-
}
-
-
// parseJWKMapToIndigoPublicKey converts a JWK map to an indigo PublicKey.
-
// This uses indigo's crypto package which supports all atProto curves including secp256k1.
-
func parseJWKMapToIndigoPublicKey(jwkMap map[string]interface{}) (indigoCrypto.PublicKey, error) {
-
// Convert map to JSON bytes for indigo's parser
-
jwkBytes, err := json.Marshal(jwkMap)
-
if err != nil {
-
return nil, fmt.Errorf("failed to serialize JWK: %w", err)
-
}
-
-
// Parse with indigo's crypto package - supports all atProto curves
-
pubKey, err := indigoCrypto.ParsePublicJWKBytes(jwkBytes)
-
if err != nil {
-
return nil, fmt.Errorf("failed to parse JWK with indigo: %w", err)
-
}
-
-
return pubKey, nil
-
}
-
-
// verifyJWTSignatureWithIndigoKey verifies a JWT signature using indigo's crypto package.
-
// This works for all ECDSA algorithms including ES256K (secp256k1).
-
func verifyJWTSignatureWithIndigoKey(tokenString string, pubKey indigoCrypto.PublicKey) error {
-
parts := strings.Split(tokenString, ".")
-
if len(parts) != 3 {
-
return fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
-
}
-
-
// The signing input is "header.payload" (without decoding)
-
signingInput := parts[0] + "." + parts[1]
-
-
// Decode the signature from base64url
-
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
-
if err != nil {
-
return fmt.Errorf("failed to decode JWT signature: %w", err)
-
}
-
-
// Use indigo's verification - HashAndVerifyLenient handles hashing internally
-
// and accepts both low-S and high-S signatures for maximum compatibility
-
if err := pubKey.HashAndVerifyLenient([]byte(signingInput), signature); err != nil {
-
return fmt.Errorf("signature verification failed: %w", err)
-
}
-
-
return nil
-
}
-
-
// parseJWTClaimsManually parses JWT claims without using golang-jwt.
-
// This is used for ES256K tokens where golang-jwt would reject the algorithm.
-
func parseJWTClaimsManually(tokenString string) (*Claims, error) {
-
parts := strings.Split(tokenString, ".")
-
if len(parts) != 3 {
-
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
-
}
-
-
// Decode claims
-
claimsBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
-
if err != nil {
-
return nil, fmt.Errorf("failed to decode JWT claims: %w", err)
-
}
-
-
// Parse into raw map first
-
var rawClaims map[string]interface{}
-
if err := json.Unmarshal(claimsBytes, &rawClaims); err != nil {
-
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
-
}
-
-
// Build Claims struct
-
claims := &Claims{}
-
-
// Extract sub (subject/DID)
-
if sub, ok := rawClaims["sub"].(string); ok {
-
claims.Subject = sub
-
}
-
-
// Extract iss (issuer)
-
if iss, ok := rawClaims["iss"].(string); ok {
-
claims.Issuer = iss
-
}
-
-
// Extract aud (audience) - can be string or array
-
switch aud := rawClaims["aud"].(type) {
-
case string:
-
claims.Audience = jwt.ClaimStrings{aud}
-
case []interface{}:
-
for _, a := range aud {
-
if s, ok := a.(string); ok {
-
claims.Audience = append(claims.Audience, s)
-
}
-
}
-
}
-
-
// Extract exp (expiration)
-
if exp, ok := rawClaims["exp"].(float64); ok {
-
t := time.Unix(int64(exp), 0)
-
claims.ExpiresAt = jwt.NewNumericDate(t)
-
}
-
-
// Extract iat (issued at)
-
if iat, ok := rawClaims["iat"].(float64); ok {
-
t := time.Unix(int64(iat), 0)
-
claims.IssuedAt = jwt.NewNumericDate(t)
-
}
-
-
// Extract nbf (not before)
-
if nbf, ok := rawClaims["nbf"].(float64); ok {
-
t := time.Unix(int64(nbf), 0)
-
claims.NotBefore = jwt.NewNumericDate(t)
-
}
-
-
// Extract jti (JWT ID)
-
if jti, ok := rawClaims["jti"].(string); ok {
-
claims.ID = jti
-
}
-
-
// Extract scope
-
if scope, ok := rawClaims["scope"].(string); ok {
-
claims.Scope = scope
-
}
-
-
// Extract cnf (confirmation) for DPoP binding
-
if cnf, ok := rawClaims["cnf"].(map[string]interface{}); ok {
-
claims.Confirmation = cnf
-
}
-
-
return claims, nil
-
}
-
-
// validateClaims performs additional validation on JWT claims
-
func validateClaims(claims *Claims) error {
-
now := time.Now()
-
-
// Check expiration
-
if claims.ExpiresAt != nil && claims.ExpiresAt.Before(now) {
-
return fmt.Errorf("token has expired")
-
}
-
-
// Check not before
-
if claims.NotBefore != nil && claims.NotBefore.After(now) {
-
return fmt.Errorf("token not yet valid")
-
}
-
-
// Validate DID format in sub claim
-
if !strings.HasPrefix(claims.Subject, "did:") {
-
return fmt.Errorf("invalid DID format in 'sub' claim: %s", claims.Subject)
-
}
-
-
// Validate issuer is either an HTTPS URL or a DID
-
// atProto uses DIDs (did:web:, did:plc:) or HTTPS URLs as issuer identifiers
-
// In dev mode (IS_DEV_ENV=true), allow HTTP for local PDS testing
-
isHTTP := strings.HasPrefix(claims.Issuer, "http://")
-
isHTTPS := strings.HasPrefix(claims.Issuer, "https://")
-
isDID := strings.HasPrefix(claims.Issuer, "did:")
-
-
if !isHTTPS && !isDID && !isHTTP {
-
return fmt.Errorf("issuer must be HTTPS URL, HTTP URL (dev only), or DID, got: %s", claims.Issuer)
-
}
-
-
// In production, reject HTTP issuers (only for non-dev environments)
-
cfg := getConfig()
-
if isHTTP && !cfg.isDevEnv {
-
return fmt.Errorf("HTTP issuer not allowed in production, got: %s", claims.Issuer)
-
}
-
-
// Parse to ensure it's a valid URL
-
if _, err := url.Parse(claims.Issuer); err != nil {
-
return fmt.Errorf("invalid issuer URL: %w", err)
-
}
-
-
// Validate scope if present (lenient: allow empty, but reject wrong scopes)
-
if claims.Scope != "" && !strings.Contains(claims.Scope, "atproto") {
-
return fmt.Errorf("token missing required 'atproto' scope, got: %s", claims.Scope)
-
}
-
-
return nil
-
}
-
-
// JWKSFetcher defines the interface for fetching public keys from JWKS endpoints
-
// Returns interface{} to support both RSA and ECDSA keys
-
type JWKSFetcher interface {
-
FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error)
-
}
-
-
// JWK represents a JSON Web Key from a JWKS endpoint
-
// Supports both RSA and EC (ECDSA) keys
-
type JWK struct {
-
Kid string `json:"kid"` // Key ID
-
Kty string `json:"kty"` // Key type ("RSA" or "EC")
-
Alg string `json:"alg"` // Algorithm (e.g., "RS256", "ES256")
-
Use string `json:"use"` // Public key use (should be "sig" for signatures)
-
// RSA fields
-
N string `json:"n,omitempty"` // RSA modulus
-
E string `json:"e,omitempty"` // RSA exponent
-
// EC fields
-
Crv string `json:"crv,omitempty"` // EC curve (e.g., "P-256")
-
X string `json:"x,omitempty"` // EC x coordinate
-
Y string `json:"y,omitempty"` // EC y coordinate
-
}
-
-
// ToPublicKey converts a JWK to a public key (RSA, ECDSA, or indigo for secp256k1).
-
//
-
// Returns:
-
// - *rsa.PublicKey for RSA keys
-
// - *ecdsa.PublicKey for NIST EC curves (P-256, P-384, P-521)
-
// - map[string]interface{} for secp256k1 (ES256K) - parsed by indigo
-
func (j *JWK) ToPublicKey() (interface{}, error) {
-
switch j.Kty {
-
case "RSA":
-
return j.toRSAPublicKey()
-
case "EC":
-
// For secp256k1, return raw JWK map for indigo to parse
-
if j.Crv == "secp256k1" {
-
return j.toJWKMap(), nil
-
}
-
return j.toECPublicKey()
-
default:
-
return nil, fmt.Errorf("unsupported key type: %s", j.Kty)
-
}
-
}
-
-
// toJWKMap converts the JWK struct to a map for indigo parsing
-
func (j *JWK) toJWKMap() map[string]interface{} {
-
m := map[string]interface{}{
-
"kty": j.Kty,
-
}
-
if j.Kid != "" {
-
m["kid"] = j.Kid
-
}
-
if j.Alg != "" {
-
m["alg"] = j.Alg
-
}
-
if j.Use != "" {
-
m["use"] = j.Use
-
}
-
// RSA fields
-
if j.N != "" {
-
m["n"] = j.N
-
}
-
if j.E != "" {
-
m["e"] = j.E
-
}
-
// EC fields
-
if j.Crv != "" {
-
m["crv"] = j.Crv
-
}
-
if j.X != "" {
-
m["x"] = j.X
-
}
-
if j.Y != "" {
-
m["y"] = j.Y
-
}
-
return m
-
}
-
-
// toRSAPublicKey converts a JWK to an RSA public key
-
func (j *JWK) toRSAPublicKey() (*rsa.PublicKey, error) {
-
// Decode modulus
-
nBytes, err := base64.RawURLEncoding.DecodeString(j.N)
-
if err != nil {
-
return nil, fmt.Errorf("failed to decode RSA modulus: %w", err)
-
}
-
-
// Decode exponent
-
eBytes, err := base64.RawURLEncoding.DecodeString(j.E)
-
if err != nil {
-
return nil, fmt.Errorf("failed to decode RSA exponent: %w", err)
-
}
-
-
// Convert exponent to int
-
var eInt int
-
for _, b := range eBytes {
-
eInt = eInt*256 + int(b)
-
}
-
-
return &rsa.PublicKey{
-
N: new(big.Int).SetBytes(nBytes),
-
E: eInt,
-
}, nil
-
}
-
-
// toECPublicKey converts a JWK to an ECDSA public key
-
func (j *JWK) toECPublicKey() (*ecdsa.PublicKey, error) {
-
// Determine curve
-
var curve elliptic.Curve
-
switch j.Crv {
-
case "P-256":
-
curve = elliptic.P256()
-
case "P-384":
-
curve = elliptic.P384()
-
case "P-521":
-
curve = elliptic.P521()
-
default:
-
return nil, fmt.Errorf("unsupported EC curve: %s", j.Crv)
-
}
-
-
// Decode X coordinate
-
xBytes, err := base64.RawURLEncoding.DecodeString(j.X)
-
if err != nil {
-
return nil, fmt.Errorf("failed to decode EC x coordinate: %w", err)
-
}
-
-
// Decode Y coordinate
-
yBytes, err := base64.RawURLEncoding.DecodeString(j.Y)
-
if err != nil {
-
return nil, fmt.Errorf("failed to decode EC y coordinate: %w", err)
-
}
-
-
return &ecdsa.PublicKey{
-
Curve: curve,
-
X: new(big.Int).SetBytes(xBytes),
-
Y: new(big.Int).SetBytes(yBytes),
-
}, nil
-
}
-
-
// JWKS represents a JSON Web Key Set
-
type JWKS struct {
-
Keys []JWK `json:"keys"`
-
}
-
-
// FindKeyByID finds a key in the JWKS by its key ID
-
func (j *JWKS) FindKeyByID(kid string) (*JWK, error) {
-
for _, key := range j.Keys {
-
if key.Kid == kid {
-
return &key, nil
-
}
-
}
-
return nil, fmt.Errorf("key with kid %s not found", kid)
-
}
-
-
// ExtractKeyID extracts the key ID from a JWT token header
-
func ExtractKeyID(tokenString string) (string, error) {
-
header, err := ParseJWTHeader(tokenString)
-
if err != nil {
-
return "", err
-
}
-
-
if header.Kid == "" {
-
return "", fmt.Errorf("missing kid in token header")
-
}
-
-
return header.Kid, nil
-
}
-496
internal/atproto/auth/jwt_test.go
···
-
package auth
-
-
import (
-
"context"
-
"testing"
-
"time"
-
-
"github.com/golang-jwt/jwt/v5"
-
)
-
-
func TestParseJWT(t *testing.T) {
-
// Create a test JWT token
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
Issuer: "https://test-pds.example.com",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
IssuedAt: jwt.NewNumericDate(time.Now()),
-
},
-
Scope: "atproto transition:generic",
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
-
tokenString, err := token.SignedString([]byte("test-secret"))
-
if err != nil {
-
t.Fatalf("Failed to create test token: %v", err)
-
}
-
-
// Test parsing
-
parsedClaims, err := ParseJWT(tokenString)
-
if err != nil {
-
t.Fatalf("ParseJWT failed: %v", err)
-
}
-
-
if parsedClaims.Subject != "did:plc:test123" {
-
t.Errorf("Expected subject 'did:plc:test123', got '%s'", parsedClaims.Subject)
-
}
-
-
if parsedClaims.Issuer != "https://test-pds.example.com" {
-
t.Errorf("Expected issuer 'https://test-pds.example.com', got '%s'", parsedClaims.Issuer)
-
}
-
-
if parsedClaims.Scope != "atproto transition:generic" {
-
t.Errorf("Expected scope 'atproto transition:generic', got '%s'", parsedClaims.Scope)
-
}
-
}
-
-
func TestParseJWT_MissingSubject(t *testing.T) {
-
// Create a token without subject
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Issuer: "https://test-pds.example.com",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
},
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
-
tokenString, err := token.SignedString([]byte("test-secret"))
-
if err != nil {
-
t.Fatalf("Failed to create test token: %v", err)
-
}
-
-
// Test parsing - should fail
-
_, err = ParseJWT(tokenString)
-
if err == nil {
-
t.Error("Expected error for missing subject, got nil")
-
}
-
}
-
-
func TestParseJWT_MissingIssuer(t *testing.T) {
-
// Create a token without issuer
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
},
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
-
tokenString, err := token.SignedString([]byte("test-secret"))
-
if err != nil {
-
t.Fatalf("Failed to create test token: %v", err)
-
}
-
-
// Test parsing - should fail
-
_, err = ParseJWT(tokenString)
-
if err == nil {
-
t.Error("Expected error for missing issuer, got nil")
-
}
-
}
-
-
func TestParseJWT_WithBearerPrefix(t *testing.T) {
-
// Create a test JWT token
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
Issuer: "https://test-pds.example.com",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
},
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
-
tokenString, err := token.SignedString([]byte("test-secret"))
-
if err != nil {
-
t.Fatalf("Failed to create test token: %v", err)
-
}
-
-
// Test parsing with Bearer prefix
-
parsedClaims, err := ParseJWT("Bearer " + tokenString)
-
if err != nil {
-
t.Fatalf("ParseJWT failed with Bearer prefix: %v", err)
-
}
-
-
if parsedClaims.Subject != "did:plc:test123" {
-
t.Errorf("Expected subject 'did:plc:test123', got '%s'", parsedClaims.Subject)
-
}
-
}
-
-
func TestValidateClaims_Expired(t *testing.T) {
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
Issuer: "https://test-pds.example.com",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(-1 * time.Hour)), // Expired
-
},
-
}
-
-
err := validateClaims(claims)
-
if err == nil {
-
t.Error("Expected error for expired token, got nil")
-
}
-
}
-
-
func TestValidateClaims_InvalidDID(t *testing.T) {
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "invalid-did-format",
-
Issuer: "https://test-pds.example.com",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
},
-
}
-
-
err := validateClaims(claims)
-
if err == nil {
-
t.Error("Expected error for invalid DID format, got nil")
-
}
-
}
-
-
func TestExtractKeyID(t *testing.T) {
-
// Create a test JWT token with kid in header
-
token := jwt.New(jwt.SigningMethodRS256)
-
token.Header["kid"] = "test-key-id"
-
token.Claims = &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
Issuer: "https://test-pds.example.com",
-
},
-
}
-
-
// Sign with a dummy RSA key (we just need a valid token structure)
-
tokenString, err := token.SignedString([]byte("dummy"))
-
if err == nil {
-
// If it succeeds (shouldn't with wrong key type, but let's handle it)
-
kid, err := ExtractKeyID(tokenString)
-
if err != nil {
-
t.Logf("ExtractKeyID failed (expected if signing fails): %v", err)
-
} else if kid != "test-key-id" {
-
t.Errorf("Expected kid 'test-key-id', got '%s'", kid)
-
}
-
}
-
}
-
-
// === HS256 Verification Tests ===
-
-
// mockJWKSFetcher is a mock implementation of JWKSFetcher for testing
-
type mockJWKSFetcher struct {
-
publicKey interface{}
-
err error
-
}
-
-
func (m *mockJWKSFetcher) FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) {
-
return m.publicKey, m.err
-
}
-
-
func createHS256Token(t *testing.T, subject, issuer, secret string, expiry time.Duration) string {
-
t.Helper()
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: subject,
-
Issuer: issuer,
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiry)),
-
IssuedAt: jwt.NewNumericDate(time.Now()),
-
},
-
Scope: "atproto transition:generic",
-
}
-
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
-
tokenString, err := token.SignedString([]byte(secret))
-
if err != nil {
-
t.Fatalf("Failed to create test token: %v", err)
-
}
-
return tokenString
-
}
-
-
func TestVerifyJWT_HS256_Valid(t *testing.T) {
-
// Setup: Configure environment for HS256 verification
-
secret := "test-jwt-secret-key-12345"
-
issuer := "https://pds.coves.social"
-
-
ResetJWTConfigForTesting()
-
t.Setenv("PDS_JWT_SECRET", secret)
-
t.Setenv("HS256_ISSUERS", issuer)
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
tokenString := createHS256Token(t, "did:plc:test123", issuer, secret, 1*time.Hour)
-
-
// Verify token
-
claims, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
-
if err != nil {
-
t.Fatalf("VerifyJWT failed for valid HS256 token: %v", err)
-
}
-
-
if claims.Subject != "did:plc:test123" {
-
t.Errorf("Expected subject 'did:plc:test123', got '%s'", claims.Subject)
-
}
-
if claims.Issuer != issuer {
-
t.Errorf("Expected issuer '%s', got '%s'", issuer, claims.Issuer)
-
}
-
}
-
-
func TestVerifyJWT_HS256_WrongSecret(t *testing.T) {
-
// Setup: Configure environment with one secret, sign with another
-
issuer := "https://pds.coves.social"
-
-
ResetJWTConfigForTesting()
-
t.Setenv("PDS_JWT_SECRET", "correct-secret")
-
t.Setenv("HS256_ISSUERS", issuer)
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
// Create token with wrong secret
-
tokenString := createHS256Token(t, "did:plc:test123", issuer, "wrong-secret", 1*time.Hour)
-
-
// Verify should fail
-
_, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
-
if err == nil {
-
t.Error("Expected error for HS256 token with wrong secret, got nil")
-
}
-
}
-
-
func TestVerifyJWT_HS256_SecretNotConfigured(t *testing.T) {
-
// Setup: Whitelist issuer but don't configure secret
-
issuer := "https://pds.coves.social"
-
-
ResetJWTConfigForTesting()
-
t.Setenv("PDS_JWT_SECRET", "") // Ensure secret is not set (empty = not configured)
-
t.Setenv("HS256_ISSUERS", issuer)
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
tokenString := createHS256Token(t, "did:plc:test123", issuer, "any-secret", 1*time.Hour)
-
-
// Verify should fail with descriptive error
-
_, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
-
if err == nil {
-
t.Error("Expected error when PDS_JWT_SECRET not configured, got nil")
-
}
-
if err != nil && !contains(err.Error(), "PDS_JWT_SECRET not configured") {
-
t.Errorf("Expected error about PDS_JWT_SECRET not configured, got: %v", err)
-
}
-
}
-
-
// === Algorithm Confusion Attack Prevention Tests ===
-
-
func TestVerifyJWT_AlgorithmConfusionAttack_HS256WithNonWhitelistedIssuer(t *testing.T) {
-
// SECURITY TEST: This tests the algorithm confusion attack prevention
-
// An attacker tries to use HS256 with an issuer that should use RS256/ES256
-
-
ResetJWTConfigForTesting()
-
t.Setenv("PDS_JWT_SECRET", "some-secret")
-
t.Setenv("HS256_ISSUERS", "https://trusted.example.com") // Different from token issuer
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
// Create HS256 token with non-whitelisted issuer (simulating attack)
-
tokenString := createHS256Token(t, "did:plc:attacker", "https://victim-pds.example.com", "some-secret", 1*time.Hour)
-
-
// Verify should fail because issuer is not in HS256 whitelist
-
_, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
-
if err == nil {
-
t.Error("SECURITY VULNERABILITY: HS256 token accepted for non-whitelisted issuer")
-
}
-
if err != nil && !contains(err.Error(), "not in HS256_ISSUERS whitelist") {
-
t.Errorf("Expected error about HS256 not allowed for issuer, got: %v", err)
-
}
-
}
-
-
func TestVerifyJWT_AlgorithmConfusionAttack_EmptyWhitelist(t *testing.T) {
-
// SECURITY TEST: When no issuers are whitelisted for HS256, all HS256 tokens should be rejected
-
-
ResetJWTConfigForTesting()
-
t.Setenv("PDS_JWT_SECRET", "some-secret")
-
t.Setenv("HS256_ISSUERS", "") // Empty whitelist
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
tokenString := createHS256Token(t, "did:plc:test123", "https://any-pds.example.com", "some-secret", 1*time.Hour)
-
-
// Verify should fail because no issuers are whitelisted for HS256
-
_, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
-
if err == nil {
-
t.Error("SECURITY VULNERABILITY: HS256 token accepted with empty issuer whitelist")
-
}
-
}
-
-
func TestVerifyJWT_IssuerRequiresHS256ButTokenUsesRS256(t *testing.T) {
-
// Test that issuer whitelisted for HS256 rejects tokens claiming to use RS256
-
issuer := "https://pds.coves.social"
-
-
ResetJWTConfigForTesting()
-
t.Setenv("PDS_JWT_SECRET", "test-secret")
-
t.Setenv("HS256_ISSUERS", issuer)
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
// Create RS256-signed token (can't actually sign without RSA key, but we can test the header check)
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
Issuer: issuer,
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
},
-
}
-
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
-
// This will create an invalid signature but valid header structure
-
// The test should fail at algorithm check, not signature verification
-
tokenString, _ := token.SignedString([]byte("dummy-key"))
-
-
if tokenString != "" {
-
_, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
-
if err == nil {
-
t.Error("Expected error when HS256 issuer receives non-HS256 token")
-
}
-
}
-
}
-
-
// === ParseJWTHeader Tests ===
-
-
func TestParseJWTHeader_Valid(t *testing.T) {
-
tokenString := createHS256Token(t, "did:plc:test123", "https://test.example.com", "secret", 1*time.Hour)
-
-
header, err := ParseJWTHeader(tokenString)
-
if err != nil {
-
t.Fatalf("ParseJWTHeader failed: %v", err)
-
}
-
-
if header.Alg != AlgorithmHS256 {
-
t.Errorf("Expected alg '%s', got '%s'", AlgorithmHS256, header.Alg)
-
}
-
}
-
-
func TestParseJWTHeader_WithBearerPrefix(t *testing.T) {
-
tokenString := createHS256Token(t, "did:plc:test123", "https://test.example.com", "secret", 1*time.Hour)
-
-
header, err := ParseJWTHeader("Bearer " + tokenString)
-
if err != nil {
-
t.Fatalf("ParseJWTHeader failed with Bearer prefix: %v", err)
-
}
-
-
if header.Alg != AlgorithmHS256 {
-
t.Errorf("Expected alg '%s', got '%s'", AlgorithmHS256, header.Alg)
-
}
-
}
-
-
func TestParseJWTHeader_InvalidFormat(t *testing.T) {
-
testCases := []struct {
-
name string
-
input string
-
}{
-
{"empty string", ""},
-
{"single part", "abc"},
-
{"two parts", "abc.def"},
-
{"too many parts", "a.b.c.d"},
-
}
-
-
for _, tc := range testCases {
-
t.Run(tc.name, func(t *testing.T) {
-
_, err := ParseJWTHeader(tc.input)
-
if err == nil {
-
t.Errorf("Expected error for invalid JWT format '%s', got nil", tc.input)
-
}
-
})
-
}
-
}
-
-
// === shouldUseHS256 and isHS256IssuerWhitelisted Tests ===
-
-
func TestIsHS256IssuerWhitelisted_Whitelisted(t *testing.T) {
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", "https://pds1.example.com,https://pds2.example.com")
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
if !isHS256IssuerWhitelisted("https://pds1.example.com") {
-
t.Error("Expected pds1 to be whitelisted")
-
}
-
if !isHS256IssuerWhitelisted("https://pds2.example.com") {
-
t.Error("Expected pds2 to be whitelisted")
-
}
-
}
-
-
func TestIsHS256IssuerWhitelisted_NotWhitelisted(t *testing.T) {
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", "https://pds1.example.com")
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
if isHS256IssuerWhitelisted("https://attacker.example.com") {
-
t.Error("Expected non-whitelisted issuer to return false")
-
}
-
}
-
-
func TestIsHS256IssuerWhitelisted_EmptyWhitelist(t *testing.T) {
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", "") // Empty whitelist
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
if isHS256IssuerWhitelisted("https://any.example.com") {
-
t.Error("Expected false when whitelist is empty (safe default)")
-
}
-
}
-
-
func TestIsHS256IssuerWhitelisted_WhitespaceHandling(t *testing.T) {
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", " https://pds1.example.com , https://pds2.example.com ")
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
if !isHS256IssuerWhitelisted("https://pds1.example.com") {
-
t.Error("Expected whitespace-trimmed issuer to be whitelisted")
-
}
-
}
-
-
// === shouldUseHS256 Tests (kid-based logic) ===
-
-
func TestShouldUseHS256_WithKid_AlwaysFalse(t *testing.T) {
-
// Tokens with kid should NEVER use HS256, regardless of issuer whitelist
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", "https://whitelisted.example.com")
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
header := &JWTHeader{
-
Alg: AlgorithmHS256,
-
Kid: "some-key-id", // Has kid
-
}
-
-
// Even whitelisted issuer should not use HS256 if token has kid
-
if shouldUseHS256(header, "https://whitelisted.example.com") {
-
t.Error("Tokens with kid should never use HS256 (supports federation)")
-
}
-
}
-
-
func TestShouldUseHS256_WithoutKid_WhitelistedIssuer(t *testing.T) {
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", "https://my-pds.example.com")
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
header := &JWTHeader{
-
Alg: AlgorithmHS256,
-
Kid: "", // No kid
-
}
-
-
if !shouldUseHS256(header, "https://my-pds.example.com") {
-
t.Error("Token without kid from whitelisted issuer should use HS256")
-
}
-
}
-
-
func TestShouldUseHS256_WithoutKid_NotWhitelisted(t *testing.T) {
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", "https://my-pds.example.com")
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
header := &JWTHeader{
-
Alg: AlgorithmHS256,
-
Kid: "", // No kid
-
}
-
-
if shouldUseHS256(header, "https://external-pds.example.com") {
-
t.Error("Token without kid from non-whitelisted issuer should NOT use HS256")
-
}
-
}
-
-
// Helper function
-
func contains(s, substr string) bool {
-
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
-
}
-
-
func containsHelper(s, substr string) bool {
-
for i := 0; i <= len(s)-len(substr); i++ {
-
if s[i:i+len(substr)] == substr {
-
return true
-
}
-
}
-
return false
-
}
-4
tests/integration/oauth_helpers.go
···
// Production mode: HTTPS PDS, use real PLC directory
config = &oauth.OAuthConfig{
PublicURL: "http://localhost:3000", // Test server callback URL
-
ClientSecret: "", // Public client
-
ClientKID: "", // Public client
SealSecret: sealSecretB64, // For sealing mobile tokens
Scopes: []string{"atproto", "transition:generic"},
DevMode: false, // Production mode for HTTPS PDS
···
// Dev mode: localhost PDS with HTTP
config = &oauth.OAuthConfig{
PublicURL: "http://localhost:3000", // Match the callback URL expected by PDS
-
ClientSecret: "", // Empty for public client in dev mode
-
ClientKID: "", // Empty for public client
SealSecret: sealSecretB64, // For sealing mobile tokens
Scopes: []string{"atproto", "transition:generic"},
DevMode: true, // Enable dev mode for localhost testing
+1
docker-compose.prod.yml
···
# Instance identity
INSTANCE_DID: did:web:coves.social
INSTANCE_DOMAIN: coves.social
+
APPVIEW_PUBLIC_URL: https://coves.social
# PDS connection (separate domain!)
PDS_URL: https://coves.me
+1 -8
Caddyfile
···
file_server
}
-
# Serve OAuth callback page
-
handle /oauth/callback {
-
root * /srv
-
rewrite * /oauth/callback.html
-
file_server
-
}
-
-
# Proxy all other requests to AppView
+
# Proxy all requests to AppView
handle {
reverse_proxy appview:8080 {
# Health check
-97
static/oauth/callback.html
···
-
<!DOCTYPE html>
-
<html>
-
<head>
-
<meta charset="utf-8">
-
<meta name="viewport" content="width=device-width, initial-scale=1">
-
<meta http-equiv="Content-Security-Policy" content="default-src 'self'; script-src 'unsafe-inline'; style-src 'unsafe-inline'">
-
<title>Authorization Successful - Coves</title>
-
<style>
-
body {
-
font-family: system-ui, -apple-system, sans-serif;
-
display: flex;
-
align-items: center;
-
justify-content: center;
-
min-height: 100vh;
-
margin: 0;
-
background: #f5f5f5;
-
}
-
.container {
-
text-align: center;
-
padding: 2rem;
-
background: white;
-
border-radius: 8px;
-
box-shadow: 0 2px 8px rgba(0,0,0,0.1);
-
max-width: 400px;
-
}
-
.success { color: #22c55e; font-size: 3rem; margin-bottom: 1rem; }
-
h1 { margin: 0 0 0.5rem; color: #1f2937; font-size: 1.5rem; }
-
p { color: #6b7280; margin: 0.5rem 0; }
-
a {
-
display: inline-block;
-
margin-top: 1rem;
-
padding: 0.75rem 1.5rem;
-
background: #3b82f6;
-
color: white;
-
text-decoration: none;
-
border-radius: 6px;
-
font-weight: 500;
-
}
-
a:hover { background: #2563eb; }
-
</style>
-
</head>
-
<body>
-
<div class="container">
-
<div class="success">โœ“</div>
-
<h1>Authorization Successful!</h1>
-
<p id="status">Returning to Coves...</p>
-
<a href="#" id="manualLink">Open Coves</a>
-
</div>
-
<script>
-
(function() {
-
// Parse and sanitize query params - only allow expected OAuth parameters
-
const urlParams = new URLSearchParams(window.location.search);
-
const safeParams = new URLSearchParams();
-
-
// Whitelist only expected OAuth callback parameters
-
const code = urlParams.get('code');
-
const state = urlParams.get('state');
-
const error = urlParams.get('error');
-
const errorDescription = urlParams.get('error_description');
-
const iss = urlParams.get('iss');
-
-
if (code) safeParams.set('code', code);
-
if (state) safeParams.set('state', state);
-
if (error) safeParams.set('error', error);
-
if (errorDescription) safeParams.set('error_description', errorDescription);
-
if (iss) safeParams.set('iss', iss);
-
-
const sanitizedQuery = safeParams.toString() ? '?' + safeParams.toString() : '';
-
-
const userAgent = navigator.userAgent || '';
-
const isAndroid = /Android/i.test(userAgent);
-
-
// Build deep link based on platform
-
let deepLink;
-
if (isAndroid) {
-
// Android: Intent URL format
-
const pathAndQuery = '/oauth/callback' + sanitizedQuery;
-
deepLink = 'intent:/' + pathAndQuery + '#Intent;scheme=social.coves;package=social.coves;end';
-
} else {
-
// iOS: Custom scheme
-
deepLink = 'social.coves:/oauth/callback' + sanitizedQuery;
-
}
-
-
// Update manual link
-
document.getElementById('manualLink').href = deepLink;
-
-
// Attempt automatic redirect
-
window.location.href = deepLink;
-
-
// Update status after 2 seconds if redirect didn't work
-
setTimeout(function() {
-
document.getElementById('status').textContent = 'Click the button above to continue';
-
}, 2000);
-
})();
-
</script>
-
</body>
-
</html>
+11
static/.well-known/apple-app-site-association
···
+
{
+
"applinks": {
+
"apps": [],
+
"details": [
+
{
+
"appID": "TEAM_ID.social.coves",
+
"paths": ["/app/oauth/callback"]
+
}
+
]
+
}
+
}
+10
static/.well-known/assetlinks.json
···
+
[{
+
"relation": ["delegate_permission/common.handle_all_urls"],
+
"target": {
+
"namespace": "android_app",
+
"package_name": "social.coves",
+
"sha256_cert_fingerprints": [
+
"0B:D8:8C:99:66:25:E5:CD:06:54:80:88:01:6F:B7:38:B9:F4:5B:41:71:F7:95:C8:68:94:87:AD:EA:9F:D9:ED"
+
]
+
}
+
}]
+41
internal/atproto/lexicon/social/coves/feed/vote/delete.json
···
+
{
+
"lexicon": 1,
+
"id": "social.coves.feed.vote.delete",
+
"defs": {
+
"main": {
+
"type": "procedure",
+
"description": "Delete a vote on a post or comment",
+
"input": {
+
"encoding": "application/json",
+
"schema": {
+
"type": "object",
+
"required": ["subject"],
+
"properties": {
+
"subject": {
+
"type": "ref",
+
"ref": "com.atproto.repo.strongRef",
+
"description": "Strong reference to the post or comment to remove the vote from"
+
}
+
}
+
}
+
},
+
"output": {
+
"encoding": "application/json",
+
"schema": {
+
"type": "object",
+
"properties": {}
+
}
+
},
+
"errors": [
+
{
+
"name": "VoteNotFound",
+
"description": "No vote found for this subject"
+
},
+
{
+
"name": "NotAuthorized",
+
"description": "User is not authorized to delete this vote"
+
}
+
]
+
}
+
}
+
}
+115
internal/api/handlers/vote/create_vote.go
···
+
package vote
+
+
import (
+
"Coves/internal/api/middleware"
+
"Coves/internal/core/votes"
+
"encoding/json"
+
"log"
+
"net/http"
+
)
+
+
// CreateVoteHandler handles vote creation
+
type CreateVoteHandler struct {
+
service votes.Service
+
}
+
+
// NewCreateVoteHandler creates a new create vote handler
+
func NewCreateVoteHandler(service votes.Service) *CreateVoteHandler {
+
return &CreateVoteHandler{
+
service: service,
+
}
+
}
+
+
// CreateVoteInput represents the request body for creating a vote
+
type CreateVoteInput struct {
+
Subject struct {
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
} `json:"subject"`
+
Direction string `json:"direction"`
+
}
+
+
// CreateVoteOutput represents the response body for creating a vote
+
type CreateVoteOutput struct {
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
}
+
+
// HandleCreateVote creates a vote on a post or comment
+
// POST /xrpc/social.coves.vote.create
+
//
+
// Request body: { "subject": { "uri": "at://...", "cid": "..." }, "direction": "up" }
+
// Response: { "uri": "at://...", "cid": "..." }
+
//
+
// Behavior:
+
// - If no vote exists: creates new vote with given direction
+
// - If vote exists with same direction: deletes vote (toggle off)
+
// - If vote exists with different direction: updates to new direction
+
func (h *CreateVoteHandler) HandleCreateVote(w http.ResponseWriter, r *http.Request) {
+
if r.Method != http.MethodPost {
+
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+
return
+
}
+
+
// Parse request body
+
var input CreateVoteInput
+
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "Invalid request body")
+
return
+
}
+
+
// Validate required fields
+
if input.Subject.URI == "" {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "subject.uri is required")
+
return
+
}
+
if input.Subject.CID == "" {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "subject.cid is required")
+
return
+
}
+
if input.Direction == "" {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "direction is required")
+
return
+
}
+
+
// Validate direction
+
if input.Direction != "up" && input.Direction != "down" {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "direction must be 'up' or 'down'")
+
return
+
}
+
+
// Get OAuth session from context (injected by auth middleware)
+
session := middleware.GetOAuthSession(r)
+
if session == nil {
+
writeError(w, http.StatusUnauthorized, "AuthRequired", "Authentication required")
+
return
+
}
+
+
// Create vote request
+
req := votes.CreateVoteRequest{
+
Subject: votes.StrongRef{
+
URI: input.Subject.URI,
+
CID: input.Subject.CID,
+
},
+
Direction: input.Direction,
+
}
+
+
// Call service to create vote
+
response, err := h.service.CreateVote(r.Context(), session, req)
+
if err != nil {
+
handleServiceError(w, err)
+
return
+
}
+
+
// Return success response
+
output := CreateVoteOutput{
+
URI: response.URI,
+
CID: response.CID,
+
}
+
+
w.Header().Set("Content-Type", "application/json")
+
w.WriteHeader(http.StatusOK)
+
if err := json.NewEncoder(w).Encode(output); err != nil {
+
log.Printf("Failed to encode response: %v", err)
+
}
+
}
+93
internal/api/handlers/vote/delete_vote.go
···
+
package vote
+
+
import (
+
"Coves/internal/api/middleware"
+
"Coves/internal/core/votes"
+
"encoding/json"
+
"log"
+
"net/http"
+
)
+
+
// DeleteVoteHandler handles vote deletion
+
type DeleteVoteHandler struct {
+
service votes.Service
+
}
+
+
// NewDeleteVoteHandler creates a new delete vote handler
+
func NewDeleteVoteHandler(service votes.Service) *DeleteVoteHandler {
+
return &DeleteVoteHandler{
+
service: service,
+
}
+
}
+
+
// DeleteVoteInput represents the request body for deleting a vote
+
type DeleteVoteInput struct {
+
Subject struct {
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
} `json:"subject"`
+
}
+
+
// DeleteVoteOutput represents the response body for deleting a vote
+
// Per lexicon: output is an empty object
+
type DeleteVoteOutput struct{}
+
+
// HandleDeleteVote removes a vote from a post or comment
+
// POST /xrpc/social.coves.vote.delete
+
//
+
// Request body: { "subject": { "uri": "at://...", "cid": "..." } }
+
// Response: { "success": true }
+
func (h *DeleteVoteHandler) HandleDeleteVote(w http.ResponseWriter, r *http.Request) {
+
if r.Method != http.MethodPost {
+
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+
return
+
}
+
+
// Parse request body
+
var input DeleteVoteInput
+
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "Invalid request body")
+
return
+
}
+
+
// Validate required fields
+
if input.Subject.URI == "" {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "subject.uri is required")
+
return
+
}
+
if input.Subject.CID == "" {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "subject.cid is required")
+
return
+
}
+
+
// Get OAuth session from context (injected by auth middleware)
+
session := middleware.GetOAuthSession(r)
+
if session == nil {
+
writeError(w, http.StatusUnauthorized, "AuthRequired", "Authentication required")
+
return
+
}
+
+
// Create delete vote request
+
req := votes.DeleteVoteRequest{
+
Subject: votes.StrongRef{
+
URI: input.Subject.URI,
+
CID: input.Subject.CID,
+
},
+
}
+
+
// Call service to delete vote
+
err := h.service.DeleteVote(r.Context(), session, req)
+
if err != nil {
+
handleServiceError(w, err)
+
return
+
}
+
+
// Return success response (empty object per lexicon)
+
output := DeleteVoteOutput{}
+
+
w.Header().Set("Content-Type", "application/json")
+
w.WriteHeader(http.StatusOK)
+
if err := json.NewEncoder(w).Encode(output); err != nil {
+
log.Printf("Failed to encode response: %v", err)
+
}
+
}
+24
internal/api/routes/vote.go
···
+
package routes
+
+
import (
+
"Coves/internal/api/handlers/vote"
+
"Coves/internal/api/middleware"
+
"Coves/internal/core/votes"
+
+
"github.com/go-chi/chi/v5"
+
)
+
+
// RegisterVoteRoutes registers vote-related XRPC endpoints on the router
+
// Implements social.coves.feed.vote.* lexicon endpoints
+
func RegisterVoteRoutes(r chi.Router, voteService votes.Service, authMiddleware *middleware.OAuthAuthMiddleware) {
+
// Initialize handlers
+
createHandler := vote.NewCreateVoteHandler(voteService)
+
deleteHandler := vote.NewDeleteVoteHandler(voteService)
+
+
// Procedure endpoints (POST) - require authentication
+
// social.coves.feed.vote.create - create or update a vote on a post/comment
+
r.With(authMiddleware.RequireAuth).Post("/xrpc/social.coves.feed.vote.create", createHandler.HandleCreateVote)
+
+
// social.coves.feed.vote.delete - delete a vote from a post/comment
+
r.With(authMiddleware.RequireAuth).Post("/xrpc/social.coves.feed.vote.delete", deleteHandler.HandleDeleteVote)
+
}
+3
.beads/beads.left.jsonl
···
+
{"id":"Coves-95q","content_hash":"8ec99d598f067780436b985f9ad57f0fa19632026981038df4f65f192186620b","title":"Add comprehensive API documentation","description":"","status":"open","priority":2,"issue_type":"task","created_at":"2025-11-17T20:30:34.835721854-08:00","updated_at":"2025-11-17T20:30:34.835721854-08:00","source_repo":".","dependencies":[{"issue_id":"Coves-95q","depends_on_id":"Coves-e16","type":"blocks","created_at":"2025-11-17T20:30:46.273899399-08:00","created_by":"daemon"}]}
+
{"id":"Coves-e16","content_hash":"7c5d0fc8f0e7f626be3dad62af0e8412467330bad01a244e5a7e52ac5afff1c1","title":"Complete post creation and moderation features","description":"","status":"open","priority":1,"issue_type":"feature","created_at":"2025-11-17T20:30:12.885991306-08:00","updated_at":"2025-11-17T20:30:12.885991306-08:00","source_repo":"."}
+
{"id":"Coves-fce","content_hash":"26b3e16b99f827316ee0d741cc959464bd0c813446c95aef8105c7fd1e6b09ff","title":"Implement aggregator feed federation","description":"","status":"open","priority":1,"issue_type":"feature","created_at":"2025-11-17T20:30:21.453326012-08:00","updated_at":"2025-11-17T20:30:21.453326012-08:00","source_repo":"."}
+1
.beads/beads.left.meta.json
···
+
{"version":"0.23.1","timestamp":"2025-12-02T18:25:24.009187871-08:00","commit":"00d7d8d"}
-3
internal/api/handlers/vote/errors.go
···
case errors.Is(err, votes.ErrVoteNotFound):
// Matches: social.coves.feed.vote.delete#VoteNotFound
writeError(w, http.StatusNotFound, "VoteNotFound", "No vote found for this subject")
-
case errors.Is(err, votes.ErrSubjectNotFound):
-
// Matches: social.coves.feed.vote.create#SubjectNotFound
-
writeError(w, http.StatusNotFound, "SubjectNotFound", "The subject post or comment was not found")
case errors.Is(err, votes.ErrInvalidDirection):
writeError(w, http.StatusBadRequest, "InvalidRequest", "Vote direction must be 'up' or 'down'")
case errors.Is(err, votes.ErrInvalidSubject):
-3
internal/core/votes/errors.go
···
// ErrVoteNotFound indicates the requested vote doesn't exist
ErrVoteNotFound = errors.New("vote not found")
-
// ErrSubjectNotFound indicates the post/comment being voted on doesn't exist
-
ErrSubjectNotFound = errors.New("subject not found")
-
// ErrInvalidDirection indicates the vote direction is not "up" or "down"
ErrInvalidDirection = errors.New("invalid vote direction: must be 'up' or 'down'")
+3 -2
internal/db/postgres/vote_repo.go
···
return nil
}
-
// GetByURI retrieves a vote by its AT-URI
+
// GetByURI retrieves an active vote by its AT-URI
// Used by Jetstream consumer for DELETE operations
+
// Returns ErrVoteNotFound for soft-deleted votes
func (r *postgresVoteRepo) GetByURI(ctx context.Context, uri string) (*votes.Vote, error) {
query := `
SELECT
···
subject_uri, subject_cid, direction,
created_at, indexed_at, deleted_at
FROM votes
-
WHERE uri = $1
+
WHERE uri = $1 AND deleted_at IS NULL
`
var vote votes.Vote
+92
.env.dev.example
···
+
# Coves Local Development Environment Configuration
+
# Copy this to .env.dev and fill in your values
+
#
+
# Quick Start:
+
# 1. cp .env.dev.example .env.dev
+
# 2. Generate OAuth key: go run cmd/genjwks/main.go (copy output to OAUTH_PRIVATE_JWK)
+
# 3. Generate cookie secret: openssl rand -hex 32
+
# 4. make dev-up # Start Docker services
+
# 5. make run # Start the server (uses -tags dev)
+
+
# =============================================================================
+
# Dev Mode Quick Reference
+
# =============================================================================
+
# REQUIRED for local OAuth to work with local PDS:
+
# IS_DEV_ENV=true # Master switch for dev mode
+
# PDS_URL=http://localhost:3001 # Local PDS for handle resolution
+
# PLC_DIRECTORY_URL=http://localhost:3002 # Local PLC directory
+
# APPVIEW_PUBLIC_URL=http://127.0.0.1:8081 # Use IP not localhost (RFC 8252)
+
#
+
# BUILD TAGS:
+
# make run - Runs with -tags dev (includes localhost OAuth resolvers)
+
# make build - Production binary (no dev code)
+
# make build-dev - Dev binary (includes dev code)
+
+
# =============================================================================
+
# PostgreSQL Configuration
+
# =============================================================================
+
POSTGRES_HOST=localhost
+
POSTGRES_PORT=5435
+
POSTGRES_DB=coves_dev
+
POSTGRES_USER=dev_user
+
POSTGRES_PASSWORD=dev_password
+
+
# Test database
+
POSTGRES_TEST_DB=coves_test
+
POSTGRES_TEST_USER=test_user
+
POSTGRES_TEST_PASSWORD=test_password
+
POSTGRES_TEST_PORT=5434
+
+
# =============================================================================
+
# PDS Configuration
+
# =============================================================================
+
PDS_HOSTNAME=localhost
+
PDS_PORT=3001
+
PDS_SERVICE_ENDPOINT=http://localhost:3000
+
PDS_DID_PLC_URL=http://plc-directory:3000
+
PDS_JWT_SECRET=local-dev-jwt-secret-change-in-production
+
PDS_ADMIN_PASSWORD=admin
+
PDS_SERVICE_HANDLE_DOMAINS=.local.coves.dev,.community.coves.social
+
PDS_PLC_ROTATION_KEY=<generate-a-random-hex-key>
+
+
# =============================================================================
+
# AppView Configuration
+
# =============================================================================
+
APPVIEW_PORT=8081
+
FIREHOSE_URL=ws://localhost:3001/xrpc/com.atproto.sync.subscribeRepos
+
PDS_URL=http://localhost:3001
+
APPVIEW_PUBLIC_URL=http://127.0.0.1:8081
+
+
# =============================================================================
+
# Jetstream Configuration
+
# =============================================================================
+
JETSTREAM_URL=ws://localhost:6008/subscribe
+
+
# =============================================================================
+
# Identity Resolution
+
# =============================================================================
+
IDENTITY_CACHE_TTL=24h
+
PLC_DIRECTORY_URL=http://localhost:3002
+
+
# =============================================================================
+
# OAuth Configuration (MUST GENERATE YOUR OWN)
+
# =============================================================================
+
# Generate with: go run cmd/genjwks/main.go
+
OAUTH_PRIVATE_JWK=<generate-your-own-jwk>
+
+
# Generate with: openssl rand -hex 32
+
OAUTH_COOKIE_SECRET=<generate-your-own-secret>
+
+
# =============================================================================
+
# Development Settings
+
# =============================================================================
+
ENV=development
+
NODE_ENV=development
+
IS_DEV_ENV=true
+
LOG_LEVEL=debug
+
LOG_ENABLED=true
+
+
# Security settings (ONLY for local dev - set to false in production!)
+
SKIP_DID_WEB_VERIFICATION=true
+
AUTH_SKIP_VERIFY=true
+
HS256_ISSUERS=http://localhost:3001
+25 -3
Makefile
···
-
.PHONY: help dev-up dev-down dev-logs dev-status dev-reset test e2e-test clean
+
.PHONY: help dev-up dev-down dev-logs dev-status dev-reset test e2e-test clean verify-stack create-test-account mobile-full-setup
# Default target - show help
.DEFAULT_GOAL := help
···
##@ Build & Run
-
build: ## Build the Coves server
-
@echo "$(GREEN)Building Coves server...$(RESET)"
+
build: ## Build the Coves server (production - no dev code)
+
@echo "$(GREEN)Building Coves server (production)...$(RESET)"
@go build -o server ./cmd/server
@echo "$(GREEN)โœ“ Build complete: ./server$(RESET)"
+
build-dev: ## Build the Coves server with dev mode (includes localhost OAuth resolvers)
+
@echo "$(GREEN)Building Coves server (dev mode)...$(RESET)"
+
@go build -tags dev -o server ./cmd/server
+
@echo "$(GREEN)โœ“ Build complete: ./server (with dev tags)$(RESET)"
+
run: ## Run the Coves server with dev environment (requires database running)
@./scripts/dev-run.sh
···
@adb reverse --remove-all || echo "$(YELLOW)No device connected$(RESET)"
@echo "$(GREEN)โœ“ Port forwarding removed$(RESET)"
+
verify-stack: ## Verify local development stack (PLC, PDS, configs)
+
@./scripts/verify-local-stack.sh
+
+
create-test-account: ## Create a test account on local PDS for OAuth testing
+
@./scripts/create-test-account.sh
+
+
mobile-full-setup: verify-stack create-test-account mobile-setup ## Full mobile setup: verify stack, create account, setup ports
+
@echo ""
+
@echo "$(GREEN)โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•$(RESET)"
+
@echo "$(GREEN) Mobile development environment ready! $(RESET)"
+
@echo "$(GREEN)โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•$(RESET)"
+
@echo ""
+
@echo "$(CYAN)Run the Flutter app with:$(RESET)"
+
@echo " $(YELLOW)cd /home/bretton/Code/coves-mobile$(RESET)"
+
@echo " $(YELLOW)flutter run --dart-define=ENVIRONMENT=local$(RESET)"
+
@echo ""
+
ngrok-up: ## Start ngrok tunnels (for iOS or WiFi testing - requires paid plan for 3 tunnels)
@echo "$(GREEN)Starting ngrok tunnels for mobile testing...$(RESET)"
@./scripts/start-ngrok.sh
+5 -1
docker-compose.dev.yml
···
# Bluesky Personal Data Server (PDS)
# Handles user repositories, DIDs, and CAR files
+
# NOTE: When using --profile plc, PDS waits for PLC directory to be healthy
pds:
image: ghcr.io/bluesky-social/pds:latest
container_name: coves-dev-pds
···
PDS_PORT: 3001 # Match external port for correct DID registration
PDS_DATA_DIRECTORY: /pds
PDS_BLOBSTORE_DISK_LOCATION: /pds/blocks
-
PDS_DID_PLC_URL: ${PDS_DID_PLC_URL:-https://plc.directory}
+
# IMPORTANT: For local E2E testing, this MUST point to local PLC directory
+
# Default to local PLC (http://plc-directory:3000) for full local stack
+
# The container hostname 'plc-directory' is used for Docker network communication
+
PDS_DID_PLC_URL: ${PDS_DID_PLC_URL:-http://plc-directory:3000}
# PDS_CRAWLERS not needed - we're not using a relay for local dev
# Note: PDS uses its own internal SQLite database and CAR file storage
+13 -2
internal/atproto/oauth/client.go
···
import (
"encoding/base64"
"fmt"
+
"log/slog"
"net/url"
"time"
···
PublicURL string
SealSecret string
PLCURL string
+
PDSURL string // For dev mode: resolve handles via local PDS
Scopes []string
SessionTTL time.Duration
SealedTokenTTL time.Duration
···
// Create indigo client config
var clientConfig oauth.ClientConfig
if config.DevMode {
-
// Dev mode: localhost with HTTP
-
callbackURL := "http://localhost:3000/oauth/callback"
+
// Dev mode: loopback with HTTP
+
// IMPORTANT: Use 127.0.0.1 instead of localhost per RFC 8252 - PDS rejects localhost
+
// The callback URL must match the APPVIEW_PUBLIC_URL from .env.dev
+
callbackURL := config.PublicURL + "/oauth/callback"
clientConfig = oauth.NewLocalhostConfig(callbackURL, config.Scopes)
+
slog.Info("dev mode: OAuth client configured",
+
"callback_url", callbackURL,
+
"client_id", clientConfig.ClientID)
} else {
// Production mode: public OAuth client with HTTPS
// client_id must be the URL of the client metadata document per atproto OAuth spec
···
// Use pointer since CacheDirectory methods have pointer receivers
cacheDir := identity.NewCacheDirectory(baseDir, 100_000, time.Hour*24, time.Minute*2, time.Minute*5)
clientApp.Dir = &cacheDir
+
// Log the PLC URL being used for OAuth directory resolution
+
fmt.Printf("๐Ÿ” OAuth client directory configured with PLC URL: %s (AllowPrivateIPs: %v)\n", config.PLCURL, config.AllowPrivateIPs)
+
} else {
+
fmt.Println("โš ๏ธ OAuth client using DEFAULT PLC directory (production plc.directory)")
}
return &OAuthClient{
+285
internal/atproto/oauth/dev_auth_resolver.go
···
+
//go:build dev
+
+
package oauth
+
+
import (
+
"context"
+
"encoding/json"
+
"fmt"
+
"log/slog"
+
"net/http"
+
"net/url"
+
"strings"
+
+
oauthlib "github.com/bluesky-social/indigo/atproto/auth/oauth"
+
"github.com/bluesky-social/indigo/atproto/identity"
+
"github.com/bluesky-social/indigo/atproto/syntax"
+
)
+
+
// DevAuthResolver is a custom OAuth resolver that allows HTTP localhost URLs for development.
+
// The standard indigo OAuth resolver requires HTTPS and no port numbers, which breaks local testing.
+
type DevAuthResolver struct {
+
Client *http.Client
+
UserAgent string
+
PDSURL string // For resolving handles via local PDS
+
handleResolver *DevHandleResolver
+
}
+
+
// ProtectedResourceMetadata matches the OAuth protected resource metadata document format
+
type ProtectedResourceMetadata struct {
+
Resource string `json:"resource"`
+
AuthorizationServers []string `json:"authorization_servers"`
+
}
+
+
// NewDevAuthResolver creates a resolver that accepts localhost HTTP URLs
+
func NewDevAuthResolver(pdsURL string, allowPrivateIPs bool) *DevAuthResolver {
+
resolver := &DevAuthResolver{
+
Client: NewSSRFSafeHTTPClient(allowPrivateIPs),
+
UserAgent: "Coves/1.0",
+
PDSURL: pdsURL,
+
}
+
// Create handle resolver for resolving handles via local PDS
+
if pdsURL != "" {
+
resolver.handleResolver = NewDevHandleResolver(pdsURL, allowPrivateIPs)
+
}
+
return resolver
+
}
+
+
// ResolveAuthServerURL resolves a PDS URL to an auth server URL.
+
// Unlike indigo's standard resolver, this allows HTTP and ports for localhost.
+
func (r *DevAuthResolver) ResolveAuthServerURL(ctx context.Context, hostURL string) (string, error) {
+
u, err := url.Parse(hostURL)
+
if err != nil {
+
return "", err
+
}
+
+
// For localhost, allow HTTP and port numbers
+
isLocalhost := u.Hostname() == "localhost" || u.Hostname() == "127.0.0.1"
+
if !isLocalhost {
+
// For non-localhost, enforce HTTPS and no port (standard rules)
+
if u.Scheme != "https" || u.Port() != "" {
+
return "", fmt.Errorf("not a valid public host URL: %s", hostURL)
+
}
+
}
+
+
// Build the protected resource document URL
+
var docURL string
+
if isLocalhost {
+
// For localhost, preserve the port and use HTTP
+
port := u.Port()
+
if port == "" {
+
port = "3001" // Default PDS port
+
}
+
docURL = fmt.Sprintf("http://%s:%s/.well-known/oauth-protected-resource", u.Hostname(), port)
+
} else {
+
docURL = fmt.Sprintf("https://%s/.well-known/oauth-protected-resource", u.Hostname())
+
}
+
+
// Fetch the protected resource document
+
req, err := http.NewRequestWithContext(ctx, "GET", docURL, nil)
+
if err != nil {
+
return "", err
+
}
+
if r.UserAgent != "" {
+
req.Header.Set("User-Agent", r.UserAgent)
+
}
+
+
resp, err := r.Client.Do(req)
+
if err != nil {
+
return "", fmt.Errorf("fetching protected resource document: %w", err)
+
}
+
defer resp.Body.Close()
+
+
if resp.StatusCode != http.StatusOK {
+
return "", fmt.Errorf("HTTP error fetching protected resource document: %d", resp.StatusCode)
+
}
+
+
var body ProtectedResourceMetadata
+
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
+
return "", fmt.Errorf("invalid protected resource document: %w", err)
+
}
+
+
if len(body.AuthorizationServers) < 1 {
+
return "", fmt.Errorf("no auth server URL in protected resource document")
+
}
+
+
authURL := body.AuthorizationServers[0]
+
+
// Validate the auth server URL (with localhost exception)
+
au, err := url.Parse(authURL)
+
if err != nil {
+
return "", fmt.Errorf("invalid auth server URL: %w", err)
+
}
+
+
authIsLocalhost := au.Hostname() == "localhost" || au.Hostname() == "127.0.0.1"
+
if !authIsLocalhost {
+
if au.Scheme != "https" || au.Port() != "" {
+
return "", fmt.Errorf("invalid auth server URL: %s", authURL)
+
}
+
}
+
+
return authURL, nil
+
}
+
+
// ResolveAuthServerMetadataDev fetches OAuth server metadata from a given auth server URL.
+
// Unlike indigo's resolver, this allows HTTP and ports for localhost.
+
func (r *DevAuthResolver) ResolveAuthServerMetadataDev(ctx context.Context, serverURL string) (*oauthlib.AuthServerMetadata, error) {
+
u, err := url.Parse(serverURL)
+
if err != nil {
+
return nil, err
+
}
+
+
// Build metadata URL - preserve port for localhost
+
var metaURL string
+
isLocalhost := u.Hostname() == "localhost" || u.Hostname() == "127.0.0.1"
+
if isLocalhost && u.Port() != "" {
+
metaURL = fmt.Sprintf("%s://%s:%s/.well-known/oauth-authorization-server", u.Scheme, u.Hostname(), u.Port())
+
} else if isLocalhost {
+
metaURL = fmt.Sprintf("%s://%s/.well-known/oauth-authorization-server", u.Scheme, u.Hostname())
+
} else {
+
metaURL = fmt.Sprintf("https://%s/.well-known/oauth-authorization-server", u.Hostname())
+
}
+
+
slog.Debug("dev mode: fetching auth server metadata", "url", metaURL)
+
+
req, err := http.NewRequestWithContext(ctx, "GET", metaURL, nil)
+
if err != nil {
+
return nil, err
+
}
+
if r.UserAgent != "" {
+
req.Header.Set("User-Agent", r.UserAgent)
+
}
+
+
resp, err := r.Client.Do(req)
+
if err != nil {
+
return nil, fmt.Errorf("fetching auth server metadata: %w", err)
+
}
+
defer resp.Body.Close()
+
+
if resp.StatusCode != http.StatusOK {
+
return nil, fmt.Errorf("HTTP error fetching auth server metadata: %d", resp.StatusCode)
+
}
+
+
var metadata oauthlib.AuthServerMetadata
+
if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil {
+
return nil, fmt.Errorf("invalid auth server metadata: %w", err)
+
}
+
+
// Skip validation for localhost (indigo's Validate checks HTTPS)
+
if !isLocalhost {
+
if err := metadata.Validate(serverURL); err != nil {
+
return nil, fmt.Errorf("invalid auth server metadata: %w", err)
+
}
+
}
+
+
return &metadata, nil
+
}
+
+
// StartDevAuthFlow performs OAuth flow for localhost development.
+
// This bypasses indigo's HTTPS validation for the auth server URL.
+
// It resolves the identity, gets the PDS endpoint, fetches auth server metadata,
+
// and returns a redirect URL for the user to approve.
+
func (r *DevAuthResolver) StartDevAuthFlow(ctx context.Context, client *OAuthClient, identifier string, dir identity.Directory) (string, error) {
+
var accountDID syntax.DID
+
var pdsEndpoint string
+
+
// Check if identifier is a handle or DID
+
if strings.HasPrefix(identifier, "did:") {
+
// It's a DID - look up via directory (PLC)
+
atid, err := syntax.ParseAtIdentifier(identifier)
+
if err != nil {
+
return "", fmt.Errorf("not a valid DID (%s): %w", identifier, err)
+
}
+
ident, err := dir.Lookup(ctx, *atid)
+
if err != nil {
+
return "", fmt.Errorf("failed to resolve DID (%s): %w", identifier, err)
+
}
+
accountDID = ident.DID
+
pdsEndpoint = ident.PDSEndpoint()
+
} else {
+
// It's a handle - resolve via local PDS first
+
if r.handleResolver == nil {
+
return "", fmt.Errorf("handle resolution not configured (PDS URL not set)")
+
}
+
+
// Resolve handle to DID via local PDS
+
did, err := r.handleResolver.ResolveHandle(ctx, identifier)
+
if err != nil {
+
return "", fmt.Errorf("failed to resolve handle via PDS (%s): %w", identifier, err)
+
}
+
if did == "" {
+
return "", fmt.Errorf("handle not found: %s", identifier)
+
}
+
+
slog.Info("dev mode: resolved handle via local PDS", "handle", identifier, "did", did)
+
+
// Parse the DID
+
parsedDID, err := syntax.ParseDID(did)
+
if err != nil {
+
return "", fmt.Errorf("invalid DID from PDS (%s): %w", did, err)
+
}
+
accountDID = parsedDID
+
+
// Now look up the DID document via PLC to get PDS endpoint
+
atid, err := syntax.ParseAtIdentifier(did)
+
if err != nil {
+
return "", fmt.Errorf("not a valid DID (%s): %w", did, err)
+
}
+
ident, err := dir.Lookup(ctx, *atid)
+
if err != nil {
+
return "", fmt.Errorf("failed to resolve DID document (%s): %w", did, err)
+
}
+
pdsEndpoint = ident.PDSEndpoint()
+
}
+
+
if pdsEndpoint == "" {
+
return "", fmt.Errorf("identity does not link to an atproto host (PDS)")
+
}
+
+
slog.Debug("dev mode: resolving auth server",
+
"did", accountDID,
+
"pds", pdsEndpoint)
+
+
// Resolve auth server URL (allowing HTTP for localhost)
+
authServerURL, err := r.ResolveAuthServerURL(ctx, pdsEndpoint)
+
if err != nil {
+
return "", fmt.Errorf("resolving auth server: %w", err)
+
}
+
+
slog.Info("dev mode: resolved auth server", "url", authServerURL)
+
+
// Fetch auth server metadata using our dev-friendly resolver
+
authMeta, err := r.ResolveAuthServerMetadataDev(ctx, authServerURL)
+
if err != nil {
+
return "", fmt.Errorf("fetching auth server metadata: %w", err)
+
}
+
+
slog.Debug("dev mode: got auth server metadata",
+
"issuer", authMeta.Issuer,
+
"authorization_endpoint", authMeta.AuthorizationEndpoint,
+
"token_endpoint", authMeta.TokenEndpoint)
+
+
// Send auth request (PAR) using indigo's method
+
info, err := client.ClientApp.SendAuthRequest(ctx, authMeta, client.Config.Scopes, identifier)
+
if err != nil {
+
return "", fmt.Errorf("auth request failed: %w", err)
+
}
+
+
// Set the account DID
+
info.AccountDID = &accountDID
+
+
// Persist auth request info
+
client.ClientApp.Store.SaveAuthRequestInfo(ctx, *info)
+
+
// Build redirect URL
+
params := url.Values{}
+
params.Set("client_id", client.ClientApp.Config.ClientID)
+
params.Set("request_uri", info.RequestURI)
+
+
authEndpoint := authMeta.AuthorizationEndpoint
+
redirectURL := fmt.Sprintf("%s?%s", authEndpoint, params.Encode())
+
+
slog.Info("dev mode: OAuth redirect URL built", "url_prefix", authEndpoint)
+
+
return redirectURL, nil
+
}
+106
internal/atproto/oauth/dev_resolver.go
···
+
//go:build dev
+
+
package oauth
+
+
import (
+
"context"
+
"encoding/json"
+
"fmt"
+
"log/slog"
+
"net/http"
+
"net/url"
+
"strings"
+
"time"
+
)
+
+
// DevHandleResolver resolves handles via local PDS for development
+
// This is needed because local handles (e.g., user.local.coves.dev) can't be
+
// resolved via standard DNS/HTTP well-known methods - they only exist on the local PDS.
+
type DevHandleResolver struct {
+
pdsURL string
+
httpClient *http.Client
+
}
+
+
// NewDevHandleResolver creates a resolver that queries local PDS for handle resolution
+
func NewDevHandleResolver(pdsURL string, allowPrivateIPs bool) *DevHandleResolver {
+
return &DevHandleResolver{
+
pdsURL: strings.TrimSuffix(pdsURL, "/"),
+
httpClient: NewSSRFSafeHTTPClient(allowPrivateIPs),
+
}
+
}
+
+
// ResolveHandle queries the local PDS to resolve a handle to a DID
+
// Returns the DID if successful, or empty string if not found
+
func (r *DevHandleResolver) ResolveHandle(ctx context.Context, handle string) (string, error) {
+
if r.pdsURL == "" {
+
return "", fmt.Errorf("PDS URL not configured")
+
}
+
+
// Build the resolve handle URL
+
resolveURL := fmt.Sprintf("%s/xrpc/com.atproto.identity.resolveHandle?handle=%s",
+
r.pdsURL, url.QueryEscape(handle))
+
+
// Create request with context and timeout
+
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
+
defer cancel()
+
+
req, err := http.NewRequestWithContext(ctx, "GET", resolveURL, nil)
+
if err != nil {
+
return "", fmt.Errorf("failed to create request: %w", err)
+
}
+
req.Header.Set("User-Agent", "Coves/1.0")
+
+
// Execute request
+
resp, err := r.httpClient.Do(req)
+
if err != nil {
+
return "", fmt.Errorf("failed to query PDS: %w", err)
+
}
+
defer resp.Body.Close()
+
+
// Check response status
+
if resp.StatusCode == http.StatusNotFound || resp.StatusCode == http.StatusBadRequest {
+
return "", nil // Handle not found
+
}
+
if resp.StatusCode != http.StatusOK {
+
return "", fmt.Errorf("PDS returned status %d", resp.StatusCode)
+
}
+
+
// Parse response
+
var result struct {
+
DID string `json:"did"`
+
}
+
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
+
return "", fmt.Errorf("failed to parse PDS response: %w", err)
+
}
+
+
if result.DID == "" {
+
return "", nil // No DID in response
+
}
+
+
slog.Debug("resolved handle via local PDS",
+
"handle", handle,
+
"did", result.DID,
+
"pds_url", r.pdsURL)
+
+
return result.DID, nil
+
}
+
+
// ResolveIdentifier attempts to resolve a handle to DID, or returns the DID if already provided
+
// This is the main entry point for the handlers
+
func (r *DevHandleResolver) ResolveIdentifier(ctx context.Context, identifier string) (string, error) {
+
// If it's already a DID, return as-is
+
if strings.HasPrefix(identifier, "did:") {
+
return identifier, nil
+
}
+
+
// Try to resolve the handle via local PDS
+
did, err := r.ResolveHandle(ctx, identifier)
+
if err != nil {
+
return "", fmt.Errorf("failed to resolve handle via PDS: %w", err)
+
}
+
if did == "" {
+
return "", fmt.Errorf("handle not found on local PDS: %s", identifier)
+
}
+
+
return did, nil
+
}
+41
internal/atproto/oauth/dev_stubs.go
···
+
//go:build !dev
+
+
package oauth
+
+
import (
+
"context"
+
+
"github.com/bluesky-social/indigo/atproto/identity"
+
)
+
+
// DevHandleResolver is a stub for production builds.
+
// The actual implementation is in dev_resolver.go (only compiled with -tags dev).
+
type DevHandleResolver struct{}
+
+
// NewDevHandleResolver returns nil in production builds.
+
// Dev mode features are only available when built with -tags dev.
+
func NewDevHandleResolver(pdsURL string, allowPrivateIPs bool) *DevHandleResolver {
+
return nil
+
}
+
+
// ResolveHandle is a stub that should never be called in production.
+
// The nil check in handlers.go prevents this from being reached.
+
func (r *DevHandleResolver) ResolveHandle(ctx context.Context, handle string) (string, error) {
+
panic("dev mode: ResolveHandle called in production build - this should never happen")
+
}
+
+
// DevAuthResolver is a stub for production builds.
+
// The actual implementation is in dev_auth_resolver.go (only compiled with -tags dev).
+
type DevAuthResolver struct{}
+
+
// NewDevAuthResolver returns nil in production builds.
+
// Dev mode features are only available when built with -tags dev.
+
func NewDevAuthResolver(pdsURL string, allowPrivateIPs bool) *DevAuthResolver {
+
return nil
+
}
+
+
// StartDevAuthFlow is a stub that should never be called in production.
+
// The nil check in handlers.go prevents this from being reached.
+
func (r *DevAuthResolver) StartDevAuthFlow(ctx context.Context, client *OAuthClient, identifier string, dir identity.Directory) (string, error) {
+
panic("dev mode: StartDevAuthFlow called in production build - this should never happen")
+
}
+5 -1
scripts/dev-run.sh
···
#!/bin/bash
# Development server runner - loads .env.dev before starting
+
# Uses -tags dev to include dev-only code (localhost OAuth resolvers, etc.)
set -a # automatically export all variables
source .env.dev
···
echo " IS_DEV_ENV: $IS_DEV_ENV"
echo " PLC_DIRECTORY_URL: $PLC_DIRECTORY_URL"
echo " JETSTREAM_URL: $JETSTREAM_URL"
+
echo " APPVIEW_PUBLIC_URL: $APPVIEW_PUBLIC_URL"
+
echo " PDS_URL: $PDS_URL"
+
echo " Build tags: dev"
echo ""
-
go run ./cmd/server
+
go run -tags dev ./cmd/server
+125
internal/atproto/pds/factory.go
···
+
package pds
+
+
import (
+
"context"
+
"fmt"
+
"net/http"
+
+
"github.com/bluesky-social/indigo/atproto/atclient"
+
"github.com/bluesky-social/indigo/atproto/auth/oauth"
+
"github.com/bluesky-social/indigo/atproto/syntax"
+
)
+
+
// NewFromOAuthSession creates a PDS client from an OAuth session.
+
// This uses DPoP authentication - the correct method for OAuth tokens.
+
//
+
// The oauthClient is used to resume the session and get a properly configured
+
// APIClient that handles DPoP proof generation and nonce rotation automatically.
+
func NewFromOAuthSession(ctx context.Context, oauthClient *oauth.ClientApp, sessionData *oauth.ClientSessionData) (Client, error) {
+
if oauthClient == nil {
+
return nil, fmt.Errorf("oauthClient is required")
+
}
+
if sessionData == nil {
+
return nil, fmt.Errorf("sessionData is required")
+
}
+
+
// ResumeSession reconstructs the OAuth session with DPoP key
+
// and returns a ClientSession that can generate authenticated requests
+
sess, err := oauthClient.ResumeSession(ctx, sessionData.AccountDID, sessionData.SessionID)
+
if err != nil {
+
return nil, fmt.Errorf("failed to resume OAuth session: %w", err)
+
}
+
+
// APIClient() returns an *atclient.APIClient configured with DPoP auth
+
apiClient := sess.APIClient()
+
+
return &client{
+
apiClient: apiClient,
+
did: sessionData.AccountDID.String(),
+
host: sessionData.HostURL,
+
}, nil
+
}
+
+
// NewFromPasswordAuth creates a PDS client using password authentication.
+
// This uses Bearer token authentication from com.atproto.server.createSession.
+
//
+
// Primarily used for:
+
// - E2E tests with local PDS
+
// - Development/debugging tools
+
// - Non-OAuth clients
+
//
+
// Note: This establishes a new session with the PDS. For repeated calls,
+
// consider using NewFromAccessToken if you already have a valid access token.
+
func NewFromPasswordAuth(ctx context.Context, host, handle, password string) (Client, error) {
+
if host == "" {
+
return nil, fmt.Errorf("host is required")
+
}
+
if handle == "" {
+
return nil, fmt.Errorf("handle is required")
+
}
+
if password == "" {
+
return nil, fmt.Errorf("password is required")
+
}
+
+
// LoginWithPasswordHost creates a session and returns an authenticated APIClient
+
// This handles the createSession call and Bearer token setup
+
apiClient, err := atclient.LoginWithPasswordHost(ctx, host, handle, password, "", nil)
+
if err != nil {
+
return nil, fmt.Errorf("failed to login with password: %w", err)
+
}
+
+
// Get DID from the authenticated client
+
did := ""
+
if apiClient.AccountDID != nil {
+
did = apiClient.AccountDID.String()
+
}
+
+
return &client{
+
apiClient: apiClient,
+
did: did,
+
host: host,
+
}, nil
+
}
+
+
// NewFromAccessToken creates a PDS client from an existing access token.
+
// This is useful when you already have a valid Bearer token (e.g., from createSession)
+
// and don't want to re-authenticate.
+
//
+
// WARNING: This creates a client with Bearer auth only. Do NOT use this with
+
// OAuth access tokens - those require DPoP proofs. Use NewFromOAuthSession instead.
+
func NewFromAccessToken(host, did, accessToken string) (Client, error) {
+
if host == "" {
+
return nil, fmt.Errorf("host is required")
+
}
+
if did == "" {
+
return nil, fmt.Errorf("did is required")
+
}
+
if accessToken == "" {
+
return nil, fmt.Errorf("accessToken is required")
+
}
+
+
// Create APIClient with Bearer auth
+
apiClient := atclient.NewAPIClient(host)
+
apiClient.Auth = &bearerAuth{token: accessToken}
+
+
return &client{
+
apiClient: apiClient,
+
did: did,
+
host: host,
+
}, nil
+
}
+
+
// bearerAuth implements atclient.AuthMethod for simple Bearer token auth.
+
// This is used for password-based sessions where DPoP is not required.
+
type bearerAuth struct {
+
token string
+
}
+
+
// Ensure bearerAuth implements atclient.AuthMethod.
+
var _ atclient.AuthMethod = (*bearerAuth)(nil)
+
+
// DoWithAuth adds the Bearer token to the request and executes it.
+
func (b *bearerAuth) DoWithAuth(c *http.Client, req *http.Request, _ syntax.NSID) (*http.Response, error) {
+
req.Header.Set("Authorization", "Bearer "+b.token)
+
return c.Do(req)
+
}
+267
cmd/reindex-votes/main.go
···
+
// cmd/reindex-votes/main.go
+
// Quick tool to reindex votes from PDS to AppView database
+
package main
+
+
import (
+
"context"
+
"database/sql"
+
"encoding/json"
+
"fmt"
+
"log"
+
"net/http"
+
"net/url"
+
"os"
+
"strings"
+
"time"
+
+
_ "github.com/lib/pq"
+
)
+
+
type ListRecordsResponse struct {
+
Records []Record `json:"records"`
+
Cursor string `json:"cursor"`
+
}
+
+
type Record struct {
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
Value map[string]interface{} `json:"value"`
+
}
+
+
func main() {
+
// Get config from env
+
dbURL := os.Getenv("DATABASE_URL")
+
if dbURL == "" {
+
dbURL = "postgres://dev_user:dev_password@localhost:5435/coves_dev?sslmode=disable"
+
}
+
pdsURL := os.Getenv("PDS_URL")
+
if pdsURL == "" {
+
pdsURL = "http://localhost:3001"
+
}
+
+
log.Printf("Connecting to database...")
+
db, err := sql.Open("postgres", dbURL)
+
if err != nil {
+
log.Fatalf("Failed to connect to database: %v", err)
+
}
+
defer db.Close()
+
+
ctx := context.Background()
+
+
// Get all accounts directly from the PDS
+
log.Printf("Fetching accounts from PDS (%s)...", pdsURL)
+
dids, err := fetchAllAccountsFromPDS(pdsURL)
+
if err != nil {
+
log.Fatalf("Failed to fetch accounts from PDS: %v", err)
+
}
+
log.Printf("Found %d accounts on PDS to check for votes", len(dids))
+
+
// Reset vote counts first
+
log.Printf("Resetting all vote counts...")
+
if _, err := db.ExecContext(ctx, "DELETE FROM votes"); err != nil {
+
log.Fatalf("Failed to clear votes table: %v", err)
+
}
+
if _, err := db.ExecContext(ctx, "UPDATE posts SET upvote_count = 0, downvote_count = 0, score = 0"); err != nil {
+
log.Fatalf("Failed to reset post vote counts: %v", err)
+
}
+
if _, err := db.ExecContext(ctx, "UPDATE comments SET upvote_count = 0, downvote_count = 0, score = 0"); err != nil {
+
log.Fatalf("Failed to reset comment vote counts: %v", err)
+
}
+
+
// For each user, fetch their votes from PDS
+
totalVotes := 0
+
for _, did := range dids {
+
votes, err := fetchVotesFromPDS(pdsURL, did)
+
if err != nil {
+
log.Printf("Warning: failed to fetch votes for %s: %v", did, err)
+
continue
+
}
+
+
if len(votes) == 0 {
+
continue
+
}
+
+
log.Printf("Found %d votes for %s", len(votes), did)
+
+
// Index each vote
+
for _, vote := range votes {
+
if err := indexVote(ctx, db, did, vote); err != nil {
+
log.Printf("Warning: failed to index vote %s: %v", vote.URI, err)
+
continue
+
}
+
totalVotes++
+
}
+
}
+
+
log.Printf("โœ“ Reindexed %d votes from PDS", totalVotes)
+
}
+
+
// fetchAllAccountsFromPDS queries the PDS sync API to get all repo DIDs
+
func fetchAllAccountsFromPDS(pdsURL string) ([]string, error) {
+
// Use com.atproto.sync.listRepos to get all repos on this PDS
+
var allDIDs []string
+
cursor := ""
+
+
for {
+
reqURL := fmt.Sprintf("%s/xrpc/com.atproto.sync.listRepos?limit=100", pdsURL)
+
if cursor != "" {
+
reqURL += "&cursor=" + url.QueryEscape(cursor)
+
}
+
+
resp, err := http.Get(reqURL)
+
if err != nil {
+
return nil, fmt.Errorf("HTTP request failed: %w", err)
+
}
+
defer resp.Body.Close()
+
+
if resp.StatusCode != 200 {
+
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
+
}
+
+
var result struct {
+
Repos []struct {
+
DID string `json:"did"`
+
} `json:"repos"`
+
Cursor string `json:"cursor"`
+
}
+
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
+
return nil, fmt.Errorf("failed to decode response: %w", err)
+
}
+
+
for _, repo := range result.Repos {
+
allDIDs = append(allDIDs, repo.DID)
+
}
+
+
if result.Cursor == "" {
+
break
+
}
+
cursor = result.Cursor
+
}
+
+
return allDIDs, nil
+
}
+
+
func fetchVotesFromPDS(pdsURL, did string) ([]Record, error) {
+
var allRecords []Record
+
cursor := ""
+
collection := "social.coves.feed.vote"
+
+
for {
+
reqURL := fmt.Sprintf("%s/xrpc/com.atproto.repo.listRecords?repo=%s&collection=%s&limit=100",
+
pdsURL, url.QueryEscape(did), url.QueryEscape(collection))
+
if cursor != "" {
+
reqURL += "&cursor=" + url.QueryEscape(cursor)
+
}
+
+
resp, err := http.Get(reqURL)
+
if err != nil {
+
return nil, fmt.Errorf("HTTP request failed: %w", err)
+
}
+
defer resp.Body.Close()
+
+
if resp.StatusCode == 400 {
+
// User doesn't exist on this PDS or has no records - that's OK
+
return nil, nil
+
}
+
if resp.StatusCode != 200 {
+
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
+
}
+
+
var result ListRecordsResponse
+
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
+
return nil, fmt.Errorf("failed to decode response: %w", err)
+
}
+
+
allRecords = append(allRecords, result.Records...)
+
+
if result.Cursor == "" {
+
break
+
}
+
cursor = result.Cursor
+
}
+
+
return allRecords, nil
+
}
+
+
func indexVote(ctx context.Context, db *sql.DB, voterDID string, record Record) error {
+
// Extract vote data from record
+
subject, ok := record.Value["subject"].(map[string]interface{})
+
if !ok {
+
return fmt.Errorf("missing subject")
+
}
+
subjectURI, _ := subject["uri"].(string)
+
subjectCID, _ := subject["cid"].(string)
+
direction, _ := record.Value["direction"].(string)
+
createdAtStr, _ := record.Value["createdAt"].(string)
+
+
if subjectURI == "" || direction == "" {
+
return fmt.Errorf("invalid vote record: missing required fields")
+
}
+
+
// Parse created_at
+
createdAt, err := time.Parse(time.RFC3339, createdAtStr)
+
if err != nil {
+
createdAt = time.Now()
+
}
+
+
// Extract rkey from URI (at://did/collection/rkey)
+
parts := strings.Split(record.URI, "/")
+
if len(parts) < 5 {
+
return fmt.Errorf("invalid URI format: %s", record.URI)
+
}
+
rkey := parts[len(parts)-1]
+
+
// Start transaction
+
tx, err := db.BeginTx(ctx, nil)
+
if err != nil {
+
return fmt.Errorf("failed to begin transaction: %w", err)
+
}
+
defer tx.Rollback()
+
+
// Insert vote
+
_, err = tx.ExecContext(ctx, `
+
INSERT INTO votes (uri, cid, rkey, voter_did, subject_uri, subject_cid, direction, created_at, indexed_at)
+
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, NOW())
+
ON CONFLICT (uri) DO NOTHING
+
`, record.URI, record.CID, rkey, voterDID, subjectURI, subjectCID, direction, createdAt)
+
if err != nil {
+
return fmt.Errorf("failed to insert vote: %w", err)
+
}
+
+
// Update post/comment counts
+
collection := extractCollectionFromURI(subjectURI)
+
var updateQuery string
+
+
switch collection {
+
case "social.coves.community.post":
+
if direction == "up" {
+
updateQuery = `UPDATE posts SET upvote_count = upvote_count + 1, score = upvote_count + 1 - downvote_count WHERE uri = $1 AND deleted_at IS NULL`
+
} else {
+
updateQuery = `UPDATE posts SET downvote_count = downvote_count + 1, score = upvote_count - (downvote_count + 1) WHERE uri = $1 AND deleted_at IS NULL`
+
}
+
case "social.coves.community.comment":
+
if direction == "up" {
+
updateQuery = `UPDATE comments SET upvote_count = upvote_count + 1, score = upvote_count + 1 - downvote_count WHERE uri = $1 AND deleted_at IS NULL`
+
} else {
+
updateQuery = `UPDATE comments SET downvote_count = downvote_count + 1, score = upvote_count - (downvote_count + 1) WHERE uri = $1 AND deleted_at IS NULL`
+
}
+
default:
+
// Unknown collection, just index the vote
+
return tx.Commit()
+
}
+
+
if _, err := tx.ExecContext(ctx, updateQuery, subjectURI); err != nil {
+
return fmt.Errorf("failed to update vote counts: %w", err)
+
}
+
+
return tx.Commit()
+
}
+
+
func extractCollectionFromURI(uri string) string {
+
// at://did:plc:xxx/social.coves.community.post/rkey
+
parts := strings.Split(uri, "/")
+
if len(parts) >= 4 {
+
return parts[3]
+
}
+
return ""
+
}
+7 -5
internal/api/routes/communityFeed.go
···
import (
"Coves/internal/api/handlers/communityFeed"
+
"Coves/internal/api/middleware"
"Coves/internal/core/communityFeeds"
+
"Coves/internal/core/votes"
"github.com/go-chi/chi/v5"
)
···
func RegisterCommunityFeedRoutes(
r chi.Router,
feedService communityFeeds.Service,
+
voteService votes.Service,
+
authMiddleware *middleware.OAuthAuthMiddleware,
) {
// Create handlers
-
getCommunityHandler := communityFeed.NewGetCommunityHandler(feedService)
+
getCommunityHandler := communityFeed.NewGetCommunityHandler(feedService, voteService)
// GET /xrpc/social.coves.communityFeed.getCommunity
-
// Public endpoint - basic community sorting only for Alpha
-
// TODO(feed-generator): Add OptionalAuth middleware when implementing viewer-specific state
-
// (blocks, upvotes, saves, etc.) in feed generator skeleton
-
r.Get("/xrpc/social.coves.communityFeed.getCommunity", getCommunityHandler.HandleGetCommunity)
+
// Public endpoint with optional auth for viewer-specific state (vote state)
+
r.With(authMiddleware.OptionalAuth).Get("/xrpc/social.coves.communityFeed.getCommunity", getCommunityHandler.HandleGetCommunity)
}
+3 -1
internal/api/routes/timeline.go
···
"Coves/internal/api/handlers/timeline"
"Coves/internal/api/middleware"
timelineCore "Coves/internal/core/timeline"
+
"Coves/internal/core/votes"
"github.com/go-chi/chi/v5"
)
···
func RegisterTimelineRoutes(
r chi.Router,
timelineService timelineCore.Service,
+
voteService votes.Service,
authMiddleware *middleware.OAuthAuthMiddleware,
) {
// Create handlers
-
getTimelineHandler := timeline.NewGetTimelineHandler(timelineService)
+
getTimelineHandler := timeline.NewGetTimelineHandler(timelineService, voteService)
// GET /xrpc/social.coves.feed.getTimeline
// Requires authentication - user must be logged in to see their timeline
+221
internal/core/votes/cache.go
···
+
package votes
+
+
import (
+
"context"
+
"fmt"
+
"log/slog"
+
"strings"
+
"sync"
+
"time"
+
+
"Coves/internal/atproto/pds"
+
)
+
+
// CachedVote represents a vote stored in the cache
+
type CachedVote struct {
+
Direction string // "up" or "down"
+
URI string // vote record URI (at://did/collection/rkey)
+
RKey string // record key
+
}
+
+
// VoteCache provides an in-memory cache of user votes fetched from their PDS.
+
// This avoids eventual consistency issues with the AppView database.
+
type VoteCache struct {
+
mu sync.RWMutex
+
votes map[string]map[string]*CachedVote // userDID -> subjectURI -> vote
+
expiry map[string]time.Time // userDID -> expiry time
+
ttl time.Duration
+
logger *slog.Logger
+
}
+
+
// NewVoteCache creates a new vote cache with the specified TTL
+
func NewVoteCache(ttl time.Duration, logger *slog.Logger) *VoteCache {
+
if logger == nil {
+
logger = slog.Default()
+
}
+
return &VoteCache{
+
votes: make(map[string]map[string]*CachedVote),
+
expiry: make(map[string]time.Time),
+
ttl: ttl,
+
logger: logger,
+
}
+
}
+
+
// GetVotesForUser returns all cached votes for a user.
+
// Returns nil if cache is empty or expired for this user.
+
func (c *VoteCache) GetVotesForUser(userDID string) map[string]*CachedVote {
+
c.mu.RLock()
+
defer c.mu.RUnlock()
+
+
// Check if cache exists and is not expired
+
expiry, exists := c.expiry[userDID]
+
if !exists || time.Now().After(expiry) {
+
return nil
+
}
+
+
return c.votes[userDID]
+
}
+
+
// GetVote returns the cached vote for a specific subject, or nil if not found/expired
+
func (c *VoteCache) GetVote(userDID, subjectURI string) *CachedVote {
+
votes := c.GetVotesForUser(userDID)
+
if votes == nil {
+
return nil
+
}
+
return votes[subjectURI]
+
}
+
+
// IsCached returns true if the user's votes are cached and not expired
+
func (c *VoteCache) IsCached(userDID string) bool {
+
c.mu.RLock()
+
defer c.mu.RUnlock()
+
+
expiry, exists := c.expiry[userDID]
+
return exists && time.Now().Before(expiry)
+
}
+
+
// SetVotesForUser replaces all cached votes for a user
+
func (c *VoteCache) SetVotesForUser(userDID string, votes map[string]*CachedVote) {
+
c.mu.Lock()
+
defer c.mu.Unlock()
+
+
c.votes[userDID] = votes
+
c.expiry[userDID] = time.Now().Add(c.ttl)
+
+
c.logger.Debug("vote cache updated",
+
"user", userDID,
+
"vote_count", len(votes),
+
"expires_at", c.expiry[userDID])
+
}
+
+
// SetVote adds or updates a single vote in the cache
+
func (c *VoteCache) SetVote(userDID, subjectURI string, vote *CachedVote) {
+
c.mu.Lock()
+
defer c.mu.Unlock()
+
+
if c.votes[userDID] == nil {
+
c.votes[userDID] = make(map[string]*CachedVote)
+
}
+
+
c.votes[userDID][subjectURI] = vote
+
+
// Always extend expiry on vote action - active users keep their cache fresh
+
c.expiry[userDID] = time.Now().Add(c.ttl)
+
+
c.logger.Debug("vote cached",
+
"user", userDID,
+
"subject", subjectURI,
+
"direction", vote.Direction)
+
}
+
+
// RemoveVote removes a vote from the cache (for toggle-off)
+
func (c *VoteCache) RemoveVote(userDID, subjectURI string) {
+
c.mu.Lock()
+
defer c.mu.Unlock()
+
+
if c.votes[userDID] != nil {
+
delete(c.votes[userDID], subjectURI)
+
+
// Extend expiry on vote action - active users keep their cache fresh
+
c.expiry[userDID] = time.Now().Add(c.ttl)
+
+
c.logger.Debug("vote removed from cache",
+
"user", userDID,
+
"subject", subjectURI)
+
}
+
}
+
+
// Invalidate removes all cached votes for a user
+
func (c *VoteCache) Invalidate(userDID string) {
+
c.mu.Lock()
+
defer c.mu.Unlock()
+
+
delete(c.votes, userDID)
+
delete(c.expiry, userDID)
+
+
c.logger.Debug("vote cache invalidated", "user", userDID)
+
}
+
+
// FetchAndCacheFromPDS fetches all votes from the user's PDS and caches them.
+
// This should be called on first authenticated request or when cache is expired.
+
func (c *VoteCache) FetchAndCacheFromPDS(ctx context.Context, pdsClient pds.Client) error {
+
userDID := pdsClient.DID()
+
+
c.logger.Debug("fetching votes from PDS",
+
"user", userDID,
+
"pds", pdsClient.HostURL())
+
+
votes, err := c.fetchAllVotesFromPDS(ctx, pdsClient)
+
if err != nil {
+
return fmt.Errorf("failed to fetch votes from PDS: %w", err)
+
}
+
+
c.SetVotesForUser(userDID, votes)
+
+
c.logger.Info("vote cache populated from PDS",
+
"user", userDID,
+
"vote_count", len(votes))
+
+
return nil
+
}
+
+
// fetchAllVotesFromPDS paginates through all vote records on the user's PDS
+
func (c *VoteCache) fetchAllVotesFromPDS(ctx context.Context, pdsClient pds.Client) (map[string]*CachedVote, error) {
+
votes := make(map[string]*CachedVote)
+
cursor := ""
+
const pageSize = 100
+
const collection = "social.coves.feed.vote"
+
+
for {
+
result, err := pdsClient.ListRecords(ctx, collection, pageSize, cursor)
+
if err != nil {
+
if pds.IsAuthError(err) {
+
return nil, ErrNotAuthorized
+
}
+
return nil, fmt.Errorf("listRecords failed: %w", err)
+
}
+
+
for _, rec := range result.Records {
+
// Extract subject from record value
+
subject, ok := rec.Value["subject"].(map[string]any)
+
if !ok {
+
continue
+
}
+
+
subjectURI, ok := subject["uri"].(string)
+
if !ok || subjectURI == "" {
+
continue
+
}
+
+
direction, _ := rec.Value["direction"].(string)
+
if direction == "" {
+
continue
+
}
+
+
// Extract rkey from URI
+
rkey := extractRKeyFromURI(rec.URI)
+
+
votes[subjectURI] = &CachedVote{
+
Direction: direction,
+
URI: rec.URI,
+
RKey: rkey,
+
}
+
}
+
+
if result.Cursor == "" {
+
break
+
}
+
cursor = result.Cursor
+
}
+
+
return votes, nil
+
}
+
+
// extractRKeyFromURI extracts the rkey from an AT-URI (at://did/collection/rkey)
+
func extractRKeyFromURI(uri string) string {
+
parts := strings.Split(uri, "/")
+
if len(parts) >= 5 {
+
return parts[len(parts)-1]
+
}
+
return ""
+
}
+14
internal/core/votes/service.go
···
// - Deletes the user's vote record from their PDS
// - AppView will soft-delete via Jetstream consumer
DeleteVote(ctx context.Context, session *oauthlib.ClientSessionData, req DeleteVoteRequest) error
+
+
// EnsureCachePopulated fetches the user's votes from their PDS if not already cached.
+
// This should be called before rendering feeds to ensure vote state is available.
+
// If cache is already populated and not expired, this is a no-op.
+
EnsureCachePopulated(ctx context.Context, session *oauthlib.ClientSessionData) error
+
+
// GetViewerVote returns the viewer's vote for a specific subject, or nil if not voted.
+
// Returns from cache if available, otherwise returns nil (caller should ensure cache is populated).
+
GetViewerVote(userDID, subjectURI string) *CachedVote
+
+
// GetViewerVotesForSubjects returns the viewer's votes for multiple subjects.
+
// Returns a map of subjectURI -> CachedVote for subjects the user has voted on.
+
// This is efficient for batch lookups when rendering feeds.
+
GetViewerVotesForSubjects(userDID string, subjectURIs []string) map[string]*CachedVote
}
// CreateVoteRequest contains the parameters for creating a vote
+84 -2
internal/core/votes/service_impl.go
···
oauthStore oauth.ClientAuthStore
logger *slog.Logger
pdsClientFactory PDSClientFactory // Optional, for testing. If nil, uses OAuth.
+
cache *VoteCache // In-memory cache of user votes from PDS
}
// NewService creates a new vote service instance
-
func NewService(repo Repository, oauthClient *oauthclient.OAuthClient, oauthStore oauth.ClientAuthStore, logger *slog.Logger) Service {
+
func NewService(repo Repository, oauthClient *oauthclient.OAuthClient, oauthStore oauth.ClientAuthStore, cache *VoteCache, logger *slog.Logger) Service {
if logger == nil {
logger = slog.Default()
}
···
repo: repo,
oauthClient: oauthClient,
oauthStore: oauthStore,
+
cache: cache,
logger: logger,
}
}
// NewServiceWithPDSFactory creates a vote service with a custom PDS client factory.
// This is primarily for testing with password-based authentication.
-
func NewServiceWithPDSFactory(repo Repository, logger *slog.Logger, factory PDSClientFactory) Service {
+
func NewServiceWithPDSFactory(repo Repository, cache *VoteCache, logger *slog.Logger, factory PDSClientFactory) Service {
if logger == nil {
logger = slog.Default()
}
return &voteService{
repo: repo,
+
cache: cache,
logger: logger,
pdsClientFactory: factory,
}
···
"subject", req.Subject.URI,
"direction", req.Direction)
+
// Update cache - remove the vote
+
if s.cache != nil {
+
s.cache.RemoveVote(session.AccountDID.String(), req.Subject.URI)
+
}
+
// Return empty response to indicate deletion
return &CreateVoteResponse{
URI: "",
···
"uri", uri,
"cid", cid)
+
// Update cache - add the new vote
+
if s.cache != nil {
+
s.cache.SetVote(session.AccountDID.String(), req.Subject.URI, &CachedVote{
+
Direction: req.Direction,
+
URI: uri,
+
RKey: extractRKeyFromURI(uri),
+
})
+
}
+
return &CreateVoteResponse{
URI: uri,
CID: cid,
···
"subject", req.Subject.URI,
"uri", existing.URI)
+
// Update cache - remove the vote
+
if s.cache != nil {
+
s.cache.RemoveVote(session.AccountDID.String(), req.Subject.URI)
+
}
+
return nil
}
···
// No vote found for this subject after checking all pages
return nil, nil
}
+
+
// EnsureCachePopulated fetches the user's votes from their PDS if not already cached.
+
func (s *voteService) EnsureCachePopulated(ctx context.Context, session *oauth.ClientSessionData) error {
+
if s.cache == nil {
+
return nil // No cache configured
+
}
+
+
// Check if already cached
+
if s.cache.IsCached(session.AccountDID.String()) {
+
return nil
+
}
+
+
// Create PDS client for this session
+
pdsClient, err := s.getPDSClient(ctx, session)
+
if err != nil {
+
s.logger.Error("failed to create PDS client for cache population",
+
"error", err,
+
"user", session.AccountDID)
+
return fmt.Errorf("failed to create PDS client: %w", err)
+
}
+
+
// Fetch and cache votes from PDS
+
if err := s.cache.FetchAndCacheFromPDS(ctx, pdsClient); err != nil {
+
s.logger.Error("failed to populate vote cache from PDS",
+
"error", err,
+
"user", session.AccountDID)
+
return fmt.Errorf("failed to populate vote cache: %w", err)
+
}
+
+
return nil
+
}
+
+
// GetViewerVote returns the viewer's vote for a specific subject, or nil if not voted.
+
func (s *voteService) GetViewerVote(userDID, subjectURI string) *CachedVote {
+
if s.cache == nil {
+
return nil
+
}
+
return s.cache.GetVote(userDID, subjectURI)
+
}
+
+
// GetViewerVotesForSubjects returns the viewer's votes for multiple subjects.
+
func (s *voteService) GetViewerVotesForSubjects(userDID string, subjectURIs []string) map[string]*CachedVote {
+
result := make(map[string]*CachedVote)
+
if s.cache == nil {
+
return result
+
}
+
+
allVotes := s.cache.GetVotesForUser(userDID)
+
if allVotes == nil {
+
return result
+
}
+
+
for _, uri := range subjectURIs {
+
if vote, exists := allVotes[uri]; exists {
+
result[uri] = vote
+
}
+
}
+
+
return result
+
}
+76 -16
internal/atproto/jetstream/vote_consumer.go
···
}
// Atomically: Index vote + Update post counts
-
if err := c.indexVoteAndUpdateCounts(ctx, vote); err != nil {
+
wasNew, err := c.indexVoteAndUpdateCounts(ctx, vote)
+
if err != nil {
return fmt.Errorf("failed to index vote and update counts: %w", err)
}
-
log.Printf("โœ“ Indexed vote: %s (%s on %s)", uri, vote.Direction, vote.SubjectURI)
+
if wasNew {
+
log.Printf("โœ“ Indexed vote: %s (%s on %s)", uri, vote.Direction, vote.SubjectURI)
+
}
return nil
}
···
}
// indexVoteAndUpdateCounts atomically indexes a vote and updates post vote counts
-
func (c *VoteEventConsumer) indexVoteAndUpdateCounts(ctx context.Context, vote *votes.Vote) error {
+
// Returns (true, nil) if vote was newly inserted, (false, nil) if already existed (idempotent)
+
func (c *VoteEventConsumer) indexVoteAndUpdateCounts(ctx context.Context, vote *votes.Vote) (bool, error) {
tx, err := c.db.BeginTx(ctx, nil)
if err != nil {
-
return fmt.Errorf("failed to begin transaction: %w", err)
+
return false, fmt.Errorf("failed to begin transaction: %w", err)
}
defer func() {
if rollbackErr := tx.Rollback(); rollbackErr != nil && rollbackErr != sql.ErrTxDone {
···
}
}()
-
// 1. Index the vote (idempotent with ON CONFLICT DO NOTHING)
+
// 1. Check for existing active vote with different URI (stale record)
+
// This handles cases where:
+
// - User voted on another client and we missed the delete event
+
// - Vote was reindexed but user created a new vote with different rkey
+
// - Any other state mismatch between PDS and AppView
+
var existingDirection sql.NullString
+
checkQuery := `
+
SELECT direction FROM votes
+
WHERE voter_did = $1
+
AND subject_uri = $2
+
AND deleted_at IS NULL
+
AND uri != $3
+
LIMIT 1
+
`
+
if err := tx.QueryRowContext(ctx, checkQuery, vote.VoterDID, vote.SubjectURI, vote.URI).Scan(&existingDirection); err != nil && err != sql.ErrNoRows {
+
return false, fmt.Errorf("failed to check existing vote: %w", err)
+
}
+
+
// If there's a stale vote, soft-delete it and adjust counts
+
if existingDirection.Valid {
+
softDeleteQuery := `
+
UPDATE votes
+
SET deleted_at = NOW()
+
WHERE voter_did = $1
+
AND subject_uri = $2
+
AND deleted_at IS NULL
+
AND uri != $3
+
`
+
if _, err := tx.ExecContext(ctx, softDeleteQuery, vote.VoterDID, vote.SubjectURI, vote.URI); err != nil {
+
return false, fmt.Errorf("failed to soft-delete existing votes: %w", err)
+
}
+
+
// Decrement the old vote's count (will be re-incremented below if same direction)
+
collection := utils.ExtractCollectionFromURI(vote.SubjectURI)
+
var decrementQuery string
+
if existingDirection.String == "up" {
+
if collection == "social.coves.community.post" {
+
decrementQuery = `UPDATE posts SET upvote_count = GREATEST(0, upvote_count - 1), score = upvote_count - 1 - downvote_count WHERE uri = $1 AND deleted_at IS NULL`
+
} else if collection == "social.coves.community.comment" {
+
decrementQuery = `UPDATE comments SET upvote_count = GREATEST(0, upvote_count - 1), score = upvote_count - 1 - downvote_count WHERE uri = $1 AND deleted_at IS NULL`
+
}
+
} else {
+
if collection == "social.coves.community.post" {
+
decrementQuery = `UPDATE posts SET downvote_count = GREATEST(0, downvote_count - 1), score = upvote_count - (downvote_count - 1) WHERE uri = $1 AND deleted_at IS NULL`
+
} else if collection == "social.coves.community.comment" {
+
decrementQuery = `UPDATE comments SET downvote_count = GREATEST(0, downvote_count - 1), score = upvote_count - (downvote_count - 1) WHERE uri = $1 AND deleted_at IS NULL`
+
}
+
}
+
if decrementQuery != "" {
+
if _, err := tx.ExecContext(ctx, decrementQuery, vote.SubjectURI); err != nil {
+
return false, fmt.Errorf("failed to decrement old vote count: %w", err)
+
}
+
}
+
log.Printf("Cleaned up stale vote for %s on %s (was %s)", vote.VoterDID, vote.SubjectURI, existingDirection.String)
+
}
+
+
// 2. Index the vote (idempotent with ON CONFLICT DO NOTHING)
query := `
INSERT INTO votes (
uri, cid, rkey, voter_did,
···
// If no rows returned, vote already exists (idempotent - OK for Jetstream replays)
if err == sql.ErrNoRows {
-
log.Printf("Vote already indexed: %s (idempotent)", vote.URI)
+
// Silently handle idempotent case - no log needed for replayed events
if commitErr := tx.Commit(); commitErr != nil {
-
return fmt.Errorf("failed to commit transaction: %w", commitErr)
+
return false, fmt.Errorf("failed to commit transaction: %w", commitErr)
}
-
return nil
+
return false, nil // Vote already existed
}
if err != nil {
-
return fmt.Errorf("failed to insert vote: %w", err)
+
return false, fmt.Errorf("failed to insert vote: %w", err)
}
-
// 2. Update vote counts on the subject (post or comment)
+
// 3. Update vote counts on the subject (post or comment)
// Parse collection from subject URI to determine target table
collection := utils.ExtractCollectionFromURI(vote.SubjectURI)
···
// Vote is still indexed in votes table, we just don't update denormalized counts
log.Printf("Vote subject has unsupported collection: %s (vote indexed, counts not updated)", collection)
if commitErr := tx.Commit(); commitErr != nil {
-
return fmt.Errorf("failed to commit transaction: %w", commitErr)
+
return false, fmt.Errorf("failed to commit transaction: %w", commitErr)
}
-
return nil
+
return true, nil // Vote was newly indexed
}
result, err := tx.ExecContext(ctx, updateQuery, vote.SubjectURI)
if err != nil {
-
return fmt.Errorf("failed to update vote counts: %w", err)
+
return false, fmt.Errorf("failed to update vote counts: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
-
return fmt.Errorf("failed to check update result: %w", err)
+
return false, fmt.Errorf("failed to check update result: %w", err)
}
// If subject doesn't exist or is deleted, that's OK (vote still indexed)
···
// Commit transaction
if err := tx.Commit(); err != nil {
-
return fmt.Errorf("failed to commit transaction: %w", err)
+
return false, fmt.Errorf("failed to commit transaction: %w", err)
}
-
return nil
+
return true, nil // Vote was newly indexed
}
// deleteVoteAndUpdateCounts atomically soft-deletes a vote and updates post vote counts
+109
internal/atproto/lexicon/social/coves/community/comment/create.json
···
+
{
+
"lexicon": 1,
+
"id": "social.coves.community.comment.create",
+
"defs": {
+
"main": {
+
"type": "procedure",
+
"description": "Create a comment on a post or another comment. Comments support nested threading, rich text, embeds, and self-labeling.",
+
"input": {
+
"encoding": "application/json",
+
"schema": {
+
"type": "object",
+
"required": ["reply", "content"],
+
"properties": {
+
"reply": {
+
"type": "object",
+
"description": "References for maintaining thread structure. Root always points to the original post, parent points to the immediate parent (post or comment).",
+
"required": ["root", "parent"],
+
"properties": {
+
"root": {
+
"type": "ref",
+
"ref": "com.atproto.repo.strongRef",
+
"description": "Strong reference to the original post that started the thread"
+
},
+
"parent": {
+
"type": "ref",
+
"ref": "com.atproto.repo.strongRef",
+
"description": "Strong reference to the immediate parent (post or comment) being replied to"
+
}
+
}
+
},
+
"content": {
+
"type": "string",
+
"maxGraphemes": 10000,
+
"maxLength": 100000,
+
"description": "Comment text content"
+
},
+
"facets": {
+
"type": "array",
+
"description": "Annotations for rich text (mentions, links, etc.)",
+
"items": {
+
"type": "ref",
+
"ref": "social.coves.richtext.facet"
+
}
+
},
+
"embed": {
+
"type": "union",
+
"description": "Embedded media or quoted posts",
+
"refs": [
+
"social.coves.embed.images",
+
"social.coves.embed.post"
+
]
+
},
+
"langs": {
+
"type": "array",
+
"description": "Languages used in the comment content (ISO 639-1)",
+
"maxLength": 3,
+
"items": {
+
"type": "string",
+
"format": "language"
+
}
+
},
+
"labels": {
+
"type": "ref",
+
"ref": "com.atproto.label.defs#selfLabels",
+
"description": "Self-applied content labels"
+
}
+
}
+
}
+
},
+
"output": {
+
"encoding": "application/json",
+
"schema": {
+
"type": "object",
+
"required": ["uri", "cid"],
+
"properties": {
+
"uri": {
+
"type": "string",
+
"format": "at-uri",
+
"description": "AT-URI of the created comment"
+
},
+
"cid": {
+
"type": "string",
+
"format": "cid",
+
"description": "CID of the created comment record"
+
}
+
}
+
}
+
},
+
"errors": [
+
{
+
"name": "InvalidReply",
+
"description": "The reply reference is invalid, malformed, or refers to non-existent content"
+
},
+
{
+
"name": "ContentTooLong",
+
"description": "Comment content exceeds maximum length constraints"
+
},
+
{
+
"name": "ContentEmpty",
+
"description": "Comment content is empty or contains only whitespace"
+
},
+
{
+
"name": "NotAuthorized",
+
"description": "User is not authorized to create comments on this content"
+
}
+
]
+
}
+
}
+
}
+41
internal/atproto/lexicon/social/coves/community/comment/delete.json
···
+
{
+
"lexicon": 1,
+
"id": "social.coves.community.comment.delete",
+
"defs": {
+
"main": {
+
"type": "procedure",
+
"description": "Delete a comment. Only the comment author can delete their own comments.",
+
"input": {
+
"encoding": "application/json",
+
"schema": {
+
"type": "object",
+
"required": ["uri"],
+
"properties": {
+
"uri": {
+
"type": "string",
+
"format": "at-uri",
+
"description": "AT-URI of the comment to delete"
+
}
+
}
+
}
+
},
+
"output": {
+
"encoding": "application/json",
+
"schema": {
+
"type": "object",
+
"properties": {}
+
}
+
},
+
"errors": [
+
{
+
"name": "CommentNotFound",
+
"description": "Comment with the specified URI does not exist"
+
},
+
{
+
"name": "NotAuthorized",
+
"description": "User is not authorized to delete this comment (not the author)"
+
}
+
]
+
}
+
}
+
}
+97
internal/atproto/lexicon/social/coves/community/comment/update.json
···
+
{
+
"lexicon": 1,
+
"id": "social.coves.community.comment.update",
+
"defs": {
+
"main": {
+
"type": "procedure",
+
"description": "Update an existing comment's content, facets, embed, languages, or labels. Threading references (reply.root and reply.parent) are immutable and cannot be changed.",
+
"input": {
+
"encoding": "application/json",
+
"schema": {
+
"type": "object",
+
"required": ["uri", "content"],
+
"properties": {
+
"uri": {
+
"type": "string",
+
"format": "at-uri",
+
"description": "AT-URI of the comment to update"
+
},
+
"content": {
+
"type": "string",
+
"maxGraphemes": 10000,
+
"maxLength": 100000,
+
"description": "Updated comment text content"
+
},
+
"facets": {
+
"type": "array",
+
"description": "Updated annotations for rich text (mentions, links, etc.)",
+
"items": {
+
"type": "ref",
+
"ref": "social.coves.richtext.facet"
+
}
+
},
+
"embed": {
+
"type": "union",
+
"description": "Updated embedded media or quoted posts",
+
"refs": [
+
"social.coves.embed.images",
+
"social.coves.embed.post"
+
]
+
},
+
"langs": {
+
"type": "array",
+
"description": "Updated languages used in the comment content (ISO 639-1)",
+
"maxLength": 3,
+
"items": {
+
"type": "string",
+
"format": "language"
+
}
+
},
+
"labels": {
+
"type": "ref",
+
"ref": "com.atproto.label.defs#selfLabels",
+
"description": "Updated self-applied content labels"
+
}
+
}
+
}
+
},
+
"output": {
+
"encoding": "application/json",
+
"schema": {
+
"type": "object",
+
"required": ["uri", "cid"],
+
"properties": {
+
"uri": {
+
"type": "string",
+
"format": "at-uri",
+
"description": "AT-URI of the updated comment (unchanged from input)"
+
},
+
"cid": {
+
"type": "string",
+
"format": "cid",
+
"description": "New CID of the updated comment record"
+
}
+
}
+
}
+
},
+
"errors": [
+
{
+
"name": "CommentNotFound",
+
"description": "Comment with the specified URI does not exist"
+
},
+
{
+
"name": "ContentTooLong",
+
"description": "Updated comment content exceeds maximum length constraints"
+
},
+
{
+
"name": "ContentEmpty",
+
"description": "Updated comment content is empty or contains only whitespace"
+
},
+
{
+
"name": "NotAuthorized",
+
"description": "User is not authorized to update this comment (not the author)"
+
}
+
]
+
}
+
}
+
}
+38
internal/core/comments/types.go
···
+
package comments
+
+
// CreateCommentRequest contains parameters for creating a comment
+
type CreateCommentRequest struct {
+
Reply ReplyRef `json:"reply"`
+
Content string `json:"content"`
+
Facets []interface{} `json:"facets,omitempty"`
+
Embed interface{} `json:"embed,omitempty"`
+
Langs []string `json:"langs,omitempty"`
+
Labels *SelfLabels `json:"labels,omitempty"`
+
}
+
+
// CreateCommentResponse contains the result of creating a comment
+
type CreateCommentResponse struct {
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
}
+
+
// UpdateCommentRequest contains parameters for updating a comment
+
type UpdateCommentRequest struct {
+
URI string `json:"uri"`
+
Content string `json:"content"`
+
Facets []interface{} `json:"facets,omitempty"`
+
Embed interface{} `json:"embed,omitempty"`
+
Langs []string `json:"langs,omitempty"`
+
Labels *SelfLabels `json:"labels,omitempty"`
+
}
+
+
// UpdateCommentResponse contains the result of updating a comment
+
type UpdateCommentResponse struct {
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
}
+
+
// DeleteCommentRequest contains parameters for deleting a comment
+
type DeleteCommentRequest struct {
+
URI string `json:"uri"`
+
}
+130
internal/api/handlers/comments/create_comment.go
···
+
package comments
+
+
import (
+
"Coves/internal/api/middleware"
+
"Coves/internal/core/comments"
+
"encoding/json"
+
"log"
+
"net/http"
+
)
+
+
// CreateCommentHandler handles comment creation requests
+
type CreateCommentHandler struct {
+
service comments.Service
+
}
+
+
// NewCreateCommentHandler creates a new handler for creating comments
+
func NewCreateCommentHandler(service comments.Service) *CreateCommentHandler {
+
return &CreateCommentHandler{
+
service: service,
+
}
+
}
+
+
// CreateCommentInput matches the lexicon input schema for social.coves.community.comment.create
+
type CreateCommentInput struct {
+
Reply struct {
+
Root struct {
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
} `json:"root"`
+
Parent struct {
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
} `json:"parent"`
+
} `json:"reply"`
+
Content string `json:"content"`
+
Facets []interface{} `json:"facets,omitempty"`
+
Embed interface{} `json:"embed,omitempty"`
+
Langs []string `json:"langs,omitempty"`
+
Labels interface{} `json:"labels,omitempty"`
+
}
+
+
// CreateCommentOutput matches the lexicon output schema
+
type CreateCommentOutput struct {
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
}
+
+
// HandleCreate handles comment creation requests
+
// POST /xrpc/social.coves.community.comment.create
+
//
+
// Request body: { "reply": { "root": {...}, "parent": {...} }, "content": "..." }
+
// Response: { "uri": "at://...", "cid": "..." }
+
func (h *CreateCommentHandler) HandleCreate(w http.ResponseWriter, r *http.Request) {
+
// 1. Check method is POST
+
if r.Method != http.MethodPost {
+
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+
return
+
}
+
+
// 2. Limit request body size to prevent DoS attacks (100KB should be plenty for comments)
+
r.Body = http.MaxBytesReader(w, r.Body, 100*1024)
+
+
// 3. Parse JSON body into CreateCommentInput
+
var input CreateCommentInput
+
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "Invalid request body")
+
return
+
}
+
+
// 4. Get OAuth session from context (injected by auth middleware)
+
session := middleware.GetOAuthSession(r)
+
if session == nil {
+
writeError(w, http.StatusUnauthorized, "AuthRequired", "Authentication required")
+
return
+
}
+
+
// 5. Convert labels interface{} to *comments.SelfLabels if provided
+
var labels *comments.SelfLabels
+
if input.Labels != nil {
+
labelsJSON, err := json.Marshal(input.Labels)
+
if err != nil {
+
writeError(w, http.StatusBadRequest, "InvalidLabels", "Invalid labels format")
+
return
+
}
+
var selfLabels comments.SelfLabels
+
if err := json.Unmarshal(labelsJSON, &selfLabels); err != nil {
+
writeError(w, http.StatusBadRequest, "InvalidLabels", "Invalid labels structure")
+
return
+
}
+
labels = &selfLabels
+
}
+
+
// 6. Convert input to CreateCommentRequest
+
req := comments.CreateCommentRequest{
+
Reply: comments.ReplyRef{
+
Root: comments.StrongRef{
+
URI: input.Reply.Root.URI,
+
CID: input.Reply.Root.CID,
+
},
+
Parent: comments.StrongRef{
+
URI: input.Reply.Parent.URI,
+
CID: input.Reply.Parent.CID,
+
},
+
},
+
Content: input.Content,
+
Facets: input.Facets,
+
Embed: input.Embed,
+
Langs: input.Langs,
+
Labels: labels,
+
}
+
+
// 7. Call service to create comment
+
response, err := h.service.CreateComment(r.Context(), session, req)
+
if err != nil {
+
handleServiceError(w, err)
+
return
+
}
+
+
// 8. Return JSON response with URI and CID
+
output := CreateCommentOutput{
+
URI: response.URI,
+
CID: response.CID,
+
}
+
+
w.Header().Set("Content-Type", "application/json")
+
w.WriteHeader(http.StatusOK)
+
if err := json.NewEncoder(w).Encode(output); err != nil {
+
log.Printf("Failed to encode response: %v", err)
+
}
+
}
+80
internal/api/handlers/comments/delete_comment.go
···
+
package comments
+
+
import (
+
"Coves/internal/api/middleware"
+
"Coves/internal/core/comments"
+
"encoding/json"
+
"log"
+
"net/http"
+
)
+
+
// DeleteCommentHandler handles comment deletion requests
+
type DeleteCommentHandler struct {
+
service comments.Service
+
}
+
+
// NewDeleteCommentHandler creates a new handler for deleting comments
+
func NewDeleteCommentHandler(service comments.Service) *DeleteCommentHandler {
+
return &DeleteCommentHandler{
+
service: service,
+
}
+
}
+
+
// DeleteCommentInput matches the lexicon input schema for social.coves.community.comment.delete
+
type DeleteCommentInput struct {
+
URI string `json:"uri"`
+
}
+
+
// DeleteCommentOutput is empty per lexicon specification
+
type DeleteCommentOutput struct{}
+
+
// HandleDelete handles comment deletion requests
+
// POST /xrpc/social.coves.community.comment.delete
+
//
+
// Request body: { "uri": "at://..." }
+
// Response: {}
+
func (h *DeleteCommentHandler) HandleDelete(w http.ResponseWriter, r *http.Request) {
+
// 1. Check method is POST
+
if r.Method != http.MethodPost {
+
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+
return
+
}
+
+
// 2. Limit request body size to prevent DoS attacks (100KB should be plenty for comments)
+
r.Body = http.MaxBytesReader(w, r.Body, 100*1024)
+
+
// 3. Parse JSON body into DeleteCommentInput
+
var input DeleteCommentInput
+
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "Invalid request body")
+
return
+
}
+
+
// 4. Get OAuth session from context (injected by auth middleware)
+
session := middleware.GetOAuthSession(r)
+
if session == nil {
+
writeError(w, http.StatusUnauthorized, "AuthRequired", "Authentication required")
+
return
+
}
+
+
// 5. Convert input to DeleteCommentRequest
+
req := comments.DeleteCommentRequest{
+
URI: input.URI,
+
}
+
+
// 6. Call service to delete comment
+
err := h.service.DeleteComment(r.Context(), session, req)
+
if err != nil {
+
handleServiceError(w, err)
+
return
+
}
+
+
// 7. Return empty JSON object per lexicon specification
+
output := DeleteCommentOutput{}
+
+
w.Header().Set("Content-Type", "application/json")
+
w.WriteHeader(http.StatusOK)
+
if err := json.NewEncoder(w).Encode(output); err != nil {
+
log.Printf("Failed to encode response: %v", err)
+
}
+
}
+34 -2
internal/api/handlers/comments/errors.go
···
import (
"Coves/internal/core/comments"
"encoding/json"
+
"errors"
"log"
"net/http"
)
···
func handleServiceError(w http.ResponseWriter, err error) {
switch {
case comments.IsNotFound(err):
-
writeError(w, http.StatusNotFound, "NotFound", err.Error())
+
// Map specific not found errors to appropriate messages
+
switch {
+
case errors.Is(err, comments.ErrCommentNotFound):
+
writeError(w, http.StatusNotFound, "CommentNotFound", "Comment not found")
+
case errors.Is(err, comments.ErrParentNotFound):
+
writeError(w, http.StatusNotFound, "ParentNotFound", "Parent post or comment not found")
+
case errors.Is(err, comments.ErrRootNotFound):
+
writeError(w, http.StatusNotFound, "RootNotFound", "Root post not found")
+
default:
+
writeError(w, http.StatusNotFound, "NotFound", err.Error())
+
}
case comments.IsValidationError(err):
-
writeError(w, http.StatusBadRequest, "InvalidRequest", err.Error())
+
// Map specific validation errors to appropriate messages
+
switch {
+
case errors.Is(err, comments.ErrInvalidReply):
+
writeError(w, http.StatusBadRequest, "InvalidReply", "The reply reference is invalid or malformed")
+
case errors.Is(err, comments.ErrContentTooLong):
+
writeError(w, http.StatusBadRequest, "ContentTooLong", "Comment content exceeds 10000 graphemes")
+
case errors.Is(err, comments.ErrContentEmpty):
+
writeError(w, http.StatusBadRequest, "ContentEmpty", "Comment content is required")
+
default:
+
writeError(w, http.StatusBadRequest, "InvalidRequest", err.Error())
+
}
+
+
case errors.Is(err, comments.ErrNotAuthorized):
+
writeError(w, http.StatusForbidden, "NotAuthorized", "User is not authorized to perform this action")
+
+
case errors.Is(err, comments.ErrBanned):
+
writeError(w, http.StatusForbidden, "Banned", "User is banned from this community")
+
+
// NOTE: IsConflict case removed - the PDS handles duplicate detection via CreateRecord,
+
// so ErrCommentAlreadyExists is never returned from the service layer. If the PDS rejects
+
// a duplicate record, it returns an auth/validation error which is handled by other cases.
+
// Keeping this code would be dead code that never executes.
default:
// Don't leak internal error details to clients
+112
internal/api/handlers/comments/update_comment.go
···
+
package comments
+
+
import (
+
"Coves/internal/api/middleware"
+
"Coves/internal/core/comments"
+
"encoding/json"
+
"log"
+
"net/http"
+
)
+
+
// UpdateCommentHandler handles comment update requests
+
type UpdateCommentHandler struct {
+
service comments.Service
+
}
+
+
// NewUpdateCommentHandler creates a new handler for updating comments
+
func NewUpdateCommentHandler(service comments.Service) *UpdateCommentHandler {
+
return &UpdateCommentHandler{
+
service: service,
+
}
+
}
+
+
// UpdateCommentInput matches the lexicon input schema for social.coves.community.comment.update
+
type UpdateCommentInput struct {
+
URI string `json:"uri"`
+
Content string `json:"content"`
+
Facets []interface{} `json:"facets,omitempty"`
+
Embed interface{} `json:"embed,omitempty"`
+
Langs []string `json:"langs,omitempty"`
+
Labels interface{} `json:"labels,omitempty"`
+
}
+
+
// UpdateCommentOutput matches the lexicon output schema
+
type UpdateCommentOutput struct {
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
}
+
+
// HandleUpdate handles comment update requests
+
// POST /xrpc/social.coves.community.comment.update
+
//
+
// Request body: { "uri": "at://...", "content": "..." }
+
// Response: { "uri": "at://...", "cid": "..." }
+
func (h *UpdateCommentHandler) HandleUpdate(w http.ResponseWriter, r *http.Request) {
+
// 1. Check method is POST
+
if r.Method != http.MethodPost {
+
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+
return
+
}
+
+
// 2. Limit request body size to prevent DoS attacks (100KB should be plenty for comments)
+
r.Body = http.MaxBytesReader(w, r.Body, 100*1024)
+
+
// 3. Parse JSON body into UpdateCommentInput
+
var input UpdateCommentInput
+
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "Invalid request body")
+
return
+
}
+
+
// 4. Get OAuth session from context (injected by auth middleware)
+
session := middleware.GetOAuthSession(r)
+
if session == nil {
+
writeError(w, http.StatusUnauthorized, "AuthRequired", "Authentication required")
+
return
+
}
+
+
// 5. Convert labels interface{} to *comments.SelfLabels if provided
+
var labels *comments.SelfLabels
+
if input.Labels != nil {
+
labelsJSON, err := json.Marshal(input.Labels)
+
if err != nil {
+
writeError(w, http.StatusBadRequest, "InvalidLabels", "Invalid labels format")
+
return
+
}
+
var selfLabels comments.SelfLabels
+
if err := json.Unmarshal(labelsJSON, &selfLabels); err != nil {
+
writeError(w, http.StatusBadRequest, "InvalidLabels", "Invalid labels structure")
+
return
+
}
+
labels = &selfLabels
+
}
+
+
// 6. Convert input to UpdateCommentRequest
+
req := comments.UpdateCommentRequest{
+
URI: input.URI,
+
Content: input.Content,
+
Facets: input.Facets,
+
Embed: input.Embed,
+
Langs: input.Langs,
+
Labels: labels,
+
}
+
+
// 7. Call service to update comment
+
response, err := h.service.UpdateComment(r.Context(), session, req)
+
if err != nil {
+
handleServiceError(w, err)
+
return
+
}
+
+
// 8. Return JSON response with URI and CID
+
output := UpdateCommentOutput{
+
URI: response.URI,
+
CID: response.CID,
+
}
+
+
w.Header().Set("Content-Type", "application/json")
+
w.WriteHeader(http.StatusOK)
+
if err := json.NewEncoder(w).Encode(output); err != nil {
+
log.Printf("Failed to encode response: %v", err)
+
}
+
}
+35
internal/api/routes/comment.go
···
+
package routes
+
+
import (
+
"Coves/internal/api/handlers/comments"
+
"Coves/internal/api/middleware"
+
commentsCore "Coves/internal/core/comments"
+
+
"github.com/go-chi/chi/v5"
+
)
+
+
// RegisterCommentRoutes registers comment-related XRPC endpoints on the router
+
// Implements social.coves.community.comment.* lexicon endpoints
+
// All write operations (create, update, delete) require authentication
+
func RegisterCommentRoutes(r chi.Router, service commentsCore.Service, authMiddleware *middleware.OAuthAuthMiddleware) {
+
// Initialize handlers
+
createHandler := comments.NewCreateCommentHandler(service)
+
updateHandler := comments.NewUpdateCommentHandler(service)
+
deleteHandler := comments.NewDeleteCommentHandler(service)
+
+
// Procedure endpoints (POST) - require authentication
+
// social.coves.community.comment.create - create a new comment on a post or another comment
+
r.With(authMiddleware.RequireAuth).Post(
+
"/xrpc/social.coves.community.comment.create",
+
createHandler.HandleCreate)
+
+
// social.coves.community.comment.update - update an existing comment's content
+
r.With(authMiddleware.RequireAuth).Post(
+
"/xrpc/social.coves.community.comment.update",
+
updateHandler.HandleUpdate)
+
+
// social.coves.community.comment.delete - soft delete a comment
+
r.With(authMiddleware.RequireAuth).Post(
+
"/xrpc/social.coves.community.comment.delete",
+
deleteHandler.HandleDelete)
+
}
+4 -2
tests/integration/comment_query_test.go
···
postRepo := postgres.NewPostRepository(db)
userRepo := postgres.NewUserRepository(db)
communityRepo := postgres.NewCommunityRepository(db)
-
return comments.NewCommentService(commentRepo, userRepo, postRepo, communityRepo)
+
// Use factory constructor with nil factory - these tests only use the read path (GetComments)
+
return comments.NewCommentServiceWithPDSFactory(commentRepo, userRepo, postRepo, communityRepo, nil, nil)
}
// Helper: createTestCommentWithScore creates a comment with specific vote counts
···
postRepo := postgres.NewPostRepository(db)
userRepo := postgres.NewUserRepository(db)
communityRepo := postgres.NewCommunityRepository(db)
-
service := comments.NewCommentService(commentRepo, userRepo, postRepo, communityRepo)
+
// Use factory constructor with nil factory - these tests only use the read path (GetComments)
+
service := comments.NewCommentServiceWithPDSFactory(commentRepo, userRepo, postRepo, communityRepo, nil, nil)
return &testCommentServiceAdapter{service: service}
}
+6 -3
tests/integration/comment_vote_test.go
···
}
// Query comments with viewer authentication
-
commentService := comments.NewCommentService(commentRepo, userRepo, postRepo, communityRepo)
+
// Use factory constructor with nil factory - this test only uses the read path (GetComments)
+
commentService := comments.NewCommentServiceWithPDSFactory(commentRepo, userRepo, postRepo, communityRepo, nil, nil)
response, err := commentService.GetComments(ctx, &comments.GetCommentsRequest{
PostURI: testPostURI,
Sort: "new",
···
}
// Query with authentication but no vote
-
commentService := comments.NewCommentService(commentRepo, userRepo, postRepo, communityRepo)
+
// Use factory constructor with nil factory - this test only uses the read path (GetComments)
+
commentService := comments.NewCommentServiceWithPDSFactory(commentRepo, userRepo, postRepo, communityRepo, nil, nil)
response, err := commentService.GetComments(ctx, &comments.GetCommentsRequest{
PostURI: testPostURI,
Sort: "new",
···
t.Run("Unauthenticated request has no viewer state", func(t *testing.T) {
// Query without authentication
-
commentService := comments.NewCommentService(commentRepo, userRepo, postRepo, communityRepo)
+
// Use factory constructor with nil factory - this test only uses the read path (GetComments)
+
commentService := comments.NewCommentServiceWithPDSFactory(commentRepo, userRepo, postRepo, communityRepo, nil, nil)
response, err := commentService.GetComments(ctx, &comments.GetCommentsRequest{
PostURI: testPostURI,
Sort: "new",
+2 -1
tests/integration/concurrent_scenarios_test.go
···
}
// Verify all comments are retrievable via service
-
commentService := comments.NewCommentService(commentRepo, userRepo, postRepo, communityRepo)
+
// Use factory constructor with nil factory - this test only uses the read path (GetComments)
+
commentService := comments.NewCommentServiceWithPDSFactory(commentRepo, userRepo, postRepo, communityRepo, nil, nil)
response, err := commentService.GetComments(ctx, &comments.GetCommentsRequest{
PostURI: postURI,
Sort: "new",
+66
internal/db/migrations/021_add_comment_deletion_metadata.sql
···
+
-- +goose Up
+
-- Add deletion reason tracking to preserve thread structure while respecting privacy
+
-- When comments are deleted, we blank content but keep the record for threading
+
+
-- Create enum type for deletion reasons
+
CREATE TYPE deletion_reason AS ENUM ('author', 'moderator');
+
+
-- Add new columns to comments table
+
ALTER TABLE comments ADD COLUMN deletion_reason deletion_reason;
+
ALTER TABLE comments ADD COLUMN deleted_by TEXT;
+
+
-- Add comments for new columns
+
COMMENT ON COLUMN comments.deletion_reason IS 'Reason for deletion: author (user deleted), moderator (community mod removed)';
+
COMMENT ON COLUMN comments.deleted_by IS 'DID of the actor who performed the deletion';
+
+
-- Backfill existing deleted comments as author-deleted
+
-- This handles existing soft-deleted comments gracefully
+
UPDATE comments
+
SET deletion_reason = 'author',
+
deleted_by = commenter_did
+
WHERE deleted_at IS NOT NULL AND deletion_reason IS NULL;
+
+
-- Modify existing indexes to NOT filter deleted_at IS NULL
+
-- This allows deleted comments to appear in thread queries for structure preservation
+
-- Note: We drop and recreate to change the partial index condition
+
+
-- Drop old partial indexes that exclude deleted comments
+
DROP INDEX IF EXISTS idx_comments_root;
+
DROP INDEX IF EXISTS idx_comments_parent;
+
DROP INDEX IF EXISTS idx_comments_parent_score;
+
DROP INDEX IF EXISTS idx_comments_uri_active;
+
+
-- Recreate indexes without the deleted_at filter (include all comments for threading)
+
CREATE INDEX idx_comments_root ON comments(root_uri, created_at DESC);
+
CREATE INDEX idx_comments_parent ON comments(parent_uri, created_at DESC);
+
CREATE INDEX idx_comments_parent_score ON comments(parent_uri, score DESC, created_at DESC);
+
CREATE INDEX idx_comments_uri_lookup ON comments(uri);
+
+
-- Add index for querying by deletion_reason (for moderation dashboard)
+
CREATE INDEX idx_comments_deleted_reason ON comments(deletion_reason, deleted_at DESC)
+
WHERE deleted_at IS NOT NULL;
+
+
-- Add index for querying by deleted_by (for moderation audit/filtering)
+
CREATE INDEX idx_comments_deleted_by ON comments(deleted_by, deleted_at DESC)
+
WHERE deleted_at IS NOT NULL;
+
+
-- +goose Down
+
-- Remove deletion metadata columns and restore original indexes
+
+
DROP INDEX IF EXISTS idx_comments_deleted_by;
+
DROP INDEX IF EXISTS idx_comments_deleted_reason;
+
DROP INDEX IF EXISTS idx_comments_uri_lookup;
+
DROP INDEX IF EXISTS idx_comments_parent_score;
+
DROP INDEX IF EXISTS idx_comments_parent;
+
DROP INDEX IF EXISTS idx_comments_root;
+
+
-- Restore original partial indexes (excluding deleted comments)
+
CREATE INDEX idx_comments_root ON comments(root_uri, created_at DESC) WHERE deleted_at IS NULL;
+
CREATE INDEX idx_comments_parent ON comments(parent_uri, created_at DESC) WHERE deleted_at IS NULL;
+
CREATE INDEX idx_comments_parent_score ON comments(parent_uri, score DESC, created_at DESC) WHERE deleted_at IS NULL;
+
CREATE INDEX idx_comments_uri_active ON comments(uri) WHERE deleted_at IS NULL;
+
+
ALTER TABLE comments DROP COLUMN IF EXISTS deleted_by;
+
ALTER TABLE comments DROP COLUMN IF EXISTS deletion_reason;
+
+
DROP TYPE IF EXISTS deletion_reason;
+17 -13
internal/core/comments/view_models.go
···
// CommentView represents the full view of a comment with all metadata
// Matches social.coves.community.comment.getComments#commentView lexicon
// Used in thread views and get endpoints
+
// For deleted comments, IsDeleted=true and content-related fields are empty/nil
type CommentView struct {
-
Embed interface{} `json:"embed,omitempty"`
-
Record interface{} `json:"record"`
-
Viewer *CommentViewerState `json:"viewer,omitempty"`
-
Author *posts.AuthorView `json:"author"`
-
Post *CommentRef `json:"post"`
-
Parent *CommentRef `json:"parent,omitempty"`
-
Stats *CommentStats `json:"stats"`
-
Content string `json:"content"`
-
CreatedAt string `json:"createdAt"`
-
IndexedAt string `json:"indexedAt"`
-
URI string `json:"uri"`
-
CID string `json:"cid"`
-
ContentFacets []interface{} `json:"contentFacets,omitempty"`
+
Embed interface{} `json:"embed,omitempty"`
+
Record interface{} `json:"record"`
+
Viewer *CommentViewerState `json:"viewer,omitempty"`
+
Author *posts.AuthorView `json:"author"`
+
Post *CommentRef `json:"post"`
+
Parent *CommentRef `json:"parent,omitempty"`
+
Stats *CommentStats `json:"stats"`
+
Content string `json:"content"`
+
CreatedAt string `json:"createdAt"`
+
IndexedAt string `json:"indexedAt"`
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
ContentFacets []interface{} `json:"contentFacets,omitempty"`
+
IsDeleted bool `json:"isDeleted,omitempty"`
+
DeletionReason *string `json:"deletionReason,omitempty"`
+
DeletedAt *string `json:"deletedAt,omitempty"`
}
// ThreadViewComment represents a comment with its nested replies
+23 -1
internal/core/comments/interfaces.go
···
package comments
-
import "context"
+
import (
+
"context"
+
"database/sql"
+
)
// Repository defines the data access interface for comments
// Used by Jetstream consumer to index comments from firehose
···
// Delete soft-deletes a comment (sets deleted_at)
// Called by Jetstream consumer after comment is deleted from PDS
+
// Deprecated: Use SoftDeleteWithReason for new code to preserve thread structure
Delete(ctx context.Context, uri string) error
+
// SoftDeleteWithReason performs a soft delete that blanks content but preserves thread structure
+
// This allows deleted comments to appear as "[deleted]" placeholders in thread views
+
// reason: "author" (user deleted) or "moderator" (mod removed)
+
// deletedByDID: DID of the actor who performed the deletion
+
SoftDeleteWithReason(ctx context.Context, uri, reason, deletedByDID string) error
+
// ListByRoot retrieves all comments in a thread (flat)
// Used for fetching entire comment threads on posts
ListByRoot(ctx context.Context, rootURI string, limit, offset int) ([]*Comment, error)
···
limitPerParent int,
) (map[string][]*Comment, error)
}
+
+
// RepositoryTx provides transaction-aware operations for consumers that need atomicity
+
// Used by Jetstream consumer to perform atomic delete + count updates
+
// Implementations that support transactions should also implement this interface
+
type RepositoryTx interface {
+
// SoftDeleteWithReasonTx performs a soft delete within a transaction
+
// If tx is nil, executes directly against the database
+
// Returns rows affected count for callers that need to check idempotency
+
// reason: must be DeletionReasonAuthor or DeletionReasonModerator
+
// deletedByDID: DID of the actor who performed the deletion
+
SoftDeleteWithReasonTx(ctx context.Context, tx *sql.Tx, uri, reason, deletedByDID string) (int64, error)
+
}
+87 -27
internal/db/postgres/comment_repo.go
···
id, uri, cid, rkey, commenter_did,
root_uri, root_cid, parent_uri, parent_cid,
content, content_facets, embed, content_labels, langs,
-
created_at, indexed_at, deleted_at,
+
created_at, indexed_at, deleted_at, deletion_reason, deleted_by,
upvote_count, downvote_count, score, reply_count
FROM comments
WHERE uri = $1
···
&comment.ID, &comment.URI, &comment.CID, &comment.RKey, &comment.CommenterDID,
&comment.RootURI, &comment.RootCID, &comment.ParentURI, &comment.ParentCID,
&comment.Content, &comment.ContentFacets, &comment.Embed, &comment.ContentLabels, &langs,
-
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt,
+
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt, &comment.DeletionReason, &comment.DeletedBy,
&comment.UpvoteCount, &comment.DownvoteCount, &comment.Score, &comment.ReplyCount,
)
···
// Delete soft-deletes a comment (sets deleted_at)
// Called by Jetstream consumer after comment is deleted from PDS
// Idempotent: Returns success if comment already deleted
+
// Deprecated: Use SoftDeleteWithReason for new code to preserve thread structure
func (r *postgresCommentRepo) Delete(ctx context.Context, uri string) error {
query := `
UPDATE comments
···
return nil
}
-
// ListByRoot retrieves all active comments in a thread (flat)
+
// SoftDeleteWithReason performs a soft delete that blanks content but preserves thread structure
+
// This allows deleted comments to appear as "[deleted]" placeholders in thread views
+
// Idempotent: Returns success if comment already deleted
+
// Validates that reason is a known deletion reason constant
+
func (r *postgresCommentRepo) SoftDeleteWithReason(ctx context.Context, uri, reason, deletedByDID string) error {
+
// Validate deletion reason
+
if reason != comments.DeletionReasonAuthor && reason != comments.DeletionReasonModerator {
+
return fmt.Errorf("invalid deletion reason: %s", reason)
+
}
+
+
_, err := r.SoftDeleteWithReasonTx(ctx, nil, uri, reason, deletedByDID)
+
return err
+
}
+
+
// SoftDeleteWithReasonTx performs a soft delete within an optional transaction
+
// If tx is nil, executes directly against the database
+
// Returns rows affected count for callers that need to check idempotency
+
// This method is used by both the repository and the Jetstream consumer
+
func (r *postgresCommentRepo) SoftDeleteWithReasonTx(ctx context.Context, tx *sql.Tx, uri, reason, deletedByDID string) (int64, error) {
+
query := `
+
UPDATE comments
+
SET
+
content = '',
+
content_facets = NULL,
+
embed = NULL,
+
content_labels = NULL,
+
deleted_at = NOW(),
+
deletion_reason = $2,
+
deleted_by = $3
+
WHERE uri = $1 AND deleted_at IS NULL
+
`
+
+
var result sql.Result
+
var err error
+
+
if tx != nil {
+
result, err = tx.ExecContext(ctx, query, uri, reason, deletedByDID)
+
} else {
+
result, err = r.db.ExecContext(ctx, query, uri, reason, deletedByDID)
+
}
+
+
if err != nil {
+
return 0, fmt.Errorf("failed to soft delete comment: %w", err)
+
}
+
+
rowsAffected, err := result.RowsAffected()
+
if err != nil {
+
return 0, fmt.Errorf("failed to check delete result: %w", err)
+
}
+
+
return rowsAffected, nil
+
}
+
+
// ListByRoot retrieves all comments in a thread (flat), including deleted ones
// Used for fetching entire comment threads on posts
+
// Includes deleted comments to preserve thread structure (shown as "[deleted]" placeholders)
func (r *postgresCommentRepo) ListByRoot(ctx context.Context, rootURI string, limit, offset int) ([]*comments.Comment, error) {
query := `
SELECT
id, uri, cid, rkey, commenter_did,
root_uri, root_cid, parent_uri, parent_cid,
content, content_facets, embed, content_labels, langs,
-
created_at, indexed_at, deleted_at,
+
created_at, indexed_at, deleted_at, deletion_reason, deleted_by,
upvote_count, downvote_count, score, reply_count
FROM comments
-
WHERE root_uri = $1 AND deleted_at IS NULL
+
WHERE root_uri = $1
ORDER BY created_at ASC
LIMIT $2 OFFSET $3
`
···
&comment.ID, &comment.URI, &comment.CID, &comment.RKey, &comment.CommenterDID,
&comment.RootURI, &comment.RootCID, &comment.ParentURI, &comment.ParentCID,
&comment.Content, &comment.ContentFacets, &comment.Embed, &comment.ContentLabels, &langs,
-
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt,
+
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt, &comment.DeletionReason, &comment.DeletedBy,
&comment.UpvoteCount, &comment.DownvoteCount, &comment.Score, &comment.ReplyCount,
)
if err != nil {
···
return result, nil
}
-
// ListByParent retrieves direct replies to a post or comment
+
// ListByParent retrieves direct replies to a post or comment, including deleted ones
// Used for building nested/threaded comment views
+
// Includes deleted comments to preserve thread structure (shown as "[deleted]" placeholders)
func (r *postgresCommentRepo) ListByParent(ctx context.Context, parentURI string, limit, offset int) ([]*comments.Comment, error) {
query := `
SELECT
id, uri, cid, rkey, commenter_did,
root_uri, root_cid, parent_uri, parent_cid,
content, content_facets, embed, content_labels, langs,
-
created_at, indexed_at, deleted_at,
+
created_at, indexed_at, deleted_at, deletion_reason, deleted_by,
upvote_count, downvote_count, score, reply_count
FROM comments
-
WHERE parent_uri = $1 AND deleted_at IS NULL
+
WHERE parent_uri = $1
ORDER BY created_at ASC
LIMIT $2 OFFSET $3
`
···
&comment.ID, &comment.URI, &comment.CID, &comment.RKey, &comment.CommenterDID,
&comment.RootURI, &comment.RootCID, &comment.ParentURI, &comment.ParentCID,
&comment.Content, &comment.ContentFacets, &comment.Embed, &comment.ContentLabels, &langs,
-
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt,
+
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt, &comment.DeletionReason, &comment.DeletedBy,
&comment.UpvoteCount, &comment.DownvoteCount, &comment.Score, &comment.ReplyCount,
)
if err != nil {
···
}
// ListByCommenter retrieves all active comments by a specific user
-
// Future: Used for user comment history
+
// Used for user comment history - filters out deleted comments
func (r *postgresCommentRepo) ListByCommenter(ctx context.Context, commenterDID string, limit, offset int) ([]*comments.Comment, error) {
query := `
SELECT
id, uri, cid, rkey, commenter_did,
root_uri, root_cid, parent_uri, parent_cid,
content, content_facets, embed, content_labels, langs,
-
created_at, indexed_at, deleted_at,
+
created_at, indexed_at, deleted_at, deletion_reason, deleted_by,
upvote_count, downvote_count, score, reply_count
FROM comments
WHERE commenter_did = $1 AND deleted_at IS NULL
···
&comment.ID, &comment.URI, &comment.CID, &comment.RKey, &comment.CommenterDID,
&comment.RootURI, &comment.RootCID, &comment.ParentURI, &comment.ParentCID,
&comment.Content, &comment.ContentFacets, &comment.Embed, &comment.ContentLabels, &langs,
-
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt,
+
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt, &comment.DeletionReason, &comment.DeletedBy,
&comment.UpvoteCount, &comment.DownvoteCount, &comment.Score, &comment.ReplyCount,
)
if err != nil {
···
c.id, c.uri, c.cid, c.rkey, c.commenter_did,
c.root_uri, c.root_cid, c.parent_uri, c.parent_cid,
c.content, c.content_facets, c.embed, c.content_labels, c.langs,
-
c.created_at, c.indexed_at, c.deleted_at,
+
c.created_at, c.indexed_at, c.deleted_at, c.deletion_reason, c.deleted_by,
c.upvote_count, c.downvote_count, c.score, c.reply_count,
log(greatest(2, c.score + 2)) / power(((EXTRACT(EPOCH FROM (NOW() - c.created_at)) / 3600) + 2), 1.8) as hot_rank,
COALESCE(u.handle, c.commenter_did) as author_handle
···
c.id, c.uri, c.cid, c.rkey, c.commenter_did,
c.root_uri, c.root_cid, c.parent_uri, c.parent_cid,
c.content, c.content_facets, c.embed, c.content_labels, c.langs,
-
c.created_at, c.indexed_at, c.deleted_at,
+
c.created_at, c.indexed_at, c.deleted_at, c.deletion_reason, c.deleted_by,
c.upvote_count, c.downvote_count, c.score, c.reply_count,
NULL::numeric as hot_rank,
COALESCE(u.handle, c.commenter_did) as author_handle
···
// Build complete query with JOINs and filters
// LEFT JOIN prevents data loss when user record hasn't been indexed yet (out-of-order Jetstream events)
+
// Includes deleted comments to preserve thread structure (shown as "[deleted]" placeholders)
query := fmt.Sprintf(`
%s
LEFT JOIN users u ON c.commenter_did = u.did
-
WHERE c.parent_uri = $1 AND c.deleted_at IS NULL
+
WHERE c.parent_uri = $1
%s
%s
ORDER BY %s
···
&comment.ID, &comment.URI, &comment.CID, &comment.RKey, &comment.CommenterDID,
&comment.RootURI, &comment.RootCID, &comment.ParentURI, &comment.ParentCID,
&comment.Content, &comment.ContentFacets, &comment.Embed, &comment.ContentLabels, &langs,
-
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt,
+
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt, &comment.DeletionReason, &comment.DeletedBy,
&comment.UpvoteCount, &comment.DownvoteCount, &comment.Score, &comment.ReplyCount,
&hotRank, &authorHandle,
)
···
// GetByURIsBatch retrieves multiple comments by their AT-URIs in a single query
// Returns map[uri]*Comment for efficient lookups without N+1 queries
+
// Includes deleted comments to preserve thread structure
func (r *postgresCommentRepo) GetByURIsBatch(ctx context.Context, uris []string) (map[string]*comments.Comment, error) {
if len(uris) == 0 {
return make(map[string]*comments.Comment), nil
···
// LEFT JOIN prevents data loss when user record hasn't been indexed yet (out-of-order Jetstream events)
// COALESCE falls back to DID when handle is NULL (user not yet in users table)
+
// Includes deleted comments to preserve thread structure (shown as "[deleted]" placeholders)
query := `
SELECT
c.id, c.uri, c.cid, c.rkey, c.commenter_did,
c.root_uri, c.root_cid, c.parent_uri, c.parent_cid,
c.content, c.content_facets, c.embed, c.content_labels, c.langs,
-
c.created_at, c.indexed_at, c.deleted_at,
+
c.created_at, c.indexed_at, c.deleted_at, c.deletion_reason, c.deleted_by,
c.upvote_count, c.downvote_count, c.score, c.reply_count,
COALESCE(u.handle, c.commenter_did) as author_handle
FROM comments c
LEFT JOIN users u ON c.commenter_did = u.did
-
WHERE c.uri = ANY($1) AND c.deleted_at IS NULL
+
WHERE c.uri = ANY($1)
`
rows, err := r.db.QueryContext(ctx, query, pq.Array(uris))
···
&comment.ID, &comment.URI, &comment.CID, &comment.RKey, &comment.CommenterDID,
&comment.RootURI, &comment.RootCID, &comment.ParentURI, &comment.ParentCID,
&comment.Content, &comment.ContentFacets, &comment.Embed, &comment.ContentLabels, &langs,
-
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt,
+
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt, &comment.DeletionReason, &comment.DeletedBy,
&comment.UpvoteCount, &comment.DownvoteCount, &comment.Score, &comment.ReplyCount,
&authorHandle,
)
···
c.id, c.uri, c.cid, c.rkey, c.commenter_did,
c.root_uri, c.root_cid, c.parent_uri, c.parent_cid,
c.content, c.content_facets, c.embed, c.content_labels, c.langs,
-
c.created_at, c.indexed_at, c.deleted_at,
+
c.created_at, c.indexed_at, c.deleted_at, c.deletion_reason, c.deleted_by,
c.upvote_count, c.downvote_count, c.score, c.reply_count,
log(greatest(2, c.score + 2)) / power(((EXTRACT(EPOCH FROM (NOW() - c.created_at)) / 3600) + 2), 1.8) as hot_rank,
COALESCE(u.handle, c.commenter_did) as author_handle`
···
c.id, c.uri, c.cid, c.rkey, c.commenter_did,
c.root_uri, c.root_cid, c.parent_uri, c.parent_cid,
c.content, c.content_facets, c.embed, c.content_labels, c.langs,
-
c.created_at, c.indexed_at, c.deleted_at,
+
c.created_at, c.indexed_at, c.deleted_at, c.deletion_reason, c.deleted_by,
c.upvote_count, c.downvote_count, c.score, c.reply_count,
NULL::numeric as hot_rank,
COALESCE(u.handle, c.commenter_did) as author_handle`
···
c.id, c.uri, c.cid, c.rkey, c.commenter_did,
c.root_uri, c.root_cid, c.parent_uri, c.parent_cid,
c.content, c.content_facets, c.embed, c.content_labels, c.langs,
-
c.created_at, c.indexed_at, c.deleted_at,
+
c.created_at, c.indexed_at, c.deleted_at, c.deletion_reason, c.deleted_by,
c.upvote_count, c.downvote_count, c.score, c.reply_count,
NULL::numeric as hot_rank,
COALESCE(u.handle, c.commenter_did) as author_handle`
···
c.id, c.uri, c.cid, c.rkey, c.commenter_did,
c.root_uri, c.root_cid, c.parent_uri, c.parent_cid,
c.content, c.content_facets, c.embed, c.content_labels, c.langs,
-
c.created_at, c.indexed_at, c.deleted_at,
+
c.created_at, c.indexed_at, c.deleted_at, c.deletion_reason, c.deleted_by,
c.upvote_count, c.downvote_count, c.score, c.reply_count,
log(greatest(2, c.score + 2)) / power(((EXTRACT(EPOCH FROM (NOW() - c.created_at)) / 3600) + 2), 1.8) as hot_rank,
COALESCE(u.handle, c.commenter_did) as author_handle`
···
// Use window function to limit results per parent
// This is more efficient than LIMIT in a subquery per parent
// LEFT JOIN prevents data loss when user record hasn't been indexed yet (out-of-order Jetstream events)
+
// Includes deleted comments to preserve thread structure (shown as "[deleted]" placeholders)
query := fmt.Sprintf(`
WITH ranked_comments AS (
SELECT
···
) as rn
FROM comments c
LEFT JOIN users u ON c.commenter_did = u.did
-
WHERE c.parent_uri = ANY($1) AND c.deleted_at IS NULL
+
WHERE c.parent_uri = ANY($1)
)
SELECT
id, uri, cid, rkey, commenter_did,
root_uri, root_cid, parent_uri, parent_cid,
content, content_facets, embed, content_labels, langs,
-
created_at, indexed_at, deleted_at,
+
created_at, indexed_at, deleted_at, deletion_reason, deleted_by,
upvote_count, downvote_count, score, reply_count,
hot_rank, author_handle
FROM ranked_comments
···
&comment.ID, &comment.URI, &comment.CID, &comment.RKey, &comment.CommenterDID,
&comment.RootURI, &comment.RootCID, &comment.ParentURI, &comment.ParentCID,
&comment.Content, &comment.ContentFacets, &comment.Embed, &comment.ContentLabels, &langs,
-
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt,
+
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt, &comment.DeletionReason, &comment.DeletedBy,
&comment.UpvoteCount, &comment.DownvoteCount, &comment.Score, &comment.ReplyCount,
&hotRank, &authorHandle,
)
+5 -6
internal/core/comments/comment_service.go
···
CreatedAt: createdAt, // Preserve original timestamp
}
-
// Update the record on PDS (putRecord)
-
// Note: This creates a new CID even though the URI stays the same
-
// TODO: Use PutRecord instead of CreateRecord for proper update semantics with optimistic locking.
-
// PutRecord should accept the existing CID (existingRecord.CID) to ensure concurrent updates are detected.
-
// However, PutRecord is not yet implemented in internal/atproto/pds/client.go.
-
uri, cid, err := pdsClient.CreateRecord(ctx, commentCollection, rkey, updatedRecord)
+
// Update the record on PDS with optimistic locking via swapRecord CID
+
uri, cid, err := pdsClient.PutRecord(ctx, commentCollection, rkey, updatedRecord, existingRecord.CID)
if err != nil {
s.logger.Error("failed to update comment on PDS",
"error", err,
···
if pds.IsAuthError(err) {
return nil, ErrNotAuthorized
}
+
if errors.Is(err, pds.ErrConflict) {
+
return nil, ErrConcurrentModification
+
}
return nil, fmt.Errorf("failed to update comment: %w", err)
}
+73
internal/api/handlers/common/viewer_state.go
···
+
package common
+
+
import (
+
"Coves/internal/api/middleware"
+
"Coves/internal/core/posts"
+
"Coves/internal/core/votes"
+
"context"
+
"log"
+
"net/http"
+
)
+
+
// FeedPostProvider is implemented by any feed post wrapper that contains a PostView.
+
// This allows the helper to work with different feed post types (discover, timeline, communityFeed).
+
type FeedPostProvider interface {
+
GetPost() *posts.PostView
+
}
+
+
// PopulateViewerVoteState enriches feed posts with the authenticated user's vote state.
+
// This is a no-op if voteService is nil or the request is unauthenticated.
+
//
+
// Parameters:
+
// - ctx: Request context for PDS calls
+
// - r: HTTP request (used to extract OAuth session)
+
// - voteService: Vote service for cache lookup (may be nil)
+
// - feedPosts: Posts to enrich with viewer state (must implement FeedPostProvider)
+
//
+
// The function logs but does not fail on errors - viewer state is optional enrichment.
+
func PopulateViewerVoteState[T FeedPostProvider](
+
ctx context.Context,
+
r *http.Request,
+
voteService votes.Service,
+
feedPosts []T,
+
) {
+
if voteService == nil {
+
return
+
}
+
+
session := middleware.GetOAuthSession(r)
+
if session == nil {
+
return
+
}
+
+
userDID := middleware.GetUserDID(r)
+
+
// Ensure vote cache is populated from PDS
+
if err := voteService.EnsureCachePopulated(ctx, session); err != nil {
+
log.Printf("Warning: failed to populate vote cache: %v", err)
+
return
+
}
+
+
// Collect post URIs to batch lookup
+
postURIs := make([]string, 0, len(feedPosts))
+
for _, feedPost := range feedPosts {
+
if post := feedPost.GetPost(); post != nil {
+
postURIs = append(postURIs, post.URI)
+
}
+
}
+
+
// Get viewer votes for all posts
+
viewerVotes := voteService.GetViewerVotesForSubjects(userDID, postURIs)
+
+
// Populate viewer state on each post
+
for _, feedPost := range feedPosts {
+
if post := feedPost.GetPost(); post != nil {
+
if vote, exists := viewerVotes[post.URI]; exists {
+
post.Viewer = &posts.ViewerState{
+
Vote: &vote.Direction,
+
VoteURI: &vote.URI,
+
}
+
}
+
}
+
}
+
}
+11 -4
internal/api/handlers/discover/get_discover.go
···
package discover
import (
+
"Coves/internal/api/handlers/common"
"Coves/internal/core/discover"
"Coves/internal/core/posts"
+
"Coves/internal/core/votes"
"encoding/json"
"log"
"net/http"
···
// GetDiscoverHandler handles discover feed retrieval
type GetDiscoverHandler struct {
-
service discover.Service
+
service discover.Service
+
voteService votes.Service
}
// NewGetDiscoverHandler creates a new discover handler
-
func NewGetDiscoverHandler(service discover.Service) *GetDiscoverHandler {
+
func NewGetDiscoverHandler(service discover.Service, voteService votes.Service) *GetDiscoverHandler {
return &GetDiscoverHandler{
-
service: service,
+
service: service,
+
voteService: voteService,
}
}
// HandleGetDiscover retrieves posts from all communities (public feed)
// GET /xrpc/social.coves.feed.getDiscover?sort=hot&limit=15&cursor=...
-
// Public endpoint - no authentication required
+
// Public endpoint with optional auth - if authenticated, includes viewer vote state
func (h *GetDiscoverHandler) HandleGetDiscover(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
···
return
}
+
// Populate viewer vote state if authenticated
+
common.PopulateViewerVoteState(r.Context(), r, h.voteService, response.Feed)
+
// Transform blob refs to URLs for all posts
for _, feedPost := range response.Feed {
if feedPost.Post != nil {
+9 -4
internal/api/routes/discover.go
···
import (
"Coves/internal/api/handlers/discover"
+
"Coves/internal/api/middleware"
discoverCore "Coves/internal/core/discover"
+
"Coves/internal/core/votes"
"github.com/go-chi/chi/v5"
)
···
// RegisterDiscoverRoutes registers discover-related XRPC endpoints
//
// SECURITY & RATE LIMITING:
-
// - Discover feed is PUBLIC (no authentication required)
+
// - Discover feed is PUBLIC (works without authentication)
+
// - Optional auth: if authenticated, includes viewer vote state on posts
// - Protected by global rate limiter: 100 requests/minute per IP (main.go:84)
// - Query timeout enforced via context (prevents long-running queries)
// - Result limit capped at 50 posts per request (validated in service layer)
···
func RegisterDiscoverRoutes(
r chi.Router,
discoverService discoverCore.Service,
+
voteService votes.Service,
+
authMiddleware *middleware.OAuthAuthMiddleware,
) {
// Create handlers
-
getDiscoverHandler := discover.NewGetDiscoverHandler(discoverService)
+
getDiscoverHandler := discover.NewGetDiscoverHandler(discoverService, voteService)
// GET /xrpc/social.coves.feed.getDiscover
-
// Public endpoint - no authentication required
+
// Public endpoint with optional auth for viewer-specific state (vote state)
// Shows posts from ALL communities (not personalized)
// Rate limited: 100 req/min per IP via global middleware
-
r.Get("/xrpc/social.coves.feed.getDiscover", getDiscoverHandler.HandleGetDiscover)
+
r.With(authMiddleware.OptionalAuth).Get("/xrpc/social.coves.feed.getDiscover", getDiscoverHandler.HandleGetDiscover)
}
+5
internal/core/communityFeeds/types.go
···
Reply *ReplyRef `json:"reply,omitempty"` // Reply context
}
+
// GetPost returns the underlying PostView for viewer state enrichment
+
func (f *FeedViewPost) GetPost() *posts.PostView {
+
return f.Post
+
}
+
// FeedReason is a union type for feed context
// Can be reasonRepost or reasonPin
type FeedReason struct {
+5
internal/core/discover/types.go
···
Reply *ReplyRef `json:"reply,omitempty"`
}
+
// GetPost returns the underlying PostView for viewer state enrichment
+
func (f *FeedViewPost) GetPost() *posts.PostView {
+
return f.Post
+
}
+
// FeedReason is a union type for feed context
type FeedReason struct {
Repost *ReasonRepost `json:"-"`
+5
internal/core/timeline/types.go
···
Reply *ReplyRef `json:"reply,omitempty"` // Reply context
}
+
// GetPost returns the underlying PostView for viewer state enrichment
+
func (f *FeedViewPost) GetPost() *posts.PostView {
+
return f.Post
+
}
+
// FeedReason is a union type for feed context
// Future: Can be reasonRepost or reasonCommunity
type FeedReason struct {
+193 -5
tests/integration/discover_test.go
···
import (
"Coves/internal/api/handlers/discover"
+
"Coves/internal/api/middleware"
+
"Coves/internal/core/votes"
"Coves/internal/db/postgres"
"context"
"encoding/json"
···
discoverCore "Coves/internal/core/discover"
+
oauthlib "github.com/bluesky-social/indigo/atproto/auth/oauth"
+
"github.com/bluesky-social/indigo/atproto/syntax"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
+
// mockVoteService implements votes.Service for testing viewer vote state
+
type mockVoteService struct {
+
cachedVotes map[string]*votes.CachedVote // userDID:subjectURI -> vote
+
}
+
+
func newMockVoteService() *mockVoteService {
+
return &mockVoteService{
+
cachedVotes: make(map[string]*votes.CachedVote),
+
}
+
}
+
+
func (m *mockVoteService) AddVote(userDID, subjectURI, direction, voteURI string) {
+
key := userDID + ":" + subjectURI
+
m.cachedVotes[key] = &votes.CachedVote{
+
Direction: direction,
+
URI: voteURI,
+
}
+
}
+
+
func (m *mockVoteService) CreateVote(_ context.Context, _ *oauthlib.ClientSessionData, _ votes.CreateVoteRequest) (*votes.CreateVoteResponse, error) {
+
return &votes.CreateVoteResponse{}, nil
+
}
+
+
func (m *mockVoteService) DeleteVote(_ context.Context, _ *oauthlib.ClientSessionData, _ votes.DeleteVoteRequest) error {
+
return nil
+
}
+
+
func (m *mockVoteService) EnsureCachePopulated(_ context.Context, _ *oauthlib.ClientSessionData) error {
+
return nil // Mock always succeeds - votes pre-populated via AddVote
+
}
+
+
func (m *mockVoteService) GetViewerVote(userDID, subjectURI string) *votes.CachedVote {
+
key := userDID + ":" + subjectURI
+
return m.cachedVotes[key]
+
}
+
+
func (m *mockVoteService) GetViewerVotesForSubjects(userDID string, subjectURIs []string) map[string]*votes.CachedVote {
+
result := make(map[string]*votes.CachedVote)
+
for _, uri := range subjectURIs {
+
key := userDID + ":" + uri
+
if vote, exists := m.cachedVotes[key]; exists {
+
result[uri] = vote
+
}
+
}
+
return result
+
}
+
// TestGetDiscover_ShowsAllCommunities tests discover feed shows posts from ALL communities
func TestGetDiscover_ShowsAllCommunities(t *testing.T) {
if testing.Short() {
···
// Setup services
discoverRepo := postgres.NewDiscoverRepository(db, "test-cursor-secret")
discoverService := discoverCore.NewDiscoverService(discoverRepo)
-
handler := discover.NewGetDiscoverHandler(discoverService)
+
handler := discover.NewGetDiscoverHandler(discoverService, nil) // nil vote service - tests don't need vote state
ctx := context.Background()
testID := time.Now().UnixNano()
···
// Setup services
discoverRepo := postgres.NewDiscoverRepository(db, "test-cursor-secret")
discoverService := discoverCore.NewDiscoverService(discoverRepo)
-
handler := discover.NewGetDiscoverHandler(discoverService)
+
handler := discover.NewGetDiscoverHandler(discoverService, nil) // nil vote service - tests don't need vote state
ctx := context.Background()
testID := time.Now().UnixNano()
···
// Setup services
discoverRepo := postgres.NewDiscoverRepository(db, "test-cursor-secret")
discoverService := discoverCore.NewDiscoverService(discoverRepo)
-
handler := discover.NewGetDiscoverHandler(discoverService)
+
handler := discover.NewGetDiscoverHandler(discoverService, nil)
ctx := context.Background()
testID := time.Now().UnixNano()
···
// Setup services
discoverRepo := postgres.NewDiscoverRepository(db, "test-cursor-secret")
discoverService := discoverCore.NewDiscoverService(discoverRepo)
-
handler := discover.NewGetDiscoverHandler(discoverService)
+
handler := discover.NewGetDiscoverHandler(discoverService, nil)
ctx := context.Background()
testID := time.Now().UnixNano()
···
// Setup services
discoverRepo := postgres.NewDiscoverRepository(db, "test-cursor-secret")
discoverService := discoverCore.NewDiscoverService(discoverRepo)
-
handler := discover.NewGetDiscoverHandler(discoverService)
+
handler := discover.NewGetDiscoverHandler(discoverService, nil)
t.Run("Limit exceeds maximum", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/xrpc/social.coves.feed.getDiscover?sort=new&limit=100", nil)
···
assert.Contains(t, errorResp["message"], "limit")
})
}
+
+
// TestGetDiscover_ViewerVoteState tests that authenticated users see their vote state on posts
+
func TestGetDiscover_ViewerVoteState(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping integration test in short mode")
+
}
+
+
db := setupTestDB(t)
+
t.Cleanup(func() { _ = db.Close() })
+
+
ctx := context.Background()
+
testID := time.Now().UnixNano()
+
+
// Create community and posts
+
communityDID, err := createFeedTestCommunity(db, ctx, fmt.Sprintf("votes-%d", testID), fmt.Sprintf("alice-%d.test", testID))
+
require.NoError(t, err)
+
+
post1URI := createTestPost(t, db, communityDID, "did:plc:author1", "Post with upvote", 10, time.Now().Add(-1*time.Hour))
+
post2URI := createTestPost(t, db, communityDID, "did:plc:author2", "Post with downvote", 5, time.Now().Add(-2*time.Hour))
+
_ = createTestPost(t, db, communityDID, "did:plc:author3", "Post without vote", 3, time.Now().Add(-3*time.Hour))
+
+
// Setup mock vote service with pre-populated votes
+
viewerDID := "did:plc:viewer123"
+
mockVotes := newMockVoteService()
+
mockVotes.AddVote(viewerDID, post1URI, "up", "at://"+viewerDID+"/social.coves.vote/vote1")
+
mockVotes.AddVote(viewerDID, post2URI, "down", "at://"+viewerDID+"/social.coves.vote/vote2")
+
+
// Setup handler with mock vote service
+
discoverRepo := postgres.NewDiscoverRepository(db, "test-cursor-secret")
+
discoverService := discoverCore.NewDiscoverService(discoverRepo)
+
handler := discover.NewGetDiscoverHandler(discoverService, mockVotes)
+
+
// Create request with authenticated user context
+
req := httptest.NewRequest(http.MethodGet, "/xrpc/social.coves.feed.getDiscover?sort=new&limit=50", nil)
+
+
// Inject OAuth session into context (simulates OptionalAuth middleware)
+
did, _ := syntax.ParseDID(viewerDID)
+
session := &oauthlib.ClientSessionData{
+
AccountDID: did,
+
AccessToken: "test_token",
+
}
+
reqCtx := context.WithValue(req.Context(), middleware.UserDIDKey, viewerDID)
+
reqCtx = context.WithValue(reqCtx, middleware.OAuthSessionKey, session)
+
req = req.WithContext(reqCtx)
+
+
rec := httptest.NewRecorder()
+
handler.HandleGetDiscover(rec, req)
+
+
// Assertions
+
assert.Equal(t, http.StatusOK, rec.Code)
+
+
var response discoverCore.DiscoverResponse
+
err = json.Unmarshal(rec.Body.Bytes(), &response)
+
require.NoError(t, err)
+
+
// Find our test posts and verify vote state
+
var foundPost1, foundPost2, foundPost3 bool
+
for _, feedPost := range response.Feed {
+
switch feedPost.Post.URI {
+
case post1URI:
+
foundPost1 = true
+
require.NotNil(t, feedPost.Post.Viewer, "Post1 should have viewer state")
+
require.NotNil(t, feedPost.Post.Viewer.Vote, "Post1 should have vote direction")
+
assert.Equal(t, "up", *feedPost.Post.Viewer.Vote, "Post1 should show upvote")
+
require.NotNil(t, feedPost.Post.Viewer.VoteURI, "Post1 should have vote URI")
+
assert.Contains(t, *feedPost.Post.Viewer.VoteURI, "vote1", "Post1 should have correct vote URI")
+
+
case post2URI:
+
foundPost2 = true
+
require.NotNil(t, feedPost.Post.Viewer, "Post2 should have viewer state")
+
require.NotNil(t, feedPost.Post.Viewer.Vote, "Post2 should have vote direction")
+
assert.Equal(t, "down", *feedPost.Post.Viewer.Vote, "Post2 should show downvote")
+
require.NotNil(t, feedPost.Post.Viewer.VoteURI, "Post2 should have vote URI")
+
+
default:
+
// Posts without votes should have nil Viewer or nil Vote
+
if feedPost.Post.Viewer != nil && feedPost.Post.Viewer.Vote != nil {
+
// This post has a vote from our viewer - it's not post3
+
continue
+
}
+
foundPost3 = true
+
}
+
}
+
+
assert.True(t, foundPost1, "Should find post1 with upvote")
+
assert.True(t, foundPost2, "Should find post2 with downvote")
+
assert.True(t, foundPost3, "Should find post3 without vote")
+
}
+
+
// TestGetDiscover_NoViewerStateWithoutAuth tests that unauthenticated users don't get viewer state
+
func TestGetDiscover_NoViewerStateWithoutAuth(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping integration test in short mode")
+
}
+
+
db := setupTestDB(t)
+
t.Cleanup(func() { _ = db.Close() })
+
+
ctx := context.Background()
+
testID := time.Now().UnixNano()
+
+
// Create community and post
+
communityDID, err := createFeedTestCommunity(db, ctx, fmt.Sprintf("noauth-%d", testID), fmt.Sprintf("alice-%d.test", testID))
+
require.NoError(t, err)
+
+
postURI := createTestPost(t, db, communityDID, "did:plc:author", "Some post", 10, time.Now())
+
+
// Setup mock vote service with a vote (but request will be unauthenticated)
+
mockVotes := newMockVoteService()
+
mockVotes.AddVote("did:plc:someuser", postURI, "up", "at://did:plc:someuser/social.coves.vote/vote1")
+
+
// Setup handler with mock vote service
+
discoverRepo := postgres.NewDiscoverRepository(db, "test-cursor-secret")
+
discoverService := discoverCore.NewDiscoverService(discoverRepo)
+
handler := discover.NewGetDiscoverHandler(discoverService, mockVotes)
+
+
// Create request WITHOUT auth context
+
req := httptest.NewRequest(http.MethodGet, "/xrpc/social.coves.feed.getDiscover?sort=new&limit=50", nil)
+
rec := httptest.NewRecorder()
+
handler.HandleGetDiscover(rec, req)
+
+
// Should succeed
+
assert.Equal(t, http.StatusOK, rec.Code)
+
+
var response discoverCore.DiscoverResponse
+
err = json.Unmarshal(rec.Body.Bytes(), &response)
+
require.NoError(t, err)
+
+
// Find our post and verify NO viewer state (unauthenticated)
+
for _, feedPost := range response.Feed {
+
if feedPost.Post.URI == postURI {
+
assert.Nil(t, feedPost.Post.Viewer, "Unauthenticated request should not have viewer state")
+
return
+
}
+
}
+
t.Fatal("Test post not found in response")
+
}
+11 -11
tests/integration/feed_test.go
···
nil,
)
feedService := communityFeeds.NewCommunityFeedService(feedRepo, communityService)
-
handler := communityFeed.NewGetCommunityHandler(feedService)
+
handler := communityFeed.NewGetCommunityHandler(feedService, nil)
// Setup test data: community, users, and posts
ctx := context.Background()
···
nil,
)
feedService := communityFeeds.NewCommunityFeedService(feedRepo, communityService)
-
handler := communityFeed.NewGetCommunityHandler(feedService)
+
handler := communityFeed.NewGetCommunityHandler(feedService, nil)
// Setup test data
ctx := context.Background()
···
nil,
)
feedService := communityFeeds.NewCommunityFeedService(feedRepo, communityService)
-
handler := communityFeed.NewGetCommunityHandler(feedService)
+
handler := communityFeed.NewGetCommunityHandler(feedService, nil)
// Setup test data
ctx := context.Background()
···
nil,
)
feedService := communityFeeds.NewCommunityFeedService(feedRepo, communityService)
-
handler := communityFeed.NewGetCommunityHandler(feedService)
+
handler := communityFeed.NewGetCommunityHandler(feedService, nil)
// Setup test data with many posts
ctx := context.Background()
···
nil,
)
feedService := communityFeeds.NewCommunityFeedService(feedRepo, communityService)
-
handler := communityFeed.NewGetCommunityHandler(feedService)
+
handler := communityFeed.NewGetCommunityHandler(feedService, nil)
// Request feed for non-existent community
req := httptest.NewRequest(http.MethodGet, "/xrpc/social.coves.communityFeed.getCommunity?community=did:plc:nonexistent&sort=hot&limit=10", nil)
···
nil,
)
feedService := communityFeeds.NewCommunityFeedService(feedRepo, communityService)
-
handler := communityFeed.NewGetCommunityHandler(feedService)
+
handler := communityFeed.NewGetCommunityHandler(feedService, nil)
// Setup test community
ctx := context.Background()
···
nil,
)
feedService := communityFeeds.NewCommunityFeedService(feedRepo, communityService)
-
handler := communityFeed.NewGetCommunityHandler(feedService)
+
handler := communityFeed.NewGetCommunityHandler(feedService, nil)
// Create community with no posts
ctx := context.Background()
···
nil,
)
feedService := communityFeeds.NewCommunityFeedService(feedRepo, communityService)
-
handler := communityFeed.NewGetCommunityHandler(feedService)
+
handler := communityFeed.NewGetCommunityHandler(feedService, nil)
// Setup test community
ctx := context.Background()
···
nil,
)
feedService := communityFeeds.NewCommunityFeedService(feedRepo, communityService)
-
handler := communityFeed.NewGetCommunityHandler(feedService)
+
handler := communityFeed.NewGetCommunityHandler(feedService, nil)
// Setup test data
ctx := context.Background()
···
nil,
)
feedService := communityFeeds.NewCommunityFeedService(feedRepo, communityService)
-
handler := communityFeed.NewGetCommunityHandler(feedService)
+
handler := communityFeed.NewGetCommunityHandler(feedService, nil)
// Setup test data
ctx := context.Background()
···
nil,
)
feedService := communityFeeds.NewCommunityFeedService(feedRepo, communityService)
-
handler := communityFeed.NewGetCommunityHandler(feedService)
+
handler := communityFeed.NewGetCommunityHandler(feedService, nil)
// Setup test data
ctx := context.Background()
+7 -7
tests/integration/timeline_test.go
···
// Setup services
timelineRepo := postgres.NewTimelineRepository(db, "test-cursor-secret")
timelineService := timelineCore.NewTimelineService(timelineRepo)
-
handler := timeline.NewGetTimelineHandler(timelineService)
+
handler := timeline.NewGetTimelineHandler(timelineService, nil)
ctx := context.Background()
testID := time.Now().UnixNano()
···
// Setup services
timelineRepo := postgres.NewTimelineRepository(db, "test-cursor-secret")
timelineService := timelineCore.NewTimelineService(timelineRepo)
-
handler := timeline.NewGetTimelineHandler(timelineService)
+
handler := timeline.NewGetTimelineHandler(timelineService, nil)
ctx := context.Background()
testID := time.Now().UnixNano()
···
// Setup services
timelineRepo := postgres.NewTimelineRepository(db, "test-cursor-secret")
timelineService := timelineCore.NewTimelineService(timelineRepo)
-
handler := timeline.NewGetTimelineHandler(timelineService)
+
handler := timeline.NewGetTimelineHandler(timelineService, nil)
ctx := context.Background()
testID := time.Now().UnixNano()
···
// Setup services
timelineRepo := postgres.NewTimelineRepository(db, "test-cursor-secret")
timelineService := timelineCore.NewTimelineService(timelineRepo)
-
handler := timeline.NewGetTimelineHandler(timelineService)
+
handler := timeline.NewGetTimelineHandler(timelineService, nil)
ctx := context.Background()
testID := time.Now().UnixNano()
···
// Setup services
timelineRepo := postgres.NewTimelineRepository(db, "test-cursor-secret")
timelineService := timelineCore.NewTimelineService(timelineRepo)
-
handler := timeline.NewGetTimelineHandler(timelineService)
+
handler := timeline.NewGetTimelineHandler(timelineService, nil)
// Request timeline WITHOUT auth context
req := httptest.NewRequest(http.MethodGet, "/xrpc/social.coves.feed.getTimeline?sort=new&limit=10", nil)
···
// Setup services
timelineRepo := postgres.NewTimelineRepository(db, "test-cursor-secret")
timelineService := timelineCore.NewTimelineService(timelineRepo)
-
handler := timeline.NewGetTimelineHandler(timelineService)
+
handler := timeline.NewGetTimelineHandler(timelineService, nil)
ctx := context.Background()
testID := time.Now().UnixNano()
···
// Setup services
timelineRepo := postgres.NewTimelineRepository(db, "test-cursor-secret")
timelineService := timelineCore.NewTimelineService(timelineRepo)
-
handler := timeline.NewGetTimelineHandler(timelineService)
+
handler := timeline.NewGetTimelineHandler(timelineService, nil)
ctx := context.Background()
testID := time.Now().UnixNano()
+1 -1
tests/integration/user_journey_e2e_test.go
···
r := chi.NewRouter()
routes.RegisterCommunityRoutes(r, communityService, e2eAuth.OAuthAuthMiddleware, nil) // nil = allow all community creators
routes.RegisterPostRoutes(r, postService, e2eAuth.OAuthAuthMiddleware)
-
routes.RegisterTimelineRoutes(r, timelineService, e2eAuth.OAuthAuthMiddleware)
+
routes.RegisterTimelineRoutes(r, timelineService, nil, e2eAuth.OAuthAuthMiddleware)
httpServer := httptest.NewServer(r)
defer httpServer.Close()