A community based topic aggregation platform built on atproto

Compare changes

Choose any two refs to compare.

Changed files
+9104 -3918
.beads
cmd
genjwks
reindex-votes
docs
internal
scripts
static
tests
+131
AGENTS.md
···
+
# AI Agent Guidelines for Coves
+
+
## Issue Tracking with bd (beads)
+
+
**IMPORTANT**: This project uses **bd (beads)** for ALL issue tracking. Do NOT use markdown TODOs, task lists, or other tracking methods.
+
+
### Why bd?
+
+
- Dependency-aware: Track blockers and relationships between issues
+
- Git-friendly: Auto-syncs to JSONL for version control
+
- Agent-optimized: JSON output, ready work detection, discovered-from links
+
- Prevents duplicate tracking systems and confusion
+
+
### Quick Start
+
+
**Check for ready work:**
+
```bash
+
bd ready --json
+
```
+
+
**Create new issues:**
+
```bash
+
bd create "Issue title" -t bug|feature|task -p 0-4 --json
+
bd create "Issue title" -p 1 --deps discovered-from:bd-123 --json
+
```
+
+
**Claim and update:**
+
```bash
+
bd update bd-42 --status in_progress --json
+
bd update bd-42 --priority 1 --json
+
```
+
+
**Complete work:**
+
```bash
+
bd close bd-42 --reason "Completed" --json
+
```
+
+
### Issue Types
+
+
- `bug` - Something broken
+
- `feature` - New functionality
+
- `task` - Work item (tests, docs, refactoring)
+
- `epic` - Large feature with subtasks
+
- `chore` - Maintenance (dependencies, tooling)
+
+
### Priorities
+
+
- `0` - Critical (security, data loss, broken builds)
+
- `1` - High (major features, important bugs)
+
- `2` - Medium (default, nice-to-have)
+
- `3` - Low (polish, optimization)
+
- `4` - Backlog (future ideas)
+
+
### Workflow for AI Agents
+
+
1. **Check ready work**: `bd ready` shows unblocked issues
+
2. **Claim your task**: `bd update <id> --status in_progress`
+
3. **Work on it**: Implement, test, document
+
4. **Discover new work?** Create linked issue:
+
- `bd create "Found bug" -p 1 --deps discovered-from:<parent-id>`
+
5. **Complete**: `bd close <id> --reason "Done"`
+
6. **Commit together**: Always commit the `.beads/issues.jsonl` file together with the code changes so issue state stays in sync with code state
+
+
### Auto-Sync
+
+
bd automatically syncs with git:
+
- Exports to `.beads/issues.jsonl` after changes (5s debounce)
+
- Imports from JSONL when newer (e.g., after `git pull`)
+
- No manual export/import needed!
+
+
### MCP Server (Recommended)
+
+
If using Claude or MCP-compatible clients, install the beads MCP server:
+
+
```bash
+
pip install beads-mcp
+
```
+
+
Add to MCP config (e.g., `~/.config/claude/config.json`):
+
```json
+
{
+
"beads": {
+
"command": "beads-mcp",
+
"args": []
+
}
+
}
+
```
+
+
Then use `mcp__beads__*` functions instead of CLI commands.
+
+
### Managing AI-Generated Planning Documents
+
+
AI assistants often create planning and design documents during development:
+
- PLAN.md, IMPLEMENTATION.md, ARCHITECTURE.md
+
- DESIGN.md, CODEBASE_SUMMARY.md, INTEGRATION_PLAN.md
+
- TESTING_GUIDE.md, TECHNICAL_DESIGN.md, and similar files
+
+
**Best Practice: Use a dedicated directory for these ephemeral files**
+
+
**Recommended approach:**
+
- Create a `history/` directory in the project root
+
- Store ALL AI-generated planning/design docs in `history/`
+
- Keep the repository root clean and focused on permanent project files
+
- Only access `history/` when explicitly asked to review past planning
+
+
**Example .gitignore entry (optional):**
+
```
+
# AI planning documents (ephemeral)
+
history/
+
```
+
+
**Benefits:**
+
- โœ… Clean repository root
+
- โœ… Clear separation between ephemeral and permanent documentation
+
- โœ… Easy to exclude from version control if desired
+
- โœ… Preserves planning history for archeological research
+
- โœ… Reduces noise when browsing the project
+
+
### Important Rules
+
+
- โœ… Use bd for ALL task tracking
+
- โœ… Always use `--json` flag for programmatic use
+
- โœ… Link discovered work with `discovered-from` dependencies
+
- โœ… Check `bd ready` before asking "what should I work on?"
+
- โœ… Store AI planning docs in `history/` directory
+
- โŒ Do NOT create markdown TODO lists
+
- โŒ Do NOT use external issue trackers
+
- โŒ Do NOT duplicate tracking systems
+
- โŒ Do NOT clutter repo root with planning documents
+
+
For more details, see the [beads repository](https://github.com/steveyegge/beads).
+5
internal/atproto/lexicon/social/coves/community/list.json
···
"type": "string",
"description": "Pagination cursor"
},
+
"visibility": {
+
"type": "string",
+
"knownValues": ["public", "unlisted", "private"],
+
"description": "Filter communities by visibility level"
+
},
"sort": {
"type": "string",
"knownValues": ["popular", "active", "new", "alphabetical"],
+8 -8
internal/core/communities/community.go
···
// ListCommunitiesRequest represents query parameters for listing communities
type ListCommunitiesRequest struct {
-
Visibility string `json:"visibility,omitempty"`
-
HostedBy string `json:"hostedBy,omitempty"`
-
SortBy string `json:"sortBy,omitempty"`
-
SortOrder string `json:"sortOrder,omitempty"`
-
Limit int `json:"limit"`
-
Offset int `json:"offset"`
+
Sort string `json:"sort,omitempty"` // Enum: popular, active, new, alphabetical
+
Visibility string `json:"visibility,omitempty"` // Filter: public, unlisted, private
+
Category string `json:"category,omitempty"` // Optional: filter by category (future)
+
Language string `json:"language,omitempty"` // Optional: filter by language (future)
+
Limit int `json:"limit"` // 1-100, default 50
+
Offset int `json:"offset"` // Pagination offset
}
// SearchCommunitiesRequest represents query parameters for searching communities
···
name := c.Handle[:communityIndex]
// Extract instance domain (everything after ".community.")
-
// len(".community.") = 11
-
instanceDomain := c.Handle[communityIndex+11:]
+
communitySegment := ".community."
+
instanceDomain := c.Handle[communityIndex+len(communitySegment):]
return fmt.Sprintf("!%s@%s", name, instanceDomain)
}
+2 -2
internal/core/communities/interfaces.go
···
UpdateCredentials(ctx context.Context, did, accessToken, refreshToken string) error
// Listing & Search
-
List(ctx context.Context, req ListCommunitiesRequest) ([]*Community, int, error) // Returns communities + total count
+
List(ctx context.Context, req ListCommunitiesRequest) ([]*Community, error)
Search(ctx context.Context, req SearchCommunitiesRequest) ([]*Community, int, error)
// Subscriptions (lightweight feed follows)
···
CreateCommunity(ctx context.Context, req CreateCommunityRequest) (*Community, error)
GetCommunity(ctx context.Context, identifier string) (*Community, error) // identifier can be DID or handle
UpdateCommunity(ctx context.Context, req UpdateCommunityRequest) (*Community, error)
-
ListCommunities(ctx context.Context, req ListCommunitiesRequest) ([]*Community, int, error)
+
ListCommunities(ctx context.Context, req ListCommunitiesRequest) ([]*Community, error)
SearchCommunities(ctx context.Context, req SearchCommunitiesRequest) ([]*Community, int, error)
// Subscription operations (write-forward: creates record in user's PDS)
+57
scripts/backup.sh
···
+
#!/bin/bash
+
# Coves Database Backup Script
+
# Usage: ./scripts/backup.sh
+
#
+
# Creates timestamped PostgreSQL backups in ./backups/
+
# Retention: Keeps last 30 days of backups
+
+
set -e
+
+
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
+
PROJECT_DIR="$(dirname "$SCRIPT_DIR")"
+
BACKUP_DIR="$PROJECT_DIR/backups"
+
COMPOSE_FILE="$PROJECT_DIR/docker-compose.prod.yml"
+
+
# Load environment
+
set -a
+
source "$PROJECT_DIR/.env.prod"
+
set +a
+
+
# Colors
+
GREEN='\033[0;32m'
+
YELLOW='\033[1;33m'
+
NC='\033[0m'
+
+
log() { echo -e "${GREEN}[BACKUP]${NC} $1"; }
+
warn() { echo -e "${YELLOW}[WARN]${NC} $1"; }
+
+
# Create backup directory
+
mkdir -p "$BACKUP_DIR"
+
+
# Generate timestamp
+
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
+
BACKUP_FILE="$BACKUP_DIR/coves_${TIMESTAMP}.sql.gz"
+
+
log "Starting backup..."
+
+
# Run pg_dump inside container
+
docker compose -f "$COMPOSE_FILE" exec -T postgres \
+
pg_dump -U "$POSTGRES_USER" -d "$POSTGRES_DB" --clean --if-exists \
+
| gzip > "$BACKUP_FILE"
+
+
# Get file size
+
SIZE=$(du -h "$BACKUP_FILE" | cut -f1)
+
+
log "โœ… Backup complete: $BACKUP_FILE ($SIZE)"
+
+
# Cleanup old backups (keep last 30 days)
+
log "Cleaning up backups older than 30 days..."
+
find "$BACKUP_DIR" -name "coves_*.sql.gz" -mtime +30 -delete
+
+
# List recent backups
+
log ""
+
log "Recent backups:"
+
ls -lh "$BACKUP_DIR"/*.sql.gz 2>/dev/null | tail -5
+
+
log ""
+
log "To restore: gunzip -c $BACKUP_FILE | docker compose -f docker-compose.prod.yml exec -T postgres psql -U $POSTGRES_USER -d $POSTGRES_DB"
+133
scripts/deploy.sh
···
+
#!/bin/bash
+
# Coves Deployment Script
+
# Usage: ./scripts/deploy.sh [service]
+
#
+
# Examples:
+
# ./scripts/deploy.sh # Deploy all services
+
# ./scripts/deploy.sh appview # Deploy only AppView
+
# ./scripts/deploy.sh --pull # Pull from git first, then deploy
+
+
set -e
+
+
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
+
PROJECT_DIR="$(dirname "$SCRIPT_DIR")"
+
COMPOSE_FILE="$PROJECT_DIR/docker-compose.prod.yml"
+
+
# Colors for output
+
RED='\033[0;31m'
+
GREEN='\033[0;32m'
+
YELLOW='\033[1;33m'
+
NC='\033[0m' # No Color
+
+
log() {
+
echo -e "${GREEN}[DEPLOY]${NC} $1"
+
}
+
+
warn() {
+
echo -e "${YELLOW}[WARN]${NC} $1"
+
}
+
+
error() {
+
echo -e "${RED}[ERROR]${NC} $1"
+
exit 1
+
}
+
+
# Parse arguments
+
PULL_GIT=false
+
SERVICE=""
+
+
for arg in "$@"; do
+
case $arg in
+
--pull)
+
PULL_GIT=true
+
;;
+
*)
+
SERVICE="$arg"
+
;;
+
esac
+
done
+
+
cd "$PROJECT_DIR"
+
+
# Load environment variables
+
if [ ! -f ".env.prod" ]; then
+
error ".env.prod not found! Copy from .env.prod.example and configure secrets."
+
fi
+
+
log "Loading environment from .env.prod..."
+
set -a
+
source .env.prod
+
set +a
+
+
# Optional: Pull from git
+
if [ "$PULL_GIT" = true ]; then
+
log "Pulling latest code from git..."
+
git fetch origin
+
git pull origin main
+
fi
+
+
# Check database connectivity before deployment
+
log "Checking database connectivity..."
+
if docker compose -f "$COMPOSE_FILE" exec -T postgres pg_isready -U "$POSTGRES_USER" -d "$POSTGRES_DB" > /dev/null 2>&1; then
+
log "Database is ready"
+
else
+
warn "Database not ready yet - it will start with the deployment"
+
fi
+
+
# Build and deploy
+
if [ -n "$SERVICE" ]; then
+
log "Building $SERVICE..."
+
docker compose -f "$COMPOSE_FILE" build --no-cache "$SERVICE"
+
+
log "Deploying $SERVICE..."
+
docker compose -f "$COMPOSE_FILE" up -d "$SERVICE"
+
else
+
log "Building all services..."
+
docker compose -f "$COMPOSE_FILE" build --no-cache
+
+
log "Deploying all services..."
+
docker compose -f "$COMPOSE_FILE" up -d
+
fi
+
+
# Health check
+
log "Waiting for services to be healthy..."
+
sleep 10
+
+
# Wait for database to be ready before running migrations
+
log "Waiting for database..."
+
for i in {1..30}; do
+
if docker compose -f "$COMPOSE_FILE" exec -T postgres pg_isready -U "$POSTGRES_USER" -d "$POSTGRES_DB" > /dev/null 2>&1; then
+
break
+
fi
+
sleep 1
+
done
+
+
# Run database migrations
+
# The AppView runs migrations on startup, but we can also trigger them explicitly
+
log "Running database migrations..."
+
if docker compose -f "$COMPOSE_FILE" exec -T appview /app/coves-server migrate 2>/dev/null; then
+
log "โœ… Migrations completed"
+
else
+
warn "โš ๏ธ Migration command not available or failed - AppView will run migrations on startup"
+
fi
+
+
# Check AppView health
+
if docker compose -f "$COMPOSE_FILE" exec -T appview wget --spider -q http://localhost:8080/xrpc/_health 2>/dev/null; then
+
log "โœ… AppView is healthy"
+
else
+
warn "โš ๏ธ AppView health check failed - check logs with: docker compose -f docker-compose.prod.yml logs appview"
+
fi
+
+
# Check PDS health
+
if docker compose -f "$COMPOSE_FILE" exec -T pds wget --spider -q http://localhost:3000/xrpc/_health 2>/dev/null; then
+
log "โœ… PDS is healthy"
+
else
+
warn "โš ๏ธ PDS health check failed - check logs with: docker compose -f docker-compose.prod.yml logs pds"
+
fi
+
+
log "Deployment complete!"
+
log ""
+
log "Useful commands:"
+
log " View logs: docker compose -f docker-compose.prod.yml logs -f"
+
log " Check status: docker compose -f docker-compose.prod.yml ps"
+
log " Rollback: docker compose -f docker-compose.prod.yml down && git checkout HEAD~1 && ./scripts/deploy.sh"
+149
scripts/generate-did-keys.sh
···
+
#!/bin/bash
+
# Generate cryptographic keys for Coves did:web DID document
+
#
+
# This script generates a secp256k1 (K-256) key pair as required by atproto.
+
# Reference: https://atproto.com/specs/cryptography
+
#
+
# Key format:
+
# - Curve: secp256k1 (K-256) - same as Bitcoin/Ethereum
+
# - Type: Multikey
+
# - Encoding: publicKeyMultibase with base58btc ('z' prefix)
+
# - Multicodec: 0xe7 for secp256k1 compressed public key
+
#
+
# Output:
+
# - Private key (hex) for PDS_PLC_ROTATION_KEY_K256_PRIVATE_KEY_HEX
+
# - Public key (multibase) for did.json publicKeyMultibase field
+
# - Complete did.json file
+
+
set -e
+
+
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
+
PROJECT_DIR="$(dirname "$SCRIPT_DIR")"
+
OUTPUT_DIR="$PROJECT_DIR/static/.well-known"
+
+
# Colors
+
GREEN='\033[0;32m'
+
YELLOW='\033[1;33m'
+
RED='\033[0;31m'
+
NC='\033[0m'
+
+
log() { echo -e "${GREEN}[KEYGEN]${NC} $1"; }
+
warn() { echo -e "${YELLOW}[WARN]${NC} $1"; }
+
error() { echo -e "${RED}[ERROR]${NC} $1"; exit 1; }
+
+
# Check for required tools
+
if ! command -v openssl &> /dev/null; then
+
error "openssl is required but not installed"
+
fi
+
+
if ! command -v python3 &> /dev/null; then
+
error "python3 is required for base58 encoding"
+
fi
+
+
# Check for base58 library
+
if ! python3 -c "import base58" 2>/dev/null; then
+
warn "Installing base58 Python library..."
+
pip3 install base58 || error "Failed to install base58. Run: pip3 install base58"
+
fi
+
+
log "Generating secp256k1 key pair for did:web..."
+
+
# Generate private key
+
PRIVATE_KEY_PEM=$(mktemp)
+
openssl ecparam -name secp256k1 -genkey -noout -out "$PRIVATE_KEY_PEM" 2>/dev/null
+
+
# Extract private key as hex (for PDS config)
+
PRIVATE_KEY_HEX=$(openssl ec -in "$PRIVATE_KEY_PEM" -text -noout 2>/dev/null | \
+
grep -A 3 "priv:" | tail -n 3 | tr -d ' :\n' | tr -d '\r')
+
+
# Extract public key as compressed format
+
# OpenSSL outputs the public key, we need to get the compressed form
+
PUBLIC_KEY_HEX=$(openssl ec -in "$PRIVATE_KEY_PEM" -pubout -conv_form compressed -outform DER 2>/dev/null | \
+
tail -c 33 | xxd -p | tr -d '\n')
+
+
# Clean up temp file
+
rm -f "$PRIVATE_KEY_PEM"
+
+
# Encode public key as multibase with multicodec
+
# Multicodec 0xe7 = secp256k1 compressed public key
+
# Then base58btc encode with 'z' prefix
+
PUBLIC_KEY_MULTIBASE=$(python3 << EOF
+
import base58
+
+
# Compressed public key bytes
+
pub_hex = "$PUBLIC_KEY_HEX"
+
pub_bytes = bytes.fromhex(pub_hex)
+
+
# Prepend multicodec 0xe7 for secp256k1-pub
+
# 0xe7 as varint is just 0xe7 (single byte, < 128)
+
multicodec = bytes([0xe7, 0x01]) # 0xe701 for secp256k1-pub compressed
+
key_with_codec = multicodec + pub_bytes
+
+
# Base58btc encode
+
encoded = base58.b58encode(key_with_codec).decode('ascii')
+
+
# Add 'z' prefix for multibase
+
print('z' + encoded)
+
EOF
+
)
+
+
log "Keys generated successfully!"
+
echo ""
+
echo "============================================"
+
echo " PRIVATE KEY (keep secret!)"
+
echo "============================================"
+
echo ""
+
echo "Add this to your .env.prod file:"
+
echo ""
+
echo "PDS_ROTATION_KEY=$PRIVATE_KEY_HEX"
+
echo ""
+
echo "============================================"
+
echo " PUBLIC KEY (for did.json)"
+
echo "============================================"
+
echo ""
+
echo "publicKeyMultibase: $PUBLIC_KEY_MULTIBASE"
+
echo ""
+
+
# Generate the did.json file
+
log "Generating did.json..."
+
+
mkdir -p "$OUTPUT_DIR"
+
+
cat > "$OUTPUT_DIR/did.json" << EOF
+
{
+
"id": "did:web:coves.social",
+
"alsoKnownAs": ["at://coves.social"],
+
"verificationMethod": [
+
{
+
"id": "did:web:coves.social#atproto",
+
"type": "Multikey",
+
"controller": "did:web:coves.social",
+
"publicKeyMultibase": "$PUBLIC_KEY_MULTIBASE"
+
}
+
],
+
"service": [
+
{
+
"id": "#atproto_pds",
+
"type": "AtprotoPersonalDataServer",
+
"serviceEndpoint": "https://coves.me"
+
}
+
]
+
}
+
EOF
+
+
log "Created: $OUTPUT_DIR/did.json"
+
echo ""
+
echo "============================================"
+
echo " NEXT STEPS"
+
echo "============================================"
+
echo ""
+
echo "1. Copy the PDS_ROTATION_KEY value to your .env.prod file"
+
echo ""
+
echo "2. Verify the did.json looks correct:"
+
echo " cat $OUTPUT_DIR/did.json"
+
echo ""
+
echo "3. After deployment, verify it's accessible:"
+
echo " curl https://coves.social/.well-known/did.json"
+
echo ""
+
warn "IMPORTANT: Keep the private key secret! Only share the public key."
+
warn "The did.json file with the public key IS safe to commit to git."
+106
scripts/setup-production.sh
···
+
#!/bin/bash
+
# Coves Production Setup Script
+
# Run this once on a fresh server to set up everything
+
#
+
# Prerequisites:
+
# - Docker and docker-compose installed
+
# - Git installed
+
# - .env.prod file configured
+
+
set -e
+
+
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
+
PROJECT_DIR="$(dirname "$SCRIPT_DIR")"
+
+
# Colors
+
GREEN='\033[0;32m'
+
YELLOW='\033[1;33m'
+
RED='\033[0;31m'
+
NC='\033[0m'
+
+
log() { echo -e "${GREEN}[SETUP]${NC} $1"; }
+
warn() { echo -e "${YELLOW}[WARN]${NC} $1"; }
+
error() { echo -e "${RED}[ERROR]${NC} $1"; exit 1; }
+
+
cd "$PROJECT_DIR"
+
+
# Check prerequisites
+
log "Checking prerequisites..."
+
+
if ! command -v docker &> /dev/null; then
+
error "Docker is not installed. Install with: curl -fsSL https://get.docker.com | sh"
+
fi
+
+
if ! docker compose version &> /dev/null; then
+
error "docker compose is not available. Install with: apt install docker-compose-plugin"
+
fi
+
+
# Check for .env.prod
+
if [ ! -f ".env.prod" ]; then
+
error ".env.prod not found! Copy from .env.prod.example and configure secrets."
+
fi
+
+
# Load environment
+
set -a
+
source .env.prod
+
set +a
+
+
# Create required directories
+
log "Creating directories..."
+
mkdir -p backups
+
mkdir -p static/.well-known
+
+
# Check for did.json
+
if [ ! -f "static/.well-known/did.json" ]; then
+
warn "static/.well-known/did.json not found!"
+
warn "Run ./scripts/generate-did-keys.sh to create it."
+
fi
+
+
# Note: Caddy logs are written to Docker volume (caddy-data)
+
# If you need host-accessible logs, uncomment and run as root:
+
# mkdir -p /var/log/caddy && chown 1000:1000 /var/log/caddy
+
+
# Pull Docker images
+
log "Pulling Docker images..."
+
docker compose -f docker-compose.prod.yml pull postgres pds caddy
+
+
# Build AppView
+
log "Building AppView..."
+
docker compose -f docker-compose.prod.yml build appview
+
+
# Start services
+
log "Starting services..."
+
docker compose -f docker-compose.prod.yml up -d
+
+
# Wait for PostgreSQL
+
log "Waiting for PostgreSQL to be ready..."
+
until docker compose -f docker-compose.prod.yml exec -T postgres pg_isready -U "$POSTGRES_USER" -d "$POSTGRES_DB" > /dev/null 2>&1; do
+
sleep 2
+
done
+
log "PostgreSQL is ready!"
+
+
# Run migrations
+
log "Running database migrations..."
+
# The AppView runs migrations on startup, but you can also run them manually:
+
# docker compose -f docker-compose.prod.yml exec appview /app/coves-server migrate
+
+
# Final status
+
log ""
+
log "============================================"
+
log " Coves Production Setup Complete!"
+
log "============================================"
+
log ""
+
log "Services running:"
+
docker compose -f docker-compose.prod.yml ps
+
log ""
+
log "Next steps:"
+
log " 1. Configure DNS for coves.social and coves.me"
+
log " 2. Run ./scripts/generate-did-keys.sh to create DID keys"
+
log " 3. Test health endpoints:"
+
log " curl https://coves.social/xrpc/_health"
+
log " curl https://coves.me/xrpc/_health"
+
log ""
+
log "Useful commands:"
+
log " View logs: docker compose -f docker-compose.prod.yml logs -f"
+
log " Deploy update: ./scripts/deploy.sh appview"
+
log " Backup DB: ./scripts/backup.sh"
+19
static/.well-known/did.json.template
···
+
{
+
"id": "did:web:coves.social",
+
"alsoKnownAs": ["at://coves.social"],
+
"verificationMethod": [
+
{
+
"id": "did:web:coves.social#atproto",
+
"type": "Multikey",
+
"controller": "did:web:coves.social",
+
"publicKeyMultibase": "REPLACE_WITH_YOUR_PUBLIC_KEY"
+
}
+
],
+
"service": [
+
{
+
"id": "#atproto_pds",
+
"type": "AtprotoPersonalDataServer",
+
"serviceEndpoint": "https://coves.me"
+
}
+
]
+
}
+18
static/client-metadata.json
···
+
{
+
"client_id": "https://coves.social/client-metadata.json",
+
"client_name": "Coves",
+
"client_uri": "https://coves.social",
+
"logo_uri": "https://coves.social/logo.png",
+
"tos_uri": "https://coves.social/terms",
+
"policy_uri": "https://coves.social/privacy",
+
"redirect_uris": [
+
"https://coves.social/oauth/callback",
+
"social.coves:/oauth/callback"
+
],
+
"scope": "atproto transition:generic",
+
"grant_types": ["authorization_code", "refresh_token"],
+
"response_types": ["code"],
+
"application_type": "native",
+
"token_endpoint_auth_method": "none",
+
"dpop_bound_access_tokens": true
+
}
+2 -1
Dockerfile
···
COPY --from=builder /build/coves-server /app/coves-server
# Copy migrations (needed for goose)
-
COPY --from=builder /build/internal/db/migrations /app/migrations
+
# Must maintain path structure as app looks for internal/db/migrations
+
COPY --from=builder /build/internal/db/migrations /app/internal/db/migrations
# Set ownership
RUN chown -R coves:coves /app
+187
scripts/derive-did-from-key.sh
···
+
#!/bin/bash
+
# Derive public key from existing PDS_ROTATION_KEY and create did.json
+
#
+
# This script takes your existing private key and derives the public key from it.
+
# Use this if you already have a PDS running with a rotation key but need to
+
# create/fix the did.json file.
+
#
+
# Usage: ./scripts/derive-did-from-key.sh
+
+
set -e
+
+
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
+
PROJECT_DIR="$(dirname "$SCRIPT_DIR")"
+
OUTPUT_DIR="$PROJECT_DIR/static/.well-known"
+
+
# Colors
+
GREEN='\033[0;32m'
+
YELLOW='\033[1;33m'
+
RED='\033[0;31m'
+
NC='\033[0m'
+
+
log() { echo -e "${GREEN}[DERIVE]${NC} $1"; }
+
warn() { echo -e "${YELLOW}[WARN]${NC} $1"; }
+
error() { echo -e "${RED}[ERROR]${NC} $1"; exit 1; }
+
+
# Check for required tools
+
if ! command -v openssl &> /dev/null; then
+
error "openssl is required but not installed"
+
fi
+
+
if ! command -v python3 &> /dev/null; then
+
error "python3 is required for base58 encoding"
+
fi
+
+
# Check for base58 library
+
if ! python3 -c "import base58" 2>/dev/null; then
+
warn "Installing base58 Python library..."
+
pip3 install base58 || error "Failed to install base58. Run: pip3 install base58"
+
fi
+
+
# Load environment to get the existing key
+
if [ -f "$PROJECT_DIR/.env.prod" ]; then
+
source "$PROJECT_DIR/.env.prod"
+
elif [ -f "$PROJECT_DIR/.env" ]; then
+
source "$PROJECT_DIR/.env"
+
else
+
error "No .env.prod or .env file found"
+
fi
+
+
if [ -z "$PDS_ROTATION_KEY" ]; then
+
error "PDS_ROTATION_KEY not found in environment"
+
fi
+
+
# Validate key format (should be 64 hex chars)
+
if [[ ! "$PDS_ROTATION_KEY" =~ ^[0-9a-fA-F]{64}$ ]]; then
+
error "PDS_ROTATION_KEY is not a valid 64-character hex string"
+
fi
+
+
log "Deriving public key from existing PDS_ROTATION_KEY..."
+
+
# Create a temporary PEM file from the hex private key
+
TEMP_DIR=$(mktemp -d)
+
PRIVATE_KEY_HEX="$PDS_ROTATION_KEY"
+
+
# Convert hex private key to PEM format
+
# secp256k1 curve OID: 1.3.132.0.10
+
python3 > "$TEMP_DIR/private.pem" << EOF
+
import binascii
+
+
# Private key in hex
+
priv_hex = "$PRIVATE_KEY_HEX"
+
priv_bytes = binascii.unhexlify(priv_hex)
+
+
# secp256k1 OID
+
oid = bytes([0x06, 0x05, 0x2b, 0x81, 0x04, 0x00, 0x0a])
+
+
# Build the EC private key structure
+
# SEQUENCE { version INTEGER, privateKey OCTET STRING, [0] OID, [1] publicKey }
+
# We'll use a simpler approach: just the private key with curve params
+
+
# EC PARAMETERS for secp256k1
+
ec_params = bytes([
+
0x30, 0x07, # SEQUENCE, 7 bytes
+
0x06, 0x05, 0x2b, 0x81, 0x04, 0x00, 0x0a # OID for secp256k1
+
])
+
+
# EC PRIVATE KEY structure
+
# SEQUENCE { version, privateKey, [0] parameters }
+
inner = bytes([0x02, 0x01, 0x01]) # version = 1
+
inner += bytes([0x04, 0x20]) + priv_bytes # OCTET STRING with 32-byte key
+
inner += bytes([0xa0, 0x07]) + bytes([0x06, 0x05, 0x2b, 0x81, 0x04, 0x00, 0x0a]) # [0] OID
+
+
# Wrap in SEQUENCE
+
key_der = bytes([0x30, len(inner)]) + inner
+
+
# Base64 encode
+
import base64
+
key_b64 = base64.b64encode(key_der).decode('ascii')
+
+
# Format as PEM
+
print("-----BEGIN EC PRIVATE KEY-----")
+
for i in range(0, len(key_b64), 64):
+
print(key_b64[i:i+64])
+
print("-----END EC PRIVATE KEY-----")
+
EOF
+
+
# Extract the compressed public key
+
PUBLIC_KEY_HEX=$(openssl ec -in "$TEMP_DIR/private.pem" -pubout -conv_form compressed -outform DER 2>/dev/null | \
+
tail -c 33 | xxd -p | tr -d '\n')
+
+
# Clean up
+
rm -rf "$TEMP_DIR"
+
+
if [ -z "$PUBLIC_KEY_HEX" ] || [ ${#PUBLIC_KEY_HEX} -ne 66 ]; then
+
error "Failed to derive public key. Got: $PUBLIC_KEY_HEX"
+
fi
+
+
log "Derived public key: ${PUBLIC_KEY_HEX:0:8}...${PUBLIC_KEY_HEX: -8}"
+
+
# Encode public key as multibase with multicodec
+
PUBLIC_KEY_MULTIBASE=$(python3 << EOF
+
import base58
+
+
# Compressed public key bytes
+
pub_hex = "$PUBLIC_KEY_HEX"
+
pub_bytes = bytes.fromhex(pub_hex)
+
+
# Prepend multicodec 0xe7 for secp256k1-pub
+
# 0xe7 as varint is just 0xe7 (single byte, < 128)
+
multicodec = bytes([0xe7, 0x01]) # 0xe701 for secp256k1-pub compressed
+
key_with_codec = multicodec + pub_bytes
+
+
# Base58btc encode
+
encoded = base58.b58encode(key_with_codec).decode('ascii')
+
+
# Add 'z' prefix for multibase
+
print('z' + encoded)
+
EOF
+
)
+
+
log "Public key multibase: $PUBLIC_KEY_MULTIBASE"
+
+
# Generate the did.json file
+
log "Generating did.json..."
+
+
mkdir -p "$OUTPUT_DIR"
+
+
cat > "$OUTPUT_DIR/did.json" << EOF
+
{
+
"id": "did:web:coves.social",
+
"alsoKnownAs": ["at://coves.social"],
+
"verificationMethod": [
+
{
+
"id": "did:web:coves.social#atproto",
+
"type": "Multikey",
+
"controller": "did:web:coves.social",
+
"publicKeyMultibase": "$PUBLIC_KEY_MULTIBASE"
+
}
+
],
+
"service": [
+
{
+
"id": "#atproto_pds",
+
"type": "AtprotoPersonalDataServer",
+
"serviceEndpoint": "https://coves.me"
+
}
+
]
+
}
+
EOF
+
+
log "Created: $OUTPUT_DIR/did.json"
+
echo ""
+
echo "============================================"
+
echo " DID Document Generated Successfully!"
+
echo "============================================"
+
echo ""
+
echo "Public key multibase: $PUBLIC_KEY_MULTIBASE"
+
echo ""
+
echo "Next steps:"
+
echo " 1. Copy this file to your production server:"
+
echo " scp $OUTPUT_DIR/did.json user@server:/opt/coves/static/.well-known/"
+
echo ""
+
echo " 2. Or if running on production, restart Caddy:"
+
echo " docker compose -f docker-compose.prod.yml restart caddy"
+
echo ""
+
echo " 3. Verify it's accessible:"
+
echo " curl https://coves.social/.well-known/did.json"
+
echo ""
+1 -1
internal/api/handlers/aggregator/register.go
···
if err != nil {
return fmt.Errorf("failed to fetch .well-known/atproto-did from %s: %w", domain, err)
}
-
defer resp.Body.Close()
+
defer func() { _ = resp.Body.Close() }()
// Check status code
if resp.StatusCode != http.StatusOK {
+1 -2
internal/api/handlers/community/list.go
···
package community
import (
+
"Coves/internal/core/communities"
"encoding/json"
"net/http"
"strconv"
-
-
"Coves/internal/core/communities"
)
// ListHandler handles listing communities
+1 -2
internal/core/communities/service.go
···
package communities
import (
+
"Coves/internal/atproto/utils"
"bytes"
"context"
"encoding/json"
···
"strings"
"sync"
"time"
-
-
"Coves/internal/atproto/utils"
)
// Community handle validation regex (DNS-valid handle: name.community.instance.com)
+2 -4
internal/db/postgres/community_repo.go
···
package postgres
import (
+
"Coves/internal/core/communities"
"context"
"database/sql"
"fmt"
"log"
"strings"
-
"Coves/internal/core/communities"
-
"github.com/lib/pq"
)
···
}
// Build sort clause - map sort enum to DB columns
-
sortColumn := "subscriber_count" // default: popular
-
sortOrder := "DESC"
+
var sortColumn, sortOrder string
switch req.Sort {
case "popular":
+1 -2
tests/e2e/ratelimit_e2e_test.go
···
package e2e
import (
+
"Coves/internal/api/middleware"
"bytes"
"encoding/json"
"net/http"
···
"testing"
"time"
-
"Coves/internal/api/middleware"
-
"github.com/stretchr/testify/assert"
)
+14 -14
tests/integration/aggregator_registration_test.go
···
// Setup test database
db := setupTestDB(t)
-
defer db.Close()
+
defer func() { _ = db.Close() }()
testDID := "did:plc:test123"
testHandle := "aggregator.bsky.social"
···
wellKnownServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/.well-known/atproto-did" {
w.Header().Set("Content-Type", "text/plain")
-
w.Write([]byte(testDID))
+
_, _ = w.Write([]byte(testDID))
} else {
w.WriteHeader(http.StatusNotFound)
}
···
// Setup test database
db := setupTestDB(t)
-
defer db.Close()
+
defer func() { _ = db.Close() }()
// Setup test server that returns wrong DID
wellKnownServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/.well-known/atproto-did" {
w.Header().Set("Content-Type", "text/plain")
-
w.Write([]byte("did:plc:wrongdid"))
+
_, _ = w.Write([]byte("did:plc:wrongdid"))
} else {
w.WriteHeader(http.StatusNotFound)
}
···
}
db := setupTestDB(t)
-
defer db.Close()
+
defer func() { _ = db.Close() }()
tests := []struct {
name string
···
}
db := setupTestDB(t)
-
defer db.Close()
+
defer func() { _ = db.Close() }()
// Pre-create user with same DID
existingDID := "did:plc:existing123"
···
wellKnownServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/.well-known/atproto-did" {
w.Header().Set("Content-Type", "text/plain")
-
w.Write([]byte(existingDID))
+
_, _ = w.Write([]byte(existingDID))
} else {
w.WriteHeader(http.StatusNotFound)
}
···
}
db := setupTestDB(t)
-
defer db.Close()
+
defer func() { _ = db.Close() }()
// Setup test server that returns 404 for .well-known
wellKnownServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
···
}
db := setupTestDB(t)
-
defer db.Close()
+
defer func() { _ = db.Close() }()
testDID := "did:plc:toolarge"
···
}
db := setupTestDB(t)
-
defer db.Close()
+
defer func() { _ = db.Close() }()
testDID := "did:plc:nonexistent"
···
wellKnownServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/.well-known/atproto-did" {
w.Header().Set("Content-Type", "text/plain")
-
w.Write([]byte(testDID))
+
_, _ = w.Write([]byte(testDID))
} else {
w.WriteHeader(http.StatusNotFound)
}
···
}
db := setupTestDB(t)
-
defer db.Close()
+
defer func() { _ = db.Close() }()
testDID := "did:plc:largedos123"
···
// with real .well-known server and real identity resolution
db := setupTestDB(t)
-
defer db.Close()
+
defer func() { _ = db.Close() }()
testDID := "did:plc:e2etest123"
testHandle := "e2ebot.bsky.social"
···
wellKnownServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/.well-known/atproto-did" {
w.Header().Set("Content-Type", "text/plain")
-
w.Write([]byte(testDID))
+
_, _ = w.Write([]byte(testDID))
} else {
w.WriteHeader(http.StatusNotFound)
}
+2 -3
tests/integration/community_repo_test.go
···
package integration
import (
+
"Coves/internal/core/communities"
+
"Coves/internal/db/postgres"
"context"
"fmt"
"testing"
"time"
-
-
"Coves/internal/core/communities"
-
"Coves/internal/db/postgres"
)
func TestCommunityRepository_Create(t *testing.T) {
+23
static/.well-known/did.json
···
+
{
+
"@context": [
+
"https://www.w3.org/ns/did/v1",
+
"https://w3id.org/security/multikey/v1"
+
],
+
"id": "did:web:coves.social",
+
"alsoKnownAs": ["at://coves.social"],
+
"verificationMethod": [
+
{
+
"id": "did:web:coves.social#atproto",
+
"type": "Multikey",
+
"controller": "did:web:coves.social",
+
"publicKeyMultibase": "zQ3shu1T3Y3MYoC1n7fCqkZqyrk8FiY3PV3BYM2JwyqcXFY6s"
+
}
+
],
+
"service": [
+
{
+
"id": "#atproto_pds",
+
"type": "AtprotoPersonalDataServer",
+
"serviceEndpoint": "https://pds.coves.me"
+
}
+
]
+
}
+1 -1
docs/E2E_TESTING.md
···
Query via API:
```bash
-
curl "http://localhost:8081/xrpc/social.coves.actor.getProfile?actor=alice.local.coves.dev"
+
curl "http://localhost:8081/xrpc/social.coves.actor.getprofile?actor=alice.local.coves.dev"
```
Expected response:
+3 -3
internal/api/routes/user.go
···
func RegisterUserRoutes(r chi.Router, service users.UserService) {
h := NewUserHandler(service)
-
// social.coves.actor.getProfile - query endpoint
-
r.Get("/xrpc/social.coves.actor.getProfile", h.GetProfile)
+
// social.coves.actor.getprofile - query endpoint
+
r.Get("/xrpc/social.coves.actor.getprofile", h.GetProfile)
// social.coves.actor.signup - procedure endpoint
r.Post("/xrpc/social.coves.actor.signup", h.Signup)
}
-
// GetProfile handles social.coves.actor.getProfile
+
// GetProfile handles social.coves.actor.getprofile
// Query endpoint that retrieves a user profile by DID or handle
func (h *UserHandler) GetProfile(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
+1 -1
internal/atproto/lexicon/social/coves/actor/getProfile.json
···
{
"lexicon": 1,
-
"id": "social.coves.actor.getProfile",
+
"id": "social.coves.actor.getprofile",
"defs": {
"main": {
"type": "query",
+1 -1
internal/atproto/lexicon/social/coves/actor/updateProfile.json
···
{
"lexicon": 1,
-
"id": "social.coves.actor.updateProfile",
+
"id": "social.coves.actor.updateprofile",
"defs": {
"main": {
"type": "procedure",
+4 -4
tests/integration/user_test.go
···
// Test 1: Get profile by DID
t.Run("Get Profile By DID", func(t *testing.T) {
-
req := httptest.NewRequest("GET", "/xrpc/social.coves.actor.getProfile?actor=did:plc:endpoint123", nil)
+
req := httptest.NewRequest("GET", "/xrpc/social.coves.actor.getprofile?actor=did:plc:endpoint123", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
···
// Test 2: Get profile by handle
t.Run("Get Profile By Handle", func(t *testing.T) {
-
req := httptest.NewRequest("GET", "/xrpc/social.coves.actor.getProfile?actor=bob.test", nil)
+
req := httptest.NewRequest("GET", "/xrpc/social.coves.actor.getprofile?actor=bob.test", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
···
// Test 3: Missing actor parameter
t.Run("Missing Actor Parameter", func(t *testing.T) {
-
req := httptest.NewRequest("GET", "/xrpc/social.coves.actor.getProfile", nil)
+
req := httptest.NewRequest("GET", "/xrpc/social.coves.actor.getprofile", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
···
// Test 4: User not found
t.Run("User Not Found", func(t *testing.T) {
-
req := httptest.NewRequest("GET", "/xrpc/social.coves.actor.getProfile?actor=nonexistent.test", nil)
+
req := httptest.NewRequest("GET", "/xrpc/social.coves.actor.getprofile?actor=nonexistent.test", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
+44 -5
internal/atproto/lexicon/social/coves/embed/external.json
···
"defs": {
"main": {
"type": "object",
-
"description": "External link embed with preview metadata and provider support",
+
"description": "External link embed with optional aggregated sources for megathreads",
"required": ["external"],
"properties": {
"external": {
···
},
"external": {
"type": "object",
-
"description": "External link metadata",
+
"description": "Primary external link metadata",
"required": ["uri"],
"properties": {
"uri": {
"type": "string",
"format": "uri",
-
"description": "URI of the external content"
+
"description": "URI of the primary external content"
},
"title": {
"type": "string",
···
"type": "blob",
"accept": ["image/png", "image/jpeg", "image/webp"],
"maxSize": 1000000,
-
"description": "Thumbnail image for the link"
+
"description": "Thumbnail image for the post (applies to primary link)"
},
"domain": {
"type": "string",
-
"description": "Domain of the linked content"
+
"maxLength": 253,
+
"description": "Domain of the linked content (e.g., nytimes.com)"
},
"embedType": {
"type": "string",
···
},
"provider": {
"type": "string",
+
"maxLength": 100,
"description": "Service provider name (e.g., imgur, streamable)"
},
"images": {
···
"type": "integer",
"minimum": 0,
"description": "Total number of items if more than displayed (for galleries)"
+
},
+
"sources": {
+
"type": "array",
+
"description": "Aggregated source links for megathreads. Each source references an original article and optionally the Coves post that shared it",
+
"maxLength": 50,
+
"items": {
+
"type": "ref",
+
"ref": "#source"
+
}
+
}
+
}
+
},
+
"source": {
+
"type": "object",
+
"description": "A source link aggregated into a megathread",
+
"required": ["uri"],
+
"properties": {
+
"uri": {
+
"type": "string",
+
"format": "uri",
+
"description": "URI of the source article"
+
},
+
"title": {
+
"type": "string",
+
"maxLength": 500,
+
"maxGraphemes": 500,
+
"description": "Title of the source article"
+
},
+
"domain": {
+
"type": "string",
+
"maxLength": 253,
+
"description": "Domain of the source (e.g., nytimes.com)"
+
},
+
"sourcePost": {
+
"type": "ref",
+
"ref": "com.atproto.repo.strongRef",
+
"description": "Reference to the Coves post that originally shared this link. Used for feed deprioritization of rolled-up posts"
}
}
}
+5 -5
docs/COMMENT_SYSTEM_IMPLEMENTATION.md
···
- Lexicon definitions: `social.coves.community.comment.defs` and `getComments`
- Database query methods with Lemmy hot ranking algorithm
- Service layer with iterative loading strategy for nested replies
-
- XRPC HTTP handler with optional authentication
+
- XRPC HTTP handler with optional DPoP authentication
- Comprehensive integration test suite (11 test scenarios)
**What works:**
···
- Nested replies up to configurable depth (default 10, max 100)
- Lemmy hot ranking: `log(greatest(2, score + 2)) / power(time_decay, 1.8)`
- Cursor-based pagination for stable scrolling
-
- Optional authentication for viewer state (stubbed for Phase 2B)
+
- Optional DPoP authentication for viewer state (stubbed for Phase 2B)
- Timeframe filtering for "top" sort (hour/day/week/month/year/all)
**Endpoints:**
···
- Required: `post` (AT-URI)
- Optional: `sort` (hot/top/new), `depth` (0-100), `limit` (1-100), `cursor`, `timeframe`
- Returns: Array of `threadViewComment` with nested replies + post context
-
- Supports Bearer token for authenticated requests (viewer state)
+
- Supports DPoP-bound access token for authenticated requests (viewer state)
**Files created (9):**
1. `internal/atproto/lexicon/social/coves/community/comment/defs.json` - View definitions
···
**8. Viewer Authentication Validation (Non-Issue - Architecture Working as Designed)**
- **Initial Concern:** ViewerDID field trusted without verification in service layer
- **Investigation:** Authentication IS properly validated at middleware layer
-
- `OptionalAuth` middleware extracts and validates JWT Bearer tokens
+
- `OptionalAuth` middleware extracts and validates DPoP-bound access tokens
- Uses PDS public keys (JWKS) for signature verification
-
- Validates token expiration, DID format, issuer
+
- Validates DPoP proof, token expiration, DID format, issuer
- Only injects verified DIDs into request context
- Handler extracts DID using `middleware.GetUserDID(r)`
- **Architecture:** Follows industry best practices (authentication at perimeter)
+7 -4
docs/FEED_SYSTEM_IMPLEMENTATION.md
···
# Get personalized timeline (hot posts from subscriptions)
curl -X GET \
'http://localhost:8081/xrpc/social.coves.feed.getTimeline?sort=hot&limit=15' \
-
-H 'Authorization: Bearer eyJhbGc...'
+
-H 'Authorization: DPoP eyJhbGc...' \
+
-H 'DPoP: eyJhbGc...'
# Get top posts from last week
curl -X GET \
'http://localhost:8081/xrpc/social.coves.feed.getTimeline?sort=top&timeframe=week&limit=20' \
-
-H 'Authorization: Bearer eyJhbGc...'
+
-H 'Authorization: DPoP eyJhbGc...' \
+
-H 'DPoP: eyJhbGc...'
# Get newest posts with pagination
curl -X GET \
'http://localhost:8081/xrpc/social.coves.feed.getTimeline?sort=new&limit=10&cursor=<cursor>' \
-
-H 'Authorization: Bearer eyJhbGc...'
+
-H 'Authorization: DPoP eyJhbGc...' \
+
-H 'DPoP: eyJhbGc...'
```
**Response:**
···
- โœ… Context timeout support
### Authentication (Timeline)
-
- โœ… JWT Bearer token required
+
- โœ… DPoP-bound access token required
- โœ… DID extracted from auth context
- โœ… Validates token signature (when AUTH_SKIP_VERIFY=false)
- โœ… Returns 401 on auth failure
+3 -3
docs/PRD_OAUTH.md
···
- โœ… Auth middleware protecting community endpoints
- โœ… Handlers updated to use `GetUserDID(r)`
- โœ… Comprehensive middleware auth tests (11 test cases)
-
- โœ… E2E tests updated to use Bearer tokens
+
- โœ… E2E tests updated to use DPoP-bound tokens
- โœ… Security logging with IP, method, path, issuer
- โœ… Scope validation (atproto required)
- โœ… Issuer HTTPS validation
···
Authorization: DPoP eyJhbGciOiJFUzI1NiIsInR5cCI6ImF0K2p3dCIsImtpZCI6ImRpZDpwbGM6YWxpY2UjYXRwcm90by1wZHMifQ...
```
-
Format: `DPoP <access_token>`
+
Format: `DPoP <access_token>` (note: uses "DPoP" scheme, not "Bearer")
The access token is a JWT containing:
```json
···
- [x] All community endpoints reject requests without valid JWT structure
- [x] Integration tests pass with mock tokens (11/11 middleware tests passing)
- [x] Zero security regressions from X-User-DID (JWT validation is strictly better)
-
- [x] E2E tests updated to use proper Bearer token authentication
+
- [x] E2E tests updated to use proper DPoP token authentication
- [x] Build succeeds without compilation errors
### Phase 2 (Beta) - โœ… READY FOR TESTING
+3 -1
docs/aggregators/SETUP_GUIDE.md
···
**Request**:
```bash
+
# Note: This calls the PDS directly, so it uses Bearer authorization (not DPoP)
curl -X POST https://bsky.social/xrpc/com.atproto.repo.createRecord \
-H "Authorization: Bearer YOUR_ACCESS_TOKEN" \
-H "Content-Type: application/json" \
···
**Request**:
```bash
+
# Note: This calls the Coves API, so it uses DPoP authorization
curl -X POST https://api.coves.social/xrpc/social.coves.community.post.create \
-
-H "Authorization: Bearer YOUR_ACCESS_TOKEN" \
+
-H "Authorization: DPoP YOUR_ACCESS_TOKEN" \
-H "Content-Type: application/json" \
-d '{
"communityDid": "did:plc:community123...",
+8 -2
docs/federation-prd.md
···
req, _ := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewBuffer(jsonData))
// Use service auth token instead of community credentials
+
// NOTE: Auth scheme depends on target PDS implementation:
+
// - Standard atproto service auth uses "Bearer" scheme
+
// - Our AppView uses "DPoP" scheme when DPoP-bound tokens are required
+
// For server-to-server with standard PDS, use Bearer; adjust based on target.
req.Header.Set("Authorization", "Bearer "+serviceAuthToken)
req.Header.Set("Content-Type", "application/json")
···
**Request to Remote PDS:**
```http
POST https://covesinstance.com/xrpc/com.atproto.server.getServiceAuth
-
Authorization: Bearer {coves-social-instance-jwt}
+
Authorization: DPoP {coves-social-instance-jwt}
+
DPoP: {coves-social-dpop-proof}
Content-Type: application/json
{
···
**Using Token to Create Post:**
```http
POST https://covesinstance.com/xrpc/com.atproto.repo.createRecord
-
Authorization: Bearer {service-auth-token}
+
Authorization: DPoP {service-auth-token}
+
DPoP: {service-auth-dpop-proof}
Content-Type: application/json
{
+1 -1
scripts/aggregator-setup/README.md
···
```bash
curl -X POST https://api.coves.social/xrpc/social.coves.community.post.create \
-
-H "Authorization: Bearer $AGGREGATOR_ACCESS_JWT" \
+
-H "Authorization: DPoP $AGGREGATOR_ACCESS_JWT" \
-H "Content-Type: application/json" \
-d '{
"communityDid": "did:plc:...",
+1
tests/integration/blob_upload_e2e_test.go
···
assert.Equal(t, "POST", r.Method, "Should be POST request")
assert.Equal(t, "/xrpc/com.atproto.repo.uploadBlob", r.URL.Path, "Should hit uploadBlob endpoint")
assert.Equal(t, "image/png", r.Header.Get("Content-Type"), "Should have correct content type")
+
// Note: This is a PDS call, so it uses Bearer (not DPoP)
assert.Contains(t, r.Header.Get("Authorization"), "Bearer ", "Should have auth header")
// Return mock blob reference
+477
internal/atproto/oauth/handlers_security_test.go
···
+
package oauth
+
+
import (
+
"net/http"
+
"net/http/httptest"
+
"testing"
+
+
"github.com/stretchr/testify/assert"
+
"github.com/stretchr/testify/require"
+
)
+
+
// TestIsAllowedMobileRedirectURI tests the mobile redirect URI allowlist with EXACT URI matching
+
// Only Universal Links (HTTPS) are allowed - custom schemes are blocked for security
+
func TestIsAllowedMobileRedirectURI(t *testing.T) {
+
tests := []struct {
+
name string
+
uri string
+
expected bool
+
}{
+
{
+
name: "allowed - Universal Link",
+
uri: "https://coves.social/app/oauth/callback",
+
expected: true,
+
},
+
{
+
name: "rejected - custom scheme coves-app (vulnerable to interception)",
+
uri: "coves-app://oauth/callback",
+
expected: false,
+
},
+
{
+
name: "rejected - custom scheme coves (vulnerable to interception)",
+
uri: "coves://oauth/callback",
+
expected: false,
+
},
+
{
+
name: "rejected - evil scheme",
+
uri: "evil://callback",
+
expected: false,
+
},
+
{
+
name: "rejected - http (not secure)",
+
uri: "http://example.com/callback",
+
expected: false,
+
},
+
{
+
name: "rejected - https different domain",
+
uri: "https://example.com/callback",
+
expected: false,
+
},
+
{
+
name: "rejected - https coves.social wrong path",
+
uri: "https://coves.social/wrong/path",
+
expected: false,
+
},
+
{
+
name: "rejected - invalid URI",
+
uri: "not a uri",
+
expected: false,
+
},
+
{
+
name: "rejected - empty string",
+
uri: "",
+
expected: false,
+
},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
result := isAllowedMobileRedirectURI(tt.uri)
+
assert.Equal(t, tt.expected, result,
+
"isAllowedMobileRedirectURI(%q) = %v, want %v", tt.uri, result, tt.expected)
+
})
+
}
+
}
+
+
// TestExtractScheme tests the scheme extraction function
+
func TestExtractScheme(t *testing.T) {
+
tests := []struct {
+
name string
+
uri string
+
expected string
+
}{
+
{
+
name: "https scheme",
+
uri: "https://coves.social/app/oauth/callback",
+
expected: "https",
+
},
+
{
+
name: "custom scheme",
+
uri: "coves-app://callback",
+
expected: "coves-app",
+
},
+
{
+
name: "invalid URI",
+
uri: "not a uri",
+
expected: "invalid",
+
},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
result := extractScheme(tt.uri)
+
assert.Equal(t, tt.expected, result)
+
})
+
}
+
}
+
+
// TestGenerateCSRFToken tests CSRF token generation
+
func TestGenerateCSRFToken(t *testing.T) {
+
// Generate two tokens and verify they are different (randomness check)
+
token1, err1 := generateCSRFToken()
+
require.NoError(t, err1)
+
require.NotEmpty(t, token1)
+
+
token2, err2 := generateCSRFToken()
+
require.NoError(t, err2)
+
require.NotEmpty(t, token2)
+
+
assert.NotEqual(t, token1, token2, "CSRF tokens should be unique")
+
+
// Verify token is base64 encoded (should decode without error)
+
assert.Greater(t, len(token1), 40, "CSRF token should be reasonably long (32 bytes base64 encoded)")
+
}
+
+
// TestHandleMobileLogin_RedirectURIValidation tests that HandleMobileLogin validates redirect URIs
+
func TestHandleMobileLogin_RedirectURIValidation(t *testing.T) {
+
// Note: This is a unit test for the validation logic only.
+
// Full integration tests with OAuth flow are in tests/integration/oauth_e2e_test.go
+
+
tests := []struct {
+
name string
+
redirectURI string
+
expectedLog string
+
expectedStatus int
+
}{
+
{
+
name: "allowed - Universal Link",
+
redirectURI: "https://coves.social/app/oauth/callback",
+
expectedStatus: http.StatusBadRequest, // Will fail at StartAuthFlow (no OAuth client setup)
+
},
+
{
+
name: "rejected - custom scheme coves-app (insecure)",
+
redirectURI: "coves-app://oauth/callback",
+
expectedStatus: http.StatusBadRequest,
+
expectedLog: "rejected unauthorized mobile redirect URI",
+
},
+
{
+
name: "rejected evil scheme",
+
redirectURI: "evil://callback",
+
expectedStatus: http.StatusBadRequest,
+
expectedLog: "rejected unauthorized mobile redirect URI",
+
},
+
{
+
name: "rejected http",
+
redirectURI: "http://evil.com/callback",
+
expectedStatus: http.StatusBadRequest,
+
expectedLog: "scheme not allowed",
+
},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
// Test the validation function directly
+
result := isAllowedMobileRedirectURI(tt.redirectURI)
+
if tt.expectedLog != "" {
+
assert.False(t, result, "Should reject %s", tt.redirectURI)
+
}
+
})
+
}
+
}
+
+
// TestHandleCallback_CSRFValidation tests that HandleCallback validates CSRF tokens for mobile flow
+
func TestHandleCallback_CSRFValidation(t *testing.T) {
+
// This is a conceptual test structure. Full implementation would require:
+
// 1. Mock OAuthClient
+
// 2. Mock OAuth store
+
// 3. Simulated OAuth callback with cookies
+
+
t.Run("mobile callback requires CSRF token", func(t *testing.T) {
+
// Setup: Create request with mobile_redirect_uri cookie but NO oauth_csrf cookie
+
req := httptest.NewRequest("GET", "/oauth/callback?code=test&state=test", nil)
+
req.AddCookie(&http.Cookie{
+
Name: "mobile_redirect_uri",
+
Value: "https%3A%2F%2Fcoves.social%2Fapp%2Foauth%2Fcallback",
+
})
+
// Missing: oauth_csrf cookie
+
+
// This would be rejected with 403 Forbidden in the actual handler
+
// (Full test in integration tests with real OAuth flow)
+
+
assert.NotNil(t, req) // Placeholder assertion
+
})
+
+
t.Run("mobile callback with valid CSRF token", func(t *testing.T) {
+
// Setup: Create request with both cookies
+
req := httptest.NewRequest("GET", "/oauth/callback?code=test&state=test", nil)
+
req.AddCookie(&http.Cookie{
+
Name: "mobile_redirect_uri",
+
Value: "https%3A%2F%2Fcoves.social%2Fapp%2Foauth%2Fcallback",
+
})
+
req.AddCookie(&http.Cookie{
+
Name: "oauth_csrf",
+
Value: "valid-csrf-token",
+
})
+
+
// This would be accepted (assuming valid OAuth code/state)
+
// (Full test in integration tests with real OAuth flow)
+
+
assert.NotNil(t, req) // Placeholder assertion
+
})
+
}
+
+
// TestHandleMobileCallback_RevalidatesRedirectURI tests that handleMobileCallback re-validates the redirect URI
+
func TestHandleMobileCallback_RevalidatesRedirectURI(t *testing.T) {
+
// This is a critical security test: even if an attacker somehow bypasses the initial check,
+
// the callback handler should re-validate the redirect URI before redirecting.
+
+
tests := []struct {
+
name string
+
redirectURI string
+
shouldPass bool
+
}{
+
{
+
name: "allowed - Universal Link",
+
redirectURI: "https://coves.social/app/oauth/callback",
+
shouldPass: true,
+
},
+
{
+
name: "blocked - custom scheme (insecure)",
+
redirectURI: "coves-app://oauth/callback",
+
shouldPass: false,
+
},
+
{
+
name: "blocked - evil scheme",
+
redirectURI: "evil://callback",
+
shouldPass: false,
+
},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
result := isAllowedMobileRedirectURI(tt.redirectURI)
+
assert.Equal(t, tt.shouldPass, result)
+
})
+
}
+
}
+
+
// TestGenerateMobileRedirectBinding tests the binding token generation
+
// The binding now includes the CSRF token for proper double-submit validation
+
func TestGenerateMobileRedirectBinding(t *testing.T) {
+
csrfToken := "test-csrf-token-12345"
+
tests := []struct {
+
name string
+
redirectURI string
+
}{
+
{
+
name: "Universal Link",
+
redirectURI: "https://coves.social/app/oauth/callback",
+
},
+
{
+
name: "different path",
+
redirectURI: "https://coves.social/different/path",
+
},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
binding1 := generateMobileRedirectBinding(csrfToken, tt.redirectURI)
+
binding2 := generateMobileRedirectBinding(csrfToken, tt.redirectURI)
+
+
// Same CSRF token + URI should produce same binding (deterministic)
+
assert.Equal(t, binding1, binding2, "binding should be deterministic for same inputs")
+
+
// Binding should not be empty
+
assert.NotEmpty(t, binding1, "binding should not be empty")
+
+
// Binding should be base64 encoded (should decode without error)
+
assert.Greater(t, len(binding1), 20, "binding should be reasonably long")
+
})
+
}
+
+
// Different URIs should produce different bindings
+
binding1 := generateMobileRedirectBinding(csrfToken, "https://coves.social/app/oauth/callback")
+
binding2 := generateMobileRedirectBinding(csrfToken, "https://coves.social/different/path")
+
assert.NotEqual(t, binding1, binding2, "different URIs should produce different bindings")
+
+
// Different CSRF tokens should produce different bindings
+
binding3 := generateMobileRedirectBinding("different-csrf-token", "https://coves.social/app/oauth/callback")
+
assert.NotEqual(t, binding1, binding3, "different CSRF tokens should produce different bindings")
+
}
+
+
// TestValidateMobileRedirectBinding tests the binding validation
+
// Now validates both CSRF token and redirect URI together (double-submit pattern)
+
func TestValidateMobileRedirectBinding(t *testing.T) {
+
csrfToken := "test-csrf-token-for-validation"
+
redirectURI := "https://coves.social/app/oauth/callback"
+
validBinding := generateMobileRedirectBinding(csrfToken, redirectURI)
+
+
tests := []struct {
+
name string
+
csrfToken string
+
redirectURI string
+
binding string
+
shouldPass bool
+
}{
+
{
+
name: "valid - correct CSRF token and redirect URI",
+
csrfToken: csrfToken,
+
redirectURI: redirectURI,
+
binding: validBinding,
+
shouldPass: true,
+
},
+
{
+
name: "invalid - wrong redirect URI",
+
csrfToken: csrfToken,
+
redirectURI: "https://coves.social/different/path",
+
binding: validBinding,
+
shouldPass: false,
+
},
+
{
+
name: "invalid - wrong CSRF token",
+
csrfToken: "wrong-csrf-token",
+
redirectURI: redirectURI,
+
binding: validBinding,
+
shouldPass: false,
+
},
+
{
+
name: "invalid - random binding",
+
csrfToken: csrfToken,
+
redirectURI: redirectURI,
+
binding: "random-invalid-binding",
+
shouldPass: false,
+
},
+
{
+
name: "invalid - empty binding",
+
csrfToken: csrfToken,
+
redirectURI: redirectURI,
+
binding: "",
+
shouldPass: false,
+
},
+
{
+
name: "invalid - empty CSRF token",
+
csrfToken: "",
+
redirectURI: redirectURI,
+
binding: validBinding,
+
shouldPass: false,
+
},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
result := validateMobileRedirectBinding(tt.csrfToken, tt.redirectURI, tt.binding)
+
assert.Equal(t, tt.shouldPass, result)
+
})
+
}
+
}
+
+
// TestSessionFixationAttackPrevention tests that the binding prevents session fixation
+
func TestSessionFixationAttackPrevention(t *testing.T) {
+
// Simulate attack scenario:
+
// 1. Attacker plants a cookie for evil://steal with binding for evil://steal
+
// 2. User does a web login (no mobile_redirect_binding cookie)
+
// 3. Callback should NOT redirect to evil://steal
+
+
attackerCSRF := "attacker-csrf-token"
+
attackerRedirectURI := "evil://steal"
+
attackerBinding := generateMobileRedirectBinding(attackerCSRF, attackerRedirectURI)
+
+
// Later, user's legitimate mobile login
+
userCSRF := "user-csrf-token"
+
userRedirectURI := "https://coves.social/app/oauth/callback"
+
userBinding := generateMobileRedirectBinding(userCSRF, userRedirectURI)
+
+
// The attacker's binding should NOT validate for the user's redirect URI
+
assert.False(t, validateMobileRedirectBinding(userCSRF, userRedirectURI, attackerBinding),
+
"attacker's binding should not validate for user's CSRF token and redirect URI")
+
+
// The user's binding should validate for the user's CSRF token and redirect URI
+
assert.True(t, validateMobileRedirectBinding(userCSRF, userRedirectURI, userBinding),
+
"user's binding should validate for user's CSRF token and redirect URI")
+
+
// Cross-validation should fail
+
assert.False(t, validateMobileRedirectBinding(attackerCSRF, attackerRedirectURI, userBinding),
+
"user's binding should not validate for attacker's CSRF token and redirect URI")
+
}
+
+
// TestCSRFTokenValidation tests that CSRF token VALUE is validated, not just presence
+
func TestCSRFTokenValidation(t *testing.T) {
+
// This test verifies the fix for the P1 security issue:
+
// "The callback never validates the token... the csrfToken argument is ignored entirely"
+
//
+
// The fix ensures that the CSRF token VALUE is cryptographically bound to the
+
// binding token, so changing the CSRF token will invalidate the binding.
+
+
t.Run("CSRF token value must match", func(t *testing.T) {
+
originalCSRF := "original-csrf-token-from-login"
+
redirectURI := "https://coves.social/app/oauth/callback"
+
binding := generateMobileRedirectBinding(originalCSRF, redirectURI)
+
+
// Original CSRF token should validate
+
assert.True(t, validateMobileRedirectBinding(originalCSRF, redirectURI, binding),
+
"original CSRF token should validate")
+
+
// Different CSRF token should NOT validate (this is the key security fix)
+
differentCSRF := "attacker-forged-csrf-token"
+
assert.False(t, validateMobileRedirectBinding(differentCSRF, redirectURI, binding),
+
"different CSRF token should NOT validate - this is the security fix")
+
})
+
+
t.Run("attacker cannot forge binding without CSRF token", func(t *testing.T) {
+
// Attacker knows the redirect URI but not the CSRF token
+
redirectURI := "https://coves.social/app/oauth/callback"
+
victimCSRF := "victim-secret-csrf-token"
+
victimBinding := generateMobileRedirectBinding(victimCSRF, redirectURI)
+
+
// Attacker tries various CSRF tokens to forge the binding
+
attackerGuesses := []string{
+
"",
+
"guess1",
+
"attacker-csrf",
+
redirectURI, // trying the redirect URI as CSRF
+
}
+
+
for _, guess := range attackerGuesses {
+
assert.False(t, validateMobileRedirectBinding(guess, redirectURI, victimBinding),
+
"attacker's CSRF guess %q should not validate", guess)
+
}
+
})
+
}
+
+
// TestConstantTimeCompare tests the timing-safe comparison function
+
func TestConstantTimeCompare(t *testing.T) {
+
tests := []struct {
+
name string
+
a string
+
b string
+
expected bool
+
}{
+
{
+
name: "equal strings",
+
a: "abc123",
+
b: "abc123",
+
expected: true,
+
},
+
{
+
name: "different strings same length",
+
a: "abc123",
+
b: "xyz789",
+
expected: false,
+
},
+
{
+
name: "different lengths",
+
a: "short",
+
b: "longer",
+
expected: false,
+
},
+
{
+
name: "empty strings",
+
a: "",
+
b: "",
+
expected: true,
+
},
+
{
+
name: "one empty",
+
a: "abc",
+
b: "",
+
expected: false,
+
},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
result := constantTimeCompare(tt.a, tt.b)
+
assert.Equal(t, tt.expected, result)
+
})
+
}
+
}
+152
internal/atproto/oauth/seal.go
···
+
package oauth
+
+
import (
+
"crypto/aes"
+
"crypto/cipher"
+
"crypto/rand"
+
"encoding/base64"
+
"encoding/json"
+
"fmt"
+
"time"
+
)
+
+
// SealedSession represents the data sealed in a mobile session token
+
type SealedSession struct {
+
DID string `json:"did"` // User's DID
+
SessionID string `json:"sid"` // Session identifier
+
ExpiresAt int64 `json:"exp"` // Unix timestamp when token expires
+
}
+
+
// SealSession creates an encrypted token containing session information.
+
// The token is encrypted using AES-256-GCM and encoded as base64url.
+
//
+
// Token format: base64url(nonce || ciphertext || tag)
+
// - nonce: 12 bytes (GCM standard nonce size)
+
// - ciphertext: encrypted JSON payload
+
// - tag: 16 bytes (GCM authentication tag)
+
//
+
// The sealed token can be safely given to mobile clients and used as
+
// a reference to the server-side session without exposing sensitive data.
+
func (c *OAuthClient) SealSession(did, sessionID string, ttl time.Duration) (string, error) {
+
if len(c.SealSecret) == 0 {
+
return "", fmt.Errorf("seal secret not configured")
+
}
+
+
if did == "" {
+
return "", fmt.Errorf("DID is required")
+
}
+
+
if sessionID == "" {
+
return "", fmt.Errorf("session ID is required")
+
}
+
+
// Create the session data
+
expiresAt := time.Now().Add(ttl).Unix()
+
session := SealedSession{
+
DID: did,
+
SessionID: sessionID,
+
ExpiresAt: expiresAt,
+
}
+
+
// Marshal to JSON
+
plaintext, err := json.Marshal(session)
+
if err != nil {
+
return "", fmt.Errorf("failed to marshal session: %w", err)
+
}
+
+
// Create AES cipher
+
block, err := aes.NewCipher(c.SealSecret)
+
if err != nil {
+
return "", fmt.Errorf("failed to create cipher: %w", err)
+
}
+
+
// Create GCM mode
+
gcm, err := cipher.NewGCM(block)
+
if err != nil {
+
return "", fmt.Errorf("failed to create GCM: %w", err)
+
}
+
+
// Generate random nonce
+
nonce := make([]byte, gcm.NonceSize())
+
if _, err := rand.Read(nonce); err != nil {
+
return "", fmt.Errorf("failed to generate nonce: %w", err)
+
}
+
+
// Encrypt and authenticate
+
// GCM.Seal appends the ciphertext and tag to the nonce
+
ciphertext := gcm.Seal(nonce, nonce, plaintext, nil)
+
+
// Encode as base64url (no padding)
+
token := base64.RawURLEncoding.EncodeToString(ciphertext)
+
+
return token, nil
+
}
+
+
// UnsealSession decrypts and validates a sealed session token.
+
// Returns the session information if the token is valid and not expired.
+
func (c *OAuthClient) UnsealSession(token string) (*SealedSession, error) {
+
if len(c.SealSecret) == 0 {
+
return nil, fmt.Errorf("seal secret not configured")
+
}
+
+
if token == "" {
+
return nil, fmt.Errorf("token is required")
+
}
+
+
// Decode from base64url
+
ciphertext, err := base64.RawURLEncoding.DecodeString(token)
+
if err != nil {
+
return nil, fmt.Errorf("invalid token encoding: %w", err)
+
}
+
+
// Create AES cipher
+
block, err := aes.NewCipher(c.SealSecret)
+
if err != nil {
+
return nil, fmt.Errorf("failed to create cipher: %w", err)
+
}
+
+
// Create GCM mode
+
gcm, err := cipher.NewGCM(block)
+
if err != nil {
+
return nil, fmt.Errorf("failed to create GCM: %w", err)
+
}
+
+
// Verify minimum size (nonce + tag)
+
nonceSize := gcm.NonceSize()
+
if len(ciphertext) < nonceSize {
+
return nil, fmt.Errorf("invalid token: too short")
+
}
+
+
// Extract nonce and ciphertext
+
nonce := ciphertext[:nonceSize]
+
ciphertextData := ciphertext[nonceSize:]
+
+
// Decrypt and authenticate
+
plaintext, err := gcm.Open(nil, nonce, ciphertextData, nil)
+
if err != nil {
+
return nil, fmt.Errorf("failed to decrypt token: %w", err)
+
}
+
+
// Unmarshal JSON
+
var session SealedSession
+
if err := json.Unmarshal(plaintext, &session); err != nil {
+
return nil, fmt.Errorf("failed to unmarshal session: %w", err)
+
}
+
+
// Validate required fields
+
if session.DID == "" {
+
return nil, fmt.Errorf("invalid session: missing DID")
+
}
+
+
if session.SessionID == "" {
+
return nil, fmt.Errorf("invalid session: missing session ID")
+
}
+
+
// Check expiration
+
now := time.Now().Unix()
+
if session.ExpiresAt <= now {
+
return nil, fmt.Errorf("token expired at %v", time.Unix(session.ExpiresAt, 0))
+
}
+
+
return &session, nil
+
}
+331
internal/atproto/oauth/seal_test.go
···
+
package oauth
+
+
import (
+
"crypto/rand"
+
"encoding/base64"
+
"strings"
+
"testing"
+
"time"
+
+
"github.com/stretchr/testify/assert"
+
"github.com/stretchr/testify/require"
+
)
+
+
// generateSealSecret generates a random 32-byte seal secret for testing
+
func generateSealSecret() []byte {
+
secret := make([]byte, 32)
+
if _, err := rand.Read(secret); err != nil {
+
panic(err)
+
}
+
return secret
+
}
+
+
func TestSealSession_RoundTrip(t *testing.T) {
+
// Create client with seal secret
+
client := &OAuthClient{
+
SealSecret: generateSealSecret(),
+
}
+
+
did := "did:plc:abc123"
+
sessionID := "session-xyz"
+
ttl := 1 * time.Hour
+
+
// Seal the session
+
token, err := client.SealSession(did, sessionID, ttl)
+
require.NoError(t, err)
+
require.NotEmpty(t, token)
+
+
// Token should be base64url encoded
+
_, err = base64.RawURLEncoding.DecodeString(token)
+
require.NoError(t, err, "token should be valid base64url")
+
+
// Unseal the session
+
session, err := client.UnsealSession(token)
+
require.NoError(t, err)
+
require.NotNil(t, session)
+
+
// Verify data
+
assert.Equal(t, did, session.DID)
+
assert.Equal(t, sessionID, session.SessionID)
+
+
// Verify expiration is approximately correct (within 1 second)
+
expectedExpiry := time.Now().Add(ttl).Unix()
+
assert.InDelta(t, expectedExpiry, session.ExpiresAt, 1.0)
+
}
+
+
func TestSealSession_ExpirationValidation(t *testing.T) {
+
client := &OAuthClient{
+
SealSecret: generateSealSecret(),
+
}
+
+
did := "did:plc:abc123"
+
sessionID := "session-xyz"
+
ttl := 2 * time.Second // Short TTL (must be >= 1 second due to Unix timestamp granularity)
+
+
// Seal the session
+
token, err := client.SealSession(did, sessionID, ttl)
+
require.NoError(t, err)
+
+
// Should work immediately
+
session, err := client.UnsealSession(token)
+
require.NoError(t, err)
+
assert.Equal(t, did, session.DID)
+
+
// Wait well past expiration
+
time.Sleep(2500 * time.Millisecond)
+
+
// Should fail after expiration
+
session, err = client.UnsealSession(token)
+
assert.Error(t, err)
+
assert.Nil(t, session)
+
assert.Contains(t, err.Error(), "token expired")
+
}
+
+
func TestSealSession_TamperedTokenDetection(t *testing.T) {
+
client := &OAuthClient{
+
SealSecret: generateSealSecret(),
+
}
+
+
did := "did:plc:abc123"
+
sessionID := "session-xyz"
+
ttl := 1 * time.Hour
+
+
// Seal the session
+
token, err := client.SealSession(did, sessionID, ttl)
+
require.NoError(t, err)
+
+
// Tamper with the token by modifying one character
+
tampered := token[:len(token)-5] + "XXXX" + token[len(token)-1:]
+
+
// Should fail to unseal tampered token
+
session, err := client.UnsealSession(tampered)
+
assert.Error(t, err)
+
assert.Nil(t, session)
+
assert.Contains(t, err.Error(), "failed to decrypt token")
+
}
+
+
func TestSealSession_InvalidTokenFormats(t *testing.T) {
+
client := &OAuthClient{
+
SealSecret: generateSealSecret(),
+
}
+
+
tests := []struct {
+
name string
+
token string
+
}{
+
{
+
name: "empty token",
+
token: "",
+
},
+
{
+
name: "invalid base64",
+
token: "not-valid-base64!@#$",
+
},
+
{
+
name: "too short",
+
token: base64.RawURLEncoding.EncodeToString([]byte("short")),
+
},
+
{
+
name: "random bytes",
+
token: base64.RawURLEncoding.EncodeToString(make([]byte, 50)),
+
},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
session, err := client.UnsealSession(tt.token)
+
assert.Error(t, err)
+
assert.Nil(t, session)
+
})
+
}
+
}
+
+
func TestSealSession_DifferentSecrets(t *testing.T) {
+
// Create two clients with different secrets
+
client1 := &OAuthClient{
+
SealSecret: generateSealSecret(),
+
}
+
client2 := &OAuthClient{
+
SealSecret: generateSealSecret(),
+
}
+
+
did := "did:plc:abc123"
+
sessionID := "session-xyz"
+
ttl := 1 * time.Hour
+
+
// Seal with client1
+
token, err := client1.SealSession(did, sessionID, ttl)
+
require.NoError(t, err)
+
+
// Try to unseal with client2 (different secret)
+
session, err := client2.UnsealSession(token)
+
assert.Error(t, err)
+
assert.Nil(t, session)
+
assert.Contains(t, err.Error(), "failed to decrypt token")
+
}
+
+
func TestSealSession_NoSecretConfigured(t *testing.T) {
+
client := &OAuthClient{
+
SealSecret: nil,
+
}
+
+
did := "did:plc:abc123"
+
sessionID := "session-xyz"
+
ttl := 1 * time.Hour
+
+
// Should fail to seal without secret
+
token, err := client.SealSession(did, sessionID, ttl)
+
assert.Error(t, err)
+
assert.Empty(t, token)
+
assert.Contains(t, err.Error(), "seal secret not configured")
+
+
// Should fail to unseal without secret
+
session, err := client.UnsealSession("dummy-token")
+
assert.Error(t, err)
+
assert.Nil(t, session)
+
assert.Contains(t, err.Error(), "seal secret not configured")
+
}
+
+
func TestSealSession_MissingRequiredFields(t *testing.T) {
+
client := &OAuthClient{
+
SealSecret: generateSealSecret(),
+
}
+
+
ttl := 1 * time.Hour
+
+
tests := []struct {
+
name string
+
did string
+
sessionID string
+
errorMsg string
+
}{
+
{
+
name: "missing DID",
+
did: "",
+
sessionID: "session-123",
+
errorMsg: "DID is required",
+
},
+
{
+
name: "missing session ID",
+
did: "did:plc:abc123",
+
sessionID: "",
+
errorMsg: "session ID is required",
+
},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
token, err := client.SealSession(tt.did, tt.sessionID, ttl)
+
assert.Error(t, err)
+
assert.Empty(t, token)
+
assert.Contains(t, err.Error(), tt.errorMsg)
+
})
+
}
+
}
+
+
func TestSealSession_UniquenessPerCall(t *testing.T) {
+
client := &OAuthClient{
+
SealSecret: generateSealSecret(),
+
}
+
+
did := "did:plc:abc123"
+
sessionID := "session-xyz"
+
ttl := 1 * time.Hour
+
+
// Seal the same session twice
+
token1, err := client.SealSession(did, sessionID, ttl)
+
require.NoError(t, err)
+
+
token2, err := client.SealSession(did, sessionID, ttl)
+
require.NoError(t, err)
+
+
// Tokens should be different (different nonces)
+
assert.NotEqual(t, token1, token2, "tokens should be unique due to different nonces")
+
+
// But both should unseal to the same session data
+
session1, err := client.UnsealSession(token1)
+
require.NoError(t, err)
+
+
session2, err := client.UnsealSession(token2)
+
require.NoError(t, err)
+
+
assert.Equal(t, session1.DID, session2.DID)
+
assert.Equal(t, session1.SessionID, session2.SessionID)
+
}
+
+
func TestSealSession_LongDIDAndSessionID(t *testing.T) {
+
client := &OAuthClient{
+
SealSecret: generateSealSecret(),
+
}
+
+
// Test with very long DID and session ID
+
did := "did:plc:" + strings.Repeat("a", 200)
+
sessionID := "session-" + strings.Repeat("x", 200)
+
ttl := 1 * time.Hour
+
+
// Should work with long values
+
token, err := client.SealSession(did, sessionID, ttl)
+
require.NoError(t, err)
+
+
session, err := client.UnsealSession(token)
+
require.NoError(t, err)
+
assert.Equal(t, did, session.DID)
+
assert.Equal(t, sessionID, session.SessionID)
+
}
+
+
func TestSealSession_URLSafeEncoding(t *testing.T) {
+
client := &OAuthClient{
+
SealSecret: generateSealSecret(),
+
}
+
+
did := "did:plc:abc123"
+
sessionID := "session-xyz"
+
ttl := 1 * time.Hour
+
+
// Seal multiple times to get different nonces
+
for i := 0; i < 100; i++ {
+
token, err := client.SealSession(did, sessionID, ttl)
+
require.NoError(t, err)
+
+
// Token should not contain URL-unsafe characters
+
assert.NotContains(t, token, "+", "token should not contain '+'")
+
assert.NotContains(t, token, "/", "token should not contain '/'")
+
assert.NotContains(t, token, "=", "token should not contain '='")
+
+
// Should unseal successfully
+
session, err := client.UnsealSession(token)
+
require.NoError(t, err)
+
assert.Equal(t, did, session.DID)
+
}
+
}
+
+
func TestSealSession_ConcurrentAccess(t *testing.T) {
+
client := &OAuthClient{
+
SealSecret: generateSealSecret(),
+
}
+
+
did := "did:plc:abc123"
+
sessionID := "session-xyz"
+
ttl := 1 * time.Hour
+
+
// Run concurrent seal/unseal operations
+
done := make(chan bool)
+
for i := 0; i < 10; i++ {
+
go func() {
+
for j := 0; j < 100; j++ {
+
token, err := client.SealSession(did, sessionID, ttl)
+
require.NoError(t, err)
+
+
session, err := client.UnsealSession(token)
+
require.NoError(t, err)
+
assert.Equal(t, did, session.DID)
+
}
+
done <- true
+
}()
+
}
+
+
// Wait for all goroutines
+
for i := 0; i < 10; i++ {
+
<-done
+
}
+
}
+614
internal/atproto/oauth/store.go
···
+
package oauth
+
+
import (
+
"context"
+
"database/sql"
+
"errors"
+
"fmt"
+
"log/slog"
+
"net/url"
+
"strings"
+
"time"
+
+
"github.com/bluesky-social/indigo/atproto/auth/oauth"
+
"github.com/bluesky-social/indigo/atproto/syntax"
+
"github.com/lib/pq"
+
)
+
+
var (
+
ErrSessionNotFound = errors.New("oauth session not found")
+
ErrAuthRequestNotFound = errors.New("oauth auth request not found")
+
)
+
+
// PostgresOAuthStore implements oauth.ClientAuthStore interface using PostgreSQL
+
type PostgresOAuthStore struct {
+
db *sql.DB
+
sessionTTL time.Duration
+
}
+
+
// NewPostgresOAuthStore creates a new PostgreSQL-backed OAuth store
+
func NewPostgresOAuthStore(db *sql.DB, sessionTTL time.Duration) oauth.ClientAuthStore {
+
if sessionTTL == 0 {
+
sessionTTL = 7 * 24 * time.Hour // Default to 7 days
+
}
+
return &PostgresOAuthStore{
+
db: db,
+
sessionTTL: sessionTTL,
+
}
+
}
+
+
// GetSession retrieves a session by DID and session ID
+
func (s *PostgresOAuthStore) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauth.ClientSessionData, error) {
+
query := `
+
SELECT
+
did, session_id, host_url, auth_server_iss,
+
auth_server_token_endpoint, auth_server_revocation_endpoint,
+
scopes, access_token, refresh_token,
+
dpop_authserver_nonce, dpop_pds_nonce, dpop_private_key_multibase
+
FROM oauth_sessions
+
WHERE did = $1 AND session_id = $2 AND expires_at > NOW()
+
`
+
+
var session oauth.ClientSessionData
+
var authServerIss, authServerTokenEndpoint, authServerRevocationEndpoint sql.NullString
+
var hostURL, dpopPrivateKeyMultibase sql.NullString
+
var scopes pq.StringArray
+
var dpopAuthServerNonce, dpopHostNonce sql.NullString
+
+
err := s.db.QueryRowContext(ctx, query, did.String(), sessionID).Scan(
+
&session.AccountDID,
+
&session.SessionID,
+
&hostURL,
+
&authServerIss,
+
&authServerTokenEndpoint,
+
&authServerRevocationEndpoint,
+
&scopes,
+
&session.AccessToken,
+
&session.RefreshToken,
+
&dpopAuthServerNonce,
+
&dpopHostNonce,
+
&dpopPrivateKeyMultibase,
+
)
+
+
if err == sql.ErrNoRows {
+
return nil, ErrSessionNotFound
+
}
+
if err != nil {
+
return nil, fmt.Errorf("failed to get session: %w", err)
+
}
+
+
// Convert nullable fields
+
if hostURL.Valid {
+
session.HostURL = hostURL.String
+
}
+
if authServerIss.Valid {
+
session.AuthServerURL = authServerIss.String
+
}
+
if authServerTokenEndpoint.Valid {
+
session.AuthServerTokenEndpoint = authServerTokenEndpoint.String
+
}
+
if authServerRevocationEndpoint.Valid {
+
session.AuthServerRevocationEndpoint = authServerRevocationEndpoint.String
+
}
+
if dpopAuthServerNonce.Valid {
+
session.DPoPAuthServerNonce = dpopAuthServerNonce.String
+
}
+
if dpopHostNonce.Valid {
+
session.DPoPHostNonce = dpopHostNonce.String
+
}
+
if dpopPrivateKeyMultibase.Valid {
+
session.DPoPPrivateKeyMultibase = dpopPrivateKeyMultibase.String
+
}
+
session.Scopes = scopes
+
+
return &session, nil
+
}
+
+
// SaveSession saves or updates a session (upsert operation)
+
func (s *PostgresOAuthStore) SaveSession(ctx context.Context, sess oauth.ClientSessionData) error {
+
// Input validation per atProto OAuth security requirements
+
+
// Validate DID format
+
if _, err := syntax.ParseDID(sess.AccountDID.String()); err != nil {
+
return fmt.Errorf("invalid DID format: %w", err)
+
}
+
+
// Validate token lengths (max 10000 chars to prevent memory issues)
+
const maxTokenLength = 10000
+
if len(sess.AccessToken) > maxTokenLength {
+
return fmt.Errorf("access_token exceeds maximum length of %d characters", maxTokenLength)
+
}
+
if len(sess.RefreshToken) > maxTokenLength {
+
return fmt.Errorf("refresh_token exceeds maximum length of %d characters", maxTokenLength)
+
}
+
+
// Validate session ID is not empty
+
if sess.SessionID == "" {
+
return fmt.Errorf("session_id cannot be empty")
+
}
+
+
// Validate URLs if provided
+
if sess.HostURL != "" {
+
if _, err := url.Parse(sess.HostURL); err != nil {
+
return fmt.Errorf("invalid host_url: %w", err)
+
}
+
}
+
if sess.AuthServerURL != "" {
+
if _, err := url.Parse(sess.AuthServerURL); err != nil {
+
return fmt.Errorf("invalid auth_server URL: %w", err)
+
}
+
}
+
if sess.AuthServerTokenEndpoint != "" {
+
if _, err := url.Parse(sess.AuthServerTokenEndpoint); err != nil {
+
return fmt.Errorf("invalid auth_server_token_endpoint: %w", err)
+
}
+
}
+
if sess.AuthServerRevocationEndpoint != "" {
+
if _, err := url.Parse(sess.AuthServerRevocationEndpoint); err != nil {
+
return fmt.Errorf("invalid auth_server_revocation_endpoint: %w", err)
+
}
+
}
+
+
query := `
+
INSERT INTO oauth_sessions (
+
did, session_id, handle, pds_url, host_url,
+
access_token, refresh_token,
+
dpop_private_jwk, dpop_private_key_multibase,
+
dpop_authserver_nonce, dpop_pds_nonce,
+
auth_server_iss, auth_server_token_endpoint, auth_server_revocation_endpoint,
+
scopes, expires_at, created_at, updated_at
+
) VALUES (
+
$1, $2, $3, $4, $5,
+
$6, $7,
+
NULL, $8,
+
$9, $10,
+
$11, $12, $13,
+
$14, $15, NOW(), NOW()
+
)
+
ON CONFLICT (did, session_id) DO UPDATE SET
+
handle = EXCLUDED.handle,
+
pds_url = EXCLUDED.pds_url,
+
host_url = EXCLUDED.host_url,
+
access_token = EXCLUDED.access_token,
+
refresh_token = EXCLUDED.refresh_token,
+
dpop_private_key_multibase = EXCLUDED.dpop_private_key_multibase,
+
dpop_authserver_nonce = EXCLUDED.dpop_authserver_nonce,
+
dpop_pds_nonce = EXCLUDED.dpop_pds_nonce,
+
auth_server_iss = EXCLUDED.auth_server_iss,
+
auth_server_token_endpoint = EXCLUDED.auth_server_token_endpoint,
+
auth_server_revocation_endpoint = EXCLUDED.auth_server_revocation_endpoint,
+
scopes = EXCLUDED.scopes,
+
expires_at = EXCLUDED.expires_at,
+
updated_at = NOW()
+
`
+
+
// Calculate token expiration using configured TTL
+
expiresAt := time.Now().Add(s.sessionTTL)
+
+
// Convert empty strings to NULL for optional fields
+
var authServerRevocationEndpoint sql.NullString
+
if sess.AuthServerRevocationEndpoint != "" {
+
authServerRevocationEndpoint.String = sess.AuthServerRevocationEndpoint
+
authServerRevocationEndpoint.Valid = true
+
}
+
+
// Extract handle from DID (placeholder - in real implementation, resolve from identity)
+
// For now, use DID as handle since we don't have the handle in ClientSessionData
+
handle := sess.AccountDID.String()
+
+
// Use HostURL as PDS URL
+
pdsURL := sess.HostURL
+
if pdsURL == "" {
+
pdsURL = sess.AuthServerURL // Fallback to auth server URL
+
}
+
+
_, err := s.db.ExecContext(
+
ctx, query,
+
sess.AccountDID.String(),
+
sess.SessionID,
+
handle,
+
pdsURL,
+
sess.HostURL,
+
sess.AccessToken,
+
sess.RefreshToken,
+
sess.DPoPPrivateKeyMultibase,
+
sess.DPoPAuthServerNonce,
+
sess.DPoPHostNonce,
+
sess.AuthServerURL,
+
sess.AuthServerTokenEndpoint,
+
authServerRevocationEndpoint,
+
pq.Array(sess.Scopes),
+
expiresAt,
+
)
+
if err != nil {
+
return fmt.Errorf("failed to save session: %w", err)
+
}
+
+
return nil
+
}
+
+
// DeleteSession deletes a session by DID and session ID
+
func (s *PostgresOAuthStore) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error {
+
query := `DELETE FROM oauth_sessions WHERE did = $1 AND session_id = $2`
+
+
result, err := s.db.ExecContext(ctx, query, did.String(), sessionID)
+
if err != nil {
+
return fmt.Errorf("failed to delete session: %w", err)
+
}
+
+
rows, err := result.RowsAffected()
+
if err != nil {
+
return fmt.Errorf("failed to get rows affected: %w", err)
+
}
+
+
if rows == 0 {
+
return ErrSessionNotFound
+
}
+
+
return nil
+
}
+
+
// GetAuthRequestInfo retrieves auth request information by state
+
func (s *PostgresOAuthStore) GetAuthRequestInfo(ctx context.Context, state string) (*oauth.AuthRequestData, error) {
+
query := `
+
SELECT
+
state, did, handle, pds_url, pkce_verifier,
+
dpop_private_key_multibase, dpop_authserver_nonce,
+
auth_server_iss, request_uri,
+
auth_server_token_endpoint, auth_server_revocation_endpoint,
+
scopes, created_at
+
FROM oauth_requests
+
WHERE state = $1
+
`
+
+
var info oauth.AuthRequestData
+
var did, handle, pdsURL sql.NullString
+
var dpopPrivateKeyMultibase, dpopAuthServerNonce sql.NullString
+
var requestURI, authServerTokenEndpoint, authServerRevocationEndpoint sql.NullString
+
var scopes pq.StringArray
+
var createdAt time.Time
+
+
err := s.db.QueryRowContext(ctx, query, state).Scan(
+
&info.State,
+
&did,
+
&handle,
+
&pdsURL,
+
&info.PKCEVerifier,
+
&dpopPrivateKeyMultibase,
+
&dpopAuthServerNonce,
+
&info.AuthServerURL,
+
&requestURI,
+
&authServerTokenEndpoint,
+
&authServerRevocationEndpoint,
+
&scopes,
+
&createdAt,
+
)
+
+
if err == sql.ErrNoRows {
+
return nil, ErrAuthRequestNotFound
+
}
+
if err != nil {
+
return nil, fmt.Errorf("failed to get auth request info: %w", err)
+
}
+
+
// Parse DID if present
+
if did.Valid && did.String != "" {
+
parsedDID, err := syntax.ParseDID(did.String)
+
if err != nil {
+
return nil, fmt.Errorf("failed to parse DID: %w", err)
+
}
+
info.AccountDID = &parsedDID
+
}
+
+
// Convert nullable fields
+
if dpopPrivateKeyMultibase.Valid {
+
info.DPoPPrivateKeyMultibase = dpopPrivateKeyMultibase.String
+
}
+
if dpopAuthServerNonce.Valid {
+
info.DPoPAuthServerNonce = dpopAuthServerNonce.String
+
}
+
if requestURI.Valid {
+
info.RequestURI = requestURI.String
+
}
+
if authServerTokenEndpoint.Valid {
+
info.AuthServerTokenEndpoint = authServerTokenEndpoint.String
+
}
+
if authServerRevocationEndpoint.Valid {
+
info.AuthServerRevocationEndpoint = authServerRevocationEndpoint.String
+
}
+
info.Scopes = scopes
+
+
return &info, nil
+
}
+
+
// SaveAuthRequestInfo saves auth request information (create only, not upsert)
+
func (s *PostgresOAuthStore) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthRequestData) error {
+
query := `
+
INSERT INTO oauth_requests (
+
state, did, handle, pds_url, pkce_verifier,
+
dpop_private_key_multibase, dpop_authserver_nonce,
+
auth_server_iss, request_uri,
+
auth_server_token_endpoint, auth_server_revocation_endpoint,
+
scopes, return_url, created_at
+
) VALUES (
+
$1, $2, $3, $4, $5,
+
$6, $7,
+
$8, $9,
+
$10, $11,
+
$12, NULL, NOW()
+
)
+
`
+
+
// Extract DID string if present
+
var didStr sql.NullString
+
if info.AccountDID != nil {
+
didStr.String = info.AccountDID.String()
+
didStr.Valid = true
+
}
+
+
// Convert empty strings to NULL for optional fields
+
var authServerRevocationEndpoint sql.NullString
+
if info.AuthServerRevocationEndpoint != "" {
+
authServerRevocationEndpoint.String = info.AuthServerRevocationEndpoint
+
authServerRevocationEndpoint.Valid = true
+
}
+
+
// Placeholder values for handle and pds_url (not in AuthRequestData)
+
// In production, these would be resolved during the auth flow
+
handle := ""
+
pdsURL := ""
+
if info.AccountDID != nil {
+
handle = info.AccountDID.String() // Temporary placeholder
+
pdsURL = info.AuthServerURL // Temporary placeholder
+
}
+
+
_, err := s.db.ExecContext(
+
ctx, query,
+
info.State,
+
didStr,
+
handle,
+
pdsURL,
+
info.PKCEVerifier,
+
info.DPoPPrivateKeyMultibase,
+
info.DPoPAuthServerNonce,
+
info.AuthServerURL,
+
info.RequestURI,
+
info.AuthServerTokenEndpoint,
+
authServerRevocationEndpoint,
+
pq.Array(info.Scopes),
+
)
+
if err != nil {
+
// Check for duplicate state
+
if strings.Contains(err.Error(), "duplicate key") && strings.Contains(err.Error(), "oauth_requests_state_key") {
+
return fmt.Errorf("auth request with state already exists: %s", info.State)
+
}
+
return fmt.Errorf("failed to save auth request info: %w", err)
+
}
+
+
return nil
+
}
+
+
// DeleteAuthRequestInfo deletes auth request information by state
+
func (s *PostgresOAuthStore) DeleteAuthRequestInfo(ctx context.Context, state string) error {
+
query := `DELETE FROM oauth_requests WHERE state = $1`
+
+
result, err := s.db.ExecContext(ctx, query, state)
+
if err != nil {
+
return fmt.Errorf("failed to delete auth request info: %w", err)
+
}
+
+
rows, err := result.RowsAffected()
+
if err != nil {
+
return fmt.Errorf("failed to get rows affected: %w", err)
+
}
+
+
if rows == 0 {
+
return ErrAuthRequestNotFound
+
}
+
+
return nil
+
}
+
+
// CleanupExpiredSessions removes sessions that have expired
+
// Should be called periodically (e.g., via cron job)
+
func (s *PostgresOAuthStore) CleanupExpiredSessions(ctx context.Context) (int64, error) {
+
query := `DELETE FROM oauth_sessions WHERE expires_at < NOW()`
+
+
result, err := s.db.ExecContext(ctx, query)
+
if err != nil {
+
return 0, fmt.Errorf("failed to cleanup expired sessions: %w", err)
+
}
+
+
rows, err := result.RowsAffected()
+
if err != nil {
+
return 0, fmt.Errorf("failed to get rows affected: %w", err)
+
}
+
+
return rows, nil
+
}
+
+
// CleanupExpiredAuthRequests removes auth requests older than 30 minutes
+
// Should be called periodically (e.g., via cron job)
+
func (s *PostgresOAuthStore) CleanupExpiredAuthRequests(ctx context.Context) (int64, error) {
+
query := `DELETE FROM oauth_requests WHERE created_at < NOW() - INTERVAL '30 minutes'`
+
+
result, err := s.db.ExecContext(ctx, query)
+
if err != nil {
+
return 0, fmt.Errorf("failed to cleanup expired auth requests: %w", err)
+
}
+
+
rows, err := result.RowsAffected()
+
if err != nil {
+
return 0, fmt.Errorf("failed to get rows affected: %w", err)
+
}
+
+
return rows, nil
+
}
+
+
// MobileOAuthData holds mobile-specific OAuth flow data
+
type MobileOAuthData struct {
+
CSRFToken string
+
RedirectURI string
+
}
+
+
// mobileFlowContextKey is the context key for mobile flow data
+
type mobileFlowContextKey struct{}
+
+
// ContextWithMobileFlowData adds mobile flow data to a context.
+
// This is used by HandleMobileLogin to pass mobile data to the store wrapper,
+
// which will save it when SaveAuthRequestInfo is called by indigo.
+
func ContextWithMobileFlowData(ctx context.Context, data MobileOAuthData) context.Context {
+
return context.WithValue(ctx, mobileFlowContextKey{}, data)
+
}
+
+
// getMobileFlowDataFromContext retrieves mobile flow data from context, if present
+
func getMobileFlowDataFromContext(ctx context.Context) (MobileOAuthData, bool) {
+
data, ok := ctx.Value(mobileFlowContextKey{}).(MobileOAuthData)
+
return data, ok
+
}
+
+
// MobileAwareClientStore is a marker interface that indicates a store is properly
+
// configured for mobile OAuth flows. Only stores that intercept SaveAuthRequestInfo
+
// to save mobile CSRF data should implement this interface.
+
// This prevents silent mobile OAuth breakage when a plain PostgresOAuthStore is used.
+
type MobileAwareClientStore interface {
+
IsMobileAware() bool
+
}
+
+
// MobileAwareStoreWrapper wraps a ClientAuthStore to automatically save mobile
+
// CSRF data when SaveAuthRequestInfo is called during a mobile OAuth flow.
+
// This is necessary because indigo's StartAuthFlow doesn't expose the OAuth state,
+
// so we intercept the SaveAuthRequestInfo call to capture it.
+
type MobileAwareStoreWrapper struct {
+
oauth.ClientAuthStore
+
mobileStore MobileOAuthStore
+
}
+
+
// IsMobileAware implements MobileAwareClientStore, indicating this store
+
// properly saves mobile CSRF data during OAuth flow initiation.
+
func (w *MobileAwareStoreWrapper) IsMobileAware() bool {
+
return true
+
}
+
+
// NewMobileAwareStoreWrapper creates a wrapper that intercepts SaveAuthRequestInfo
+
// to also save mobile CSRF data when present in context.
+
func NewMobileAwareStoreWrapper(store oauth.ClientAuthStore) *MobileAwareStoreWrapper {
+
wrapper := &MobileAwareStoreWrapper{
+
ClientAuthStore: store,
+
}
+
// Check if the underlying store implements MobileOAuthStore
+
if ms, ok := store.(MobileOAuthStore); ok {
+
wrapper.mobileStore = ms
+
}
+
return wrapper
+
}
+
+
// SaveAuthRequestInfo saves the auth request and also saves mobile CSRF data
+
// if mobile flow data is present in the context.
+
func (w *MobileAwareStoreWrapper) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthRequestData) error {
+
// First, save the auth request to the underlying store
+
if err := w.ClientAuthStore.SaveAuthRequestInfo(ctx, info); err != nil {
+
return err
+
}
+
+
// Check if this is a mobile flow (mobile data in context)
+
if mobileData, ok := getMobileFlowDataFromContext(ctx); ok && w.mobileStore != nil {
+
// Save mobile CSRF data tied to this OAuth state
+
// IMPORTANT: If this fails, we MUST propagate the error. Otherwise:
+
// 1. No server-side CSRF record is stored
+
// 2. Every mobile callback will "fail closed" to web flow
+
// 3. Mobile sign-in silently breaks with no indication
+
// Failing loudly here lets the user retry rather than being confused
+
// about why they're getting a web flow instead of mobile.
+
if err := w.mobileStore.SaveMobileOAuthData(ctx, info.State, mobileData); err != nil {
+
slog.Error("failed to save mobile CSRF data - mobile login will fail",
+
"state", info.State, "error", err)
+
return fmt.Errorf("failed to save mobile OAuth data: %w", err)
+
}
+
}
+
+
return nil
+
}
+
+
// GetMobileOAuthData implements MobileOAuthStore interface by delegating to underlying store
+
func (w *MobileAwareStoreWrapper) GetMobileOAuthData(ctx context.Context, state string) (*MobileOAuthData, error) {
+
if w.mobileStore != nil {
+
return w.mobileStore.GetMobileOAuthData(ctx, state)
+
}
+
return nil, nil
+
}
+
+
// SaveMobileOAuthData implements MobileOAuthStore interface by delegating to underlying store
+
func (w *MobileAwareStoreWrapper) SaveMobileOAuthData(ctx context.Context, state string, data MobileOAuthData) error {
+
if w.mobileStore != nil {
+
return w.mobileStore.SaveMobileOAuthData(ctx, state, data)
+
}
+
return nil
+
}
+
+
// UnwrapPostgresStore returns the underlying PostgresOAuthStore if present.
+
// This is useful for accessing cleanup methods that aren't part of the interface.
+
func (w *MobileAwareStoreWrapper) UnwrapPostgresStore() *PostgresOAuthStore {
+
if ps, ok := w.ClientAuthStore.(*PostgresOAuthStore); ok {
+
return ps
+
}
+
return nil
+
}
+
+
// SaveMobileOAuthData stores mobile CSRF data tied to an OAuth state
+
// This ties the CSRF token to the OAuth flow via the state parameter,
+
// which comes back through the OAuth response for server-side validation.
+
func (s *PostgresOAuthStore) SaveMobileOAuthData(ctx context.Context, state string, data MobileOAuthData) error {
+
query := `
+
UPDATE oauth_requests
+
SET mobile_csrf_token = $2, mobile_redirect_uri = $3
+
WHERE state = $1
+
`
+
+
result, err := s.db.ExecContext(ctx, query, state, data.CSRFToken, data.RedirectURI)
+
if err != nil {
+
return fmt.Errorf("failed to save mobile OAuth data: %w", err)
+
}
+
+
rows, err := result.RowsAffected()
+
if err != nil {
+
return fmt.Errorf("failed to get rows affected: %w", err)
+
}
+
+
if rows == 0 {
+
return ErrAuthRequestNotFound
+
}
+
+
return nil
+
}
+
+
// GetMobileOAuthData retrieves mobile CSRF data by OAuth state
+
// This is called during callback to compare the server-side CSRF token
+
// (retrieved by state from the OAuth response) against the cookie CSRF.
+
func (s *PostgresOAuthStore) GetMobileOAuthData(ctx context.Context, state string) (*MobileOAuthData, error) {
+
query := `
+
SELECT mobile_csrf_token, mobile_redirect_uri
+
FROM oauth_requests
+
WHERE state = $1
+
`
+
+
var csrfToken, redirectURI sql.NullString
+
err := s.db.QueryRowContext(ctx, query, state).Scan(&csrfToken, &redirectURI)
+
+
if err == sql.ErrNoRows {
+
return nil, ErrAuthRequestNotFound
+
}
+
if err != nil {
+
return nil, fmt.Errorf("failed to get mobile OAuth data: %w", err)
+
}
+
+
// Return nil if no mobile data was stored (this was a web flow)
+
if !csrfToken.Valid {
+
return nil, nil
+
}
+
+
return &MobileOAuthData{
+
CSRFToken: csrfToken.String,
+
RedirectURI: redirectURI.String,
+
}, nil
+
}
+522
internal/atproto/oauth/store_test.go
···
+
package oauth
+
+
import (
+
"context"
+
"database/sql"
+
"os"
+
"testing"
+
+
"github.com/bluesky-social/indigo/atproto/auth/oauth"
+
"github.com/bluesky-social/indigo/atproto/syntax"
+
_ "github.com/lib/pq"
+
"github.com/pressly/goose/v3"
+
"github.com/stretchr/testify/assert"
+
"github.com/stretchr/testify/require"
+
)
+
+
// setupTestDB creates a test database connection and runs migrations
+
func setupTestDB(t *testing.T) *sql.DB {
+
dsn := os.Getenv("TEST_DATABASE_URL")
+
if dsn == "" {
+
dsn = "postgres://test_user:test_password@localhost:5434/coves_test?sslmode=disable"
+
}
+
+
db, err := sql.Open("postgres", dsn)
+
require.NoError(t, err, "Failed to connect to test database")
+
+
// Run migrations
+
require.NoError(t, goose.Up(db, "../../db/migrations"), "Failed to run migrations")
+
+
return db
+
}
+
+
// cleanupOAuth removes all test OAuth data from the database
+
func cleanupOAuth(t *testing.T, db *sql.DB) {
+
_, err := db.Exec("DELETE FROM oauth_sessions WHERE did LIKE 'did:plc:test%'")
+
require.NoError(t, err, "Failed to cleanup oauth_sessions")
+
+
_, err = db.Exec("DELETE FROM oauth_requests WHERE state LIKE 'test%'")
+
require.NoError(t, err, "Failed to cleanup oauth_requests")
+
}
+
+
func TestPostgresOAuthStore_SaveAndGetSession(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
defer cleanupOAuth(t, db)
+
+
store := NewPostgresOAuthStore(db, 0) // Use default TTL
+
ctx := context.Background()
+
+
did, err := syntax.ParseDID("did:plc:test123abc")
+
require.NoError(t, err)
+
+
session := oauth.ClientSessionData{
+
AccountDID: did,
+
SessionID: "session123",
+
HostURL: "https://pds.example.com",
+
AuthServerURL: "https://auth.example.com",
+
AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
+
AuthServerRevocationEndpoint: "https://auth.example.com/oauth/revoke",
+
Scopes: []string{"atproto"},
+
AccessToken: "at_test_token_abc123",
+
RefreshToken: "rt_test_token_xyz789",
+
DPoPAuthServerNonce: "nonce_auth_123",
+
DPoPHostNonce: "nonce_host_456",
+
DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
+
}
+
+
// Save session
+
err = store.SaveSession(ctx, session)
+
assert.NoError(t, err)
+
+
// Retrieve session
+
retrieved, err := store.GetSession(ctx, did, "session123")
+
assert.NoError(t, err)
+
assert.NotNil(t, retrieved)
+
assert.Equal(t, session.AccountDID.String(), retrieved.AccountDID.String())
+
assert.Equal(t, session.SessionID, retrieved.SessionID)
+
assert.Equal(t, session.HostURL, retrieved.HostURL)
+
assert.Equal(t, session.AuthServerURL, retrieved.AuthServerURL)
+
assert.Equal(t, session.AuthServerTokenEndpoint, retrieved.AuthServerTokenEndpoint)
+
assert.Equal(t, session.AccessToken, retrieved.AccessToken)
+
assert.Equal(t, session.RefreshToken, retrieved.RefreshToken)
+
assert.Equal(t, session.DPoPAuthServerNonce, retrieved.DPoPAuthServerNonce)
+
assert.Equal(t, session.DPoPHostNonce, retrieved.DPoPHostNonce)
+
assert.Equal(t, session.DPoPPrivateKeyMultibase, retrieved.DPoPPrivateKeyMultibase)
+
assert.Equal(t, session.Scopes, retrieved.Scopes)
+
}
+
+
func TestPostgresOAuthStore_SaveSession_Upsert(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
defer cleanupOAuth(t, db)
+
+
store := NewPostgresOAuthStore(db, 0) // Use default TTL
+
ctx := context.Background()
+
+
did, err := syntax.ParseDID("did:plc:testupsert")
+
require.NoError(t, err)
+
+
// Initial session
+
session1 := oauth.ClientSessionData{
+
AccountDID: did,
+
SessionID: "session_upsert",
+
HostURL: "https://pds1.example.com",
+
AuthServerURL: "https://auth1.example.com",
+
AuthServerTokenEndpoint: "https://auth1.example.com/oauth/token",
+
Scopes: []string{"atproto"},
+
AccessToken: "old_access_token",
+
RefreshToken: "old_refresh_token",
+
DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
+
}
+
+
err = store.SaveSession(ctx, session1)
+
require.NoError(t, err)
+
+
// Updated session (same DID and session ID)
+
session2 := oauth.ClientSessionData{
+
AccountDID: did,
+
SessionID: "session_upsert",
+
HostURL: "https://pds2.example.com",
+
AuthServerURL: "https://auth2.example.com",
+
AuthServerTokenEndpoint: "https://auth2.example.com/oauth/token",
+
Scopes: []string{"atproto", "transition:generic"},
+
AccessToken: "new_access_token",
+
RefreshToken: "new_refresh_token",
+
DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktX",
+
}
+
+
// Save again - should update
+
err = store.SaveSession(ctx, session2)
+
assert.NoError(t, err)
+
+
// Retrieve should get updated values
+
retrieved, err := store.GetSession(ctx, did, "session_upsert")
+
assert.NoError(t, err)
+
assert.Equal(t, "new_access_token", retrieved.AccessToken)
+
assert.Equal(t, "new_refresh_token", retrieved.RefreshToken)
+
assert.Equal(t, "https://pds2.example.com", retrieved.HostURL)
+
assert.Equal(t, []string{"atproto", "transition:generic"}, retrieved.Scopes)
+
}
+
+
func TestPostgresOAuthStore_GetSession_NotFound(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
store := NewPostgresOAuthStore(db, 0) // Use default TTL
+
ctx := context.Background()
+
+
did, err := syntax.ParseDID("did:plc:nonexistent")
+
require.NoError(t, err)
+
+
_, err = store.GetSession(ctx, did, "nonexistent_session")
+
assert.ErrorIs(t, err, ErrSessionNotFound)
+
}
+
+
func TestPostgresOAuthStore_DeleteSession(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
defer cleanupOAuth(t, db)
+
+
store := NewPostgresOAuthStore(db, 0) // Use default TTL
+
ctx := context.Background()
+
+
did, err := syntax.ParseDID("did:plc:testdelete")
+
require.NoError(t, err)
+
+
session := oauth.ClientSessionData{
+
AccountDID: did,
+
SessionID: "session_delete",
+
HostURL: "https://pds.example.com",
+
AuthServerURL: "https://auth.example.com",
+
AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
+
Scopes: []string{"atproto"},
+
AccessToken: "test_token",
+
RefreshToken: "test_refresh",
+
DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
+
}
+
+
// Save session
+
err = store.SaveSession(ctx, session)
+
require.NoError(t, err)
+
+
// Delete session
+
err = store.DeleteSession(ctx, did, "session_delete")
+
assert.NoError(t, err)
+
+
// Verify session is gone
+
_, err = store.GetSession(ctx, did, "session_delete")
+
assert.ErrorIs(t, err, ErrSessionNotFound)
+
}
+
+
func TestPostgresOAuthStore_DeleteSession_NotFound(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
store := NewPostgresOAuthStore(db, 0) // Use default TTL
+
ctx := context.Background()
+
+
did, err := syntax.ParseDID("did:plc:nonexistent")
+
require.NoError(t, err)
+
+
err = store.DeleteSession(ctx, did, "nonexistent_session")
+
assert.ErrorIs(t, err, ErrSessionNotFound)
+
}
+
+
func TestPostgresOAuthStore_SaveAndGetAuthRequestInfo(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
defer cleanupOAuth(t, db)
+
+
store := NewPostgresOAuthStore(db, 0) // Use default TTL
+
ctx := context.Background()
+
+
did, err := syntax.ParseDID("did:plc:testrequest")
+
require.NoError(t, err)
+
+
info := oauth.AuthRequestData{
+
State: "test_state_abc123",
+
AuthServerURL: "https://auth.example.com",
+
AccountDID: &did,
+
Scopes: []string{"atproto"},
+
RequestURI: "urn:ietf:params:oauth:request_uri:abc123",
+
AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
+
AuthServerRevocationEndpoint: "https://auth.example.com/oauth/revoke",
+
PKCEVerifier: "verifier_xyz789",
+
DPoPAuthServerNonce: "nonce_abc",
+
DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
+
}
+
+
// Save auth request info
+
err = store.SaveAuthRequestInfo(ctx, info)
+
assert.NoError(t, err)
+
+
// Retrieve auth request info
+
retrieved, err := store.GetAuthRequestInfo(ctx, "test_state_abc123")
+
assert.NoError(t, err)
+
assert.NotNil(t, retrieved)
+
assert.Equal(t, info.State, retrieved.State)
+
assert.Equal(t, info.AuthServerURL, retrieved.AuthServerURL)
+
assert.NotNil(t, retrieved.AccountDID)
+
assert.Equal(t, info.AccountDID.String(), retrieved.AccountDID.String())
+
assert.Equal(t, info.Scopes, retrieved.Scopes)
+
assert.Equal(t, info.RequestURI, retrieved.RequestURI)
+
assert.Equal(t, info.AuthServerTokenEndpoint, retrieved.AuthServerTokenEndpoint)
+
assert.Equal(t, info.PKCEVerifier, retrieved.PKCEVerifier)
+
assert.Equal(t, info.DPoPAuthServerNonce, retrieved.DPoPAuthServerNonce)
+
assert.Equal(t, info.DPoPPrivateKeyMultibase, retrieved.DPoPPrivateKeyMultibase)
+
}
+
+
func TestPostgresOAuthStore_SaveAuthRequestInfo_NoDID(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
defer cleanupOAuth(t, db)
+
+
store := NewPostgresOAuthStore(db, 0) // Use default TTL
+
ctx := context.Background()
+
+
info := oauth.AuthRequestData{
+
State: "test_state_nodid",
+
AuthServerURL: "https://auth.example.com",
+
AccountDID: nil, // No DID provided
+
Scopes: []string{"atproto"},
+
RequestURI: "urn:ietf:params:oauth:request_uri:nodid",
+
AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
+
PKCEVerifier: "verifier_nodid",
+
DPoPAuthServerNonce: "nonce_nodid",
+
DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
+
}
+
+
// Save auth request info without DID
+
err := store.SaveAuthRequestInfo(ctx, info)
+
assert.NoError(t, err)
+
+
// Retrieve and verify DID is nil
+
retrieved, err := store.GetAuthRequestInfo(ctx, "test_state_nodid")
+
assert.NoError(t, err)
+
assert.Nil(t, retrieved.AccountDID)
+
assert.Equal(t, info.State, retrieved.State)
+
}
+
+
func TestPostgresOAuthStore_GetAuthRequestInfo_NotFound(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
store := NewPostgresOAuthStore(db, 0) // Use default TTL
+
ctx := context.Background()
+
+
_, err := store.GetAuthRequestInfo(ctx, "nonexistent_state")
+
assert.ErrorIs(t, err, ErrAuthRequestNotFound)
+
}
+
+
func TestPostgresOAuthStore_DeleteAuthRequestInfo(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
defer cleanupOAuth(t, db)
+
+
store := NewPostgresOAuthStore(db, 0) // Use default TTL
+
ctx := context.Background()
+
+
info := oauth.AuthRequestData{
+
State: "test_state_delete",
+
AuthServerURL: "https://auth.example.com",
+
Scopes: []string{"atproto"},
+
RequestURI: "urn:ietf:params:oauth:request_uri:delete",
+
AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
+
PKCEVerifier: "verifier_delete",
+
DPoPAuthServerNonce: "nonce_delete",
+
DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
+
}
+
+
// Save auth request info
+
err := store.SaveAuthRequestInfo(ctx, info)
+
require.NoError(t, err)
+
+
// Delete auth request info
+
err = store.DeleteAuthRequestInfo(ctx, "test_state_delete")
+
assert.NoError(t, err)
+
+
// Verify it's gone
+
_, err = store.GetAuthRequestInfo(ctx, "test_state_delete")
+
assert.ErrorIs(t, err, ErrAuthRequestNotFound)
+
}
+
+
func TestPostgresOAuthStore_DeleteAuthRequestInfo_NotFound(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
store := NewPostgresOAuthStore(db, 0) // Use default TTL
+
ctx := context.Background()
+
+
err := store.DeleteAuthRequestInfo(ctx, "nonexistent_state")
+
assert.ErrorIs(t, err, ErrAuthRequestNotFound)
+
}
+
+
func TestPostgresOAuthStore_CleanupExpiredSessions(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
defer cleanupOAuth(t, db)
+
+
storeInterface := NewPostgresOAuthStore(db, 0) // Use default TTL
+
store, ok := storeInterface.(*PostgresOAuthStore)
+
require.True(t, ok, "store should be *PostgresOAuthStore")
+
ctx := context.Background()
+
+
did1, err := syntax.ParseDID("did:plc:testexpired1")
+
require.NoError(t, err)
+
did2, err := syntax.ParseDID("did:plc:testexpired2")
+
require.NoError(t, err)
+
+
// Create an expired session (manually insert with past expiration)
+
_, err = db.ExecContext(ctx, `
+
INSERT INTO oauth_sessions (
+
did, session_id, handle, pds_url, host_url,
+
access_token, refresh_token,
+
dpop_private_key_multibase, auth_server_iss,
+
auth_server_token_endpoint, scopes,
+
expires_at, created_at
+
) VALUES (
+
$1, $2, $3, $4, $5,
+
$6, $7,
+
$8, $9,
+
$10, $11,
+
NOW() - INTERVAL '1 day', NOW()
+
)
+
`, did1.String(), "expired_session", "test.handle", "https://pds.example.com", "https://pds.example.com",
+
"expired_token", "expired_refresh",
+
"z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH", "https://auth.example.com",
+
"https://auth.example.com/oauth/token", `{"atproto"}`)
+
require.NoError(t, err)
+
+
// Create a valid session
+
validSession := oauth.ClientSessionData{
+
AccountDID: did2,
+
SessionID: "valid_session",
+
HostURL: "https://pds.example.com",
+
AuthServerURL: "https://auth.example.com",
+
AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
+
Scopes: []string{"atproto"},
+
AccessToken: "valid_token",
+
RefreshToken: "valid_refresh",
+
DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
+
}
+
err = store.SaveSession(ctx, validSession)
+
require.NoError(t, err)
+
+
// Cleanup expired sessions
+
count, err := store.CleanupExpiredSessions(ctx)
+
assert.NoError(t, err)
+
assert.Equal(t, int64(1), count, "Should delete 1 expired session")
+
+
// Verify expired session is gone
+
_, err = store.GetSession(ctx, did1, "expired_session")
+
assert.ErrorIs(t, err, ErrSessionNotFound)
+
+
// Verify valid session still exists
+
_, err = store.GetSession(ctx, did2, "valid_session")
+
assert.NoError(t, err)
+
}
+
+
func TestPostgresOAuthStore_CleanupExpiredAuthRequests(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
defer cleanupOAuth(t, db)
+
+
storeInterface := NewPostgresOAuthStore(db, 0)
+
pgStore, ok := storeInterface.(*PostgresOAuthStore)
+
require.True(t, ok, "store should be *PostgresOAuthStore")
+
store := oauth.ClientAuthStore(pgStore)
+
ctx := context.Background()
+
+
// Create an old auth request (manually insert with old timestamp)
+
_, err := db.ExecContext(ctx, `
+
INSERT INTO oauth_requests (
+
state, did, handle, pds_url, pkce_verifier,
+
dpop_private_key_multibase, dpop_authserver_nonce,
+
auth_server_iss, request_uri,
+
auth_server_token_endpoint, scopes,
+
created_at
+
) VALUES (
+
$1, $2, $3, $4, $5,
+
$6, $7,
+
$8, $9,
+
$10, $11,
+
NOW() - INTERVAL '1 hour'
+
)
+
`, "test_old_state", "did:plc:testold", "test.handle", "https://pds.example.com",
+
"old_verifier", "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
+
"nonce_old", "https://auth.example.com", "urn:ietf:params:oauth:request_uri:old",
+
"https://auth.example.com/oauth/token", `{"atproto"}`)
+
require.NoError(t, err)
+
+
// Create a recent auth request
+
recentInfo := oauth.AuthRequestData{
+
State: "test_recent_state",
+
AuthServerURL: "https://auth.example.com",
+
Scopes: []string{"atproto"},
+
RequestURI: "urn:ietf:params:oauth:request_uri:recent",
+
AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
+
PKCEVerifier: "recent_verifier",
+
DPoPAuthServerNonce: "nonce_recent",
+
DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
+
}
+
err = store.SaveAuthRequestInfo(ctx, recentInfo)
+
require.NoError(t, err)
+
+
// Cleanup expired auth requests (older than 30 minutes)
+
count, err := pgStore.CleanupExpiredAuthRequests(ctx)
+
assert.NoError(t, err)
+
assert.Equal(t, int64(1), count, "Should delete 1 expired auth request")
+
+
// Verify old request is gone
+
_, err = store.GetAuthRequestInfo(ctx, "test_old_state")
+
assert.ErrorIs(t, err, ErrAuthRequestNotFound)
+
+
// Verify recent request still exists
+
_, err = store.GetAuthRequestInfo(ctx, "test_recent_state")
+
assert.NoError(t, err)
+
}
+
+
func TestPostgresOAuthStore_MultipleSessions(t *testing.T) {
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
defer cleanupOAuth(t, db)
+
+
store := NewPostgresOAuthStore(db, 0) // Use default TTL
+
ctx := context.Background()
+
+
did, err := syntax.ParseDID("did:plc:testmulti")
+
require.NoError(t, err)
+
+
// Create multiple sessions for the same DID
+
session1 := oauth.ClientSessionData{
+
AccountDID: did,
+
SessionID: "browser1",
+
HostURL: "https://pds.example.com",
+
AuthServerURL: "https://auth.example.com",
+
AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
+
Scopes: []string{"atproto"},
+
AccessToken: "token_browser1",
+
RefreshToken: "refresh_browser1",
+
DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
+
}
+
+
session2 := oauth.ClientSessionData{
+
AccountDID: did,
+
SessionID: "mobile_app",
+
HostURL: "https://pds.example.com",
+
AuthServerURL: "https://auth.example.com",
+
AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
+
Scopes: []string{"atproto"},
+
AccessToken: "token_mobile",
+
RefreshToken: "refresh_mobile",
+
DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktX",
+
}
+
+
// Save both sessions
+
err = store.SaveSession(ctx, session1)
+
require.NoError(t, err)
+
err = store.SaveSession(ctx, session2)
+
require.NoError(t, err)
+
+
// Retrieve both sessions
+
retrieved1, err := store.GetSession(ctx, did, "browser1")
+
assert.NoError(t, err)
+
assert.Equal(t, "token_browser1", retrieved1.AccessToken)
+
+
retrieved2, err := store.GetSession(ctx, did, "mobile_app")
+
assert.NoError(t, err)
+
assert.Equal(t, "token_mobile", retrieved2.AccessToken)
+
+
// Delete one session
+
err = store.DeleteSession(ctx, did, "browser1")
+
assert.NoError(t, err)
+
+
// Verify only browser1 is deleted
+
_, err = store.GetSession(ctx, did, "browser1")
+
assert.ErrorIs(t, err, ErrSessionNotFound)
+
+
// mobile_app should still exist
+
_, err = store.GetSession(ctx, did, "mobile_app")
+
assert.NoError(t, err)
+
}
+99
internal/atproto/oauth/transport.go
···
+
package oauth
+
+
import (
+
"fmt"
+
"net"
+
"net/http"
+
"time"
+
)
+
+
// ssrfSafeTransport wraps http.Transport to prevent SSRF attacks
+
type ssrfSafeTransport struct {
+
base *http.Transport
+
allowPrivate bool // For dev/testing only
+
}
+
+
// isPrivateIP checks if an IP is in a private/reserved range
+
func isPrivateIP(ip net.IP) bool {
+
if ip == nil {
+
return false
+
}
+
+
// Check for loopback
+
if ip.IsLoopback() {
+
return true
+
}
+
+
// Check for link-local
+
if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
+
return true
+
}
+
+
// Check for private ranges
+
privateRanges := []string{
+
"10.0.0.0/8",
+
"172.16.0.0/12",
+
"192.168.0.0/16",
+
"169.254.0.0/16",
+
"::1/128",
+
"fc00::/7",
+
"fe80::/10",
+
}
+
+
for _, cidr := range privateRanges {
+
_, network, err := net.ParseCIDR(cidr)
+
if err == nil && network.Contains(ip) {
+
return true
+
}
+
}
+
+
return false
+
}
+
+
func (t *ssrfSafeTransport) RoundTrip(req *http.Request) (*http.Response, error) {
+
host := req.URL.Hostname()
+
+
// Resolve hostname to IP
+
ips, err := net.LookupIP(host)
+
if err != nil {
+
return nil, fmt.Errorf("failed to resolve host: %w", err)
+
}
+
+
// Check all resolved IPs
+
if !t.allowPrivate {
+
for _, ip := range ips {
+
if isPrivateIP(ip) {
+
return nil, fmt.Errorf("SSRF blocked: %s resolves to private IP %s", host, ip)
+
}
+
}
+
}
+
+
return t.base.RoundTrip(req)
+
}
+
+
// NewSSRFSafeHTTPClient creates an HTTP client with SSRF protections
+
func NewSSRFSafeHTTPClient(allowPrivate bool) *http.Client {
+
transport := &ssrfSafeTransport{
+
base: &http.Transport{
+
DialContext: (&net.Dialer{
+
Timeout: 10 * time.Second,
+
KeepAlive: 30 * time.Second,
+
}).DialContext,
+
MaxIdleConns: 100,
+
IdleConnTimeout: 90 * time.Second,
+
TLSHandshakeTimeout: 10 * time.Second,
+
},
+
allowPrivate: allowPrivate,
+
}
+
+
return &http.Client{
+
Timeout: 15 * time.Second,
+
Transport: transport,
+
CheckRedirect: func(req *http.Request, via []*http.Request) error {
+
if len(via) >= 5 {
+
return fmt.Errorf("too many redirects")
+
}
+
return nil
+
},
+
}
+
}
+132
internal/atproto/oauth/transport_test.go
···
+
package oauth
+
+
import (
+
"net"
+
"net/http"
+
"testing"
+
)
+
+
func TestIsPrivateIP(t *testing.T) {
+
tests := []struct {
+
name string
+
ip string
+
expected bool
+
}{
+
// Loopback addresses
+
{"IPv4 loopback", "127.0.0.1", true},
+
{"IPv6 loopback", "::1", true},
+
+
// Private IPv4 ranges
+
{"Private 10.x.x.x", "10.0.0.1", true},
+
{"Private 10.x.x.x edge", "10.255.255.255", true},
+
{"Private 172.16.x.x", "172.16.0.1", true},
+
{"Private 172.31.x.x edge", "172.31.255.255", true},
+
{"Private 192.168.x.x", "192.168.1.1", true},
+
{"Private 192.168.x.x edge", "192.168.255.255", true},
+
+
// Link-local addresses
+
{"Link-local IPv4", "169.254.1.1", true},
+
{"Link-local IPv6", "fe80::1", true},
+
+
// IPv6 private ranges
+
{"IPv6 unique local fc00", "fc00::1", true},
+
{"IPv6 unique local fd00", "fd00::1", true},
+
+
// Public addresses
+
{"Public IP 1.1.1.1", "1.1.1.1", false},
+
{"Public IP 8.8.8.8", "8.8.8.8", false},
+
{"Public IP 172.15.0.1", "172.15.0.1", false}, // Just before 172.16/12
+
{"Public IP 172.32.0.1", "172.32.0.1", false}, // Just after 172.31/12
+
{"Public IP 11.0.0.1", "11.0.0.1", false}, // Just after 10/8
+
{"Public IPv6", "2001:4860:4860::8888", false}, // Google DNS
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
ip := net.ParseIP(tt.ip)
+
if ip == nil {
+
t.Fatalf("Failed to parse IP: %s", tt.ip)
+
}
+
+
result := isPrivateIP(ip)
+
if result != tt.expected {
+
t.Errorf("isPrivateIP(%s) = %v, expected %v", tt.ip, result, tt.expected)
+
}
+
})
+
}
+
}
+
+
func TestIsPrivateIP_NilIP(t *testing.T) {
+
result := isPrivateIP(nil)
+
if result != false {
+
t.Errorf("isPrivateIP(nil) = %v, expected false", result)
+
}
+
}
+
+
func TestNewSSRFSafeHTTPClient(t *testing.T) {
+
tests := []struct {
+
name string
+
allowPrivate bool
+
}{
+
{"Production client (no private IPs)", false},
+
{"Development client (allow private IPs)", true},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
client := NewSSRFSafeHTTPClient(tt.allowPrivate)
+
+
if client == nil {
+
t.Fatal("NewSSRFSafeHTTPClient returned nil")
+
}
+
+
if client.Timeout == 0 {
+
t.Error("Expected timeout to be set")
+
}
+
+
if client.Transport == nil {
+
t.Error("Expected transport to be set")
+
}
+
+
transport, ok := client.Transport.(*ssrfSafeTransport)
+
if !ok {
+
t.Error("Expected ssrfSafeTransport")
+
}
+
+
if transport.allowPrivate != tt.allowPrivate {
+
t.Errorf("Expected allowPrivate=%v, got %v", tt.allowPrivate, transport.allowPrivate)
+
}
+
})
+
}
+
}
+
+
func TestSSRFSafeHTTPClient_RedirectLimit(t *testing.T) {
+
client := NewSSRFSafeHTTPClient(false)
+
+
// Simulate checking redirect limit
+
if client.CheckRedirect == nil {
+
t.Fatal("Expected CheckRedirect to be set")
+
}
+
+
// Test redirect limit (5 redirects)
+
var via []*http.Request
+
for i := 0; i < 5; i++ {
+
req := &http.Request{}
+
via = append(via, req)
+
}
+
+
err := client.CheckRedirect(nil, via)
+
if err == nil {
+
t.Error("Expected error for too many redirects")
+
}
+
if err.Error() != "too many redirects" {
+
t.Errorf("Expected 'too many redirects' error, got: %v", err)
+
}
+
+
// Test within limit (4 redirects)
+
via = via[:4]
+
err = client.CheckRedirect(nil, via)
+
if err != nil {
+
t.Errorf("Expected no error for 4 redirects, got: %v", err)
+
}
+
}
+124
internal/db/migrations/019_update_oauth_for_indigo.sql
···
+
-- +goose Up
+
-- Update OAuth tables to match indigo's ClientAuthStore interface requirements
+
-- This migration adds columns needed for OAuth client sessions and auth requests
+
+
-- Update oauth_requests table
+
-- Add columns for request URI, auth server endpoints, scopes, and DPoP key
+
ALTER TABLE oauth_requests
+
ADD COLUMN request_uri TEXT,
+
ADD COLUMN auth_server_token_endpoint TEXT,
+
ADD COLUMN auth_server_revocation_endpoint TEXT,
+
ADD COLUMN scopes TEXT[],
+
ADD COLUMN dpop_private_key_multibase TEXT;
+
+
-- Make original dpop_private_jwk nullable (we now use dpop_private_key_multibase)
+
ALTER TABLE oauth_requests ALTER COLUMN dpop_private_jwk DROP NOT NULL;
+
+
-- Make did nullable (indigo's AuthRequestData.AccountDID is a pointer - optional)
+
ALTER TABLE oauth_requests ALTER COLUMN did DROP NOT NULL;
+
+
-- Make handle and pds_url nullable too (derived from DID resolution, not always available at auth request time)
+
ALTER TABLE oauth_requests ALTER COLUMN handle DROP NOT NULL;
+
ALTER TABLE oauth_requests ALTER COLUMN pds_url DROP NOT NULL;
+
+
-- Update existing oauth_requests data
+
-- Convert dpop_private_jwk (JSONB) to multibase format if needed
+
-- Note: This will leave the multibase column NULL for now since conversion requires crypto logic
+
-- The application will need to handle NULL values or regenerate keys on next auth flow
+
UPDATE oauth_requests
+
SET
+
auth_server_token_endpoint = auth_server_iss || '/oauth/token',
+
scopes = ARRAY['atproto']::TEXT[]
+
WHERE auth_server_token_endpoint IS NULL;
+
+
-- Add indexes for new columns
+
CREATE INDEX idx_oauth_requests_request_uri ON oauth_requests(request_uri) WHERE request_uri IS NOT NULL;
+
+
-- Update oauth_sessions table
+
-- Add session_id column (will become part of composite key)
+
ALTER TABLE oauth_sessions
+
ADD COLUMN session_id TEXT,
+
ADD COLUMN host_url TEXT,
+
ADD COLUMN auth_server_token_endpoint TEXT,
+
ADD COLUMN auth_server_revocation_endpoint TEXT,
+
ADD COLUMN scopes TEXT[],
+
ADD COLUMN dpop_private_key_multibase TEXT;
+
+
-- Make original dpop_private_jwk nullable (we now use dpop_private_key_multibase)
+
ALTER TABLE oauth_sessions ALTER COLUMN dpop_private_jwk DROP NOT NULL;
+
+
-- Populate session_id for existing sessions (use DID as default for single-session per account)
+
-- In production, you may want to generate unique session IDs
+
UPDATE oauth_sessions
+
SET
+
session_id = 'default',
+
host_url = pds_url,
+
auth_server_token_endpoint = auth_server_iss || '/oauth/token',
+
scopes = ARRAY['atproto']::TEXT[]
+
WHERE session_id IS NULL;
+
+
-- Make session_id NOT NULL after populating existing data
+
ALTER TABLE oauth_sessions
+
ALTER COLUMN session_id SET NOT NULL;
+
+
-- Drop old unique constraint on did only
+
ALTER TABLE oauth_sessions
+
DROP CONSTRAINT IF EXISTS oauth_sessions_did_key;
+
+
-- Create new composite unique constraint for (did, session_id)
+
-- This allows multiple sessions per account
+
-- Note: UNIQUE constraint automatically creates an index, so no separate index needed
+
ALTER TABLE oauth_sessions
+
ADD CONSTRAINT oauth_sessions_did_session_id_key UNIQUE (did, session_id);
+
+
-- Add comment explaining the schema change
+
COMMENT ON COLUMN oauth_sessions.session_id IS 'Session identifier to support multiple concurrent sessions per account';
+
COMMENT ON CONSTRAINT oauth_sessions_did_session_id_key ON oauth_sessions IS 'Composite key allowing multiple sessions per DID';
+
+
-- +goose Down
+
-- Rollback: Remove added columns and restore original unique constraint
+
+
-- oauth_sessions rollback
+
-- Drop composite unique constraint (this also drops the associated index)
+
ALTER TABLE oauth_sessions
+
DROP CONSTRAINT IF EXISTS oauth_sessions_did_session_id_key;
+
+
-- Delete all but the most recent session per DID before restoring unique constraint
+
-- This ensures the UNIQUE (did) constraint can be added without conflicts
+
DELETE FROM oauth_sessions a
+
USING oauth_sessions b
+
WHERE a.did = b.did
+
AND a.created_at < b.created_at;
+
+
-- Restore old unique constraint
+
ALTER TABLE oauth_sessions
+
ADD CONSTRAINT oauth_sessions_did_key UNIQUE (did);
+
+
-- Restore NOT NULL constraint on dpop_private_jwk
+
ALTER TABLE oauth_sessions
+
ALTER COLUMN dpop_private_jwk SET NOT NULL;
+
+
ALTER TABLE oauth_sessions
+
DROP COLUMN IF EXISTS dpop_private_key_multibase,
+
DROP COLUMN IF EXISTS scopes,
+
DROP COLUMN IF EXISTS auth_server_revocation_endpoint,
+
DROP COLUMN IF EXISTS auth_server_token_endpoint,
+
DROP COLUMN IF EXISTS host_url,
+
DROP COLUMN IF EXISTS session_id;
+
+
-- oauth_requests rollback
+
DROP INDEX IF EXISTS idx_oauth_requests_request_uri;
+
+
-- Restore NOT NULL constraints
+
ALTER TABLE oauth_requests
+
ALTER COLUMN dpop_private_jwk SET NOT NULL,
+
ALTER COLUMN did SET NOT NULL,
+
ALTER COLUMN handle SET NOT NULL,
+
ALTER COLUMN pds_url SET NOT NULL;
+
+
ALTER TABLE oauth_requests
+
DROP COLUMN IF EXISTS dpop_private_key_multibase,
+
DROP COLUMN IF EXISTS scopes,
+
DROP COLUMN IF EXISTS auth_server_revocation_endpoint,
+
DROP COLUMN IF EXISTS auth_server_token_endpoint,
+
DROP COLUMN IF EXISTS request_uri;
+23
internal/db/migrations/020_add_mobile_oauth_csrf.sql
···
+
-- +goose Up
+
-- Add columns for mobile OAuth CSRF protection with server-side state
+
-- This ties the CSRF token to the OAuth state, allowing validation against
+
-- a value that comes back through the OAuth response (the state parameter)
+
-- rather than only validating cookies against each other.
+
+
ALTER TABLE oauth_requests
+
ADD COLUMN mobile_csrf_token TEXT,
+
ADD COLUMN mobile_redirect_uri TEXT;
+
+
-- Index for quick lookup of mobile data when callback is received
+
CREATE INDEX idx_oauth_requests_mobile_csrf ON oauth_requests(state)
+
WHERE mobile_csrf_token IS NOT NULL;
+
+
COMMENT ON COLUMN oauth_requests.mobile_csrf_token IS 'CSRF token for mobile OAuth flows, validated against cookie on callback';
+
COMMENT ON COLUMN oauth_requests.mobile_redirect_uri IS 'Mobile redirect URI (Universal Link) for this OAuth flow';
+
+
-- +goose Down
+
DROP INDEX IF EXISTS idx_oauth_requests_mobile_csrf;
+
+
ALTER TABLE oauth_requests
+
DROP COLUMN IF EXISTS mobile_redirect_uri,
+
DROP COLUMN IF EXISTS mobile_csrf_token;
+137
internal/api/handlers/wellknown/universal_links.go
···
+
package wellknown
+
+
import (
+
"encoding/json"
+
"log/slog"
+
"net/http"
+
"os"
+
)
+
+
// HandleAppleAppSiteAssociation serves the iOS Universal Links configuration
+
// GET /.well-known/apple-app-site-association
+
//
+
// Universal Links provide cryptographic binding between the app and domain:
+
// - Requires apple-app-site-association file served over HTTPS
+
// - App must have Associated Domains capability configured
+
// - System verifies domain ownership before routing deep links
+
// - Prevents malicious apps from intercepting deep links
+
//
+
// Spec: https://developer.apple.com/documentation/xcode/supporting-universal-links-in-your-app
+
func HandleAppleAppSiteAssociation(w http.ResponseWriter, r *http.Request) {
+
// Get Apple App ID from environment (format: <Team ID>.<Bundle ID>)
+
// Example: "ABCD1234.social.coves.app"
+
// Find Team ID in Apple Developer Portal -> Membership
+
// Bundle ID is configured in Xcode project
+
appleAppID := os.Getenv("APPLE_APP_ID")
+
if appleAppID == "" {
+
// Development fallback - allows testing without real Team ID
+
// IMPORTANT: This MUST be set in production for Universal Links to work
+
appleAppID = "DEVELOPMENT.social.coves.app"
+
slog.Warn("APPLE_APP_ID not set, using development placeholder",
+
"app_id", appleAppID,
+
"note", "Set APPLE_APP_ID env var for production Universal Links")
+
}
+
+
// Apple requires application/json content type (no charset)
+
w.Header().Set("Content-Type", "application/json")
+
+
// Construct the response per Apple's spec
+
// See: https://developer.apple.com/documentation/bundleresources/applinks
+
response := map[string]interface{}{
+
"applinks": map[string]interface{}{
+
"apps": []string{}, // Must be empty array per Apple spec
+
"details": []map[string]interface{}{
+
{
+
"appID": appleAppID,
+
// Paths that trigger Universal Links when opened in Safari/other apps
+
// These URLs will open the app instead of the browser
+
"paths": []string{
+
"/app/oauth/callback", // Primary Universal Link OAuth callback
+
"/app/oauth/callback/*", // Catch-all for query params
+
},
+
},
+
},
+
},
+
}
+
+
if err := json.NewEncoder(w).Encode(response); err != nil {
+
slog.Error("failed to encode apple-app-site-association", "error", err)
+
http.Error(w, "internal server error", http.StatusInternalServerError)
+
return
+
}
+
+
slog.Debug("served apple-app-site-association", "app_id", appleAppID)
+
}
+
+
// HandleAssetLinks serves the Android App Links configuration
+
// GET /.well-known/assetlinks.json
+
//
+
// App Links provide cryptographic binding between the app and domain:
+
// - Requires assetlinks.json file served over HTTPS
+
// - App must have intent-filter with android:autoVerify="true"
+
// - System verifies domain ownership via SHA-256 certificate fingerprint
+
// - Prevents malicious apps from intercepting deep links
+
//
+
// Spec: https://developer.android.com/training/app-links/verify-android-applinks
+
func HandleAssetLinks(w http.ResponseWriter, r *http.Request) {
+
// Get Android package name from environment
+
// Example: "social.coves.app"
+
androidPackage := os.Getenv("ANDROID_PACKAGE_NAME")
+
if androidPackage == "" {
+
androidPackage = "social.coves.app" // Default for development
+
slog.Warn("ANDROID_PACKAGE_NAME not set, using default",
+
"package", androidPackage,
+
"note", "Set ANDROID_PACKAGE_NAME env var for production App Links")
+
}
+
+
// Get SHA-256 fingerprint from environment
+
// This is the SHA-256 fingerprint of the app's signing certificate
+
//
+
// To get the fingerprint:
+
// Production: keytool -list -v -keystore release.jks -alias release
+
// Debug: keytool -list -v -keystore ~/.android/debug.keystore -alias androiddebugkey -storepass android -keypass android
+
//
+
// Look for "SHA256:" in the output
+
// Format: AA:BB:CC:DD:...:FF (64 hex characters separated by colons)
+
androidFingerprint := os.Getenv("ANDROID_SHA256_FINGERPRINT")
+
if androidFingerprint == "" {
+
// Development fallback - this won't work for real App Links verification
+
// IMPORTANT: This MUST be set in production for App Links to work
+
androidFingerprint = "00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00"
+
slog.Warn("ANDROID_SHA256_FINGERPRINT not set, using development placeholder",
+
"fingerprint", androidFingerprint,
+
"note", "Set ANDROID_SHA256_FINGERPRINT env var for production App Links")
+
}
+
+
w.Header().Set("Content-Type", "application/json")
+
+
// Construct the response per Google's Digital Asset Links spec
+
// See: https://developers.google.com/digital-asset-links/v1/getting-started
+
response := []map[string]interface{}{
+
{
+
// delegate_permission/common.handle_all_urls grants the app permission
+
// to handle URLs for this domain
+
"relation": []string{"delegate_permission/common.handle_all_urls"},
+
"target": map[string]interface{}{
+
"namespace": "android_app",
+
"package_name": androidPackage,
+
// List of certificate fingerprints that can sign the app
+
// Multiple fingerprints can be provided for different signing keys
+
// (e.g., debug + release)
+
"sha256_cert_fingerprints": []string{
+
androidFingerprint,
+
},
+
},
+
},
+
}
+
+
if err := json.NewEncoder(w).Encode(response); err != nil {
+
slog.Error("failed to encode assetlinks.json", "error", err)
+
http.Error(w, "internal server error", http.StatusInternalServerError)
+
return
+
}
+
+
slog.Debug("served assetlinks.json",
+
"package", androidPackage,
+
"fingerprint", androidFingerprint)
+
}
+25
internal/api/routes/wellknown.go
···
+
package routes
+
+
import (
+
"Coves/internal/api/handlers/wellknown"
+
+
"github.com/go-chi/chi/v5"
+
)
+
+
// RegisterWellKnownRoutes registers RFC 8615 well-known URI endpoints
+
// These endpoints are used for service discovery and mobile app deep linking
+
//
+
// Spec: https://www.rfc-editor.org/rfc/rfc8615.html
+
func RegisterWellKnownRoutes(r chi.Router) {
+
// iOS Universal Links configuration
+
// Required for cryptographically-bound deep linking on iOS
+
// Must be served at exact path /.well-known/apple-app-site-association
+
// Content-Type: application/json (no redirects allowed)
+
r.Get("/.well-known/apple-app-site-association", wellknown.HandleAppleAppSiteAssociation)
+
+
// Android App Links configuration
+
// Required for cryptographically-bound deep linking on Android
+
// Must be served at exact path /.well-known/assetlinks.json
+
// Content-Type: application/json (no redirects allowed)
+
r.Get("/.well-known/assetlinks.json", wellknown.HandleAssetLinks)
+
}
+1 -1
internal/api/handlers/comments/middleware.go
···
// The middleware extracts the viewer DID from the Authorization header if present and valid,
// making it available via middleware.GetUserDID(r) in the handler.
// If no valid token is present, the request continues as anonymous (empty DID).
-
func OptionalAuthMiddleware(authMiddleware *middleware.AtProtoAuthMiddleware, next http.HandlerFunc) http.Handler {
+
func OptionalAuthMiddleware(authMiddleware *middleware.OAuthAuthMiddleware, next http.HandlerFunc) http.Handler {
return authMiddleware.OptionalAuth(http.HandlerFunc(next))
}
+164 -312
internal/api/middleware/auth.go
···
package middleware
import (
-
"Coves/internal/atproto/auth"
+
"Coves/internal/atproto/oauth"
"context"
-
"fmt"
+
"encoding/json"
"log"
"net/http"
"strings"
+
+
oauthlib "github.com/bluesky-social/indigo/atproto/auth/oauth"
+
"github.com/bluesky-social/indigo/atproto/syntax"
)
// Context keys for storing user information
···
const (
UserDIDKey contextKey = "user_did"
-
JWTClaimsKey contextKey = "jwt_claims"
-
UserAccessToken contextKey = "user_access_token"
-
DPoPProofKey contextKey = "dpop_proof"
+
OAuthSessionKey contextKey = "oauth_session"
+
UserAccessToken contextKey = "user_access_token" // Kept for backward compatibility
)
-
// AtProtoAuthMiddleware enforces atProto OAuth authentication for protected routes
-
// Validates JWT Bearer tokens from the Authorization header
-
// Supports DPoP (RFC 9449) for token binding verification
-
type AtProtoAuthMiddleware struct {
-
jwksFetcher auth.JWKSFetcher
-
dpopVerifier *auth.DPoPVerifier
-
skipVerify bool // For Phase 1 testing only
+
// SessionUnsealer is an interface for unsealing session tokens
+
// This allows for mocking in tests
+
type SessionUnsealer interface {
+
UnsealSession(token string) (*oauth.SealedSession, error)
}
-
// NewAtProtoAuthMiddleware creates a new atProto auth middleware
-
// skipVerify: if true, only parses JWT without signature verification (Phase 1)
-
//
-
// if false, performs full signature verification (Phase 2)
-
//
-
// IMPORTANT: Call Stop() when shutting down to clean up background goroutines.
-
func NewAtProtoAuthMiddleware(jwksFetcher auth.JWKSFetcher, skipVerify bool) *AtProtoAuthMiddleware {
-
return &AtProtoAuthMiddleware{
-
jwksFetcher: jwksFetcher,
-
dpopVerifier: auth.NewDPoPVerifier(),
-
skipVerify: skipVerify,
-
}
+
// OAuthAuthMiddleware enforces OAuth authentication using sealed session tokens.
+
type OAuthAuthMiddleware struct {
+
unsealer SessionUnsealer
+
store oauthlib.ClientAuthStore
}
-
// Stop stops background goroutines. Call this when shutting down the server.
-
// This prevents goroutine leaks from the DPoP verifier's replay protection cache.
-
func (m *AtProtoAuthMiddleware) Stop() {
-
if m.dpopVerifier != nil {
-
m.dpopVerifier.Stop()
+
// NewOAuthAuthMiddleware creates a new OAuth auth middleware using sealed session tokens.
+
func NewOAuthAuthMiddleware(unsealer SessionUnsealer, store oauthlib.ClientAuthStore) *OAuthAuthMiddleware {
+
return &OAuthAuthMiddleware{
+
unsealer: unsealer,
+
store: store,
}
}
-
// RequireAuth middleware ensures the user is authenticated with a valid JWT
-
// If not authenticated, returns 401
-
// If authenticated, injects user DID and JWT claims into context
+
// RequireAuth middleware ensures the user is authenticated.
+
// Supports sealed session tokens via:
+
// - Authorization: Bearer <sealed_token>
+
// - Cookie: coves_session=<sealed_token>
//
-
// Only accepts DPoP authorization scheme per RFC 9449:
-
// - Authorization: DPoP <token> (DPoP-bound tokens)
-
func (m *AtProtoAuthMiddleware) RequireAuth(next http.Handler) http.Handler {
+
// If not authenticated, returns 401.
+
// If authenticated, injects user DID into context.
+
func (m *OAuthAuthMiddleware) RequireAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-
// Extract Authorization header
+
var token string
+
+
// Try Authorization header first (for mobile/API clients)
authHeader := r.Header.Get("Authorization")
-
if authHeader == "" {
-
writeAuthError(w, "Missing Authorization header")
-
return
+
if authHeader != "" {
+
var ok bool
+
token, ok = extractBearerToken(authHeader)
+
if !ok {
+
writeAuthError(w, "Invalid Authorization header format. Expected: Bearer <token>")
+
return
+
}
+
}
+
+
// If no header, try session cookie (for web clients)
+
if token == "" {
+
if cookie, err := r.Cookie("coves_session"); err == nil {
+
token = cookie.Value
+
}
}
-
// Only accept DPoP scheme per RFC 9449
-
// HTTP auth schemes are case-insensitive per RFC 7235
-
token, ok := extractDPoPToken(authHeader)
-
if !ok {
-
writeAuthError(w, "Invalid Authorization header format. Expected: DPoP <token>")
+
// Must have authentication from either source
+
if token == "" {
+
writeAuthError(w, "Missing authentication")
return
}
-
var claims *auth.Claims
-
var err error
+
// Authenticate using sealed token
+
sealedSession, err := m.unsealer.UnsealSession(token)
+
if err != nil {
+
log.Printf("[AUTH_FAILURE] type=unseal_failed ip=%s method=%s path=%s error=%v",
+
r.RemoteAddr, r.Method, r.URL.Path, err)
+
writeAuthError(w, "Invalid or expired token")
+
return
+
}
-
if m.skipVerify {
-
// Phase 1: Parse only (no signature verification)
-
claims, err = auth.ParseJWT(token)
-
if err != nil {
-
log.Printf("[AUTH_FAILURE] type=parse_error ip=%s method=%s path=%s error=%v",
-
r.RemoteAddr, r.Method, r.URL.Path, err)
-
writeAuthError(w, "Invalid token")
-
return
-
}
-
} else {
-
// Phase 2: Full verification with signature check
-
//
-
// SECURITY: The access token MUST be verified before trusting any claims.
-
// DPoP is an ADDITIONAL security layer, not a replacement for signature verification.
-
claims, err = auth.VerifyJWT(r.Context(), token, m.jwksFetcher)
-
if err != nil {
-
// Token verification failed - REJECT
-
// DO NOT fall back to DPoP-only verification, as that would trust unverified claims
-
issuer := "unknown"
-
if parsedClaims, parseErr := auth.ParseJWT(token); parseErr == nil {
-
issuer = parsedClaims.Issuer
-
}
-
log.Printf("[AUTH_FAILURE] type=verification_failed ip=%s method=%s path=%s issuer=%s error=%v",
-
r.RemoteAddr, r.Method, r.URL.Path, issuer, err)
-
writeAuthError(w, "Invalid or expired token")
-
return
-
}
+
// Parse DID
+
did, err := syntax.ParseDID(sealedSession.DID)
+
if err != nil {
+
log.Printf("[AUTH_FAILURE] type=invalid_did ip=%s method=%s path=%s did=%s error=%v",
+
r.RemoteAddr, r.Method, r.URL.Path, sealedSession.DID, err)
+
writeAuthError(w, "Invalid DID in token")
+
return
+
}
-
// Token signature verified - now check if DPoP binding is required
-
// If the token has a cnf.jkt claim, DPoP proof is REQUIRED
-
dpopHeader := r.Header.Get("DPoP")
-
hasCnfJkt := claims.Confirmation != nil && claims.Confirmation["jkt"] != nil
-
-
if hasCnfJkt {
-
// Token has DPoP binding - REQUIRE valid DPoP proof
-
if dpopHeader == "" {
-
log.Printf("[AUTH_FAILURE] type=missing_dpop ip=%s method=%s path=%s error=token has cnf.jkt but no DPoP header",
-
r.RemoteAddr, r.Method, r.URL.Path)
-
writeAuthError(w, "DPoP proof required")
-
return
-
}
-
-
proof, err := m.verifyDPoPBinding(r, claims, dpopHeader, token)
-
if err != nil {
-
log.Printf("[AUTH_FAILURE] type=dpop_verification_failed ip=%s method=%s path=%s error=%v",
-
r.RemoteAddr, r.Method, r.URL.Path, err)
-
writeAuthError(w, "Invalid DPoP proof")
-
return
-
}
-
-
// Store verified DPoP proof in context
-
ctx := context.WithValue(r.Context(), DPoPProofKey, proof)
-
r = r.WithContext(ctx)
-
} else if dpopHeader != "" {
-
// DPoP header present but token doesn't have cnf.jkt - this is suspicious
-
// Log warning but don't reject (could be a misconfigured client)
-
log.Printf("[AUTH_WARNING] type=unexpected_dpop ip=%s method=%s path=%s warning=DPoP header present but token has no cnf.jkt",
-
r.RemoteAddr, r.Method, r.URL.Path)
-
}
+
// Load full OAuth session from database
+
session, err := m.store.GetSession(r.Context(), did, sealedSession.SessionID)
+
if err != nil {
+
log.Printf("[AUTH_FAILURE] type=session_not_found ip=%s method=%s path=%s did=%s session_id=%s error=%v",
+
r.RemoteAddr, r.Method, r.URL.Path, sealedSession.DID, sealedSession.SessionID, err)
+
writeAuthError(w, "Session not found or expired")
+
return
}
-
// Extract user DID from 'sub' claim
-
userDID := claims.Subject
-
if userDID == "" {
-
writeAuthError(w, "Missing user DID in token")
+
// Verify session DID matches token DID
+
if session.AccountDID.String() != sealedSession.DID {
+
log.Printf("[AUTH_FAILURE] type=did_mismatch ip=%s method=%s path=%s token_did=%s session_did=%s",
+
r.RemoteAddr, r.Method, r.URL.Path, sealedSession.DID, session.AccountDID.String())
+
writeAuthError(w, "Session DID mismatch")
return
}
-
// Inject user info and access token into context
-
ctx := context.WithValue(r.Context(), UserDIDKey, userDID)
-
ctx = context.WithValue(ctx, JWTClaimsKey, claims)
-
ctx = context.WithValue(ctx, UserAccessToken, token)
+
log.Printf("[AUTH_SUCCESS] ip=%s method=%s path=%s did=%s session_id=%s",
+
r.RemoteAddr, r.Method, r.URL.Path, sealedSession.DID, sealedSession.SessionID)
+
+
// Inject user info and session into context
+
ctx := context.WithValue(r.Context(), UserDIDKey, sealedSession.DID)
+
ctx = context.WithValue(ctx, OAuthSessionKey, session)
+
// Store access token for backward compatibility
+
ctx = context.WithValue(ctx, UserAccessToken, session.AccessToken)
// Call next handler
next.ServeHTTP(w, r.WithContext(ctx))
})
}
-
// OptionalAuth middleware loads user info if authenticated, but doesn't require it
-
// Useful for endpoints that work for both authenticated and anonymous users
+
// OptionalAuth middleware loads user info if authenticated, but doesn't require it.
+
// Useful for endpoints that work for both authenticated and anonymous users.
+
//
+
// Supports sealed session tokens via:
+
// - Authorization: Bearer <sealed_token>
+
// - Cookie: coves_session=<sealed_token>
//
-
// Only accepts DPoP authorization scheme per RFC 9449:
-
// - Authorization: DPoP <token> (DPoP-bound tokens)
-
func (m *AtProtoAuthMiddleware) OptionalAuth(next http.Handler) http.Handler {
+
// If authentication fails, continues without user context (does not return error).
+
func (m *OAuthAuthMiddleware) OptionalAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-
// Extract Authorization header
+
var token string
+
+
// Try Authorization header first (for mobile/API clients)
authHeader := r.Header.Get("Authorization")
+
if authHeader != "" {
+
var ok bool
+
token, ok = extractBearerToken(authHeader)
+
if !ok {
+
// Invalid format - continue without user context
+
next.ServeHTTP(w, r)
+
return
+
}
+
}
+
+
// If no header, try session cookie (for web clients)
+
if token == "" {
+
if cookie, err := r.Cookie("coves_session"); err == nil {
+
token = cookie.Value
+
}
+
}
-
// Only accept DPoP scheme per RFC 9449
-
// HTTP auth schemes are case-insensitive per RFC 7235
-
token, ok := extractDPoPToken(authHeader)
-
if !ok {
-
// Not authenticated or invalid format - continue without user context
+
// If still no token, continue without authentication
+
if token == "" {
next.ServeHTTP(w, r)
return
}
-
var claims *auth.Claims
-
var err error
+
// Try to authenticate (don't write errors, just continue without user context on failure)
+
sealedSession, err := m.unsealer.UnsealSession(token)
+
if err != nil {
+
next.ServeHTTP(w, r)
+
return
+
}
-
if m.skipVerify {
-
// Phase 1: Parse only
-
claims, err = auth.ParseJWT(token)
-
} else {
-
// Phase 2: Full verification
-
// SECURITY: Token MUST be verified before trusting claims
-
claims, err = auth.VerifyJWT(r.Context(), token, m.jwksFetcher)
+
// Parse DID
+
did, err := syntax.ParseDID(sealedSession.DID)
+
if err != nil {
+
log.Printf("[AUTH_WARNING] Optional auth: invalid DID: %v", err)
+
next.ServeHTTP(w, r)
+
return
}
+
// Load full OAuth session from database
+
session, err := m.store.GetSession(r.Context(), did, sealedSession.SessionID)
if err != nil {
-
// Invalid token - continue without user context
-
log.Printf("Optional auth failed: %v", err)
+
log.Printf("[AUTH_WARNING] Optional auth: session not found: %v", err)
next.ServeHTTP(w, r)
return
}
-
// Check DPoP binding if token has cnf.jkt (after successful verification)
-
// SECURITY: If token has cnf.jkt but no DPoP header, we cannot trust it
-
// (could be a stolen token). Continue as unauthenticated.
-
if !m.skipVerify {
-
dpopHeader := r.Header.Get("DPoP")
-
hasCnfJkt := claims.Confirmation != nil && claims.Confirmation["jkt"] != nil
-
-
if hasCnfJkt {
-
if dpopHeader == "" {
-
// Token requires DPoP binding but no proof provided
-
// Cannot trust this token - continue without auth
-
log.Printf("[AUTH_WARNING] Optional auth: token has cnf.jkt but no DPoP header - treating as unauthenticated (potential token theft)")
-
next.ServeHTTP(w, r)
-
return
-
}
-
-
proof, err := m.verifyDPoPBinding(r, claims, dpopHeader, token)
-
if err != nil {
-
// DPoP verification failed - cannot trust this token
-
log.Printf("[AUTH_WARNING] Optional auth: DPoP verification failed - treating as unauthenticated: %v", err)
-
next.ServeHTTP(w, r)
-
return
-
}
-
-
// DPoP verified - inject proof into context
-
ctx := context.WithValue(r.Context(), UserDIDKey, claims.Subject)
-
ctx = context.WithValue(ctx, JWTClaimsKey, claims)
-
ctx = context.WithValue(ctx, UserAccessToken, token)
-
ctx = context.WithValue(ctx, DPoPProofKey, proof)
-
next.ServeHTTP(w, r.WithContext(ctx))
-
return
-
}
+
// Verify session DID matches token DID
+
if session.AccountDID.String() != sealedSession.DID {
+
log.Printf("[AUTH_WARNING] Optional auth: DID mismatch")
+
next.ServeHTTP(w, r)
+
return
}
-
// No DPoP binding required - inject user info and access token into context
-
ctx := context.WithValue(r.Context(), UserDIDKey, claims.Subject)
-
ctx = context.WithValue(ctx, JWTClaimsKey, claims)
-
ctx = context.WithValue(ctx, UserAccessToken, token)
+
// Build authenticated context
+
ctx := context.WithValue(r.Context(), UserDIDKey, sealedSession.DID)
+
ctx = context.WithValue(ctx, OAuthSessionKey, session)
+
ctx = context.WithValue(ctx, UserAccessToken, session.AccessToken)
-
// Call next handler
next.ServeHTTP(w, r.WithContext(ctx))
})
}
···
return did
}
-
// GetJWTClaims extracts the JWT claims from the request context
+
// GetOAuthSession extracts the OAuth session from the request context
// Returns nil if not authenticated
-
func GetJWTClaims(r *http.Request) *auth.Claims {
-
claims, _ := r.Context().Value(JWTClaimsKey).(*auth.Claims)
-
return claims
-
}
-
-
// SetTestUserDID sets the user DID in the context for testing purposes
-
// This function should ONLY be used in tests to mock authenticated users
-
func SetTestUserDID(ctx context.Context, userDID string) context.Context {
-
return context.WithValue(ctx, UserDIDKey, userDID)
+
// Handlers can use this to make authenticated PDS calls
+
func GetOAuthSession(r *http.Request) *oauthlib.ClientSessionData {
+
session, _ := r.Context().Value(OAuthSessionKey).(*oauthlib.ClientSessionData)
+
return session
}
// GetUserAccessToken extracts the user's access token from the request context
···
return token
}
-
// GetDPoPProof extracts the DPoP proof from the request context
-
// Returns nil if no DPoP proof was verified
-
func GetDPoPProof(r *http.Request) *auth.DPoPProof {
-
proof, _ := r.Context().Value(DPoPProofKey).(*auth.DPoPProof)
-
return proof
-
}
-
-
// verifyDPoPBinding verifies DPoP proof binding for an ALREADY VERIFIED token.
-
//
-
// SECURITY: This function ONLY verifies the DPoP proof and its binding to the token.
-
// The access token MUST be signature-verified BEFORE calling this function.
-
// DPoP is an ADDITIONAL security layer, not a replacement for signature verification.
-
//
-
// This prevents token theft attacks by proving the client possesses the private key
-
// corresponding to the public key thumbprint in the token's cnf.jkt claim.
-
func (m *AtProtoAuthMiddleware) verifyDPoPBinding(r *http.Request, claims *auth.Claims, dpopProofHeader, accessToken string) (*auth.DPoPProof, error) {
-
// Extract the cnf.jkt claim from the already-verified token
-
jkt, err := auth.ExtractCnfJkt(claims)
-
if err != nil {
-
return nil, fmt.Errorf("token requires DPoP but missing cnf.jkt: %w", err)
-
}
-
-
// Build the HTTP URI for DPoP verification
-
// Use the full URL including scheme and host, respecting proxy headers
-
scheme, host := extractSchemeAndHost(r)
-
-
// Use EscapedPath to preserve percent-encoding (P3 fix)
-
// r.URL.Path is decoded, but DPoP proofs contain the raw encoded path
-
path := r.URL.EscapedPath()
-
if path == "" {
-
path = r.URL.Path // Fallback if EscapedPath returns empty
-
}
-
-
httpURI := scheme + "://" + host + path
-
-
// Verify the DPoP proof
-
proof, err := m.dpopVerifier.VerifyDPoPProof(dpopProofHeader, r.Method, httpURI)
-
if err != nil {
-
return nil, fmt.Errorf("DPoP proof verification failed: %w", err)
-
}
-
-
// Verify the binding between the proof and the token (cnf.jkt)
-
if err := m.dpopVerifier.VerifyTokenBinding(proof, jkt); err != nil {
-
return nil, fmt.Errorf("DPoP binding verification failed: %w", err)
-
}
-
-
// Verify the access token hash (ath) if present in the proof
-
// Per RFC 9449 section 4.2, if ath is present, it MUST match the access token
-
if err := m.dpopVerifier.VerifyAccessTokenHash(proof, accessToken); err != nil {
-
return nil, fmt.Errorf("DPoP ath verification failed: %w", err)
-
}
-
-
return proof, nil
-
}
-
-
// extractSchemeAndHost extracts the scheme and host from the request,
-
// respecting proxy headers (X-Forwarded-Proto, X-Forwarded-Host, Forwarded).
-
// This is critical for DPoP verification when behind TLS-terminating proxies.
-
func extractSchemeAndHost(r *http.Request) (scheme, host string) {
-
// Start with request defaults
-
scheme = r.URL.Scheme
-
host = r.Host
-
-
// Check X-Forwarded-Proto for scheme (most common)
-
if forwardedProto := r.Header.Get("X-Forwarded-Proto"); forwardedProto != "" {
-
parts := strings.Split(forwardedProto, ",")
-
if len(parts) > 0 && strings.TrimSpace(parts[0]) != "" {
-
scheme = strings.ToLower(strings.TrimSpace(parts[0]))
-
}
-
}
-
-
// Check X-Forwarded-Host for host (common with nginx/traefik)
-
if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" {
-
parts := strings.Split(forwardedHost, ",")
-
if len(parts) > 0 && strings.TrimSpace(parts[0]) != "" {
-
host = strings.TrimSpace(parts[0])
-
}
-
}
-
-
// Check standard Forwarded header (RFC 7239) - takes precedence if present
-
// Format: Forwarded: for=192.0.2.60;proto=http;by=203.0.113.43;host=example.com
-
// RFC 7239 allows: mixed-case keys (Proto, PROTO), quoted values (host="example.com")
-
if forwarded := r.Header.Get("Forwarded"); forwarded != "" {
-
// Parse the first entry (comma-separated list)
-
firstEntry := strings.Split(forwarded, ",")[0]
-
for _, part := range strings.Split(firstEntry, ";") {
-
part = strings.TrimSpace(part)
-
// Split on first '=' to properly handle key=value pairs
-
if idx := strings.Index(part, "="); idx != -1 {
-
key := strings.ToLower(strings.TrimSpace(part[:idx]))
-
value := strings.TrimSpace(part[idx+1:])
-
// Strip optional quotes per RFC 7239 section 4
-
value = strings.Trim(value, "\"")
-
-
switch key {
-
case "proto":
-
scheme = strings.ToLower(value)
-
case "host":
-
host = value
-
}
-
}
-
}
-
}
-
-
// Fallback scheme detection from TLS
-
if scheme == "" {
-
if r.TLS != nil {
-
scheme = "https"
-
} else {
-
scheme = "http"
-
}
-
}
-
-
return strings.ToLower(scheme), host
-
}
-
-
// writeAuthError writes a JSON error response for authentication failures
-
func writeAuthError(w http.ResponseWriter, message string) {
-
w.Header().Set("Content-Type", "application/json")
-
w.WriteHeader(http.StatusUnauthorized)
-
// Simple error response matching XRPC error format
-
response := `{"error":"AuthenticationRequired","message":"` + message + `"}`
-
if _, err := w.Write([]byte(response)); err != nil {
-
log.Printf("Failed to write auth error response: %v", err)
-
}
+
// SetTestUserDID sets the user DID in the context for testing purposes
+
// This function should ONLY be used in tests to mock authenticated users
+
func SetTestUserDID(ctx context.Context, userDID string) context.Context {
+
return context.WithValue(ctx, UserDIDKey, userDID)
}
-
// extractDPoPToken extracts the token from a DPoP Authorization header.
-
// HTTP auth schemes are case-insensitive per RFC 7235, so "DPoP", "dpop", "DPOP" are all valid.
-
// Returns the token and true if valid DPoP scheme, empty string and false otherwise.
-
func extractDPoPToken(authHeader string) (string, bool) {
+
// extractBearerToken extracts the token from a Bearer Authorization header.
+
// HTTP auth schemes are case-insensitive per RFC 7235, so "Bearer", "bearer", "BEARER" are all valid.
+
// Returns the token and true if valid Bearer scheme, empty string and false otherwise.
+
func extractBearerToken(authHeader string) (string, bool) {
if authHeader == "" {
return "", false
}
-
// Split on first space: "DPoP <token>" -> ["DPoP", "<token>"]
+
// Split on first space: "Bearer <token>" -> ["Bearer", "<token>"]
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 {
return "", false
}
// Case-insensitive scheme comparison per RFC 7235
-
if !strings.EqualFold(parts[0], "DPoP") {
+
if !strings.EqualFold(parts[0], "Bearer") {
return "", false
}
···
return token, true
}
+
+
// writeAuthError writes a JSON error response for authentication failures
+
func writeAuthError(w http.ResponseWriter, message string) {
+
w.Header().Set("Content-Type", "application/json")
+
w.WriteHeader(http.StatusUnauthorized)
+
// Use json.NewEncoder to properly escape the message and prevent injection
+
if err := json.NewEncoder(w).Encode(map[string]string{
+
"error": "AuthenticationRequired",
+
"message": message,
+
}); err != nil {
+
log.Printf("Failed to write auth error response: %v", err)
+
}
+
}
+511 -728
internal/api/middleware/auth_test.go
···
package middleware
import (
-
"Coves/internal/atproto/auth"
+
"Coves/internal/atproto/oauth"
"context"
-
"crypto/ecdsa"
-
"crypto/elliptic"
-
"crypto/rand"
-
"crypto/sha256"
"encoding/base64"
+
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
···
"testing"
"time"
-
"github.com/golang-jwt/jwt/v5"
-
"github.com/google/uuid"
+
oauthlib "github.com/bluesky-social/indigo/atproto/auth/oauth"
+
"github.com/bluesky-social/indigo/atproto/syntax"
)
-
// mockJWKSFetcher is a test double for JWKSFetcher
-
type mockJWKSFetcher struct {
-
shouldFail bool
+
// mockOAuthClient is a test double for OAuthClient
+
type mockOAuthClient struct {
+
sealSecret []byte
+
shouldFailSeal bool
}
-
func (m *mockJWKSFetcher) FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) {
-
if m.shouldFail {
-
return nil, fmt.Errorf("mock fetch failure")
+
func newMockOAuthClient() *mockOAuthClient {
+
// Create a 32-byte seal secret for testing
+
secret := []byte("test-secret-key-32-bytes-long!!")
+
return &mockOAuthClient{
+
sealSecret: secret,
}
-
// Return nil - we won't actually verify signatures in Phase 1 tests
-
return nil, nil
}
-
// createTestToken creates a test JWT with the given DID
-
func createTestToken(did string) string {
-
claims := jwt.MapClaims{
-
"sub": did,
-
"iss": "https://test.pds.local",
-
"scope": "atproto",
-
"exp": time.Now().Add(1 * time.Hour).Unix(),
-
"iat": time.Now().Unix(),
+
func (m *mockOAuthClient) UnsealSession(token string) (*oauth.SealedSession, error) {
+
if m.shouldFailSeal {
+
return nil, fmt.Errorf("mock unseal failure")
}
-
token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
-
tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
-
return tokenString
+
// For testing, we'll decode a simple format: base64(did|sessionID|expiresAt)
+
// In production this would be AES-GCM encrypted
+
// Using pipe separator to avoid conflicts with colon in DIDs
+
decoded, err := base64.RawURLEncoding.DecodeString(token)
+
if err != nil {
+
return nil, fmt.Errorf("invalid token encoding: %w", err)
+
}
+
+
parts := strings.Split(string(decoded), "|")
+
if len(parts) != 3 {
+
return nil, fmt.Errorf("invalid token format")
+
}
+
+
var expiresAt int64
+
_, _ = fmt.Sscanf(parts[2], "%d", &expiresAt)
+
+
// Check expiration
+
if expiresAt <= time.Now().Unix() {
+
return nil, fmt.Errorf("token expired")
+
}
+
+
return &oauth.SealedSession{
+
DID: parts[0],
+
SessionID: parts[1],
+
ExpiresAt: expiresAt,
+
}, nil
+
}
+
+
// Helper to create a test sealed token
+
func (m *mockOAuthClient) createTestToken(did, sessionID string, ttl time.Duration) string {
+
expiresAt := time.Now().Add(ttl).Unix()
+
payload := fmt.Sprintf("%s|%s|%d", did, sessionID, expiresAt)
+
return base64.RawURLEncoding.EncodeToString([]byte(payload))
+
}
+
+
// mockOAuthStore is a test double for ClientAuthStore
+
type mockOAuthStore struct {
+
sessions map[string]*oauthlib.ClientSessionData
+
}
+
+
func newMockOAuthStore() *mockOAuthStore {
+
return &mockOAuthStore{
+
sessions: make(map[string]*oauthlib.ClientSessionData),
+
}
+
}
+
+
func (m *mockOAuthStore) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauthlib.ClientSessionData, error) {
+
key := did.String() + ":" + sessionID
+
session, ok := m.sessions[key]
+
if !ok {
+
return nil, fmt.Errorf("session not found")
+
}
+
return session, nil
}
-
// TestRequireAuth_ValidToken tests that valid tokens are accepted with DPoP scheme (Phase 1)
+
func (m *mockOAuthStore) SaveSession(ctx context.Context, session oauthlib.ClientSessionData) error {
+
key := session.AccountDID.String() + ":" + session.SessionID
+
m.sessions[key] = &session
+
return nil
+
}
+
+
func (m *mockOAuthStore) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error {
+
key := did.String() + ":" + sessionID
+
delete(m.sessions, key)
+
return nil
+
}
+
+
func (m *mockOAuthStore) GetAuthRequestInfo(ctx context.Context, state string) (*oauthlib.AuthRequestData, error) {
+
return nil, fmt.Errorf("not implemented")
+
}
+
+
func (m *mockOAuthStore) SaveAuthRequestInfo(ctx context.Context, info oauthlib.AuthRequestData) error {
+
return fmt.Errorf("not implemented")
+
}
+
+
func (m *mockOAuthStore) DeleteAuthRequestInfo(ctx context.Context, state string) error {
+
return fmt.Errorf("not implemented")
+
}
+
+
// TestRequireAuth_ValidToken tests that valid sealed tokens are accepted
func TestRequireAuth_ValidToken(t *testing.T) {
-
fetcher := &mockJWKSFetcher{}
-
middleware := NewAtProtoAuthMiddleware(fetcher, true) // skipVerify=true
+
client := newMockOAuthClient()
+
store := newMockOAuthStore()
+
+
// Create a test session
+
did := syntax.DID("did:plc:test123")
+
sessionID := "session123"
+
session := &oauthlib.ClientSessionData{
+
AccountDID: did,
+
SessionID: sessionID,
+
AccessToken: "test_access_token",
+
HostURL: "https://pds.example.com",
+
}
+
_ = store.SaveSession(context.Background(), *session)
+
+
middleware := NewOAuthAuthMiddleware(client, store)
handlerCalled := false
handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlerCalled = true
// Verify DID was extracted and injected into context
-
did := GetUserDID(r)
-
if did != "did:plc:test123" {
-
t.Errorf("expected DID 'did:plc:test123', got %s", did)
+
extractedDID := GetUserDID(r)
+
if extractedDID != "did:plc:test123" {
+
t.Errorf("expected DID 'did:plc:test123', got %s", extractedDID)
}
-
// Verify claims were injected
-
claims := GetJWTClaims(r)
-
if claims == nil {
-
t.Error("expected claims to be non-nil")
+
// Verify OAuth session was injected
+
oauthSession := GetOAuthSession(r)
+
if oauthSession == nil {
+
t.Error("expected OAuth session to be non-nil")
return
}
-
if claims.Subject != "did:plc:test123" {
-
t.Errorf("expected claims.Subject 'did:plc:test123', got %s", claims.Subject)
+
if oauthSession.SessionID != sessionID {
+
t.Errorf("expected session ID '%s', got %s", sessionID, oauthSession.SessionID)
+
}
+
+
// Verify access token is available
+
accessToken := GetUserAccessToken(r)
+
if accessToken != "test_access_token" {
+
t.Errorf("expected access token 'test_access_token', got %s", accessToken)
}
w.WriteHeader(http.StatusOK)
}))
-
token := createTestToken("did:plc:test123")
+
token := client.createTestToken("did:plc:test123", sessionID, time.Hour)
req := httptest.NewRequest("GET", "/test", nil)
-
req.Header.Set("Authorization", "DPoP "+token)
+
req.Header.Set("Authorization", "Bearer "+token)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
···
// TestRequireAuth_MissingAuthHeader tests that missing Authorization header is rejected
func TestRequireAuth_MissingAuthHeader(t *testing.T) {
-
fetcher := &mockJWKSFetcher{}
-
middleware := NewAtProtoAuthMiddleware(fetcher, true)
+
client := newMockOAuthClient()
+
store := newMockOAuthStore()
+
middleware := NewOAuthAuthMiddleware(client, store)
handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("handler should not be called")
···
}
}
-
// TestRequireAuth_InvalidAuthHeaderFormat tests that non-DPoP tokens are rejected (including Bearer)
+
// TestRequireAuth_InvalidAuthHeaderFormat tests that non-Bearer tokens are rejected
func TestRequireAuth_InvalidAuthHeaderFormat(t *testing.T) {
-
fetcher := &mockJWKSFetcher{}
-
middleware := NewAtProtoAuthMiddleware(fetcher, true)
+
client := newMockOAuthClient()
+
store := newMockOAuthStore()
+
middleware := NewOAuthAuthMiddleware(client, store)
tests := []struct {
name string
header string
}{
{"Basic auth", "Basic dGVzdDp0ZXN0"},
-
{"Bearer scheme", "Bearer some-token"},
+
{"DPoP scheme", "DPoP some-token"},
{"Invalid format", "InvalidFormat"},
}
···
}
}
-
// TestRequireAuth_BearerRejectionErrorMessage verifies that Bearer tokens are rejected
-
// with a helpful error message guiding users to use DPoP scheme
-
func TestRequireAuth_BearerRejectionErrorMessage(t *testing.T) {
-
fetcher := &mockJWKSFetcher{}
-
middleware := NewAtProtoAuthMiddleware(fetcher, true)
-
-
handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-
t.Error("handler should not be called")
-
}))
-
-
req := httptest.NewRequest("GET", "/test", nil)
-
req.Header.Set("Authorization", "Bearer some-token")
-
w := httptest.NewRecorder()
-
-
handler.ServeHTTP(w, req)
-
-
if w.Code != http.StatusUnauthorized {
-
t.Errorf("expected status 401, got %d", w.Code)
-
}
+
// TestRequireAuth_CaseInsensitiveScheme verifies that Bearer scheme matching is case-insensitive
+
func TestRequireAuth_CaseInsensitiveScheme(t *testing.T) {
+
client := newMockOAuthClient()
+
store := newMockOAuthStore()
-
// Verify error message guides user to use DPoP
-
body := w.Body.String()
-
if !strings.Contains(body, "Expected: DPoP") {
-
t.Errorf("error message should guide user to use DPoP, got: %s", body)
+
// Create a test session
+
did := syntax.DID("did:plc:test123")
+
sessionID := "session123"
+
session := &oauthlib.ClientSessionData{
+
AccountDID: did,
+
SessionID: sessionID,
+
AccessToken: "test_access_token",
}
-
}
-
-
// TestRequireAuth_CaseInsensitiveScheme verifies that DPoP scheme matching is case-insensitive
-
// per RFC 7235 which states HTTP auth schemes are case-insensitive
-
func TestRequireAuth_CaseInsensitiveScheme(t *testing.T) {
-
fetcher := &mockJWKSFetcher{}
-
middleware := NewAtProtoAuthMiddleware(fetcher, true)
+
_ = store.SaveSession(context.Background(), *session)
-
// Create a valid JWT for testing
-
validToken := createValidJWT(t, "did:plc:test123", time.Hour)
+
middleware := NewOAuthAuthMiddleware(client, store)
+
token := client.createTestToken("did:plc:test123", sessionID, time.Hour)
testCases := []struct {
name string
scheme string
}{
-
{"lowercase", "dpop"},
-
{"uppercase", "DPOP"},
-
{"mixed_case", "DpOp"},
-
{"standard", "DPoP"},
+
{"lowercase", "bearer"},
+
{"uppercase", "BEARER"},
+
{"mixed_case", "BeArEr"},
+
{"standard", "Bearer"},
}
for _, tc := range testCases {
···
}))
req := httptest.NewRequest("GET", "/test", nil)
-
req.Header.Set("Authorization", tc.scheme+" "+validToken)
+
req.Header.Set("Authorization", tc.scheme+" "+token)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
···
}
}
-
// TestRequireAuth_MalformedToken tests that malformed JWTs are rejected
-
func TestRequireAuth_MalformedToken(t *testing.T) {
-
fetcher := &mockJWKSFetcher{}
-
middleware := NewAtProtoAuthMiddleware(fetcher, true)
+
// TestRequireAuth_InvalidToken tests that malformed sealed tokens are rejected
+
func TestRequireAuth_InvalidToken(t *testing.T) {
+
client := newMockOAuthClient()
+
store := newMockOAuthStore()
+
middleware := NewOAuthAuthMiddleware(client, store)
handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("handler should not be called")
}))
req := httptest.NewRequest("GET", "/test", nil)
-
req.Header.Set("Authorization", "DPoP not-a-valid-jwt")
+
req.Header.Set("Authorization", "Bearer not-a-valid-sealed-token")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
···
}
}
-
// TestRequireAuth_ExpiredToken tests that expired tokens are rejected
+
// TestRequireAuth_ExpiredToken tests that expired sealed tokens are rejected
func TestRequireAuth_ExpiredToken(t *testing.T) {
-
fetcher := &mockJWKSFetcher{}
-
middleware := NewAtProtoAuthMiddleware(fetcher, true)
+
client := newMockOAuthClient()
+
store := newMockOAuthStore()
+
+
// Create a test session
+
did := syntax.DID("did:plc:test123")
+
sessionID := "session123"
+
session := &oauthlib.ClientSessionData{
+
AccountDID: did,
+
SessionID: sessionID,
+
AccessToken: "test_access_token",
+
}
+
_ = store.SaveSession(context.Background(), *session)
+
+
middleware := NewOAuthAuthMiddleware(client, store)
handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("handler should not be called for expired token")
}))
-
// Create expired token
-
claims := jwt.MapClaims{
-
"sub": "did:plc:test123",
-
"iss": "https://test.pds.local",
-
"scope": "atproto",
-
"exp": time.Now().Add(-1 * time.Hour).Unix(), // Expired 1 hour ago
-
"iat": time.Now().Add(-2 * time.Hour).Unix(),
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
-
tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
+
// Create expired token (expired 1 hour ago)
+
token := client.createTestToken("did:plc:test123", sessionID, -time.Hour)
req := httptest.NewRequest("GET", "/test", nil)
-
req.Header.Set("Authorization", "DPoP "+tokenString)
+
req.Header.Set("Authorization", "Bearer "+token)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
···
}
}
-
// TestRequireAuth_MissingDID tests that tokens without DID are rejected
-
func TestRequireAuth_MissingDID(t *testing.T) {
-
fetcher := &mockJWKSFetcher{}
-
middleware := NewAtProtoAuthMiddleware(fetcher, true)
+
// TestRequireAuth_SessionNotFound tests that tokens with non-existent sessions are rejected
+
func TestRequireAuth_SessionNotFound(t *testing.T) {
+
client := newMockOAuthClient()
+
store := newMockOAuthStore()
+
middleware := NewOAuthAuthMiddleware(client, store)
handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("handler should not be called")
}))
-
// Create token without sub claim
-
claims := jwt.MapClaims{
-
// "sub" missing
-
"iss": "https://test.pds.local",
-
"scope": "atproto",
-
"exp": time.Now().Add(1 * time.Hour).Unix(),
-
"iat": time.Now().Unix(),
+
// Create token for session that doesn't exist in store
+
token := client.createTestToken("did:plc:nonexistent", "session999", time.Hour)
+
+
req := httptest.NewRequest("GET", "/test", nil)
+
req.Header.Set("Authorization", "Bearer "+token)
+
w := httptest.NewRecorder()
+
+
handler.ServeHTTP(w, req)
+
+
if w.Code != http.StatusUnauthorized {
+
t.Errorf("expected status 401, got %d", w.Code)
+
}
+
}
+
+
// TestRequireAuth_DIDMismatch tests that session DID must match token DID
+
func TestRequireAuth_DIDMismatch(t *testing.T) {
+
client := newMockOAuthClient()
+
store := newMockOAuthStore()
+
+
// Create a session with different DID than token
+
did := syntax.DID("did:plc:different")
+
sessionID := "session123"
+
session := &oauthlib.ClientSessionData{
+
AccountDID: did,
+
SessionID: sessionID,
+
AccessToken: "test_access_token",
}
+
// Store with key that matches token DID
+
key := "did:plc:test123:" + sessionID
+
store.sessions[key] = session
-
token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
-
tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
+
middleware := NewOAuthAuthMiddleware(client, store)
+
+
handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
t.Error("handler should not be called when DID mismatches")
+
}))
+
+
token := client.createTestToken("did:plc:test123", sessionID, time.Hour)
req := httptest.NewRequest("GET", "/test", nil)
-
req.Header.Set("Authorization", "DPoP "+tokenString)
+
req.Header.Set("Authorization", "Bearer "+token)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
···
}
}
-
// TestOptionalAuth_WithToken tests that OptionalAuth accepts valid DPoP tokens
+
// TestOptionalAuth_WithToken tests that OptionalAuth accepts valid Bearer tokens
func TestOptionalAuth_WithToken(t *testing.T) {
-
fetcher := &mockJWKSFetcher{}
-
middleware := NewAtProtoAuthMiddleware(fetcher, true)
+
client := newMockOAuthClient()
+
store := newMockOAuthStore()
+
+
// Create a test session
+
did := syntax.DID("did:plc:test123")
+
sessionID := "session123"
+
session := &oauthlib.ClientSessionData{
+
AccountDID: did,
+
SessionID: sessionID,
+
AccessToken: "test_access_token",
+
}
+
_ = store.SaveSession(context.Background(), *session)
+
+
middleware := NewOAuthAuthMiddleware(client, store)
handlerCalled := false
handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlerCalled = true
// Verify DID was extracted
-
did := GetUserDID(r)
-
if did != "did:plc:test123" {
-
t.Errorf("expected DID 'did:plc:test123', got %s", did)
+
extractedDID := GetUserDID(r)
+
if extractedDID != "did:plc:test123" {
+
t.Errorf("expected DID 'did:plc:test123', got %s", extractedDID)
}
w.WriteHeader(http.StatusOK)
}))
-
token := createTestToken("did:plc:test123")
+
token := client.createTestToken("did:plc:test123", sessionID, time.Hour)
req := httptest.NewRequest("GET", "/test", nil)
-
req.Header.Set("Authorization", "DPoP "+token)
+
req.Header.Set("Authorization", "Bearer "+token)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
···
// TestOptionalAuth_WithoutToken tests that OptionalAuth allows requests without tokens
func TestOptionalAuth_WithoutToken(t *testing.T) {
-
fetcher := &mockJWKSFetcher{}
-
middleware := NewAtProtoAuthMiddleware(fetcher, true)
+
client := newMockOAuthClient()
+
store := newMockOAuthStore()
+
middleware := NewOAuthAuthMiddleware(client, store)
handlerCalled := false
handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
···
// TestOptionalAuth_InvalidToken tests that OptionalAuth continues without auth on invalid token
func TestOptionalAuth_InvalidToken(t *testing.T) {
-
fetcher := &mockJWKSFetcher{}
-
middleware := NewAtProtoAuthMiddleware(fetcher, true)
+
client := newMockOAuthClient()
+
store := newMockOAuthStore()
+
middleware := NewOAuthAuthMiddleware(client, store)
handlerCalled := false
handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
···
}))
req := httptest.NewRequest("GET", "/test", nil)
-
req.Header.Set("Authorization", "DPoP not-a-valid-jwt")
+
req.Header.Set("Authorization", "Bearer not-a-valid-sealed-token")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
···
}
}
-
// TestGetJWTClaims_NotAuthenticated tests that GetJWTClaims returns nil when not authenticated
-
func TestGetJWTClaims_NotAuthenticated(t *testing.T) {
+
// TestGetOAuthSession_NotAuthenticated tests that GetOAuthSession returns nil when not authenticated
+
func TestGetOAuthSession_NotAuthenticated(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
-
claims := GetJWTClaims(req)
+
session := GetOAuthSession(req)
-
if claims != nil {
-
t.Errorf("expected nil claims, got %+v", claims)
+
if session != nil {
+
t.Errorf("expected nil session, got %+v", session)
}
}
-
// TestGetDPoPProof_NotAuthenticated tests that GetDPoPProof returns nil when no DPoP was verified
-
func TestGetDPoPProof_NotAuthenticated(t *testing.T) {
+
// TestGetUserAccessToken_NotAuthenticated tests that GetUserAccessToken returns empty when not authenticated
+
func TestGetUserAccessToken_NotAuthenticated(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
-
proof := GetDPoPProof(req)
+
token := GetUserAccessToken(req)
-
if proof != nil {
-
t.Errorf("expected nil proof, got %+v", proof)
+
if token != "" {
+
t.Errorf("expected empty token, got %s", token)
}
}
-
// TestRequireAuth_WithDPoP_SecurityModel tests the correct DPoP security model:
-
// Token MUST be verified first, then DPoP is checked as an additional layer.
-
// DPoP is NOT a fallback for failed token verification.
-
func TestRequireAuth_WithDPoP_SecurityModel(t *testing.T) {
-
// Generate an ECDSA key pair for DPoP
-
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
-
if err != nil {
-
t.Fatalf("failed to generate key: %v", err)
-
}
+
// TestSetTestUserDID tests the testing helper function
+
func TestSetTestUserDID(t *testing.T) {
+
ctx := context.Background()
+
ctx = SetTestUserDID(ctx, "did:plc:testuser")
-
// Calculate JWK thumbprint for cnf.jkt
-
jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey)
-
thumbprint, err := auth.CalculateJWKThumbprint(jwk)
-
if err != nil {
-
t.Fatalf("failed to calculate thumbprint: %v", err)
+
did, ok := ctx.Value(UserDIDKey).(string)
+
if !ok {
+
t.Error("DID not found in context")
}
+
if did != "did:plc:testuser" {
+
t.Errorf("expected 'did:plc:testuser', got %s", did)
+
}
+
}
-
t.Run("DPoP_is_NOT_fallback_for_failed_verification", func(t *testing.T) {
-
// SECURITY TEST: When token verification fails, DPoP should NOT be used as fallback
-
// This prevents an attacker from forging a token with their own cnf.jkt
-
-
// Create a DPoP-bound access token (unsigned - will fail verification)
-
claims := auth.Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:attacker",
-
Issuer: "https://external.pds.local",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
IssuedAt: jwt.NewNumericDate(time.Now()),
-
},
-
Scope: "atproto",
-
Confirmation: map[string]interface{}{
-
"jkt": thumbprint,
-
},
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
-
tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
+
// TestExtractBearerToken tests the Bearer token extraction logic
+
func TestExtractBearerToken(t *testing.T) {
+
tests := []struct {
+
name string
+
authHeader string
+
expectToken string
+
expectOK bool
+
}{
+
{"valid bearer", "Bearer token123", "token123", true},
+
{"lowercase bearer", "bearer token123", "token123", true},
+
{"uppercase bearer", "BEARER token123", "token123", true},
+
{"mixed case", "BeArEr token123", "token123", true},
+
{"empty header", "", "", false},
+
{"wrong scheme", "DPoP token123", "", false},
+
{"no token", "Bearer", "", false},
+
{"no space", "Bearertoken123", "", false},
+
{"extra spaces", "Bearer token123 ", "token123", true},
+
}
-
// Create valid DPoP proof (attacker has the private key)
-
dpopProof := createDPoPProof(t, privateKey, "GET", "https://test.local/api/endpoint")
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
token, ok := extractBearerToken(tt.authHeader)
+
if ok != tt.expectOK {
+
t.Errorf("expected ok=%v, got %v", tt.expectOK, ok)
+
}
+
if token != tt.expectToken {
+
t.Errorf("expected token '%s', got '%s'", tt.expectToken, token)
+
}
+
})
+
}
+
}
-
// Mock fetcher that fails (simulating external PDS without JWKS)
-
fetcher := &mockJWKSFetcher{shouldFail: true}
-
middleware := NewAtProtoAuthMiddleware(fetcher, false) // skipVerify=false
+
// TestRequireAuth_ValidCookie tests that valid session cookies are accepted
+
func TestRequireAuth_ValidCookie(t *testing.T) {
+
client := newMockOAuthClient()
+
store := newMockOAuthStore()
-
handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-
t.Error("SECURITY VULNERABILITY: handler was called despite token verification failure")
-
}))
+
// Create a test session
+
did := syntax.DID("did:plc:test123")
+
sessionID := "session123"
+
session := &oauthlib.ClientSessionData{
+
AccountDID: did,
+
SessionID: sessionID,
+
AccessToken: "test_access_token",
+
HostURL: "https://pds.example.com",
+
}
+
_ = store.SaveSession(context.Background(), *session)
-
req := httptest.NewRequest("GET", "https://test.local/api/endpoint", nil)
-
req.Header.Set("Authorization", "DPoP "+tokenString)
-
req.Header.Set("DPoP", dpopProof)
-
w := httptest.NewRecorder()
+
middleware := NewOAuthAuthMiddleware(client, store)
-
handler.ServeHTTP(w, req)
+
handlerCalled := false
+
handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
handlerCalled = true
-
// MUST reject - token verification failed, DPoP cannot substitute for signature verification
-
if w.Code != http.StatusUnauthorized {
-
t.Errorf("SECURITY: expected 401 for unverified token, got %d", w.Code)
+
// Verify DID was extracted and injected into context
+
extractedDID := GetUserDID(r)
+
if extractedDID != "did:plc:test123" {
+
t.Errorf("expected DID 'did:plc:test123', got %s", extractedDID)
}
-
})
-
t.Run("DPoP_required_when_cnf_jkt_present_in_verified_token", func(t *testing.T) {
-
// When token has cnf.jkt, DPoP header MUST be present
-
// This test uses skipVerify=true to simulate a verified token
-
-
claims := auth.Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
Issuer: "https://test.pds.local",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
IssuedAt: jwt.NewNumericDate(time.Now()),
-
},
-
Scope: "atproto",
-
Confirmation: map[string]interface{}{
-
"jkt": thumbprint,
-
},
+
// Verify OAuth session was injected
+
oauthSession := GetOAuthSession(r)
+
if oauthSession == nil {
+
t.Error("expected OAuth session to be non-nil")
+
return
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
-
tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
-
-
// NO DPoP header - should fail when skipVerify is false
-
// Note: with skipVerify=true, DPoP is not checked
-
fetcher := &mockJWKSFetcher{}
-
middleware := NewAtProtoAuthMiddleware(fetcher, true) // skipVerify=true for parsing
-
-
handlerCalled := false
-
handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-
handlerCalled = true
-
w.WriteHeader(http.StatusOK)
-
}))
-
-
req := httptest.NewRequest("GET", "https://test.local/api/endpoint", nil)
-
req.Header.Set("Authorization", "DPoP "+tokenString)
-
// No DPoP header
-
w := httptest.NewRecorder()
-
-
handler.ServeHTTP(w, req)
-
-
// With skipVerify=true, DPoP is not checked, so this should succeed
-
if !handlerCalled {
-
t.Error("handler should be called when skipVerify=true")
+
if oauthSession.SessionID != sessionID {
+
t.Errorf("expected session ID '%s', got %s", sessionID, oauthSession.SessionID)
}
-
})
-
}
-
-
// TestRequireAuth_TokenVerificationFails_DPoPNotUsedAsFallback is the key security test.
-
// It ensures that DPoP cannot be used as a fallback when token signature verification fails.
-
func TestRequireAuth_TokenVerificationFails_DPoPNotUsedAsFallback(t *testing.T) {
-
// Generate a key pair (attacker's key)
-
attackerKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
-
jwk := ecdsaPublicKeyToJWK(&attackerKey.PublicKey)
-
thumbprint, _ := auth.CalculateJWKThumbprint(jwk)
-
-
// Create a FORGED token claiming to be the victim
-
claims := auth.Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:victim_user", // Attacker claims to be victim
-
Issuer: "https://untrusted.pds",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
IssuedAt: jwt.NewNumericDate(time.Now()),
-
},
-
Scope: "atproto",
-
Confirmation: map[string]interface{}{
-
"jkt": thumbprint, // Attacker uses their own key
-
},
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
-
tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
-
-
// Attacker creates a valid DPoP proof with their key
-
dpopProof := createDPoPProof(t, attackerKey, "POST", "https://api.example.com/protected")
-
// Fetcher fails (external PDS without JWKS)
-
fetcher := &mockJWKSFetcher{shouldFail: true}
-
middleware := NewAtProtoAuthMiddleware(fetcher, false) // skipVerify=false - REAL verification
-
-
handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-
t.Fatalf("CRITICAL SECURITY FAILURE: Request authenticated as %s despite forged token!",
-
GetUserDID(r))
+
w.WriteHeader(http.StatusOK)
}))
-
req := httptest.NewRequest("POST", "https://api.example.com/protected", nil)
-
req.Header.Set("Authorization", "DPoP "+tokenString)
-
req.Header.Set("DPoP", dpopProof)
+
token := client.createTestToken("did:plc:test123", sessionID, time.Hour)
+
req := httptest.NewRequest("GET", "/test", nil)
+
req.AddCookie(&http.Cookie{
+
Name: "coves_session",
+
Value: token,
+
})
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
-
// MUST reject - the token signature was never verified
-
if w.Code != http.StatusUnauthorized {
-
t.Errorf("SECURITY VULNERABILITY: Expected 401, got %d. Token was not properly verified!", w.Code)
+
if !handlerCalled {
+
t.Error("handler was not called")
}
-
}
-
// TestVerifyDPoPBinding_UsesForwardedProto ensures we honor the external HTTPS
-
// scheme when TLS is terminated upstream and X-Forwarded-Proto is present.
-
func TestVerifyDPoPBinding_UsesForwardedProto(t *testing.T) {
-
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
-
if err != nil {
-
t.Fatalf("failed to generate key: %v", err)
+
if w.Code != http.StatusOK {
+
t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
+
}
-
jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey)
-
thumbprint, err := auth.CalculateJWKThumbprint(jwk)
-
if err != nil {
-
t.Fatalf("failed to calculate thumbprint: %v", err)
-
}
+
// TestRequireAuth_HeaderPrecedenceOverCookie tests that Authorization header takes precedence over cookie
+
func TestRequireAuth_HeaderPrecedenceOverCookie(t *testing.T) {
+
client := newMockOAuthClient()
+
store := newMockOAuthStore()
-
claims := &auth.Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
Issuer: "https://test.pds.local",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
IssuedAt: jwt.NewNumericDate(time.Now()),
-
},
-
Scope: "atproto",
-
Confirmation: map[string]interface{}{
-
"jkt": thumbprint,
-
},
+
// Create two test sessions
+
did1 := syntax.DID("did:plc:header")
+
sessionID1 := "session_header"
+
session1 := &oauthlib.ClientSessionData{
+
AccountDID: did1,
+
SessionID: sessionID1,
+
AccessToken: "header_token",
+
HostURL: "https://pds.example.com",
}
+
_ = store.SaveSession(context.Background(), *session1)
-
middleware := NewAtProtoAuthMiddleware(&mockJWKSFetcher{}, false)
-
defer middleware.Stop()
-
-
externalURI := "https://api.example.com/protected/resource"
-
dpopProof := createDPoPProof(t, privateKey, "GET", externalURI)
-
-
req := httptest.NewRequest("GET", "http://internal-service/protected/resource", nil)
-
req.Host = "api.example.com"
-
req.Header.Set("X-Forwarded-Proto", "https")
-
-
// Pass a fake access token - ath verification will pass since we don't include ath in the DPoP proof
-
fakeAccessToken := "fake-access-token-for-testing"
-
proof, err := middleware.verifyDPoPBinding(req, claims, dpopProof, fakeAccessToken)
-
if err != nil {
-
t.Fatalf("expected DPoP verification to succeed with forwarded proto, got %v", err)
+
did2 := syntax.DID("did:plc:cookie")
+
sessionID2 := "session_cookie"
+
session2 := &oauthlib.ClientSessionData{
+
AccountDID: did2,
+
SessionID: sessionID2,
+
AccessToken: "cookie_token",
+
HostURL: "https://pds.example.com",
}
+
_ = store.SaveSession(context.Background(), *session2)
-
if proof == nil || proof.Claims == nil {
-
t.Fatal("expected DPoP proof to be returned")
-
}
-
}
+
middleware := NewOAuthAuthMiddleware(client, store)
-
// TestVerifyDPoPBinding_UsesForwardedHost ensures we honor X-Forwarded-Host header
-
// when behind a TLS-terminating proxy that rewrites the Host header.
-
func TestVerifyDPoPBinding_UsesForwardedHost(t *testing.T) {
-
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
-
if err != nil {
-
t.Fatalf("failed to generate key: %v", err)
-
}
+
handlerCalled := false
+
handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
handlerCalled = true
-
jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey)
-
thumbprint, err := auth.CalculateJWKThumbprint(jwk)
-
if err != nil {
-
t.Fatalf("failed to calculate thumbprint: %v", err)
-
}
+
// Should get header DID, not cookie DID
+
extractedDID := GetUserDID(r)
+
if extractedDID != "did:plc:header" {
+
t.Errorf("expected header DID 'did:plc:header', got %s", extractedDID)
+
}
-
claims := &auth.Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
Issuer: "https://test.pds.local",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
IssuedAt: jwt.NewNumericDate(time.Now()),
-
},
-
Scope: "atproto",
-
Confirmation: map[string]interface{}{
-
"jkt": thumbprint,
-
},
-
}
+
w.WriteHeader(http.StatusOK)
+
}))
-
middleware := NewAtProtoAuthMiddleware(&mockJWKSFetcher{}, false)
-
defer middleware.Stop()
+
headerToken := client.createTestToken("did:plc:header", sessionID1, time.Hour)
+
cookieToken := client.createTestToken("did:plc:cookie", sessionID2, time.Hour)
-
// External URI that the client uses
-
externalURI := "https://api.example.com/protected/resource"
-
dpopProof := createDPoPProof(t, privateKey, "GET", externalURI)
+
req := httptest.NewRequest("GET", "/test", nil)
+
req.Header.Set("Authorization", "Bearer "+headerToken)
+
req.AddCookie(&http.Cookie{
+
Name: "coves_session",
+
Value: cookieToken,
+
})
+
w := httptest.NewRecorder()
-
// Request hits internal service with internal hostname, but X-Forwarded-Host has public hostname
-
req := httptest.NewRequest("GET", "http://internal-service:8080/protected/resource", nil)
-
req.Host = "internal-service:8080" // Internal host after proxy
-
req.Header.Set("X-Forwarded-Proto", "https")
-
req.Header.Set("X-Forwarded-Host", "api.example.com") // Original public host
+
handler.ServeHTTP(w, req)
-
fakeAccessToken := "fake-access-token-for-testing"
-
proof, err := middleware.verifyDPoPBinding(req, claims, dpopProof, fakeAccessToken)
-
if err != nil {
-
t.Fatalf("expected DPoP verification to succeed with X-Forwarded-Host, got %v", err)
+
if !handlerCalled {
+
t.Error("handler was not called")
}
-
if proof == nil || proof.Claims == nil {
-
t.Fatal("expected DPoP proof to be returned")
+
if w.Code != http.StatusOK {
+
t.Errorf("expected status 200, got %d", w.Code)
}
}
-
// TestVerifyDPoPBinding_UsesStandardForwardedHeader tests RFC 7239 Forwarded header parsing
-
func TestVerifyDPoPBinding_UsesStandardForwardedHeader(t *testing.T) {
-
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
-
if err != nil {
-
t.Fatalf("failed to generate key: %v", err)
-
}
-
-
jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey)
-
thumbprint, err := auth.CalculateJWKThumbprint(jwk)
-
if err != nil {
-
t.Fatalf("failed to calculate thumbprint: %v", err)
-
}
+
// TestRequireAuth_MissingBothHeaderAndCookie tests that missing both auth methods is rejected
+
func TestRequireAuth_MissingBothHeaderAndCookie(t *testing.T) {
+
client := newMockOAuthClient()
+
store := newMockOAuthStore()
+
middleware := NewOAuthAuthMiddleware(client, store)
-
claims := &auth.Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
Issuer: "https://test.pds.local",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
IssuedAt: jwt.NewNumericDate(time.Now()),
-
},
-
Scope: "atproto",
-
Confirmation: map[string]interface{}{
-
"jkt": thumbprint,
-
},
-
}
-
-
middleware := NewAtProtoAuthMiddleware(&mockJWKSFetcher{}, false)
-
defer middleware.Stop()
-
-
// External URI
-
externalURI := "https://api.example.com/protected/resource"
-
dpopProof := createDPoPProof(t, privateKey, "GET", externalURI)
+
handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
t.Error("handler should not be called")
+
}))
-
// Request with standard Forwarded header (RFC 7239)
-
req := httptest.NewRequest("GET", "http://internal-service/protected/resource", nil)
-
req.Host = "internal-service"
-
req.Header.Set("Forwarded", "for=192.0.2.60;proto=https;host=api.example.com")
+
req := httptest.NewRequest("GET", "/test", nil)
+
// No Authorization header and no cookie
+
w := httptest.NewRecorder()
-
fakeAccessToken := "fake-access-token-for-testing"
-
proof, err := middleware.verifyDPoPBinding(req, claims, dpopProof, fakeAccessToken)
-
if err != nil {
-
t.Fatalf("expected DPoP verification to succeed with Forwarded header, got %v", err)
-
}
+
handler.ServeHTTP(w, req)
-
if proof == nil {
-
t.Fatal("expected DPoP proof to be returned")
+
if w.Code != http.StatusUnauthorized {
+
t.Errorf("expected status 401, got %d", w.Code)
}
}
-
// TestVerifyDPoPBinding_ForwardedMixedCaseAndQuotes tests RFC 7239 edge cases:
-
// mixed-case keys (Proto vs proto) and quoted values (host="example.com")
-
func TestVerifyDPoPBinding_ForwardedMixedCaseAndQuotes(t *testing.T) {
-
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
-
if err != nil {
-
t.Fatalf("failed to generate key: %v", err)
-
}
-
-
jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey)
-
thumbprint, err := auth.CalculateJWKThumbprint(jwk)
-
if err != nil {
-
t.Fatalf("failed to calculate thumbprint: %v", err)
-
}
+
// TestRequireAuth_InvalidCookie tests that malformed cookie tokens are rejected
+
func TestRequireAuth_InvalidCookie(t *testing.T) {
+
client := newMockOAuthClient()
+
store := newMockOAuthStore()
+
middleware := NewOAuthAuthMiddleware(client, store)
-
claims := &auth.Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
Issuer: "https://test.pds.local",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
IssuedAt: jwt.NewNumericDate(time.Now()),
-
},
-
Scope: "atproto",
-
Confirmation: map[string]interface{}{
-
"jkt": thumbprint,
-
},
-
}
-
-
middleware := NewAtProtoAuthMiddleware(&mockJWKSFetcher{}, false)
-
defer middleware.Stop()
-
-
// External URI that the client uses
-
externalURI := "https://api.example.com/protected/resource"
-
dpopProof := createDPoPProof(t, privateKey, "GET", externalURI)
+
handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
t.Error("handler should not be called")
+
}))
-
// Request with RFC 7239 Forwarded header using:
-
// - Mixed-case keys: "Proto" instead of "proto", "Host" instead of "host"
-
// - Quoted value: Host="api.example.com" (legal per RFC 7239 section 4)
-
req := httptest.NewRequest("GET", "http://internal-service/protected/resource", nil)
-
req.Host = "internal-service"
-
req.Header.Set("Forwarded", `for=192.0.2.60;Proto=https;Host="api.example.com"`)
+
req := httptest.NewRequest("GET", "/test", nil)
+
req.AddCookie(&http.Cookie{
+
Name: "coves_session",
+
Value: "not-a-valid-sealed-token",
+
})
+
w := httptest.NewRecorder()
-
fakeAccessToken := "fake-access-token-for-testing"
-
proof, err := middleware.verifyDPoPBinding(req, claims, dpopProof, fakeAccessToken)
-
if err != nil {
-
t.Fatalf("expected DPoP verification to succeed with mixed-case/quoted Forwarded header, got %v", err)
-
}
+
handler.ServeHTTP(w, req)
-
if proof == nil {
-
t.Fatal("expected DPoP proof to be returned")
+
if w.Code != http.StatusUnauthorized {
+
t.Errorf("expected status 401, got %d", w.Code)
}
}
-
// TestVerifyDPoPBinding_AthValidation tests access token hash (ath) claim validation
-
func TestVerifyDPoPBinding_AthValidation(t *testing.T) {
-
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
-
if err != nil {
-
t.Fatalf("failed to generate key: %v", err)
-
}
+
// TestOptionalAuth_WithCookie tests that OptionalAuth accepts valid session cookies
+
func TestOptionalAuth_WithCookie(t *testing.T) {
+
client := newMockOAuthClient()
+
store := newMockOAuthStore()
-
jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey)
-
thumbprint, err := auth.CalculateJWKThumbprint(jwk)
-
if err != nil {
-
t.Fatalf("failed to calculate thumbprint: %v", err)
+
// Create a test session
+
did := syntax.DID("did:plc:test123")
+
sessionID := "session123"
+
session := &oauthlib.ClientSessionData{
+
AccountDID: did,
+
SessionID: sessionID,
+
AccessToken: "test_access_token",
}
+
_ = store.SaveSession(context.Background(), *session)
-
claims := &auth.Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
Issuer: "https://test.pds.local",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
IssuedAt: jwt.NewNumericDate(time.Now()),
-
},
-
Scope: "atproto",
-
Confirmation: map[string]interface{}{
-
"jkt": thumbprint,
-
},
-
}
-
-
middleware := NewAtProtoAuthMiddleware(&mockJWKSFetcher{}, false)
-
defer middleware.Stop()
-
-
accessToken := "real-access-token-12345"
-
-
t.Run("ath_matches_access_token", func(t *testing.T) {
-
// Create DPoP proof with ath claim matching the access token
-
dpopProof := createDPoPProofWithAth(t, privateKey, "GET", "https://api.example.com/resource", accessToken)
+
middleware := NewOAuthAuthMiddleware(client, store)
-
req := httptest.NewRequest("GET", "https://api.example.com/resource", nil)
-
req.Host = "api.example.com"
+
handlerCalled := false
+
handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
handlerCalled = true
-
proof, err := middleware.verifyDPoPBinding(req, claims, dpopProof, accessToken)
-
if err != nil {
-
t.Fatalf("expected verification to succeed with matching ath, got %v", err)
-
}
-
if proof == nil {
-
t.Fatal("expected proof to be returned")
+
// Verify DID was extracted
+
extractedDID := GetUserDID(r)
+
if extractedDID != "did:plc:test123" {
+
t.Errorf("expected DID 'did:plc:test123', got %s", extractedDID)
}
-
})
-
t.Run("ath_mismatch_rejected", func(t *testing.T) {
-
// Create DPoP proof with ath for a DIFFERENT token
-
differentToken := "different-token-67890"
-
dpopProof := createDPoPProofWithAth(t, privateKey, "POST", "https://api.example.com/resource", differentToken)
-
-
req := httptest.NewRequest("POST", "https://api.example.com/resource", nil)
-
req.Host = "api.example.com"
+
w.WriteHeader(http.StatusOK)
+
}))
-
// Try to use with the original access token - should fail
-
_, err := middleware.verifyDPoPBinding(req, claims, dpopProof, accessToken)
-
if err == nil {
-
t.Fatal("SECURITY: expected verification to fail when ath doesn't match access token")
-
}
-
if !strings.Contains(err.Error(), "ath") {
-
t.Errorf("error should mention ath mismatch, got: %v", err)
-
}
+
token := client.createTestToken("did:plc:test123", sessionID, time.Hour)
+
req := httptest.NewRequest("GET", "/test", nil)
+
req.AddCookie(&http.Cookie{
+
Name: "coves_session",
+
Value: token,
})
-
}
+
w := httptest.NewRecorder()
-
// TestMiddlewareStop tests that the middleware can be stopped properly
-
func TestMiddlewareStop(t *testing.T) {
-
fetcher := &mockJWKSFetcher{}
-
middleware := NewAtProtoAuthMiddleware(fetcher, false)
+
handler.ServeHTTP(w, req)
-
// Stop should not panic and should clean up resources
-
middleware.Stop()
+
if !handlerCalled {
+
t.Error("handler was not called")
+
}
-
// Calling Stop again should also be safe (idempotent-ish)
-
// Note: The underlying DPoPVerifier.Stop() closes a channel, so this might panic
-
// if not handled properly. We test that at least one Stop works.
+
if w.Code != http.StatusOK {
+
t.Errorf("expected status 200, got %d", w.Code)
+
}
}
-
// TestOptionalAuth_DPoPBoundToken_NoDPoPHeader tests that OptionalAuth treats
-
// tokens with cnf.jkt but no DPoP header as unauthenticated (potential token theft)
-
func TestOptionalAuth_DPoPBoundToken_NoDPoPHeader(t *testing.T) {
-
// Generate a key pair for DPoP binding
-
privateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
-
jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey)
-
thumbprint, _ := auth.CalculateJWKThumbprint(jwk)
-
-
// Create a DPoP-bound token (has cnf.jkt)
-
claims := auth.Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:user123",
-
Issuer: "https://test.pds.local",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
IssuedAt: jwt.NewNumericDate(time.Now()),
-
},
-
Scope: "atproto",
-
Confirmation: map[string]interface{}{
-
"jkt": thumbprint,
-
},
-
}
+
// TestOptionalAuth_InvalidCookie tests that OptionalAuth continues without auth on invalid cookie
+
func TestOptionalAuth_InvalidCookie(t *testing.T) {
+
client := newMockOAuthClient()
+
store := newMockOAuthStore()
+
middleware := NewOAuthAuthMiddleware(client, store)
-
token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
-
tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
-
-
// Use skipVerify=true to simulate a verified token
-
// (In production, skipVerify would be false and VerifyJWT would be called)
-
// However, for this test we need skipVerify=false to trigger DPoP checking
-
// But the fetcher will fail, so let's use skipVerify=true and verify the logic
-
// Actually, the DPoP check only happens when skipVerify=false
-
-
t.Run("with_skipVerify_false", func(t *testing.T) {
-
// This will fail at JWT verification level, but that's expected
-
// The important thing is the code path for DPoP checking
-
fetcher := &mockJWKSFetcher{shouldFail: true}
-
middleware := NewAtProtoAuthMiddleware(fetcher, false)
-
defer middleware.Stop()
-
-
handlerCalled := false
-
var capturedDID string
-
handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-
handlerCalled = true
-
capturedDID = GetUserDID(r)
-
w.WriteHeader(http.StatusOK)
-
}))
-
-
req := httptest.NewRequest("GET", "/test", nil)
-
req.Header.Set("Authorization", "DPoP "+tokenString)
-
// Deliberately NOT setting DPoP header
-
w := httptest.NewRecorder()
-
-
handler.ServeHTTP(w, req)
-
-
// Handler should be called (optional auth doesn't block)
-
if !handlerCalled {
-
t.Error("handler should be called")
-
}
+
handlerCalled := false
+
handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
handlerCalled = true
-
// But since JWT verification fails, user should not be authenticated
-
if capturedDID != "" {
-
t.Errorf("expected empty DID when verification fails, got %s", capturedDID)
+
// Verify no DID is set (invalid cookie ignored)
+
did := GetUserDID(r)
+
if did != "" {
+
t.Errorf("expected empty DID for invalid cookie, got %s", did)
}
-
})
-
t.Run("with_skipVerify_true_dpop_not_checked", func(t *testing.T) {
-
// When skipVerify=true, DPoP is not checked (Phase 1 mode)
-
fetcher := &mockJWKSFetcher{}
-
middleware := NewAtProtoAuthMiddleware(fetcher, true)
-
defer middleware.Stop()
-
-
handlerCalled := false
-
var capturedDID string
-
handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-
handlerCalled = true
-
capturedDID = GetUserDID(r)
-
w.WriteHeader(http.StatusOK)
-
}))
-
-
req := httptest.NewRequest("GET", "/test", nil)
-
req.Header.Set("Authorization", "DPoP "+tokenString)
-
// No DPoP header
-
w := httptest.NewRecorder()
-
-
handler.ServeHTTP(w, req)
-
-
if !handlerCalled {
-
t.Error("handler should be called")
-
}
+
w.WriteHeader(http.StatusOK)
+
}))
-
// With skipVerify=true, DPoP check is bypassed - token is trusted
-
if capturedDID != "did:plc:user123" {
-
t.Errorf("expected DID when skipVerify=true, got %s", capturedDID)
-
}
+
req := httptest.NewRequest("GET", "/test", nil)
+
req.AddCookie(&http.Cookie{
+
Name: "coves_session",
+
Value: "not-a-valid-sealed-token",
})
-
}
-
-
// TestDPoPReplayProtection tests that the same DPoP proof cannot be used twice
-
func TestDPoPReplayProtection(t *testing.T) {
-
// This tests the NonceCache functionality
-
cache := auth.NewNonceCache(5 * time.Minute)
-
defer cache.Stop()
-
-
jti := "unique-proof-id-123"
-
-
// First use should succeed
-
if !cache.CheckAndStore(jti) {
-
t.Error("First use of jti should succeed")
-
}
-
-
// Second use should fail (replay detected)
-
if cache.CheckAndStore(jti) {
-
t.Error("SECURITY: Replay attack not detected - same jti accepted twice")
-
}
+
w := httptest.NewRecorder()
-
// Different jti should succeed
-
if !cache.CheckAndStore("different-jti-456") {
-
t.Error("Different jti should succeed")
-
}
-
}
+
handler.ServeHTTP(w, req)
-
// Helper: createDPoPProof creates a DPoP proof JWT for testing
-
func createDPoPProof(t *testing.T, privateKey *ecdsa.PrivateKey, method, uri string) string {
-
// Create JWK from public key
-
jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey)
-
-
// Create DPoP claims with UUID for jti to ensure uniqueness across tests
-
claims := auth.DPoPClaims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
IssuedAt: jwt.NewNumericDate(time.Now()),
-
ID: uuid.New().String(),
-
},
-
HTTPMethod: method,
-
HTTPURI: uri,
+
if !handlerCalled {
+
t.Error("handler was not called")
}
-
// Create token with custom header
-
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
-
token.Header["typ"] = "dpop+jwt"
-
token.Header["jwk"] = jwk
-
-
// Sign with private key
-
signedToken, err := token.SignedString(privateKey)
-
if err != nil {
-
t.Fatalf("failed to sign DPoP proof: %v", err)
+
if w.Code != http.StatusOK {
+
t.Errorf("expected status 200, got %d", w.Code)
}
-
-
return signedToken
}
-
// Helper: createDPoPProofWithAth creates a DPoP proof JWT with ath (access token hash) claim
-
func createDPoPProofWithAth(t *testing.T, privateKey *ecdsa.PrivateKey, method, uri, accessToken string) string {
-
// Create JWK from public key
-
jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey)
-
-
// Calculate ath: base64url(SHA-256(access_token))
-
hash := sha256.Sum256([]byte(accessToken))
-
ath := base64.RawURLEncoding.EncodeToString(hash[:])
-
-
// Create DPoP claims with ath
-
claims := auth.DPoPClaims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
IssuedAt: jwt.NewNumericDate(time.Now()),
-
ID: uuid.New().String(),
-
},
-
HTTPMethod: method,
-
HTTPURI: uri,
-
AccessTokenHash: ath,
-
}
-
-
// Create token with custom header
-
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
-
token.Header["typ"] = "dpop+jwt"
-
token.Header["jwk"] = jwk
-
-
// Sign with private key
-
signedToken, err := token.SignedString(privateKey)
-
if err != nil {
-
t.Fatalf("failed to sign DPoP proof: %v", err)
+
// TestWriteAuthError_JSONEscaping tests that writeAuthError properly escapes messages
+
func TestWriteAuthError_JSONEscaping(t *testing.T) {
+
tests := []struct {
+
name string
+
message string
+
}{
+
{"simple message", "Missing authentication"},
+
{"message with quotes", `Invalid "token" format`},
+
{"message with newlines", "Invalid\ntoken\nformat"},
+
{"message with backslashes", `Invalid \ token`},
+
{"message with special chars", `Invalid <script>alert("xss")</script> token`},
+
{"message with unicode", "Invalid token: \u2028\u2029"},
}
-
return signedToken
-
}
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
w := httptest.NewRecorder()
+
writeAuthError(w, tt.message)
-
// Helper: ecdsaPublicKeyToJWK converts an ECDSA public key to JWK map
-
func ecdsaPublicKeyToJWK(pubKey *ecdsa.PublicKey) map[string]interface{} {
-
// Get curve name
-
var crv string
-
switch pubKey.Curve {
-
case elliptic.P256():
-
crv = "P-256"
-
case elliptic.P384():
-
crv = "P-384"
-
case elliptic.P521():
-
crv = "P-521"
-
default:
-
panic("unsupported curve")
-
}
+
// Verify status code
+
if w.Code != http.StatusUnauthorized {
+
t.Errorf("expected status 401, got %d", w.Code)
+
}
-
// Encode coordinates
-
xBytes := pubKey.X.Bytes()
-
yBytes := pubKey.Y.Bytes()
-
-
// Ensure proper byte length (pad if needed)
-
keySize := (pubKey.Curve.Params().BitSize + 7) / 8
-
xPadded := make([]byte, keySize)
-
yPadded := make([]byte, keySize)
-
copy(xPadded[keySize-len(xBytes):], xBytes)
-
copy(yPadded[keySize-len(yBytes):], yBytes)
-
-
return map[string]interface{}{
-
"kty": "EC",
-
"crv": crv,
-
"x": base64.RawURLEncoding.EncodeToString(xPadded),
-
"y": base64.RawURLEncoding.EncodeToString(yPadded),
-
}
-
}
+
// Verify content type
+
if ct := w.Header().Get("Content-Type"); ct != "application/json" {
+
t.Errorf("expected Content-Type 'application/json', got %s", ct)
+
}
-
// Helper: createValidJWT creates a valid unsigned JWT token for testing
-
// This is used with skipVerify=true middleware where signature verification is skipped
-
func createValidJWT(t *testing.T, subject string, expiry time.Duration) string {
-
t.Helper()
-
-
claims := auth.Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: subject,
-
Issuer: "https://test.pds.local",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiry)),
-
IssuedAt: jwt.NewNumericDate(time.Now()),
-
},
-
Scope: "atproto",
-
}
+
// Verify response is valid JSON
+
var response map[string]string
+
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
+
t.Fatalf("response is not valid JSON: %v\nBody: %s", err, w.Body.String())
+
}
-
// Create unsigned token (for skipVerify=true tests)
-
token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
-
signedToken, err := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
-
if err != nil {
-
t.Fatalf("failed to create test JWT: %v", err)
+
// Verify fields
+
if response["error"] != "AuthenticationRequired" {
+
t.Errorf("expected error 'AuthenticationRequired', got %s", response["error"])
+
}
+
if response["message"] != tt.message {
+
t.Errorf("expected message %q, got %q", tt.message, response["message"])
+
}
+
})
}
-
-
return signedToken
}
+1 -1
internal/api/routes/post.go
···
// RegisterPostRoutes registers post-related XRPC endpoints on the router
// Implements social.coves.community.post.* lexicon endpoints
-
func RegisterPostRoutes(r chi.Router, service posts.Service, authMiddleware *middleware.AtProtoAuthMiddleware) {
+
func RegisterPostRoutes(r chi.Router, service posts.Service, authMiddleware *middleware.OAuthAuthMiddleware) {
// Initialize handlers
createHandler := post.NewCreateHandler(service)
+291
tests/e2e/oauth_ratelimit_e2e_test.go
···
+
package e2e
+
+
import (
+
"Coves/internal/api/middleware"
+
"net/http"
+
"net/http/httptest"
+
"testing"
+
"time"
+
+
"github.com/stretchr/testify/assert"
+
)
+
+
// TestRateLimiting_E2E_OAuthEndpoints tests OAuth-specific rate limiting
+
// OAuth endpoints have stricter rate limits to prevent:
+
// - Credential stuffing attacks on login endpoints (10 req/min)
+
// - OAuth state exhaustion
+
// - Refresh token abuse (20 req/min)
+
func TestRateLimiting_E2E_OAuthEndpoints(t *testing.T) {
+
t.Run("Login endpoints have 10 req/min limit", func(t *testing.T) {
+
// Create rate limiter matching oauth.go config: 10 requests per minute
+
loginLimiter := middleware.NewRateLimiter(10, 1*time.Minute)
+
+
// Mock OAuth login handler
+
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
w.WriteHeader(http.StatusOK)
+
_, _ = w.Write([]byte("OK"))
+
})
+
+
handler := loginLimiter.Middleware(testHandler)
+
clientIP := "192.168.1.200:12345"
+
+
// Make exactly 10 requests (at limit)
+
for i := 0; i < 10; i++ {
+
req := httptest.NewRequest("GET", "/oauth/login", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
+
handler.ServeHTTP(rr, req)
+
+
assert.Equal(t, http.StatusOK, rr.Code, "Request %d should succeed", i+1)
+
}
+
+
// 11th request should be rate limited
+
req := httptest.NewRequest("GET", "/oauth/login", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
+
handler.ServeHTTP(rr, req)
+
+
assert.Equal(t, http.StatusTooManyRequests, rr.Code, "Request 11 should be rate limited")
+
assert.Contains(t, rr.Body.String(), "Rate limit exceeded", "Should have rate limit error message")
+
})
+
+
t.Run("Mobile login endpoints have 10 req/min limit", func(t *testing.T) {
+
loginLimiter := middleware.NewRateLimiter(10, 1*time.Minute)
+
+
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
w.WriteHeader(http.StatusOK)
+
})
+
+
handler := loginLimiter.Middleware(testHandler)
+
clientIP := "192.168.1.201:12345"
+
+
// Make 10 requests
+
for i := 0; i < 10; i++ {
+
req := httptest.NewRequest("GET", "/oauth/mobile/login", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
handler.ServeHTTP(rr, req)
+
assert.Equal(t, http.StatusOK, rr.Code)
+
}
+
+
// 11th request blocked
+
req := httptest.NewRequest("GET", "/oauth/mobile/login", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
handler.ServeHTTP(rr, req)
+
+
assert.Equal(t, http.StatusTooManyRequests, rr.Code, "Mobile login should be rate limited at 10 req/min")
+
})
+
+
t.Run("Refresh endpoint has 20 req/min limit", func(t *testing.T) {
+
// Refresh has higher limit (20 req/min) for legitimate token refresh
+
refreshLimiter := middleware.NewRateLimiter(20, 1*time.Minute)
+
+
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
w.WriteHeader(http.StatusOK)
+
})
+
+
handler := refreshLimiter.Middleware(testHandler)
+
clientIP := "192.168.1.202:12345"
+
+
// Make 20 requests
+
for i := 0; i < 20; i++ {
+
req := httptest.NewRequest("POST", "/oauth/refresh", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
handler.ServeHTTP(rr, req)
+
assert.Equal(t, http.StatusOK, rr.Code, "Request %d should succeed", i+1)
+
}
+
+
// 21st request blocked
+
req := httptest.NewRequest("POST", "/oauth/refresh", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
handler.ServeHTTP(rr, req)
+
+
assert.Equal(t, http.StatusTooManyRequests, rr.Code, "Refresh should be rate limited at 20 req/min")
+
})
+
+
t.Run("Logout endpoint has 10 req/min limit", func(t *testing.T) {
+
logoutLimiter := middleware.NewRateLimiter(10, 1*time.Minute)
+
+
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
w.WriteHeader(http.StatusOK)
+
})
+
+
handler := logoutLimiter.Middleware(testHandler)
+
clientIP := "192.168.1.203:12345"
+
+
// Make 10 requests
+
for i := 0; i < 10; i++ {
+
req := httptest.NewRequest("POST", "/oauth/logout", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
handler.ServeHTTP(rr, req)
+
assert.Equal(t, http.StatusOK, rr.Code)
+
}
+
+
// 11th request blocked
+
req := httptest.NewRequest("POST", "/oauth/logout", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
handler.ServeHTTP(rr, req)
+
+
assert.Equal(t, http.StatusTooManyRequests, rr.Code, "Logout should be rate limited at 10 req/min")
+
})
+
+
t.Run("OAuth callback has 10 req/min limit", func(t *testing.T) {
+
// Callback uses same limiter as login (part of auth flow)
+
callbackLimiter := middleware.NewRateLimiter(10, 1*time.Minute)
+
+
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
w.WriteHeader(http.StatusOK)
+
})
+
+
handler := callbackLimiter.Middleware(testHandler)
+
clientIP := "192.168.1.204:12345"
+
+
// Make 10 requests
+
for i := 0; i < 10; i++ {
+
req := httptest.NewRequest("GET", "/oauth/callback", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
handler.ServeHTTP(rr, req)
+
assert.Equal(t, http.StatusOK, rr.Code)
+
}
+
+
// 11th request blocked
+
req := httptest.NewRequest("GET", "/oauth/callback", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
handler.ServeHTTP(rr, req)
+
+
assert.Equal(t, http.StatusTooManyRequests, rr.Code, "Callback should be rate limited at 10 req/min")
+
})
+
+
t.Run("OAuth rate limits are stricter than global limit", func(t *testing.T) {
+
// Verify OAuth limits are more restrictive than global 100 req/min
+
const globalLimit = 100
+
const oauthLoginLimit = 10
+
const oauthRefreshLimit = 20
+
+
assert.Less(t, oauthLoginLimit, globalLimit, "OAuth login limit should be stricter than global")
+
assert.Less(t, oauthRefreshLimit, globalLimit, "OAuth refresh limit should be stricter than global")
+
assert.Greater(t, oauthRefreshLimit, oauthLoginLimit, "Refresh limit should be higher than login (legitimate use case)")
+
})
+
+
t.Run("OAuth limits prevent credential stuffing", func(t *testing.T) {
+
// Simulate credential stuffing attack
+
loginLimiter := middleware.NewRateLimiter(10, 1*time.Minute)
+
+
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
// Simulate failed login attempts
+
w.WriteHeader(http.StatusUnauthorized)
+
})
+
+
handler := loginLimiter.Middleware(testHandler)
+
attackerIP := "203.0.113.50:12345"
+
+
// Attacker tries 15 login attempts (credential stuffing)
+
successfulAttempts := 0
+
blockedAttempts := 0
+
+
for i := 0; i < 15; i++ {
+
req := httptest.NewRequest("GET", "/oauth/login", nil)
+
req.RemoteAddr = attackerIP
+
rr := httptest.NewRecorder()
+
+
handler.ServeHTTP(rr, req)
+
+
if rr.Code == http.StatusUnauthorized {
+
successfulAttempts++ // Reached handler (even if auth failed)
+
} else if rr.Code == http.StatusTooManyRequests {
+
blockedAttempts++
+
}
+
}
+
+
// Rate limiter should block 5 attempts after first 10
+
assert.Equal(t, 10, successfulAttempts, "Should allow 10 login attempts")
+
assert.Equal(t, 5, blockedAttempts, "Should block 5 attempts after limit reached")
+
})
+
+
t.Run("OAuth limits are per-endpoint", func(t *testing.T) {
+
// Each endpoint gets its own rate limiter
+
// This test verifies that limits are independent per endpoint
+
loginLimiter := middleware.NewRateLimiter(10, 1*time.Minute)
+
refreshLimiter := middleware.NewRateLimiter(20, 1*time.Minute)
+
+
loginHandler := loginLimiter.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
w.WriteHeader(http.StatusOK)
+
}))
+
+
refreshHandler := refreshLimiter.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
w.WriteHeader(http.StatusOK)
+
}))
+
+
clientIP := "192.168.1.205:12345"
+
+
// Exhaust login limit
+
for i := 0; i < 10; i++ {
+
req := httptest.NewRequest("GET", "/oauth/login", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
loginHandler.ServeHTTP(rr, req)
+
assert.Equal(t, http.StatusOK, rr.Code)
+
}
+
+
// Login limit exhausted
+
req := httptest.NewRequest("GET", "/oauth/login", nil)
+
req.RemoteAddr = clientIP
+
rr := httptest.NewRecorder()
+
loginHandler.ServeHTTP(rr, req)
+
assert.Equal(t, http.StatusTooManyRequests, rr.Code, "Login should be rate limited")
+
+
// Refresh endpoint should still work (independent limiter)
+
req = httptest.NewRequest("POST", "/oauth/refresh", nil)
+
req.RemoteAddr = clientIP
+
rr = httptest.NewRecorder()
+
refreshHandler.ServeHTTP(rr, req)
+
assert.Equal(t, http.StatusOK, rr.Code, "Refresh should not be affected by login rate limit")
+
})
+
}
+
+
// OAuth Rate Limiting Configuration Documentation
+
// ================================================
+
// This test file validates OAuth-specific rate limits applied in oauth.go:
+
//
+
// 1. Login Endpoints (Credential Stuffing Protection)
+
// - Endpoints: /oauth/login, /oauth/mobile/login, /oauth/callback
+
// - Limit: 10 requests per minute per IP
+
// - Reason: Prevent brute force and credential stuffing attacks
+
// - Implementation: internal/api/routes/oauth.go:21
+
//
+
// 2. Refresh Endpoint (Token Refresh)
+
// - Endpoint: /oauth/refresh
+
// - Limit: 20 requests per minute per IP
+
// - Reason: Allow legitimate token refresh while preventing abuse
+
// - Implementation: internal/api/routes/oauth.go:24
+
//
+
// 3. Logout Endpoint
+
// - Endpoint: /oauth/logout
+
// - Limit: 10 requests per minute per IP
+
// - Reason: Prevent session exhaustion attacks
+
// - Implementation: internal/api/routes/oauth.go:27
+
//
+
// 4. Metadata Endpoints (No Extra Limit)
+
// - Endpoints: /oauth/client-metadata.json, /oauth/jwks.json
+
// - Limit: Global 100 requests per minute (from main.go)
+
// - Reason: Public metadata, not sensitive to rate abuse
+
//
+
// Security Benefits:
+
// - Credential Stuffing: Limits password guessing to 10 attempts/min
+
// - State Exhaustion: Prevents OAuth state generation spam
+
// - Token Abuse: Limits refresh token usage while allowing legitimate refresh
+
//
+
// Rate Limit Hierarchy:
+
// - OAuth login: 10 req/min (most restrictive)
+
// - OAuth refresh: 20 req/min (moderate)
+
// - Comments: 20 req/min (expensive queries)
+
// - Global: 100 req/min (baseline)
+910
tests/integration/oauth_e2e_test.go
···
+
package integration
+
+
import (
+
"Coves/internal/atproto/oauth"
+
"context"
+
"encoding/json"
+
"fmt"
+
"net/http"
+
"net/http/httptest"
+
"strings"
+
"testing"
+
"time"
+
+
oauthlib "github.com/bluesky-social/indigo/atproto/auth/oauth"
+
"github.com/bluesky-social/indigo/atproto/syntax"
+
"github.com/go-chi/chi/v5"
+
_ "github.com/lib/pq"
+
"github.com/pressly/goose/v3"
+
"github.com/stretchr/testify/assert"
+
"github.com/stretchr/testify/require"
+
)
+
+
// TestOAuth_Components tests OAuth component functionality without requiring PDS.
+
// This validates all Coves OAuth code:
+
// - Session storage and retrieval (PostgreSQL)
+
// - Token sealing (AES-GCM encryption)
+
// - Token unsealing (decryption + validation)
+
// - Session cleanup
+
//
+
// NOTE: Full OAuth redirect flow testing requires both HTTPS PDS and HTTPS Coves deployment.
+
// The OAuth redirect flow is handled by indigo's library and enforces OAuth 2.0 spec
+
// (HTTPS required for authorization servers and redirect URIs).
+
func TestOAuth_Components(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping OAuth component test in short mode")
+
}
+
+
// Setup test database
+
db := setupTestDB(t)
+
defer func() {
+
if err := db.Close(); err != nil {
+
t.Logf("Failed to close database: %v", err)
+
}
+
}()
+
+
// Run migrations to ensure OAuth tables exist
+
require.NoError(t, goose.SetDialect("postgres"))
+
require.NoError(t, goose.Up(db, "../../internal/db/migrations"))
+
+
t.Log("๐Ÿ”ง Testing OAuth Components")
+
+
ctx := context.Background()
+
+
// Setup OAuth client and store
+
store := SetupOAuthTestStore(t, db)
+
client := SetupOAuthTestClient(t, store)
+
require.NotNil(t, client, "OAuth client should be initialized")
+
+
// Use a test DID (doesn't need to exist on PDS for component tests)
+
testDID := "did:plc:componenttest123"
+
+
// Run component tests
+
testOAuthComponentsWithMockedSession(t, ctx, nil, store, client, testDID, "")
+
+
t.Log("")
+
t.Log(strings.Repeat("=", 60))
+
t.Log("โœ… OAuth Component Tests Complete")
+
t.Log(strings.Repeat("=", 60))
+
t.Log("Components validated:")
+
t.Log(" โœ“ Session storage (PostgreSQL)")
+
t.Log(" โœ“ Token sealing (AES-GCM encryption)")
+
t.Log(" โœ“ Token unsealing (decryption + validation)")
+
t.Log(" โœ“ Session cleanup")
+
t.Log("")
+
t.Log("NOTE: Full OAuth redirect flow requires HTTPS PDS + HTTPS Coves")
+
t.Log(strings.Repeat("=", 60))
+
}
+
+
// testOAuthComponentsWithMockedSession tests OAuth components that work without PDS redirect flow.
+
// This is used when testing with localhost PDS, where the indigo library rejects http:// URLs.
+
func testOAuthComponentsWithMockedSession(t *testing.T, ctx context.Context, _ interface{}, store oauthlib.ClientAuthStore, client *oauth.OAuthClient, userDID, _ string) {
+
t.Helper()
+
+
t.Log("๐Ÿ”ง Testing OAuth components with mocked session...")
+
+
// Parse DID
+
parsedDID, err := syntax.ParseDID(userDID)
+
require.NoError(t, err, "Should parse DID")
+
+
// Component 1: Session Storage
+
t.Log(" ๐Ÿ“ฆ Component 1: Testing session storage...")
+
testSession := oauthlib.ClientSessionData{
+
AccountDID: parsedDID,
+
SessionID: fmt.Sprintf("localhost-test-%d", time.Now().UnixNano()),
+
HostURL: "http://localhost:3001",
+
AccessToken: "mocked-access-token",
+
Scopes: []string{"atproto", "transition:generic"},
+
}
+
+
err = store.SaveSession(ctx, testSession)
+
require.NoError(t, err, "Should save session")
+
+
retrieved, err := store.GetSession(ctx, parsedDID, testSession.SessionID)
+
require.NoError(t, err, "Should retrieve session")
+
require.Equal(t, testSession.SessionID, retrieved.SessionID)
+
require.Equal(t, testSession.AccessToken, retrieved.AccessToken)
+
t.Log(" โœ… Session storage working")
+
+
// Component 2: Token Sealing
+
t.Log(" ๐Ÿ” Component 2: Testing token sealing...")
+
sealedToken, err := client.SealSession(parsedDID.String(), testSession.SessionID, time.Hour)
+
require.NoError(t, err, "Should seal token")
+
require.NotEmpty(t, sealedToken, "Sealed token should not be empty")
+
tokenPreview := sealedToken
+
if len(tokenPreview) > 50 {
+
tokenPreview = tokenPreview[:50]
+
}
+
t.Logf(" โœ… Token sealed: %s...", tokenPreview)
+
+
// Component 3: Token Unsealing
+
t.Log(" ๐Ÿ”“ Component 3: Testing token unsealing...")
+
unsealed, err := client.UnsealSession(sealedToken)
+
require.NoError(t, err, "Should unseal token")
+
require.Equal(t, userDID, unsealed.DID)
+
require.Equal(t, testSession.SessionID, unsealed.SessionID)
+
t.Log(" โœ… Token unsealing working")
+
+
// Component 4: Session Cleanup
+
t.Log(" ๐Ÿงน Component 4: Testing session cleanup...")
+
err = store.DeleteSession(ctx, parsedDID, testSession.SessionID)
+
require.NoError(t, err, "Should delete session")
+
+
_, err = store.GetSession(ctx, parsedDID, testSession.SessionID)
+
require.Error(t, err, "Session should not exist after deletion")
+
t.Log(" โœ… Session cleanup working")
+
+
t.Log("โœ… All OAuth components verified!")
+
t.Log("")
+
t.Log("๐Ÿ“ Summary: OAuth implementation validated with mocked session")
+
t.Log(" - Session storage: โœ“")
+
t.Log(" - Token sealing: โœ“")
+
t.Log(" - Token unsealing: โœ“")
+
t.Log(" - Session cleanup: โœ“")
+
t.Log("")
+
t.Log("โš ๏ธ To test full OAuth redirect flow, use a production PDS with HTTPS")
+
}
+
+
// TestOAuthE2E_TokenExpiration tests that expired sealed tokens are rejected
+
func TestOAuthE2E_TokenExpiration(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping OAuth token expiration test in short mode")
+
}
+
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
// Run migrations
+
require.NoError(t, goose.SetDialect("postgres"))
+
require.NoError(t, goose.Up(db, "../../internal/db/migrations"))
+
+
ctx := context.Background()
+
+
t.Log("โฐ Testing OAuth token expiration...")
+
+
// Setup OAuth client and store
+
store := SetupOAuthTestStore(t, db)
+
client := SetupOAuthTestClient(t, store)
+
_ = oauth.NewOAuthHandler(client, store) // Handler created for completeness
+
+
// Create test session with past expiration
+
did, err := syntax.ParseDID("did:plc:expiredtest123")
+
require.NoError(t, err)
+
+
testSession := oauthlib.ClientSessionData{
+
AccountDID: did,
+
SessionID: "expired-session",
+
HostURL: "http://localhost:3001",
+
AccessToken: "expired-token",
+
Scopes: []string{"atproto"},
+
}
+
+
// Save session
+
err = store.SaveSession(ctx, testSession)
+
require.NoError(t, err)
+
+
// Manually update expiration to the past
+
_, err = db.ExecContext(ctx,
+
"UPDATE oauth_sessions SET expires_at = NOW() - INTERVAL '1 day' WHERE did = $1 AND session_id = $2",
+
did.String(), testSession.SessionID)
+
require.NoError(t, err)
+
+
// Try to retrieve expired session
+
_, err = store.GetSession(ctx, did, testSession.SessionID)
+
assert.Error(t, err, "Should not be able to retrieve expired session")
+
assert.Equal(t, oauth.ErrSessionNotFound, err, "Should return ErrSessionNotFound for expired session")
+
+
// Test cleanup of expired sessions
+
cleaned, err := store.(*oauth.PostgresOAuthStore).CleanupExpiredSessions(ctx)
+
require.NoError(t, err, "Cleanup should succeed")
+
assert.Greater(t, cleaned, int64(0), "Should have cleaned up at least one session")
+
+
t.Logf("โœ… Expired session handling verified (cleaned %d sessions)", cleaned)
+
}
+
+
// TestOAuthE2E_InvalidToken tests that invalid/tampered tokens are rejected
+
func TestOAuthE2E_InvalidToken(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping OAuth invalid token test in short mode")
+
}
+
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
// Run migrations
+
require.NoError(t, goose.SetDialect("postgres"))
+
require.NoError(t, goose.Up(db, "../../internal/db/migrations"))
+
+
t.Log("๐Ÿ”’ Testing OAuth invalid token rejection...")
+
+
// Setup OAuth client and store
+
store := SetupOAuthTestStore(t, db)
+
client := SetupOAuthTestClient(t, store)
+
handler := oauth.NewOAuthHandler(client, store)
+
+
// Setup test server with protected endpoint
+
r := chi.NewRouter()
+
r.Get("/api/me", func(w http.ResponseWriter, r *http.Request) {
+
sessData, err := handler.GetSessionFromRequest(r)
+
if err != nil {
+
http.Error(w, "Unauthorized", http.StatusUnauthorized)
+
return
+
}
+
w.Header().Set("Content-Type", "application/json")
+
_ = json.NewEncoder(w).Encode(map[string]string{"did": sessData.AccountDID.String()})
+
})
+
+
server := httptest.NewServer(r)
+
defer server.Close()
+
+
// Test with invalid token formats
+
testCases := []struct {
+
name string
+
token string
+
}{
+
{"Empty token", ""},
+
{"Invalid base64", "not-valid-base64!!!"},
+
{"Tampered token", "dGFtcGVyZWQtdG9rZW4tZGF0YQ=="}, // Valid base64 but invalid content
+
{"Short token", "abc"},
+
}
+
+
for _, tc := range testCases {
+
t.Run(tc.name, func(t *testing.T) {
+
req, _ := http.NewRequest("GET", server.URL+"/api/me", nil)
+
if tc.token != "" {
+
req.Header.Set("Authorization", "Bearer "+tc.token)
+
}
+
+
resp, err := http.DefaultClient.Do(req)
+
require.NoError(t, err)
+
defer func() { _ = resp.Body.Close() }()
+
+
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode,
+
"Invalid token should be rejected with 401")
+
})
+
}
+
+
t.Logf("โœ… Invalid token rejection verified")
+
}
+
+
// TestOAuthE2E_SessionNotFound tests behavior when session doesn't exist in DB
+
func TestOAuthE2E_SessionNotFound(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping OAuth session not found test in short mode")
+
}
+
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
// Run migrations
+
require.NoError(t, goose.SetDialect("postgres"))
+
require.NoError(t, goose.Up(db, "../../internal/db/migrations"))
+
+
ctx := context.Background()
+
+
t.Log("๐Ÿ” Testing OAuth session not found behavior...")
+
+
// Setup OAuth store
+
store := SetupOAuthTestStore(t, db)
+
+
// Try to retrieve non-existent session
+
nonExistentDID, err := syntax.ParseDID("did:plc:nonexistent123")
+
require.NoError(t, err)
+
+
_, err = store.GetSession(ctx, nonExistentDID, "nonexistent-session")
+
assert.Error(t, err, "Should return error for non-existent session")
+
assert.Equal(t, oauth.ErrSessionNotFound, err, "Should return ErrSessionNotFound")
+
+
// Try to delete non-existent session
+
err = store.DeleteSession(ctx, nonExistentDID, "nonexistent-session")
+
assert.Error(t, err, "Should return error when deleting non-existent session")
+
assert.Equal(t, oauth.ErrSessionNotFound, err, "Should return ErrSessionNotFound")
+
+
t.Logf("โœ… Session not found handling verified")
+
}
+
+
// TestOAuthE2E_MultipleSessionsPerUser tests that a user can have multiple active sessions
+
func TestOAuthE2E_MultipleSessionsPerUser(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping OAuth multiple sessions test in short mode")
+
}
+
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
// Run migrations
+
require.NoError(t, goose.SetDialect("postgres"))
+
require.NoError(t, goose.Up(db, "../../internal/db/migrations"))
+
+
ctx := context.Background()
+
+
t.Log("๐Ÿ‘ฅ Testing multiple OAuth sessions per user...")
+
+
// Setup OAuth store
+
store := SetupOAuthTestStore(t, db)
+
+
// Create a test DID
+
did, err := syntax.ParseDID("did:plc:multisession123")
+
require.NoError(t, err)
+
+
// Create multiple sessions for the same user
+
sessions := []oauthlib.ClientSessionData{
+
{
+
AccountDID: did,
+
SessionID: "session-1-web",
+
HostURL: "http://localhost:3001",
+
AccessToken: "token-1",
+
Scopes: []string{"atproto"},
+
},
+
{
+
AccountDID: did,
+
SessionID: "session-2-mobile",
+
HostURL: "http://localhost:3001",
+
AccessToken: "token-2",
+
Scopes: []string{"atproto"},
+
},
+
{
+
AccountDID: did,
+
SessionID: "session-3-tablet",
+
HostURL: "http://localhost:3001",
+
AccessToken: "token-3",
+
Scopes: []string{"atproto"},
+
},
+
}
+
+
// Save all sessions
+
for i, session := range sessions {
+
err := store.SaveSession(ctx, session)
+
require.NoError(t, err, "Should be able to save session %d", i+1)
+
}
+
+
t.Logf("โœ… Created %d sessions for user", len(sessions))
+
+
// Verify all sessions can be retrieved independently
+
for i, session := range sessions {
+
retrieved, err := store.GetSession(ctx, did, session.SessionID)
+
require.NoError(t, err, "Should be able to retrieve session %d", i+1)
+
assert.Equal(t, session.SessionID, retrieved.SessionID, "Session ID should match")
+
assert.Equal(t, session.AccessToken, retrieved.AccessToken, "Access token should match")
+
}
+
+
t.Logf("โœ… All sessions retrieved independently")
+
+
// Delete one session and verify others remain
+
err = store.DeleteSession(ctx, did, sessions[0].SessionID)
+
require.NoError(t, err, "Should be able to delete first session")
+
+
// Verify first session is deleted
+
_, err = store.GetSession(ctx, did, sessions[0].SessionID)
+
assert.Equal(t, oauth.ErrSessionNotFound, err, "First session should be deleted")
+
+
// Verify other sessions still exist
+
for i := 1; i < len(sessions); i++ {
+
_, err := store.GetSession(ctx, did, sessions[i].SessionID)
+
require.NoError(t, err, "Session %d should still exist", i+1)
+
}
+
+
t.Logf("โœ… Multiple sessions per user verified")
+
+
// Cleanup
+
for i := 1; i < len(sessions); i++ {
+
_ = store.DeleteSession(ctx, did, sessions[i].SessionID)
+
}
+
}
+
+
// TestOAuthE2E_AuthRequestStorage tests OAuth auth request storage and retrieval
+
func TestOAuthE2E_AuthRequestStorage(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping OAuth auth request storage test in short mode")
+
}
+
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
// Run migrations
+
require.NoError(t, goose.SetDialect("postgres"))
+
require.NoError(t, goose.Up(db, "../../internal/db/migrations"))
+
+
ctx := context.Background()
+
+
t.Log("๐Ÿ“ Testing OAuth auth request storage...")
+
+
// Setup OAuth store
+
store := SetupOAuthTestStore(t, db)
+
+
// Create test auth request data
+
did, err := syntax.ParseDID("did:plc:authrequest123")
+
require.NoError(t, err)
+
+
authRequest := oauthlib.AuthRequestData{
+
State: "test-state-12345",
+
AccountDID: &did,
+
PKCEVerifier: "test-pkce-verifier",
+
DPoPPrivateKeyMultibase: "test-dpop-key",
+
DPoPAuthServerNonce: "test-nonce",
+
AuthServerURL: "http://localhost:3001",
+
RequestURI: "http://localhost:3001/authorize",
+
AuthServerTokenEndpoint: "http://localhost:3001/oauth/token",
+
AuthServerRevocationEndpoint: "http://localhost:3001/oauth/revoke",
+
Scopes: []string{"atproto", "transition:generic"},
+
}
+
+
// Save auth request
+
err = store.SaveAuthRequestInfo(ctx, authRequest)
+
require.NoError(t, err, "Should be able to save auth request")
+
+
t.Logf("โœ… Auth request saved")
+
+
// Retrieve auth request
+
retrieved, err := store.GetAuthRequestInfo(ctx, authRequest.State)
+
require.NoError(t, err, "Should be able to retrieve auth request")
+
assert.Equal(t, authRequest.State, retrieved.State, "State should match")
+
assert.Equal(t, authRequest.PKCEVerifier, retrieved.PKCEVerifier, "PKCE verifier should match")
+
assert.Equal(t, authRequest.AuthServerURL, retrieved.AuthServerURL, "Auth server URL should match")
+
assert.Equal(t, len(authRequest.Scopes), len(retrieved.Scopes), "Scopes length should match")
+
+
t.Logf("โœ… Auth request retrieved and verified")
+
+
// Test duplicate state error
+
err = store.SaveAuthRequestInfo(ctx, authRequest)
+
assert.Error(t, err, "Should not allow duplicate state")
+
assert.Contains(t, err.Error(), "already exists", "Error should indicate duplicate")
+
+
t.Logf("โœ… Duplicate state prevention verified")
+
+
// Delete auth request
+
err = store.DeleteAuthRequestInfo(ctx, authRequest.State)
+
require.NoError(t, err, "Should be able to delete auth request")
+
+
// Verify deletion
+
_, err = store.GetAuthRequestInfo(ctx, authRequest.State)
+
assert.Equal(t, oauth.ErrAuthRequestNotFound, err, "Auth request should be deleted")
+
+
t.Logf("โœ… Auth request deletion verified")
+
+
// Test cleanup of expired auth requests
+
// Create an auth request and manually set created_at to the past
+
oldAuthRequest := oauthlib.AuthRequestData{
+
State: "old-state-12345",
+
PKCEVerifier: "old-verifier",
+
AuthServerURL: "http://localhost:3001",
+
Scopes: []string{"atproto"},
+
}
+
+
err = store.SaveAuthRequestInfo(ctx, oldAuthRequest)
+
require.NoError(t, err)
+
+
// Update created_at to 1 hour ago
+
_, err = db.ExecContext(ctx,
+
"UPDATE oauth_requests SET created_at = NOW() - INTERVAL '1 hour' WHERE state = $1",
+
oldAuthRequest.State)
+
require.NoError(t, err)
+
+
// Cleanup expired requests
+
cleaned, err := store.(*oauth.PostgresOAuthStore).CleanupExpiredAuthRequests(ctx)
+
require.NoError(t, err, "Cleanup should succeed")
+
assert.Greater(t, cleaned, int64(0), "Should have cleaned up at least one auth request")
+
+
t.Logf("โœ… Expired auth request cleanup verified (cleaned %d requests)", cleaned)
+
}
+
+
// TestOAuthE2E_TokenRefresh tests the refresh token flow
+
func TestOAuthE2E_TokenRefresh(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping OAuth token refresh test in short mode")
+
}
+
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
// Run migrations
+
require.NoError(t, goose.SetDialect("postgres"))
+
require.NoError(t, goose.Up(db, "../../internal/db/migrations"))
+
+
ctx := context.Background()
+
+
t.Log("๐Ÿ”„ Testing OAuth token refresh flow...")
+
+
// Setup OAuth client and store
+
store := SetupOAuthTestStore(t, db)
+
client := SetupOAuthTestClient(t, store)
+
handler := oauth.NewOAuthHandler(client, store)
+
+
// Create a test DID and session
+
did, err := syntax.ParseDID("did:plc:refreshtest123")
+
require.NoError(t, err)
+
+
// Create initial session with refresh token
+
initialSession := oauthlib.ClientSessionData{
+
AccountDID: did,
+
SessionID: "refresh-session-1",
+
HostURL: "http://localhost:3001",
+
AuthServerURL: "http://localhost:3001",
+
AuthServerTokenEndpoint: "http://localhost:3001/oauth/token",
+
AuthServerRevocationEndpoint: "http://localhost:3001/oauth/revoke",
+
AccessToken: "initial-access-token",
+
RefreshToken: "initial-refresh-token",
+
DPoPPrivateKeyMultibase: "test-dpop-key",
+
DPoPAuthServerNonce: "test-nonce",
+
Scopes: []string{"atproto", "transition:generic"},
+
}
+
+
// Save the session
+
err = store.SaveSession(ctx, initialSession)
+
require.NoError(t, err, "Should save initial session")
+
+
t.Logf("โœ… Initial session created")
+
+
// Create a sealed token for this session
+
sealedToken, err := client.SealSession(did.String(), initialSession.SessionID, time.Hour)
+
require.NoError(t, err, "Should seal session token")
+
require.NotEmpty(t, sealedToken, "Sealed token should not be empty")
+
+
t.Logf("โœ… Session token sealed")
+
+
// Setup test server with refresh endpoint
+
r := chi.NewRouter()
+
r.Post("/oauth/refresh", handler.HandleRefresh)
+
+
server := httptest.NewServer(r)
+
defer server.Close()
+
+
t.Run("Valid refresh request", func(t *testing.T) {
+
// NOTE: This test verifies that the refresh endpoint can be called
+
// In a real scenario, the indigo client's RefreshTokens() would call the PDS
+
// Since we're in a component test, we're testing the Coves handler logic
+
+
// Create refresh request
+
refreshReq := map[string]interface{}{
+
"did": did.String(),
+
"session_id": initialSession.SessionID,
+
"sealed_token": sealedToken,
+
}
+
+
reqBody, err := json.Marshal(refreshReq)
+
require.NoError(t, err)
+
+
req, err := http.NewRequest("POST", server.URL+"/oauth/refresh", strings.NewReader(string(reqBody)))
+
require.NoError(t, err)
+
req.Header.Set("Content-Type", "application/json")
+
+
// NOTE: In component testing mode, the indigo client may not have
+
// real PDS credentials, so RefreshTokens() might fail
+
// We're testing that the handler correctly processes the request
+
resp, err := http.DefaultClient.Do(req)
+
require.NoError(t, err)
+
defer func() { _ = resp.Body.Close() }()
+
+
// In component test mode without real PDS, we may get 401
+
// In production with real PDS, this would return 200 with new tokens
+
t.Logf("Refresh response status: %d", resp.StatusCode)
+
+
// The important thing is that the handler doesn't crash
+
// and properly validates the request structure
+
assert.True(t, resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusUnauthorized,
+
"Refresh should return either success or auth failure, got %d", resp.StatusCode)
+
})
+
+
t.Run("Invalid DID format (with valid token)", func(t *testing.T) {
+
// Create a sealed token with an invalid DID format
+
invalidDID := "invalid-did-format"
+
// Create the token with a valid DID first, then we'll try to use it with invalid DID in request
+
validToken, err := client.SealSession(did.String(), initialSession.SessionID, 30*24*time.Hour)
+
require.NoError(t, err)
+
+
refreshReq := map[string]interface{}{
+
"did": invalidDID, // Invalid DID format in request
+
"session_id": initialSession.SessionID,
+
"sealed_token": validToken, // Valid token for different DID
+
}
+
+
reqBody, err := json.Marshal(refreshReq)
+
require.NoError(t, err)
+
+
req, err := http.NewRequest("POST", server.URL+"/oauth/refresh", strings.NewReader(string(reqBody)))
+
require.NoError(t, err)
+
req.Header.Set("Content-Type", "application/json")
+
+
resp, err := http.DefaultClient.Do(req)
+
require.NoError(t, err)
+
defer func() { _ = resp.Body.Close() }()
+
+
// Should reject with 401 due to DID mismatch (not 400) since auth happens first
+
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode,
+
"DID mismatch should be rejected with 401 (auth check happens before format validation)")
+
})
+
+
t.Run("Missing sealed_token (security test)", func(t *testing.T) {
+
refreshReq := map[string]interface{}{
+
"did": did.String(),
+
"session_id": initialSession.SessionID,
+
// Missing sealed_token - should be rejected for security
+
}
+
+
reqBody, err := json.Marshal(refreshReq)
+
require.NoError(t, err)
+
+
req, err := http.NewRequest("POST", server.URL+"/oauth/refresh", strings.NewReader(string(reqBody)))
+
require.NoError(t, err)
+
req.Header.Set("Content-Type", "application/json")
+
+
resp, err := http.DefaultClient.Do(req)
+
require.NoError(t, err)
+
defer func() { _ = resp.Body.Close() }()
+
+
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode,
+
"Missing sealed_token should be rejected (proof of possession required)")
+
})
+
+
t.Run("Invalid sealed_token", func(t *testing.T) {
+
refreshReq := map[string]interface{}{
+
"did": did.String(),
+
"session_id": initialSession.SessionID,
+
"sealed_token": "invalid-token-data",
+
}
+
+
reqBody, err := json.Marshal(refreshReq)
+
require.NoError(t, err)
+
+
req, err := http.NewRequest("POST", server.URL+"/oauth/refresh", strings.NewReader(string(reqBody)))
+
require.NoError(t, err)
+
req.Header.Set("Content-Type", "application/json")
+
+
resp, err := http.DefaultClient.Do(req)
+
require.NoError(t, err)
+
defer func() { _ = resp.Body.Close() }()
+
+
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode,
+
"Invalid sealed_token should be rejected")
+
})
+
+
t.Run("DID mismatch (security test)", func(t *testing.T) {
+
// Create a sealed token for a different DID
+
wrongDID := "did:plc:wronguser123"
+
wrongToken, err := client.SealSession(wrongDID, initialSession.SessionID, 30*24*time.Hour)
+
require.NoError(t, err)
+
+
// Try to use it to refresh the original session
+
refreshReq := map[string]interface{}{
+
"did": did.String(), // Claiming original DID
+
"session_id": initialSession.SessionID,
+
"sealed_token": wrongToken, // But token is for different DID
+
}
+
+
reqBody, err := json.Marshal(refreshReq)
+
require.NoError(t, err)
+
+
req, err := http.NewRequest("POST", server.URL+"/oauth/refresh", strings.NewReader(string(reqBody)))
+
require.NoError(t, err)
+
req.Header.Set("Content-Type", "application/json")
+
+
resp, err := http.DefaultClient.Do(req)
+
require.NoError(t, err)
+
defer func() { _ = resp.Body.Close() }()
+
+
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode,
+
"DID mismatch should be rejected (prevents session hijacking)")
+
})
+
+
t.Run("Session ID mismatch (security test)", func(t *testing.T) {
+
// Create a sealed token with wrong session ID
+
wrongSessionID := "wrong-session-id"
+
wrongToken, err := client.SealSession(did.String(), wrongSessionID, 30*24*time.Hour)
+
require.NoError(t, err)
+
+
// Try to use it to refresh the original session
+
refreshReq := map[string]interface{}{
+
"did": did.String(),
+
"session_id": initialSession.SessionID, // Claiming original session
+
"sealed_token": wrongToken, // But token is for different session
+
}
+
+
reqBody, err := json.Marshal(refreshReq)
+
require.NoError(t, err)
+
+
req, err := http.NewRequest("POST", server.URL+"/oauth/refresh", strings.NewReader(string(reqBody)))
+
require.NoError(t, err)
+
req.Header.Set("Content-Type", "application/json")
+
+
resp, err := http.DefaultClient.Do(req)
+
require.NoError(t, err)
+
defer func() { _ = resp.Body.Close() }()
+
+
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode,
+
"Session ID mismatch should be rejected (prevents session hijacking)")
+
})
+
+
t.Run("Non-existent session", func(t *testing.T) {
+
// Create a valid sealed token for a non-existent session
+
nonExistentSessionID := "nonexistent-session-id"
+
validToken, err := client.SealSession(did.String(), nonExistentSessionID, 30*24*time.Hour)
+
require.NoError(t, err)
+
+
refreshReq := map[string]interface{}{
+
"did": did.String(),
+
"session_id": nonExistentSessionID,
+
"sealed_token": validToken, // Valid token but session doesn't exist
+
}
+
+
reqBody, err := json.Marshal(refreshReq)
+
require.NoError(t, err)
+
+
req, err := http.NewRequest("POST", server.URL+"/oauth/refresh", strings.NewReader(string(reqBody)))
+
require.NoError(t, err)
+
req.Header.Set("Content-Type", "application/json")
+
+
resp, err := http.DefaultClient.Do(req)
+
require.NoError(t, err)
+
defer func() { _ = resp.Body.Close() }()
+
+
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode,
+
"Non-existent session should be rejected with 401")
+
})
+
+
t.Logf("โœ… Token refresh endpoint validation verified")
+
}
+
+
// TestOAuthE2E_SessionUpdate tests that refresh updates the session in database
+
func TestOAuthE2E_SessionUpdate(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping OAuth session update test in short mode")
+
}
+
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
// Run migrations
+
require.NoError(t, goose.SetDialect("postgres"))
+
require.NoError(t, goose.Up(db, "../../internal/db/migrations"))
+
+
ctx := context.Background()
+
+
t.Log("๐Ÿ’พ Testing OAuth session update on refresh...")
+
+
// Setup OAuth store
+
store := SetupOAuthTestStore(t, db)
+
+
// Create a test session
+
did, err := syntax.ParseDID("did:plc:sessionupdate123")
+
require.NoError(t, err)
+
+
originalSession := oauthlib.ClientSessionData{
+
AccountDID: did,
+
SessionID: "update-session-1",
+
HostURL: "http://localhost:3001",
+
AuthServerURL: "http://localhost:3001",
+
AuthServerTokenEndpoint: "http://localhost:3001/oauth/token",
+
AccessToken: "original-access-token",
+
RefreshToken: "original-refresh-token",
+
DPoPPrivateKeyMultibase: "original-dpop-key",
+
Scopes: []string{"atproto"},
+
}
+
+
// Save original session
+
err = store.SaveSession(ctx, originalSession)
+
require.NoError(t, err)
+
+
t.Logf("โœ… Original session saved")
+
+
// Simulate a token refresh by updating the session with new tokens
+
updatedSession := originalSession
+
updatedSession.AccessToken = "new-access-token"
+
updatedSession.RefreshToken = "new-refresh-token"
+
updatedSession.DPoPAuthServerNonce = "new-nonce"
+
+
// Update the session (upsert)
+
err = store.SaveSession(ctx, updatedSession)
+
require.NoError(t, err)
+
+
t.Logf("โœ… Session updated with new tokens")
+
+
// Retrieve the session and verify it was updated
+
retrieved, err := store.GetSession(ctx, did, originalSession.SessionID)
+
require.NoError(t, err, "Should retrieve updated session")
+
+
assert.Equal(t, "new-access-token", retrieved.AccessToken,
+
"Access token should be updated")
+
assert.Equal(t, "new-refresh-token", retrieved.RefreshToken,
+
"Refresh token should be updated")
+
assert.Equal(t, "new-nonce", retrieved.DPoPAuthServerNonce,
+
"DPoP nonce should be updated")
+
+
// Verify session ID and DID remain the same
+
assert.Equal(t, originalSession.SessionID, retrieved.SessionID,
+
"Session ID should remain the same")
+
assert.Equal(t, did, retrieved.AccountDID,
+
"DID should remain the same")
+
+
t.Logf("โœ… Session update verified - tokens refreshed in database")
+
+
// Verify updated_at was changed
+
var updatedAt time.Time
+
err = db.QueryRowContext(ctx,
+
"SELECT updated_at FROM oauth_sessions WHERE did = $1 AND session_id = $2",
+
did.String(), originalSession.SessionID).Scan(&updatedAt)
+
require.NoError(t, err)
+
+
// Updated timestamp should be recent (within last minute)
+
assert.WithinDuration(t, time.Now(), updatedAt, time.Minute,
+
"Session updated_at should be recent")
+
+
t.Logf("โœ… Session timestamp update verified")
+
}
+
+
// TestOAuthE2E_RefreshTokenRotation tests refresh token rotation behavior
+
func TestOAuthE2E_RefreshTokenRotation(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping OAuth refresh token rotation test in short mode")
+
}
+
+
db := setupTestDB(t)
+
defer func() { _ = db.Close() }()
+
+
// Run migrations
+
require.NoError(t, goose.SetDialect("postgres"))
+
require.NoError(t, goose.Up(db, "../../internal/db/migrations"))
+
+
ctx := context.Background()
+
+
t.Log("๐Ÿ”„ Testing OAuth refresh token rotation...")
+
+
// Setup OAuth store
+
store := SetupOAuthTestStore(t, db)
+
+
// Create a test session
+
did, err := syntax.ParseDID("did:plc:rotation123")
+
require.NoError(t, err)
+
+
// Simulate multiple refresh cycles
+
sessionID := "rotation-session-1"
+
tokens := []struct {
+
access string
+
refresh string
+
}{
+
{"access-token-v1", "refresh-token-v1"},
+
{"access-token-v2", "refresh-token-v2"},
+
{"access-token-v3", "refresh-token-v3"},
+
}
+
+
for i, tokenPair := range tokens {
+
session := oauthlib.ClientSessionData{
+
AccountDID: did,
+
SessionID: sessionID,
+
HostURL: "http://localhost:3001",
+
AuthServerURL: "http://localhost:3001",
+
AuthServerTokenEndpoint: "http://localhost:3001/oauth/token",
+
AccessToken: tokenPair.access,
+
RefreshToken: tokenPair.refresh,
+
Scopes: []string{"atproto"},
+
}
+
+
// Save/update session
+
err = store.SaveSession(ctx, session)
+
require.NoError(t, err, "Should save session iteration %d", i+1)
+
+
// Retrieve and verify
+
retrieved, err := store.GetSession(ctx, did, sessionID)
+
require.NoError(t, err, "Should retrieve session iteration %d", i+1)
+
+
assert.Equal(t, tokenPair.access, retrieved.AccessToken,
+
"Access token should match iteration %d", i+1)
+
assert.Equal(t, tokenPair.refresh, retrieved.RefreshToken,
+
"Refresh token should match iteration %d", i+1)
+
+
// Small delay to ensure timestamp differences
+
time.Sleep(10 * time.Millisecond)
+
}
+
+
t.Logf("โœ… Refresh token rotation verified through %d cycles", len(tokens))
+
+
// Verify final state
+
finalSession, err := store.GetSession(ctx, did, sessionID)
+
require.NoError(t, err)
+
+
assert.Equal(t, "access-token-v3", finalSession.AccessToken,
+
"Final access token should be from last rotation")
+
assert.Equal(t, "refresh-token-v3", finalSession.RefreshToken,
+
"Final refresh token should be from last rotation")
+
+
t.Logf("โœ… Token rotation state verified")
+
}
+312
tests/integration/oauth_session_fixation_test.go
···
+
package integration
+
+
import (
+
"Coves/internal/atproto/oauth"
+
"context"
+
"crypto/sha256"
+
"encoding/base64"
+
"net/http"
+
"net/http/httptest"
+
"net/url"
+
"testing"
+
"time"
+
+
oauthlib "github.com/bluesky-social/indigo/atproto/auth/oauth"
+
"github.com/bluesky-social/indigo/atproto/syntax"
+
"github.com/go-chi/chi/v5"
+
"github.com/pressly/goose/v3"
+
"github.com/stretchr/testify/assert"
+
"github.com/stretchr/testify/require"
+
)
+
+
// TestOAuth_SessionFixationAttackPrevention tests that the mobile redirect binding
+
// prevents session fixation attacks where an attacker plants a mobile_redirect_uri
+
// cookie, then the user does a web login, and credentials get sent to attacker's deep link.
+
//
+
// Attack scenario:
+
// 1. Attacker tricks user into visiting /oauth/mobile/login?redirect_uri=evil://steal
+
// 2. This plants a mobile_redirect_uri cookie (lives 10 minutes)
+
// 3. User later does normal web OAuth login via /oauth/login
+
// 4. HandleCallback sees the stale mobile_redirect_uri cookie
+
// 5. WITHOUT THE FIX: Callback sends sealed token, DID, session_id to attacker's deep link
+
// 6. WITH THE FIX: Binding mismatch is detected, mobile cookies cleared, user gets web session
+
func TestOAuth_SessionFixationAttackPrevention(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping OAuth session fixation test in short mode")
+
}
+
+
// Setup test database
+
db := setupTestDB(t)
+
defer func() {
+
if err := db.Close(); err != nil {
+
t.Logf("Failed to close database: %v", err)
+
}
+
}()
+
+
// Run migrations
+
require.NoError(t, goose.SetDialect("postgres"))
+
require.NoError(t, goose.Up(db, "../../internal/db/migrations"))
+
+
// Setup OAuth client and store
+
store := SetupOAuthTestStore(t, db)
+
client := SetupOAuthTestClient(t, store)
+
require.NotNil(t, client, "OAuth client should be initialized")
+
+
// Setup handler
+
handler := oauth.NewOAuthHandler(client, store)
+
+
// Setup router
+
r := chi.NewRouter()
+
r.Get("/oauth/callback", handler.HandleCallback)
+
+
t.Run("attack scenario - planted mobile cookie without binding", func(t *testing.T) {
+
ctx := context.Background()
+
+
// Step 1: Simulate a successful OAuth callback (like a user did web login)
+
// We'll create a mock session to simulate what ProcessCallback would return
+
testDID := "did:plc:test123456"
+
parsedDID, err := syntax.ParseDID(testDID)
+
require.NoError(t, err)
+
+
sessionID := "test-session-" + time.Now().Format("20060102150405")
+
testSession := oauthlib.ClientSessionData{
+
AccountDID: parsedDID,
+
SessionID: sessionID,
+
HostURL: "http://localhost:3001",
+
AccessToken: "test-access-token",
+
Scopes: []string{"atproto"},
+
}
+
+
// Save the session (simulating successful OAuth flow)
+
err = store.SaveSession(ctx, testSession)
+
require.NoError(t, err)
+
+
// Step 2: Attacker planted a mobile_redirect_uri cookie (without binding)
+
// This simulates the cookie being planted earlier by attacker
+
attackerRedirectURI := "evil://steal"
+
req := httptest.NewRequest("GET", "/oauth/callback?code=test&state=test&iss=http://localhost:3001", nil)
+
+
// Plant the attacker's cookie (URL escaped as it would be in real scenario)
+
req.AddCookie(&http.Cookie{
+
Name: "mobile_redirect_uri",
+
Value: url.QueryEscape(attackerRedirectURI),
+
Path: "/oauth",
+
})
+
// NOTE: No mobile_redirect_binding cookie! This is the attack scenario.
+
+
rec := httptest.NewRecorder()
+
+
// Step 3: Try to process the callback
+
// This would fail because ProcessCallback needs real OAuth code/state
+
// For this test, we're verifying the handler's security checks work
+
// even before ProcessCallback is called
+
+
// The handler will try to call ProcessCallback which will fail
+
// But we're testing that even if it succeeded, the mobile redirect
+
// validation would prevent the attack
+
handler.HandleCallback(rec, req)
+
+
// Step 4: Verify the attack was prevented
+
// The handler should reject the request due to missing binding
+
// Since ProcessCallback will fail first (no real OAuth code), we expect
+
// a 400 error, but the important thing is it doesn't redirect to evil://steal
+
+
assert.NotEqual(t, http.StatusFound, rec.Code,
+
"Should not redirect when ProcessCallback fails")
+
assert.NotContains(t, rec.Header().Get("Location"), "evil://",
+
"Should never redirect to attacker's URI")
+
})
+
+
t.Run("legitimate mobile flow - with valid binding", func(t *testing.T) {
+
ctx := context.Background()
+
+
// Setup a legitimate mobile session
+
testDID := "did:plc:mobile123"
+
parsedDID, err := syntax.ParseDID(testDID)
+
require.NoError(t, err)
+
+
sessionID := "mobile-session-" + time.Now().Format("20060102150405")
+
testSession := oauthlib.ClientSessionData{
+
AccountDID: parsedDID,
+
SessionID: sessionID,
+
HostURL: "http://localhost:3001",
+
AccessToken: "mobile-access-token",
+
Scopes: []string{"atproto"},
+
}
+
+
// Save the session
+
err = store.SaveSession(ctx, testSession)
+
require.NoError(t, err)
+
+
// Create request with BOTH mobile_redirect_uri AND valid binding
+
// Use Universal Link URI that's in the allowlist
+
legitRedirectURI := "https://coves.social/app/oauth/callback"
+
csrfToken := "valid-csrf-token-for-mobile"
+
req := httptest.NewRequest("GET", "/oauth/callback?code=test&state=test&iss=http://localhost:3001", nil)
+
+
// Add mobile redirect URI cookie
+
req.AddCookie(&http.Cookie{
+
Name: "mobile_redirect_uri",
+
Value: url.QueryEscape(legitRedirectURI),
+
Path: "/oauth",
+
})
+
+
// Add CSRF token (required for mobile flow)
+
req.AddCookie(&http.Cookie{
+
Name: "oauth_csrf",
+
Value: csrfToken,
+
Path: "/oauth",
+
})
+
+
// Add VALID binding cookie (this is what prevents the attack)
+
// In real flow, this would be set by HandleMobileLogin
+
// The binding now includes the CSRF token for double-submit validation
+
mobileBinding := generateMobileRedirectBindingForTest(csrfToken, legitRedirectURI)
+
req.AddCookie(&http.Cookie{
+
Name: "mobile_redirect_binding",
+
Value: mobileBinding,
+
Path: "/oauth",
+
})
+
+
rec := httptest.NewRecorder()
+
handler.HandleCallback(rec, req)
+
+
// This will also fail at ProcessCallback (no real OAuth code)
+
// but we're verifying the binding validation logic is in place
+
// In a real integration test with PDS, this would succeed
+
assert.NotEqual(t, http.StatusFound, rec.Code,
+
"Should not redirect when ProcessCallback fails (expected in mock test)")
+
})
+
+
t.Run("binding mismatch - attacker tries wrong binding", func(t *testing.T) {
+
ctx := context.Background()
+
+
// Setup session
+
testDID := "did:plc:bindingtest"
+
parsedDID, err := syntax.ParseDID(testDID)
+
require.NoError(t, err)
+
+
sessionID := "binding-test-" + time.Now().Format("20060102150405")
+
testSession := oauthlib.ClientSessionData{
+
AccountDID: parsedDID,
+
SessionID: sessionID,
+
HostURL: "http://localhost:3001",
+
AccessToken: "binding-test-token",
+
Scopes: []string{"atproto"},
+
}
+
+
err = store.SaveSession(ctx, testSession)
+
require.NoError(t, err)
+
+
// Attacker tries to plant evil redirect with a binding from different URI
+
attackerRedirectURI := "evil://steal"
+
attackerCSRF := "attacker-csrf-token"
+
req := httptest.NewRequest("GET", "/oauth/callback?code=test&state=test&iss=http://localhost:3001", nil)
+
+
req.AddCookie(&http.Cookie{
+
Name: "mobile_redirect_uri",
+
Value: url.QueryEscape(attackerRedirectURI),
+
Path: "/oauth",
+
})
+
+
req.AddCookie(&http.Cookie{
+
Name: "oauth_csrf",
+
Value: attackerCSRF,
+
Path: "/oauth",
+
})
+
+
// Use binding from a DIFFERENT CSRF token and URI (attacker's attempt to forge)
+
// Even if attacker knows the redirect URI, they don't know the user's CSRF token
+
wrongBinding := generateMobileRedirectBindingForTest("different-csrf", "https://coves.social/app/oauth/callback")
+
req.AddCookie(&http.Cookie{
+
Name: "mobile_redirect_binding",
+
Value: wrongBinding,
+
Path: "/oauth",
+
})
+
+
rec := httptest.NewRecorder()
+
handler.HandleCallback(rec, req)
+
+
// Should fail due to binding mismatch (even before ProcessCallback)
+
// The binding validation happens after ProcessCallback in the real code,
+
// but the mismatch would be caught and cookies cleared
+
assert.NotContains(t, rec.Header().Get("Location"), "evil://",
+
"Should never redirect to attacker's URI on binding mismatch")
+
})
+
+
t.Run("CSRF token value mismatch - attacker tries different CSRF", func(t *testing.T) {
+
ctx := context.Background()
+
+
// Setup session
+
testDID := "did:plc:csrftest"
+
parsedDID, err := syntax.ParseDID(testDID)
+
require.NoError(t, err)
+
+
sessionID := "csrf-test-" + time.Now().Format("20060102150405")
+
testSession := oauthlib.ClientSessionData{
+
AccountDID: parsedDID,
+
SessionID: sessionID,
+
HostURL: "http://localhost:3001",
+
AccessToken: "csrf-test-token",
+
Scopes: []string{"atproto"},
+
}
+
+
err = store.SaveSession(ctx, testSession)
+
require.NoError(t, err)
+
+
// This tests the P1 security fix: CSRF token VALUE must be validated, not just presence
+
// Attack scenario:
+
// 1. User starts mobile login with CSRF token A and redirect URI X
+
// 2. Binding = hash(A + X) is stored in cookie
+
// 3. Attacker somehow gets user to have CSRF token B in cookie (different from A)
+
// 4. Callback receives CSRF token B, redirect URI X, binding = hash(A + X)
+
// 5. hash(B + X) != hash(A + X), so attack is detected
+
+
originalCSRF := "original-csrf-token-set-at-login"
+
redirectURI := "https://coves.social/app/oauth/callback"
+
// Binding was created with original CSRF token
+
originalBinding := generateMobileRedirectBindingForTest(originalCSRF, redirectURI)
+
+
// But attacker managed to change the CSRF cookie
+
attackerCSRF := "attacker-replaced-csrf"
+
+
req := httptest.NewRequest("GET", "/oauth/callback?code=test&state=test&iss=http://localhost:3001", nil)
+
+
req.AddCookie(&http.Cookie{
+
Name: "mobile_redirect_uri",
+
Value: url.QueryEscape(redirectURI),
+
Path: "/oauth",
+
})
+
+
// Attacker's CSRF token (different from what created the binding)
+
req.AddCookie(&http.Cookie{
+
Name: "oauth_csrf",
+
Value: attackerCSRF,
+
Path: "/oauth",
+
})
+
+
// Original binding (created with original CSRF token)
+
req.AddCookie(&http.Cookie{
+
Name: "mobile_redirect_binding",
+
Value: originalBinding,
+
Path: "/oauth",
+
})
+
+
rec := httptest.NewRecorder()
+
handler.HandleCallback(rec, req)
+
+
// Should fail because hash(attackerCSRF + redirectURI) != hash(originalCSRF + redirectURI)
+
// This is the key security fix - CSRF token VALUE is now validated
+
assert.NotEqual(t, http.StatusFound, rec.Code,
+
"Should not redirect when CSRF token doesn't match binding")
+
})
+
}
+
+
// generateMobileRedirectBindingForTest generates a binding for testing
+
// This mirrors the actual logic in handlers_security.go:
+
// binding = base64(sha256(csrfToken + "|" + redirectURI)[:16])
+
func generateMobileRedirectBindingForTest(csrfToken, mobileRedirectURI string) string {
+
combined := csrfToken + "|" + mobileRedirectURI
+
hash := sha256.Sum256([]byte(combined))
+
return base64.URLEncoding.EncodeToString(hash[:16])
+
}
+169
tests/integration/oauth_token_verification_test.go
···
+
package integration
+
+
import (
+
"Coves/internal/api/middleware"
+
"fmt"
+
"net/http"
+
"net/http/httptest"
+
"os"
+
"testing"
+
"time"
+
)
+
+
// TestOAuthTokenVerification tests end-to-end OAuth token verification
+
// with real PDS-issued OAuth tokens. This replaces the old JWT verification test
+
// since we now use OAuth sealed session tokens instead of raw JWTs.
+
//
+
// Flow:
+
// 1. Create account on local PDS (or use existing)
+
// 2. Authenticate to get OAuth tokens and create sealed session token
+
// 3. Verify our auth middleware can unseal and validate the token
+
// 4. Test token validation and session retrieval
+
//
+
// NOTE: This test uses the E2E OAuth middleware which mocks the session unsealing
+
// for testing purposes. Real OAuth tokens from PDS would be sealed using the
+
// OAuth client's seal secret.
+
func TestOAuthTokenVerification(t *testing.T) {
+
// Skip in short mode since this requires real PDS
+
if testing.Short() {
+
t.Skip("Skipping OAuth token verification test in short mode")
+
}
+
+
pdsURL := os.Getenv("PDS_URL")
+
if pdsURL == "" {
+
pdsURL = "http://localhost:3001"
+
}
+
+
// Check if PDS is running
+
healthResp, err := http.Get(pdsURL + "/xrpc/_health")
+
if err != nil {
+
t.Skipf("PDS not running at %s: %v", pdsURL, err)
+
}
+
_ = healthResp.Body.Close()
+
+
t.Run("OAuth token validation and middleware integration", func(t *testing.T) {
+
// Step 1: Create a test account on PDS
+
// Keep handle short to avoid PDS validation errors
+
timestamp := time.Now().Unix() % 100000 // Last 5 digits
+
handle := fmt.Sprintf("oauth%d.local.coves.dev", timestamp)
+
password := "testpass123"
+
email := fmt.Sprintf("oauth%d@test.com", timestamp)
+
+
_, did, err := createPDSAccount(pdsURL, handle, email, password)
+
if err != nil {
+
t.Fatalf("Failed to create PDS account: %v", err)
+
}
+
t.Logf("โœ“ Created test account: %s (DID: %s)", handle, did)
+
+
// Step 2: Create OAuth middleware with mock unsealer for testing
+
// In production, this would unseal real OAuth tokens from PDS
+
t.Log("Testing OAuth middleware with sealed session tokens...")
+
+
e2eAuth := NewE2EOAuthMiddleware()
+
testToken := e2eAuth.AddUser(did)
+
+
handlerCalled := false
+
var extractedDID string
+
+
testHandler := e2eAuth.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
handlerCalled = true
+
extractedDID = middleware.GetUserDID(r)
+
w.WriteHeader(http.StatusOK)
+
_, _ = w.Write([]byte(`{"success": true}`))
+
}))
+
+
req := httptest.NewRequest("GET", "/test", nil)
+
req.Header.Set("Authorization", "Bearer "+testToken)
+
w := httptest.NewRecorder()
+
+
testHandler.ServeHTTP(w, req)
+
+
if !handlerCalled {
+
t.Errorf("Handler was not called - auth middleware rejected valid token")
+
t.Logf("Response status: %d", w.Code)
+
t.Logf("Response body: %s", w.Body.String())
+
}
+
+
if w.Code != http.StatusOK {
+
t.Errorf("Expected status 200, got %d", w.Code)
+
t.Logf("Response body: %s", w.Body.String())
+
}
+
+
if extractedDID != did {
+
t.Errorf("Middleware extracted wrong DID: expected %s, got %s", did, extractedDID)
+
}
+
+
t.Logf("โœ… OAuth middleware with token validation working correctly!")
+
t.Logf(" Handler called: %v", handlerCalled)
+
t.Logf(" Extracted DID: %s", extractedDID)
+
t.Logf(" Response status: %d", w.Code)
+
})
+
+
t.Run("Rejects tampered/invalid sealed tokens", func(t *testing.T) {
+
// Create valid user
+
timestamp := time.Now().Unix() % 100000
+
handle := fmt.Sprintf("tamp%d.local.coves.dev", timestamp)
+
password := "testpass456"
+
email := fmt.Sprintf("tamp%d@test.com", timestamp)
+
+
_, did, err := createPDSAccount(pdsURL, handle, email, password)
+
if err != nil {
+
t.Fatalf("Failed to create PDS account: %v", err)
+
}
+
+
// Create OAuth middleware
+
e2eAuth := NewE2EOAuthMiddleware()
+
validToken := e2eAuth.AddUser(did)
+
+
// Create various invalid tokens to test
+
testCases := []struct {
+
name string
+
token string
+
}{
+
{"Empty token", ""},
+
{"Invalid base64", "not-valid-base64!!!"},
+
{"Tampered token", "dGFtcGVyZWQtdG9rZW4tZGF0YQ=="}, // Valid base64 but not a real sealed session
+
{"Short token", "abc"},
+
{"Modified valid token", validToken + "extra"},
+
}
+
+
for _, tc := range testCases {
+
t.Run(tc.name, func(t *testing.T) {
+
handlerCalled := false
+
testHandler := e2eAuth.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
handlerCalled = true
+
w.WriteHeader(http.StatusOK)
+
}))
+
+
req := httptest.NewRequest("GET", "/test", nil)
+
if tc.token != "" {
+
req.Header.Set("Authorization", "Bearer "+tc.token)
+
}
+
w := httptest.NewRecorder()
+
+
testHandler.ServeHTTP(w, req)
+
+
if handlerCalled {
+
t.Error("Handler was called for invalid token - should have been rejected")
+
}
+
+
if w.Code != http.StatusUnauthorized {
+
t.Errorf("Expected status 401 for invalid token, got %d", w.Code)
+
}
+
+
t.Logf("โœ“ Middleware correctly rejected %s with status %d", tc.name, w.Code)
+
})
+
}
+
+
t.Logf("โœ… All invalid token types correctly rejected")
+
})
+
+
t.Run("Session expiration handling", func(t *testing.T) {
+
// OAuth session expiration is handled at the database level
+
// See TestOAuthE2E_TokenExpiration in oauth_e2e_test.go for full expiration testing
+
t.Log("โ„น๏ธ Session expiration testing is covered in oauth_e2e_test.go")
+
t.Log(" OAuth sessions expire based on database timestamps and are cleaned up periodically")
+
t.Log(" This is different from JWT expiration which was timestamp-based in the token itself")
+
t.Skip("Session expiration is tested in oauth_e2e_test.go - see TestOAuthE2E_TokenExpiration")
+
})
+
}
+16 -20
tests/integration/community_e2e_test.go
···
package integration
import (
-
"Coves/internal/api/middleware"
"Coves/internal/api/routes"
"Coves/internal/atproto/identity"
"Coves/internal/atproto/jetstream"
···
t.Logf("โœ… Authenticated - Instance DID: %s", instanceDID)
-
// Initialize auth middleware with skipVerify=true
-
// IMPORTANT: PDS password authentication returns Bearer tokens (not DPoP-bound tokens).
-
// E2E tests use these Bearer tokens with the DPoP scheme header, which only works
-
// because skipVerify=true bypasses signature and DPoP binding verification.
-
// In production, skipVerify=false requires proper DPoP-bound tokens from OAuth flow.
-
authMiddleware := middleware.NewAtProtoAuthMiddleware(nil, true)
-
defer authMiddleware.Stop() // Clean up DPoP replay cache goroutine
+
// Initialize OAuth auth middleware for E2E testing
+
e2eAuth := NewE2EOAuthMiddleware()
+
// Register the instance user for OAuth authentication
+
token := e2eAuth.AddUser(instanceDID)
// V2.0: Extract instance domain for community provisioning
var instanceDomain string
···
// Setup HTTP server with XRPC routes
r := chi.NewRouter()
-
routes.RegisterCommunityRoutes(r, communityService, authMiddleware, nil) // nil = allow all community creators
+
routes.RegisterCommunityRoutes(r, communityService, e2eAuth.OAuthAuthMiddleware, nil) // nil = allow all community creators
httpServer := httptest.NewServer(r)
defer httpServer.Close()
···
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
-
// Use real PDS access token for E2E authentication
-
req.Header.Set("Authorization", "DPoP "+accessToken)
+
// Use OAuth token for Coves API authentication
+
req.Header.Set("Authorization", "Bearer "+token)
resp, err := http.DefaultClient.Do(req)
if err != nil {
···
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
-
// Use real PDS access token for E2E authentication
-
req.Header.Set("Authorization", "DPoP "+accessToken)
+
// Use OAuth token for Coves API authentication
+
req.Header.Set("Authorization", "Bearer "+token)
resp, err := http.DefaultClient.Do(req)
if err != nil {
···
t.Fatalf("Failed to create request: %v", err)
req.Header.Set("Content-Type", "application/json")
-
// Use real PDS access token for E2E authentication
-
req.Header.Set("Authorization", "DPoP "+accessToken)
+
// Use OAuth token for Coves API authentication
+
req.Header.Set("Authorization", "Bearer "+token)
resp, err := http.DefaultClient.Do(req)
if err != nil {
···
t.Fatalf("Failed to create block request: %v", err)
req.Header.Set("Content-Type", "application/json")
-
req.Header.Set("Authorization", "DPoP "+accessToken)
+
req.Header.Set("Authorization", "Bearer "+token)
resp, err := http.DefaultClient.Do(req)
if err != nil {
···
t.Fatalf("Failed to create block request: %v", err)
blockHttpReq.Header.Set("Content-Type", "application/json")
-
blockHttpReq.Header.Set("Authorization", "DPoP "+accessToken)
+
blockHttpReq.Header.Set("Authorization", "Bearer "+token)
blockResp, err := http.DefaultClient.Do(blockHttpReq)
if err != nil {
···
t.Fatalf("Failed to create unblock request: %v", err)
req.Header.Set("Content-Type", "application/json")
-
req.Header.Set("Authorization", "DPoP "+accessToken)
+
req.Header.Set("Authorization", "Bearer "+token)
resp, err := http.DefaultClient.Do(req)
if err != nil {
···
t.Fatalf("Failed to create request: %v", err)
req.Header.Set("Content-Type", "application/json")
-
// Use real PDS access token for E2E authentication
-
req.Header.Set("Authorization", "DPoP "+accessToken)
+
// Use OAuth token for Coves API authentication
+
req.Header.Set("Authorization", "Bearer "+token)
resp, err := http.DefaultClient.Do(req)
if err != nil {
-73
cmd/genjwks/main.go
···
-
package main
-
-
import (
-
"crypto/ecdsa"
-
"crypto/elliptic"
-
"crypto/rand"
-
"encoding/json"
-
"fmt"
-
"log"
-
"os"
-
-
"github.com/lestrrat-go/jwx/v2/jwk"
-
)
-
-
// genjwks generates an ES256 keypair for OAuth client authentication
-
// The private key is stored in the config/env, public key is served at /oauth/jwks.json
-
//
-
// Usage:
-
//
-
// go run cmd/genjwks/main.go
-
//
-
// This will output a JSON private key that should be stored in OAUTH_PRIVATE_JWK
-
func main() {
-
fmt.Println("Generating ES256 keypair for OAuth client authentication...")
-
-
// Generate ES256 (NIST P-256) private key
-
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
-
if err != nil {
-
log.Fatalf("Failed to generate private key: %v", err)
-
}
-
-
// Convert to JWK
-
jwkKey, err := jwk.FromRaw(privateKey)
-
if err != nil {
-
log.Fatalf("Failed to create JWK from private key: %v", err)
-
}
-
-
// Set key parameters
-
if err = jwkKey.Set(jwk.KeyIDKey, "oauth-client-key"); err != nil {
-
log.Fatalf("Failed to set kid: %v", err)
-
}
-
if err = jwkKey.Set(jwk.AlgorithmKey, "ES256"); err != nil {
-
log.Fatalf("Failed to set alg: %v", err)
-
}
-
if err = jwkKey.Set(jwk.KeyUsageKey, "sig"); err != nil {
-
log.Fatalf("Failed to set use: %v", err)
-
}
-
-
// Marshal to JSON
-
jsonData, err := json.MarshalIndent(jwkKey, "", " ")
-
if err != nil {
-
log.Fatalf("Failed to marshal JWK: %v", err)
-
}
-
-
// Output instructions
-
fmt.Println("\nโœ… ES256 keypair generated successfully!")
-
fmt.Println("\n๐Ÿ“ Add this to your .env.dev file:")
-
fmt.Println("\nOAUTH_PRIVATE_JWK='" + string(jsonData) + "'")
-
fmt.Println("\nโš ๏ธ IMPORTANT:")
-
fmt.Println(" - Keep this private key SECRET")
-
fmt.Println(" - Never commit it to version control")
-
fmt.Println(" - Generate a new key for production")
-
fmt.Println(" - The public key will be automatically derived and served at /oauth/jwks.json")
-
-
// Optionally write to a file (not committed)
-
if len(os.Args) > 1 && os.Args[1] == "--save" {
-
filename := "oauth-private-key.json"
-
if err := os.WriteFile(filename, jsonData, 0o600); err != nil {
-
log.Fatalf("Failed to write key file: %v", err)
-
}
-
fmt.Printf("\n๐Ÿ’พ Private key saved to %s (remember to add to .gitignore!)\n", filename)
-
}
-
}
-330
internal/atproto/auth/README.md
···
-
# atProto OAuth Authentication
-
-
This package implements third-party OAuth authentication for Coves, validating DPoP-bound access tokens from mobile apps and other atProto clients.
-
-
## Architecture
-
-
This is **third-party authentication** (validating incoming requests), not first-party authentication (logging users into Coves web frontend).
-
-
### Components
-
-
1. **JWT Parser** (`jwt.go`) - Parses and validates JWT tokens
-
2. **JWKS Fetcher** (`jwks_fetcher.go`) - Fetches and caches public keys from PDS authorization servers
-
3. **Auth Middleware** (`internal/api/middleware/auth.go`) - HTTP middleware that protects endpoints
-
-
### Flow
-
-
```
-
Client Request
-
โ†“
-
Authorization: DPoP <access_token>
-
DPoP: <proof-jwt>
-
โ†“
-
Auth Middleware
-
โ†“
-
Extract JWT โ†’ Parse Claims โ†’ Verify Signature (via JWKS) โ†’ Verify DPoP Proof
-
โ†“
-
Inject DID into Context โ†’ Call Handler
-
```
-
-
## Usage
-
-
### Phase 1: Parse-Only Mode (Testing)
-
-
Set `AUTH_SKIP_VERIFY=true` to only parse JWTs without signature verification:
-
-
```bash
-
export AUTH_SKIP_VERIFY=true
-
```
-
-
This is useful for:
-
- Initial integration testing
-
- Testing with mock tokens
-
- Debugging JWT structure
-
-
### Phase 2: Full Verification (Production)
-
-
Set `AUTH_SKIP_VERIFY=false` (or unset) to enable full JWT signature verification:
-
-
```bash
-
export AUTH_SKIP_VERIFY=false
-
# or just unset it
-
```
-
-
This is **required for production** and validates:
-
- JWT signature using PDS public key
-
- Token expiration
-
- Required claims (sub, iss)
-
- DID format
-
-
## Protected Endpoints
-
-
The following endpoints require authentication:
-
-
- `POST /xrpc/social.coves.community.create`
-
- `POST /xrpc/social.coves.community.update`
-
- `POST /xrpc/social.coves.community.subscribe`
-
- `POST /xrpc/social.coves.community.unsubscribe`
-
-
### Making Authenticated Requests
-
-
Include the JWT in the `Authorization` header:
-
-
```bash
-
curl -X POST https://coves.social/xrpc/social.coves.community.create \
-
-H "Authorization: DPoP eyJhbGc..." \
-
-H "DPoP: eyJhbGc..." \
-
-H "Content-Type: application/json" \
-
-d '{"name":"Gaming","hostedByDid":"did:plc:..."}'
-
```
-
-
### Getting User DID in Handlers
-
-
The middleware injects the authenticated user's DID into the request context:
-
-
```go
-
import "Coves/internal/api/middleware"
-
-
func (h *Handler) HandleCreate(w http.ResponseWriter, r *http.Request) {
-
// Extract authenticated user DID
-
userDID := middleware.GetUserDID(r)
-
if userDID == "" {
-
// Not authenticated (should never happen with RequireAuth middleware)
-
http.Error(w, "Unauthorized", http.StatusUnauthorized)
-
return
-
}
-
-
// Use userDID for authorization checks
-
// ...
-
}
-
```
-
-
## Key Caching
-
-
Public keys are fetched from PDS authorization servers and cached for 1 hour. The cache is automatically cleaned up hourly to remove expired entries.
-
-
### JWKS Discovery Flow
-
-
1. Extract `iss` claim from JWT (e.g., `https://pds.example.com`)
-
2. Fetch `https://pds.example.com/.well-known/oauth-authorization-server`
-
3. Extract `jwks_uri` from metadata
-
4. Fetch JWKS from `jwks_uri`
-
5. Find matching key by `kid` from JWT header
-
6. Cache the JWKS for 1 hour
-
-
## DPoP Token Binding
-
-
DPoP (Demonstrating Proof-of-Possession) binds access tokens to client-controlled cryptographic keys, preventing token theft and replay attacks.
-
-
### What is DPoP?
-
-
DPoP is an OAuth extension (RFC 9449) that adds proof-of-possession semantics to bearer tokens. When a PDS issues a DPoP-bound access token:
-
-
1. Access token contains `cnf.jkt` claim (JWK thumbprint of client's public key)
-
2. Client creates a DPoP proof JWT signed with their private key
-
3. Server verifies the proof signature and checks it matches the token's `cnf.jkt`
-
-
### CRITICAL: DPoP Security Model
-
-
> โš ๏ธ **DPoP is an ADDITIONAL security layer, NOT a replacement for token signature verification.**
-
-
The correct verification order is:
-
1. **ALWAYS verify the access token signature first** (via JWKS, HS256 shared secret, or DID resolution)
-
2. **If the verified token has `cnf.jkt`, REQUIRE valid DPoP proof**
-
3. **NEVER use DPoP as a fallback when signature verification fails**
-
-
**Why This Matters**: An attacker could create a fake token with `sub: "did:plc:victim"` and their own `cnf.jkt`, then present a valid DPoP proof signed with their key. If we accept DPoP as a fallback, the attacker can impersonate any user.
-
-
### How DPoP Works
-
-
```
-
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
-
โ”‚ Client โ”‚ โ”‚ Server โ”‚
-
โ”‚ โ”‚ โ”‚ (Coves) โ”‚
-
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
-
โ”‚ โ”‚
-
โ”‚ 1. Authorization: DPoP <token> โ”‚
-
โ”‚ DPoP: <proof-jwt> โ”‚
-
โ”‚โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€>โ”‚
-
โ”‚ โ”‚
-
โ”‚ โ”‚ 2. VERIFY token signature
-
โ”‚ โ”‚ (REQUIRED - no fallback!)
-
โ”‚ โ”‚
-
โ”‚ โ”‚ 3. If token has cnf.jkt:
-
โ”‚ โ”‚ - Verify DPoP proof
-
โ”‚ โ”‚ - Check thumbprint match
-
โ”‚ โ”‚
-
โ”‚ 200 OK โ”‚
-
โ”‚<โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”‚
-
```
-
-
### When DPoP is Required
-
-
DPoP verification is **REQUIRED** when:
-
- Access token signature has been verified AND
-
- Access token contains `cnf.jkt` claim (DPoP-bound)
-
-
If the token has `cnf.jkt` but no DPoP header is present, the request is **REJECTED**.
-
-
### Replay Protection
-
-
DPoP proofs include a unique `jti` (JWT ID) claim. The server tracks seen `jti` values to prevent replay attacks:
-
-
```go
-
// Create a verifier with replay protection (default)
-
verifier := auth.NewDPoPVerifier()
-
defer verifier.Stop() // Stop cleanup goroutine on shutdown
-
-
// The verifier automatically rejects reused jti values within the proof validity window (5 minutes)
-
```
-
-
### DPoP Implementation
-
-
The `dpop.go` module provides:
-
-
```go
-
// Create a verifier with replay protection
-
verifier := auth.NewDPoPVerifier()
-
defer verifier.Stop()
-
-
// Verify the DPoP proof
-
proof, err := verifier.VerifyDPoPProof(dpopHeader, "POST", "https://coves.social/xrpc/...")
-
if err != nil {
-
// Invalid proof (includes replay detection)
-
}
-
-
// Verify it binds to the VERIFIED access token
-
expectedThumbprint, err := auth.ExtractCnfJkt(claims)
-
if err != nil {
-
// Token not DPoP-bound
-
}
-
-
if err := verifier.VerifyTokenBinding(proof, expectedThumbprint); err != nil {
-
// Proof doesn't match token
-
}
-
```
-
-
### DPoP Proof Format
-
-
The DPoP header contains a JWT with:
-
-
**Header**:
-
- `typ`: `"dpop+jwt"` (required)
-
- `alg`: `"ES256"` (or other supported algorithm)
-
- `jwk`: Client's public key (JWK format)
-
-
**Claims**:
-
- `jti`: Unique proof identifier (tracked for replay protection)
-
- `htm`: HTTP method (e.g., `"POST"`)
-
- `htu`: HTTP URI (without query/fragment)
-
- `iat`: Timestamp (must be recent, within 5 minutes)
-
-
**Example**:
-
```json
-
{
-
"typ": "dpop+jwt",
-
"alg": "ES256",
-
"jwk": {
-
"kty": "EC",
-
"crv": "P-256",
-
"x": "...",
-
"y": "..."
-
}
-
}
-
{
-
"jti": "unique-id-123",
-
"htm": "POST",
-
"htu": "https://coves.social/xrpc/social.coves.community.create",
-
"iat": 1700000000
-
}
-
```
-
-
## Security Considerations
-
-
### โœ… Implemented
-
-
- JWT signature verification with PDS public keys
-
- Token expiration validation
-
- DID format validation
-
- Required claims validation (sub, iss)
-
- Key caching with TTL
-
- Secure error messages (no internal details leaked)
-
- **DPoP proof verification** (proof-of-possession for token binding)
-
- **DPoP thumbprint validation** (prevents token theft attacks)
-
- **DPoP freshness checks** (5-minute proof validity window)
-
- **DPoP replay protection** (jti tracking with in-memory cache)
-
- **Secure DPoP model** (DPoP required AFTER signature verification, never as fallback)
-
-
### โš ๏ธ Not Yet Implemented
-
-
- Server-issued DPoP nonces (additional replay protection)
-
- Scope validation (checking `scope` claim)
-
- Audience validation (checking `aud` claim)
-
- Rate limiting per DID
-
- Token revocation checking
-
-
## Testing
-
-
Run the test suite:
-
-
```bash
-
go test ./internal/atproto/auth/... -v
-
```
-
-
### Manual Testing
-
-
1. **Phase 1 (Parse Only)**:
-
```bash
-
# Create a test JWT (use jwt.io or a tool)
-
export AUTH_SKIP_VERIFY=true
-
curl -X POST http://localhost:8081/xrpc/social.coves.community.create \
-
-H "Authorization: DPoP <test-jwt>" \
-
-H "DPoP: <test-dpop-proof>" \
-
-d '{"name":"Test","hostedByDid":"did:plc:test"}'
-
```
-
-
2. **Phase 2 (Full Verification)**:
-
```bash
-
# Use a real JWT from a PDS
-
export AUTH_SKIP_VERIFY=false
-
curl -X POST http://localhost:8081/xrpc/social.coves.community.create \
-
-H "Authorization: DPoP <real-jwt>" \
-
-H "DPoP: <real-dpop-proof>" \
-
-d '{"name":"Test","hostedByDid":"did:plc:test"}'
-
```
-
-
## Error Responses
-
-
### 401 Unauthorized
-
-
Missing or invalid token:
-
-
```json
-
{
-
"error": "AuthenticationRequired",
-
"message": "Missing Authorization header"
-
}
-
```
-
-
```json
-
{
-
"error": "AuthenticationRequired",
-
"message": "Invalid or expired token"
-
}
-
```
-
-
### Common Issues
-
-
1. **Missing Authorization header** โ†’ Add `Authorization: DPoP <token>` and `DPoP: <proof>`
-
2. **Token expired** โ†’ Get a new token from PDS
-
3. **Invalid signature** โ†’ Ensure token is from a valid PDS
-
4. **JWKS fetch fails** โ†’ Check PDS availability and network connectivity
-
-
## Future Enhancements
-
-
- [ ] DPoP nonce validation (server-managed nonce for additional replay protection)
-
- [ ] Scope-based authorization
-
- [ ] Audience claim validation
-
- [ ] Token revocation support
-
- [ ] Rate limiting per DID
-
- [ ] Metrics and monitoring
-122
internal/atproto/auth/did_key_fetcher.go
···
-
package auth
-
-
import (
-
"context"
-
"crypto/ecdsa"
-
"crypto/elliptic"
-
"encoding/base64"
-
"fmt"
-
"math/big"
-
"strings"
-
-
indigoCrypto "github.com/bluesky-social/indigo/atproto/atcrypto"
-
indigoIdentity "github.com/bluesky-social/indigo/atproto/identity"
-
"github.com/bluesky-social/indigo/atproto/syntax"
-
)
-
-
// DIDKeyFetcher fetches public keys from DID documents for JWT verification.
-
// This is the primary method for atproto service authentication, where:
-
// - The JWT issuer is the user's DID (e.g., did:plc:abc123)
-
// - The signing key is published in the user's DID document
-
// - Verification happens by resolving the DID and checking the signature
-
type DIDKeyFetcher struct {
-
directory indigoIdentity.Directory
-
}
-
-
// NewDIDKeyFetcher creates a new DID-based key fetcher.
-
func NewDIDKeyFetcher(directory indigoIdentity.Directory) *DIDKeyFetcher {
-
return &DIDKeyFetcher{
-
directory: directory,
-
}
-
}
-
-
// FetchPublicKey fetches the public key for verifying a JWT from the issuer's DID document.
-
// For DID issuers (did:plc: or did:web:), resolves the DID and extracts the signing key.
-
//
-
// Returns:
-
// - indigoCrypto.PublicKey for secp256k1 (ES256K) keys - use indigo for verification
-
// - *ecdsa.PublicKey for NIST curves (P-256, P-384, P-521) - compatible with golang-jwt
-
func (f *DIDKeyFetcher) FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) {
-
// Only handle DID issuers
-
if !strings.HasPrefix(issuer, "did:") {
-
return nil, fmt.Errorf("DIDKeyFetcher only handles DID issuers, got: %s", issuer)
-
}
-
-
// Parse the DID
-
did, err := syntax.ParseDID(issuer)
-
if err != nil {
-
return nil, fmt.Errorf("invalid DID format: %w", err)
-
}
-
-
// Resolve the DID to get the identity (includes public keys)
-
ident, err := f.directory.LookupDID(ctx, did)
-
if err != nil {
-
return nil, fmt.Errorf("failed to resolve DID %s: %w", issuer, err)
-
}
-
-
// Get the atproto signing key from the DID document
-
pubKey, err := ident.PublicKey()
-
if err != nil {
-
return nil, fmt.Errorf("failed to get public key from DID document: %w", err)
-
}
-
-
// Convert to JWK format to check curve type
-
jwk, err := pubKey.JWK()
-
if err != nil {
-
return nil, fmt.Errorf("failed to convert public key to JWK: %w", err)
-
}
-
-
// For secp256k1 (ES256K), return indigo's PublicKey directly
-
// since Go's crypto/ecdsa doesn't support this curve
-
if jwk.Curve == "secp256k1" {
-
return pubKey, nil
-
}
-
-
// For NIST curves, convert to Go's ecdsa.PublicKey for golang-jwt compatibility
-
return atcryptoJWKToECDSA(jwk)
-
}
-
-
// atcryptoJWKToECDSA converts an indigoCrypto.JWK to a Go ecdsa.PublicKey.
-
// Note: secp256k1 is handled separately in FetchPublicKey by returning indigo's PublicKey directly.
-
func atcryptoJWKToECDSA(jwk *indigoCrypto.JWK) (*ecdsa.PublicKey, error) {
-
if jwk.KeyType != "EC" {
-
return nil, fmt.Errorf("unsupported JWK key type: %s (expected EC)", jwk.KeyType)
-
}
-
-
// Decode X and Y coordinates (base64url, no padding)
-
xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X)
-
if err != nil {
-
return nil, fmt.Errorf("invalid JWK X coordinate encoding: %w", err)
-
}
-
yBytes, err := base64.RawURLEncoding.DecodeString(jwk.Y)
-
if err != nil {
-
return nil, fmt.Errorf("invalid JWK Y coordinate encoding: %w", err)
-
}
-
-
var ecCurve elliptic.Curve
-
switch jwk.Curve {
-
case "P-256":
-
ecCurve = elliptic.P256()
-
case "P-384":
-
ecCurve = elliptic.P384()
-
case "P-521":
-
ecCurve = elliptic.P521()
-
default:
-
// secp256k1 should be handled before calling this function
-
return nil, fmt.Errorf("unsupported JWK curve for Go ecdsa: %s (secp256k1 uses indigo)", jwk.Curve)
-
}
-
-
// Create the public key
-
pubKey := &ecdsa.PublicKey{
-
Curve: ecCurve,
-
X: new(big.Int).SetBytes(xBytes),
-
Y: new(big.Int).SetBytes(yBytes),
-
}
-
-
// Validate point is on curve
-
if !ecCurve.IsOnCurve(pubKey.X, pubKey.Y) {
-
return nil, fmt.Errorf("invalid public key: point not on curve")
-
}
-
-
return pubKey, nil
-
}
-1308
internal/atproto/auth/dpop_test.go
···
-
package auth
-
-
import (
-
"crypto/ecdsa"
-
"crypto/elliptic"
-
"crypto/rand"
-
"crypto/sha256"
-
"encoding/base64"
-
"encoding/json"
-
"strings"
-
"testing"
-
"time"
-
-
indigoCrypto "github.com/bluesky-social/indigo/atproto/atcrypto"
-
"github.com/golang-jwt/jwt/v5"
-
"github.com/google/uuid"
-
)
-
-
// === Test Helpers ===
-
-
// testECKey holds a test ES256 key pair
-
type testECKey struct {
-
privateKey *ecdsa.PrivateKey
-
publicKey *ecdsa.PublicKey
-
jwk map[string]interface{}
-
thumbprint string
-
}
-
-
// generateTestES256Key generates a test ES256 key pair and JWK
-
func generateTestES256Key(t *testing.T) *testECKey {
-
t.Helper()
-
-
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
-
if err != nil {
-
t.Fatalf("Failed to generate test key: %v", err)
-
}
-
-
// Encode public key coordinates as base64url
-
xBytes := privateKey.PublicKey.X.Bytes()
-
yBytes := privateKey.PublicKey.Y.Bytes()
-
-
// P-256 coordinates must be 32 bytes (pad if needed)
-
xBytes = padTo32Bytes(xBytes)
-
yBytes = padTo32Bytes(yBytes)
-
-
x := base64.RawURLEncoding.EncodeToString(xBytes)
-
y := base64.RawURLEncoding.EncodeToString(yBytes)
-
-
jwk := map[string]interface{}{
-
"kty": "EC",
-
"crv": "P-256",
-
"x": x,
-
"y": y,
-
}
-
-
// Calculate thumbprint
-
thumbprint, err := CalculateJWKThumbprint(jwk)
-
if err != nil {
-
t.Fatalf("Failed to calculate thumbprint: %v", err)
-
}
-
-
return &testECKey{
-
privateKey: privateKey,
-
publicKey: &privateKey.PublicKey,
-
jwk: jwk,
-
thumbprint: thumbprint,
-
}
-
}
-
-
// padTo32Bytes pads a byte slice to 32 bytes (required for P-256 coordinates)
-
func padTo32Bytes(b []byte) []byte {
-
if len(b) >= 32 {
-
return b
-
}
-
padded := make([]byte, 32)
-
copy(padded[32-len(b):], b)
-
return padded
-
}
-
-
// createDPoPProof creates a DPoP proof JWT for testing
-
func createDPoPProof(t *testing.T, key *testECKey, method, uri string, iat time.Time, jti string) string {
-
t.Helper()
-
-
claims := &DPoPClaims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
ID: jti,
-
IssuedAt: jwt.NewNumericDate(iat),
-
},
-
HTTPMethod: method,
-
HTTPURI: uri,
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
-
token.Header["typ"] = "dpop+jwt"
-
token.Header["jwk"] = key.jwk
-
-
tokenString, err := token.SignedString(key.privateKey)
-
if err != nil {
-
t.Fatalf("Failed to create DPoP proof: %v", err)
-
}
-
-
return tokenString
-
}
-
-
// === JWK Thumbprint Tests (RFC 7638) ===
-
-
func TestCalculateJWKThumbprint_EC_P256(t *testing.T) {
-
// Test with known values from RFC 7638 Appendix A (adapted for P-256)
-
jwk := map[string]interface{}{
-
"kty": "EC",
-
"crv": "P-256",
-
"x": "WKn-ZIGevcwGIyyrzFoZNBdaq9_TsqzGl96oc0CWuis",
-
"y": "y77t-RvAHRKTsSGdIYUfweuOvwrvDD-Q3Hv5J0fSKbE",
-
}
-
-
thumbprint, err := CalculateJWKThumbprint(jwk)
-
if err != nil {
-
t.Fatalf("CalculateJWKThumbprint failed: %v", err)
-
}
-
-
if thumbprint == "" {
-
t.Error("Expected non-empty thumbprint")
-
}
-
-
// Verify it's valid base64url
-
_, err = base64.RawURLEncoding.DecodeString(thumbprint)
-
if err != nil {
-
t.Errorf("Thumbprint is not valid base64url: %v", err)
-
}
-
-
// Verify length (SHA-256 produces 32 bytes = 43 base64url chars)
-
if len(thumbprint) != 43 {
-
t.Errorf("Expected thumbprint length 43, got %d", len(thumbprint))
-
}
-
}
-
-
func TestCalculateJWKThumbprint_Deterministic(t *testing.T) {
-
// Same key should produce same thumbprint
-
jwk := map[string]interface{}{
-
"kty": "EC",
-
"crv": "P-256",
-
"x": "test-x-coordinate",
-
"y": "test-y-coordinate",
-
}
-
-
thumbprint1, err := CalculateJWKThumbprint(jwk)
-
if err != nil {
-
t.Fatalf("First CalculateJWKThumbprint failed: %v", err)
-
}
-
-
thumbprint2, err := CalculateJWKThumbprint(jwk)
-
if err != nil {
-
t.Fatalf("Second CalculateJWKThumbprint failed: %v", err)
-
}
-
-
if thumbprint1 != thumbprint2 {
-
t.Errorf("Thumbprints are not deterministic: %s != %s", thumbprint1, thumbprint2)
-
}
-
}
-
-
func TestCalculateJWKThumbprint_DifferentKeys(t *testing.T) {
-
// Different keys should produce different thumbprints
-
jwk1 := map[string]interface{}{
-
"kty": "EC",
-
"crv": "P-256",
-
"x": "coordinate-x-1",
-
"y": "coordinate-y-1",
-
}
-
-
jwk2 := map[string]interface{}{
-
"kty": "EC",
-
"crv": "P-256",
-
"x": "coordinate-x-2",
-
"y": "coordinate-y-2",
-
}
-
-
thumbprint1, err := CalculateJWKThumbprint(jwk1)
-
if err != nil {
-
t.Fatalf("First CalculateJWKThumbprint failed: %v", err)
-
}
-
-
thumbprint2, err := CalculateJWKThumbprint(jwk2)
-
if err != nil {
-
t.Fatalf("Second CalculateJWKThumbprint failed: %v", err)
-
}
-
-
if thumbprint1 == thumbprint2 {
-
t.Error("Different keys produced same thumbprint (collision)")
-
}
-
}
-
-
func TestCalculateJWKThumbprint_MissingKty(t *testing.T) {
-
jwk := map[string]interface{}{
-
"crv": "P-256",
-
"x": "test-x",
-
"y": "test-y",
-
}
-
-
_, err := CalculateJWKThumbprint(jwk)
-
if err == nil {
-
t.Error("Expected error for missing kty, got nil")
-
}
-
if err != nil && !contains(err.Error(), "missing kty") {
-
t.Errorf("Expected error about missing kty, got: %v", err)
-
}
-
}
-
-
func TestCalculateJWKThumbprint_EC_MissingCrv(t *testing.T) {
-
jwk := map[string]interface{}{
-
"kty": "EC",
-
"x": "test-x",
-
"y": "test-y",
-
}
-
-
_, err := CalculateJWKThumbprint(jwk)
-
if err == nil {
-
t.Error("Expected error for missing crv, got nil")
-
}
-
if err != nil && !contains(err.Error(), "missing crv") {
-
t.Errorf("Expected error about missing crv, got: %v", err)
-
}
-
}
-
-
func TestCalculateJWKThumbprint_EC_MissingX(t *testing.T) {
-
jwk := map[string]interface{}{
-
"kty": "EC",
-
"crv": "P-256",
-
"y": "test-y",
-
}
-
-
_, err := CalculateJWKThumbprint(jwk)
-
if err == nil {
-
t.Error("Expected error for missing x, got nil")
-
}
-
if err != nil && !contains(err.Error(), "missing x") {
-
t.Errorf("Expected error about missing x, got: %v", err)
-
}
-
}
-
-
func TestCalculateJWKThumbprint_EC_MissingY(t *testing.T) {
-
jwk := map[string]interface{}{
-
"kty": "EC",
-
"crv": "P-256",
-
"x": "test-x",
-
}
-
-
_, err := CalculateJWKThumbprint(jwk)
-
if err == nil {
-
t.Error("Expected error for missing y, got nil")
-
}
-
if err != nil && !contains(err.Error(), "missing y") {
-
t.Errorf("Expected error about missing y, got: %v", err)
-
}
-
}
-
-
func TestCalculateJWKThumbprint_RSA(t *testing.T) {
-
// Test RSA key thumbprint calculation
-
jwk := map[string]interface{}{
-
"kty": "RSA",
-
"e": "AQAB",
-
"n": "test-modulus",
-
}
-
-
thumbprint, err := CalculateJWKThumbprint(jwk)
-
if err != nil {
-
t.Fatalf("CalculateJWKThumbprint failed for RSA: %v", err)
-
}
-
-
if thumbprint == "" {
-
t.Error("Expected non-empty thumbprint for RSA key")
-
}
-
}
-
-
func TestCalculateJWKThumbprint_OKP(t *testing.T) {
-
// Test OKP (Octet Key Pair) thumbprint calculation
-
jwk := map[string]interface{}{
-
"kty": "OKP",
-
"crv": "Ed25519",
-
"x": "test-x-coordinate",
-
}
-
-
thumbprint, err := CalculateJWKThumbprint(jwk)
-
if err != nil {
-
t.Fatalf("CalculateJWKThumbprint failed for OKP: %v", err)
-
}
-
-
if thumbprint == "" {
-
t.Error("Expected non-empty thumbprint for OKP key")
-
}
-
}
-
-
func TestCalculateJWKThumbprint_UnsupportedKeyType(t *testing.T) {
-
jwk := map[string]interface{}{
-
"kty": "UNKNOWN",
-
}
-
-
_, err := CalculateJWKThumbprint(jwk)
-
if err == nil {
-
t.Error("Expected error for unsupported key type, got nil")
-
}
-
if err != nil && !contains(err.Error(), "unsupported JWK key type") {
-
t.Errorf("Expected error about unsupported key type, got: %v", err)
-
}
-
}
-
-
func TestCalculateJWKThumbprint_CanonicalJSON(t *testing.T) {
-
// RFC 7638 requires lexicographic ordering of keys in canonical JSON
-
// This test verifies that the canonical JSON is correctly ordered
-
-
jwk := map[string]interface{}{
-
"kty": "EC",
-
"crv": "P-256",
-
"x": "x-coord",
-
"y": "y-coord",
-
}
-
-
// The canonical JSON should be: {"crv":"P-256","kty":"EC","x":"x-coord","y":"y-coord"}
-
// (lexicographically ordered: crv, kty, x, y)
-
-
canonical := map[string]string{
-
"crv": "P-256",
-
"kty": "EC",
-
"x": "x-coord",
-
"y": "y-coord",
-
}
-
-
canonicalJSON, err := json.Marshal(canonical)
-
if err != nil {
-
t.Fatalf("Failed to marshal canonical JSON: %v", err)
-
}
-
-
expectedHash := sha256.Sum256(canonicalJSON)
-
expectedThumbprint := base64.RawURLEncoding.EncodeToString(expectedHash[:])
-
-
actualThumbprint, err := CalculateJWKThumbprint(jwk)
-
if err != nil {
-
t.Fatalf("CalculateJWKThumbprint failed: %v", err)
-
}
-
-
if actualThumbprint != expectedThumbprint {
-
t.Errorf("Thumbprint doesn't match expected canonical JSON hash\nExpected: %s\nGot: %s",
-
expectedThumbprint, actualThumbprint)
-
}
-
}
-
-
// === DPoP Proof Verification Tests ===
-
-
func TestVerifyDPoPProof_Valid(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, uri, iat, jti)
-
-
result, err := verifier.VerifyDPoPProof(proof, method, uri)
-
if err != nil {
-
t.Fatalf("VerifyDPoPProof failed for valid proof: %v", err)
-
}
-
-
if result == nil {
-
t.Fatal("Expected non-nil proof result")
-
}
-
-
if result.Claims.HTTPMethod != method {
-
t.Errorf("Expected method %s, got %s", method, result.Claims.HTTPMethod)
-
}
-
-
if result.Claims.HTTPURI != uri {
-
t.Errorf("Expected URI %s, got %s", uri, result.Claims.HTTPURI)
-
}
-
-
if result.Claims.ID != jti {
-
t.Errorf("Expected jti %s, got %s", jti, result.Claims.ID)
-
}
-
-
if result.Thumbprint != key.thumbprint {
-
t.Errorf("Expected thumbprint %s, got %s", key.thumbprint, result.Thumbprint)
-
}
-
}
-
-
func TestVerifyDPoPProof_InvalidSignature(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
wrongKey := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
// Create proof with one key
-
proof := createDPoPProof(t, key, method, uri, iat, jti)
-
-
// Parse and modify to use wrong key's JWK in header (signature won't match)
-
parts := splitJWT(proof)
-
header := parseJWTHeader(t, parts[0])
-
header["jwk"] = wrongKey.jwk
-
modifiedHeader := encodeJSON(t, header)
-
tamperedProof := modifiedHeader + "." + parts[1] + "." + parts[2]
-
-
_, err := verifier.VerifyDPoPProof(tamperedProof, method, uri)
-
if err == nil {
-
t.Error("Expected error for invalid signature, got nil")
-
}
-
if err != nil && !contains(err.Error(), "signature verification failed") {
-
t.Errorf("Expected signature verification error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_WrongHTTPMethod(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
wrongMethod := "GET"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, uri, iat, jti)
-
-
_, err := verifier.VerifyDPoPProof(proof, wrongMethod, uri)
-
if err == nil {
-
t.Error("Expected error for HTTP method mismatch, got nil")
-
}
-
if err != nil && !contains(err.Error(), "htm mismatch") {
-
t.Errorf("Expected htm mismatch error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_WrongURI(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
wrongURI := "https://api.example.com/different"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, uri, iat, jti)
-
-
_, err := verifier.VerifyDPoPProof(proof, method, wrongURI)
-
if err == nil {
-
t.Error("Expected error for URI mismatch, got nil")
-
}
-
if err != nil && !contains(err.Error(), "htu mismatch") {
-
t.Errorf("Expected htu mismatch error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_URIWithQuery(t *testing.T) {
-
// URI comparison should strip query and fragment
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
baseURI := "https://api.example.com/resource"
-
uriWithQuery := baseURI + "?param=value"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, baseURI, iat, jti)
-
-
// Should succeed because query is stripped
-
_, err := verifier.VerifyDPoPProof(proof, method, uriWithQuery)
-
if err != nil {
-
t.Fatalf("VerifyDPoPProof failed for URI with query: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_URIWithFragment(t *testing.T) {
-
// URI comparison should strip query and fragment
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
baseURI := "https://api.example.com/resource"
-
uriWithFragment := baseURI + "#section"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, baseURI, iat, jti)
-
-
// Should succeed because fragment is stripped
-
_, err := verifier.VerifyDPoPProof(proof, method, uriWithFragment)
-
if err != nil {
-
t.Fatalf("VerifyDPoPProof failed for URI with fragment: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_ExpiredProof(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
// Proof issued 10 minutes ago (exceeds default MaxProofAge of 5 minutes)
-
iat := time.Now().Add(-10 * time.Minute)
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, uri, iat, jti)
-
-
_, err := verifier.VerifyDPoPProof(proof, method, uri)
-
if err == nil {
-
t.Error("Expected error for expired proof, got nil")
-
}
-
if err != nil && !contains(err.Error(), "too old") {
-
t.Errorf("Expected 'too old' error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_FutureProof(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
// Proof issued 1 minute in the future (exceeds MaxClockSkew)
-
iat := time.Now().Add(1 * time.Minute)
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, uri, iat, jti)
-
-
_, err := verifier.VerifyDPoPProof(proof, method, uri)
-
if err == nil {
-
t.Error("Expected error for future proof, got nil")
-
}
-
if err != nil && !contains(err.Error(), "in the future") {
-
t.Errorf("Expected 'in the future' error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_WithinClockSkew(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
// Proof issued 15 seconds in the future (within MaxClockSkew of 30s)
-
iat := time.Now().Add(15 * time.Second)
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, uri, iat, jti)
-
-
_, err := verifier.VerifyDPoPProof(proof, method, uri)
-
if err != nil {
-
t.Fatalf("VerifyDPoPProof failed for proof within clock skew: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_MissingJti(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
-
claims := &DPoPClaims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
// No ID (jti)
-
IssuedAt: jwt.NewNumericDate(iat),
-
},
-
HTTPMethod: method,
-
HTTPURI: uri,
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
-
token.Header["typ"] = "dpop+jwt"
-
token.Header["jwk"] = key.jwk
-
-
proof, err := token.SignedString(key.privateKey)
-
if err != nil {
-
t.Fatalf("Failed to create test proof: %v", err)
-
}
-
-
_, err = verifier.VerifyDPoPProof(proof, method, uri)
-
if err == nil {
-
t.Error("Expected error for missing jti, got nil")
-
}
-
if err != nil && !contains(err.Error(), "missing jti") {
-
t.Errorf("Expected missing jti error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_MissingTypHeader(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
claims := &DPoPClaims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
ID: jti,
-
IssuedAt: jwt.NewNumericDate(iat),
-
},
-
HTTPMethod: method,
-
HTTPURI: uri,
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
-
// Don't set typ header
-
token.Header["jwk"] = key.jwk
-
-
proof, err := token.SignedString(key.privateKey)
-
if err != nil {
-
t.Fatalf("Failed to create test proof: %v", err)
-
}
-
-
_, err = verifier.VerifyDPoPProof(proof, method, uri)
-
if err == nil {
-
t.Error("Expected error for missing typ header, got nil")
-
}
-
if err != nil && !contains(err.Error(), "typ must be 'dpop+jwt'") {
-
t.Errorf("Expected typ header error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_WrongTypHeader(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
claims := &DPoPClaims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
ID: jti,
-
IssuedAt: jwt.NewNumericDate(iat),
-
},
-
HTTPMethod: method,
-
HTTPURI: uri,
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
-
token.Header["typ"] = "JWT" // Wrong typ
-
token.Header["jwk"] = key.jwk
-
-
proof, err := token.SignedString(key.privateKey)
-
if err != nil {
-
t.Fatalf("Failed to create test proof: %v", err)
-
}
-
-
_, err = verifier.VerifyDPoPProof(proof, method, uri)
-
if err == nil {
-
t.Error("Expected error for wrong typ header, got nil")
-
}
-
if err != nil && !contains(err.Error(), "typ must be 'dpop+jwt'") {
-
t.Errorf("Expected typ header error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_MissingJWK(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
claims := &DPoPClaims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
ID: jti,
-
IssuedAt: jwt.NewNumericDate(iat),
-
},
-
HTTPMethod: method,
-
HTTPURI: uri,
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
-
token.Header["typ"] = "dpop+jwt"
-
// Don't include JWK
-
-
proof, err := token.SignedString(key.privateKey)
-
if err != nil {
-
t.Fatalf("Failed to create test proof: %v", err)
-
}
-
-
_, err = verifier.VerifyDPoPProof(proof, method, uri)
-
if err == nil {
-
t.Error("Expected error for missing jwk header, got nil")
-
}
-
if err != nil && !contains(err.Error(), "missing jwk") {
-
t.Errorf("Expected missing jwk error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_CustomTimeSettings(t *testing.T) {
-
verifier := &DPoPVerifier{
-
MaxClockSkew: 1 * time.Minute,
-
MaxProofAge: 10 * time.Minute,
-
}
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
// Proof issued 50 seconds in the future (within custom MaxClockSkew)
-
iat := time.Now().Add(50 * time.Second)
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, uri, iat, jti)
-
-
_, err := verifier.VerifyDPoPProof(proof, method, uri)
-
if err != nil {
-
t.Fatalf("VerifyDPoPProof failed with custom time settings: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_HTTPMethodCaseInsensitive(t *testing.T) {
-
// HTTP method comparison should be case-insensitive per spec
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "post"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, uri, iat, jti)
-
-
// Verify with uppercase method
-
_, err := verifier.VerifyDPoPProof(proof, "POST", uri)
-
if err != nil {
-
t.Fatalf("VerifyDPoPProof failed for case-insensitive method: %v", err)
-
}
-
}
-
-
// === Token Binding Verification Tests ===
-
-
func TestVerifyTokenBinding_Matching(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, uri, iat, jti)
-
-
result, err := verifier.VerifyDPoPProof(proof, method, uri)
-
if err != nil {
-
t.Fatalf("VerifyDPoPProof failed: %v", err)
-
}
-
-
// Verify token binding with matching thumbprint
-
err = verifier.VerifyTokenBinding(result, key.thumbprint)
-
if err != nil {
-
t.Fatalf("VerifyTokenBinding failed for matching thumbprint: %v", err)
-
}
-
}
-
-
func TestVerifyTokenBinding_Mismatch(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
wrongKey := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, uri, iat, jti)
-
-
result, err := verifier.VerifyDPoPProof(proof, method, uri)
-
if err != nil {
-
t.Fatalf("VerifyDPoPProof failed: %v", err)
-
}
-
-
// Verify token binding with wrong thumbprint
-
err = verifier.VerifyTokenBinding(result, wrongKey.thumbprint)
-
if err == nil {
-
t.Error("Expected error for thumbprint mismatch, got nil")
-
}
-
if err != nil && !contains(err.Error(), "thumbprint mismatch") {
-
t.Errorf("Expected thumbprint mismatch error, got: %v", err)
-
}
-
}
-
-
// === ExtractCnfJkt Tests ===
-
-
func TestExtractCnfJkt_Valid(t *testing.T) {
-
expectedJkt := "test-thumbprint-123"
-
claims := &Claims{
-
Confirmation: map[string]interface{}{
-
"jkt": expectedJkt,
-
},
-
}
-
-
jkt, err := ExtractCnfJkt(claims)
-
if err != nil {
-
t.Fatalf("ExtractCnfJkt failed for valid claims: %v", err)
-
}
-
-
if jkt != expectedJkt {
-
t.Errorf("Expected jkt %s, got %s", expectedJkt, jkt)
-
}
-
}
-
-
func TestExtractCnfJkt_MissingCnf(t *testing.T) {
-
claims := &Claims{
-
// No Confirmation
-
}
-
-
_, err := ExtractCnfJkt(claims)
-
if err == nil {
-
t.Error("Expected error for missing cnf, got nil")
-
}
-
if err != nil && !contains(err.Error(), "missing cnf claim") {
-
t.Errorf("Expected missing cnf error, got: %v", err)
-
}
-
}
-
-
func TestExtractCnfJkt_NilCnf(t *testing.T) {
-
claims := &Claims{
-
Confirmation: nil,
-
}
-
-
_, err := ExtractCnfJkt(claims)
-
if err == nil {
-
t.Error("Expected error for nil cnf, got nil")
-
}
-
if err != nil && !contains(err.Error(), "missing cnf claim") {
-
t.Errorf("Expected missing cnf error, got: %v", err)
-
}
-
}
-
-
func TestExtractCnfJkt_MissingJkt(t *testing.T) {
-
claims := &Claims{
-
Confirmation: map[string]interface{}{
-
"other": "value",
-
},
-
}
-
-
_, err := ExtractCnfJkt(claims)
-
if err == nil {
-
t.Error("Expected error for missing jkt, got nil")
-
}
-
if err != nil && !contains(err.Error(), "missing jkt") {
-
t.Errorf("Expected missing jkt error, got: %v", err)
-
}
-
}
-
-
func TestExtractCnfJkt_EmptyJkt(t *testing.T) {
-
claims := &Claims{
-
Confirmation: map[string]interface{}{
-
"jkt": "",
-
},
-
}
-
-
_, err := ExtractCnfJkt(claims)
-
if err == nil {
-
t.Error("Expected error for empty jkt, got nil")
-
}
-
if err != nil && !contains(err.Error(), "missing jkt") {
-
t.Errorf("Expected missing jkt error, got: %v", err)
-
}
-
}
-
-
func TestExtractCnfJkt_WrongType(t *testing.T) {
-
claims := &Claims{
-
Confirmation: map[string]interface{}{
-
"jkt": 123, // Not a string
-
},
-
}
-
-
_, err := ExtractCnfJkt(claims)
-
if err == nil {
-
t.Error("Expected error for wrong type jkt, got nil")
-
}
-
if err != nil && !contains(err.Error(), "missing jkt") {
-
t.Errorf("Expected missing jkt error, got: %v", err)
-
}
-
}
-
-
// === Helper Functions for Tests ===
-
-
// splitJWT splits a JWT into its three parts
-
func splitJWT(token string) []string {
-
return []string{
-
token[:strings.IndexByte(token, '.')],
-
token[strings.IndexByte(token, '.')+1 : strings.LastIndexByte(token, '.')],
-
token[strings.LastIndexByte(token, '.')+1:],
-
}
-
}
-
-
// parseJWTHeader parses a base64url-encoded JWT header
-
func parseJWTHeader(t *testing.T, encoded string) map[string]interface{} {
-
t.Helper()
-
decoded, err := base64.RawURLEncoding.DecodeString(encoded)
-
if err != nil {
-
t.Fatalf("Failed to decode header: %v", err)
-
}
-
-
var header map[string]interface{}
-
if err := json.Unmarshal(decoded, &header); err != nil {
-
t.Fatalf("Failed to unmarshal header: %v", err)
-
}
-
-
return header
-
}
-
-
// encodeJSON encodes a value to base64url-encoded JSON
-
func encodeJSON(t *testing.T, v interface{}) string {
-
t.Helper()
-
data, err := json.Marshal(v)
-
if err != nil {
-
t.Fatalf("Failed to marshal JSON: %v", err)
-
}
-
return base64.RawURLEncoding.EncodeToString(data)
-
}
-
-
// === ES256K (secp256k1) Test Helpers ===
-
-
// testES256KKey holds a test ES256K key pair using indigo
-
type testES256KKey struct {
-
privateKey indigoCrypto.PrivateKey
-
publicKey indigoCrypto.PublicKey
-
jwk map[string]interface{}
-
thumbprint string
-
}
-
-
// generateTestES256KKey generates a test ES256K (secp256k1) key pair and JWK
-
func generateTestES256KKey(t *testing.T) *testES256KKey {
-
t.Helper()
-
-
privateKey, err := indigoCrypto.GeneratePrivateKeyK256()
-
if err != nil {
-
t.Fatalf("Failed to generate ES256K test key: %v", err)
-
}
-
-
publicKey, err := privateKey.PublicKey()
-
if err != nil {
-
t.Fatalf("Failed to get public key from ES256K private key: %v", err)
-
}
-
-
// Get the JWK representation
-
jwkStruct, err := publicKey.JWK()
-
if err != nil {
-
t.Fatalf("Failed to get JWK from ES256K public key: %v", err)
-
}
-
jwk := map[string]interface{}{
-
"kty": jwkStruct.KeyType,
-
"crv": jwkStruct.Curve,
-
"x": jwkStruct.X,
-
"y": jwkStruct.Y,
-
}
-
-
// Calculate thumbprint
-
thumbprint, err := CalculateJWKThumbprint(jwk)
-
if err != nil {
-
t.Fatalf("Failed to calculate ES256K thumbprint: %v", err)
-
}
-
-
return &testES256KKey{
-
privateKey: privateKey,
-
publicKey: publicKey,
-
jwk: jwk,
-
thumbprint: thumbprint,
-
}
-
}
-
-
// createES256KDPoPProof creates a DPoP proof JWT using ES256K for testing
-
func createES256KDPoPProof(t *testing.T, key *testES256KKey, method, uri string, iat time.Time, jti string) string {
-
t.Helper()
-
-
// Build claims
-
claims := map[string]interface{}{
-
"jti": jti,
-
"iat": iat.Unix(),
-
"htm": method,
-
"htu": uri,
-
}
-
-
// Build header
-
header := map[string]interface{}{
-
"typ": "dpop+jwt",
-
"alg": "ES256K",
-
"jwk": key.jwk,
-
}
-
-
// Encode header and claims
-
headerJSON, err := json.Marshal(header)
-
if err != nil {
-
t.Fatalf("Failed to marshal header: %v", err)
-
}
-
claimsJSON, err := json.Marshal(claims)
-
if err != nil {
-
t.Fatalf("Failed to marshal claims: %v", err)
-
}
-
-
headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
-
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
-
-
// Sign with indigo
-
signingInput := headerB64 + "." + claimsB64
-
signature, err := key.privateKey.HashAndSign([]byte(signingInput))
-
if err != nil {
-
t.Fatalf("Failed to sign ES256K proof: %v", err)
-
}
-
-
signatureB64 := base64.RawURLEncoding.EncodeToString(signature)
-
return signingInput + "." + signatureB64
-
}
-
-
// === ES256K Tests ===
-
-
func TestVerifyDPoPProof_ES256K_Valid(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256KKey(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
proof := createES256KDPoPProof(t, key, method, uri, iat, jti)
-
-
result, err := verifier.VerifyDPoPProof(proof, method, uri)
-
if err != nil {
-
t.Fatalf("VerifyDPoPProof failed for valid ES256K proof: %v", err)
-
}
-
-
if result == nil {
-
t.Fatal("Expected non-nil proof result")
-
}
-
-
if result.Claims.HTTPMethod != method {
-
t.Errorf("Expected method %s, got %s", method, result.Claims.HTTPMethod)
-
}
-
-
if result.Claims.HTTPURI != uri {
-
t.Errorf("Expected URI %s, got %s", uri, result.Claims.HTTPURI)
-
}
-
-
if result.Thumbprint != key.thumbprint {
-
t.Errorf("Expected thumbprint %s, got %s", key.thumbprint, result.Thumbprint)
-
}
-
}
-
-
func TestVerifyDPoPProof_ES256K_InvalidSignature(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256KKey(t)
-
wrongKey := generateTestES256KKey(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
// Create proof with one key
-
proof := createES256KDPoPProof(t, key, method, uri, iat, jti)
-
-
// Tamper by replacing JWK with wrong key
-
parts := splitJWT(proof)
-
header := parseJWTHeader(t, parts[0])
-
header["jwk"] = wrongKey.jwk
-
modifiedHeader := encodeJSON(t, header)
-
tamperedProof := modifiedHeader + "." + parts[1] + "." + parts[2]
-
-
_, err := verifier.VerifyDPoPProof(tamperedProof, method, uri)
-
if err == nil {
-
t.Error("Expected error for invalid ES256K signature, got nil")
-
}
-
if err != nil && !contains(err.Error(), "signature verification failed") {
-
t.Errorf("Expected signature verification error, got: %v", err)
-
}
-
}
-
-
func TestCalculateJWKThumbprint_ES256K(t *testing.T) {
-
// Test thumbprint calculation for secp256k1 keys
-
key := generateTestES256KKey(t)
-
-
thumbprint, err := CalculateJWKThumbprint(key.jwk)
-
if err != nil {
-
t.Fatalf("CalculateJWKThumbprint failed for ES256K: %v", err)
-
}
-
-
if thumbprint == "" {
-
t.Error("Expected non-empty thumbprint for ES256K key")
-
}
-
-
// Verify it's valid base64url
-
_, err = base64.RawURLEncoding.DecodeString(thumbprint)
-
if err != nil {
-
t.Errorf("ES256K thumbprint is not valid base64url: %v", err)
-
}
-
-
// Verify length (SHA-256 produces 32 bytes = 43 base64url chars)
-
if len(thumbprint) != 43 {
-
t.Errorf("Expected ES256K thumbprint length 43, got %d", len(thumbprint))
-
}
-
}
-
-
// === Algorithm-Curve Binding Tests ===
-
-
func TestVerifyDPoPProof_AlgorithmCurveMismatch_ES256KWithP256Key(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t) // P-256 key
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
// Create a proof claiming ES256K but using P-256 key
-
claims := &DPoPClaims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
ID: jti,
-
IssuedAt: jwt.NewNumericDate(iat),
-
},
-
HTTPMethod: method,
-
HTTPURI: uri,
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
-
token.Header["typ"] = "dpop+jwt"
-
token.Header["alg"] = "ES256K" // Claim ES256K
-
token.Header["jwk"] = key.jwk // But use P-256 key
-
-
proof, err := token.SignedString(key.privateKey)
-
if err != nil {
-
t.Fatalf("Failed to create test proof: %v", err)
-
}
-
-
_, err = verifier.VerifyDPoPProof(proof, method, uri)
-
if err == nil {
-
t.Error("Expected error for ES256K algorithm with P-256 curve, got nil")
-
}
-
if err != nil && !contains(err.Error(), "requires curve secp256k1") {
-
t.Errorf("Expected curve mismatch error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_AlgorithmCurveMismatch_ES256WithSecp256k1Key(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256KKey(t) // secp256k1 key
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
// Build claims
-
claims := map[string]interface{}{
-
"jti": jti,
-
"iat": iat.Unix(),
-
"htm": method,
-
"htu": uri,
-
}
-
-
// Build header claiming ES256 but using secp256k1 key
-
header := map[string]interface{}{
-
"typ": "dpop+jwt",
-
"alg": "ES256", // Claim ES256
-
"jwk": key.jwk, // But use secp256k1 key
-
}
-
-
headerJSON, _ := json.Marshal(header)
-
claimsJSON, _ := json.Marshal(claims)
-
-
headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
-
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
-
-
signingInput := headerB64 + "." + claimsB64
-
signature, err := key.privateKey.HashAndSign([]byte(signingInput))
-
if err != nil {
-
t.Fatalf("Failed to sign: %v", err)
-
}
-
-
proof := signingInput + "." + base64.RawURLEncoding.EncodeToString(signature)
-
-
_, err = verifier.VerifyDPoPProof(proof, method, uri)
-
if err == nil {
-
t.Error("Expected error for ES256 algorithm with secp256k1 curve, got nil")
-
}
-
if err != nil && !contains(err.Error(), "requires curve P-256") {
-
t.Errorf("Expected curve mismatch error, got: %v", err)
-
}
-
}
-
-
// === exp/nbf Validation Tests ===
-
-
func TestVerifyDPoPProof_ExpiredWithExpClaim(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now().Add(-2 * time.Minute)
-
exp := time.Now().Add(-1 * time.Minute) // Expired 1 minute ago
-
jti := uuid.New().String()
-
-
claims := &DPoPClaims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
ID: jti,
-
IssuedAt: jwt.NewNumericDate(iat),
-
ExpiresAt: jwt.NewNumericDate(exp),
-
},
-
HTTPMethod: method,
-
HTTPURI: uri,
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
-
token.Header["typ"] = "dpop+jwt"
-
token.Header["jwk"] = key.jwk
-
-
proof, err := token.SignedString(key.privateKey)
-
if err != nil {
-
t.Fatalf("Failed to create test proof: %v", err)
-
}
-
-
_, err = verifier.VerifyDPoPProof(proof, method, uri)
-
if err == nil {
-
t.Error("Expected error for expired proof with exp claim, got nil")
-
}
-
if err != nil && !contains(err.Error(), "expired") {
-
t.Errorf("Expected expiration error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_NotYetValidWithNbfClaim(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
nbf := time.Now().Add(5 * time.Minute) // Not valid for another 5 minutes
-
jti := uuid.New().String()
-
-
claims := &DPoPClaims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
ID: jti,
-
IssuedAt: jwt.NewNumericDate(iat),
-
NotBefore: jwt.NewNumericDate(nbf),
-
},
-
HTTPMethod: method,
-
HTTPURI: uri,
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
-
token.Header["typ"] = "dpop+jwt"
-
token.Header["jwk"] = key.jwk
-
-
proof, err := token.SignedString(key.privateKey)
-
if err != nil {
-
t.Fatalf("Failed to create test proof: %v", err)
-
}
-
-
_, err = verifier.VerifyDPoPProof(proof, method, uri)
-
if err == nil {
-
t.Error("Expected error for not-yet-valid proof with nbf claim, got nil")
-
}
-
if err != nil && !contains(err.Error(), "not valid before") {
-
t.Errorf("Expected not-before error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_ValidWithExpClaimInFuture(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
exp := time.Now().Add(5 * time.Minute) // Valid for 5 more minutes
-
jti := uuid.New().String()
-
-
claims := &DPoPClaims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
ID: jti,
-
IssuedAt: jwt.NewNumericDate(iat),
-
ExpiresAt: jwt.NewNumericDate(exp),
-
},
-
HTTPMethod: method,
-
HTTPURI: uri,
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
-
token.Header["typ"] = "dpop+jwt"
-
token.Header["jwk"] = key.jwk
-
-
proof, err := token.SignedString(key.privateKey)
-
if err != nil {
-
t.Fatalf("Failed to create test proof: %v", err)
-
}
-
-
result, err := verifier.VerifyDPoPProof(proof, method, uri)
-
if err != nil {
-
t.Fatalf("VerifyDPoPProof failed for valid proof with exp in future: %v", err)
-
}
-
-
if result == nil {
-
t.Error("Expected non-nil result for valid proof")
-
}
-
}
-189
internal/atproto/auth/jwks_fetcher.go
···
-
package auth
-
-
import (
-
"context"
-
"encoding/json"
-
"fmt"
-
"net/http"
-
"strings"
-
"sync"
-
"time"
-
)
-
-
// CachedJWKSFetcher fetches and caches JWKS from authorization servers
-
type CachedJWKSFetcher struct {
-
cache map[string]*cachedJWKS
-
httpClient *http.Client
-
cacheMutex sync.RWMutex
-
cacheTTL time.Duration
-
}
-
-
type cachedJWKS struct {
-
jwks *JWKS
-
expiresAt time.Time
-
}
-
-
// NewCachedJWKSFetcher creates a new JWKS fetcher with caching
-
func NewCachedJWKSFetcher(cacheTTL time.Duration) *CachedJWKSFetcher {
-
return &CachedJWKSFetcher{
-
cache: make(map[string]*cachedJWKS),
-
httpClient: &http.Client{
-
Timeout: 10 * time.Second,
-
},
-
cacheTTL: cacheTTL,
-
}
-
}
-
-
// FetchPublicKey fetches the public key for verifying a JWT from the issuer
-
// Implements JWKSFetcher interface
-
// Returns interface{} to support both RSA and ECDSA keys
-
func (f *CachedJWKSFetcher) FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) {
-
// Extract key ID from token
-
kid, err := ExtractKeyID(token)
-
if err != nil {
-
return nil, fmt.Errorf("failed to extract key ID: %w", err)
-
}
-
-
// Get JWKS from cache or fetch
-
jwks, err := f.getJWKS(ctx, issuer)
-
if err != nil {
-
return nil, err
-
}
-
-
// Find the key by ID
-
jwk, err := jwks.FindKeyByID(kid)
-
if err != nil {
-
// Key not found in cache - try refreshing
-
jwks, err = f.fetchJWKS(ctx, issuer)
-
if err != nil {
-
return nil, fmt.Errorf("failed to refresh JWKS: %w", err)
-
}
-
f.cacheJWKS(issuer, jwks)
-
-
// Try again with fresh JWKS
-
jwk, err = jwks.FindKeyByID(kid)
-
if err != nil {
-
return nil, err
-
}
-
}
-
-
// Convert JWK to public key (RSA or ECDSA)
-
return jwk.ToPublicKey()
-
}
-
-
// getJWKS gets JWKS from cache or fetches if not cached/expired
-
func (f *CachedJWKSFetcher) getJWKS(ctx context.Context, issuer string) (*JWKS, error) {
-
// Check cache first
-
f.cacheMutex.RLock()
-
cached, exists := f.cache[issuer]
-
f.cacheMutex.RUnlock()
-
-
if exists && time.Now().Before(cached.expiresAt) {
-
return cached.jwks, nil
-
}
-
-
// Not in cache or expired - fetch from issuer
-
jwks, err := f.fetchJWKS(ctx, issuer)
-
if err != nil {
-
return nil, err
-
}
-
-
// Cache it
-
f.cacheJWKS(issuer, jwks)
-
-
return jwks, nil
-
}
-
-
// fetchJWKS fetches JWKS from the authorization server
-
func (f *CachedJWKSFetcher) fetchJWKS(ctx context.Context, issuer string) (*JWKS, error) {
-
// Step 1: Fetch OAuth server metadata to get JWKS URI
-
metadataURL := strings.TrimSuffix(issuer, "/") + "/.well-known/oauth-authorization-server"
-
-
req, err := http.NewRequestWithContext(ctx, "GET", metadataURL, nil)
-
if err != nil {
-
return nil, fmt.Errorf("failed to create metadata request: %w", err)
-
}
-
-
resp, err := f.httpClient.Do(req)
-
if err != nil {
-
return nil, fmt.Errorf("failed to fetch metadata: %w", err)
-
}
-
defer func() {
-
_ = resp.Body.Close()
-
}()
-
-
if resp.StatusCode != http.StatusOK {
-
return nil, fmt.Errorf("metadata endpoint returned status %d", resp.StatusCode)
-
}
-
-
var metadata struct {
-
JWKSURI string `json:"jwks_uri"`
-
}
-
if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil {
-
return nil, fmt.Errorf("failed to decode metadata: %w", err)
-
}
-
-
if metadata.JWKSURI == "" {
-
return nil, fmt.Errorf("jwks_uri not found in metadata")
-
}
-
-
// Step 2: Fetch JWKS from the JWKS URI
-
jwksReq, err := http.NewRequestWithContext(ctx, "GET", metadata.JWKSURI, nil)
-
if err != nil {
-
return nil, fmt.Errorf("failed to create JWKS request: %w", err)
-
}
-
-
jwksResp, err := f.httpClient.Do(jwksReq)
-
if err != nil {
-
return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
-
}
-
defer func() {
-
_ = jwksResp.Body.Close()
-
}()
-
-
if jwksResp.StatusCode != http.StatusOK {
-
return nil, fmt.Errorf("JWKS endpoint returned status %d", jwksResp.StatusCode)
-
}
-
-
var jwks JWKS
-
if err := json.NewDecoder(jwksResp.Body).Decode(&jwks); err != nil {
-
return nil, fmt.Errorf("failed to decode JWKS: %w", err)
-
}
-
-
if len(jwks.Keys) == 0 {
-
return nil, fmt.Errorf("no keys found in JWKS")
-
}
-
-
return &jwks, nil
-
}
-
-
// cacheJWKS stores JWKS in the cache
-
func (f *CachedJWKSFetcher) cacheJWKS(issuer string, jwks *JWKS) {
-
f.cacheMutex.Lock()
-
defer f.cacheMutex.Unlock()
-
-
f.cache[issuer] = &cachedJWKS{
-
jwks: jwks,
-
expiresAt: time.Now().Add(f.cacheTTL),
-
}
-
}
-
-
// ClearCache clears the entire JWKS cache
-
func (f *CachedJWKSFetcher) ClearCache() {
-
f.cacheMutex.Lock()
-
defer f.cacheMutex.Unlock()
-
f.cache = make(map[string]*cachedJWKS)
-
}
-
-
// CleanupExpiredCache removes expired entries from the cache
-
func (f *CachedJWKSFetcher) CleanupExpiredCache() {
-
f.cacheMutex.Lock()
-
defer f.cacheMutex.Unlock()
-
-
now := time.Now()
-
for issuer, cached := range f.cache {
-
if now.After(cached.expiresAt) {
-
delete(f.cache, issuer)
-
}
-
}
-
}
-496
internal/atproto/auth/jwt_test.go
···
-
package auth
-
-
import (
-
"context"
-
"testing"
-
"time"
-
-
"github.com/golang-jwt/jwt/v5"
-
)
-
-
func TestParseJWT(t *testing.T) {
-
// Create a test JWT token
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
Issuer: "https://test-pds.example.com",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
IssuedAt: jwt.NewNumericDate(time.Now()),
-
},
-
Scope: "atproto transition:generic",
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
-
tokenString, err := token.SignedString([]byte("test-secret"))
-
if err != nil {
-
t.Fatalf("Failed to create test token: %v", err)
-
}
-
-
// Test parsing
-
parsedClaims, err := ParseJWT(tokenString)
-
if err != nil {
-
t.Fatalf("ParseJWT failed: %v", err)
-
}
-
-
if parsedClaims.Subject != "did:plc:test123" {
-
t.Errorf("Expected subject 'did:plc:test123', got '%s'", parsedClaims.Subject)
-
}
-
-
if parsedClaims.Issuer != "https://test-pds.example.com" {
-
t.Errorf("Expected issuer 'https://test-pds.example.com', got '%s'", parsedClaims.Issuer)
-
}
-
-
if parsedClaims.Scope != "atproto transition:generic" {
-
t.Errorf("Expected scope 'atproto transition:generic', got '%s'", parsedClaims.Scope)
-
}
-
}
-
-
func TestParseJWT_MissingSubject(t *testing.T) {
-
// Create a token without subject
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Issuer: "https://test-pds.example.com",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
},
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
-
tokenString, err := token.SignedString([]byte("test-secret"))
-
if err != nil {
-
t.Fatalf("Failed to create test token: %v", err)
-
}
-
-
// Test parsing - should fail
-
_, err = ParseJWT(tokenString)
-
if err == nil {
-
t.Error("Expected error for missing subject, got nil")
-
}
-
}
-
-
func TestParseJWT_MissingIssuer(t *testing.T) {
-
// Create a token without issuer
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
},
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
-
tokenString, err := token.SignedString([]byte("test-secret"))
-
if err != nil {
-
t.Fatalf("Failed to create test token: %v", err)
-
}
-
-
// Test parsing - should fail
-
_, err = ParseJWT(tokenString)
-
if err == nil {
-
t.Error("Expected error for missing issuer, got nil")
-
}
-
}
-
-
func TestParseJWT_WithBearerPrefix(t *testing.T) {
-
// Create a test JWT token
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
Issuer: "https://test-pds.example.com",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
},
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
-
tokenString, err := token.SignedString([]byte("test-secret"))
-
if err != nil {
-
t.Fatalf("Failed to create test token: %v", err)
-
}
-
-
// Test parsing with Bearer prefix
-
parsedClaims, err := ParseJWT("Bearer " + tokenString)
-
if err != nil {
-
t.Fatalf("ParseJWT failed with Bearer prefix: %v", err)
-
}
-
-
if parsedClaims.Subject != "did:plc:test123" {
-
t.Errorf("Expected subject 'did:plc:test123', got '%s'", parsedClaims.Subject)
-
}
-
}
-
-
func TestValidateClaims_Expired(t *testing.T) {
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
Issuer: "https://test-pds.example.com",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(-1 * time.Hour)), // Expired
-
},
-
}
-
-
err := validateClaims(claims)
-
if err == nil {
-
t.Error("Expected error for expired token, got nil")
-
}
-
}
-
-
func TestValidateClaims_InvalidDID(t *testing.T) {
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "invalid-did-format",
-
Issuer: "https://test-pds.example.com",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
},
-
}
-
-
err := validateClaims(claims)
-
if err == nil {
-
t.Error("Expected error for invalid DID format, got nil")
-
}
-
}
-
-
func TestExtractKeyID(t *testing.T) {
-
// Create a test JWT token with kid in header
-
token := jwt.New(jwt.SigningMethodRS256)
-
token.Header["kid"] = "test-key-id"
-
token.Claims = &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
Issuer: "https://test-pds.example.com",
-
},
-
}
-
-
// Sign with a dummy RSA key (we just need a valid token structure)
-
tokenString, err := token.SignedString([]byte("dummy"))
-
if err == nil {
-
// If it succeeds (shouldn't with wrong key type, but let's handle it)
-
kid, err := ExtractKeyID(tokenString)
-
if err != nil {
-
t.Logf("ExtractKeyID failed (expected if signing fails): %v", err)
-
} else if kid != "test-key-id" {
-
t.Errorf("Expected kid 'test-key-id', got '%s'", kid)
-
}
-
}
-
}
-
-
// === HS256 Verification Tests ===
-
-
// mockJWKSFetcher is a mock implementation of JWKSFetcher for testing
-
type mockJWKSFetcher struct {
-
publicKey interface{}
-
err error
-
}
-
-
func (m *mockJWKSFetcher) FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) {
-
return m.publicKey, m.err
-
}
-
-
func createHS256Token(t *testing.T, subject, issuer, secret string, expiry time.Duration) string {
-
t.Helper()
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: subject,
-
Issuer: issuer,
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiry)),
-
IssuedAt: jwt.NewNumericDate(time.Now()),
-
},
-
Scope: "atproto transition:generic",
-
}
-
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
-
tokenString, err := token.SignedString([]byte(secret))
-
if err != nil {
-
t.Fatalf("Failed to create test token: %v", err)
-
}
-
return tokenString
-
}
-
-
func TestVerifyJWT_HS256_Valid(t *testing.T) {
-
// Setup: Configure environment for HS256 verification
-
secret := "test-jwt-secret-key-12345"
-
issuer := "https://pds.coves.social"
-
-
ResetJWTConfigForTesting()
-
t.Setenv("PDS_JWT_SECRET", secret)
-
t.Setenv("HS256_ISSUERS", issuer)
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
tokenString := createHS256Token(t, "did:plc:test123", issuer, secret, 1*time.Hour)
-
-
// Verify token
-
claims, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
-
if err != nil {
-
t.Fatalf("VerifyJWT failed for valid HS256 token: %v", err)
-
}
-
-
if claims.Subject != "did:plc:test123" {
-
t.Errorf("Expected subject 'did:plc:test123', got '%s'", claims.Subject)
-
}
-
if claims.Issuer != issuer {
-
t.Errorf("Expected issuer '%s', got '%s'", issuer, claims.Issuer)
-
}
-
}
-
-
func TestVerifyJWT_HS256_WrongSecret(t *testing.T) {
-
// Setup: Configure environment with one secret, sign with another
-
issuer := "https://pds.coves.social"
-
-
ResetJWTConfigForTesting()
-
t.Setenv("PDS_JWT_SECRET", "correct-secret")
-
t.Setenv("HS256_ISSUERS", issuer)
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
// Create token with wrong secret
-
tokenString := createHS256Token(t, "did:plc:test123", issuer, "wrong-secret", 1*time.Hour)
-
-
// Verify should fail
-
_, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
-
if err == nil {
-
t.Error("Expected error for HS256 token with wrong secret, got nil")
-
}
-
}
-
-
func TestVerifyJWT_HS256_SecretNotConfigured(t *testing.T) {
-
// Setup: Whitelist issuer but don't configure secret
-
issuer := "https://pds.coves.social"
-
-
ResetJWTConfigForTesting()
-
t.Setenv("PDS_JWT_SECRET", "") // Ensure secret is not set (empty = not configured)
-
t.Setenv("HS256_ISSUERS", issuer)
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
tokenString := createHS256Token(t, "did:plc:test123", issuer, "any-secret", 1*time.Hour)
-
-
// Verify should fail with descriptive error
-
_, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
-
if err == nil {
-
t.Error("Expected error when PDS_JWT_SECRET not configured, got nil")
-
}
-
if err != nil && !contains(err.Error(), "PDS_JWT_SECRET not configured") {
-
t.Errorf("Expected error about PDS_JWT_SECRET not configured, got: %v", err)
-
}
-
}
-
-
// === Algorithm Confusion Attack Prevention Tests ===
-
-
func TestVerifyJWT_AlgorithmConfusionAttack_HS256WithNonWhitelistedIssuer(t *testing.T) {
-
// SECURITY TEST: This tests the algorithm confusion attack prevention
-
// An attacker tries to use HS256 with an issuer that should use RS256/ES256
-
-
ResetJWTConfigForTesting()
-
t.Setenv("PDS_JWT_SECRET", "some-secret")
-
t.Setenv("HS256_ISSUERS", "https://trusted.example.com") // Different from token issuer
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
// Create HS256 token with non-whitelisted issuer (simulating attack)
-
tokenString := createHS256Token(t, "did:plc:attacker", "https://victim-pds.example.com", "some-secret", 1*time.Hour)
-
-
// Verify should fail because issuer is not in HS256 whitelist
-
_, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
-
if err == nil {
-
t.Error("SECURITY VULNERABILITY: HS256 token accepted for non-whitelisted issuer")
-
}
-
if err != nil && !contains(err.Error(), "not in HS256_ISSUERS whitelist") {
-
t.Errorf("Expected error about HS256 not allowed for issuer, got: %v", err)
-
}
-
}
-
-
func TestVerifyJWT_AlgorithmConfusionAttack_EmptyWhitelist(t *testing.T) {
-
// SECURITY TEST: When no issuers are whitelisted for HS256, all HS256 tokens should be rejected
-
-
ResetJWTConfigForTesting()
-
t.Setenv("PDS_JWT_SECRET", "some-secret")
-
t.Setenv("HS256_ISSUERS", "") // Empty whitelist
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
tokenString := createHS256Token(t, "did:plc:test123", "https://any-pds.example.com", "some-secret", 1*time.Hour)
-
-
// Verify should fail because no issuers are whitelisted for HS256
-
_, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
-
if err == nil {
-
t.Error("SECURITY VULNERABILITY: HS256 token accepted with empty issuer whitelist")
-
}
-
}
-
-
func TestVerifyJWT_IssuerRequiresHS256ButTokenUsesRS256(t *testing.T) {
-
// Test that issuer whitelisted for HS256 rejects tokens claiming to use RS256
-
issuer := "https://pds.coves.social"
-
-
ResetJWTConfigForTesting()
-
t.Setenv("PDS_JWT_SECRET", "test-secret")
-
t.Setenv("HS256_ISSUERS", issuer)
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
// Create RS256-signed token (can't actually sign without RSA key, but we can test the header check)
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
Issuer: issuer,
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
},
-
}
-
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
-
// This will create an invalid signature but valid header structure
-
// The test should fail at algorithm check, not signature verification
-
tokenString, _ := token.SignedString([]byte("dummy-key"))
-
-
if tokenString != "" {
-
_, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
-
if err == nil {
-
t.Error("Expected error when HS256 issuer receives non-HS256 token")
-
}
-
}
-
}
-
-
// === ParseJWTHeader Tests ===
-
-
func TestParseJWTHeader_Valid(t *testing.T) {
-
tokenString := createHS256Token(t, "did:plc:test123", "https://test.example.com", "secret", 1*time.Hour)
-
-
header, err := ParseJWTHeader(tokenString)
-
if err != nil {
-
t.Fatalf("ParseJWTHeader failed: %v", err)
-
}
-
-
if header.Alg != AlgorithmHS256 {
-
t.Errorf("Expected alg '%s', got '%s'", AlgorithmHS256, header.Alg)
-
}
-
}
-
-
func TestParseJWTHeader_WithBearerPrefix(t *testing.T) {
-
tokenString := createHS256Token(t, "did:plc:test123", "https://test.example.com", "secret", 1*time.Hour)
-
-
header, err := ParseJWTHeader("Bearer " + tokenString)
-
if err != nil {
-
t.Fatalf("ParseJWTHeader failed with Bearer prefix: %v", err)
-
}
-
-
if header.Alg != AlgorithmHS256 {
-
t.Errorf("Expected alg '%s', got '%s'", AlgorithmHS256, header.Alg)
-
}
-
}
-
-
func TestParseJWTHeader_InvalidFormat(t *testing.T) {
-
testCases := []struct {
-
name string
-
input string
-
}{
-
{"empty string", ""},
-
{"single part", "abc"},
-
{"two parts", "abc.def"},
-
{"too many parts", "a.b.c.d"},
-
}
-
-
for _, tc := range testCases {
-
t.Run(tc.name, func(t *testing.T) {
-
_, err := ParseJWTHeader(tc.input)
-
if err == nil {
-
t.Errorf("Expected error for invalid JWT format '%s', got nil", tc.input)
-
}
-
})
-
}
-
}
-
-
// === shouldUseHS256 and isHS256IssuerWhitelisted Tests ===
-
-
func TestIsHS256IssuerWhitelisted_Whitelisted(t *testing.T) {
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", "https://pds1.example.com,https://pds2.example.com")
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
if !isHS256IssuerWhitelisted("https://pds1.example.com") {
-
t.Error("Expected pds1 to be whitelisted")
-
}
-
if !isHS256IssuerWhitelisted("https://pds2.example.com") {
-
t.Error("Expected pds2 to be whitelisted")
-
}
-
}
-
-
func TestIsHS256IssuerWhitelisted_NotWhitelisted(t *testing.T) {
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", "https://pds1.example.com")
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
if isHS256IssuerWhitelisted("https://attacker.example.com") {
-
t.Error("Expected non-whitelisted issuer to return false")
-
}
-
}
-
-
func TestIsHS256IssuerWhitelisted_EmptyWhitelist(t *testing.T) {
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", "") // Empty whitelist
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
if isHS256IssuerWhitelisted("https://any.example.com") {
-
t.Error("Expected false when whitelist is empty (safe default)")
-
}
-
}
-
-
func TestIsHS256IssuerWhitelisted_WhitespaceHandling(t *testing.T) {
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", " https://pds1.example.com , https://pds2.example.com ")
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
if !isHS256IssuerWhitelisted("https://pds1.example.com") {
-
t.Error("Expected whitespace-trimmed issuer to be whitelisted")
-
}
-
}
-
-
// === shouldUseHS256 Tests (kid-based logic) ===
-
-
func TestShouldUseHS256_WithKid_AlwaysFalse(t *testing.T) {
-
// Tokens with kid should NEVER use HS256, regardless of issuer whitelist
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", "https://whitelisted.example.com")
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
header := &JWTHeader{
-
Alg: AlgorithmHS256,
-
Kid: "some-key-id", // Has kid
-
}
-
-
// Even whitelisted issuer should not use HS256 if token has kid
-
if shouldUseHS256(header, "https://whitelisted.example.com") {
-
t.Error("Tokens with kid should never use HS256 (supports federation)")
-
}
-
}
-
-
func TestShouldUseHS256_WithoutKid_WhitelistedIssuer(t *testing.T) {
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", "https://my-pds.example.com")
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
header := &JWTHeader{
-
Alg: AlgorithmHS256,
-
Kid: "", // No kid
-
}
-
-
if !shouldUseHS256(header, "https://my-pds.example.com") {
-
t.Error("Token without kid from whitelisted issuer should use HS256")
-
}
-
}
-
-
func TestShouldUseHS256_WithoutKid_NotWhitelisted(t *testing.T) {
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", "https://my-pds.example.com")
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
header := &JWTHeader{
-
Alg: AlgorithmHS256,
-
Kid: "", // No kid
-
}
-
-
if shouldUseHS256(header, "https://external-pds.example.com") {
-
t.Error("Token without kid from non-whitelisted issuer should NOT use HS256")
-
}
-
}
-
-
// Helper function
-
func contains(s, substr string) bool {
-
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
-
}
-
-
func containsHelper(s, substr string) bool {
-
for i := 0; i <= len(s)-len(substr); i++ {
-
if s[i:i+len(substr)] == substr {
-
return true
-
}
-
}
-
return false
-
}
+1
docker-compose.prod.yml
···
# Instance identity
INSTANCE_DID: did:web:coves.social
INSTANCE_DOMAIN: coves.social
+
APPVIEW_PUBLIC_URL: https://coves.social
# PDS connection (separate domain!)
PDS_URL: https://coves.me
+6 -5
internal/api/routes/oauth.go
···
// Use login limiter since callback completes the authentication flow
r.With(corsMiddleware(allowedOrigins), loginLimiter.Middleware).Get("/oauth/callback", handler.HandleCallback)
-
// Mobile Universal Link callback route
-
// This route is used for iOS Universal Links and Android App Links
-
// Path must match the path in .well-known/apple-app-site-association
-
// Uses the same handler as web callback - the system routes it to the mobile app
-
r.With(loginLimiter.Middleware).Get("/app/oauth/callback", handler.HandleCallback)
+
// Mobile Universal Link callback route (fallback when app doesn't intercept)
+
// This route exists for iOS Universal Links and Android App Links.
+
// When properly configured, the mobile OS intercepts this URL and opens the app
+
// BEFORE the request reaches the server. If this handler is reached, it means
+
// Universal Links failed to intercept.
+
r.With(loginLimiter.Middleware).Get("/app/oauth/callback", handler.HandleMobileDeepLinkFallback)
// Session management - dedicated rate limits
r.With(logoutLimiter.Middleware).Post("/oauth/logout", handler.HandleLogout)
+11
static/.well-known/apple-app-site-association
···
+
{
+
"applinks": {
+
"apps": [],
+
"details": [
+
{
+
"appID": "TEAM_ID.social.coves",
+
"paths": ["/app/oauth/callback"]
+
}
+
]
+
}
+
}
+10
static/.well-known/assetlinks.json
···
+
[{
+
"relation": ["delegate_permission/common.handle_all_urls"],
+
"target": {
+
"namespace": "android_app",
+
"package_name": "social.coves",
+
"sha256_cert_fingerprints": [
+
"0B:D8:8C:99:66:25:E5:CD:06:54:80:88:01:6F:B7:38:B9:F4:5B:41:71:F7:95:C8:68:94:87:AD:EA:9F:D9:ED"
+
]
+
}
+
}]
+16 -9
internal/atproto/oauth/handlers_test.go
···
}
// TestIsMobileRedirectURI tests mobile redirect URI validation with EXACT URI matching
-
// Only Universal Links (HTTPS) are allowed - custom schemes are blocked for security
+
// Per atproto spec, custom schemes must match client_id hostname in reverse-domain order
func TestIsMobileRedirectURI(t *testing.T) {
tests := []struct {
uri string
expected bool
}{
-
{"https://coves.social/app/oauth/callback", true}, // Universal Link - allowed
-
{"coves-app://oauth/callback", false}, // Custom scheme - blocked (insecure)
-
{"coves://oauth/callback", false}, // Custom scheme - blocked (insecure)
-
{"coves-app://callback", false}, // Custom scheme - blocked
-
{"coves://oauth", false}, // Custom scheme - blocked
-
{"myapp://oauth", false}, // Not in allowlist
-
{"https://example.com", false}, // Wrong domain
-
{"http://localhost", false}, // HTTP not allowed
+
// Custom scheme per atproto spec (reverse domain of coves.social)
+
{"social.coves:/callback", true},
+
{"social.coves://callback", true},
+
{"social.coves:/oauth/callback", true},
+
{"social.coves://oauth/callback", true},
+
// Universal Link - allowed (strongest security)
+
{"https://coves.social/app/oauth/callback", true},
+
// Wrong custom schemes - not reverse-domain of coves.social
+
{"coves-app://oauth/callback", false},
+
{"coves://oauth/callback", false},
+
{"coves.social://callback", false}, // Not reversed
+
{"myapp://oauth", false},
+
// Wrong domain/scheme
+
{"https://example.com", false},
+
{"http://localhost", false},
{"", false},
{"not-a-uri", false},
}
+41
internal/atproto/lexicon/social/coves/feed/vote/delete.json
···
+
{
+
"lexicon": 1,
+
"id": "social.coves.feed.vote.delete",
+
"defs": {
+
"main": {
+
"type": "procedure",
+
"description": "Delete a vote on a post or comment",
+
"input": {
+
"encoding": "application/json",
+
"schema": {
+
"type": "object",
+
"required": ["subject"],
+
"properties": {
+
"subject": {
+
"type": "ref",
+
"ref": "com.atproto.repo.strongRef",
+
"description": "Strong reference to the post or comment to remove the vote from"
+
}
+
}
+
}
+
},
+
"output": {
+
"encoding": "application/json",
+
"schema": {
+
"type": "object",
+
"properties": {}
+
}
+
},
+
"errors": [
+
{
+
"name": "VoteNotFound",
+
"description": "No vote found for this subject"
+
},
+
{
+
"name": "NotAuthorized",
+
"description": "User is not authorized to delete this vote"
+
}
+
]
+
}
+
}
+
}
+115
internal/api/handlers/vote/create_vote.go
···
+
package vote
+
+
import (
+
"Coves/internal/api/middleware"
+
"Coves/internal/core/votes"
+
"encoding/json"
+
"log"
+
"net/http"
+
)
+
+
// CreateVoteHandler handles vote creation
+
type CreateVoteHandler struct {
+
service votes.Service
+
}
+
+
// NewCreateVoteHandler creates a new create vote handler
+
func NewCreateVoteHandler(service votes.Service) *CreateVoteHandler {
+
return &CreateVoteHandler{
+
service: service,
+
}
+
}
+
+
// CreateVoteInput represents the request body for creating a vote
+
type CreateVoteInput struct {
+
Subject struct {
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
} `json:"subject"`
+
Direction string `json:"direction"`
+
}
+
+
// CreateVoteOutput represents the response body for creating a vote
+
type CreateVoteOutput struct {
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
}
+
+
// HandleCreateVote creates a vote on a post or comment
+
// POST /xrpc/social.coves.vote.create
+
//
+
// Request body: { "subject": { "uri": "at://...", "cid": "..." }, "direction": "up" }
+
// Response: { "uri": "at://...", "cid": "..." }
+
//
+
// Behavior:
+
// - If no vote exists: creates new vote with given direction
+
// - If vote exists with same direction: deletes vote (toggle off)
+
// - If vote exists with different direction: updates to new direction
+
func (h *CreateVoteHandler) HandleCreateVote(w http.ResponseWriter, r *http.Request) {
+
if r.Method != http.MethodPost {
+
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+
return
+
}
+
+
// Parse request body
+
var input CreateVoteInput
+
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "Invalid request body")
+
return
+
}
+
+
// Validate required fields
+
if input.Subject.URI == "" {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "subject.uri is required")
+
return
+
}
+
if input.Subject.CID == "" {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "subject.cid is required")
+
return
+
}
+
if input.Direction == "" {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "direction is required")
+
return
+
}
+
+
// Validate direction
+
if input.Direction != "up" && input.Direction != "down" {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "direction must be 'up' or 'down'")
+
return
+
}
+
+
// Get OAuth session from context (injected by auth middleware)
+
session := middleware.GetOAuthSession(r)
+
if session == nil {
+
writeError(w, http.StatusUnauthorized, "AuthRequired", "Authentication required")
+
return
+
}
+
+
// Create vote request
+
req := votes.CreateVoteRequest{
+
Subject: votes.StrongRef{
+
URI: input.Subject.URI,
+
CID: input.Subject.CID,
+
},
+
Direction: input.Direction,
+
}
+
+
// Call service to create vote
+
response, err := h.service.CreateVote(r.Context(), session, req)
+
if err != nil {
+
handleServiceError(w, err)
+
return
+
}
+
+
// Return success response
+
output := CreateVoteOutput{
+
URI: response.URI,
+
CID: response.CID,
+
}
+
+
w.Header().Set("Content-Type", "application/json")
+
w.WriteHeader(http.StatusOK)
+
if err := json.NewEncoder(w).Encode(output); err != nil {
+
log.Printf("Failed to encode response: %v", err)
+
}
+
}
+93
internal/api/handlers/vote/delete_vote.go
···
+
package vote
+
+
import (
+
"Coves/internal/api/middleware"
+
"Coves/internal/core/votes"
+
"encoding/json"
+
"log"
+
"net/http"
+
)
+
+
// DeleteVoteHandler handles vote deletion
+
type DeleteVoteHandler struct {
+
service votes.Service
+
}
+
+
// NewDeleteVoteHandler creates a new delete vote handler
+
func NewDeleteVoteHandler(service votes.Service) *DeleteVoteHandler {
+
return &DeleteVoteHandler{
+
service: service,
+
}
+
}
+
+
// DeleteVoteInput represents the request body for deleting a vote
+
type DeleteVoteInput struct {
+
Subject struct {
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
} `json:"subject"`
+
}
+
+
// DeleteVoteOutput represents the response body for deleting a vote
+
// Per lexicon: output is an empty object
+
type DeleteVoteOutput struct{}
+
+
// HandleDeleteVote removes a vote from a post or comment
+
// POST /xrpc/social.coves.vote.delete
+
//
+
// Request body: { "subject": { "uri": "at://...", "cid": "..." } }
+
// Response: { "success": true }
+
func (h *DeleteVoteHandler) HandleDeleteVote(w http.ResponseWriter, r *http.Request) {
+
if r.Method != http.MethodPost {
+
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+
return
+
}
+
+
// Parse request body
+
var input DeleteVoteInput
+
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "Invalid request body")
+
return
+
}
+
+
// Validate required fields
+
if input.Subject.URI == "" {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "subject.uri is required")
+
return
+
}
+
if input.Subject.CID == "" {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "subject.cid is required")
+
return
+
}
+
+
// Get OAuth session from context (injected by auth middleware)
+
session := middleware.GetOAuthSession(r)
+
if session == nil {
+
writeError(w, http.StatusUnauthorized, "AuthRequired", "Authentication required")
+
return
+
}
+
+
// Create delete vote request
+
req := votes.DeleteVoteRequest{
+
Subject: votes.StrongRef{
+
URI: input.Subject.URI,
+
CID: input.Subject.CID,
+
},
+
}
+
+
// Call service to delete vote
+
err := h.service.DeleteVote(r.Context(), session, req)
+
if err != nil {
+
handleServiceError(w, err)
+
return
+
}
+
+
// Return success response (empty object per lexicon)
+
output := DeleteVoteOutput{}
+
+
w.Header().Set("Content-Type", "application/json")
+
w.WriteHeader(http.StatusOK)
+
if err := json.NewEncoder(w).Encode(output); err != nil {
+
log.Printf("Failed to encode response: %v", err)
+
}
+
}
+24
internal/api/routes/vote.go
···
+
package routes
+
+
import (
+
"Coves/internal/api/handlers/vote"
+
"Coves/internal/api/middleware"
+
"Coves/internal/core/votes"
+
+
"github.com/go-chi/chi/v5"
+
)
+
+
// RegisterVoteRoutes registers vote-related XRPC endpoints on the router
+
// Implements social.coves.feed.vote.* lexicon endpoints
+
func RegisterVoteRoutes(r chi.Router, voteService votes.Service, authMiddleware *middleware.OAuthAuthMiddleware) {
+
// Initialize handlers
+
createHandler := vote.NewCreateVoteHandler(voteService)
+
deleteHandler := vote.NewDeleteVoteHandler(voteService)
+
+
// Procedure endpoints (POST) - require authentication
+
// social.coves.feed.vote.create - create or update a vote on a post/comment
+
r.With(authMiddleware.RequireAuth).Post("/xrpc/social.coves.feed.vote.create", createHandler.HandleCreateVote)
+
+
// social.coves.feed.vote.delete - delete a vote from a post/comment
+
r.With(authMiddleware.RequireAuth).Post("/xrpc/social.coves.feed.vote.delete", deleteHandler.HandleDeleteVote)
+
}
+3
.beads/beads.left.jsonl
···
+
{"id":"Coves-95q","content_hash":"8ec99d598f067780436b985f9ad57f0fa19632026981038df4f65f192186620b","title":"Add comprehensive API documentation","description":"","status":"open","priority":2,"issue_type":"task","created_at":"2025-11-17T20:30:34.835721854-08:00","updated_at":"2025-11-17T20:30:34.835721854-08:00","source_repo":".","dependencies":[{"issue_id":"Coves-95q","depends_on_id":"Coves-e16","type":"blocks","created_at":"2025-11-17T20:30:46.273899399-08:00","created_by":"daemon"}]}
+
{"id":"Coves-e16","content_hash":"7c5d0fc8f0e7f626be3dad62af0e8412467330bad01a244e5a7e52ac5afff1c1","title":"Complete post creation and moderation features","description":"","status":"open","priority":1,"issue_type":"feature","created_at":"2025-11-17T20:30:12.885991306-08:00","updated_at":"2025-11-17T20:30:12.885991306-08:00","source_repo":"."}
+
{"id":"Coves-fce","content_hash":"26b3e16b99f827316ee0d741cc959464bd0c813446c95aef8105c7fd1e6b09ff","title":"Implement aggregator feed federation","description":"","status":"open","priority":1,"issue_type":"feature","created_at":"2025-11-17T20:30:21.453326012-08:00","updated_at":"2025-11-17T20:30:21.453326012-08:00","source_repo":"."}
+1
.beads/beads.left.meta.json
···
+
{"version":"0.23.1","timestamp":"2025-12-02T18:25:24.009187871-08:00","commit":"00d7d8d"}
-3
internal/api/handlers/vote/errors.go
···
case errors.Is(err, votes.ErrVoteNotFound):
// Matches: social.coves.feed.vote.delete#VoteNotFound
writeError(w, http.StatusNotFound, "VoteNotFound", "No vote found for this subject")
-
case errors.Is(err, votes.ErrSubjectNotFound):
-
// Matches: social.coves.feed.vote.create#SubjectNotFound
-
writeError(w, http.StatusNotFound, "SubjectNotFound", "The subject post or comment was not found")
case errors.Is(err, votes.ErrInvalidDirection):
writeError(w, http.StatusBadRequest, "InvalidRequest", "Vote direction must be 'up' or 'down'")
case errors.Is(err, votes.ErrInvalidSubject):
+4 -4
internal/atproto/oauth/handlers_security.go
···
// - Android: Verified via /.well-known/assetlinks.json
var allowedMobileRedirectURIs = map[string]bool{
// Custom scheme per atproto spec (reverse-domain of coves.social)
-
"social.coves:/callback": true,
-
"social.coves://callback": true, // Some platforms add double slash
-
"social.coves:/oauth/callback": true, // Alternative path
-
"social.coves://oauth/callback": true,
+
"social.coves:/callback": true,
+
"social.coves://callback": true, // Some platforms add double slash
+
"social.coves:/oauth/callback": true, // Alternative path
+
"social.coves://oauth/callback": true,
// Universal Links - cryptographically bound to app (preferred for security)
"https://coves.social/app/oauth/callback": true,
}
-3
internal/core/votes/errors.go
···
// ErrVoteNotFound indicates the requested vote doesn't exist
ErrVoteNotFound = errors.New("vote not found")
-
// ErrSubjectNotFound indicates the post/comment being voted on doesn't exist
-
ErrSubjectNotFound = errors.New("subject not found")
-
// ErrInvalidDirection indicates the vote direction is not "up" or "down"
ErrInvalidDirection = errors.New("invalid vote direction: must be 'up' or 'down'")
+3 -2
internal/db/postgres/vote_repo.go
···
return nil
}
-
// GetByURI retrieves a vote by its AT-URI
+
// GetByURI retrieves an active vote by its AT-URI
// Used by Jetstream consumer for DELETE operations
+
// Returns ErrVoteNotFound for soft-deleted votes
func (r *postgresVoteRepo) GetByURI(ctx context.Context, uri string) (*votes.Vote, error) {
query := `
SELECT
···
subject_uri, subject_cid, direction,
created_at, indexed_at, deleted_at
FROM votes
-
WHERE uri = $1
+
WHERE uri = $1 AND deleted_at IS NULL
`
var vote votes.Vote
+18
.env.dev
···
#
PLC_DIRECTORY_URL=http://localhost:3002
+
# =============================================================================
+
# Dev Mode Quick Reference
+
# =============================================================================
+
# REQUIRED for local OAuth to work with local PDS:
+
# IS_DEV_ENV=true # Master switch for dev mode
+
# PDS_URL=http://localhost:3001 # Local PDS for handle resolution
+
# PLC_DIRECTORY_URL=http://localhost:3002 # Local PLC directory
+
# APPVIEW_PUBLIC_URL=http://127.0.0.1:8081 # Use IP not localhost (RFC 8252)
+
#
+
# BUILD TAGS:
+
# make run - Runs with -tags dev (includes localhost OAuth resolvers)
+
# make build - Production binary (no dev code)
+
# make build-dev - Dev binary (includes dev code)
+
#
+
# Dev-only code (only compiled with -tags dev):
+
# - internal/atproto/oauth/dev_resolver.go (handle resolution via local PDS)
+
# - internal/atproto/oauth/dev_auth_resolver.go (localhost OAuth bypass)
+
#
# =============================================================================
# Notes
# =============================================================================
+92
.env.dev.example
···
+
# Coves Local Development Environment Configuration
+
# Copy this to .env.dev and fill in your values
+
#
+
# Quick Start:
+
# 1. cp .env.dev.example .env.dev
+
# 2. Generate OAuth key: go run cmd/genjwks/main.go (copy output to OAUTH_PRIVATE_JWK)
+
# 3. Generate cookie secret: openssl rand -hex 32
+
# 4. make dev-up # Start Docker services
+
# 5. make run # Start the server (uses -tags dev)
+
+
# =============================================================================
+
# Dev Mode Quick Reference
+
# =============================================================================
+
# REQUIRED for local OAuth to work with local PDS:
+
# IS_DEV_ENV=true # Master switch for dev mode
+
# PDS_URL=http://localhost:3001 # Local PDS for handle resolution
+
# PLC_DIRECTORY_URL=http://localhost:3002 # Local PLC directory
+
# APPVIEW_PUBLIC_URL=http://127.0.0.1:8081 # Use IP not localhost (RFC 8252)
+
#
+
# BUILD TAGS:
+
# make run - Runs with -tags dev (includes localhost OAuth resolvers)
+
# make build - Production binary (no dev code)
+
# make build-dev - Dev binary (includes dev code)
+
+
# =============================================================================
+
# PostgreSQL Configuration
+
# =============================================================================
+
POSTGRES_HOST=localhost
+
POSTGRES_PORT=5435
+
POSTGRES_DB=coves_dev
+
POSTGRES_USER=dev_user
+
POSTGRES_PASSWORD=dev_password
+
+
# Test database
+
POSTGRES_TEST_DB=coves_test
+
POSTGRES_TEST_USER=test_user
+
POSTGRES_TEST_PASSWORD=test_password
+
POSTGRES_TEST_PORT=5434
+
+
# =============================================================================
+
# PDS Configuration
+
# =============================================================================
+
PDS_HOSTNAME=localhost
+
PDS_PORT=3001
+
PDS_SERVICE_ENDPOINT=http://localhost:3000
+
PDS_DID_PLC_URL=http://plc-directory:3000
+
PDS_JWT_SECRET=local-dev-jwt-secret-change-in-production
+
PDS_ADMIN_PASSWORD=admin
+
PDS_SERVICE_HANDLE_DOMAINS=.local.coves.dev,.community.coves.social
+
PDS_PLC_ROTATION_KEY=<generate-a-random-hex-key>
+
+
# =============================================================================
+
# AppView Configuration
+
# =============================================================================
+
APPVIEW_PORT=8081
+
FIREHOSE_URL=ws://localhost:3001/xrpc/com.atproto.sync.subscribeRepos
+
PDS_URL=http://localhost:3001
+
APPVIEW_PUBLIC_URL=http://127.0.0.1:8081
+
+
# =============================================================================
+
# Jetstream Configuration
+
# =============================================================================
+
JETSTREAM_URL=ws://localhost:6008/subscribe
+
+
# =============================================================================
+
# Identity Resolution
+
# =============================================================================
+
IDENTITY_CACHE_TTL=24h
+
PLC_DIRECTORY_URL=http://localhost:3002
+
+
# =============================================================================
+
# OAuth Configuration (MUST GENERATE YOUR OWN)
+
# =============================================================================
+
# Generate with: go run cmd/genjwks/main.go
+
OAUTH_PRIVATE_JWK=<generate-your-own-jwk>
+
+
# Generate with: openssl rand -hex 32
+
OAUTH_COOKIE_SECRET=<generate-your-own-secret>
+
+
# =============================================================================
+
# Development Settings
+
# =============================================================================
+
ENV=development
+
NODE_ENV=development
+
IS_DEV_ENV=true
+
LOG_LEVEL=debug
+
LOG_ENABLED=true
+
+
# Security settings (ONLY for local dev - set to false in production!)
+
SKIP_DID_WEB_VERIFICATION=true
+
AUTH_SKIP_VERIFY=true
+
HS256_ISSUERS=http://localhost:3001
+25 -3
Makefile
···
-
.PHONY: help dev-up dev-down dev-logs dev-status dev-reset test e2e-test clean
+
.PHONY: help dev-up dev-down dev-logs dev-status dev-reset test e2e-test clean verify-stack create-test-account mobile-full-setup
# Default target - show help
.DEFAULT_GOAL := help
···
##@ Build & Run
-
build: ## Build the Coves server
-
@echo "$(GREEN)Building Coves server...$(RESET)"
+
build: ## Build the Coves server (production - no dev code)
+
@echo "$(GREEN)Building Coves server (production)...$(RESET)"
@go build -o server ./cmd/server
@echo "$(GREEN)โœ“ Build complete: ./server$(RESET)"
+
build-dev: ## Build the Coves server with dev mode (includes localhost OAuth resolvers)
+
@echo "$(GREEN)Building Coves server (dev mode)...$(RESET)"
+
@go build -tags dev -o server ./cmd/server
+
@echo "$(GREEN)โœ“ Build complete: ./server (with dev tags)$(RESET)"
+
run: ## Run the Coves server with dev environment (requires database running)
@./scripts/dev-run.sh
···
@adb reverse --remove-all || echo "$(YELLOW)No device connected$(RESET)"
@echo "$(GREEN)โœ“ Port forwarding removed$(RESET)"
+
verify-stack: ## Verify local development stack (PLC, PDS, configs)
+
@./scripts/verify-local-stack.sh
+
+
create-test-account: ## Create a test account on local PDS for OAuth testing
+
@./scripts/create-test-account.sh
+
+
mobile-full-setup: verify-stack create-test-account mobile-setup ## Full mobile setup: verify stack, create account, setup ports
+
@echo ""
+
@echo "$(GREEN)โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•$(RESET)"
+
@echo "$(GREEN) Mobile development environment ready! $(RESET)"
+
@echo "$(GREEN)โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•$(RESET)"
+
@echo ""
+
@echo "$(CYAN)Run the Flutter app with:$(RESET)"
+
@echo " $(YELLOW)cd /home/bretton/Code/coves-mobile$(RESET)"
+
@echo " $(YELLOW)flutter run --dart-define=ENVIRONMENT=local$(RESET)"
+
@echo ""
+
ngrok-up: ## Start ngrok tunnels (for iOS or WiFi testing - requires paid plan for 3 tunnels)
@echo "$(GREEN)Starting ngrok tunnels for mobile testing...$(RESET)"
@./scripts/start-ngrok.sh
+5 -1
docker-compose.dev.yml
···
# Bluesky Personal Data Server (PDS)
# Handles user repositories, DIDs, and CAR files
+
# NOTE: When using --profile plc, PDS waits for PLC directory to be healthy
pds:
image: ghcr.io/bluesky-social/pds:latest
container_name: coves-dev-pds
···
PDS_PORT: 3001 # Match external port for correct DID registration
PDS_DATA_DIRECTORY: /pds
PDS_BLOBSTORE_DISK_LOCATION: /pds/blocks
-
PDS_DID_PLC_URL: ${PDS_DID_PLC_URL:-https://plc.directory}
+
# IMPORTANT: For local E2E testing, this MUST point to local PLC directory
+
# Default to local PLC (http://plc-directory:3000) for full local stack
+
# The container hostname 'plc-directory' is used for Docker network communication
+
PDS_DID_PLC_URL: ${PDS_DID_PLC_URL:-http://plc-directory:3000}
# PDS_CRAWLERS not needed - we're not using a relay for local dev
# Note: PDS uses its own internal SQLite database and CAR file storage
+285
internal/atproto/oauth/dev_auth_resolver.go
···
+
//go:build dev
+
+
package oauth
+
+
import (
+
"context"
+
"encoding/json"
+
"fmt"
+
"log/slog"
+
"net/http"
+
"net/url"
+
"strings"
+
+
oauthlib "github.com/bluesky-social/indigo/atproto/auth/oauth"
+
"github.com/bluesky-social/indigo/atproto/identity"
+
"github.com/bluesky-social/indigo/atproto/syntax"
+
)
+
+
// DevAuthResolver is a custom OAuth resolver that allows HTTP localhost URLs for development.
+
// The standard indigo OAuth resolver requires HTTPS and no port numbers, which breaks local testing.
+
type DevAuthResolver struct {
+
Client *http.Client
+
UserAgent string
+
PDSURL string // For resolving handles via local PDS
+
handleResolver *DevHandleResolver
+
}
+
+
// ProtectedResourceMetadata matches the OAuth protected resource metadata document format
+
type ProtectedResourceMetadata struct {
+
Resource string `json:"resource"`
+
AuthorizationServers []string `json:"authorization_servers"`
+
}
+
+
// NewDevAuthResolver creates a resolver that accepts localhost HTTP URLs
+
func NewDevAuthResolver(pdsURL string, allowPrivateIPs bool) *DevAuthResolver {
+
resolver := &DevAuthResolver{
+
Client: NewSSRFSafeHTTPClient(allowPrivateIPs),
+
UserAgent: "Coves/1.0",
+
PDSURL: pdsURL,
+
}
+
// Create handle resolver for resolving handles via local PDS
+
if pdsURL != "" {
+
resolver.handleResolver = NewDevHandleResolver(pdsURL, allowPrivateIPs)
+
}
+
return resolver
+
}
+
+
// ResolveAuthServerURL resolves a PDS URL to an auth server URL.
+
// Unlike indigo's standard resolver, this allows HTTP and ports for localhost.
+
func (r *DevAuthResolver) ResolveAuthServerURL(ctx context.Context, hostURL string) (string, error) {
+
u, err := url.Parse(hostURL)
+
if err != nil {
+
return "", err
+
}
+
+
// For localhost, allow HTTP and port numbers
+
isLocalhost := u.Hostname() == "localhost" || u.Hostname() == "127.0.0.1"
+
if !isLocalhost {
+
// For non-localhost, enforce HTTPS and no port (standard rules)
+
if u.Scheme != "https" || u.Port() != "" {
+
return "", fmt.Errorf("not a valid public host URL: %s", hostURL)
+
}
+
}
+
+
// Build the protected resource document URL
+
var docURL string
+
if isLocalhost {
+
// For localhost, preserve the port and use HTTP
+
port := u.Port()
+
if port == "" {
+
port = "3001" // Default PDS port
+
}
+
docURL = fmt.Sprintf("http://%s:%s/.well-known/oauth-protected-resource", u.Hostname(), port)
+
} else {
+
docURL = fmt.Sprintf("https://%s/.well-known/oauth-protected-resource", u.Hostname())
+
}
+
+
// Fetch the protected resource document
+
req, err := http.NewRequestWithContext(ctx, "GET", docURL, nil)
+
if err != nil {
+
return "", err
+
}
+
if r.UserAgent != "" {
+
req.Header.Set("User-Agent", r.UserAgent)
+
}
+
+
resp, err := r.Client.Do(req)
+
if err != nil {
+
return "", fmt.Errorf("fetching protected resource document: %w", err)
+
}
+
defer resp.Body.Close()
+
+
if resp.StatusCode != http.StatusOK {
+
return "", fmt.Errorf("HTTP error fetching protected resource document: %d", resp.StatusCode)
+
}
+
+
var body ProtectedResourceMetadata
+
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
+
return "", fmt.Errorf("invalid protected resource document: %w", err)
+
}
+
+
if len(body.AuthorizationServers) < 1 {
+
return "", fmt.Errorf("no auth server URL in protected resource document")
+
}
+
+
authURL := body.AuthorizationServers[0]
+
+
// Validate the auth server URL (with localhost exception)
+
au, err := url.Parse(authURL)
+
if err != nil {
+
return "", fmt.Errorf("invalid auth server URL: %w", err)
+
}
+
+
authIsLocalhost := au.Hostname() == "localhost" || au.Hostname() == "127.0.0.1"
+
if !authIsLocalhost {
+
if au.Scheme != "https" || au.Port() != "" {
+
return "", fmt.Errorf("invalid auth server URL: %s", authURL)
+
}
+
}
+
+
return authURL, nil
+
}
+
+
// ResolveAuthServerMetadataDev fetches OAuth server metadata from a given auth server URL.
+
// Unlike indigo's resolver, this allows HTTP and ports for localhost.
+
func (r *DevAuthResolver) ResolveAuthServerMetadataDev(ctx context.Context, serverURL string) (*oauthlib.AuthServerMetadata, error) {
+
u, err := url.Parse(serverURL)
+
if err != nil {
+
return nil, err
+
}
+
+
// Build metadata URL - preserve port for localhost
+
var metaURL string
+
isLocalhost := u.Hostname() == "localhost" || u.Hostname() == "127.0.0.1"
+
if isLocalhost && u.Port() != "" {
+
metaURL = fmt.Sprintf("%s://%s:%s/.well-known/oauth-authorization-server", u.Scheme, u.Hostname(), u.Port())
+
} else if isLocalhost {
+
metaURL = fmt.Sprintf("%s://%s/.well-known/oauth-authorization-server", u.Scheme, u.Hostname())
+
} else {
+
metaURL = fmt.Sprintf("https://%s/.well-known/oauth-authorization-server", u.Hostname())
+
}
+
+
slog.Debug("dev mode: fetching auth server metadata", "url", metaURL)
+
+
req, err := http.NewRequestWithContext(ctx, "GET", metaURL, nil)
+
if err != nil {
+
return nil, err
+
}
+
if r.UserAgent != "" {
+
req.Header.Set("User-Agent", r.UserAgent)
+
}
+
+
resp, err := r.Client.Do(req)
+
if err != nil {
+
return nil, fmt.Errorf("fetching auth server metadata: %w", err)
+
}
+
defer resp.Body.Close()
+
+
if resp.StatusCode != http.StatusOK {
+
return nil, fmt.Errorf("HTTP error fetching auth server metadata: %d", resp.StatusCode)
+
}
+
+
var metadata oauthlib.AuthServerMetadata
+
if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil {
+
return nil, fmt.Errorf("invalid auth server metadata: %w", err)
+
}
+
+
// Skip validation for localhost (indigo's Validate checks HTTPS)
+
if !isLocalhost {
+
if err := metadata.Validate(serverURL); err != nil {
+
return nil, fmt.Errorf("invalid auth server metadata: %w", err)
+
}
+
}
+
+
return &metadata, nil
+
}
+
+
// StartDevAuthFlow performs OAuth flow for localhost development.
+
// This bypasses indigo's HTTPS validation for the auth server URL.
+
// It resolves the identity, gets the PDS endpoint, fetches auth server metadata,
+
// and returns a redirect URL for the user to approve.
+
func (r *DevAuthResolver) StartDevAuthFlow(ctx context.Context, client *OAuthClient, identifier string, dir identity.Directory) (string, error) {
+
var accountDID syntax.DID
+
var pdsEndpoint string
+
+
// Check if identifier is a handle or DID
+
if strings.HasPrefix(identifier, "did:") {
+
// It's a DID - look up via directory (PLC)
+
atid, err := syntax.ParseAtIdentifier(identifier)
+
if err != nil {
+
return "", fmt.Errorf("not a valid DID (%s): %w", identifier, err)
+
}
+
ident, err := dir.Lookup(ctx, *atid)
+
if err != nil {
+
return "", fmt.Errorf("failed to resolve DID (%s): %w", identifier, err)
+
}
+
accountDID = ident.DID
+
pdsEndpoint = ident.PDSEndpoint()
+
} else {
+
// It's a handle - resolve via local PDS first
+
if r.handleResolver == nil {
+
return "", fmt.Errorf("handle resolution not configured (PDS URL not set)")
+
}
+
+
// Resolve handle to DID via local PDS
+
did, err := r.handleResolver.ResolveHandle(ctx, identifier)
+
if err != nil {
+
return "", fmt.Errorf("failed to resolve handle via PDS (%s): %w", identifier, err)
+
}
+
if did == "" {
+
return "", fmt.Errorf("handle not found: %s", identifier)
+
}
+
+
slog.Info("dev mode: resolved handle via local PDS", "handle", identifier, "did", did)
+
+
// Parse the DID
+
parsedDID, err := syntax.ParseDID(did)
+
if err != nil {
+
return "", fmt.Errorf("invalid DID from PDS (%s): %w", did, err)
+
}
+
accountDID = parsedDID
+
+
// Now look up the DID document via PLC to get PDS endpoint
+
atid, err := syntax.ParseAtIdentifier(did)
+
if err != nil {
+
return "", fmt.Errorf("not a valid DID (%s): %w", did, err)
+
}
+
ident, err := dir.Lookup(ctx, *atid)
+
if err != nil {
+
return "", fmt.Errorf("failed to resolve DID document (%s): %w", did, err)
+
}
+
pdsEndpoint = ident.PDSEndpoint()
+
}
+
+
if pdsEndpoint == "" {
+
return "", fmt.Errorf("identity does not link to an atproto host (PDS)")
+
}
+
+
slog.Debug("dev mode: resolving auth server",
+
"did", accountDID,
+
"pds", pdsEndpoint)
+
+
// Resolve auth server URL (allowing HTTP for localhost)
+
authServerURL, err := r.ResolveAuthServerURL(ctx, pdsEndpoint)
+
if err != nil {
+
return "", fmt.Errorf("resolving auth server: %w", err)
+
}
+
+
slog.Info("dev mode: resolved auth server", "url", authServerURL)
+
+
// Fetch auth server metadata using our dev-friendly resolver
+
authMeta, err := r.ResolveAuthServerMetadataDev(ctx, authServerURL)
+
if err != nil {
+
return "", fmt.Errorf("fetching auth server metadata: %w", err)
+
}
+
+
slog.Debug("dev mode: got auth server metadata",
+
"issuer", authMeta.Issuer,
+
"authorization_endpoint", authMeta.AuthorizationEndpoint,
+
"token_endpoint", authMeta.TokenEndpoint)
+
+
// Send auth request (PAR) using indigo's method
+
info, err := client.ClientApp.SendAuthRequest(ctx, authMeta, client.Config.Scopes, identifier)
+
if err != nil {
+
return "", fmt.Errorf("auth request failed: %w", err)
+
}
+
+
// Set the account DID
+
info.AccountDID = &accountDID
+
+
// Persist auth request info
+
client.ClientApp.Store.SaveAuthRequestInfo(ctx, *info)
+
+
// Build redirect URL
+
params := url.Values{}
+
params.Set("client_id", client.ClientApp.Config.ClientID)
+
params.Set("request_uri", info.RequestURI)
+
+
authEndpoint := authMeta.AuthorizationEndpoint
+
redirectURL := fmt.Sprintf("%s?%s", authEndpoint, params.Encode())
+
+
slog.Info("dev mode: OAuth redirect URL built", "url_prefix", authEndpoint)
+
+
return redirectURL, nil
+
}
+106
internal/atproto/oauth/dev_resolver.go
···
+
//go:build dev
+
+
package oauth
+
+
import (
+
"context"
+
"encoding/json"
+
"fmt"
+
"log/slog"
+
"net/http"
+
"net/url"
+
"strings"
+
"time"
+
)
+
+
// DevHandleResolver resolves handles via local PDS for development
+
// This is needed because local handles (e.g., user.local.coves.dev) can't be
+
// resolved via standard DNS/HTTP well-known methods - they only exist on the local PDS.
+
type DevHandleResolver struct {
+
pdsURL string
+
httpClient *http.Client
+
}
+
+
// NewDevHandleResolver creates a resolver that queries local PDS for handle resolution
+
func NewDevHandleResolver(pdsURL string, allowPrivateIPs bool) *DevHandleResolver {
+
return &DevHandleResolver{
+
pdsURL: strings.TrimSuffix(pdsURL, "/"),
+
httpClient: NewSSRFSafeHTTPClient(allowPrivateIPs),
+
}
+
}
+
+
// ResolveHandle queries the local PDS to resolve a handle to a DID
+
// Returns the DID if successful, or empty string if not found
+
func (r *DevHandleResolver) ResolveHandle(ctx context.Context, handle string) (string, error) {
+
if r.pdsURL == "" {
+
return "", fmt.Errorf("PDS URL not configured")
+
}
+
+
// Build the resolve handle URL
+
resolveURL := fmt.Sprintf("%s/xrpc/com.atproto.identity.resolveHandle?handle=%s",
+
r.pdsURL, url.QueryEscape(handle))
+
+
// Create request with context and timeout
+
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
+
defer cancel()
+
+
req, err := http.NewRequestWithContext(ctx, "GET", resolveURL, nil)
+
if err != nil {
+
return "", fmt.Errorf("failed to create request: %w", err)
+
}
+
req.Header.Set("User-Agent", "Coves/1.0")
+
+
// Execute request
+
resp, err := r.httpClient.Do(req)
+
if err != nil {
+
return "", fmt.Errorf("failed to query PDS: %w", err)
+
}
+
defer resp.Body.Close()
+
+
// Check response status
+
if resp.StatusCode == http.StatusNotFound || resp.StatusCode == http.StatusBadRequest {
+
return "", nil // Handle not found
+
}
+
if resp.StatusCode != http.StatusOK {
+
return "", fmt.Errorf("PDS returned status %d", resp.StatusCode)
+
}
+
+
// Parse response
+
var result struct {
+
DID string `json:"did"`
+
}
+
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
+
return "", fmt.Errorf("failed to parse PDS response: %w", err)
+
}
+
+
if result.DID == "" {
+
return "", nil // No DID in response
+
}
+
+
slog.Debug("resolved handle via local PDS",
+
"handle", handle,
+
"did", result.DID,
+
"pds_url", r.pdsURL)
+
+
return result.DID, nil
+
}
+
+
// ResolveIdentifier attempts to resolve a handle to DID, or returns the DID if already provided
+
// This is the main entry point for the handlers
+
func (r *DevHandleResolver) ResolveIdentifier(ctx context.Context, identifier string) (string, error) {
+
// If it's already a DID, return as-is
+
if strings.HasPrefix(identifier, "did:") {
+
return identifier, nil
+
}
+
+
// Try to resolve the handle via local PDS
+
did, err := r.ResolveHandle(ctx, identifier)
+
if err != nil {
+
return "", fmt.Errorf("failed to resolve handle via PDS: %w", err)
+
}
+
if did == "" {
+
return "", fmt.Errorf("handle not found on local PDS: %s", identifier)
+
}
+
+
return did, nil
+
}
+41
internal/atproto/oauth/dev_stubs.go
···
+
//go:build !dev
+
+
package oauth
+
+
import (
+
"context"
+
+
"github.com/bluesky-social/indigo/atproto/identity"
+
)
+
+
// DevHandleResolver is a stub for production builds.
+
// The actual implementation is in dev_resolver.go (only compiled with -tags dev).
+
type DevHandleResolver struct{}
+
+
// NewDevHandleResolver returns nil in production builds.
+
// Dev mode features are only available when built with -tags dev.
+
func NewDevHandleResolver(pdsURL string, allowPrivateIPs bool) *DevHandleResolver {
+
return nil
+
}
+
+
// ResolveHandle is a stub that should never be called in production.
+
// The nil check in handlers.go prevents this from being reached.
+
func (r *DevHandleResolver) ResolveHandle(ctx context.Context, handle string) (string, error) {
+
panic("dev mode: ResolveHandle called in production build - this should never happen")
+
}
+
+
// DevAuthResolver is a stub for production builds.
+
// The actual implementation is in dev_auth_resolver.go (only compiled with -tags dev).
+
type DevAuthResolver struct{}
+
+
// NewDevAuthResolver returns nil in production builds.
+
// Dev mode features are only available when built with -tags dev.
+
func NewDevAuthResolver(pdsURL string, allowPrivateIPs bool) *DevAuthResolver {
+
return nil
+
}
+
+
// StartDevAuthFlow is a stub that should never be called in production.
+
// The nil check in handlers.go prevents this from being reached.
+
func (r *DevAuthResolver) StartDevAuthFlow(ctx context.Context, client *OAuthClient, identifier string, dir identity.Directory) (string, error) {
+
panic("dev mode: StartDevAuthFlow called in production build - this should never happen")
+
}
+107 -15
internal/atproto/oauth/handlers.go
···
"log/slog"
"net/http"
"net/url"
+
"strings"
"github.com/bluesky-social/indigo/atproto/auth/oauth"
"github.com/bluesky-social/indigo/atproto/syntax"
···
// OAuthHandler handles OAuth-related HTTP endpoints
type OAuthHandler struct {
-
client *OAuthClient
-
store oauth.ClientAuthStore
-
mobileStore MobileOAuthStore // For server-side CSRF validation
+
client *OAuthClient
+
store oauth.ClientAuthStore
+
mobileStore MobileOAuthStore // For server-side CSRF validation
+
devResolver *DevHandleResolver // For dev mode: resolve handles via local PDS
+
devAuthResolver *DevAuthResolver // For dev mode: bypass HTTPS validation for localhost OAuth
}
// NewOAuthHandler creates a new OAuth handler
···
handler.mobileStore = mobileStore
}
+
// In dev mode, create resolvers for local PDS/PLC
+
// This is needed because:
+
// 1. Local handles (e.g., user.local.coves.dev) can't be resolved via DNS/HTTP
+
// 2. Indigo's OAuth library requires HTTPS, which localhost doesn't have
+
if client.Config.DevMode {
+
if client.Config.PDSURL != "" {
+
handler.devResolver = NewDevHandleResolver(client.Config.PDSURL, client.Config.AllowPrivateIPs)
+
slog.Info("dev mode: handle resolution via local PDS enabled", "pds_url", client.Config.PDSURL)
+
}
+
// Create dev auth resolver to bypass HTTPS validation (pass PDS URL for handle resolution)
+
handler.devAuthResolver = NewDevAuthResolver(client.Config.PDSURL, client.Config.AllowPrivateIPs)
+
slog.Info("dev mode: localhost OAuth auth resolver enabled", "pds_url", client.Config.PDSURL)
+
}
+
return handler
}
···
return
}
-
// Start OAuth flow
-
redirectURL, err := h.client.ClientApp.StartAuthFlow(ctx, identifier)
-
if err != nil {
-
slog.Error("failed to start OAuth flow", "error", err, "identifier", identifier)
-
http.Error(w, fmt.Sprintf("failed to start OAuth flow: %v", err), http.StatusBadRequest)
-
return
+
var redirectURL string
+
var err error
+
+
// DEV MODE: Use custom OAuth flow that bypasses HTTPS validation
+
// This is needed because:
+
// 1. Local handles can't be resolved via DNS/HTTP well-known
+
// 2. Indigo's OAuth library requires HTTPS for auth servers
+
if h.devAuthResolver != nil {
+
slog.Info("dev mode: using localhost OAuth flow", "identifier", identifier)
+
redirectURL, err = h.devAuthResolver.StartDevAuthFlow(ctx, h.client, identifier, h.client.ClientApp.Dir)
+
if err != nil {
+
slog.Error("dev mode: failed to start OAuth flow", "error", err, "identifier", identifier)
+
http.Error(w, fmt.Sprintf("failed to start OAuth flow: %v", err), http.StatusBadRequest)
+
return
+
}
+
} else {
+
// Production mode: use standard indigo OAuth flow
+
redirectURL, err = h.client.ClientApp.StartAuthFlow(ctx, identifier)
+
if err != nil {
+
slog.Error("failed to start OAuth flow", "error", err, "identifier", identifier)
+
http.Error(w, fmt.Sprintf("failed to start OAuth flow: %v", err), http.StatusBadRequest)
+
return
+
}
}
// Log OAuth flow initiation (sanitized - no full URL to avoid leaking state)
···
func (h *OAuthHandler) HandleMobileLogin(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
+
// DEV MODE: Redirect localhost to 127.0.0.1 for cookie consistency
+
// The OAuth callback URL uses 127.0.0.1 (per RFC 8252), so cookies must be set
+
// on 127.0.0.1. If user calls localhost, redirect to 127.0.0.1 first.
+
if h.client.Config.DevMode && strings.Contains(r.Host, "localhost") {
+
// Use the configured PublicURL host for consistency
+
redirectURL := h.client.Config.PublicURL + r.URL.RequestURI()
+
slog.Info("dev mode: redirecting localhost to PublicURL host for cookie consistency",
+
"from", r.Host, "to", h.client.Config.PublicURL)
+
http.Redirect(w, r, redirectURL, http.StatusFound)
+
return
+
}
+
// Get handle or DID from query params
identifier := r.URL.Query().Get("handle")
if identifier == "" {
···
RedirectURI: mobileRedirectURI,
})
-
// Start OAuth flow (the store wrapper will save mobile data when auth request is saved)
-
redirectURL, err := h.client.ClientApp.StartAuthFlow(mobileCtx, identifier)
-
if err != nil {
-
slog.Error("failed to start OAuth flow", "error", err, "identifier", identifier)
-
http.Error(w, fmt.Sprintf("failed to start OAuth flow: %v", err), http.StatusBadRequest)
-
return
+
var redirectURL string
+
+
// DEV MODE: Use custom OAuth flow that bypasses HTTPS validation
+
// This is needed because:
+
// 1. Local handles can't be resolved via DNS/HTTP well-known
+
// 2. Indigo's OAuth library requires HTTPS for auth servers
+
if h.devAuthResolver != nil {
+
slog.Info("dev mode: using localhost OAuth flow for mobile", "identifier", identifier)
+
redirectURL, err = h.devAuthResolver.StartDevAuthFlow(mobileCtx, h.client, identifier, h.client.ClientApp.Dir)
+
if err != nil {
+
slog.Error("dev mode: failed to start OAuth flow", "error", err, "identifier", identifier)
+
http.Error(w, fmt.Sprintf("failed to start OAuth flow: %v", err), http.StatusBadRequest)
+
return
+
}
+
} else {
+
// Production mode: use standard indigo OAuth flow
+
redirectURL, err = h.client.ClientApp.StartAuthFlow(mobileCtx, identifier)
+
if err != nil {
+
slog.Error("failed to start OAuth flow", "error", err, "identifier", identifier)
+
http.Error(w, fmt.Sprintf("failed to start OAuth flow: %v", err), http.StatusBadRequest)
+
return
+
}
}
// Log mobile OAuth flow initiation (sanitized - no full URLs or sensitive params)
···
// Check if the handle is the special "handle.invalid" value
// This indicates that bidirectional verification failed (DID->handle->DID roundtrip failed)
if ident.Handle.String() == "handle.invalid" {
+
// DEV MODE: For local handles, verify via PDS instead of DNS/HTTP
+
// Local handles like "user.local.coves.dev" can't be resolved via DNS
+
if h.devResolver != nil {
+
// Get the handle from DID document (alsoKnownAs)
+
declaredHandle := ""
+
if len(ident.AlsoKnownAs) > 0 {
+
// Extract handle from at:// URI
+
for _, aka := range ident.AlsoKnownAs {
+
if len(aka) > 5 && aka[:5] == "at://" {
+
declaredHandle = aka[5:]
+
break
+
}
+
}
+
}
+
+
if declaredHandle != "" {
+
// Verify handle via PDS
+
resolvedDID, err := h.devResolver.ResolveHandle(ctx, declaredHandle)
+
if err == nil && resolvedDID == sessData.AccountDID.String() {
+
slog.Info("OAuth callback successful (dev mode: handle verified via PDS)",
+
"did", sessData.AccountDID, "handle", declaredHandle)
+
goto handleVerificationPassed
+
}
+
slog.Warn("dev mode: PDS handle verification failed",
+
"did", sessData.AccountDID, "handle", declaredHandle,
+
"resolved_did", resolvedDID, "error", err)
+
}
+
}
+
slog.Warn("OAuth callback: bidirectional handle verification failed",
"did", sessData.AccountDID,
"handle", "handle.invalid",
···
"did", sessData.AccountDID)
slog.Info("OAuth callback successful (no handle verification)", "did", sessData.AccountDID)
}
+
handleVerificationPassed:
// Check if this is a mobile callback (check for mobile_redirect_uri cookie)
mobileRedirect, err := r.Cookie("mobile_redirect_uri")
+5 -1
scripts/dev-run.sh
···
#!/bin/bash
# Development server runner - loads .env.dev before starting
+
# Uses -tags dev to include dev-only code (localhost OAuth resolvers, etc.)
set -a # automatically export all variables
source .env.dev
···
echo " IS_DEV_ENV: $IS_DEV_ENV"
echo " PLC_DIRECTORY_URL: $PLC_DIRECTORY_URL"
echo " JETSTREAM_URL: $JETSTREAM_URL"
+
echo " APPVIEW_PUBLIC_URL: $APPVIEW_PUBLIC_URL"
+
echo " PDS_URL: $PDS_URL"
+
echo " Build tags: dev"
echo ""
-
go run ./cmd/server
+
go run -tags dev ./cmd/server
+125
internal/atproto/pds/factory.go
···
+
package pds
+
+
import (
+
"context"
+
"fmt"
+
"net/http"
+
+
"github.com/bluesky-social/indigo/atproto/atclient"
+
"github.com/bluesky-social/indigo/atproto/auth/oauth"
+
"github.com/bluesky-social/indigo/atproto/syntax"
+
)
+
+
// NewFromOAuthSession creates a PDS client from an OAuth session.
+
// This uses DPoP authentication - the correct method for OAuth tokens.
+
//
+
// The oauthClient is used to resume the session and get a properly configured
+
// APIClient that handles DPoP proof generation and nonce rotation automatically.
+
func NewFromOAuthSession(ctx context.Context, oauthClient *oauth.ClientApp, sessionData *oauth.ClientSessionData) (Client, error) {
+
if oauthClient == nil {
+
return nil, fmt.Errorf("oauthClient is required")
+
}
+
if sessionData == nil {
+
return nil, fmt.Errorf("sessionData is required")
+
}
+
+
// ResumeSession reconstructs the OAuth session with DPoP key
+
// and returns a ClientSession that can generate authenticated requests
+
sess, err := oauthClient.ResumeSession(ctx, sessionData.AccountDID, sessionData.SessionID)
+
if err != nil {
+
return nil, fmt.Errorf("failed to resume OAuth session: %w", err)
+
}
+
+
// APIClient() returns an *atclient.APIClient configured with DPoP auth
+
apiClient := sess.APIClient()
+
+
return &client{
+
apiClient: apiClient,
+
did: sessionData.AccountDID.String(),
+
host: sessionData.HostURL,
+
}, nil
+
}
+
+
// NewFromPasswordAuth creates a PDS client using password authentication.
+
// This uses Bearer token authentication from com.atproto.server.createSession.
+
//
+
// Primarily used for:
+
// - E2E tests with local PDS
+
// - Development/debugging tools
+
// - Non-OAuth clients
+
//
+
// Note: This establishes a new session with the PDS. For repeated calls,
+
// consider using NewFromAccessToken if you already have a valid access token.
+
func NewFromPasswordAuth(ctx context.Context, host, handle, password string) (Client, error) {
+
if host == "" {
+
return nil, fmt.Errorf("host is required")
+
}
+
if handle == "" {
+
return nil, fmt.Errorf("handle is required")
+
}
+
if password == "" {
+
return nil, fmt.Errorf("password is required")
+
}
+
+
// LoginWithPasswordHost creates a session and returns an authenticated APIClient
+
// This handles the createSession call and Bearer token setup
+
apiClient, err := atclient.LoginWithPasswordHost(ctx, host, handle, password, "", nil)
+
if err != nil {
+
return nil, fmt.Errorf("failed to login with password: %w", err)
+
}
+
+
// Get DID from the authenticated client
+
did := ""
+
if apiClient.AccountDID != nil {
+
did = apiClient.AccountDID.String()
+
}
+
+
return &client{
+
apiClient: apiClient,
+
did: did,
+
host: host,
+
}, nil
+
}
+
+
// NewFromAccessToken creates a PDS client from an existing access token.
+
// This is useful when you already have a valid Bearer token (e.g., from createSession)
+
// and don't want to re-authenticate.
+
//
+
// WARNING: This creates a client with Bearer auth only. Do NOT use this with
+
// OAuth access tokens - those require DPoP proofs. Use NewFromOAuthSession instead.
+
func NewFromAccessToken(host, did, accessToken string) (Client, error) {
+
if host == "" {
+
return nil, fmt.Errorf("host is required")
+
}
+
if did == "" {
+
return nil, fmt.Errorf("did is required")
+
}
+
if accessToken == "" {
+
return nil, fmt.Errorf("accessToken is required")
+
}
+
+
// Create APIClient with Bearer auth
+
apiClient := atclient.NewAPIClient(host)
+
apiClient.Auth = &bearerAuth{token: accessToken}
+
+
return &client{
+
apiClient: apiClient,
+
did: did,
+
host: host,
+
}, nil
+
}
+
+
// bearerAuth implements atclient.AuthMethod for simple Bearer token auth.
+
// This is used for password-based sessions where DPoP is not required.
+
type bearerAuth struct {
+
token string
+
}
+
+
// Ensure bearerAuth implements atclient.AuthMethod.
+
var _ atclient.AuthMethod = (*bearerAuth)(nil)
+
+
// DoWithAuth adds the Bearer token to the request and executes it.
+
func (b *bearerAuth) DoWithAuth(c *http.Client, req *http.Request, _ syntax.NSID) (*http.Response, error) {
+
req.Header.Set("Authorization", "Bearer "+b.token)
+
return c.Do(req)
+
}
+18
tests/integration/helpers.go
···
import (
"Coves/internal/api/middleware"
"Coves/internal/atproto/oauth"
+
"Coves/internal/atproto/pds"
"Coves/internal/core/users"
+
"Coves/internal/core/votes"
"bytes"
"context"
"database/sql"
···
e.store.AddSessionWithPDS(did, sessionID, pdsAccessToken, pdsURL)
return token
}
+
+
// PasswordAuthPDSClientFactory creates a PDSClientFactory that uses password-based Bearer auth.
+
// This is for E2E tests that use createSession instead of OAuth.
+
// The factory extracts the access token and host URL from the session data.
+
func PasswordAuthPDSClientFactory() votes.PDSClientFactory {
+
return func(ctx context.Context, session *oauthlib.ClientSessionData) (pds.Client, error) {
+
if session.AccessToken == "" {
+
return nil, fmt.Errorf("session has no access token")
+
}
+
if session.HostURL == "" {
+
return nil, fmt.Errorf("session has no host URL")
+
}
+
+
return pds.NewFromAccessToken(session.HostURL, session.AccountDID.String(), session.AccessToken)
+
}
+
}
+267
cmd/reindex-votes/main.go
···
+
// cmd/reindex-votes/main.go
+
// Quick tool to reindex votes from PDS to AppView database
+
package main
+
+
import (
+
"context"
+
"database/sql"
+
"encoding/json"
+
"fmt"
+
"log"
+
"net/http"
+
"net/url"
+
"os"
+
"strings"
+
"time"
+
+
_ "github.com/lib/pq"
+
)
+
+
type ListRecordsResponse struct {
+
Records []Record `json:"records"`
+
Cursor string `json:"cursor"`
+
}
+
+
type Record struct {
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
Value map[string]interface{} `json:"value"`
+
}
+
+
func main() {
+
// Get config from env
+
dbURL := os.Getenv("DATABASE_URL")
+
if dbURL == "" {
+
dbURL = "postgres://dev_user:dev_password@localhost:5435/coves_dev?sslmode=disable"
+
}
+
pdsURL := os.Getenv("PDS_URL")
+
if pdsURL == "" {
+
pdsURL = "http://localhost:3001"
+
}
+
+
log.Printf("Connecting to database...")
+
db, err := sql.Open("postgres", dbURL)
+
if err != nil {
+
log.Fatalf("Failed to connect to database: %v", err)
+
}
+
defer db.Close()
+
+
ctx := context.Background()
+
+
// Get all accounts directly from the PDS
+
log.Printf("Fetching accounts from PDS (%s)...", pdsURL)
+
dids, err := fetchAllAccountsFromPDS(pdsURL)
+
if err != nil {
+
log.Fatalf("Failed to fetch accounts from PDS: %v", err)
+
}
+
log.Printf("Found %d accounts on PDS to check for votes", len(dids))
+
+
// Reset vote counts first
+
log.Printf("Resetting all vote counts...")
+
if _, err := db.ExecContext(ctx, "DELETE FROM votes"); err != nil {
+
log.Fatalf("Failed to clear votes table: %v", err)
+
}
+
if _, err := db.ExecContext(ctx, "UPDATE posts SET upvote_count = 0, downvote_count = 0, score = 0"); err != nil {
+
log.Fatalf("Failed to reset post vote counts: %v", err)
+
}
+
if _, err := db.ExecContext(ctx, "UPDATE comments SET upvote_count = 0, downvote_count = 0, score = 0"); err != nil {
+
log.Fatalf("Failed to reset comment vote counts: %v", err)
+
}
+
+
// For each user, fetch their votes from PDS
+
totalVotes := 0
+
for _, did := range dids {
+
votes, err := fetchVotesFromPDS(pdsURL, did)
+
if err != nil {
+
log.Printf("Warning: failed to fetch votes for %s: %v", did, err)
+
continue
+
}
+
+
if len(votes) == 0 {
+
continue
+
}
+
+
log.Printf("Found %d votes for %s", len(votes), did)
+
+
// Index each vote
+
for _, vote := range votes {
+
if err := indexVote(ctx, db, did, vote); err != nil {
+
log.Printf("Warning: failed to index vote %s: %v", vote.URI, err)
+
continue
+
}
+
totalVotes++
+
}
+
}
+
+
log.Printf("โœ“ Reindexed %d votes from PDS", totalVotes)
+
}
+
+
// fetchAllAccountsFromPDS queries the PDS sync API to get all repo DIDs
+
func fetchAllAccountsFromPDS(pdsURL string) ([]string, error) {
+
// Use com.atproto.sync.listRepos to get all repos on this PDS
+
var allDIDs []string
+
cursor := ""
+
+
for {
+
reqURL := fmt.Sprintf("%s/xrpc/com.atproto.sync.listRepos?limit=100", pdsURL)
+
if cursor != "" {
+
reqURL += "&cursor=" + url.QueryEscape(cursor)
+
}
+
+
resp, err := http.Get(reqURL)
+
if err != nil {
+
return nil, fmt.Errorf("HTTP request failed: %w", err)
+
}
+
defer resp.Body.Close()
+
+
if resp.StatusCode != 200 {
+
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
+
}
+
+
var result struct {
+
Repos []struct {
+
DID string `json:"did"`
+
} `json:"repos"`
+
Cursor string `json:"cursor"`
+
}
+
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
+
return nil, fmt.Errorf("failed to decode response: %w", err)
+
}
+
+
for _, repo := range result.Repos {
+
allDIDs = append(allDIDs, repo.DID)
+
}
+
+
if result.Cursor == "" {
+
break
+
}
+
cursor = result.Cursor
+
}
+
+
return allDIDs, nil
+
}
+
+
func fetchVotesFromPDS(pdsURL, did string) ([]Record, error) {
+
var allRecords []Record
+
cursor := ""
+
collection := "social.coves.feed.vote"
+
+
for {
+
reqURL := fmt.Sprintf("%s/xrpc/com.atproto.repo.listRecords?repo=%s&collection=%s&limit=100",
+
pdsURL, url.QueryEscape(did), url.QueryEscape(collection))
+
if cursor != "" {
+
reqURL += "&cursor=" + url.QueryEscape(cursor)
+
}
+
+
resp, err := http.Get(reqURL)
+
if err != nil {
+
return nil, fmt.Errorf("HTTP request failed: %w", err)
+
}
+
defer resp.Body.Close()
+
+
if resp.StatusCode == 400 {
+
// User doesn't exist on this PDS or has no records - that's OK
+
return nil, nil
+
}
+
if resp.StatusCode != 200 {
+
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
+
}
+
+
var result ListRecordsResponse
+
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
+
return nil, fmt.Errorf("failed to decode response: %w", err)
+
}
+
+
allRecords = append(allRecords, result.Records...)
+
+
if result.Cursor == "" {
+
break
+
}
+
cursor = result.Cursor
+
}
+
+
return allRecords, nil
+
}
+
+
func indexVote(ctx context.Context, db *sql.DB, voterDID string, record Record) error {
+
// Extract vote data from record
+
subject, ok := record.Value["subject"].(map[string]interface{})
+
if !ok {
+
return fmt.Errorf("missing subject")
+
}
+
subjectURI, _ := subject["uri"].(string)
+
subjectCID, _ := subject["cid"].(string)
+
direction, _ := record.Value["direction"].(string)
+
createdAtStr, _ := record.Value["createdAt"].(string)
+
+
if subjectURI == "" || direction == "" {
+
return fmt.Errorf("invalid vote record: missing required fields")
+
}
+
+
// Parse created_at
+
createdAt, err := time.Parse(time.RFC3339, createdAtStr)
+
if err != nil {
+
createdAt = time.Now()
+
}
+
+
// Extract rkey from URI (at://did/collection/rkey)
+
parts := strings.Split(record.URI, "/")
+
if len(parts) < 5 {
+
return fmt.Errorf("invalid URI format: %s", record.URI)
+
}
+
rkey := parts[len(parts)-1]
+
+
// Start transaction
+
tx, err := db.BeginTx(ctx, nil)
+
if err != nil {
+
return fmt.Errorf("failed to begin transaction: %w", err)
+
}
+
defer tx.Rollback()
+
+
// Insert vote
+
_, err = tx.ExecContext(ctx, `
+
INSERT INTO votes (uri, cid, rkey, voter_did, subject_uri, subject_cid, direction, created_at, indexed_at)
+
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, NOW())
+
ON CONFLICT (uri) DO NOTHING
+
`, record.URI, record.CID, rkey, voterDID, subjectURI, subjectCID, direction, createdAt)
+
if err != nil {
+
return fmt.Errorf("failed to insert vote: %w", err)
+
}
+
+
// Update post/comment counts
+
collection := extractCollectionFromURI(subjectURI)
+
var updateQuery string
+
+
switch collection {
+
case "social.coves.community.post":
+
if direction == "up" {
+
updateQuery = `UPDATE posts SET upvote_count = upvote_count + 1, score = upvote_count + 1 - downvote_count WHERE uri = $1 AND deleted_at IS NULL`
+
} else {
+
updateQuery = `UPDATE posts SET downvote_count = downvote_count + 1, score = upvote_count - (downvote_count + 1) WHERE uri = $1 AND deleted_at IS NULL`
+
}
+
case "social.coves.community.comment":
+
if direction == "up" {
+
updateQuery = `UPDATE comments SET upvote_count = upvote_count + 1, score = upvote_count + 1 - downvote_count WHERE uri = $1 AND deleted_at IS NULL`
+
} else {
+
updateQuery = `UPDATE comments SET downvote_count = downvote_count + 1, score = upvote_count - (downvote_count + 1) WHERE uri = $1 AND deleted_at IS NULL`
+
}
+
default:
+
// Unknown collection, just index the vote
+
return tx.Commit()
+
}
+
+
if _, err := tx.ExecContext(ctx, updateQuery, subjectURI); err != nil {
+
return fmt.Errorf("failed to update vote counts: %w", err)
+
}
+
+
return tx.Commit()
+
}
+
+
func extractCollectionFromURI(uri string) string {
+
// at://did:plc:xxx/social.coves.community.post/rkey
+
parts := strings.Split(uri, "/")
+
if len(parts) >= 4 {
+
return parts[3]
+
}
+
return ""
+
}
+7 -5
internal/api/routes/communityFeed.go
···
import (
"Coves/internal/api/handlers/communityFeed"
+
"Coves/internal/api/middleware"
"Coves/internal/core/communityFeeds"
+
"Coves/internal/core/votes"
"github.com/go-chi/chi/v5"
)
···
func RegisterCommunityFeedRoutes(
r chi.Router,
feedService communityFeeds.Service,
+
voteService votes.Service,
+
authMiddleware *middleware.OAuthAuthMiddleware,
) {
// Create handlers
-
getCommunityHandler := communityFeed.NewGetCommunityHandler(feedService)
+
getCommunityHandler := communityFeed.NewGetCommunityHandler(feedService, voteService)
// GET /xrpc/social.coves.communityFeed.getCommunity
-
// Public endpoint - basic community sorting only for Alpha
-
// TODO(feed-generator): Add OptionalAuth middleware when implementing viewer-specific state
-
// (blocks, upvotes, saves, etc.) in feed generator skeleton
-
r.Get("/xrpc/social.coves.communityFeed.getCommunity", getCommunityHandler.HandleGetCommunity)
+
// Public endpoint with optional auth for viewer-specific state (vote state)
+
r.With(authMiddleware.OptionalAuth).Get("/xrpc/social.coves.communityFeed.getCommunity", getCommunityHandler.HandleGetCommunity)
}
+221
internal/core/votes/cache.go
···
+
package votes
+
+
import (
+
"context"
+
"fmt"
+
"log/slog"
+
"strings"
+
"sync"
+
"time"
+
+
"Coves/internal/atproto/pds"
+
)
+
+
// CachedVote represents a vote stored in the cache
+
type CachedVote struct {
+
Direction string // "up" or "down"
+
URI string // vote record URI (at://did/collection/rkey)
+
RKey string // record key
+
}
+
+
// VoteCache provides an in-memory cache of user votes fetched from their PDS.
+
// This avoids eventual consistency issues with the AppView database.
+
type VoteCache struct {
+
mu sync.RWMutex
+
votes map[string]map[string]*CachedVote // userDID -> subjectURI -> vote
+
expiry map[string]time.Time // userDID -> expiry time
+
ttl time.Duration
+
logger *slog.Logger
+
}
+
+
// NewVoteCache creates a new vote cache with the specified TTL
+
func NewVoteCache(ttl time.Duration, logger *slog.Logger) *VoteCache {
+
if logger == nil {
+
logger = slog.Default()
+
}
+
return &VoteCache{
+
votes: make(map[string]map[string]*CachedVote),
+
expiry: make(map[string]time.Time),
+
ttl: ttl,
+
logger: logger,
+
}
+
}
+
+
// GetVotesForUser returns all cached votes for a user.
+
// Returns nil if cache is empty or expired for this user.
+
func (c *VoteCache) GetVotesForUser(userDID string) map[string]*CachedVote {
+
c.mu.RLock()
+
defer c.mu.RUnlock()
+
+
// Check if cache exists and is not expired
+
expiry, exists := c.expiry[userDID]
+
if !exists || time.Now().After(expiry) {
+
return nil
+
}
+
+
return c.votes[userDID]
+
}
+
+
// GetVote returns the cached vote for a specific subject, or nil if not found/expired
+
func (c *VoteCache) GetVote(userDID, subjectURI string) *CachedVote {
+
votes := c.GetVotesForUser(userDID)
+
if votes == nil {
+
return nil
+
}
+
return votes[subjectURI]
+
}
+
+
// IsCached returns true if the user's votes are cached and not expired
+
func (c *VoteCache) IsCached(userDID string) bool {
+
c.mu.RLock()
+
defer c.mu.RUnlock()
+
+
expiry, exists := c.expiry[userDID]
+
return exists && time.Now().Before(expiry)
+
}
+
+
// SetVotesForUser replaces all cached votes for a user
+
func (c *VoteCache) SetVotesForUser(userDID string, votes map[string]*CachedVote) {
+
c.mu.Lock()
+
defer c.mu.Unlock()
+
+
c.votes[userDID] = votes
+
c.expiry[userDID] = time.Now().Add(c.ttl)
+
+
c.logger.Debug("vote cache updated",
+
"user", userDID,
+
"vote_count", len(votes),
+
"expires_at", c.expiry[userDID])
+
}
+
+
// SetVote adds or updates a single vote in the cache
+
func (c *VoteCache) SetVote(userDID, subjectURI string, vote *CachedVote) {
+
c.mu.Lock()
+
defer c.mu.Unlock()
+
+
if c.votes[userDID] == nil {
+
c.votes[userDID] = make(map[string]*CachedVote)
+
}
+
+
c.votes[userDID][subjectURI] = vote
+
+
// Always extend expiry on vote action - active users keep their cache fresh
+
c.expiry[userDID] = time.Now().Add(c.ttl)
+
+
c.logger.Debug("vote cached",
+
"user", userDID,
+
"subject", subjectURI,
+
"direction", vote.Direction)
+
}
+
+
// RemoveVote removes a vote from the cache (for toggle-off)
+
func (c *VoteCache) RemoveVote(userDID, subjectURI string) {
+
c.mu.Lock()
+
defer c.mu.Unlock()
+
+
if c.votes[userDID] != nil {
+
delete(c.votes[userDID], subjectURI)
+
+
// Extend expiry on vote action - active users keep their cache fresh
+
c.expiry[userDID] = time.Now().Add(c.ttl)
+
+
c.logger.Debug("vote removed from cache",
+
"user", userDID,
+
"subject", subjectURI)
+
}
+
}
+
+
// Invalidate removes all cached votes for a user
+
func (c *VoteCache) Invalidate(userDID string) {
+
c.mu.Lock()
+
defer c.mu.Unlock()
+
+
delete(c.votes, userDID)
+
delete(c.expiry, userDID)
+
+
c.logger.Debug("vote cache invalidated", "user", userDID)
+
}
+
+
// FetchAndCacheFromPDS fetches all votes from the user's PDS and caches them.
+
// This should be called on first authenticated request or when cache is expired.
+
func (c *VoteCache) FetchAndCacheFromPDS(ctx context.Context, pdsClient pds.Client) error {
+
userDID := pdsClient.DID()
+
+
c.logger.Debug("fetching votes from PDS",
+
"user", userDID,
+
"pds", pdsClient.HostURL())
+
+
votes, err := c.fetchAllVotesFromPDS(ctx, pdsClient)
+
if err != nil {
+
return fmt.Errorf("failed to fetch votes from PDS: %w", err)
+
}
+
+
c.SetVotesForUser(userDID, votes)
+
+
c.logger.Info("vote cache populated from PDS",
+
"user", userDID,
+
"vote_count", len(votes))
+
+
return nil
+
}
+
+
// fetchAllVotesFromPDS paginates through all vote records on the user's PDS
+
func (c *VoteCache) fetchAllVotesFromPDS(ctx context.Context, pdsClient pds.Client) (map[string]*CachedVote, error) {
+
votes := make(map[string]*CachedVote)
+
cursor := ""
+
const pageSize = 100
+
const collection = "social.coves.feed.vote"
+
+
for {
+
result, err := pdsClient.ListRecords(ctx, collection, pageSize, cursor)
+
if err != nil {
+
if pds.IsAuthError(err) {
+
return nil, ErrNotAuthorized
+
}
+
return nil, fmt.Errorf("listRecords failed: %w", err)
+
}
+
+
for _, rec := range result.Records {
+
// Extract subject from record value
+
subject, ok := rec.Value["subject"].(map[string]any)
+
if !ok {
+
continue
+
}
+
+
subjectURI, ok := subject["uri"].(string)
+
if !ok || subjectURI == "" {
+
continue
+
}
+
+
direction, _ := rec.Value["direction"].(string)
+
if direction == "" {
+
continue
+
}
+
+
// Extract rkey from URI
+
rkey := extractRKeyFromURI(rec.URI)
+
+
votes[subjectURI] = &CachedVote{
+
Direction: direction,
+
URI: rec.URI,
+
RKey: rkey,
+
}
+
}
+
+
if result.Cursor == "" {
+
break
+
}
+
cursor = result.Cursor
+
}
+
+
return votes, nil
+
}
+
+
// extractRKeyFromURI extracts the rkey from an AT-URI (at://did/collection/rkey)
+
func extractRKeyFromURI(uri string) string {
+
parts := strings.Split(uri, "/")
+
if len(parts) >= 5 {
+
return parts[len(parts)-1]
+
}
+
return ""
+
}
+14
internal/core/votes/service.go
···
// - Deletes the user's vote record from their PDS
// - AppView will soft-delete via Jetstream consumer
DeleteVote(ctx context.Context, session *oauthlib.ClientSessionData, req DeleteVoteRequest) error
+
+
// EnsureCachePopulated fetches the user's votes from their PDS if not already cached.
+
// This should be called before rendering feeds to ensure vote state is available.
+
// If cache is already populated and not expired, this is a no-op.
+
EnsureCachePopulated(ctx context.Context, session *oauthlib.ClientSessionData) error
+
+
// GetViewerVote returns the viewer's vote for a specific subject, or nil if not voted.
+
// Returns from cache if available, otherwise returns nil (caller should ensure cache is populated).
+
GetViewerVote(userDID, subjectURI string) *CachedVote
+
+
// GetViewerVotesForSubjects returns the viewer's votes for multiple subjects.
+
// Returns a map of subjectURI -> CachedVote for subjects the user has voted on.
+
// This is efficient for batch lookups when rendering feeds.
+
GetViewerVotesForSubjects(userDID string, subjectURIs []string) map[string]*CachedVote
}
// CreateVoteRequest contains the parameters for creating a vote
+84 -2
internal/core/votes/service_impl.go
···
oauthStore oauth.ClientAuthStore
logger *slog.Logger
pdsClientFactory PDSClientFactory // Optional, for testing. If nil, uses OAuth.
+
cache *VoteCache // In-memory cache of user votes from PDS
}
// NewService creates a new vote service instance
-
func NewService(repo Repository, oauthClient *oauthclient.OAuthClient, oauthStore oauth.ClientAuthStore, logger *slog.Logger) Service {
+
func NewService(repo Repository, oauthClient *oauthclient.OAuthClient, oauthStore oauth.ClientAuthStore, cache *VoteCache, logger *slog.Logger) Service {
if logger == nil {
logger = slog.Default()
}
···
repo: repo,
oauthClient: oauthClient,
oauthStore: oauthStore,
+
cache: cache,
logger: logger,
}
}
// NewServiceWithPDSFactory creates a vote service with a custom PDS client factory.
// This is primarily for testing with password-based authentication.
-
func NewServiceWithPDSFactory(repo Repository, logger *slog.Logger, factory PDSClientFactory) Service {
+
func NewServiceWithPDSFactory(repo Repository, cache *VoteCache, logger *slog.Logger, factory PDSClientFactory) Service {
if logger == nil {
logger = slog.Default()
}
return &voteService{
repo: repo,
+
cache: cache,
logger: logger,
pdsClientFactory: factory,
}
···
"subject", req.Subject.URI,
"direction", req.Direction)
+
// Update cache - remove the vote
+
if s.cache != nil {
+
s.cache.RemoveVote(session.AccountDID.String(), req.Subject.URI)
+
}
+
// Return empty response to indicate deletion
return &CreateVoteResponse{
URI: "",
···
"uri", uri,
"cid", cid)
+
// Update cache - add the new vote
+
if s.cache != nil {
+
s.cache.SetVote(session.AccountDID.String(), req.Subject.URI, &CachedVote{
+
Direction: req.Direction,
+
URI: uri,
+
RKey: extractRKeyFromURI(uri),
+
})
+
}
+
return &CreateVoteResponse{
URI: uri,
CID: cid,
···
"subject", req.Subject.URI,
"uri", existing.URI)
+
// Update cache - remove the vote
+
if s.cache != nil {
+
s.cache.RemoveVote(session.AccountDID.String(), req.Subject.URI)
+
}
+
return nil
}
···
// No vote found for this subject after checking all pages
return nil, nil
}
+
+
// EnsureCachePopulated fetches the user's votes from their PDS if not already cached.
+
func (s *voteService) EnsureCachePopulated(ctx context.Context, session *oauth.ClientSessionData) error {
+
if s.cache == nil {
+
return nil // No cache configured
+
}
+
+
// Check if already cached
+
if s.cache.IsCached(session.AccountDID.String()) {
+
return nil
+
}
+
+
// Create PDS client for this session
+
pdsClient, err := s.getPDSClient(ctx, session)
+
if err != nil {
+
s.logger.Error("failed to create PDS client for cache population",
+
"error", err,
+
"user", session.AccountDID)
+
return fmt.Errorf("failed to create PDS client: %w", err)
+
}
+
+
// Fetch and cache votes from PDS
+
if err := s.cache.FetchAndCacheFromPDS(ctx, pdsClient); err != nil {
+
s.logger.Error("failed to populate vote cache from PDS",
+
"error", err,
+
"user", session.AccountDID)
+
return fmt.Errorf("failed to populate vote cache: %w", err)
+
}
+
+
return nil
+
}
+
+
// GetViewerVote returns the viewer's vote for a specific subject, or nil if not voted.
+
func (s *voteService) GetViewerVote(userDID, subjectURI string) *CachedVote {
+
if s.cache == nil {
+
return nil
+
}
+
return s.cache.GetVote(userDID, subjectURI)
+
}
+
+
// GetViewerVotesForSubjects returns the viewer's votes for multiple subjects.
+
func (s *voteService) GetViewerVotesForSubjects(userDID string, subjectURIs []string) map[string]*CachedVote {
+
result := make(map[string]*CachedVote)
+
if s.cache == nil {
+
return result
+
}
+
+
allVotes := s.cache.GetVotesForUser(userDID)
+
if allVotes == nil {
+
return result
+
}
+
+
for _, uri := range subjectURIs {
+
if vote, exists := allVotes[uri]; exists {
+
result[uri] = vote
+
}
+
}
+
+
return result
+
}
+118 -111
CLAUDE.md
···
-
Project: Coves PR Reviewer
-
You are a distinguished senior architect conducting a thorough code review for Coves, a forum-like atProto social media platform.
-
-
## Review Mindset
-
- Be constructive but thorough - catch issues before they reach production
-
- Question assumptions and look for edge cases
-
- Prioritize security, performance, and maintainability concerns
-
- Suggest alternatives when identifying problems
-
- Ensure there is proper test coverage that adequately tests atproto write forward architecture
-
-
-
## Special Attention Areas for Coves
-
- **atProto architecture**: - Ensure architecture follows atProto recommendations with WRITE FORWARD ARCHITECTURE (Appview -> PDS -> Relay -> Appview -> App DB (if necessary))
-
- Ensure HTTP Endpoints match the Lexicon data contract
-
- **Federation**: Check for proper DID resolution and identity verification
-
-
## Review Checklist
-
-
### 1. Architecture Compliance
-
**MUST VERIFY:**
-
- [ ] NO SQL queries in handlers (automatic rejection if found)
-
- [ ] Proper layer separation: Handler โ†’ Service โ†’ Repository โ†’ Database
-
- [ ] Services use repository interfaces, not concrete implementations
-
- [ ] Dependencies injected via constructors, not globals
-
- [ ] No database packages imported in handlers
-
-
### 2. Security Review
-
**CHECK FOR:**
-
- SQL injection vulnerabilities (even with prepared statements, verify)
-
- Proper input validation and sanitization
-
- Authentication/authorization checks on all protected endpoints
-
- No sensitive data in logs or error messages
-
- Rate limiting on public endpoints
-
- CSRF protection where applicable
-
- Proper atProto identity verification
-
-
### 3. Error Handling Audit
-
**VERIFY:**
-
- All errors are handled, not ignored
-
- Error wrapping provides context: `fmt.Errorf("service: %w", err)`
-
- Domain errors defined in core/errors/
-
- HTTP status codes correctly map to error types
-
- No internal error details exposed to API consumers
-
- Nil pointer checks before dereferencing
-
-
### 4. Performance Considerations
-
**LOOK FOR:**
-
- N+1 query problems
-
- Missing database indexes for frequently queried fields
-
- Unnecessary database round trips
-
- Large unbounded queries without pagination
-
- Memory leaks in goroutines
-
- Proper connection pool usage
-
- Efficient atProto federation calls
-
-
### 5. Testing Coverage
-
**REQUIRE:**
-
- Unit tests for all new service methods
-
- Integration tests for new API endpoints
-
- Edge case coverage (empty inputs, max values, special characters)
-
- Error path testing
-
- Mock verification in unit tests
-
- No flaky tests (check for time dependencies, random values)
-
-
### 6. Code Quality
-
**ASSESS:**
-
- Naming follows conventions (full words, not abbreviations)
-
- Functions do one thing well
-
- No code duplication (DRY principle)
-
- Consistent error handling patterns
-
- Proper use of Go idioms
-
- No commented-out code
-
-
### 7. Breaking Changes
-
**IDENTIFY:**
-
- API contract changes
-
- Database schema modifications affecting existing data
-
- Changes to core interfaces
-
- Modified error codes or response formats
-
-
### 8. Documentation
-
**ENSURE:**
-
- API endpoints have example requests/responses
-
- Complex business logic is explained
-
- Database migrations include rollback scripts
-
- README updated if setup process changes
-
- Swagger/OpenAPI specs updated if applicable
-
-
## Review Process
-
-
1. **First Pass - Automatic Rejections**
-
- SQL in handlers
-
- Missing tests
-
- Security vulnerabilities
-
- Broken layer separation
-
-
2. **Second Pass - Deep Dive**
-
- Business logic correctness
-
- Edge case handling
-
- Performance implications
-
- Code maintainability
-
-
3. **Third Pass - Suggestions**
-
- Better patterns or approaches
-
- Refactoring opportunities
-
- Future considerations
-
-
Then provide detailed feedback organized by: 1. ๐Ÿšจ **Critical Issues** (must fix) 2. โš ๏ธ **Important Issues** (should fix) 3. ๐Ÿ’ก **Suggestions** (consider for improvement) 4. โœ… **Good Practices Observed** (reinforce positive patterns)
-
-
-
Remember: The goal is to ship quality code quickly. Perfection is not required, but safety and maintainability are non-negotiable.
+
Project: Coves Builder You are a distinguished developer actively building Coves, a forum-like atProto social media platform. Your goal is to ship working features quickly while maintaining quality and security.
+
+
## Builder Mindset
+
+
- Ship working code today, refactor tomorrow
+
- Security is built-in, not bolted-on
+
- Test-driven: write the test, then make it pass
+
- ASK QUESTIONS if you need context surrounding the product DONT ASSUME
+
+
## No Stubs, No Shortcuts
+
- **NEVER** use `unimplemented!()`, `todo!()`, or stub implementations
+
- **NEVER** leave placeholder code or incomplete implementations
+
- **NEVER** skip functionality because it seems complex
+
- Every function must be fully implemented and working
+
- Every feature must be complete before moving on
+
- E2E tests must test REAL infrastructure - not mocks
+
+
## Issue Tracking
+
+
**This project uses [bd (beads)](https://github.com/steveyegge/beads) for ALL issue tracking.**
+
+
- Use `bd` commands, NOT markdown TODOs or task lists
+
- Check `bd ready` for unblocked work
+
- Always commit `.beads/issues.jsonl` with code changes
+
- See [AGENTS.md](AGENTS.md) for full workflow details
+
+
Quick commands:
+
- `bd ready --json` - Show ready work
+
- `bd create "Title" -t bug|feature|task -p 0-4 --json` - Create issue
+
- `bd update <id> --status in_progress --json` - Claim work
+
- `bd close <id> --reason "Done" --json` - Complete work
+
## Break Down Complex Tasks
+
- Large files or complex features should be broken into manageable chunks
+
- If a file is too large, discuss breaking it into smaller modules
+
- If a task seems overwhelming, ask the user how to break it down
+
- Work incrementally, but each increment must be complete and functional
+
+
#### Human & LLM Readability Guidelines:
+
- Descriptive Naming: Use full words over abbreviations (e.g., CommunityGovernance not CommGov)
+
+
## atProto Essentials for Coves
+
+
### Architecture
+
+
- **PDS is Self-Contained**: Uses internal SQLite + CAR files (in Docker volume)
+
- **PostgreSQL for AppView Only**: One database for Coves AppView indexing
+
- **Don't Touch PDS Internals**: PDS manages its own storage, we just read from firehose
+
- **Data Flow**: Client โ†’ PDS โ†’ Firehose โ†’ AppView โ†’ PostgreSQL
+
+
### Always Consider:
+
+
- [ ] ย **Identity**: Every action needs DID verification
+
- [ ] ย **Record Types**: Define custom lexicons (e.g.,ย `social.coves.post`,ย `social.coves.community`)
+
- [ ] ย **Is it federated-friendly?**ย (Can other PDSs interact with it?)
+
- [ ] ย **Does the Lexicon make sense?**ย (Would it work for other forums?)
+
- [ ] ย **AppView only indexes**: We don't write to CAR files, only read from firehose
+
+
## Security-First Building
+
+
### Every Feature MUST:
+
+
- [ ] ย **Validate all inputs**ย at the handler level
+
- [ ] ย **Use parameterized queries**ย (never string concatenation)
+
- [ ] ย **Check authorization**ย before any operation
+
- [ ] ย **Limit resource access**ย (pagination, rate limits)
+
- [ ] ย **Log security events**ย (failed auth, invalid inputs)
+
- [ ] ย **Never log sensitive data**ย (passwords, tokens, PII)
+
+
### Red Flags to Avoid:
+
+
- `fmt.Sprintf`ย in SQL queries โ†’ Use parameterized queries
+
- Missingย `context.Context`ย โ†’ Need it for timeouts/cancellation
+
- No input validation โ†’ Add it immediately
+
- Error messages with internal details โ†’ Wrap errors properly
+
- Unbounded queries โ†’ Add limits/pagination
+
+
### "How should I structure this?"
+
+
1. One domain, one package
+
2. Interfaces for testability
+
3. Services coordinate repos
+
4. Handlers only handle XRPC
+
+
## Comprehensive Testing
+
- Write comprehensive unit tests for every module
+
- Aim for high test coverage (all major code paths)
+
- Test edge cases, error conditions, and boundary values
+
- Include doc tests for public APIs
+
- All tests must pass before considering a file "complete"
+
- Test both success and failure cases
+
## Pre-Production Advantages
+
+
Since we're pre-production:
+
+
- **Break things**: Delete and rebuild rather than complex migrations
+
- **Experiment**: Try approaches, keep what works
+
- **Simplify**: Remove unused code aggressively
+
- **But never compromise security basics**
+
+
## Success Metrics
+
+
Your code is ready when:
+
+
- [ ] ย Tests pass (including security tests)
+
- [ ] ย Follows atProto patterns
+
- [ ] ย Handles errors gracefully
+
- [ ] ย Works end-to-end with auth
+
+
## Quick Checks Before Committing
+
+
1. **Will it work?**ย (Integration test proves it)
+
2. **Is it secure?**ย (Auth, validation, parameterized queries)
+
3. **Is it simple?**ย (Could you explain to a junior?)
+
4. **Is it complete?**ย (Test, implementation, documentation)
+
+
Remember: We're building a working product. Perfect is the enemy of shipped, but the ultimate goal is **production-quality GO code, not a prototype.**
+
+
Every line of code should be something you'd be proud to ship in a production system. Quality over speed. Completeness over convenience.
+76 -16
internal/atproto/jetstream/vote_consumer.go
···
}
// Atomically: Index vote + Update post counts
-
if err := c.indexVoteAndUpdateCounts(ctx, vote); err != nil {
+
wasNew, err := c.indexVoteAndUpdateCounts(ctx, vote)
+
if err != nil {
return fmt.Errorf("failed to index vote and update counts: %w", err)
}
-
log.Printf("โœ“ Indexed vote: %s (%s on %s)", uri, vote.Direction, vote.SubjectURI)
+
if wasNew {
+
log.Printf("โœ“ Indexed vote: %s (%s on %s)", uri, vote.Direction, vote.SubjectURI)
+
}
return nil
}
···
}
// indexVoteAndUpdateCounts atomically indexes a vote and updates post vote counts
-
func (c *VoteEventConsumer) indexVoteAndUpdateCounts(ctx context.Context, vote *votes.Vote) error {
+
// Returns (true, nil) if vote was newly inserted, (false, nil) if already existed (idempotent)
+
func (c *VoteEventConsumer) indexVoteAndUpdateCounts(ctx context.Context, vote *votes.Vote) (bool, error) {
tx, err := c.db.BeginTx(ctx, nil)
if err != nil {
-
return fmt.Errorf("failed to begin transaction: %w", err)
+
return false, fmt.Errorf("failed to begin transaction: %w", err)
}
defer func() {
if rollbackErr := tx.Rollback(); rollbackErr != nil && rollbackErr != sql.ErrTxDone {
···
}
}()
-
// 1. Index the vote (idempotent with ON CONFLICT DO NOTHING)
+
// 1. Check for existing active vote with different URI (stale record)
+
// This handles cases where:
+
// - User voted on another client and we missed the delete event
+
// - Vote was reindexed but user created a new vote with different rkey
+
// - Any other state mismatch between PDS and AppView
+
var existingDirection sql.NullString
+
checkQuery := `
+
SELECT direction FROM votes
+
WHERE voter_did = $1
+
AND subject_uri = $2
+
AND deleted_at IS NULL
+
AND uri != $3
+
LIMIT 1
+
`
+
if err := tx.QueryRowContext(ctx, checkQuery, vote.VoterDID, vote.SubjectURI, vote.URI).Scan(&existingDirection); err != nil && err != sql.ErrNoRows {
+
return false, fmt.Errorf("failed to check existing vote: %w", err)
+
}
+
+
// If there's a stale vote, soft-delete it and adjust counts
+
if existingDirection.Valid {
+
softDeleteQuery := `
+
UPDATE votes
+
SET deleted_at = NOW()
+
WHERE voter_did = $1
+
AND subject_uri = $2
+
AND deleted_at IS NULL
+
AND uri != $3
+
`
+
if _, err := tx.ExecContext(ctx, softDeleteQuery, vote.VoterDID, vote.SubjectURI, vote.URI); err != nil {
+
return false, fmt.Errorf("failed to soft-delete existing votes: %w", err)
+
}
+
+
// Decrement the old vote's count (will be re-incremented below if same direction)
+
collection := utils.ExtractCollectionFromURI(vote.SubjectURI)
+
var decrementQuery string
+
if existingDirection.String == "up" {
+
if collection == "social.coves.community.post" {
+
decrementQuery = `UPDATE posts SET upvote_count = GREATEST(0, upvote_count - 1), score = upvote_count - 1 - downvote_count WHERE uri = $1 AND deleted_at IS NULL`
+
} else if collection == "social.coves.community.comment" {
+
decrementQuery = `UPDATE comments SET upvote_count = GREATEST(0, upvote_count - 1), score = upvote_count - 1 - downvote_count WHERE uri = $1 AND deleted_at IS NULL`
+
}
+
} else {
+
if collection == "social.coves.community.post" {
+
decrementQuery = `UPDATE posts SET downvote_count = GREATEST(0, downvote_count - 1), score = upvote_count - (downvote_count - 1) WHERE uri = $1 AND deleted_at IS NULL`
+
} else if collection == "social.coves.community.comment" {
+
decrementQuery = `UPDATE comments SET downvote_count = GREATEST(0, downvote_count - 1), score = upvote_count - (downvote_count - 1) WHERE uri = $1 AND deleted_at IS NULL`
+
}
+
}
+
if decrementQuery != "" {
+
if _, err := tx.ExecContext(ctx, decrementQuery, vote.SubjectURI); err != nil {
+
return false, fmt.Errorf("failed to decrement old vote count: %w", err)
+
}
+
}
+
log.Printf("Cleaned up stale vote for %s on %s (was %s)", vote.VoterDID, vote.SubjectURI, existingDirection.String)
+
}
+
+
// 2. Index the vote (idempotent with ON CONFLICT DO NOTHING)
query := `
INSERT INTO votes (
uri, cid, rkey, voter_did,
···
// If no rows returned, vote already exists (idempotent - OK for Jetstream replays)
if err == sql.ErrNoRows {
-
log.Printf("Vote already indexed: %s (idempotent)", vote.URI)
+
// Silently handle idempotent case - no log needed for replayed events
if commitErr := tx.Commit(); commitErr != nil {
-
return fmt.Errorf("failed to commit transaction: %w", commitErr)
+
return false, fmt.Errorf("failed to commit transaction: %w", commitErr)
}
-
return nil
+
return false, nil // Vote already existed
}
if err != nil {
-
return fmt.Errorf("failed to insert vote: %w", err)
+
return false, fmt.Errorf("failed to insert vote: %w", err)
}
-
// 2. Update vote counts on the subject (post or comment)
+
// 3. Update vote counts on the subject (post or comment)
// Parse collection from subject URI to determine target table
collection := utils.ExtractCollectionFromURI(vote.SubjectURI)
···
// Vote is still indexed in votes table, we just don't update denormalized counts
log.Printf("Vote subject has unsupported collection: %s (vote indexed, counts not updated)", collection)
if commitErr := tx.Commit(); commitErr != nil {
-
return fmt.Errorf("failed to commit transaction: %w", commitErr)
+
return false, fmt.Errorf("failed to commit transaction: %w", commitErr)
}
-
return nil
+
return true, nil // Vote was newly indexed
}
result, err := tx.ExecContext(ctx, updateQuery, vote.SubjectURI)
if err != nil {
-
return fmt.Errorf("failed to update vote counts: %w", err)
+
return false, fmt.Errorf("failed to update vote counts: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
-
return fmt.Errorf("failed to check update result: %w", err)
+
return false, fmt.Errorf("failed to check update result: %w", err)
}
// If subject doesn't exist or is deleted, that's OK (vote still indexed)
···
// Commit transaction
if err := tx.Commit(); err != nil {
-
return fmt.Errorf("failed to commit transaction: %w", err)
+
return false, fmt.Errorf("failed to commit transaction: %w", err)
}
-
return nil
+
return true, nil // Vote was newly indexed
}
// deleteVoteAndUpdateCounts atomically soft-deletes a vote and updates post vote counts
+109
internal/atproto/lexicon/social/coves/community/comment/create.json
···
+
{
+
"lexicon": 1,
+
"id": "social.coves.community.comment.create",
+
"defs": {
+
"main": {
+
"type": "procedure",
+
"description": "Create a comment on a post or another comment. Comments support nested threading, rich text, embeds, and self-labeling.",
+
"input": {
+
"encoding": "application/json",
+
"schema": {
+
"type": "object",
+
"required": ["reply", "content"],
+
"properties": {
+
"reply": {
+
"type": "object",
+
"description": "References for maintaining thread structure. Root always points to the original post, parent points to the immediate parent (post or comment).",
+
"required": ["root", "parent"],
+
"properties": {
+
"root": {
+
"type": "ref",
+
"ref": "com.atproto.repo.strongRef",
+
"description": "Strong reference to the original post that started the thread"
+
},
+
"parent": {
+
"type": "ref",
+
"ref": "com.atproto.repo.strongRef",
+
"description": "Strong reference to the immediate parent (post or comment) being replied to"
+
}
+
}
+
},
+
"content": {
+
"type": "string",
+
"maxGraphemes": 10000,
+
"maxLength": 100000,
+
"description": "Comment text content"
+
},
+
"facets": {
+
"type": "array",
+
"description": "Annotations for rich text (mentions, links, etc.)",
+
"items": {
+
"type": "ref",
+
"ref": "social.coves.richtext.facet"
+
}
+
},
+
"embed": {
+
"type": "union",
+
"description": "Embedded media or quoted posts",
+
"refs": [
+
"social.coves.embed.images",
+
"social.coves.embed.post"
+
]
+
},
+
"langs": {
+
"type": "array",
+
"description": "Languages used in the comment content (ISO 639-1)",
+
"maxLength": 3,
+
"items": {
+
"type": "string",
+
"format": "language"
+
}
+
},
+
"labels": {
+
"type": "ref",
+
"ref": "com.atproto.label.defs#selfLabels",
+
"description": "Self-applied content labels"
+
}
+
}
+
}
+
},
+
"output": {
+
"encoding": "application/json",
+
"schema": {
+
"type": "object",
+
"required": ["uri", "cid"],
+
"properties": {
+
"uri": {
+
"type": "string",
+
"format": "at-uri",
+
"description": "AT-URI of the created comment"
+
},
+
"cid": {
+
"type": "string",
+
"format": "cid",
+
"description": "CID of the created comment record"
+
}
+
}
+
}
+
},
+
"errors": [
+
{
+
"name": "InvalidReply",
+
"description": "The reply reference is invalid, malformed, or refers to non-existent content"
+
},
+
{
+
"name": "ContentTooLong",
+
"description": "Comment content exceeds maximum length constraints"
+
},
+
{
+
"name": "ContentEmpty",
+
"description": "Comment content is empty or contains only whitespace"
+
},
+
{
+
"name": "NotAuthorized",
+
"description": "User is not authorized to create comments on this content"
+
}
+
]
+
}
+
}
+
}
+41
internal/atproto/lexicon/social/coves/community/comment/delete.json
···
+
{
+
"lexicon": 1,
+
"id": "social.coves.community.comment.delete",
+
"defs": {
+
"main": {
+
"type": "procedure",
+
"description": "Delete a comment. Only the comment author can delete their own comments.",
+
"input": {
+
"encoding": "application/json",
+
"schema": {
+
"type": "object",
+
"required": ["uri"],
+
"properties": {
+
"uri": {
+
"type": "string",
+
"format": "at-uri",
+
"description": "AT-URI of the comment to delete"
+
}
+
}
+
}
+
},
+
"output": {
+
"encoding": "application/json",
+
"schema": {
+
"type": "object",
+
"properties": {}
+
}
+
},
+
"errors": [
+
{
+
"name": "CommentNotFound",
+
"description": "Comment with the specified URI does not exist"
+
},
+
{
+
"name": "NotAuthorized",
+
"description": "User is not authorized to delete this comment (not the author)"
+
}
+
]
+
}
+
}
+
}
+97
internal/atproto/lexicon/social/coves/community/comment/update.json
···
+
{
+
"lexicon": 1,
+
"id": "social.coves.community.comment.update",
+
"defs": {
+
"main": {
+
"type": "procedure",
+
"description": "Update an existing comment's content, facets, embed, languages, or labels. Threading references (reply.root and reply.parent) are immutable and cannot be changed.",
+
"input": {
+
"encoding": "application/json",
+
"schema": {
+
"type": "object",
+
"required": ["uri", "content"],
+
"properties": {
+
"uri": {
+
"type": "string",
+
"format": "at-uri",
+
"description": "AT-URI of the comment to update"
+
},
+
"content": {
+
"type": "string",
+
"maxGraphemes": 10000,
+
"maxLength": 100000,
+
"description": "Updated comment text content"
+
},
+
"facets": {
+
"type": "array",
+
"description": "Updated annotations for rich text (mentions, links, etc.)",
+
"items": {
+
"type": "ref",
+
"ref": "social.coves.richtext.facet"
+
}
+
},
+
"embed": {
+
"type": "union",
+
"description": "Updated embedded media or quoted posts",
+
"refs": [
+
"social.coves.embed.images",
+
"social.coves.embed.post"
+
]
+
},
+
"langs": {
+
"type": "array",
+
"description": "Updated languages used in the comment content (ISO 639-1)",
+
"maxLength": 3,
+
"items": {
+
"type": "string",
+
"format": "language"
+
}
+
},
+
"labels": {
+
"type": "ref",
+
"ref": "com.atproto.label.defs#selfLabels",
+
"description": "Updated self-applied content labels"
+
}
+
}
+
}
+
},
+
"output": {
+
"encoding": "application/json",
+
"schema": {
+
"type": "object",
+
"required": ["uri", "cid"],
+
"properties": {
+
"uri": {
+
"type": "string",
+
"format": "at-uri",
+
"description": "AT-URI of the updated comment (unchanged from input)"
+
},
+
"cid": {
+
"type": "string",
+
"format": "cid",
+
"description": "New CID of the updated comment record"
+
}
+
}
+
}
+
},
+
"errors": [
+
{
+
"name": "CommentNotFound",
+
"description": "Comment with the specified URI does not exist"
+
},
+
{
+
"name": "ContentTooLong",
+
"description": "Updated comment content exceeds maximum length constraints"
+
},
+
{
+
"name": "ContentEmpty",
+
"description": "Updated comment content is empty or contains only whitespace"
+
},
+
{
+
"name": "NotAuthorized",
+
"description": "User is not authorized to update this comment (not the author)"
+
}
+
]
+
}
+
}
+
}
+38
internal/core/comments/types.go
···
+
package comments
+
+
// CreateCommentRequest contains parameters for creating a comment
+
type CreateCommentRequest struct {
+
Reply ReplyRef `json:"reply"`
+
Content string `json:"content"`
+
Facets []interface{} `json:"facets,omitempty"`
+
Embed interface{} `json:"embed,omitempty"`
+
Langs []string `json:"langs,omitempty"`
+
Labels *SelfLabels `json:"labels,omitempty"`
+
}
+
+
// CreateCommentResponse contains the result of creating a comment
+
type CreateCommentResponse struct {
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
}
+
+
// UpdateCommentRequest contains parameters for updating a comment
+
type UpdateCommentRequest struct {
+
URI string `json:"uri"`
+
Content string `json:"content"`
+
Facets []interface{} `json:"facets,omitempty"`
+
Embed interface{} `json:"embed,omitempty"`
+
Langs []string `json:"langs,omitempty"`
+
Labels *SelfLabels `json:"labels,omitempty"`
+
}
+
+
// UpdateCommentResponse contains the result of updating a comment
+
type UpdateCommentResponse struct {
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
}
+
+
// DeleteCommentRequest contains parameters for deleting a comment
+
type DeleteCommentRequest struct {
+
URI string `json:"uri"`
+
}
+130
internal/api/handlers/comments/create_comment.go
···
+
package comments
+
+
import (
+
"Coves/internal/api/middleware"
+
"Coves/internal/core/comments"
+
"encoding/json"
+
"log"
+
"net/http"
+
)
+
+
// CreateCommentHandler handles comment creation requests
+
type CreateCommentHandler struct {
+
service comments.Service
+
}
+
+
// NewCreateCommentHandler creates a new handler for creating comments
+
func NewCreateCommentHandler(service comments.Service) *CreateCommentHandler {
+
return &CreateCommentHandler{
+
service: service,
+
}
+
}
+
+
// CreateCommentInput matches the lexicon input schema for social.coves.community.comment.create
+
type CreateCommentInput struct {
+
Reply struct {
+
Root struct {
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
} `json:"root"`
+
Parent struct {
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
} `json:"parent"`
+
} `json:"reply"`
+
Content string `json:"content"`
+
Facets []interface{} `json:"facets,omitempty"`
+
Embed interface{} `json:"embed,omitempty"`
+
Langs []string `json:"langs,omitempty"`
+
Labels interface{} `json:"labels,omitempty"`
+
}
+
+
// CreateCommentOutput matches the lexicon output schema
+
type CreateCommentOutput struct {
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
}
+
+
// HandleCreate handles comment creation requests
+
// POST /xrpc/social.coves.community.comment.create
+
//
+
// Request body: { "reply": { "root": {...}, "parent": {...} }, "content": "..." }
+
// Response: { "uri": "at://...", "cid": "..." }
+
func (h *CreateCommentHandler) HandleCreate(w http.ResponseWriter, r *http.Request) {
+
// 1. Check method is POST
+
if r.Method != http.MethodPost {
+
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+
return
+
}
+
+
// 2. Limit request body size to prevent DoS attacks (100KB should be plenty for comments)
+
r.Body = http.MaxBytesReader(w, r.Body, 100*1024)
+
+
// 3. Parse JSON body into CreateCommentInput
+
var input CreateCommentInput
+
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "Invalid request body")
+
return
+
}
+
+
// 4. Get OAuth session from context (injected by auth middleware)
+
session := middleware.GetOAuthSession(r)
+
if session == nil {
+
writeError(w, http.StatusUnauthorized, "AuthRequired", "Authentication required")
+
return
+
}
+
+
// 5. Convert labels interface{} to *comments.SelfLabels if provided
+
var labels *comments.SelfLabels
+
if input.Labels != nil {
+
labelsJSON, err := json.Marshal(input.Labels)
+
if err != nil {
+
writeError(w, http.StatusBadRequest, "InvalidLabels", "Invalid labels format")
+
return
+
}
+
var selfLabels comments.SelfLabels
+
if err := json.Unmarshal(labelsJSON, &selfLabels); err != nil {
+
writeError(w, http.StatusBadRequest, "InvalidLabels", "Invalid labels structure")
+
return
+
}
+
labels = &selfLabels
+
}
+
+
// 6. Convert input to CreateCommentRequest
+
req := comments.CreateCommentRequest{
+
Reply: comments.ReplyRef{
+
Root: comments.StrongRef{
+
URI: input.Reply.Root.URI,
+
CID: input.Reply.Root.CID,
+
},
+
Parent: comments.StrongRef{
+
URI: input.Reply.Parent.URI,
+
CID: input.Reply.Parent.CID,
+
},
+
},
+
Content: input.Content,
+
Facets: input.Facets,
+
Embed: input.Embed,
+
Langs: input.Langs,
+
Labels: labels,
+
}
+
+
// 7. Call service to create comment
+
response, err := h.service.CreateComment(r.Context(), session, req)
+
if err != nil {
+
handleServiceError(w, err)
+
return
+
}
+
+
// 8. Return JSON response with URI and CID
+
output := CreateCommentOutput{
+
URI: response.URI,
+
CID: response.CID,
+
}
+
+
w.Header().Set("Content-Type", "application/json")
+
w.WriteHeader(http.StatusOK)
+
if err := json.NewEncoder(w).Encode(output); err != nil {
+
log.Printf("Failed to encode response: %v", err)
+
}
+
}
+80
internal/api/handlers/comments/delete_comment.go
···
+
package comments
+
+
import (
+
"Coves/internal/api/middleware"
+
"Coves/internal/core/comments"
+
"encoding/json"
+
"log"
+
"net/http"
+
)
+
+
// DeleteCommentHandler handles comment deletion requests
+
type DeleteCommentHandler struct {
+
service comments.Service
+
}
+
+
// NewDeleteCommentHandler creates a new handler for deleting comments
+
func NewDeleteCommentHandler(service comments.Service) *DeleteCommentHandler {
+
return &DeleteCommentHandler{
+
service: service,
+
}
+
}
+
+
// DeleteCommentInput matches the lexicon input schema for social.coves.community.comment.delete
+
type DeleteCommentInput struct {
+
URI string `json:"uri"`
+
}
+
+
// DeleteCommentOutput is empty per lexicon specification
+
type DeleteCommentOutput struct{}
+
+
// HandleDelete handles comment deletion requests
+
// POST /xrpc/social.coves.community.comment.delete
+
//
+
// Request body: { "uri": "at://..." }
+
// Response: {}
+
func (h *DeleteCommentHandler) HandleDelete(w http.ResponseWriter, r *http.Request) {
+
// 1. Check method is POST
+
if r.Method != http.MethodPost {
+
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+
return
+
}
+
+
// 2. Limit request body size to prevent DoS attacks (100KB should be plenty for comments)
+
r.Body = http.MaxBytesReader(w, r.Body, 100*1024)
+
+
// 3. Parse JSON body into DeleteCommentInput
+
var input DeleteCommentInput
+
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "Invalid request body")
+
return
+
}
+
+
// 4. Get OAuth session from context (injected by auth middleware)
+
session := middleware.GetOAuthSession(r)
+
if session == nil {
+
writeError(w, http.StatusUnauthorized, "AuthRequired", "Authentication required")
+
return
+
}
+
+
// 5. Convert input to DeleteCommentRequest
+
req := comments.DeleteCommentRequest{
+
URI: input.URI,
+
}
+
+
// 6. Call service to delete comment
+
err := h.service.DeleteComment(r.Context(), session, req)
+
if err != nil {
+
handleServiceError(w, err)
+
return
+
}
+
+
// 7. Return empty JSON object per lexicon specification
+
output := DeleteCommentOutput{}
+
+
w.Header().Set("Content-Type", "application/json")
+
w.WriteHeader(http.StatusOK)
+
if err := json.NewEncoder(w).Encode(output); err != nil {
+
log.Printf("Failed to encode response: %v", err)
+
}
+
}
+34 -2
internal/api/handlers/comments/errors.go
···
import (
"Coves/internal/core/comments"
"encoding/json"
+
"errors"
"log"
"net/http"
)
···
func handleServiceError(w http.ResponseWriter, err error) {
switch {
case comments.IsNotFound(err):
-
writeError(w, http.StatusNotFound, "NotFound", err.Error())
+
// Map specific not found errors to appropriate messages
+
switch {
+
case errors.Is(err, comments.ErrCommentNotFound):
+
writeError(w, http.StatusNotFound, "CommentNotFound", "Comment not found")
+
case errors.Is(err, comments.ErrParentNotFound):
+
writeError(w, http.StatusNotFound, "ParentNotFound", "Parent post or comment not found")
+
case errors.Is(err, comments.ErrRootNotFound):
+
writeError(w, http.StatusNotFound, "RootNotFound", "Root post not found")
+
default:
+
writeError(w, http.StatusNotFound, "NotFound", err.Error())
+
}
case comments.IsValidationError(err):
-
writeError(w, http.StatusBadRequest, "InvalidRequest", err.Error())
+
// Map specific validation errors to appropriate messages
+
switch {
+
case errors.Is(err, comments.ErrInvalidReply):
+
writeError(w, http.StatusBadRequest, "InvalidReply", "The reply reference is invalid or malformed")
+
case errors.Is(err, comments.ErrContentTooLong):
+
writeError(w, http.StatusBadRequest, "ContentTooLong", "Comment content exceeds 10000 graphemes")
+
case errors.Is(err, comments.ErrContentEmpty):
+
writeError(w, http.StatusBadRequest, "ContentEmpty", "Comment content is required")
+
default:
+
writeError(w, http.StatusBadRequest, "InvalidRequest", err.Error())
+
}
+
+
case errors.Is(err, comments.ErrNotAuthorized):
+
writeError(w, http.StatusForbidden, "NotAuthorized", "User is not authorized to perform this action")
+
+
case errors.Is(err, comments.ErrBanned):
+
writeError(w, http.StatusForbidden, "Banned", "User is banned from this community")
+
+
// NOTE: IsConflict case removed - the PDS handles duplicate detection via CreateRecord,
+
// so ErrCommentAlreadyExists is never returned from the service layer. If the PDS rejects
+
// a duplicate record, it returns an auth/validation error which is handled by other cases.
+
// Keeping this code would be dead code that never executes.
default:
// Don't leak internal error details to clients
+112
internal/api/handlers/comments/update_comment.go
···
+
package comments
+
+
import (
+
"Coves/internal/api/middleware"
+
"Coves/internal/core/comments"
+
"encoding/json"
+
"log"
+
"net/http"
+
)
+
+
// UpdateCommentHandler handles comment update requests
+
type UpdateCommentHandler struct {
+
service comments.Service
+
}
+
+
// NewUpdateCommentHandler creates a new handler for updating comments
+
func NewUpdateCommentHandler(service comments.Service) *UpdateCommentHandler {
+
return &UpdateCommentHandler{
+
service: service,
+
}
+
}
+
+
// UpdateCommentInput matches the lexicon input schema for social.coves.community.comment.update
+
type UpdateCommentInput struct {
+
URI string `json:"uri"`
+
Content string `json:"content"`
+
Facets []interface{} `json:"facets,omitempty"`
+
Embed interface{} `json:"embed,omitempty"`
+
Langs []string `json:"langs,omitempty"`
+
Labels interface{} `json:"labels,omitempty"`
+
}
+
+
// UpdateCommentOutput matches the lexicon output schema
+
type UpdateCommentOutput struct {
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
}
+
+
// HandleUpdate handles comment update requests
+
// POST /xrpc/social.coves.community.comment.update
+
//
+
// Request body: { "uri": "at://...", "content": "..." }
+
// Response: { "uri": "at://...", "cid": "..." }
+
func (h *UpdateCommentHandler) HandleUpdate(w http.ResponseWriter, r *http.Request) {
+
// 1. Check method is POST
+
if r.Method != http.MethodPost {
+
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+
return
+
}
+
+
// 2. Limit request body size to prevent DoS attacks (100KB should be plenty for comments)
+
r.Body = http.MaxBytesReader(w, r.Body, 100*1024)
+
+
// 3. Parse JSON body into UpdateCommentInput
+
var input UpdateCommentInput
+
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
+
writeError(w, http.StatusBadRequest, "InvalidRequest", "Invalid request body")
+
return
+
}
+
+
// 4. Get OAuth session from context (injected by auth middleware)
+
session := middleware.GetOAuthSession(r)
+
if session == nil {
+
writeError(w, http.StatusUnauthorized, "AuthRequired", "Authentication required")
+
return
+
}
+
+
// 5. Convert labels interface{} to *comments.SelfLabels if provided
+
var labels *comments.SelfLabels
+
if input.Labels != nil {
+
labelsJSON, err := json.Marshal(input.Labels)
+
if err != nil {
+
writeError(w, http.StatusBadRequest, "InvalidLabels", "Invalid labels format")
+
return
+
}
+
var selfLabels comments.SelfLabels
+
if err := json.Unmarshal(labelsJSON, &selfLabels); err != nil {
+
writeError(w, http.StatusBadRequest, "InvalidLabels", "Invalid labels structure")
+
return
+
}
+
labels = &selfLabels
+
}
+
+
// 6. Convert input to UpdateCommentRequest
+
req := comments.UpdateCommentRequest{
+
URI: input.URI,
+
Content: input.Content,
+
Facets: input.Facets,
+
Embed: input.Embed,
+
Langs: input.Langs,
+
Labels: labels,
+
}
+
+
// 7. Call service to update comment
+
response, err := h.service.UpdateComment(r.Context(), session, req)
+
if err != nil {
+
handleServiceError(w, err)
+
return
+
}
+
+
// 8. Return JSON response with URI and CID
+
output := UpdateCommentOutput{
+
URI: response.URI,
+
CID: response.CID,
+
}
+
+
w.Header().Set("Content-Type", "application/json")
+
w.WriteHeader(http.StatusOK)
+
if err := json.NewEncoder(w).Encode(output); err != nil {
+
log.Printf("Failed to encode response: %v", err)
+
}
+
}
+35
internal/api/routes/comment.go
···
+
package routes
+
+
import (
+
"Coves/internal/api/handlers/comments"
+
"Coves/internal/api/middleware"
+
commentsCore "Coves/internal/core/comments"
+
+
"github.com/go-chi/chi/v5"
+
)
+
+
// RegisterCommentRoutes registers comment-related XRPC endpoints on the router
+
// Implements social.coves.community.comment.* lexicon endpoints
+
// All write operations (create, update, delete) require authentication
+
func RegisterCommentRoutes(r chi.Router, service commentsCore.Service, authMiddleware *middleware.OAuthAuthMiddleware) {
+
// Initialize handlers
+
createHandler := comments.NewCreateCommentHandler(service)
+
updateHandler := comments.NewUpdateCommentHandler(service)
+
deleteHandler := comments.NewDeleteCommentHandler(service)
+
+
// Procedure endpoints (POST) - require authentication
+
// social.coves.community.comment.create - create a new comment on a post or another comment
+
r.With(authMiddleware.RequireAuth).Post(
+
"/xrpc/social.coves.community.comment.create",
+
createHandler.HandleCreate)
+
+
// social.coves.community.comment.update - update an existing comment's content
+
r.With(authMiddleware.RequireAuth).Post(
+
"/xrpc/social.coves.community.comment.update",
+
updateHandler.HandleUpdate)
+
+
// social.coves.community.comment.delete - soft delete a comment
+
r.With(authMiddleware.RequireAuth).Post(
+
"/xrpc/social.coves.community.comment.delete",
+
deleteHandler.HandleDelete)
+
}
+4 -2
tests/integration/comment_query_test.go
···
postRepo := postgres.NewPostRepository(db)
userRepo := postgres.NewUserRepository(db)
communityRepo := postgres.NewCommunityRepository(db)
-
return comments.NewCommentService(commentRepo, userRepo, postRepo, communityRepo)
+
// Use factory constructor with nil factory - these tests only use the read path (GetComments)
+
return comments.NewCommentServiceWithPDSFactory(commentRepo, userRepo, postRepo, communityRepo, nil, nil)
}
// Helper: createTestCommentWithScore creates a comment with specific vote counts
···
postRepo := postgres.NewPostRepository(db)
userRepo := postgres.NewUserRepository(db)
communityRepo := postgres.NewCommunityRepository(db)
-
service := comments.NewCommentService(commentRepo, userRepo, postRepo, communityRepo)
+
// Use factory constructor with nil factory - these tests only use the read path (GetComments)
+
service := comments.NewCommentServiceWithPDSFactory(commentRepo, userRepo, postRepo, communityRepo, nil, nil)
return &testCommentServiceAdapter{service: service}
}
+6 -3
tests/integration/comment_vote_test.go
···
}
// Query comments with viewer authentication
-
commentService := comments.NewCommentService(commentRepo, userRepo, postRepo, communityRepo)
+
// Use factory constructor with nil factory - this test only uses the read path (GetComments)
+
commentService := comments.NewCommentServiceWithPDSFactory(commentRepo, userRepo, postRepo, communityRepo, nil, nil)
response, err := commentService.GetComments(ctx, &comments.GetCommentsRequest{
PostURI: testPostURI,
Sort: "new",
···
}
// Query with authentication but no vote
-
commentService := comments.NewCommentService(commentRepo, userRepo, postRepo, communityRepo)
+
// Use factory constructor with nil factory - this test only uses the read path (GetComments)
+
commentService := comments.NewCommentServiceWithPDSFactory(commentRepo, userRepo, postRepo, communityRepo, nil, nil)
response, err := commentService.GetComments(ctx, &comments.GetCommentsRequest{
PostURI: testPostURI,
Sort: "new",
···
t.Run("Unauthenticated request has no viewer state", func(t *testing.T) {
// Query without authentication
-
commentService := comments.NewCommentService(commentRepo, userRepo, postRepo, communityRepo)
+
// Use factory constructor with nil factory - this test only uses the read path (GetComments)
+
commentService := comments.NewCommentServiceWithPDSFactory(commentRepo, userRepo, postRepo, communityRepo, nil, nil)
response, err := commentService.GetComments(ctx, &comments.GetCommentsRequest{
PostURI: testPostURI,
Sort: "new",
+2 -1
tests/integration/concurrent_scenarios_test.go
···
}
// Verify all comments are retrievable via service
-
commentService := comments.NewCommentService(commentRepo, userRepo, postRepo, communityRepo)
+
// Use factory constructor with nil factory - this test only uses the read path (GetComments)
+
commentService := comments.NewCommentServiceWithPDSFactory(commentRepo, userRepo, postRepo, communityRepo, nil, nil)
response, err := commentService.GetComments(ctx, &comments.GetCommentsRequest{
PostURI: postURI,
Sort: "new",
+1 -1
go.mod
···
github.com/prometheus/client_model v0.5.0 // indirect
github.com/prometheus/common v0.45.0 // indirect
github.com/prometheus/procfs v0.12.0 // indirect
-
github.com/rivo/uniseg v0.1.0 // indirect
+
github.com/rivo/uniseg v0.4.7 // indirect
github.com/segmentio/asm v1.2.0 // indirect
github.com/sethvargo/go-retry v0.3.0 // indirect
github.com/spaolacci/murmur3 v1.1.0 // indirect
+66
internal/db/migrations/021_add_comment_deletion_metadata.sql
···
+
-- +goose Up
+
-- Add deletion reason tracking to preserve thread structure while respecting privacy
+
-- When comments are deleted, we blank content but keep the record for threading
+
+
-- Create enum type for deletion reasons
+
CREATE TYPE deletion_reason AS ENUM ('author', 'moderator');
+
+
-- Add new columns to comments table
+
ALTER TABLE comments ADD COLUMN deletion_reason deletion_reason;
+
ALTER TABLE comments ADD COLUMN deleted_by TEXT;
+
+
-- Add comments for new columns
+
COMMENT ON COLUMN comments.deletion_reason IS 'Reason for deletion: author (user deleted), moderator (community mod removed)';
+
COMMENT ON COLUMN comments.deleted_by IS 'DID of the actor who performed the deletion';
+
+
-- Backfill existing deleted comments as author-deleted
+
-- This handles existing soft-deleted comments gracefully
+
UPDATE comments
+
SET deletion_reason = 'author',
+
deleted_by = commenter_did
+
WHERE deleted_at IS NOT NULL AND deletion_reason IS NULL;
+
+
-- Modify existing indexes to NOT filter deleted_at IS NULL
+
-- This allows deleted comments to appear in thread queries for structure preservation
+
-- Note: We drop and recreate to change the partial index condition
+
+
-- Drop old partial indexes that exclude deleted comments
+
DROP INDEX IF EXISTS idx_comments_root;
+
DROP INDEX IF EXISTS idx_comments_parent;
+
DROP INDEX IF EXISTS idx_comments_parent_score;
+
DROP INDEX IF EXISTS idx_comments_uri_active;
+
+
-- Recreate indexes without the deleted_at filter (include all comments for threading)
+
CREATE INDEX idx_comments_root ON comments(root_uri, created_at DESC);
+
CREATE INDEX idx_comments_parent ON comments(parent_uri, created_at DESC);
+
CREATE INDEX idx_comments_parent_score ON comments(parent_uri, score DESC, created_at DESC);
+
CREATE INDEX idx_comments_uri_lookup ON comments(uri);
+
+
-- Add index for querying by deletion_reason (for moderation dashboard)
+
CREATE INDEX idx_comments_deleted_reason ON comments(deletion_reason, deleted_at DESC)
+
WHERE deleted_at IS NOT NULL;
+
+
-- Add index for querying by deleted_by (for moderation audit/filtering)
+
CREATE INDEX idx_comments_deleted_by ON comments(deleted_by, deleted_at DESC)
+
WHERE deleted_at IS NOT NULL;
+
+
-- +goose Down
+
-- Remove deletion metadata columns and restore original indexes
+
+
DROP INDEX IF EXISTS idx_comments_deleted_by;
+
DROP INDEX IF EXISTS idx_comments_deleted_reason;
+
DROP INDEX IF EXISTS idx_comments_uri_lookup;
+
DROP INDEX IF EXISTS idx_comments_parent_score;
+
DROP INDEX IF EXISTS idx_comments_parent;
+
DROP INDEX IF EXISTS idx_comments_root;
+
+
-- Restore original partial indexes (excluding deleted comments)
+
CREATE INDEX idx_comments_root ON comments(root_uri, created_at DESC) WHERE deleted_at IS NULL;
+
CREATE INDEX idx_comments_parent ON comments(parent_uri, created_at DESC) WHERE deleted_at IS NULL;
+
CREATE INDEX idx_comments_parent_score ON comments(parent_uri, score DESC, created_at DESC) WHERE deleted_at IS NULL;
+
CREATE INDEX idx_comments_uri_active ON comments(uri) WHERE deleted_at IS NULL;
+
+
ALTER TABLE comments DROP COLUMN IF EXISTS deleted_by;
+
ALTER TABLE comments DROP COLUMN IF EXISTS deletion_reason;
+
+
DROP TYPE IF EXISTS deletion_reason;
+17 -13
internal/core/comments/view_models.go
···
// CommentView represents the full view of a comment with all metadata
// Matches social.coves.community.comment.getComments#commentView lexicon
// Used in thread views and get endpoints
+
// For deleted comments, IsDeleted=true and content-related fields are empty/nil
type CommentView struct {
-
Embed interface{} `json:"embed,omitempty"`
-
Record interface{} `json:"record"`
-
Viewer *CommentViewerState `json:"viewer,omitempty"`
-
Author *posts.AuthorView `json:"author"`
-
Post *CommentRef `json:"post"`
-
Parent *CommentRef `json:"parent,omitempty"`
-
Stats *CommentStats `json:"stats"`
-
Content string `json:"content"`
-
CreatedAt string `json:"createdAt"`
-
IndexedAt string `json:"indexedAt"`
-
URI string `json:"uri"`
-
CID string `json:"cid"`
-
ContentFacets []interface{} `json:"contentFacets,omitempty"`
+
Embed interface{} `json:"embed,omitempty"`
+
Record interface{} `json:"record"`
+
Viewer *CommentViewerState `json:"viewer,omitempty"`
+
Author *posts.AuthorView `json:"author"`
+
Post *CommentRef `json:"post"`
+
Parent *CommentRef `json:"parent,omitempty"`
+
Stats *CommentStats `json:"stats"`
+
Content string `json:"content"`
+
CreatedAt string `json:"createdAt"`
+
IndexedAt string `json:"indexedAt"`
+
URI string `json:"uri"`
+
CID string `json:"cid"`
+
ContentFacets []interface{} `json:"contentFacets,omitempty"`
+
IsDeleted bool `json:"isDeleted,omitempty"`
+
DeletionReason *string `json:"deletionReason,omitempty"`
+
DeletedAt *string `json:"deletedAt,omitempty"`
}
// ThreadViewComment represents a comment with its nested replies
+23 -1
internal/core/comments/interfaces.go
···
package comments
-
import "context"
+
import (
+
"context"
+
"database/sql"
+
)
// Repository defines the data access interface for comments
// Used by Jetstream consumer to index comments from firehose
···
// Delete soft-deletes a comment (sets deleted_at)
// Called by Jetstream consumer after comment is deleted from PDS
+
// Deprecated: Use SoftDeleteWithReason for new code to preserve thread structure
Delete(ctx context.Context, uri string) error
+
// SoftDeleteWithReason performs a soft delete that blanks content but preserves thread structure
+
// This allows deleted comments to appear as "[deleted]" placeholders in thread views
+
// reason: "author" (user deleted) or "moderator" (mod removed)
+
// deletedByDID: DID of the actor who performed the deletion
+
SoftDeleteWithReason(ctx context.Context, uri, reason, deletedByDID string) error
+
// ListByRoot retrieves all comments in a thread (flat)
// Used for fetching entire comment threads on posts
ListByRoot(ctx context.Context, rootURI string, limit, offset int) ([]*Comment, error)
···
limitPerParent int,
) (map[string][]*Comment, error)
}
+
+
// RepositoryTx provides transaction-aware operations for consumers that need atomicity
+
// Used by Jetstream consumer to perform atomic delete + count updates
+
// Implementations that support transactions should also implement this interface
+
type RepositoryTx interface {
+
// SoftDeleteWithReasonTx performs a soft delete within a transaction
+
// If tx is nil, executes directly against the database
+
// Returns rows affected count for callers that need to check idempotency
+
// reason: must be DeletionReasonAuthor or DeletionReasonModerator
+
// deletedByDID: DID of the actor who performed the deletion
+
SoftDeleteWithReasonTx(ctx context.Context, tx *sql.Tx, uri, reason, deletedByDID string) (int64, error)
+
}
+87 -27
internal/db/postgres/comment_repo.go
···
id, uri, cid, rkey, commenter_did,
root_uri, root_cid, parent_uri, parent_cid,
content, content_facets, embed, content_labels, langs,
-
created_at, indexed_at, deleted_at,
+
created_at, indexed_at, deleted_at, deletion_reason, deleted_by,
upvote_count, downvote_count, score, reply_count
FROM comments
WHERE uri = $1
···
&comment.ID, &comment.URI, &comment.CID, &comment.RKey, &comment.CommenterDID,
&comment.RootURI, &comment.RootCID, &comment.ParentURI, &comment.ParentCID,
&comment.Content, &comment.ContentFacets, &comment.Embed, &comment.ContentLabels, &langs,
-
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt,
+
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt, &comment.DeletionReason, &comment.DeletedBy,
&comment.UpvoteCount, &comment.DownvoteCount, &comment.Score, &comment.ReplyCount,
)
···
// Delete soft-deletes a comment (sets deleted_at)
// Called by Jetstream consumer after comment is deleted from PDS
// Idempotent: Returns success if comment already deleted
+
// Deprecated: Use SoftDeleteWithReason for new code to preserve thread structure
func (r *postgresCommentRepo) Delete(ctx context.Context, uri string) error {
query := `
UPDATE comments
···
return nil
}
-
// ListByRoot retrieves all active comments in a thread (flat)
+
// SoftDeleteWithReason performs a soft delete that blanks content but preserves thread structure
+
// This allows deleted comments to appear as "[deleted]" placeholders in thread views
+
// Idempotent: Returns success if comment already deleted
+
// Validates that reason is a known deletion reason constant
+
func (r *postgresCommentRepo) SoftDeleteWithReason(ctx context.Context, uri, reason, deletedByDID string) error {
+
// Validate deletion reason
+
if reason != comments.DeletionReasonAuthor && reason != comments.DeletionReasonModerator {
+
return fmt.Errorf("invalid deletion reason: %s", reason)
+
}
+
+
_, err := r.SoftDeleteWithReasonTx(ctx, nil, uri, reason, deletedByDID)
+
return err
+
}
+
+
// SoftDeleteWithReasonTx performs a soft delete within an optional transaction
+
// If tx is nil, executes directly against the database
+
// Returns rows affected count for callers that need to check idempotency
+
// This method is used by both the repository and the Jetstream consumer
+
func (r *postgresCommentRepo) SoftDeleteWithReasonTx(ctx context.Context, tx *sql.Tx, uri, reason, deletedByDID string) (int64, error) {
+
query := `
+
UPDATE comments
+
SET
+
content = '',
+
content_facets = NULL,
+
embed = NULL,
+
content_labels = NULL,
+
deleted_at = NOW(),
+
deletion_reason = $2,
+
deleted_by = $3
+
WHERE uri = $1 AND deleted_at IS NULL
+
`
+
+
var result sql.Result
+
var err error
+
+
if tx != nil {
+
result, err = tx.ExecContext(ctx, query, uri, reason, deletedByDID)
+
} else {
+
result, err = r.db.ExecContext(ctx, query, uri, reason, deletedByDID)
+
}
+
+
if err != nil {
+
return 0, fmt.Errorf("failed to soft delete comment: %w", err)
+
}
+
+
rowsAffected, err := result.RowsAffected()
+
if err != nil {
+
return 0, fmt.Errorf("failed to check delete result: %w", err)
+
}
+
+
return rowsAffected, nil
+
}
+
+
// ListByRoot retrieves all comments in a thread (flat), including deleted ones
// Used for fetching entire comment threads on posts
+
// Includes deleted comments to preserve thread structure (shown as "[deleted]" placeholders)
func (r *postgresCommentRepo) ListByRoot(ctx context.Context, rootURI string, limit, offset int) ([]*comments.Comment, error) {
query := `
SELECT
id, uri, cid, rkey, commenter_did,
root_uri, root_cid, parent_uri, parent_cid,
content, content_facets, embed, content_labels, langs,
-
created_at, indexed_at, deleted_at,
+
created_at, indexed_at, deleted_at, deletion_reason, deleted_by,
upvote_count, downvote_count, score, reply_count
FROM comments
-
WHERE root_uri = $1 AND deleted_at IS NULL
+
WHERE root_uri = $1
ORDER BY created_at ASC
LIMIT $2 OFFSET $3
`
···
&comment.ID, &comment.URI, &comment.CID, &comment.RKey, &comment.CommenterDID,
&comment.RootURI, &comment.RootCID, &comment.ParentURI, &comment.ParentCID,
&comment.Content, &comment.ContentFacets, &comment.Embed, &comment.ContentLabels, &langs,
-
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt,
+
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt, &comment.DeletionReason, &comment.DeletedBy,
&comment.UpvoteCount, &comment.DownvoteCount, &comment.Score, &comment.ReplyCount,
)
if err != nil {
···
return result, nil
}
-
// ListByParent retrieves direct replies to a post or comment
+
// ListByParent retrieves direct replies to a post or comment, including deleted ones
// Used for building nested/threaded comment views
+
// Includes deleted comments to preserve thread structure (shown as "[deleted]" placeholders)
func (r *postgresCommentRepo) ListByParent(ctx context.Context, parentURI string, limit, offset int) ([]*comments.Comment, error) {
query := `
SELECT
id, uri, cid, rkey, commenter_did,
root_uri, root_cid, parent_uri, parent_cid,
content, content_facets, embed, content_labels, langs,
-
created_at, indexed_at, deleted_at,
+
created_at, indexed_at, deleted_at, deletion_reason, deleted_by,
upvote_count, downvote_count, score, reply_count
FROM comments
-
WHERE parent_uri = $1 AND deleted_at IS NULL
+
WHERE parent_uri = $1
ORDER BY created_at ASC
LIMIT $2 OFFSET $3
`
···
&comment.ID, &comment.URI, &comment.CID, &comment.RKey, &comment.CommenterDID,
&comment.RootURI, &comment.RootCID, &comment.ParentURI, &comment.ParentCID,
&comment.Content, &comment.ContentFacets, &comment.Embed, &comment.ContentLabels, &langs,
-
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt,
+
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt, &comment.DeletionReason, &comment.DeletedBy,
&comment.UpvoteCount, &comment.DownvoteCount, &comment.Score, &comment.ReplyCount,
)
if err != nil {
···
}
// ListByCommenter retrieves all active comments by a specific user
-
// Future: Used for user comment history
+
// Used for user comment history - filters out deleted comments
func (r *postgresCommentRepo) ListByCommenter(ctx context.Context, commenterDID string, limit, offset int) ([]*comments.Comment, error) {
query := `
SELECT
id, uri, cid, rkey, commenter_did,
root_uri, root_cid, parent_uri, parent_cid,
content, content_facets, embed, content_labels, langs,
-
created_at, indexed_at, deleted_at,
+
created_at, indexed_at, deleted_at, deletion_reason, deleted_by,
upvote_count, downvote_count, score, reply_count
FROM comments
WHERE commenter_did = $1 AND deleted_at IS NULL
···
&comment.ID, &comment.URI, &comment.CID, &comment.RKey, &comment.CommenterDID,
&comment.RootURI, &comment.RootCID, &comment.ParentURI, &comment.ParentCID,
&comment.Content, &comment.ContentFacets, &comment.Embed, &comment.ContentLabels, &langs,
-
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt,
+
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt, &comment.DeletionReason, &comment.DeletedBy,
&comment.UpvoteCount, &comment.DownvoteCount, &comment.Score, &comment.ReplyCount,
)
if err != nil {
···
c.id, c.uri, c.cid, c.rkey, c.commenter_did,
c.root_uri, c.root_cid, c.parent_uri, c.parent_cid,
c.content, c.content_facets, c.embed, c.content_labels, c.langs,
-
c.created_at, c.indexed_at, c.deleted_at,
+
c.created_at, c.indexed_at, c.deleted_at, c.deletion_reason, c.deleted_by,
c.upvote_count, c.downvote_count, c.score, c.reply_count,
log(greatest(2, c.score + 2)) / power(((EXTRACT(EPOCH FROM (NOW() - c.created_at)) / 3600) + 2), 1.8) as hot_rank,
COALESCE(u.handle, c.commenter_did) as author_handle
···
c.id, c.uri, c.cid, c.rkey, c.commenter_did,
c.root_uri, c.root_cid, c.parent_uri, c.parent_cid,
c.content, c.content_facets, c.embed, c.content_labels, c.langs,
-
c.created_at, c.indexed_at, c.deleted_at,
+
c.created_at, c.indexed_at, c.deleted_at, c.deletion_reason, c.deleted_by,
c.upvote_count, c.downvote_count, c.score, c.reply_count,
NULL::numeric as hot_rank,
COALESCE(u.handle, c.commenter_did) as author_handle
···
// Build complete query with JOINs and filters
// LEFT JOIN prevents data loss when user record hasn't been indexed yet (out-of-order Jetstream events)
+
// Includes deleted comments to preserve thread structure (shown as "[deleted]" placeholders)
query := fmt.Sprintf(`
%s
LEFT JOIN users u ON c.commenter_did = u.did
-
WHERE c.parent_uri = $1 AND c.deleted_at IS NULL
+
WHERE c.parent_uri = $1
%s
%s
ORDER BY %s
···
&comment.ID, &comment.URI, &comment.CID, &comment.RKey, &comment.CommenterDID,
&comment.RootURI, &comment.RootCID, &comment.ParentURI, &comment.ParentCID,
&comment.Content, &comment.ContentFacets, &comment.Embed, &comment.ContentLabels, &langs,
-
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt,
+
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt, &comment.DeletionReason, &comment.DeletedBy,
&comment.UpvoteCount, &comment.DownvoteCount, &comment.Score, &comment.ReplyCount,
&hotRank, &authorHandle,
)
···
// GetByURIsBatch retrieves multiple comments by their AT-URIs in a single query
// Returns map[uri]*Comment for efficient lookups without N+1 queries
+
// Includes deleted comments to preserve thread structure
func (r *postgresCommentRepo) GetByURIsBatch(ctx context.Context, uris []string) (map[string]*comments.Comment, error) {
if len(uris) == 0 {
return make(map[string]*comments.Comment), nil
···
// LEFT JOIN prevents data loss when user record hasn't been indexed yet (out-of-order Jetstream events)
// COALESCE falls back to DID when handle is NULL (user not yet in users table)
+
// Includes deleted comments to preserve thread structure (shown as "[deleted]" placeholders)
query := `
SELECT
c.id, c.uri, c.cid, c.rkey, c.commenter_did,
c.root_uri, c.root_cid, c.parent_uri, c.parent_cid,
c.content, c.content_facets, c.embed, c.content_labels, c.langs,
-
c.created_at, c.indexed_at, c.deleted_at,
+
c.created_at, c.indexed_at, c.deleted_at, c.deletion_reason, c.deleted_by,
c.upvote_count, c.downvote_count, c.score, c.reply_count,
COALESCE(u.handle, c.commenter_did) as author_handle
FROM comments c
LEFT JOIN users u ON c.commenter_did = u.did
-
WHERE c.uri = ANY($1) AND c.deleted_at IS NULL
+
WHERE c.uri = ANY($1)
`
rows, err := r.db.QueryContext(ctx, query, pq.Array(uris))
···
&comment.ID, &comment.URI, &comment.CID, &comment.RKey, &comment.CommenterDID,
&comment.RootURI, &comment.RootCID, &comment.ParentURI, &comment.ParentCID,
&comment.Content, &comment.ContentFacets, &comment.Embed, &comment.ContentLabels, &langs,
-
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt,
+
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt, &comment.DeletionReason, &comment.DeletedBy,
&comment.UpvoteCount, &comment.DownvoteCount, &comment.Score, &comment.ReplyCount,
&authorHandle,
)
···
c.id, c.uri, c.cid, c.rkey, c.commenter_did,
c.root_uri, c.root_cid, c.parent_uri, c.parent_cid,
c.content, c.content_facets, c.embed, c.content_labels, c.langs,
-
c.created_at, c.indexed_at, c.deleted_at,
+
c.created_at, c.indexed_at, c.deleted_at, c.deletion_reason, c.deleted_by,
c.upvote_count, c.downvote_count, c.score, c.reply_count,
log(greatest(2, c.score + 2)) / power(((EXTRACT(EPOCH FROM (NOW() - c.created_at)) / 3600) + 2), 1.8) as hot_rank,
COALESCE(u.handle, c.commenter_did) as author_handle`
···
c.id, c.uri, c.cid, c.rkey, c.commenter_did,
c.root_uri, c.root_cid, c.parent_uri, c.parent_cid,
c.content, c.content_facets, c.embed, c.content_labels, c.langs,
-
c.created_at, c.indexed_at, c.deleted_at,
+
c.created_at, c.indexed_at, c.deleted_at, c.deletion_reason, c.deleted_by,
c.upvote_count, c.downvote_count, c.score, c.reply_count,
NULL::numeric as hot_rank,
COALESCE(u.handle, c.commenter_did) as author_handle`
···
c.id, c.uri, c.cid, c.rkey, c.commenter_did,
c.root_uri, c.root_cid, c.parent_uri, c.parent_cid,
c.content, c.content_facets, c.embed, c.content_labels, c.langs,
-
c.created_at, c.indexed_at, c.deleted_at,
+
c.created_at, c.indexed_at, c.deleted_at, c.deletion_reason, c.deleted_by,
c.upvote_count, c.downvote_count, c.score, c.reply_count,
NULL::numeric as hot_rank,
COALESCE(u.handle, c.commenter_did) as author_handle`
···
c.id, c.uri, c.cid, c.rkey, c.commenter_did,
c.root_uri, c.root_cid, c.parent_uri, c.parent_cid,
c.content, c.content_facets, c.embed, c.content_labels, c.langs,
-
c.created_at, c.indexed_at, c.deleted_at,
+
c.created_at, c.indexed_at, c.deleted_at, c.deletion_reason, c.deleted_by,
c.upvote_count, c.downvote_count, c.score, c.reply_count,
log(greatest(2, c.score + 2)) / power(((EXTRACT(EPOCH FROM (NOW() - c.created_at)) / 3600) + 2), 1.8) as hot_rank,
COALESCE(u.handle, c.commenter_did) as author_handle`
···
// Use window function to limit results per parent
// This is more efficient than LIMIT in a subquery per parent
// LEFT JOIN prevents data loss when user record hasn't been indexed yet (out-of-order Jetstream events)
+
// Includes deleted comments to preserve thread structure (shown as "[deleted]" placeholders)
query := fmt.Sprintf(`
WITH ranked_comments AS (
SELECT
···
) as rn
FROM comments c
LEFT JOIN users u ON c.commenter_did = u.did
-
WHERE c.parent_uri = ANY($1) AND c.deleted_at IS NULL
+
WHERE c.parent_uri = ANY($1)
)
SELECT
id, uri, cid, rkey, commenter_did,
root_uri, root_cid, parent_uri, parent_cid,
content, content_facets, embed, content_labels, langs,
-
created_at, indexed_at, deleted_at,
+
created_at, indexed_at, deleted_at, deletion_reason, deleted_by,
upvote_count, downvote_count, score, reply_count,
hot_rank, author_handle
FROM ranked_comments
···
&comment.ID, &comment.URI, &comment.CID, &comment.RKey, &comment.CommenterDID,
&comment.RootURI, &comment.RootCID, &comment.ParentURI, &comment.ParentCID,
&comment.Content, &comment.ContentFacets, &comment.Embed, &comment.ContentLabels, &langs,
-
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt,
+
&comment.CreatedAt, &comment.IndexedAt, &comment.DeletedAt, &comment.DeletionReason, &comment.DeletedBy,
&comment.UpvoteCount, &comment.DownvoteCount, &comment.Score, &comment.ReplyCount,
&hotRank, &authorHandle,
)
+5 -6
internal/core/comments/comment_service.go
···
CreatedAt: createdAt, // Preserve original timestamp
}
-
// Update the record on PDS (putRecord)
-
// Note: This creates a new CID even though the URI stays the same
-
// TODO: Use PutRecord instead of CreateRecord for proper update semantics with optimistic locking.
-
// PutRecord should accept the existing CID (existingRecord.CID) to ensure concurrent updates are detected.
-
// However, PutRecord is not yet implemented in internal/atproto/pds/client.go.
-
uri, cid, err := pdsClient.CreateRecord(ctx, commentCollection, rkey, updatedRecord)
+
// Update the record on PDS with optimistic locking via swapRecord CID
+
uri, cid, err := pdsClient.PutRecord(ctx, commentCollection, rkey, updatedRecord, existingRecord.CID)
if err != nil {
s.logger.Error("failed to update comment on PDS",
"error", err,
···
if pds.IsAuthError(err) {
return nil, ErrNotAuthorized
}
+
if errors.Is(err, pds.ErrConflict) {
+
return nil, ErrConcurrentModification
+
}
return nil, fmt.Errorf("failed to update comment: %w", err)
}
+73
internal/api/handlers/common/viewer_state.go
···
+
package common
+
+
import (
+
"Coves/internal/api/middleware"
+
"Coves/internal/core/posts"
+
"Coves/internal/core/votes"
+
"context"
+
"log"
+
"net/http"
+
)
+
+
// FeedPostProvider is implemented by any feed post wrapper that contains a PostView.
+
// This allows the helper to work with different feed post types (discover, timeline, communityFeed).
+
type FeedPostProvider interface {
+
GetPost() *posts.PostView
+
}
+
+
// PopulateViewerVoteState enriches feed posts with the authenticated user's vote state.
+
// This is a no-op if voteService is nil or the request is unauthenticated.
+
//
+
// Parameters:
+
// - ctx: Request context for PDS calls
+
// - r: HTTP request (used to extract OAuth session)
+
// - voteService: Vote service for cache lookup (may be nil)
+
// - feedPosts: Posts to enrich with viewer state (must implement FeedPostProvider)
+
//
+
// The function logs but does not fail on errors - viewer state is optional enrichment.
+
func PopulateViewerVoteState[T FeedPostProvider](
+
ctx context.Context,
+
r *http.Request,
+
voteService votes.Service,
+
feedPosts []T,
+
) {
+
if voteService == nil {
+
return
+
}
+
+
session := middleware.GetOAuthSession(r)
+
if session == nil {
+
return
+
}
+
+
userDID := middleware.GetUserDID(r)
+
+
// Ensure vote cache is populated from PDS
+
if err := voteService.EnsureCachePopulated(ctx, session); err != nil {
+
log.Printf("Warning: failed to populate vote cache: %v", err)
+
return
+
}
+
+
// Collect post URIs to batch lookup
+
postURIs := make([]string, 0, len(feedPosts))
+
for _, feedPost := range feedPosts {
+
if post := feedPost.GetPost(); post != nil {
+
postURIs = append(postURIs, post.URI)
+
}
+
}
+
+
// Get viewer votes for all posts
+
viewerVotes := voteService.GetViewerVotesForSubjects(userDID, postURIs)
+
+
// Populate viewer state on each post
+
for _, feedPost := range feedPosts {
+
if post := feedPost.GetPost(); post != nil {
+
if vote, exists := viewerVotes[post.URI]; exists {
+
post.Viewer = &posts.ViewerState{
+
Vote: &vote.Direction,
+
VoteURI: &vote.URI,
+
}
+
}
+
}
+
}
+
}
+11 -4
internal/api/handlers/discover/get_discover.go
···
package discover
import (
+
"Coves/internal/api/handlers/common"
"Coves/internal/core/discover"
"Coves/internal/core/posts"
+
"Coves/internal/core/votes"
"encoding/json"
"log"
"net/http"
···
// GetDiscoverHandler handles discover feed retrieval
type GetDiscoverHandler struct {
-
service discover.Service
+
service discover.Service
+
voteService votes.Service
}
// NewGetDiscoverHandler creates a new discover handler
-
func NewGetDiscoverHandler(service discover.Service) *GetDiscoverHandler {
+
func NewGetDiscoverHandler(service discover.Service, voteService votes.Service) *GetDiscoverHandler {
return &GetDiscoverHandler{
-
service: service,
+
service: service,
+
voteService: voteService,
}
}
// HandleGetDiscover retrieves posts from all communities (public feed)
// GET /xrpc/social.coves.feed.getDiscover?sort=hot&limit=15&cursor=...
-
// Public endpoint - no authentication required
+
// Public endpoint with optional auth - if authenticated, includes viewer vote state
func (h *GetDiscoverHandler) HandleGetDiscover(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
···
return
}
+
// Populate viewer vote state if authenticated
+
common.PopulateViewerVoteState(r.Context(), r, h.voteService, response.Feed)
+
// Transform blob refs to URLs for all posts
for _, feedPost := range response.Feed {
if feedPost.Post != nil {
+9 -4
internal/api/routes/discover.go
···
import (
"Coves/internal/api/handlers/discover"
+
"Coves/internal/api/middleware"
discoverCore "Coves/internal/core/discover"
+
"Coves/internal/core/votes"
"github.com/go-chi/chi/v5"
)
···
// RegisterDiscoverRoutes registers discover-related XRPC endpoints
//
// SECURITY & RATE LIMITING:
-
// - Discover feed is PUBLIC (no authentication required)
+
// - Discover feed is PUBLIC (works without authentication)
+
// - Optional auth: if authenticated, includes viewer vote state on posts
// - Protected by global rate limiter: 100 requests/minute per IP (main.go:84)
// - Query timeout enforced via context (prevents long-running queries)
// - Result limit capped at 50 posts per request (validated in service layer)
···
func RegisterDiscoverRoutes(
r chi.Router,
discoverService discoverCore.Service,
+
voteService votes.Service,
+
authMiddleware *middleware.OAuthAuthMiddleware,
) {
// Create handlers
-
getDiscoverHandler := discover.NewGetDiscoverHandler(discoverService)
+
getDiscoverHandler := discover.NewGetDiscoverHandler(discoverService, voteService)
// GET /xrpc/social.coves.feed.getDiscover
-
// Public endpoint - no authentication required
+
// Public endpoint with optional auth for viewer-specific state (vote state)
// Shows posts from ALL communities (not personalized)
// Rate limited: 100 req/min per IP via global middleware
-
r.Get("/xrpc/social.coves.feed.getDiscover", getDiscoverHandler.HandleGetDiscover)
+
r.With(authMiddleware.OptionalAuth).Get("/xrpc/social.coves.feed.getDiscover", getDiscoverHandler.HandleGetDiscover)
}
+5
internal/core/communityFeeds/types.go
···
Reply *ReplyRef `json:"reply,omitempty"` // Reply context
}
+
// GetPost returns the underlying PostView for viewer state enrichment
+
func (f *FeedViewPost) GetPost() *posts.PostView {
+
return f.Post
+
}
+
// FeedReason is a union type for feed context
// Can be reasonRepost or reasonPin
type FeedReason struct {
+5
internal/core/discover/types.go
···
Reply *ReplyRef `json:"reply,omitempty"`
}
+
// GetPost returns the underlying PostView for viewer state enrichment
+
func (f *FeedViewPost) GetPost() *posts.PostView {
+
return f.Post
+
}
+
// FeedReason is a union type for feed context
type FeedReason struct {
Repost *ReasonRepost `json:"-"`
+5
internal/core/timeline/types.go
···
Reply *ReplyRef `json:"reply,omitempty"` // Reply context
}
+
// GetPost returns the underlying PostView for viewer state enrichment
+
func (f *FeedViewPost) GetPost() *posts.PostView {
+
return f.Post
+
}
+
// FeedReason is a union type for feed context
// Future: Can be reasonRepost or reasonCommunity
type FeedReason struct {
+193 -5
tests/integration/discover_test.go
···
import (
"Coves/internal/api/handlers/discover"
+
"Coves/internal/api/middleware"
+
"Coves/internal/core/votes"
"Coves/internal/db/postgres"
"context"
"encoding/json"
···
discoverCore "Coves/internal/core/discover"
+
oauthlib "github.com/bluesky-social/indigo/atproto/auth/oauth"
+
"github.com/bluesky-social/indigo/atproto/syntax"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
+
// mockVoteService implements votes.Service for testing viewer vote state
+
type mockVoteService struct {
+
cachedVotes map[string]*votes.CachedVote // userDID:subjectURI -> vote
+
}
+
+
func newMockVoteService() *mockVoteService {
+
return &mockVoteService{
+
cachedVotes: make(map[string]*votes.CachedVote),
+
}
+
}
+
+
func (m *mockVoteService) AddVote(userDID, subjectURI, direction, voteURI string) {
+
key := userDID + ":" + subjectURI
+
m.cachedVotes[key] = &votes.CachedVote{
+
Direction: direction,
+
URI: voteURI,
+
}
+
}
+
+
func (m *mockVoteService) CreateVote(_ context.Context, _ *oauthlib.ClientSessionData, _ votes.CreateVoteRequest) (*votes.CreateVoteResponse, error) {
+
return &votes.CreateVoteResponse{}, nil
+
}
+
+
func (m *mockVoteService) DeleteVote(_ context.Context, _ *oauthlib.ClientSessionData, _ votes.DeleteVoteRequest) error {
+
return nil
+
}
+
+
func (m *mockVoteService) EnsureCachePopulated(_ context.Context, _ *oauthlib.ClientSessionData) error {
+
return nil // Mock always succeeds - votes pre-populated via AddVote
+
}
+
+
func (m *mockVoteService) GetViewerVote(userDID, subjectURI string) *votes.CachedVote {
+
key := userDID + ":" + subjectURI
+
return m.cachedVotes[key]
+
}
+
+
func (m *mockVoteService) GetViewerVotesForSubjects(userDID string, subjectURIs []string) map[string]*votes.CachedVote {
+
result := make(map[string]*votes.CachedVote)
+
for _, uri := range subjectURIs {
+
key := userDID + ":" + uri
+
if vote, exists := m.cachedVotes[key]; exists {
+
result[uri] = vote
+
}
+
}
+
return result
+
}
+
// TestGetDiscover_ShowsAllCommunities tests discover feed shows posts from ALL communities
func TestGetDiscover_ShowsAllCommunities(t *testing.T) {
if testing.Short() {
···
// Setup services
discoverRepo := postgres.NewDiscoverRepository(db, "test-cursor-secret")
discoverService := discoverCore.NewDiscoverService(discoverRepo)
-
handler := discover.NewGetDiscoverHandler(discoverService)
+
handler := discover.NewGetDiscoverHandler(discoverService, nil) // nil vote service - tests don't need vote state
ctx := context.Background()
testID := time.Now().UnixNano()
···
// Setup services
discoverRepo := postgres.NewDiscoverRepository(db, "test-cursor-secret")
discoverService := discoverCore.NewDiscoverService(discoverRepo)
-
handler := discover.NewGetDiscoverHandler(discoverService)
+
handler := discover.NewGetDiscoverHandler(discoverService, nil) // nil vote service - tests don't need vote state
ctx := context.Background()
testID := time.Now().UnixNano()
···
// Setup services
discoverRepo := postgres.NewDiscoverRepository(db, "test-cursor-secret")
discoverService := discoverCore.NewDiscoverService(discoverRepo)
-
handler := discover.NewGetDiscoverHandler(discoverService)
+
handler := discover.NewGetDiscoverHandler(discoverService, nil)
ctx := context.Background()
testID := time.Now().UnixNano()
···
// Setup services
discoverRepo := postgres.NewDiscoverRepository(db, "test-cursor-secret")
discoverService := discoverCore.NewDiscoverService(discoverRepo)
-
handler := discover.NewGetDiscoverHandler(discoverService)
+
handler := discover.NewGetDiscoverHandler(discoverService, nil)
ctx := context.Background()
testID := time.Now().UnixNano()
···
// Setup services
discoverRepo := postgres.NewDiscoverRepository(db, "test-cursor-secret")
discoverService := discoverCore.NewDiscoverService(discoverRepo)
-
handler := discover.NewGetDiscoverHandler(discoverService)
+
handler := discover.NewGetDiscoverHandler(discoverService, nil)
t.Run("Limit exceeds maximum", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/xrpc/social.coves.feed.getDiscover?sort=new&limit=100", nil)
···
assert.Contains(t, errorResp["message"], "limit")
})
}
+
+
// TestGetDiscover_ViewerVoteState tests that authenticated users see their vote state on posts
+
func TestGetDiscover_ViewerVoteState(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping integration test in short mode")
+
}
+
+
db := setupTestDB(t)
+
t.Cleanup(func() { _ = db.Close() })
+
+
ctx := context.Background()
+
testID := time.Now().UnixNano()
+
+
// Create community and posts
+
communityDID, err := createFeedTestCommunity(db, ctx, fmt.Sprintf("votes-%d", testID), fmt.Sprintf("alice-%d.test", testID))
+
require.NoError(t, err)
+
+
post1URI := createTestPost(t, db, communityDID, "did:plc:author1", "Post with upvote", 10, time.Now().Add(-1*time.Hour))
+
post2URI := createTestPost(t, db, communityDID, "did:plc:author2", "Post with downvote", 5, time.Now().Add(-2*time.Hour))
+
_ = createTestPost(t, db, communityDID, "did:plc:author3", "Post without vote", 3, time.Now().Add(-3*time.Hour))
+
+
// Setup mock vote service with pre-populated votes
+
viewerDID := "did:plc:viewer123"
+
mockVotes := newMockVoteService()
+
mockVotes.AddVote(viewerDID, post1URI, "up", "at://"+viewerDID+"/social.coves.vote/vote1")
+
mockVotes.AddVote(viewerDID, post2URI, "down", "at://"+viewerDID+"/social.coves.vote/vote2")
+
+
// Setup handler with mock vote service
+
discoverRepo := postgres.NewDiscoverRepository(db, "test-cursor-secret")
+
discoverService := discoverCore.NewDiscoverService(discoverRepo)
+
handler := discover.NewGetDiscoverHandler(discoverService, mockVotes)
+
+
// Create request with authenticated user context
+
req := httptest.NewRequest(http.MethodGet, "/xrpc/social.coves.feed.getDiscover?sort=new&limit=50", nil)
+
+
// Inject OAuth session into context (simulates OptionalAuth middleware)
+
did, _ := syntax.ParseDID(viewerDID)
+
session := &oauthlib.ClientSessionData{
+
AccountDID: did,
+
AccessToken: "test_token",
+
}
+
reqCtx := context.WithValue(req.Context(), middleware.UserDIDKey, viewerDID)
+
reqCtx = context.WithValue(reqCtx, middleware.OAuthSessionKey, session)
+
req = req.WithContext(reqCtx)
+
+
rec := httptest.NewRecorder()
+
handler.HandleGetDiscover(rec, req)
+
+
// Assertions
+
assert.Equal(t, http.StatusOK, rec.Code)
+
+
var response discoverCore.DiscoverResponse
+
err = json.Unmarshal(rec.Body.Bytes(), &response)
+
require.NoError(t, err)
+
+
// Find our test posts and verify vote state
+
var foundPost1, foundPost2, foundPost3 bool
+
for _, feedPost := range response.Feed {
+
switch feedPost.Post.URI {
+
case post1URI:
+
foundPost1 = true
+
require.NotNil(t, feedPost.Post.Viewer, "Post1 should have viewer state")
+
require.NotNil(t, feedPost.Post.Viewer.Vote, "Post1 should have vote direction")
+
assert.Equal(t, "up", *feedPost.Post.Viewer.Vote, "Post1 should show upvote")
+
require.NotNil(t, feedPost.Post.Viewer.VoteURI, "Post1 should have vote URI")
+
assert.Contains(t, *feedPost.Post.Viewer.VoteURI, "vote1", "Post1 should have correct vote URI")
+
+
case post2URI:
+
foundPost2 = true
+
require.NotNil(t, feedPost.Post.Viewer, "Post2 should have viewer state")
+
require.NotNil(t, feedPost.Post.Viewer.Vote, "Post2 should have vote direction")
+
assert.Equal(t, "down", *feedPost.Post.Viewer.Vote, "Post2 should show downvote")
+
require.NotNil(t, feedPost.Post.Viewer.VoteURI, "Post2 should have vote URI")
+
+
default:
+
// Posts without votes should have nil Viewer or nil Vote
+
if feedPost.Post.Viewer != nil && feedPost.Post.Viewer.Vote != nil {
+
// This post has a vote from our viewer - it's not post3
+
continue
+
}
+
foundPost3 = true
+
}
+
}
+
+
assert.True(t, foundPost1, "Should find post1 with upvote")
+
assert.True(t, foundPost2, "Should find post2 with downvote")
+
assert.True(t, foundPost3, "Should find post3 without vote")
+
}
+
+
// TestGetDiscover_NoViewerStateWithoutAuth tests that unauthenticated users don't get viewer state
+
func TestGetDiscover_NoViewerStateWithoutAuth(t *testing.T) {
+
if testing.Short() {
+
t.Skip("Skipping integration test in short mode")
+
}
+
+
db := setupTestDB(t)
+
t.Cleanup(func() { _ = db.Close() })
+
+
ctx := context.Background()
+
testID := time.Now().UnixNano()
+
+
// Create community and post
+
communityDID, err := createFeedTestCommunity(db, ctx, fmt.Sprintf("noauth-%d", testID), fmt.Sprintf("alice-%d.test", testID))
+
require.NoError(t, err)
+
+
postURI := createTestPost(t, db, communityDID, "did:plc:author", "Some post", 10, time.Now())
+
+
// Setup mock vote service with a vote (but request will be unauthenticated)
+
mockVotes := newMockVoteService()
+
mockVotes.AddVote("did:plc:someuser", postURI, "up", "at://did:plc:someuser/social.coves.vote/vote1")
+
+
// Setup handler with mock vote service
+
discoverRepo := postgres.NewDiscoverRepository(db, "test-cursor-secret")
+
discoverService := discoverCore.NewDiscoverService(discoverRepo)
+
handler := discover.NewGetDiscoverHandler(discoverService, mockVotes)
+
+
// Create request WITHOUT auth context
+
req := httptest.NewRequest(http.MethodGet, "/xrpc/social.coves.feed.getDiscover?sort=new&limit=50", nil)
+
rec := httptest.NewRecorder()
+
handler.HandleGetDiscover(rec, req)
+
+
// Should succeed
+
assert.Equal(t, http.StatusOK, rec.Code)
+
+
var response discoverCore.DiscoverResponse
+
err = json.Unmarshal(rec.Body.Bytes(), &response)
+
require.NoError(t, err)
+
+
// Find our post and verify NO viewer state (unauthenticated)
+
for _, feedPost := range response.Feed {
+
if feedPost.Post.URI == postURI {
+
assert.Nil(t, feedPost.Post.Viewer, "Unauthenticated request should not have viewer state")
+
return
+
}
+
}
+
t.Fatal("Test post not found in response")
+
}
+11 -11
tests/integration/feed_test.go
···
nil,
)
feedService := communityFeeds.NewCommunityFeedService(feedRepo, communityService)
-
handler := communityFeed.NewGetCommunityHandler(feedService)
+
handler := communityFeed.NewGetCommunityHandler(feedService, nil)
// Setup test data: community, users, and posts
ctx := context.Background()
···
nil,
)
feedService := communityFeeds.NewCommunityFeedService(feedRepo, communityService)
-
handler := communityFeed.NewGetCommunityHandler(feedService)
+
handler := communityFeed.NewGetCommunityHandler(feedService, nil)
// Setup test data
ctx := context.Background()
···
nil,
)
feedService := communityFeeds.NewCommunityFeedService(feedRepo, communityService)
-
handler := communityFeed.NewGetCommunityHandler(feedService)
+
handler := communityFeed.NewGetCommunityHandler(feedService, nil)
// Setup test data
ctx := context.Background()
···
nil,
)
feedService := communityFeeds.NewCommunityFeedService(feedRepo, communityService)
-
handler := communityFeed.NewGetCommunityHandler(feedService)
+
handler := communityFeed.NewGetCommunityHandler(feedService, nil)
// Setup test data with many posts
ctx := context.Background()
···
nil,
)
feedService := communityFeeds.NewCommunityFeedService(feedRepo, communityService)
-
handler := communityFeed.NewGetCommunityHandler(feedService)
+
handler := communityFeed.NewGetCommunityHandler(feedService, nil)
// Request feed for non-existent community
req := httptest.NewRequest(http.MethodGet, "/xrpc/social.coves.communityFeed.getCommunity?community=did:plc:nonexistent&sort=hot&limit=10", nil)
···
nil,
)
feedService := communityFeeds.NewCommunityFeedService(feedRepo, communityService)
-
handler := communityFeed.NewGetCommunityHandler(feedService)
+
handler := communityFeed.NewGetCommunityHandler(feedService, nil)
// Setup test community
ctx := context.Background()
···
nil,
)
feedService := communityFeeds.NewCommunityFeedService(feedRepo, communityService)
-
handler := communityFeed.NewGetCommunityHandler(feedService)
+
handler := communityFeed.NewGetCommunityHandler(feedService, nil)
// Create community with no posts
ctx := context.Background()
···
nil,
)
feedService := communityFeeds.NewCommunityFeedService(feedRepo, communityService)
-
handler := communityFeed.NewGetCommunityHandler(feedService)
+
handler := communityFeed.NewGetCommunityHandler(feedService, nil)
// Setup test community
ctx := context.Background()
···
nil,
)
feedService := communityFeeds.NewCommunityFeedService(feedRepo, communityService)
-
handler := communityFeed.NewGetCommunityHandler(feedService)
+
handler := communityFeed.NewGetCommunityHandler(feedService, nil)
// Setup test data
ctx := context.Background()
···
nil,
)
feedService := communityFeeds.NewCommunityFeedService(feedRepo, communityService)
-
handler := communityFeed.NewGetCommunityHandler(feedService)
+
handler := communityFeed.NewGetCommunityHandler(feedService, nil)
// Setup test data
ctx := context.Background()
···
nil,
)
feedService := communityFeeds.NewCommunityFeedService(feedRepo, communityService)
-
handler := communityFeed.NewGetCommunityHandler(feedService)
+
handler := communityFeed.NewGetCommunityHandler(feedService, nil)
// Setup test data
ctx := context.Background()
+7 -7
tests/integration/timeline_test.go
···
// Setup services
timelineRepo := postgres.NewTimelineRepository(db, "test-cursor-secret")
timelineService := timelineCore.NewTimelineService(timelineRepo)
-
handler := timeline.NewGetTimelineHandler(timelineService)
+
handler := timeline.NewGetTimelineHandler(timelineService, nil)
ctx := context.Background()
testID := time.Now().UnixNano()
···
// Setup services
timelineRepo := postgres.NewTimelineRepository(db, "test-cursor-secret")
timelineService := timelineCore.NewTimelineService(timelineRepo)
-
handler := timeline.NewGetTimelineHandler(timelineService)
+
handler := timeline.NewGetTimelineHandler(timelineService, nil)
ctx := context.Background()
testID := time.Now().UnixNano()
···
// Setup services
timelineRepo := postgres.NewTimelineRepository(db, "test-cursor-secret")
timelineService := timelineCore.NewTimelineService(timelineRepo)
-
handler := timeline.NewGetTimelineHandler(timelineService)
+
handler := timeline.NewGetTimelineHandler(timelineService, nil)
ctx := context.Background()
testID := time.Now().UnixNano()
···
// Setup services
timelineRepo := postgres.NewTimelineRepository(db, "test-cursor-secret")
timelineService := timelineCore.NewTimelineService(timelineRepo)
-
handler := timeline.NewGetTimelineHandler(timelineService)
+
handler := timeline.NewGetTimelineHandler(timelineService, nil)
ctx := context.Background()
testID := time.Now().UnixNano()
···
// Setup services
timelineRepo := postgres.NewTimelineRepository(db, "test-cursor-secret")
timelineService := timelineCore.NewTimelineService(timelineRepo)
-
handler := timeline.NewGetTimelineHandler(timelineService)
+
handler := timeline.NewGetTimelineHandler(timelineService, nil)
// Request timeline WITHOUT auth context
req := httptest.NewRequest(http.MethodGet, "/xrpc/social.coves.feed.getTimeline?sort=new&limit=10", nil)
···
// Setup services
timelineRepo := postgres.NewTimelineRepository(db, "test-cursor-secret")
timelineService := timelineCore.NewTimelineService(timelineRepo)
-
handler := timeline.NewGetTimelineHandler(timelineService)
+
handler := timeline.NewGetTimelineHandler(timelineService, nil)
ctx := context.Background()
testID := time.Now().UnixNano()
···
// Setup services
timelineRepo := postgres.NewTimelineRepository(db, "test-cursor-secret")
timelineService := timelineCore.NewTimelineService(timelineRepo)
-
handler := timeline.NewGetTimelineHandler(timelineService)
+
handler := timeline.NewGetTimelineHandler(timelineService, nil)
ctx := context.Background()
testID := time.Now().UnixNano()
+1 -1
tests/integration/user_journey_e2e_test.go
···
r := chi.NewRouter()
routes.RegisterCommunityRoutes(r, communityService, e2eAuth.OAuthAuthMiddleware, nil) // nil = allow all community creators
routes.RegisterPostRoutes(r, postService, e2eAuth.OAuthAuthMiddleware)
-
routes.RegisterTimelineRoutes(r, timelineService, e2eAuth.OAuthAuthMiddleware)
+
routes.RegisterTimelineRoutes(r, timelineService, nil, e2eAuth.OAuthAuthMiddleware)
httpServer := httptest.NewServer(r)
defer httpServer.Close()