A community based topic aggregation platform built on atproto

refactor(oauth): remove dead confidential client and JWT verification code

- Delete internal/atproto/auth/ directory (JWT/DPoP verification - unused)
- Delete cmd/genjwks/ (confidential client key generator - unused)
- Remove ClientSecret/ClientKID from OAuthConfig (public client only)
- Remove HandleJWKS endpoint and routes (not needed for public clients)
- Remove OAUTH_PRIVATE_JWK from docker-compose.prod.yml
- Update tests and integration helpers

Coves is a public OAuth client - this cleanup removes ~1,500 lines of
dead code that was never being used.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

-73
cmd/genjwks/main.go
···
-
package main
-
-
import (
-
"crypto/ecdsa"
-
"crypto/elliptic"
-
"crypto/rand"
-
"encoding/json"
-
"fmt"
-
"log"
-
"os"
-
-
"github.com/lestrrat-go/jwx/v2/jwk"
-
)
-
-
// genjwks generates an ES256 keypair for OAuth client authentication
-
// The private key is stored in the config/env, public key is served at /oauth/jwks.json
-
//
-
// Usage:
-
//
-
// go run cmd/genjwks/main.go
-
//
-
// This will output a JSON private key that should be stored in OAUTH_PRIVATE_JWK
-
func main() {
-
fmt.Println("Generating ES256 keypair for OAuth client authentication...")
-
-
// Generate ES256 (NIST P-256) private key
-
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
-
if err != nil {
-
log.Fatalf("Failed to generate private key: %v", err)
-
}
-
-
// Convert to JWK
-
jwkKey, err := jwk.FromRaw(privateKey)
-
if err != nil {
-
log.Fatalf("Failed to create JWK from private key: %v", err)
-
}
-
-
// Set key parameters
-
if err = jwkKey.Set(jwk.KeyIDKey, "oauth-client-key"); err != nil {
-
log.Fatalf("Failed to set kid: %v", err)
-
}
-
if err = jwkKey.Set(jwk.AlgorithmKey, "ES256"); err != nil {
-
log.Fatalf("Failed to set alg: %v", err)
-
}
-
if err = jwkKey.Set(jwk.KeyUsageKey, "sig"); err != nil {
-
log.Fatalf("Failed to set use: %v", err)
-
}
-
-
// Marshal to JSON
-
jsonData, err := json.MarshalIndent(jwkKey, "", " ")
-
if err != nil {
-
log.Fatalf("Failed to marshal JWK: %v", err)
-
}
-
-
// Output instructions
-
fmt.Println("\n✅ ES256 keypair generated successfully!")
-
fmt.Println("\n📝 Add this to your .env.dev file:")
-
fmt.Println("\nOAUTH_PRIVATE_JWK='" + string(jsonData) + "'")
-
fmt.Println("\n⚠️ IMPORTANT:")
-
fmt.Println(" - Keep this private key SECRET")
-
fmt.Println(" - Never commit it to version control")
-
fmt.Println(" - Generate a new key for production")
-
fmt.Println(" - The public key will be automatically derived and served at /oauth/jwks.json")
-
-
// Optionally write to a file (not committed)
-
if len(os.Args) > 1 && os.Args[1] == "--save" {
-
filename := "oauth-private-key.json"
-
if err := os.WriteFile(filename, jsonData, 0o600); err != nil {
-
log.Fatalf("Failed to write key file: %v", err)
-
}
-
fmt.Printf("\n💾 Private key saved to %s (remember to add to .gitignore!)\n", filename)
-
}
-
}
-4
cmd/server/main.go
···
oauthConfig.DevMode = true // Force dev mode for localhost
}
-
// Optional: confidential client secret for production
-
oauthConfig.ClientSecret = os.Getenv("OAUTH_CLIENT_SECRET")
-
oauthConfig.ClientKID = os.Getenv("OAUTH_CLIENT_KID")
-
oauthClient, err := oauth.NewOAuthClient(oauthConfig, oauthStore)
if err != nil {
log.Fatalf("Failed to initialize OAuth client: %v", err)
-1
docker-compose.prod.yml
···
# OAuth (for community account provisioning)
OAUTH_CLIENT_ID: ${OAUTH_CLIENT_ID}
OAUTH_REDIRECT_URI: ${OAUTH_REDIRECT_URI}
-
OAUTH_PRIVATE_JWK: ${OAUTH_PRIVATE_JWK}
# Application settings
PORT: 8080
-4
internal/api/routes/oauth.go
···
// OAuth metadata endpoints - public, no extra rate limiting (use global limit)
r.Get("/oauth/client-metadata.json", handler.HandleClientMetadata)
-
r.Get("/oauth/jwks.json", handler.HandleJWKS)
-
-
// Alternative well-known paths for OAuth metadata
-
r.Get("/.well-known/oauth-jwks.json", handler.HandleJWKS)
r.Get("/.well-known/oauth-protected-resource", handler.HandleProtectedResourceMetadata)
// OAuth flow endpoints - stricter rate limiting for authentication attempts
-330
internal/atproto/auth/README.md
···
-
# atProto OAuth Authentication
-
-
This package implements third-party OAuth authentication for Coves, validating DPoP-bound access tokens from mobile apps and other atProto clients.
-
-
## Architecture
-
-
This is **third-party authentication** (validating incoming requests), not first-party authentication (logging users into Coves web frontend).
-
-
### Components
-
-
1. **JWT Parser** (`jwt.go`) - Parses and validates JWT tokens
-
2. **JWKS Fetcher** (`jwks_fetcher.go`) - Fetches and caches public keys from PDS authorization servers
-
3. **Auth Middleware** (`internal/api/middleware/auth.go`) - HTTP middleware that protects endpoints
-
-
### Flow
-
-
```
-
Client Request
-
-
Authorization: DPoP <access_token>
-
DPoP: <proof-jwt>
-
-
Auth Middleware
-
-
Extract JWT → Parse Claims → Verify Signature (via JWKS) → Verify DPoP Proof
-
-
Inject DID into Context → Call Handler
-
```
-
-
## Usage
-
-
### Phase 1: Parse-Only Mode (Testing)
-
-
Set `AUTH_SKIP_VERIFY=true` to only parse JWTs without signature verification:
-
-
```bash
-
export AUTH_SKIP_VERIFY=true
-
```
-
-
This is useful for:
-
- Initial integration testing
-
- Testing with mock tokens
-
- Debugging JWT structure
-
-
### Phase 2: Full Verification (Production)
-
-
Set `AUTH_SKIP_VERIFY=false` (or unset) to enable full JWT signature verification:
-
-
```bash
-
export AUTH_SKIP_VERIFY=false
-
# or just unset it
-
```
-
-
This is **required for production** and validates:
-
- JWT signature using PDS public key
-
- Token expiration
-
- Required claims (sub, iss)
-
- DID format
-
-
## Protected Endpoints
-
-
The following endpoints require authentication:
-
-
- `POST /xrpc/social.coves.community.create`
-
- `POST /xrpc/social.coves.community.update`
-
- `POST /xrpc/social.coves.community.subscribe`
-
- `POST /xrpc/social.coves.community.unsubscribe`
-
-
### Making Authenticated Requests
-
-
Include the JWT in the `Authorization` header:
-
-
```bash
-
curl -X POST https://coves.social/xrpc/social.coves.community.create \
-
-H "Authorization: DPoP eyJhbGc..." \
-
-H "DPoP: eyJhbGc..." \
-
-H "Content-Type: application/json" \
-
-d '{"name":"Gaming","hostedByDid":"did:plc:..."}'
-
```
-
-
### Getting User DID in Handlers
-
-
The middleware injects the authenticated user's DID into the request context:
-
-
```go
-
import "Coves/internal/api/middleware"
-
-
func (h *Handler) HandleCreate(w http.ResponseWriter, r *http.Request) {
-
// Extract authenticated user DID
-
userDID := middleware.GetUserDID(r)
-
if userDID == "" {
-
// Not authenticated (should never happen with RequireAuth middleware)
-
http.Error(w, "Unauthorized", http.StatusUnauthorized)
-
return
-
}
-
-
// Use userDID for authorization checks
-
// ...
-
}
-
```
-
-
## Key Caching
-
-
Public keys are fetched from PDS authorization servers and cached for 1 hour. The cache is automatically cleaned up hourly to remove expired entries.
-
-
### JWKS Discovery Flow
-
-
1. Extract `iss` claim from JWT (e.g., `https://pds.example.com`)
-
2. Fetch `https://pds.example.com/.well-known/oauth-authorization-server`
-
3. Extract `jwks_uri` from metadata
-
4. Fetch JWKS from `jwks_uri`
-
5. Find matching key by `kid` from JWT header
-
6. Cache the JWKS for 1 hour
-
-
## DPoP Token Binding
-
-
DPoP (Demonstrating Proof-of-Possession) binds access tokens to client-controlled cryptographic keys, preventing token theft and replay attacks.
-
-
### What is DPoP?
-
-
DPoP is an OAuth extension (RFC 9449) that adds proof-of-possession semantics to bearer tokens. When a PDS issues a DPoP-bound access token:
-
-
1. Access token contains `cnf.jkt` claim (JWK thumbprint of client's public key)
-
2. Client creates a DPoP proof JWT signed with their private key
-
3. Server verifies the proof signature and checks it matches the token's `cnf.jkt`
-
-
### CRITICAL: DPoP Security Model
-
-
> ⚠️ **DPoP is an ADDITIONAL security layer, NOT a replacement for token signature verification.**
-
-
The correct verification order is:
-
1. **ALWAYS verify the access token signature first** (via JWKS, HS256 shared secret, or DID resolution)
-
2. **If the verified token has `cnf.jkt`, REQUIRE valid DPoP proof**
-
3. **NEVER use DPoP as a fallback when signature verification fails**
-
-
**Why This Matters**: An attacker could create a fake token with `sub: "did:plc:victim"` and their own `cnf.jkt`, then present a valid DPoP proof signed with their key. If we accept DPoP as a fallback, the attacker can impersonate any user.
-
-
### How DPoP Works
-
-
```
-
┌─────────────┐ ┌─────────────┐
-
│ Client │ │ Server │
-
│ │ │ (Coves) │
-
└─────────────┘ └─────────────┘
-
│ │
-
│ 1. Authorization: DPoP <token> │
-
│ DPoP: <proof-jwt> │
-
│───────────────────────────────────────>│
-
│ │
-
│ │ 2. VERIFY token signature
-
│ │ (REQUIRED - no fallback!)
-
│ │
-
│ │ 3. If token has cnf.jkt:
-
│ │ - Verify DPoP proof
-
│ │ - Check thumbprint match
-
│ │
-
│ 200 OK │
-
│<───────────────────────────────────────│
-
```
-
-
### When DPoP is Required
-
-
DPoP verification is **REQUIRED** when:
-
- Access token signature has been verified AND
-
- Access token contains `cnf.jkt` claim (DPoP-bound)
-
-
If the token has `cnf.jkt` but no DPoP header is present, the request is **REJECTED**.
-
-
### Replay Protection
-
-
DPoP proofs include a unique `jti` (JWT ID) claim. The server tracks seen `jti` values to prevent replay attacks:
-
-
```go
-
// Create a verifier with replay protection (default)
-
verifier := auth.NewDPoPVerifier()
-
defer verifier.Stop() // Stop cleanup goroutine on shutdown
-
-
// The verifier automatically rejects reused jti values within the proof validity window (5 minutes)
-
```
-
-
### DPoP Implementation
-
-
The `dpop.go` module provides:
-
-
```go
-
// Create a verifier with replay protection
-
verifier := auth.NewDPoPVerifier()
-
defer verifier.Stop()
-
-
// Verify the DPoP proof
-
proof, err := verifier.VerifyDPoPProof(dpopHeader, "POST", "https://coves.social/xrpc/...")
-
if err != nil {
-
// Invalid proof (includes replay detection)
-
}
-
-
// Verify it binds to the VERIFIED access token
-
expectedThumbprint, err := auth.ExtractCnfJkt(claims)
-
if err != nil {
-
// Token not DPoP-bound
-
}
-
-
if err := verifier.VerifyTokenBinding(proof, expectedThumbprint); err != nil {
-
// Proof doesn't match token
-
}
-
```
-
-
### DPoP Proof Format
-
-
The DPoP header contains a JWT with:
-
-
**Header**:
-
- `typ`: `"dpop+jwt"` (required)
-
- `alg`: `"ES256"` (or other supported algorithm)
-
- `jwk`: Client's public key (JWK format)
-
-
**Claims**:
-
- `jti`: Unique proof identifier (tracked for replay protection)
-
- `htm`: HTTP method (e.g., `"POST"`)
-
- `htu`: HTTP URI (without query/fragment)
-
- `iat`: Timestamp (must be recent, within 5 minutes)
-
-
**Example**:
-
```json
-
{
-
"typ": "dpop+jwt",
-
"alg": "ES256",
-
"jwk": {
-
"kty": "EC",
-
"crv": "P-256",
-
"x": "...",
-
"y": "..."
-
}
-
}
-
{
-
"jti": "unique-id-123",
-
"htm": "POST",
-
"htu": "https://coves.social/xrpc/social.coves.community.create",
-
"iat": 1700000000
-
}
-
```
-
-
## Security Considerations
-
-
### ✅ Implemented
-
-
- JWT signature verification with PDS public keys
-
- Token expiration validation
-
- DID format validation
-
- Required claims validation (sub, iss)
-
- Key caching with TTL
-
- Secure error messages (no internal details leaked)
-
- **DPoP proof verification** (proof-of-possession for token binding)
-
- **DPoP thumbprint validation** (prevents token theft attacks)
-
- **DPoP freshness checks** (5-minute proof validity window)
-
- **DPoP replay protection** (jti tracking with in-memory cache)
-
- **Secure DPoP model** (DPoP required AFTER signature verification, never as fallback)
-
-
### ⚠️ Not Yet Implemented
-
-
- Server-issued DPoP nonces (additional replay protection)
-
- Scope validation (checking `scope` claim)
-
- Audience validation (checking `aud` claim)
-
- Rate limiting per DID
-
- Token revocation checking
-
-
## Testing
-
-
Run the test suite:
-
-
```bash
-
go test ./internal/atproto/auth/... -v
-
```
-
-
### Manual Testing
-
-
1. **Phase 1 (Parse Only)**:
-
```bash
-
# Create a test JWT (use jwt.io or a tool)
-
export AUTH_SKIP_VERIFY=true
-
curl -X POST http://localhost:8081/xrpc/social.coves.community.create \
-
-H "Authorization: DPoP <test-jwt>" \
-
-H "DPoP: <test-dpop-proof>" \
-
-d '{"name":"Test","hostedByDid":"did:plc:test"}'
-
```
-
-
2. **Phase 2 (Full Verification)**:
-
```bash
-
# Use a real JWT from a PDS
-
export AUTH_SKIP_VERIFY=false
-
curl -X POST http://localhost:8081/xrpc/social.coves.community.create \
-
-H "Authorization: DPoP <real-jwt>" \
-
-H "DPoP: <real-dpop-proof>" \
-
-d '{"name":"Test","hostedByDid":"did:plc:test"}'
-
```
-
-
## Error Responses
-
-
### 401 Unauthorized
-
-
Missing or invalid token:
-
-
```json
-
{
-
"error": "AuthenticationRequired",
-
"message": "Missing Authorization header"
-
}
-
```
-
-
```json
-
{
-
"error": "AuthenticationRequired",
-
"message": "Invalid or expired token"
-
}
-
```
-
-
### Common Issues
-
-
1. **Missing Authorization header** → Add `Authorization: DPoP <token>` and `DPoP: <proof>`
-
2. **Token expired** → Get a new token from PDS
-
3. **Invalid signature** → Ensure token is from a valid PDS
-
4. **JWKS fetch fails** → Check PDS availability and network connectivity
-
-
## Future Enhancements
-
-
- [ ] DPoP nonce validation (server-managed nonce for additional replay protection)
-
- [ ] Scope-based authorization
-
- [ ] Audience claim validation
-
- [ ] Token revocation support
-
- [ ] Rate limiting per DID
-
- [ ] Metrics and monitoring
-52
internal/atproto/auth/combined_key_fetcher.go
···
-
package auth
-
-
import (
-
"context"
-
"fmt"
-
"strings"
-
-
indigoIdentity "github.com/bluesky-social/indigo/atproto/identity"
-
)
-
-
// CombinedKeyFetcher handles JWT public key fetching for both:
-
// - DID issuers (did:plc:, did:web:) → resolves via DID document
-
// - URL issuers (https://) → fetches via JWKS endpoint (legacy/fallback)
-
//
-
// For atproto service authentication, the issuer is typically the user's DID,
-
// and the signing key is published in their DID document.
-
type CombinedKeyFetcher struct {
-
didFetcher *DIDKeyFetcher
-
jwksFetcher JWKSFetcher
-
}
-
-
// NewCombinedKeyFetcher creates a key fetcher that supports both DID and URL issuers.
-
// Parameters:
-
// - directory: Indigo's identity directory for DID resolution
-
// - jwksFetcher: fallback JWKS fetcher for URL issuers (can be nil if not needed)
-
func NewCombinedKeyFetcher(directory indigoIdentity.Directory, jwksFetcher JWKSFetcher) *CombinedKeyFetcher {
-
return &CombinedKeyFetcher{
-
didFetcher: NewDIDKeyFetcher(directory),
-
jwksFetcher: jwksFetcher,
-
}
-
}
-
-
// FetchPublicKey fetches the public key for verifying a JWT.
-
// Routes to the appropriate fetcher based on issuer format:
-
// - DID (did:plc:, did:web:) → DIDKeyFetcher
-
// - URL (https://) → JWKSFetcher
-
func (f *CombinedKeyFetcher) FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) {
-
// Check if issuer is a DID
-
if strings.HasPrefix(issuer, "did:") {
-
return f.didFetcher.FetchPublicKey(ctx, issuer, token)
-
}
-
-
// Check if issuer is a URL (https:// or http:// in dev)
-
if strings.HasPrefix(issuer, "https://") || strings.HasPrefix(issuer, "http://") {
-
if f.jwksFetcher == nil {
-
return nil, fmt.Errorf("URL issuer %s requires JWKS fetcher, but none configured", issuer)
-
}
-
return f.jwksFetcher.FetchPublicKey(ctx, issuer, token)
-
}
-
-
return nil, fmt.Errorf("unsupported issuer format: %s (expected DID or URL)", issuer)
-
}
-122
internal/atproto/auth/did_key_fetcher.go
···
-
package auth
-
-
import (
-
"context"
-
"crypto/ecdsa"
-
"crypto/elliptic"
-
"encoding/base64"
-
"fmt"
-
"math/big"
-
"strings"
-
-
indigoCrypto "github.com/bluesky-social/indigo/atproto/atcrypto"
-
indigoIdentity "github.com/bluesky-social/indigo/atproto/identity"
-
"github.com/bluesky-social/indigo/atproto/syntax"
-
)
-
-
// DIDKeyFetcher fetches public keys from DID documents for JWT verification.
-
// This is the primary method for atproto service authentication, where:
-
// - The JWT issuer is the user's DID (e.g., did:plc:abc123)
-
// - The signing key is published in the user's DID document
-
// - Verification happens by resolving the DID and checking the signature
-
type DIDKeyFetcher struct {
-
directory indigoIdentity.Directory
-
}
-
-
// NewDIDKeyFetcher creates a new DID-based key fetcher.
-
func NewDIDKeyFetcher(directory indigoIdentity.Directory) *DIDKeyFetcher {
-
return &DIDKeyFetcher{
-
directory: directory,
-
}
-
}
-
-
// FetchPublicKey fetches the public key for verifying a JWT from the issuer's DID document.
-
// For DID issuers (did:plc: or did:web:), resolves the DID and extracts the signing key.
-
//
-
// Returns:
-
// - indigoCrypto.PublicKey for secp256k1 (ES256K) keys - use indigo for verification
-
// - *ecdsa.PublicKey for NIST curves (P-256, P-384, P-521) - compatible with golang-jwt
-
func (f *DIDKeyFetcher) FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) {
-
// Only handle DID issuers
-
if !strings.HasPrefix(issuer, "did:") {
-
return nil, fmt.Errorf("DIDKeyFetcher only handles DID issuers, got: %s", issuer)
-
}
-
-
// Parse the DID
-
did, err := syntax.ParseDID(issuer)
-
if err != nil {
-
return nil, fmt.Errorf("invalid DID format: %w", err)
-
}
-
-
// Resolve the DID to get the identity (includes public keys)
-
ident, err := f.directory.LookupDID(ctx, did)
-
if err != nil {
-
return nil, fmt.Errorf("failed to resolve DID %s: %w", issuer, err)
-
}
-
-
// Get the atproto signing key from the DID document
-
pubKey, err := ident.PublicKey()
-
if err != nil {
-
return nil, fmt.Errorf("failed to get public key from DID document: %w", err)
-
}
-
-
// Convert to JWK format to check curve type
-
jwk, err := pubKey.JWK()
-
if err != nil {
-
return nil, fmt.Errorf("failed to convert public key to JWK: %w", err)
-
}
-
-
// For secp256k1 (ES256K), return indigo's PublicKey directly
-
// since Go's crypto/ecdsa doesn't support this curve
-
if jwk.Curve == "secp256k1" {
-
return pubKey, nil
-
}
-
-
// For NIST curves, convert to Go's ecdsa.PublicKey for golang-jwt compatibility
-
return atcryptoJWKToECDSA(jwk)
-
}
-
-
// atcryptoJWKToECDSA converts an indigoCrypto.JWK to a Go ecdsa.PublicKey.
-
// Note: secp256k1 is handled separately in FetchPublicKey by returning indigo's PublicKey directly.
-
func atcryptoJWKToECDSA(jwk *indigoCrypto.JWK) (*ecdsa.PublicKey, error) {
-
if jwk.KeyType != "EC" {
-
return nil, fmt.Errorf("unsupported JWK key type: %s (expected EC)", jwk.KeyType)
-
}
-
-
// Decode X and Y coordinates (base64url, no padding)
-
xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X)
-
if err != nil {
-
return nil, fmt.Errorf("invalid JWK X coordinate encoding: %w", err)
-
}
-
yBytes, err := base64.RawURLEncoding.DecodeString(jwk.Y)
-
if err != nil {
-
return nil, fmt.Errorf("invalid JWK Y coordinate encoding: %w", err)
-
}
-
-
var ecCurve elliptic.Curve
-
switch jwk.Curve {
-
case "P-256":
-
ecCurve = elliptic.P256()
-
case "P-384":
-
ecCurve = elliptic.P384()
-
case "P-521":
-
ecCurve = elliptic.P521()
-
default:
-
// secp256k1 should be handled before calling this function
-
return nil, fmt.Errorf("unsupported JWK curve for Go ecdsa: %s (secp256k1 uses indigo)", jwk.Curve)
-
}
-
-
// Create the public key
-
pubKey := &ecdsa.PublicKey{
-
Curve: ecCurve,
-
X: new(big.Int).SetBytes(xBytes),
-
Y: new(big.Int).SetBytes(yBytes),
-
}
-
-
// Validate point is on curve
-
if !ecCurve.IsOnCurve(pubKey.X, pubKey.Y) {
-
return nil, fmt.Errorf("invalid public key: point not on curve")
-
}
-
-
return pubKey, nil
-
}
-616
internal/atproto/auth/dpop.go
···
-
package auth
-
-
import (
-
"crypto/sha256"
-
"encoding/base64"
-
"encoding/json"
-
"fmt"
-
"strings"
-
"sync"
-
"time"
-
-
indigoCrypto "github.com/bluesky-social/indigo/atproto/atcrypto"
-
"github.com/golang-jwt/jwt/v5"
-
)
-
-
// NonceCache provides replay protection for DPoP proofs by tracking seen jti values.
-
// This prevents an attacker from reusing a captured DPoP proof within the validity window.
-
// Per RFC 9449 Section 11.1, servers SHOULD prevent replay attacks.
-
type NonceCache struct {
-
seen map[string]time.Time // jti -> expiration time
-
stopCh chan struct{}
-
maxAge time.Duration // How long to keep entries
-
cleanup time.Duration // How often to clean up expired entries
-
mu sync.RWMutex
-
}
-
-
// NewNonceCache creates a new nonce cache for DPoP replay protection.
-
// maxAge should match or exceed DPoPVerifier.MaxProofAge.
-
func NewNonceCache(maxAge time.Duration) *NonceCache {
-
nc := &NonceCache{
-
seen: make(map[string]time.Time),
-
maxAge: maxAge,
-
cleanup: maxAge / 2, // Clean up at half the max age
-
stopCh: make(chan struct{}),
-
}
-
-
// Start background cleanup goroutine
-
go nc.cleanupLoop()
-
-
return nc
-
}
-
-
// CheckAndStore checks if a jti has been seen before and stores it if not.
-
// Returns true if the jti is fresh (not a replay), false if it's a replay.
-
func (nc *NonceCache) CheckAndStore(jti string) bool {
-
nc.mu.Lock()
-
defer nc.mu.Unlock()
-
-
now := time.Now()
-
expiry := now.Add(nc.maxAge)
-
-
// Check if already seen
-
if existingExpiry, seen := nc.seen[jti]; seen {
-
// Still valid (not expired) - this is a replay
-
if existingExpiry.After(now) {
-
return false
-
}
-
// Expired entry - allow reuse and update expiry
-
}
-
-
// Store the new jti
-
nc.seen[jti] = expiry
-
return true
-
}
-
-
// cleanupLoop periodically removes expired entries from the cache
-
func (nc *NonceCache) cleanupLoop() {
-
ticker := time.NewTicker(nc.cleanup)
-
defer ticker.Stop()
-
-
for {
-
select {
-
case <-ticker.C:
-
nc.cleanupExpired()
-
case <-nc.stopCh:
-
return
-
}
-
}
-
}
-
-
// cleanupExpired removes expired entries from the cache
-
func (nc *NonceCache) cleanupExpired() {
-
nc.mu.Lock()
-
defer nc.mu.Unlock()
-
-
now := time.Now()
-
for jti, expiry := range nc.seen {
-
if expiry.Before(now) {
-
delete(nc.seen, jti)
-
}
-
}
-
}
-
-
// Stop stops the cleanup goroutine. Call this when done with the cache.
-
func (nc *NonceCache) Stop() {
-
close(nc.stopCh)
-
}
-
-
// Size returns the number of entries in the cache (for testing/monitoring)
-
func (nc *NonceCache) Size() int {
-
nc.mu.RLock()
-
defer nc.mu.RUnlock()
-
return len(nc.seen)
-
}
-
-
// DPoPClaims represents the claims in a DPoP proof JWT (RFC 9449)
-
type DPoPClaims struct {
-
jwt.RegisteredClaims
-
-
// HTTP method of the request (e.g., "GET", "POST")
-
HTTPMethod string `json:"htm"`
-
-
// HTTP URI of the request (without query and fragment parts)
-
HTTPURI string `json:"htu"`
-
-
// Access token hash (optional, for token binding)
-
AccessTokenHash string `json:"ath,omitempty"`
-
}
-
-
// DPoPProof represents a parsed and verified DPoP proof
-
type DPoPProof struct {
-
RawPublicJWK map[string]interface{}
-
Claims *DPoPClaims
-
PublicKey interface{} // *ecdsa.PublicKey or similar
-
Thumbprint string // JWK thumbprint (base64url)
-
}
-
-
// DPoPVerifier verifies DPoP proofs for OAuth token binding
-
type DPoPVerifier struct {
-
// Optional: custom nonce validation function (for server-issued nonces)
-
ValidateNonce func(nonce string) bool
-
-
// NonceCache for replay protection (optional but recommended)
-
// If nil, jti replay protection is disabled
-
NonceCache *NonceCache
-
-
// Maximum allowed clock skew for timestamp validation
-
MaxClockSkew time.Duration
-
-
// Maximum age of DPoP proof (prevents replay with old proofs)
-
MaxProofAge time.Duration
-
}
-
-
// NewDPoPVerifier creates a DPoP verifier with sensible defaults including replay protection
-
func NewDPoPVerifier() *DPoPVerifier {
-
maxProofAge := 5 * time.Minute
-
return &DPoPVerifier{
-
MaxClockSkew: 30 * time.Second,
-
MaxProofAge: maxProofAge,
-
NonceCache: NewNonceCache(maxProofAge),
-
}
-
}
-
-
// NewDPoPVerifierWithoutReplayProtection creates a DPoP verifier without replay protection.
-
// This should only be used in testing or when replay protection is handled externally.
-
func NewDPoPVerifierWithoutReplayProtection() *DPoPVerifier {
-
return &DPoPVerifier{
-
MaxClockSkew: 30 * time.Second,
-
MaxProofAge: 5 * time.Minute,
-
NonceCache: nil, // No replay protection
-
}
-
}
-
-
// Stop stops background goroutines. Call this when shutting down.
-
func (v *DPoPVerifier) Stop() {
-
if v.NonceCache != nil {
-
v.NonceCache.Stop()
-
}
-
}
-
-
// VerifyDPoPProof verifies a DPoP proof JWT and returns the parsed proof.
-
// This supports all atProto-compatible ECDSA algorithms including ES256K (secp256k1).
-
func (v *DPoPVerifier) VerifyDPoPProof(dpopProof, httpMethod, httpURI string) (*DPoPProof, error) {
-
// Manually parse the JWT to support ES256K (which golang-jwt doesn't recognize)
-
header, claims, err := parseJWTHeaderAndClaims(dpopProof)
-
if err != nil {
-
return nil, fmt.Errorf("failed to parse DPoP proof: %w", err)
-
}
-
-
// Extract and validate the typ header
-
typ, ok := header["typ"].(string)
-
if !ok || typ != "dpop+jwt" {
-
return nil, fmt.Errorf("invalid DPoP proof: typ must be 'dpop+jwt', got '%s'", typ)
-
}
-
-
alg, ok := header["alg"].(string)
-
if !ok {
-
return nil, fmt.Errorf("invalid DPoP proof: missing alg header")
-
}
-
-
// Extract the JWK from the header first (needed for algorithm-curve validation)
-
jwkRaw, ok := header["jwk"]
-
if !ok {
-
return nil, fmt.Errorf("invalid DPoP proof: missing jwk header")
-
}
-
-
jwkMap, ok := jwkRaw.(map[string]interface{})
-
if !ok {
-
return nil, fmt.Errorf("invalid DPoP proof: jwk must be an object")
-
}
-
-
// Validate the algorithm is supported and matches the JWK curve
-
// This is critical for security - prevents algorithm confusion attacks
-
if err := validateAlgorithmCurveBinding(alg, jwkMap); err != nil {
-
return nil, fmt.Errorf("invalid DPoP proof: %w", err)
-
}
-
-
// Parse the public key using indigo's crypto package
-
// This supports all atProto curves including secp256k1 (ES256K)
-
publicKey, err := parseJWKToIndigoPublicKey(jwkMap)
-
if err != nil {
-
return nil, fmt.Errorf("invalid DPoP proof JWK: %w", err)
-
}
-
-
// Calculate the JWK thumbprint
-
thumbprint, err := CalculateJWKThumbprint(jwkMap)
-
if err != nil {
-
return nil, fmt.Errorf("failed to calculate JWK thumbprint: %w", err)
-
}
-
-
// Verify the signature using indigo's crypto package
-
// This works for all ECDSA algorithms including ES256K
-
if err := verifyJWTSignatureWithIndigo(dpopProof, publicKey); err != nil {
-
return nil, fmt.Errorf("DPoP proof signature verification failed: %w", err)
-
}
-
-
// Validate the claims
-
if err := v.validateDPoPClaims(claims, httpMethod, httpURI); err != nil {
-
return nil, err
-
}
-
-
return &DPoPProof{
-
Claims: claims,
-
PublicKey: publicKey,
-
Thumbprint: thumbprint,
-
RawPublicJWK: jwkMap,
-
}, nil
-
}
-
-
// validateDPoPClaims validates the DPoP proof claims
-
func (v *DPoPVerifier) validateDPoPClaims(claims *DPoPClaims, expectedMethod, expectedURI string) error {
-
// Validate jti (unique identifier) is present
-
if claims.ID == "" {
-
return fmt.Errorf("DPoP proof missing jti claim")
-
}
-
-
// Validate htm (HTTP method)
-
if !strings.EqualFold(claims.HTTPMethod, expectedMethod) {
-
return fmt.Errorf("DPoP proof htm mismatch: expected %s, got %s", expectedMethod, claims.HTTPMethod)
-
}
-
-
// Validate htu (HTTP URI) - compare without query/fragment
-
expectedURIBase := stripQueryFragment(expectedURI)
-
claimURIBase := stripQueryFragment(claims.HTTPURI)
-
if expectedURIBase != claimURIBase {
-
return fmt.Errorf("DPoP proof htu mismatch: expected %s, got %s", expectedURIBase, claimURIBase)
-
}
-
-
// Validate iat (issued at) is present and recent
-
if claims.IssuedAt == nil {
-
return fmt.Errorf("DPoP proof missing iat claim")
-
}
-
-
now := time.Now()
-
iat := claims.IssuedAt.Time
-
-
// Check clock skew (not too far in the future)
-
if iat.After(now.Add(v.MaxClockSkew)) {
-
return fmt.Errorf("DPoP proof iat is in the future")
-
}
-
-
// Check proof age (not too old)
-
if now.Sub(iat) > v.MaxProofAge {
-
return fmt.Errorf("DPoP proof is too old (issued %v ago, max %v)", now.Sub(iat), v.MaxProofAge)
-
}
-
-
// SECURITY: Validate exp claim if present (RFC standard JWT validation)
-
// While DPoP proofs typically use iat + MaxProofAge, if exp is included it must be honored
-
if claims.ExpiresAt != nil {
-
expWithSkew := claims.ExpiresAt.Time.Add(v.MaxClockSkew)
-
if now.After(expWithSkew) {
-
return fmt.Errorf("DPoP proof expired at %v", claims.ExpiresAt.Time)
-
}
-
}
-
-
// SECURITY: Validate nbf claim if present (RFC standard JWT validation)
-
if claims.NotBefore != nil {
-
nbfWithSkew := claims.NotBefore.Time.Add(-v.MaxClockSkew)
-
if now.Before(nbfWithSkew) {
-
return fmt.Errorf("DPoP proof not valid before %v", claims.NotBefore.Time)
-
}
-
}
-
-
// SECURITY: Check for replay attack using jti
-
// Per RFC 9449 Section 11.1, servers SHOULD prevent replay attacks
-
if v.NonceCache != nil {
-
if !v.NonceCache.CheckAndStore(claims.ID) {
-
return fmt.Errorf("DPoP proof replay detected: jti %s already used", claims.ID)
-
}
-
}
-
-
return nil
-
}
-
-
// VerifyTokenBinding verifies that the DPoP proof binds to the access token
-
// by comparing the proof's thumbprint to the token's cnf.jkt claim
-
func (v *DPoPVerifier) VerifyTokenBinding(proof *DPoPProof, expectedThumbprint string) error {
-
if proof.Thumbprint != expectedThumbprint {
-
return fmt.Errorf("DPoP proof thumbprint mismatch: token expects %s, proof has %s",
-
expectedThumbprint, proof.Thumbprint)
-
}
-
return nil
-
}
-
-
// VerifyAccessTokenHash verifies the DPoP proof's ath (access token hash) claim
-
// matches the SHA-256 hash of the presented access token.
-
// Per RFC 9449 section 4.2, if ath is present, the RS MUST verify it.
-
func (v *DPoPVerifier) VerifyAccessTokenHash(proof *DPoPProof, accessToken string) error {
-
// If ath claim is not present, that's acceptable per RFC 9449
-
// (ath is only required when the RS mandates it)
-
if proof.Claims.AccessTokenHash == "" {
-
return nil
-
}
-
-
// Calculate the expected ath: base64url(SHA-256(access_token))
-
hash := sha256.Sum256([]byte(accessToken))
-
expectedAth := base64.RawURLEncoding.EncodeToString(hash[:])
-
-
if proof.Claims.AccessTokenHash != expectedAth {
-
return fmt.Errorf("DPoP proof ath mismatch: proof bound to different access token")
-
}
-
-
return nil
-
}
-
-
// CalculateJWKThumbprint calculates the JWK thumbprint per RFC 7638
-
// The thumbprint is the base64url-encoded SHA-256 hash of the canonical JWK representation
-
func CalculateJWKThumbprint(jwk map[string]interface{}) (string, error) {
-
kty, ok := jwk["kty"].(string)
-
if !ok {
-
return "", fmt.Errorf("JWK missing kty")
-
}
-
-
// Build the canonical JWK representation based on key type
-
// Per RFC 7638, only specific members are included, in lexicographic order
-
var canonical map[string]string
-
-
switch kty {
-
case "EC":
-
crv, ok := jwk["crv"].(string)
-
if !ok {
-
return "", fmt.Errorf("EC JWK missing crv")
-
}
-
x, ok := jwk["x"].(string)
-
if !ok {
-
return "", fmt.Errorf("EC JWK missing x")
-
}
-
y, ok := jwk["y"].(string)
-
if !ok {
-
return "", fmt.Errorf("EC JWK missing y")
-
}
-
// Lexicographic order: crv, kty, x, y
-
canonical = map[string]string{
-
"crv": crv,
-
"kty": kty,
-
"x": x,
-
"y": y,
-
}
-
case "RSA":
-
e, ok := jwk["e"].(string)
-
if !ok {
-
return "", fmt.Errorf("RSA JWK missing e")
-
}
-
n, ok := jwk["n"].(string)
-
if !ok {
-
return "", fmt.Errorf("RSA JWK missing n")
-
}
-
// Lexicographic order: e, kty, n
-
canonical = map[string]string{
-
"e": e,
-
"kty": kty,
-
"n": n,
-
}
-
case "OKP":
-
crv, ok := jwk["crv"].(string)
-
if !ok {
-
return "", fmt.Errorf("OKP JWK missing crv")
-
}
-
x, ok := jwk["x"].(string)
-
if !ok {
-
return "", fmt.Errorf("OKP JWK missing x")
-
}
-
// Lexicographic order: crv, kty, x
-
canonical = map[string]string{
-
"crv": crv,
-
"kty": kty,
-
"x": x,
-
}
-
default:
-
return "", fmt.Errorf("unsupported JWK key type: %s", kty)
-
}
-
-
// Serialize to JSON (Go's json.Marshal produces lexicographically ordered keys for map[string]string)
-
canonicalJSON, err := json.Marshal(canonical)
-
if err != nil {
-
return "", fmt.Errorf("failed to serialize canonical JWK: %w", err)
-
}
-
-
// SHA-256 hash
-
hash := sha256.Sum256(canonicalJSON)
-
-
// Base64url encode (no padding)
-
thumbprint := base64.RawURLEncoding.EncodeToString(hash[:])
-
-
return thumbprint, nil
-
}
-
-
// validateAlgorithmCurveBinding validates that the JWT algorithm matches the JWK curve.
-
// This is critical for security - an attacker could claim alg: "ES256K" but provide
-
// a P-256 key, potentially bypassing algorithm binding requirements.
-
func validateAlgorithmCurveBinding(alg string, jwkMap map[string]interface{}) error {
-
kty, ok := jwkMap["kty"].(string)
-
if !ok {
-
return fmt.Errorf("JWK missing kty")
-
}
-
-
// ECDSA algorithms require EC key type
-
switch alg {
-
case "ES256K", "ES256", "ES384", "ES512":
-
if kty != "EC" {
-
return fmt.Errorf("algorithm %s requires EC key type, got %s", alg, kty)
-
}
-
case "RS256", "RS384", "RS512", "PS256", "PS384", "PS512":
-
return fmt.Errorf("RSA algorithms not yet supported for DPoP: %s", alg)
-
default:
-
return fmt.Errorf("unsupported DPoP algorithm: %s", alg)
-
}
-
-
// Validate curve matches algorithm
-
crv, ok := jwkMap["crv"].(string)
-
if !ok {
-
return fmt.Errorf("EC JWK missing crv")
-
}
-
-
var expectedCurve string
-
switch alg {
-
case "ES256K":
-
expectedCurve = "secp256k1"
-
case "ES256":
-
expectedCurve = "P-256"
-
case "ES384":
-
expectedCurve = "P-384"
-
case "ES512":
-
expectedCurve = "P-521"
-
}
-
-
if crv != expectedCurve {
-
return fmt.Errorf("algorithm %s requires curve %s, got %s", alg, expectedCurve, crv)
-
}
-
-
return nil
-
}
-
-
// parseJWKToIndigoPublicKey parses a JWK map to an indigo PublicKey.
-
// This returns indigo's PublicKey interface which supports all atProto curves
-
// including secp256k1 (ES256K), P-256 (ES256), P-384 (ES384), and P-521 (ES512).
-
func parseJWKToIndigoPublicKey(jwkMap map[string]interface{}) (indigoCrypto.PublicKey, error) {
-
// Convert map to JSON bytes for indigo's parser
-
jwkBytes, err := json.Marshal(jwkMap)
-
if err != nil {
-
return nil, fmt.Errorf("failed to serialize JWK: %w", err)
-
}
-
-
// Parse with indigo's crypto package - this supports all atProto curves
-
// including secp256k1 (ES256K) which Go's crypto/elliptic doesn't support
-
pubKey, err := indigoCrypto.ParsePublicJWKBytes(jwkBytes)
-
if err != nil {
-
return nil, fmt.Errorf("failed to parse JWK: %w", err)
-
}
-
-
return pubKey, nil
-
}
-
-
// parseJWTHeaderAndClaims manually parses a JWT's header and claims without using golang-jwt.
-
// This is necessary to support ES256K (secp256k1) which golang-jwt doesn't recognize.
-
func parseJWTHeaderAndClaims(tokenString string) (map[string]interface{}, *DPoPClaims, error) {
-
parts := strings.Split(tokenString, ".")
-
if len(parts) != 3 {
-
return nil, nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
-
}
-
-
// Decode header
-
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
-
if err != nil {
-
return nil, nil, fmt.Errorf("failed to decode JWT header: %w", err)
-
}
-
-
var header map[string]interface{}
-
if err := json.Unmarshal(headerBytes, &header); err != nil {
-
return nil, nil, fmt.Errorf("failed to parse JWT header: %w", err)
-
}
-
-
// Decode claims
-
claimsBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
-
if err != nil {
-
return nil, nil, fmt.Errorf("failed to decode JWT claims: %w", err)
-
}
-
-
// Parse into raw map first to extract standard claims
-
var rawClaims map[string]interface{}
-
if err := json.Unmarshal(claimsBytes, &rawClaims); err != nil {
-
return nil, nil, fmt.Errorf("failed to parse JWT claims: %w", err)
-
}
-
-
// Build DPoPClaims struct
-
claims := &DPoPClaims{}
-
-
// Extract jti
-
if jti, ok := rawClaims["jti"].(string); ok {
-
claims.ID = jti
-
}
-
-
// Extract iat (issued at)
-
if iat, ok := rawClaims["iat"].(float64); ok {
-
t := time.Unix(int64(iat), 0)
-
claims.IssuedAt = jwt.NewNumericDate(t)
-
}
-
-
// Extract exp (expiration) if present
-
if exp, ok := rawClaims["exp"].(float64); ok {
-
t := time.Unix(int64(exp), 0)
-
claims.ExpiresAt = jwt.NewNumericDate(t)
-
}
-
-
// Extract nbf (not before) if present
-
if nbf, ok := rawClaims["nbf"].(float64); ok {
-
t := time.Unix(int64(nbf), 0)
-
claims.NotBefore = jwt.NewNumericDate(t)
-
}
-
-
// Extract htm (HTTP method)
-
if htm, ok := rawClaims["htm"].(string); ok {
-
claims.HTTPMethod = htm
-
}
-
-
// Extract htu (HTTP URI)
-
if htu, ok := rawClaims["htu"].(string); ok {
-
claims.HTTPURI = htu
-
}
-
-
// Extract ath (access token hash) if present
-
if ath, ok := rawClaims["ath"].(string); ok {
-
claims.AccessTokenHash = ath
-
}
-
-
return header, claims, nil
-
}
-
-
// verifyJWTSignatureWithIndigo verifies a JWT signature using indigo's crypto package.
-
// This is used instead of golang-jwt for algorithms not supported by golang-jwt (like ES256K).
-
// It parses the JWT, extracts the signing input and signature, and uses indigo's
-
// PublicKey.HashAndVerifyLenient() for verification.
-
//
-
// JWT format: header.payload.signature (all base64url-encoded)
-
// Signature is verified over the raw bytes of "header.payload"
-
// (indigo's HashAndVerifyLenient handles SHA-256 hashing internally)
-
func verifyJWTSignatureWithIndigo(tokenString string, pubKey indigoCrypto.PublicKey) error {
-
// Split the JWT into parts
-
parts := strings.Split(tokenString, ".")
-
if len(parts) != 3 {
-
return fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
-
}
-
-
// The signing input is "header.payload" (without decoding)
-
signingInput := parts[0] + "." + parts[1]
-
-
// Decode the signature from base64url
-
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
-
if err != nil {
-
return fmt.Errorf("failed to decode JWT signature: %w", err)
-
}
-
-
// Use indigo's verification - HashAndVerifyLenient handles hashing internally
-
// and accepts both low-S and high-S signatures for maximum compatibility
-
err = pubKey.HashAndVerifyLenient([]byte(signingInput), signature)
-
if err != nil {
-
return fmt.Errorf("signature verification failed: %w", err)
-
}
-
-
return nil
-
}
-
-
// stripQueryFragment removes query and fragment from a URI
-
func stripQueryFragment(uri string) string {
-
if idx := strings.Index(uri, "?"); idx != -1 {
-
uri = uri[:idx]
-
}
-
if idx := strings.Index(uri, "#"); idx != -1 {
-
uri = uri[:idx]
-
}
-
return uri
-
}
-
-
// ExtractCnfJkt extracts the cnf.jkt (confirmation key thumbprint) from JWT claims
-
func ExtractCnfJkt(claims *Claims) (string, error) {
-
if claims.Confirmation == nil {
-
return "", fmt.Errorf("token missing cnf claim (no DPoP binding)")
-
}
-
-
jkt, ok := claims.Confirmation["jkt"].(string)
-
if !ok || jkt == "" {
-
return "", fmt.Errorf("token cnf claim missing jkt (DPoP key thumbprint)")
-
}
-
-
return jkt, nil
-
}
-1308
internal/atproto/auth/dpop_test.go
···
-
package auth
-
-
import (
-
"crypto/ecdsa"
-
"crypto/elliptic"
-
"crypto/rand"
-
"crypto/sha256"
-
"encoding/base64"
-
"encoding/json"
-
"strings"
-
"testing"
-
"time"
-
-
indigoCrypto "github.com/bluesky-social/indigo/atproto/atcrypto"
-
"github.com/golang-jwt/jwt/v5"
-
"github.com/google/uuid"
-
)
-
-
// === Test Helpers ===
-
-
// testECKey holds a test ES256 key pair
-
type testECKey struct {
-
privateKey *ecdsa.PrivateKey
-
publicKey *ecdsa.PublicKey
-
jwk map[string]interface{}
-
thumbprint string
-
}
-
-
// generateTestES256Key generates a test ES256 key pair and JWK
-
func generateTestES256Key(t *testing.T) *testECKey {
-
t.Helper()
-
-
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
-
if err != nil {
-
t.Fatalf("Failed to generate test key: %v", err)
-
}
-
-
// Encode public key coordinates as base64url
-
xBytes := privateKey.PublicKey.X.Bytes()
-
yBytes := privateKey.PublicKey.Y.Bytes()
-
-
// P-256 coordinates must be 32 bytes (pad if needed)
-
xBytes = padTo32Bytes(xBytes)
-
yBytes = padTo32Bytes(yBytes)
-
-
x := base64.RawURLEncoding.EncodeToString(xBytes)
-
y := base64.RawURLEncoding.EncodeToString(yBytes)
-
-
jwk := map[string]interface{}{
-
"kty": "EC",
-
"crv": "P-256",
-
"x": x,
-
"y": y,
-
}
-
-
// Calculate thumbprint
-
thumbprint, err := CalculateJWKThumbprint(jwk)
-
if err != nil {
-
t.Fatalf("Failed to calculate thumbprint: %v", err)
-
}
-
-
return &testECKey{
-
privateKey: privateKey,
-
publicKey: &privateKey.PublicKey,
-
jwk: jwk,
-
thumbprint: thumbprint,
-
}
-
}
-
-
// padTo32Bytes pads a byte slice to 32 bytes (required for P-256 coordinates)
-
func padTo32Bytes(b []byte) []byte {
-
if len(b) >= 32 {
-
return b
-
}
-
padded := make([]byte, 32)
-
copy(padded[32-len(b):], b)
-
return padded
-
}
-
-
// createDPoPProof creates a DPoP proof JWT for testing
-
func createDPoPProof(t *testing.T, key *testECKey, method, uri string, iat time.Time, jti string) string {
-
t.Helper()
-
-
claims := &DPoPClaims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
ID: jti,
-
IssuedAt: jwt.NewNumericDate(iat),
-
},
-
HTTPMethod: method,
-
HTTPURI: uri,
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
-
token.Header["typ"] = "dpop+jwt"
-
token.Header["jwk"] = key.jwk
-
-
tokenString, err := token.SignedString(key.privateKey)
-
if err != nil {
-
t.Fatalf("Failed to create DPoP proof: %v", err)
-
}
-
-
return tokenString
-
}
-
-
// === JWK Thumbprint Tests (RFC 7638) ===
-
-
func TestCalculateJWKThumbprint_EC_P256(t *testing.T) {
-
// Test with known values from RFC 7638 Appendix A (adapted for P-256)
-
jwk := map[string]interface{}{
-
"kty": "EC",
-
"crv": "P-256",
-
"x": "WKn-ZIGevcwGIyyrzFoZNBdaq9_TsqzGl96oc0CWuis",
-
"y": "y77t-RvAHRKTsSGdIYUfweuOvwrvDD-Q3Hv5J0fSKbE",
-
}
-
-
thumbprint, err := CalculateJWKThumbprint(jwk)
-
if err != nil {
-
t.Fatalf("CalculateJWKThumbprint failed: %v", err)
-
}
-
-
if thumbprint == "" {
-
t.Error("Expected non-empty thumbprint")
-
}
-
-
// Verify it's valid base64url
-
_, err = base64.RawURLEncoding.DecodeString(thumbprint)
-
if err != nil {
-
t.Errorf("Thumbprint is not valid base64url: %v", err)
-
}
-
-
// Verify length (SHA-256 produces 32 bytes = 43 base64url chars)
-
if len(thumbprint) != 43 {
-
t.Errorf("Expected thumbprint length 43, got %d", len(thumbprint))
-
}
-
}
-
-
func TestCalculateJWKThumbprint_Deterministic(t *testing.T) {
-
// Same key should produce same thumbprint
-
jwk := map[string]interface{}{
-
"kty": "EC",
-
"crv": "P-256",
-
"x": "test-x-coordinate",
-
"y": "test-y-coordinate",
-
}
-
-
thumbprint1, err := CalculateJWKThumbprint(jwk)
-
if err != nil {
-
t.Fatalf("First CalculateJWKThumbprint failed: %v", err)
-
}
-
-
thumbprint2, err := CalculateJWKThumbprint(jwk)
-
if err != nil {
-
t.Fatalf("Second CalculateJWKThumbprint failed: %v", err)
-
}
-
-
if thumbprint1 != thumbprint2 {
-
t.Errorf("Thumbprints are not deterministic: %s != %s", thumbprint1, thumbprint2)
-
}
-
}
-
-
func TestCalculateJWKThumbprint_DifferentKeys(t *testing.T) {
-
// Different keys should produce different thumbprints
-
jwk1 := map[string]interface{}{
-
"kty": "EC",
-
"crv": "P-256",
-
"x": "coordinate-x-1",
-
"y": "coordinate-y-1",
-
}
-
-
jwk2 := map[string]interface{}{
-
"kty": "EC",
-
"crv": "P-256",
-
"x": "coordinate-x-2",
-
"y": "coordinate-y-2",
-
}
-
-
thumbprint1, err := CalculateJWKThumbprint(jwk1)
-
if err != nil {
-
t.Fatalf("First CalculateJWKThumbprint failed: %v", err)
-
}
-
-
thumbprint2, err := CalculateJWKThumbprint(jwk2)
-
if err != nil {
-
t.Fatalf("Second CalculateJWKThumbprint failed: %v", err)
-
}
-
-
if thumbprint1 == thumbprint2 {
-
t.Error("Different keys produced same thumbprint (collision)")
-
}
-
}
-
-
func TestCalculateJWKThumbprint_MissingKty(t *testing.T) {
-
jwk := map[string]interface{}{
-
"crv": "P-256",
-
"x": "test-x",
-
"y": "test-y",
-
}
-
-
_, err := CalculateJWKThumbprint(jwk)
-
if err == nil {
-
t.Error("Expected error for missing kty, got nil")
-
}
-
if err != nil && !contains(err.Error(), "missing kty") {
-
t.Errorf("Expected error about missing kty, got: %v", err)
-
}
-
}
-
-
func TestCalculateJWKThumbprint_EC_MissingCrv(t *testing.T) {
-
jwk := map[string]interface{}{
-
"kty": "EC",
-
"x": "test-x",
-
"y": "test-y",
-
}
-
-
_, err := CalculateJWKThumbprint(jwk)
-
if err == nil {
-
t.Error("Expected error for missing crv, got nil")
-
}
-
if err != nil && !contains(err.Error(), "missing crv") {
-
t.Errorf("Expected error about missing crv, got: %v", err)
-
}
-
}
-
-
func TestCalculateJWKThumbprint_EC_MissingX(t *testing.T) {
-
jwk := map[string]interface{}{
-
"kty": "EC",
-
"crv": "P-256",
-
"y": "test-y",
-
}
-
-
_, err := CalculateJWKThumbprint(jwk)
-
if err == nil {
-
t.Error("Expected error for missing x, got nil")
-
}
-
if err != nil && !contains(err.Error(), "missing x") {
-
t.Errorf("Expected error about missing x, got: %v", err)
-
}
-
}
-
-
func TestCalculateJWKThumbprint_EC_MissingY(t *testing.T) {
-
jwk := map[string]interface{}{
-
"kty": "EC",
-
"crv": "P-256",
-
"x": "test-x",
-
}
-
-
_, err := CalculateJWKThumbprint(jwk)
-
if err == nil {
-
t.Error("Expected error for missing y, got nil")
-
}
-
if err != nil && !contains(err.Error(), "missing y") {
-
t.Errorf("Expected error about missing y, got: %v", err)
-
}
-
}
-
-
func TestCalculateJWKThumbprint_RSA(t *testing.T) {
-
// Test RSA key thumbprint calculation
-
jwk := map[string]interface{}{
-
"kty": "RSA",
-
"e": "AQAB",
-
"n": "test-modulus",
-
}
-
-
thumbprint, err := CalculateJWKThumbprint(jwk)
-
if err != nil {
-
t.Fatalf("CalculateJWKThumbprint failed for RSA: %v", err)
-
}
-
-
if thumbprint == "" {
-
t.Error("Expected non-empty thumbprint for RSA key")
-
}
-
}
-
-
func TestCalculateJWKThumbprint_OKP(t *testing.T) {
-
// Test OKP (Octet Key Pair) thumbprint calculation
-
jwk := map[string]interface{}{
-
"kty": "OKP",
-
"crv": "Ed25519",
-
"x": "test-x-coordinate",
-
}
-
-
thumbprint, err := CalculateJWKThumbprint(jwk)
-
if err != nil {
-
t.Fatalf("CalculateJWKThumbprint failed for OKP: %v", err)
-
}
-
-
if thumbprint == "" {
-
t.Error("Expected non-empty thumbprint for OKP key")
-
}
-
}
-
-
func TestCalculateJWKThumbprint_UnsupportedKeyType(t *testing.T) {
-
jwk := map[string]interface{}{
-
"kty": "UNKNOWN",
-
}
-
-
_, err := CalculateJWKThumbprint(jwk)
-
if err == nil {
-
t.Error("Expected error for unsupported key type, got nil")
-
}
-
if err != nil && !contains(err.Error(), "unsupported JWK key type") {
-
t.Errorf("Expected error about unsupported key type, got: %v", err)
-
}
-
}
-
-
func TestCalculateJWKThumbprint_CanonicalJSON(t *testing.T) {
-
// RFC 7638 requires lexicographic ordering of keys in canonical JSON
-
// This test verifies that the canonical JSON is correctly ordered
-
-
jwk := map[string]interface{}{
-
"kty": "EC",
-
"crv": "P-256",
-
"x": "x-coord",
-
"y": "y-coord",
-
}
-
-
// The canonical JSON should be: {"crv":"P-256","kty":"EC","x":"x-coord","y":"y-coord"}
-
// (lexicographically ordered: crv, kty, x, y)
-
-
canonical := map[string]string{
-
"crv": "P-256",
-
"kty": "EC",
-
"x": "x-coord",
-
"y": "y-coord",
-
}
-
-
canonicalJSON, err := json.Marshal(canonical)
-
if err != nil {
-
t.Fatalf("Failed to marshal canonical JSON: %v", err)
-
}
-
-
expectedHash := sha256.Sum256(canonicalJSON)
-
expectedThumbprint := base64.RawURLEncoding.EncodeToString(expectedHash[:])
-
-
actualThumbprint, err := CalculateJWKThumbprint(jwk)
-
if err != nil {
-
t.Fatalf("CalculateJWKThumbprint failed: %v", err)
-
}
-
-
if actualThumbprint != expectedThumbprint {
-
t.Errorf("Thumbprint doesn't match expected canonical JSON hash\nExpected: %s\nGot: %s",
-
expectedThumbprint, actualThumbprint)
-
}
-
}
-
-
// === DPoP Proof Verification Tests ===
-
-
func TestVerifyDPoPProof_Valid(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, uri, iat, jti)
-
-
result, err := verifier.VerifyDPoPProof(proof, method, uri)
-
if err != nil {
-
t.Fatalf("VerifyDPoPProof failed for valid proof: %v", err)
-
}
-
-
if result == nil {
-
t.Fatal("Expected non-nil proof result")
-
}
-
-
if result.Claims.HTTPMethod != method {
-
t.Errorf("Expected method %s, got %s", method, result.Claims.HTTPMethod)
-
}
-
-
if result.Claims.HTTPURI != uri {
-
t.Errorf("Expected URI %s, got %s", uri, result.Claims.HTTPURI)
-
}
-
-
if result.Claims.ID != jti {
-
t.Errorf("Expected jti %s, got %s", jti, result.Claims.ID)
-
}
-
-
if result.Thumbprint != key.thumbprint {
-
t.Errorf("Expected thumbprint %s, got %s", key.thumbprint, result.Thumbprint)
-
}
-
}
-
-
func TestVerifyDPoPProof_InvalidSignature(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
wrongKey := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
// Create proof with one key
-
proof := createDPoPProof(t, key, method, uri, iat, jti)
-
-
// Parse and modify to use wrong key's JWK in header (signature won't match)
-
parts := splitJWT(proof)
-
header := parseJWTHeader(t, parts[0])
-
header["jwk"] = wrongKey.jwk
-
modifiedHeader := encodeJSON(t, header)
-
tamperedProof := modifiedHeader + "." + parts[1] + "." + parts[2]
-
-
_, err := verifier.VerifyDPoPProof(tamperedProof, method, uri)
-
if err == nil {
-
t.Error("Expected error for invalid signature, got nil")
-
}
-
if err != nil && !contains(err.Error(), "signature verification failed") {
-
t.Errorf("Expected signature verification error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_WrongHTTPMethod(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
wrongMethod := "GET"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, uri, iat, jti)
-
-
_, err := verifier.VerifyDPoPProof(proof, wrongMethod, uri)
-
if err == nil {
-
t.Error("Expected error for HTTP method mismatch, got nil")
-
}
-
if err != nil && !contains(err.Error(), "htm mismatch") {
-
t.Errorf("Expected htm mismatch error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_WrongURI(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
wrongURI := "https://api.example.com/different"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, uri, iat, jti)
-
-
_, err := verifier.VerifyDPoPProof(proof, method, wrongURI)
-
if err == nil {
-
t.Error("Expected error for URI mismatch, got nil")
-
}
-
if err != nil && !contains(err.Error(), "htu mismatch") {
-
t.Errorf("Expected htu mismatch error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_URIWithQuery(t *testing.T) {
-
// URI comparison should strip query and fragment
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
baseURI := "https://api.example.com/resource"
-
uriWithQuery := baseURI + "?param=value"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, baseURI, iat, jti)
-
-
// Should succeed because query is stripped
-
_, err := verifier.VerifyDPoPProof(proof, method, uriWithQuery)
-
if err != nil {
-
t.Fatalf("VerifyDPoPProof failed for URI with query: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_URIWithFragment(t *testing.T) {
-
// URI comparison should strip query and fragment
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
baseURI := "https://api.example.com/resource"
-
uriWithFragment := baseURI + "#section"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, baseURI, iat, jti)
-
-
// Should succeed because fragment is stripped
-
_, err := verifier.VerifyDPoPProof(proof, method, uriWithFragment)
-
if err != nil {
-
t.Fatalf("VerifyDPoPProof failed for URI with fragment: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_ExpiredProof(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
// Proof issued 10 minutes ago (exceeds default MaxProofAge of 5 minutes)
-
iat := time.Now().Add(-10 * time.Minute)
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, uri, iat, jti)
-
-
_, err := verifier.VerifyDPoPProof(proof, method, uri)
-
if err == nil {
-
t.Error("Expected error for expired proof, got nil")
-
}
-
if err != nil && !contains(err.Error(), "too old") {
-
t.Errorf("Expected 'too old' error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_FutureProof(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
// Proof issued 1 minute in the future (exceeds MaxClockSkew)
-
iat := time.Now().Add(1 * time.Minute)
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, uri, iat, jti)
-
-
_, err := verifier.VerifyDPoPProof(proof, method, uri)
-
if err == nil {
-
t.Error("Expected error for future proof, got nil")
-
}
-
if err != nil && !contains(err.Error(), "in the future") {
-
t.Errorf("Expected 'in the future' error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_WithinClockSkew(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
// Proof issued 15 seconds in the future (within MaxClockSkew of 30s)
-
iat := time.Now().Add(15 * time.Second)
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, uri, iat, jti)
-
-
_, err := verifier.VerifyDPoPProof(proof, method, uri)
-
if err != nil {
-
t.Fatalf("VerifyDPoPProof failed for proof within clock skew: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_MissingJti(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
-
claims := &DPoPClaims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
// No ID (jti)
-
IssuedAt: jwt.NewNumericDate(iat),
-
},
-
HTTPMethod: method,
-
HTTPURI: uri,
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
-
token.Header["typ"] = "dpop+jwt"
-
token.Header["jwk"] = key.jwk
-
-
proof, err := token.SignedString(key.privateKey)
-
if err != nil {
-
t.Fatalf("Failed to create test proof: %v", err)
-
}
-
-
_, err = verifier.VerifyDPoPProof(proof, method, uri)
-
if err == nil {
-
t.Error("Expected error for missing jti, got nil")
-
}
-
if err != nil && !contains(err.Error(), "missing jti") {
-
t.Errorf("Expected missing jti error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_MissingTypHeader(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
claims := &DPoPClaims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
ID: jti,
-
IssuedAt: jwt.NewNumericDate(iat),
-
},
-
HTTPMethod: method,
-
HTTPURI: uri,
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
-
// Don't set typ header
-
token.Header["jwk"] = key.jwk
-
-
proof, err := token.SignedString(key.privateKey)
-
if err != nil {
-
t.Fatalf("Failed to create test proof: %v", err)
-
}
-
-
_, err = verifier.VerifyDPoPProof(proof, method, uri)
-
if err == nil {
-
t.Error("Expected error for missing typ header, got nil")
-
}
-
if err != nil && !contains(err.Error(), "typ must be 'dpop+jwt'") {
-
t.Errorf("Expected typ header error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_WrongTypHeader(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
claims := &DPoPClaims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
ID: jti,
-
IssuedAt: jwt.NewNumericDate(iat),
-
},
-
HTTPMethod: method,
-
HTTPURI: uri,
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
-
token.Header["typ"] = "JWT" // Wrong typ
-
token.Header["jwk"] = key.jwk
-
-
proof, err := token.SignedString(key.privateKey)
-
if err != nil {
-
t.Fatalf("Failed to create test proof: %v", err)
-
}
-
-
_, err = verifier.VerifyDPoPProof(proof, method, uri)
-
if err == nil {
-
t.Error("Expected error for wrong typ header, got nil")
-
}
-
if err != nil && !contains(err.Error(), "typ must be 'dpop+jwt'") {
-
t.Errorf("Expected typ header error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_MissingJWK(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
claims := &DPoPClaims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
ID: jti,
-
IssuedAt: jwt.NewNumericDate(iat),
-
},
-
HTTPMethod: method,
-
HTTPURI: uri,
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
-
token.Header["typ"] = "dpop+jwt"
-
// Don't include JWK
-
-
proof, err := token.SignedString(key.privateKey)
-
if err != nil {
-
t.Fatalf("Failed to create test proof: %v", err)
-
}
-
-
_, err = verifier.VerifyDPoPProof(proof, method, uri)
-
if err == nil {
-
t.Error("Expected error for missing jwk header, got nil")
-
}
-
if err != nil && !contains(err.Error(), "missing jwk") {
-
t.Errorf("Expected missing jwk error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_CustomTimeSettings(t *testing.T) {
-
verifier := &DPoPVerifier{
-
MaxClockSkew: 1 * time.Minute,
-
MaxProofAge: 10 * time.Minute,
-
}
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
// Proof issued 50 seconds in the future (within custom MaxClockSkew)
-
iat := time.Now().Add(50 * time.Second)
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, uri, iat, jti)
-
-
_, err := verifier.VerifyDPoPProof(proof, method, uri)
-
if err != nil {
-
t.Fatalf("VerifyDPoPProof failed with custom time settings: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_HTTPMethodCaseInsensitive(t *testing.T) {
-
// HTTP method comparison should be case-insensitive per spec
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "post"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, uri, iat, jti)
-
-
// Verify with uppercase method
-
_, err := verifier.VerifyDPoPProof(proof, "POST", uri)
-
if err != nil {
-
t.Fatalf("VerifyDPoPProof failed for case-insensitive method: %v", err)
-
}
-
}
-
-
// === Token Binding Verification Tests ===
-
-
func TestVerifyTokenBinding_Matching(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, uri, iat, jti)
-
-
result, err := verifier.VerifyDPoPProof(proof, method, uri)
-
if err != nil {
-
t.Fatalf("VerifyDPoPProof failed: %v", err)
-
}
-
-
// Verify token binding with matching thumbprint
-
err = verifier.VerifyTokenBinding(result, key.thumbprint)
-
if err != nil {
-
t.Fatalf("VerifyTokenBinding failed for matching thumbprint: %v", err)
-
}
-
}
-
-
func TestVerifyTokenBinding_Mismatch(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
wrongKey := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
proof := createDPoPProof(t, key, method, uri, iat, jti)
-
-
result, err := verifier.VerifyDPoPProof(proof, method, uri)
-
if err != nil {
-
t.Fatalf("VerifyDPoPProof failed: %v", err)
-
}
-
-
// Verify token binding with wrong thumbprint
-
err = verifier.VerifyTokenBinding(result, wrongKey.thumbprint)
-
if err == nil {
-
t.Error("Expected error for thumbprint mismatch, got nil")
-
}
-
if err != nil && !contains(err.Error(), "thumbprint mismatch") {
-
t.Errorf("Expected thumbprint mismatch error, got: %v", err)
-
}
-
}
-
-
// === ExtractCnfJkt Tests ===
-
-
func TestExtractCnfJkt_Valid(t *testing.T) {
-
expectedJkt := "test-thumbprint-123"
-
claims := &Claims{
-
Confirmation: map[string]interface{}{
-
"jkt": expectedJkt,
-
},
-
}
-
-
jkt, err := ExtractCnfJkt(claims)
-
if err != nil {
-
t.Fatalf("ExtractCnfJkt failed for valid claims: %v", err)
-
}
-
-
if jkt != expectedJkt {
-
t.Errorf("Expected jkt %s, got %s", expectedJkt, jkt)
-
}
-
}
-
-
func TestExtractCnfJkt_MissingCnf(t *testing.T) {
-
claims := &Claims{
-
// No Confirmation
-
}
-
-
_, err := ExtractCnfJkt(claims)
-
if err == nil {
-
t.Error("Expected error for missing cnf, got nil")
-
}
-
if err != nil && !contains(err.Error(), "missing cnf claim") {
-
t.Errorf("Expected missing cnf error, got: %v", err)
-
}
-
}
-
-
func TestExtractCnfJkt_NilCnf(t *testing.T) {
-
claims := &Claims{
-
Confirmation: nil,
-
}
-
-
_, err := ExtractCnfJkt(claims)
-
if err == nil {
-
t.Error("Expected error for nil cnf, got nil")
-
}
-
if err != nil && !contains(err.Error(), "missing cnf claim") {
-
t.Errorf("Expected missing cnf error, got: %v", err)
-
}
-
}
-
-
func TestExtractCnfJkt_MissingJkt(t *testing.T) {
-
claims := &Claims{
-
Confirmation: map[string]interface{}{
-
"other": "value",
-
},
-
}
-
-
_, err := ExtractCnfJkt(claims)
-
if err == nil {
-
t.Error("Expected error for missing jkt, got nil")
-
}
-
if err != nil && !contains(err.Error(), "missing jkt") {
-
t.Errorf("Expected missing jkt error, got: %v", err)
-
}
-
}
-
-
func TestExtractCnfJkt_EmptyJkt(t *testing.T) {
-
claims := &Claims{
-
Confirmation: map[string]interface{}{
-
"jkt": "",
-
},
-
}
-
-
_, err := ExtractCnfJkt(claims)
-
if err == nil {
-
t.Error("Expected error for empty jkt, got nil")
-
}
-
if err != nil && !contains(err.Error(), "missing jkt") {
-
t.Errorf("Expected missing jkt error, got: %v", err)
-
}
-
}
-
-
func TestExtractCnfJkt_WrongType(t *testing.T) {
-
claims := &Claims{
-
Confirmation: map[string]interface{}{
-
"jkt": 123, // Not a string
-
},
-
}
-
-
_, err := ExtractCnfJkt(claims)
-
if err == nil {
-
t.Error("Expected error for wrong type jkt, got nil")
-
}
-
if err != nil && !contains(err.Error(), "missing jkt") {
-
t.Errorf("Expected missing jkt error, got: %v", err)
-
}
-
}
-
-
// === Helper Functions for Tests ===
-
-
// splitJWT splits a JWT into its three parts
-
func splitJWT(token string) []string {
-
return []string{
-
token[:strings.IndexByte(token, '.')],
-
token[strings.IndexByte(token, '.')+1 : strings.LastIndexByte(token, '.')],
-
token[strings.LastIndexByte(token, '.')+1:],
-
}
-
}
-
-
// parseJWTHeader parses a base64url-encoded JWT header
-
func parseJWTHeader(t *testing.T, encoded string) map[string]interface{} {
-
t.Helper()
-
decoded, err := base64.RawURLEncoding.DecodeString(encoded)
-
if err != nil {
-
t.Fatalf("Failed to decode header: %v", err)
-
}
-
-
var header map[string]interface{}
-
if err := json.Unmarshal(decoded, &header); err != nil {
-
t.Fatalf("Failed to unmarshal header: %v", err)
-
}
-
-
return header
-
}
-
-
// encodeJSON encodes a value to base64url-encoded JSON
-
func encodeJSON(t *testing.T, v interface{}) string {
-
t.Helper()
-
data, err := json.Marshal(v)
-
if err != nil {
-
t.Fatalf("Failed to marshal JSON: %v", err)
-
}
-
return base64.RawURLEncoding.EncodeToString(data)
-
}
-
-
// === ES256K (secp256k1) Test Helpers ===
-
-
// testES256KKey holds a test ES256K key pair using indigo
-
type testES256KKey struct {
-
privateKey indigoCrypto.PrivateKey
-
publicKey indigoCrypto.PublicKey
-
jwk map[string]interface{}
-
thumbprint string
-
}
-
-
// generateTestES256KKey generates a test ES256K (secp256k1) key pair and JWK
-
func generateTestES256KKey(t *testing.T) *testES256KKey {
-
t.Helper()
-
-
privateKey, err := indigoCrypto.GeneratePrivateKeyK256()
-
if err != nil {
-
t.Fatalf("Failed to generate ES256K test key: %v", err)
-
}
-
-
publicKey, err := privateKey.PublicKey()
-
if err != nil {
-
t.Fatalf("Failed to get public key from ES256K private key: %v", err)
-
}
-
-
// Get the JWK representation
-
jwkStruct, err := publicKey.JWK()
-
if err != nil {
-
t.Fatalf("Failed to get JWK from ES256K public key: %v", err)
-
}
-
jwk := map[string]interface{}{
-
"kty": jwkStruct.KeyType,
-
"crv": jwkStruct.Curve,
-
"x": jwkStruct.X,
-
"y": jwkStruct.Y,
-
}
-
-
// Calculate thumbprint
-
thumbprint, err := CalculateJWKThumbprint(jwk)
-
if err != nil {
-
t.Fatalf("Failed to calculate ES256K thumbprint: %v", err)
-
}
-
-
return &testES256KKey{
-
privateKey: privateKey,
-
publicKey: publicKey,
-
jwk: jwk,
-
thumbprint: thumbprint,
-
}
-
}
-
-
// createES256KDPoPProof creates a DPoP proof JWT using ES256K for testing
-
func createES256KDPoPProof(t *testing.T, key *testES256KKey, method, uri string, iat time.Time, jti string) string {
-
t.Helper()
-
-
// Build claims
-
claims := map[string]interface{}{
-
"jti": jti,
-
"iat": iat.Unix(),
-
"htm": method,
-
"htu": uri,
-
}
-
-
// Build header
-
header := map[string]interface{}{
-
"typ": "dpop+jwt",
-
"alg": "ES256K",
-
"jwk": key.jwk,
-
}
-
-
// Encode header and claims
-
headerJSON, err := json.Marshal(header)
-
if err != nil {
-
t.Fatalf("Failed to marshal header: %v", err)
-
}
-
claimsJSON, err := json.Marshal(claims)
-
if err != nil {
-
t.Fatalf("Failed to marshal claims: %v", err)
-
}
-
-
headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
-
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
-
-
// Sign with indigo
-
signingInput := headerB64 + "." + claimsB64
-
signature, err := key.privateKey.HashAndSign([]byte(signingInput))
-
if err != nil {
-
t.Fatalf("Failed to sign ES256K proof: %v", err)
-
}
-
-
signatureB64 := base64.RawURLEncoding.EncodeToString(signature)
-
return signingInput + "." + signatureB64
-
}
-
-
// === ES256K Tests ===
-
-
func TestVerifyDPoPProof_ES256K_Valid(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256KKey(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
proof := createES256KDPoPProof(t, key, method, uri, iat, jti)
-
-
result, err := verifier.VerifyDPoPProof(proof, method, uri)
-
if err != nil {
-
t.Fatalf("VerifyDPoPProof failed for valid ES256K proof: %v", err)
-
}
-
-
if result == nil {
-
t.Fatal("Expected non-nil proof result")
-
}
-
-
if result.Claims.HTTPMethod != method {
-
t.Errorf("Expected method %s, got %s", method, result.Claims.HTTPMethod)
-
}
-
-
if result.Claims.HTTPURI != uri {
-
t.Errorf("Expected URI %s, got %s", uri, result.Claims.HTTPURI)
-
}
-
-
if result.Thumbprint != key.thumbprint {
-
t.Errorf("Expected thumbprint %s, got %s", key.thumbprint, result.Thumbprint)
-
}
-
}
-
-
func TestVerifyDPoPProof_ES256K_InvalidSignature(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256KKey(t)
-
wrongKey := generateTestES256KKey(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
// Create proof with one key
-
proof := createES256KDPoPProof(t, key, method, uri, iat, jti)
-
-
// Tamper by replacing JWK with wrong key
-
parts := splitJWT(proof)
-
header := parseJWTHeader(t, parts[0])
-
header["jwk"] = wrongKey.jwk
-
modifiedHeader := encodeJSON(t, header)
-
tamperedProof := modifiedHeader + "." + parts[1] + "." + parts[2]
-
-
_, err := verifier.VerifyDPoPProof(tamperedProof, method, uri)
-
if err == nil {
-
t.Error("Expected error for invalid ES256K signature, got nil")
-
}
-
if err != nil && !contains(err.Error(), "signature verification failed") {
-
t.Errorf("Expected signature verification error, got: %v", err)
-
}
-
}
-
-
func TestCalculateJWKThumbprint_ES256K(t *testing.T) {
-
// Test thumbprint calculation for secp256k1 keys
-
key := generateTestES256KKey(t)
-
-
thumbprint, err := CalculateJWKThumbprint(key.jwk)
-
if err != nil {
-
t.Fatalf("CalculateJWKThumbprint failed for ES256K: %v", err)
-
}
-
-
if thumbprint == "" {
-
t.Error("Expected non-empty thumbprint for ES256K key")
-
}
-
-
// Verify it's valid base64url
-
_, err = base64.RawURLEncoding.DecodeString(thumbprint)
-
if err != nil {
-
t.Errorf("ES256K thumbprint is not valid base64url: %v", err)
-
}
-
-
// Verify length (SHA-256 produces 32 bytes = 43 base64url chars)
-
if len(thumbprint) != 43 {
-
t.Errorf("Expected ES256K thumbprint length 43, got %d", len(thumbprint))
-
}
-
}
-
-
// === Algorithm-Curve Binding Tests ===
-
-
func TestVerifyDPoPProof_AlgorithmCurveMismatch_ES256KWithP256Key(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t) // P-256 key
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
// Create a proof claiming ES256K but using P-256 key
-
claims := &DPoPClaims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
ID: jti,
-
IssuedAt: jwt.NewNumericDate(iat),
-
},
-
HTTPMethod: method,
-
HTTPURI: uri,
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
-
token.Header["typ"] = "dpop+jwt"
-
token.Header["alg"] = "ES256K" // Claim ES256K
-
token.Header["jwk"] = key.jwk // But use P-256 key
-
-
proof, err := token.SignedString(key.privateKey)
-
if err != nil {
-
t.Fatalf("Failed to create test proof: %v", err)
-
}
-
-
_, err = verifier.VerifyDPoPProof(proof, method, uri)
-
if err == nil {
-
t.Error("Expected error for ES256K algorithm with P-256 curve, got nil")
-
}
-
if err != nil && !contains(err.Error(), "requires curve secp256k1") {
-
t.Errorf("Expected curve mismatch error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_AlgorithmCurveMismatch_ES256WithSecp256k1Key(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256KKey(t) // secp256k1 key
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
jti := uuid.New().String()
-
-
// Build claims
-
claims := map[string]interface{}{
-
"jti": jti,
-
"iat": iat.Unix(),
-
"htm": method,
-
"htu": uri,
-
}
-
-
// Build header claiming ES256 but using secp256k1 key
-
header := map[string]interface{}{
-
"typ": "dpop+jwt",
-
"alg": "ES256", // Claim ES256
-
"jwk": key.jwk, // But use secp256k1 key
-
}
-
-
headerJSON, _ := json.Marshal(header)
-
claimsJSON, _ := json.Marshal(claims)
-
-
headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
-
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
-
-
signingInput := headerB64 + "." + claimsB64
-
signature, err := key.privateKey.HashAndSign([]byte(signingInput))
-
if err != nil {
-
t.Fatalf("Failed to sign: %v", err)
-
}
-
-
proof := signingInput + "." + base64.RawURLEncoding.EncodeToString(signature)
-
-
_, err = verifier.VerifyDPoPProof(proof, method, uri)
-
if err == nil {
-
t.Error("Expected error for ES256 algorithm with secp256k1 curve, got nil")
-
}
-
if err != nil && !contains(err.Error(), "requires curve P-256") {
-
t.Errorf("Expected curve mismatch error, got: %v", err)
-
}
-
}
-
-
// === exp/nbf Validation Tests ===
-
-
func TestVerifyDPoPProof_ExpiredWithExpClaim(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now().Add(-2 * time.Minute)
-
exp := time.Now().Add(-1 * time.Minute) // Expired 1 minute ago
-
jti := uuid.New().String()
-
-
claims := &DPoPClaims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
ID: jti,
-
IssuedAt: jwt.NewNumericDate(iat),
-
ExpiresAt: jwt.NewNumericDate(exp),
-
},
-
HTTPMethod: method,
-
HTTPURI: uri,
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
-
token.Header["typ"] = "dpop+jwt"
-
token.Header["jwk"] = key.jwk
-
-
proof, err := token.SignedString(key.privateKey)
-
if err != nil {
-
t.Fatalf("Failed to create test proof: %v", err)
-
}
-
-
_, err = verifier.VerifyDPoPProof(proof, method, uri)
-
if err == nil {
-
t.Error("Expected error for expired proof with exp claim, got nil")
-
}
-
if err != nil && !contains(err.Error(), "expired") {
-
t.Errorf("Expected expiration error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_NotYetValidWithNbfClaim(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
nbf := time.Now().Add(5 * time.Minute) // Not valid for another 5 minutes
-
jti := uuid.New().String()
-
-
claims := &DPoPClaims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
ID: jti,
-
IssuedAt: jwt.NewNumericDate(iat),
-
NotBefore: jwt.NewNumericDate(nbf),
-
},
-
HTTPMethod: method,
-
HTTPURI: uri,
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
-
token.Header["typ"] = "dpop+jwt"
-
token.Header["jwk"] = key.jwk
-
-
proof, err := token.SignedString(key.privateKey)
-
if err != nil {
-
t.Fatalf("Failed to create test proof: %v", err)
-
}
-
-
_, err = verifier.VerifyDPoPProof(proof, method, uri)
-
if err == nil {
-
t.Error("Expected error for not-yet-valid proof with nbf claim, got nil")
-
}
-
if err != nil && !contains(err.Error(), "not valid before") {
-
t.Errorf("Expected not-before error, got: %v", err)
-
}
-
}
-
-
func TestVerifyDPoPProof_ValidWithExpClaimInFuture(t *testing.T) {
-
verifier := NewDPoPVerifier()
-
key := generateTestES256Key(t)
-
-
method := "POST"
-
uri := "https://api.example.com/resource"
-
iat := time.Now()
-
exp := time.Now().Add(5 * time.Minute) // Valid for 5 more minutes
-
jti := uuid.New().String()
-
-
claims := &DPoPClaims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
ID: jti,
-
IssuedAt: jwt.NewNumericDate(iat),
-
ExpiresAt: jwt.NewNumericDate(exp),
-
},
-
HTTPMethod: method,
-
HTTPURI: uri,
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
-
token.Header["typ"] = "dpop+jwt"
-
token.Header["jwk"] = key.jwk
-
-
proof, err := token.SignedString(key.privateKey)
-
if err != nil {
-
t.Fatalf("Failed to create test proof: %v", err)
-
}
-
-
result, err := verifier.VerifyDPoPProof(proof, method, uri)
-
if err != nil {
-
t.Fatalf("VerifyDPoPProof failed for valid proof with exp in future: %v", err)
-
}
-
-
if result == nil {
-
t.Error("Expected non-nil result for valid proof")
-
}
-
}
-189
internal/atproto/auth/jwks_fetcher.go
···
-
package auth
-
-
import (
-
"context"
-
"encoding/json"
-
"fmt"
-
"net/http"
-
"strings"
-
"sync"
-
"time"
-
)
-
-
// CachedJWKSFetcher fetches and caches JWKS from authorization servers
-
type CachedJWKSFetcher struct {
-
cache map[string]*cachedJWKS
-
httpClient *http.Client
-
cacheMutex sync.RWMutex
-
cacheTTL time.Duration
-
}
-
-
type cachedJWKS struct {
-
jwks *JWKS
-
expiresAt time.Time
-
}
-
-
// NewCachedJWKSFetcher creates a new JWKS fetcher with caching
-
func NewCachedJWKSFetcher(cacheTTL time.Duration) *CachedJWKSFetcher {
-
return &CachedJWKSFetcher{
-
cache: make(map[string]*cachedJWKS),
-
httpClient: &http.Client{
-
Timeout: 10 * time.Second,
-
},
-
cacheTTL: cacheTTL,
-
}
-
}
-
-
// FetchPublicKey fetches the public key for verifying a JWT from the issuer
-
// Implements JWKSFetcher interface
-
// Returns interface{} to support both RSA and ECDSA keys
-
func (f *CachedJWKSFetcher) FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) {
-
// Extract key ID from token
-
kid, err := ExtractKeyID(token)
-
if err != nil {
-
return nil, fmt.Errorf("failed to extract key ID: %w", err)
-
}
-
-
// Get JWKS from cache or fetch
-
jwks, err := f.getJWKS(ctx, issuer)
-
if err != nil {
-
return nil, err
-
}
-
-
// Find the key by ID
-
jwk, err := jwks.FindKeyByID(kid)
-
if err != nil {
-
// Key not found in cache - try refreshing
-
jwks, err = f.fetchJWKS(ctx, issuer)
-
if err != nil {
-
return nil, fmt.Errorf("failed to refresh JWKS: %w", err)
-
}
-
f.cacheJWKS(issuer, jwks)
-
-
// Try again with fresh JWKS
-
jwk, err = jwks.FindKeyByID(kid)
-
if err != nil {
-
return nil, err
-
}
-
}
-
-
// Convert JWK to public key (RSA or ECDSA)
-
return jwk.ToPublicKey()
-
}
-
-
// getJWKS gets JWKS from cache or fetches if not cached/expired
-
func (f *CachedJWKSFetcher) getJWKS(ctx context.Context, issuer string) (*JWKS, error) {
-
// Check cache first
-
f.cacheMutex.RLock()
-
cached, exists := f.cache[issuer]
-
f.cacheMutex.RUnlock()
-
-
if exists && time.Now().Before(cached.expiresAt) {
-
return cached.jwks, nil
-
}
-
-
// Not in cache or expired - fetch from issuer
-
jwks, err := f.fetchJWKS(ctx, issuer)
-
if err != nil {
-
return nil, err
-
}
-
-
// Cache it
-
f.cacheJWKS(issuer, jwks)
-
-
return jwks, nil
-
}
-
-
// fetchJWKS fetches JWKS from the authorization server
-
func (f *CachedJWKSFetcher) fetchJWKS(ctx context.Context, issuer string) (*JWKS, error) {
-
// Step 1: Fetch OAuth server metadata to get JWKS URI
-
metadataURL := strings.TrimSuffix(issuer, "/") + "/.well-known/oauth-authorization-server"
-
-
req, err := http.NewRequestWithContext(ctx, "GET", metadataURL, nil)
-
if err != nil {
-
return nil, fmt.Errorf("failed to create metadata request: %w", err)
-
}
-
-
resp, err := f.httpClient.Do(req)
-
if err != nil {
-
return nil, fmt.Errorf("failed to fetch metadata: %w", err)
-
}
-
defer func() {
-
_ = resp.Body.Close()
-
}()
-
-
if resp.StatusCode != http.StatusOK {
-
return nil, fmt.Errorf("metadata endpoint returned status %d", resp.StatusCode)
-
}
-
-
var metadata struct {
-
JWKSURI string `json:"jwks_uri"`
-
}
-
if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil {
-
return nil, fmt.Errorf("failed to decode metadata: %w", err)
-
}
-
-
if metadata.JWKSURI == "" {
-
return nil, fmt.Errorf("jwks_uri not found in metadata")
-
}
-
-
// Step 2: Fetch JWKS from the JWKS URI
-
jwksReq, err := http.NewRequestWithContext(ctx, "GET", metadata.JWKSURI, nil)
-
if err != nil {
-
return nil, fmt.Errorf("failed to create JWKS request: %w", err)
-
}
-
-
jwksResp, err := f.httpClient.Do(jwksReq)
-
if err != nil {
-
return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
-
}
-
defer func() {
-
_ = jwksResp.Body.Close()
-
}()
-
-
if jwksResp.StatusCode != http.StatusOK {
-
return nil, fmt.Errorf("JWKS endpoint returned status %d", jwksResp.StatusCode)
-
}
-
-
var jwks JWKS
-
if err := json.NewDecoder(jwksResp.Body).Decode(&jwks); err != nil {
-
return nil, fmt.Errorf("failed to decode JWKS: %w", err)
-
}
-
-
if len(jwks.Keys) == 0 {
-
return nil, fmt.Errorf("no keys found in JWKS")
-
}
-
-
return &jwks, nil
-
}
-
-
// cacheJWKS stores JWKS in the cache
-
func (f *CachedJWKSFetcher) cacheJWKS(issuer string, jwks *JWKS) {
-
f.cacheMutex.Lock()
-
defer f.cacheMutex.Unlock()
-
-
f.cache[issuer] = &cachedJWKS{
-
jwks: jwks,
-
expiresAt: time.Now().Add(f.cacheTTL),
-
}
-
}
-
-
// ClearCache clears the entire JWKS cache
-
func (f *CachedJWKSFetcher) ClearCache() {
-
f.cacheMutex.Lock()
-
defer f.cacheMutex.Unlock()
-
f.cache = make(map[string]*cachedJWKS)
-
}
-
-
// CleanupExpiredCache removes expired entries from the cache
-
func (f *CachedJWKSFetcher) CleanupExpiredCache() {
-
f.cacheMutex.Lock()
-
defer f.cacheMutex.Unlock()
-
-
now := time.Now()
-
for issuer, cached := range f.cache {
-
if now.After(cached.expiresAt) {
-
delete(f.cache, issuer)
-
}
-
}
-
}
-709
internal/atproto/auth/jwt.go
···
-
package auth
-
-
import (
-
"context"
-
"crypto/ecdsa"
-
"crypto/elliptic"
-
"crypto/rsa"
-
"encoding/base64"
-
"encoding/json"
-
"fmt"
-
"math/big"
-
"net/url"
-
"os"
-
"strings"
-
"sync"
-
"time"
-
-
indigoCrypto "github.com/bluesky-social/indigo/atproto/atcrypto"
-
"github.com/golang-jwt/jwt/v5"
-
)
-
-
// jwtConfig holds cached JWT configuration to avoid reading env vars on every request
-
type jwtConfig struct {
-
hs256Issuers map[string]struct{} // Set of whitelisted HS256 issuers
-
pdsJWTSecret []byte // Cached PDS_JWT_SECRET
-
isDevEnv bool // Cached IS_DEV_ENV
-
}
-
-
var (
-
cachedConfig *jwtConfig
-
configOnce sync.Once
-
)
-
-
// InitJWTConfig initializes the JWT configuration from environment variables.
-
// This should be called once at startup. If not called explicitly, it will be
-
// initialized lazily on first use.
-
func InitJWTConfig() {
-
configOnce.Do(func() {
-
cachedConfig = &jwtConfig{
-
hs256Issuers: make(map[string]struct{}),
-
isDevEnv: os.Getenv("IS_DEV_ENV") == "true",
-
}
-
-
// Parse HS256_ISSUERS into a set for O(1) lookup
-
if issuers := os.Getenv("HS256_ISSUERS"); issuers != "" {
-
for _, issuer := range strings.Split(issuers, ",") {
-
issuer = strings.TrimSpace(issuer)
-
if issuer != "" {
-
cachedConfig.hs256Issuers[issuer] = struct{}{}
-
}
-
}
-
}
-
-
// Cache PDS_JWT_SECRET
-
if secret := os.Getenv("PDS_JWT_SECRET"); secret != "" {
-
cachedConfig.pdsJWTSecret = []byte(secret)
-
}
-
})
-
}
-
-
// getConfig returns the cached config, initializing if needed
-
func getConfig() *jwtConfig {
-
InitJWTConfig()
-
return cachedConfig
-
}
-
-
// ResetJWTConfigForTesting resets the cached config to allow re-initialization.
-
// This should ONLY be used in tests.
-
func ResetJWTConfigForTesting() {
-
cachedConfig = nil
-
configOnce = sync.Once{}
-
}
-
-
// Algorithm constants for JWT signing methods
-
const (
-
AlgorithmHS256 = "HS256"
-
AlgorithmRS256 = "RS256"
-
AlgorithmES256 = "ES256"
-
)
-
-
// JWTHeader represents the parsed JWT header
-
type JWTHeader struct {
-
Alg string `json:"alg"`
-
Kid string `json:"kid"`
-
Typ string `json:"typ,omitempty"`
-
}
-
-
// Claims represents the standard JWT claims we care about
-
type Claims struct {
-
jwt.RegisteredClaims
-
// Confirmation claim for DPoP token binding (RFC 9449)
-
// Contains "jkt" (JWK thumbprint) when token is bound to a DPoP key
-
Confirmation map[string]interface{} `json:"cnf,omitempty"`
-
Scope string `json:"scope,omitempty"`
-
}
-
-
// stripBearerPrefix removes the "Bearer " prefix from a token string
-
func stripBearerPrefix(tokenString string) string {
-
tokenString = strings.TrimPrefix(tokenString, "Bearer ")
-
return strings.TrimSpace(tokenString)
-
}
-
-
// ParseJWTHeader extracts and parses the JWT header from a token string
-
// This is a reusable function for getting algorithm and key ID information
-
func ParseJWTHeader(tokenString string) (*JWTHeader, error) {
-
tokenString = stripBearerPrefix(tokenString)
-
-
parts := strings.Split(tokenString, ".")
-
if len(parts) != 3 {
-
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
-
}
-
-
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
-
if err != nil {
-
return nil, fmt.Errorf("failed to decode JWT header: %w", err)
-
}
-
-
var header JWTHeader
-
if err := json.Unmarshal(headerBytes, &header); err != nil {
-
return nil, fmt.Errorf("failed to parse JWT header: %w", err)
-
}
-
-
return &header, nil
-
}
-
-
// shouldUseHS256 determines if a token should use HS256 verification
-
// This prevents algorithm confusion attacks by using multiple signals:
-
// 1. If the token has a `kid` (key ID), it MUST use asymmetric verification
-
// 2. If no `kid`, only allow HS256 from whitelisted issuers (your own PDS)
-
//
-
// This approach supports open federation because:
-
// - External PDSes publish keys via JWKS and include `kid` in their tokens
-
// - Only your own PDS (which shares PDS_JWT_SECRET) uses HS256 without `kid`
-
func shouldUseHS256(header *JWTHeader, issuer string) bool {
-
// If token has a key ID, it MUST use asymmetric verification
-
// This is the primary defense against algorithm confusion attacks
-
if header.Kid != "" {
-
return false
-
}
-
-
// No kid - check if issuer is whitelisted for HS256
-
// This should only include your own PDS URL(s)
-
return isHS256IssuerWhitelisted(issuer)
-
}
-
-
// isHS256IssuerWhitelisted checks if the issuer is in the HS256 whitelist
-
// Only your own PDS should be in this list - external PDSes should use JWKS
-
func isHS256IssuerWhitelisted(issuer string) bool {
-
cfg := getConfig()
-
_, whitelisted := cfg.hs256Issuers[issuer]
-
return whitelisted
-
}
-
-
// ParseJWT parses a JWT token without verification (Phase 1)
-
// Returns the claims if the token is valid JSON and has required fields
-
func ParseJWT(tokenString string) (*Claims, error) {
-
// Remove "Bearer " prefix if present
-
tokenString = stripBearerPrefix(tokenString)
-
-
// Parse without verification first to extract claims
-
parser := jwt.NewParser(jwt.WithoutClaimsValidation())
-
token, _, err := parser.ParseUnverified(tokenString, &Claims{})
-
if err != nil {
-
return nil, fmt.Errorf("failed to parse JWT: %w", err)
-
}
-
-
claims, ok := token.Claims.(*Claims)
-
if !ok {
-
return nil, fmt.Errorf("invalid claims type")
-
}
-
-
// Validate required fields
-
if claims.Subject == "" {
-
return nil, fmt.Errorf("missing 'sub' claim (user DID)")
-
}
-
-
// atProto PDSes may use 'aud' instead of 'iss' for the authorization server
-
// If 'iss' is missing, use 'aud' as the authorization server identifier
-
if claims.Issuer == "" {
-
if len(claims.Audience) > 0 {
-
claims.Issuer = claims.Audience[0]
-
} else {
-
return nil, fmt.Errorf("missing both 'iss' and 'aud' claims (authorization server)")
-
}
-
}
-
-
// Validate claims (even in Phase 1, we need basic validation like expiry)
-
if err := validateClaims(claims); err != nil {
-
return nil, err
-
}
-
-
return claims, nil
-
}
-
-
// VerifyJWT verifies a JWT token's signature and claims (Phase 2)
-
// Fetches the public key from the issuer's JWKS endpoint and validates the signature
-
// For HS256 tokens from whitelisted issuers, uses the shared PDS_JWT_SECRET
-
//
-
// SECURITY: Algorithm is determined by the issuer whitelist, NOT the token header,
-
// to prevent algorithm confusion attacks where an attacker could re-sign a token
-
// with HS256 using a public key as the secret.
-
func VerifyJWT(ctx context.Context, tokenString string, keyFetcher JWKSFetcher) (*Claims, error) {
-
// Strip Bearer prefix once at the start
-
tokenString = stripBearerPrefix(tokenString)
-
-
// First parse to get the issuer (needed to determine expected algorithm)
-
claims, err := ParseJWT(tokenString)
-
if err != nil {
-
return nil, err
-
}
-
-
// Parse header to get the claimed algorithm (for validation)
-
header, err := ParseJWTHeader(tokenString)
-
if err != nil {
-
return nil, err
-
}
-
-
// SECURITY: Determine verification method based on token characteristics
-
// 1. Tokens with `kid` MUST use asymmetric verification (supports federation)
-
// 2. Tokens without `kid` can use HS256 only from whitelisted issuers (your own PDS)
-
useHS256 := shouldUseHS256(header, claims.Issuer)
-
-
if useHS256 {
-
// Verify token actually claims to use HS256
-
if header.Alg != AlgorithmHS256 {
-
return nil, fmt.Errorf("expected HS256 for issuer %s but token uses %s", claims.Issuer, header.Alg)
-
}
-
return verifyHS256Token(tokenString)
-
}
-
-
// Token must use asymmetric verification
-
// Reject HS256 tokens that don't meet the criteria above
-
if header.Alg == AlgorithmHS256 {
-
if header.Kid != "" {
-
return nil, fmt.Errorf("HS256 tokens with kid must use asymmetric verification")
-
}
-
return nil, fmt.Errorf("HS256 not allowed for issuer %s (not in HS256_ISSUERS whitelist)", claims.Issuer)
-
}
-
-
// For RSA/ECDSA, fetch public key from JWKS and verify
-
return verifyAsymmetricToken(ctx, tokenString, claims.Issuer, keyFetcher)
-
}
-
-
// verifyHS256Token verifies a JWT using HMAC-SHA256 with the shared secret
-
func verifyHS256Token(tokenString string) (*Claims, error) {
-
cfg := getConfig()
-
if len(cfg.pdsJWTSecret) == 0 {
-
return nil, fmt.Errorf("HS256 verification failed: PDS_JWT_SECRET not configured")
-
}
-
-
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
-
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
-
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
-
}
-
return cfg.pdsJWTSecret, nil
-
})
-
if err != nil {
-
return nil, fmt.Errorf("HS256 verification failed: %w", err)
-
}
-
-
if !token.Valid {
-
return nil, fmt.Errorf("HS256 verification failed: token signature invalid")
-
}
-
-
verifiedClaims, ok := token.Claims.(*Claims)
-
if !ok {
-
return nil, fmt.Errorf("HS256 verification failed: invalid claims type")
-
}
-
-
if err := validateClaims(verifiedClaims); err != nil {
-
return nil, err
-
}
-
-
return verifiedClaims, nil
-
}
-
-
// verifyAsymmetricToken verifies a JWT using RSA or ECDSA with a public key from JWKS.
-
// For ES256K (secp256k1), uses indigo's crypto package since golang-jwt doesn't support it.
-
func verifyAsymmetricToken(ctx context.Context, tokenString, issuer string, keyFetcher JWKSFetcher) (*Claims, error) {
-
// Parse header to check algorithm
-
header, err := ParseJWTHeader(tokenString)
-
if err != nil {
-
return nil, fmt.Errorf("failed to parse JWT header: %w", err)
-
}
-
-
// ES256K (secp256k1) requires special handling via indigo's crypto package
-
// golang-jwt doesn't recognize ES256K as a valid signing method
-
if header.Alg == "ES256K" {
-
return verifyES256KToken(ctx, tokenString, issuer, keyFetcher)
-
}
-
-
// For standard algorithms (ES256, ES384, ES512, RS256, etc.), use golang-jwt
-
publicKey, err := keyFetcher.FetchPublicKey(ctx, issuer, tokenString)
-
if err != nil {
-
return nil, fmt.Errorf("failed to fetch public key: %w", err)
-
}
-
-
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
-
// Validate signing method - support both RSA and ECDSA (atProto uses ES256 primarily)
-
switch token.Method.(type) {
-
case *jwt.SigningMethodRSA, *jwt.SigningMethodECDSA:
-
// Valid signing methods for atProto
-
default:
-
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
-
}
-
return publicKey, nil
-
})
-
if err != nil {
-
return nil, fmt.Errorf("asymmetric verification failed: %w", err)
-
}
-
-
if !token.Valid {
-
return nil, fmt.Errorf("asymmetric verification failed: token signature invalid")
-
}
-
-
verifiedClaims, ok := token.Claims.(*Claims)
-
if !ok {
-
return nil, fmt.Errorf("asymmetric verification failed: invalid claims type")
-
}
-
-
if err := validateClaims(verifiedClaims); err != nil {
-
return nil, err
-
}
-
-
return verifiedClaims, nil
-
}
-
-
// verifyES256KToken verifies a JWT signed with ES256K (secp256k1) using indigo's crypto package.
-
// This is necessary because golang-jwt doesn't support ES256K as a signing method.
-
func verifyES256KToken(ctx context.Context, tokenString, issuer string, keyFetcher JWKSFetcher) (*Claims, error) {
-
// Fetch the public key - for ES256K, the fetcher returns a JWK map or indigo PublicKey
-
keyData, err := keyFetcher.FetchPublicKey(ctx, issuer, tokenString)
-
if err != nil {
-
return nil, fmt.Errorf("failed to fetch public key for ES256K: %w", err)
-
}
-
-
// Convert to indigo PublicKey based on what the fetcher returned
-
var pubKey indigoCrypto.PublicKey
-
switch k := keyData.(type) {
-
case indigoCrypto.PublicKey:
-
// Already an indigo PublicKey (from DIDKeyFetcher or updated JWKSFetcher)
-
pubKey = k
-
case map[string]interface{}:
-
// Raw JWK map - parse with indigo
-
pubKey, err = parseJWKMapToIndigoPublicKey(k)
-
if err != nil {
-
return nil, fmt.Errorf("failed to parse ES256K JWK: %w", err)
-
}
-
default:
-
return nil, fmt.Errorf("ES256K verification requires indigo PublicKey or JWK map, got %T", keyData)
-
}
-
-
// Verify signature using indigo
-
if err := verifyJWTSignatureWithIndigoKey(tokenString, pubKey); err != nil {
-
return nil, fmt.Errorf("ES256K signature verification failed: %w", err)
-
}
-
-
// Parse claims (signature already verified)
-
claims, err := parseJWTClaimsManually(tokenString)
-
if err != nil {
-
return nil, fmt.Errorf("failed to parse ES256K JWT claims: %w", err)
-
}
-
-
if err := validateClaims(claims); err != nil {
-
return nil, err
-
}
-
-
return claims, nil
-
}
-
-
// parseJWKMapToIndigoPublicKey converts a JWK map to an indigo PublicKey.
-
// This uses indigo's crypto package which supports all atProto curves including secp256k1.
-
func parseJWKMapToIndigoPublicKey(jwkMap map[string]interface{}) (indigoCrypto.PublicKey, error) {
-
// Convert map to JSON bytes for indigo's parser
-
jwkBytes, err := json.Marshal(jwkMap)
-
if err != nil {
-
return nil, fmt.Errorf("failed to serialize JWK: %w", err)
-
}
-
-
// Parse with indigo's crypto package - supports all atProto curves
-
pubKey, err := indigoCrypto.ParsePublicJWKBytes(jwkBytes)
-
if err != nil {
-
return nil, fmt.Errorf("failed to parse JWK with indigo: %w", err)
-
}
-
-
return pubKey, nil
-
}
-
-
// verifyJWTSignatureWithIndigoKey verifies a JWT signature using indigo's crypto package.
-
// This works for all ECDSA algorithms including ES256K (secp256k1).
-
func verifyJWTSignatureWithIndigoKey(tokenString string, pubKey indigoCrypto.PublicKey) error {
-
parts := strings.Split(tokenString, ".")
-
if len(parts) != 3 {
-
return fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
-
}
-
-
// The signing input is "header.payload" (without decoding)
-
signingInput := parts[0] + "." + parts[1]
-
-
// Decode the signature from base64url
-
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
-
if err != nil {
-
return fmt.Errorf("failed to decode JWT signature: %w", err)
-
}
-
-
// Use indigo's verification - HashAndVerifyLenient handles hashing internally
-
// and accepts both low-S and high-S signatures for maximum compatibility
-
if err := pubKey.HashAndVerifyLenient([]byte(signingInput), signature); err != nil {
-
return fmt.Errorf("signature verification failed: %w", err)
-
}
-
-
return nil
-
}
-
-
// parseJWTClaimsManually parses JWT claims without using golang-jwt.
-
// This is used for ES256K tokens where golang-jwt would reject the algorithm.
-
func parseJWTClaimsManually(tokenString string) (*Claims, error) {
-
parts := strings.Split(tokenString, ".")
-
if len(parts) != 3 {
-
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
-
}
-
-
// Decode claims
-
claimsBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
-
if err != nil {
-
return nil, fmt.Errorf("failed to decode JWT claims: %w", err)
-
}
-
-
// Parse into raw map first
-
var rawClaims map[string]interface{}
-
if err := json.Unmarshal(claimsBytes, &rawClaims); err != nil {
-
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
-
}
-
-
// Build Claims struct
-
claims := &Claims{}
-
-
// Extract sub (subject/DID)
-
if sub, ok := rawClaims["sub"].(string); ok {
-
claims.Subject = sub
-
}
-
-
// Extract iss (issuer)
-
if iss, ok := rawClaims["iss"].(string); ok {
-
claims.Issuer = iss
-
}
-
-
// Extract aud (audience) - can be string or array
-
switch aud := rawClaims["aud"].(type) {
-
case string:
-
claims.Audience = jwt.ClaimStrings{aud}
-
case []interface{}:
-
for _, a := range aud {
-
if s, ok := a.(string); ok {
-
claims.Audience = append(claims.Audience, s)
-
}
-
}
-
}
-
-
// Extract exp (expiration)
-
if exp, ok := rawClaims["exp"].(float64); ok {
-
t := time.Unix(int64(exp), 0)
-
claims.ExpiresAt = jwt.NewNumericDate(t)
-
}
-
-
// Extract iat (issued at)
-
if iat, ok := rawClaims["iat"].(float64); ok {
-
t := time.Unix(int64(iat), 0)
-
claims.IssuedAt = jwt.NewNumericDate(t)
-
}
-
-
// Extract nbf (not before)
-
if nbf, ok := rawClaims["nbf"].(float64); ok {
-
t := time.Unix(int64(nbf), 0)
-
claims.NotBefore = jwt.NewNumericDate(t)
-
}
-
-
// Extract jti (JWT ID)
-
if jti, ok := rawClaims["jti"].(string); ok {
-
claims.ID = jti
-
}
-
-
// Extract scope
-
if scope, ok := rawClaims["scope"].(string); ok {
-
claims.Scope = scope
-
}
-
-
// Extract cnf (confirmation) for DPoP binding
-
if cnf, ok := rawClaims["cnf"].(map[string]interface{}); ok {
-
claims.Confirmation = cnf
-
}
-
-
return claims, nil
-
}
-
-
// validateClaims performs additional validation on JWT claims
-
func validateClaims(claims *Claims) error {
-
now := time.Now()
-
-
// Check expiration
-
if claims.ExpiresAt != nil && claims.ExpiresAt.Before(now) {
-
return fmt.Errorf("token has expired")
-
}
-
-
// Check not before
-
if claims.NotBefore != nil && claims.NotBefore.After(now) {
-
return fmt.Errorf("token not yet valid")
-
}
-
-
// Validate DID format in sub claim
-
if !strings.HasPrefix(claims.Subject, "did:") {
-
return fmt.Errorf("invalid DID format in 'sub' claim: %s", claims.Subject)
-
}
-
-
// Validate issuer is either an HTTPS URL or a DID
-
// atProto uses DIDs (did:web:, did:plc:) or HTTPS URLs as issuer identifiers
-
// In dev mode (IS_DEV_ENV=true), allow HTTP for local PDS testing
-
isHTTP := strings.HasPrefix(claims.Issuer, "http://")
-
isHTTPS := strings.HasPrefix(claims.Issuer, "https://")
-
isDID := strings.HasPrefix(claims.Issuer, "did:")
-
-
if !isHTTPS && !isDID && !isHTTP {
-
return fmt.Errorf("issuer must be HTTPS URL, HTTP URL (dev only), or DID, got: %s", claims.Issuer)
-
}
-
-
// In production, reject HTTP issuers (only for non-dev environments)
-
cfg := getConfig()
-
if isHTTP && !cfg.isDevEnv {
-
return fmt.Errorf("HTTP issuer not allowed in production, got: %s", claims.Issuer)
-
}
-
-
// Parse to ensure it's a valid URL
-
if _, err := url.Parse(claims.Issuer); err != nil {
-
return fmt.Errorf("invalid issuer URL: %w", err)
-
}
-
-
// Validate scope if present (lenient: allow empty, but reject wrong scopes)
-
if claims.Scope != "" && !strings.Contains(claims.Scope, "atproto") {
-
return fmt.Errorf("token missing required 'atproto' scope, got: %s", claims.Scope)
-
}
-
-
return nil
-
}
-
-
// JWKSFetcher defines the interface for fetching public keys from JWKS endpoints
-
// Returns interface{} to support both RSA and ECDSA keys
-
type JWKSFetcher interface {
-
FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error)
-
}
-
-
// JWK represents a JSON Web Key from a JWKS endpoint
-
// Supports both RSA and EC (ECDSA) keys
-
type JWK struct {
-
Kid string `json:"kid"` // Key ID
-
Kty string `json:"kty"` // Key type ("RSA" or "EC")
-
Alg string `json:"alg"` // Algorithm (e.g., "RS256", "ES256")
-
Use string `json:"use"` // Public key use (should be "sig" for signatures)
-
// RSA fields
-
N string `json:"n,omitempty"` // RSA modulus
-
E string `json:"e,omitempty"` // RSA exponent
-
// EC fields
-
Crv string `json:"crv,omitempty"` // EC curve (e.g., "P-256")
-
X string `json:"x,omitempty"` // EC x coordinate
-
Y string `json:"y,omitempty"` // EC y coordinate
-
}
-
-
// ToPublicKey converts a JWK to a public key (RSA, ECDSA, or indigo for secp256k1).
-
//
-
// Returns:
-
// - *rsa.PublicKey for RSA keys
-
// - *ecdsa.PublicKey for NIST EC curves (P-256, P-384, P-521)
-
// - map[string]interface{} for secp256k1 (ES256K) - parsed by indigo
-
func (j *JWK) ToPublicKey() (interface{}, error) {
-
switch j.Kty {
-
case "RSA":
-
return j.toRSAPublicKey()
-
case "EC":
-
// For secp256k1, return raw JWK map for indigo to parse
-
if j.Crv == "secp256k1" {
-
return j.toJWKMap(), nil
-
}
-
return j.toECPublicKey()
-
default:
-
return nil, fmt.Errorf("unsupported key type: %s", j.Kty)
-
}
-
}
-
-
// toJWKMap converts the JWK struct to a map for indigo parsing
-
func (j *JWK) toJWKMap() map[string]interface{} {
-
m := map[string]interface{}{
-
"kty": j.Kty,
-
}
-
if j.Kid != "" {
-
m["kid"] = j.Kid
-
}
-
if j.Alg != "" {
-
m["alg"] = j.Alg
-
}
-
if j.Use != "" {
-
m["use"] = j.Use
-
}
-
// RSA fields
-
if j.N != "" {
-
m["n"] = j.N
-
}
-
if j.E != "" {
-
m["e"] = j.E
-
}
-
// EC fields
-
if j.Crv != "" {
-
m["crv"] = j.Crv
-
}
-
if j.X != "" {
-
m["x"] = j.X
-
}
-
if j.Y != "" {
-
m["y"] = j.Y
-
}
-
return m
-
}
-
-
// toRSAPublicKey converts a JWK to an RSA public key
-
func (j *JWK) toRSAPublicKey() (*rsa.PublicKey, error) {
-
// Decode modulus
-
nBytes, err := base64.RawURLEncoding.DecodeString(j.N)
-
if err != nil {
-
return nil, fmt.Errorf("failed to decode RSA modulus: %w", err)
-
}
-
-
// Decode exponent
-
eBytes, err := base64.RawURLEncoding.DecodeString(j.E)
-
if err != nil {
-
return nil, fmt.Errorf("failed to decode RSA exponent: %w", err)
-
}
-
-
// Convert exponent to int
-
var eInt int
-
for _, b := range eBytes {
-
eInt = eInt*256 + int(b)
-
}
-
-
return &rsa.PublicKey{
-
N: new(big.Int).SetBytes(nBytes),
-
E: eInt,
-
}, nil
-
}
-
-
// toECPublicKey converts a JWK to an ECDSA public key
-
func (j *JWK) toECPublicKey() (*ecdsa.PublicKey, error) {
-
// Determine curve
-
var curve elliptic.Curve
-
switch j.Crv {
-
case "P-256":
-
curve = elliptic.P256()
-
case "P-384":
-
curve = elliptic.P384()
-
case "P-521":
-
curve = elliptic.P521()
-
default:
-
return nil, fmt.Errorf("unsupported EC curve: %s", j.Crv)
-
}
-
-
// Decode X coordinate
-
xBytes, err := base64.RawURLEncoding.DecodeString(j.X)
-
if err != nil {
-
return nil, fmt.Errorf("failed to decode EC x coordinate: %w", err)
-
}
-
-
// Decode Y coordinate
-
yBytes, err := base64.RawURLEncoding.DecodeString(j.Y)
-
if err != nil {
-
return nil, fmt.Errorf("failed to decode EC y coordinate: %w", err)
-
}
-
-
return &ecdsa.PublicKey{
-
Curve: curve,
-
X: new(big.Int).SetBytes(xBytes),
-
Y: new(big.Int).SetBytes(yBytes),
-
}, nil
-
}
-
-
// JWKS represents a JSON Web Key Set
-
type JWKS struct {
-
Keys []JWK `json:"keys"`
-
}
-
-
// FindKeyByID finds a key in the JWKS by its key ID
-
func (j *JWKS) FindKeyByID(kid string) (*JWK, error) {
-
for _, key := range j.Keys {
-
if key.Kid == kid {
-
return &key, nil
-
}
-
}
-
return nil, fmt.Errorf("key with kid %s not found", kid)
-
}
-
-
// ExtractKeyID extracts the key ID from a JWT token header
-
func ExtractKeyID(tokenString string) (string, error) {
-
header, err := ParseJWTHeader(tokenString)
-
if err != nil {
-
return "", err
-
}
-
-
if header.Kid == "" {
-
return "", fmt.Errorf("missing kid in token header")
-
}
-
-
return header.Kid, nil
-
}
-496
internal/atproto/auth/jwt_test.go
···
-
package auth
-
-
import (
-
"context"
-
"testing"
-
"time"
-
-
"github.com/golang-jwt/jwt/v5"
-
)
-
-
func TestParseJWT(t *testing.T) {
-
// Create a test JWT token
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
Issuer: "https://test-pds.example.com",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
IssuedAt: jwt.NewNumericDate(time.Now()),
-
},
-
Scope: "atproto transition:generic",
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
-
tokenString, err := token.SignedString([]byte("test-secret"))
-
if err != nil {
-
t.Fatalf("Failed to create test token: %v", err)
-
}
-
-
// Test parsing
-
parsedClaims, err := ParseJWT(tokenString)
-
if err != nil {
-
t.Fatalf("ParseJWT failed: %v", err)
-
}
-
-
if parsedClaims.Subject != "did:plc:test123" {
-
t.Errorf("Expected subject 'did:plc:test123', got '%s'", parsedClaims.Subject)
-
}
-
-
if parsedClaims.Issuer != "https://test-pds.example.com" {
-
t.Errorf("Expected issuer 'https://test-pds.example.com', got '%s'", parsedClaims.Issuer)
-
}
-
-
if parsedClaims.Scope != "atproto transition:generic" {
-
t.Errorf("Expected scope 'atproto transition:generic', got '%s'", parsedClaims.Scope)
-
}
-
}
-
-
func TestParseJWT_MissingSubject(t *testing.T) {
-
// Create a token without subject
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Issuer: "https://test-pds.example.com",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
},
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
-
tokenString, err := token.SignedString([]byte("test-secret"))
-
if err != nil {
-
t.Fatalf("Failed to create test token: %v", err)
-
}
-
-
// Test parsing - should fail
-
_, err = ParseJWT(tokenString)
-
if err == nil {
-
t.Error("Expected error for missing subject, got nil")
-
}
-
}
-
-
func TestParseJWT_MissingIssuer(t *testing.T) {
-
// Create a token without issuer
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
},
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
-
tokenString, err := token.SignedString([]byte("test-secret"))
-
if err != nil {
-
t.Fatalf("Failed to create test token: %v", err)
-
}
-
-
// Test parsing - should fail
-
_, err = ParseJWT(tokenString)
-
if err == nil {
-
t.Error("Expected error for missing issuer, got nil")
-
}
-
}
-
-
func TestParseJWT_WithBearerPrefix(t *testing.T) {
-
// Create a test JWT token
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
Issuer: "https://test-pds.example.com",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
},
-
}
-
-
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
-
tokenString, err := token.SignedString([]byte("test-secret"))
-
if err != nil {
-
t.Fatalf("Failed to create test token: %v", err)
-
}
-
-
// Test parsing with Bearer prefix
-
parsedClaims, err := ParseJWT("Bearer " + tokenString)
-
if err != nil {
-
t.Fatalf("ParseJWT failed with Bearer prefix: %v", err)
-
}
-
-
if parsedClaims.Subject != "did:plc:test123" {
-
t.Errorf("Expected subject 'did:plc:test123', got '%s'", parsedClaims.Subject)
-
}
-
}
-
-
func TestValidateClaims_Expired(t *testing.T) {
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
Issuer: "https://test-pds.example.com",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(-1 * time.Hour)), // Expired
-
},
-
}
-
-
err := validateClaims(claims)
-
if err == nil {
-
t.Error("Expected error for expired token, got nil")
-
}
-
}
-
-
func TestValidateClaims_InvalidDID(t *testing.T) {
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "invalid-did-format",
-
Issuer: "https://test-pds.example.com",
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
},
-
}
-
-
err := validateClaims(claims)
-
if err == nil {
-
t.Error("Expected error for invalid DID format, got nil")
-
}
-
}
-
-
func TestExtractKeyID(t *testing.T) {
-
// Create a test JWT token with kid in header
-
token := jwt.New(jwt.SigningMethodRS256)
-
token.Header["kid"] = "test-key-id"
-
token.Claims = &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
Issuer: "https://test-pds.example.com",
-
},
-
}
-
-
// Sign with a dummy RSA key (we just need a valid token structure)
-
tokenString, err := token.SignedString([]byte("dummy"))
-
if err == nil {
-
// If it succeeds (shouldn't with wrong key type, but let's handle it)
-
kid, err := ExtractKeyID(tokenString)
-
if err != nil {
-
t.Logf("ExtractKeyID failed (expected if signing fails): %v", err)
-
} else if kid != "test-key-id" {
-
t.Errorf("Expected kid 'test-key-id', got '%s'", kid)
-
}
-
}
-
}
-
-
// === HS256 Verification Tests ===
-
-
// mockJWKSFetcher is a mock implementation of JWKSFetcher for testing
-
type mockJWKSFetcher struct {
-
publicKey interface{}
-
err error
-
}
-
-
func (m *mockJWKSFetcher) FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) {
-
return m.publicKey, m.err
-
}
-
-
func createHS256Token(t *testing.T, subject, issuer, secret string, expiry time.Duration) string {
-
t.Helper()
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: subject,
-
Issuer: issuer,
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiry)),
-
IssuedAt: jwt.NewNumericDate(time.Now()),
-
},
-
Scope: "atproto transition:generic",
-
}
-
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
-
tokenString, err := token.SignedString([]byte(secret))
-
if err != nil {
-
t.Fatalf("Failed to create test token: %v", err)
-
}
-
return tokenString
-
}
-
-
func TestVerifyJWT_HS256_Valid(t *testing.T) {
-
// Setup: Configure environment for HS256 verification
-
secret := "test-jwt-secret-key-12345"
-
issuer := "https://pds.coves.social"
-
-
ResetJWTConfigForTesting()
-
t.Setenv("PDS_JWT_SECRET", secret)
-
t.Setenv("HS256_ISSUERS", issuer)
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
tokenString := createHS256Token(t, "did:plc:test123", issuer, secret, 1*time.Hour)
-
-
// Verify token
-
claims, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
-
if err != nil {
-
t.Fatalf("VerifyJWT failed for valid HS256 token: %v", err)
-
}
-
-
if claims.Subject != "did:plc:test123" {
-
t.Errorf("Expected subject 'did:plc:test123', got '%s'", claims.Subject)
-
}
-
if claims.Issuer != issuer {
-
t.Errorf("Expected issuer '%s', got '%s'", issuer, claims.Issuer)
-
}
-
}
-
-
func TestVerifyJWT_HS256_WrongSecret(t *testing.T) {
-
// Setup: Configure environment with one secret, sign with another
-
issuer := "https://pds.coves.social"
-
-
ResetJWTConfigForTesting()
-
t.Setenv("PDS_JWT_SECRET", "correct-secret")
-
t.Setenv("HS256_ISSUERS", issuer)
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
// Create token with wrong secret
-
tokenString := createHS256Token(t, "did:plc:test123", issuer, "wrong-secret", 1*time.Hour)
-
-
// Verify should fail
-
_, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
-
if err == nil {
-
t.Error("Expected error for HS256 token with wrong secret, got nil")
-
}
-
}
-
-
func TestVerifyJWT_HS256_SecretNotConfigured(t *testing.T) {
-
// Setup: Whitelist issuer but don't configure secret
-
issuer := "https://pds.coves.social"
-
-
ResetJWTConfigForTesting()
-
t.Setenv("PDS_JWT_SECRET", "") // Ensure secret is not set (empty = not configured)
-
t.Setenv("HS256_ISSUERS", issuer)
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
tokenString := createHS256Token(t, "did:plc:test123", issuer, "any-secret", 1*time.Hour)
-
-
// Verify should fail with descriptive error
-
_, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
-
if err == nil {
-
t.Error("Expected error when PDS_JWT_SECRET not configured, got nil")
-
}
-
if err != nil && !contains(err.Error(), "PDS_JWT_SECRET not configured") {
-
t.Errorf("Expected error about PDS_JWT_SECRET not configured, got: %v", err)
-
}
-
}
-
-
// === Algorithm Confusion Attack Prevention Tests ===
-
-
func TestVerifyJWT_AlgorithmConfusionAttack_HS256WithNonWhitelistedIssuer(t *testing.T) {
-
// SECURITY TEST: This tests the algorithm confusion attack prevention
-
// An attacker tries to use HS256 with an issuer that should use RS256/ES256
-
-
ResetJWTConfigForTesting()
-
t.Setenv("PDS_JWT_SECRET", "some-secret")
-
t.Setenv("HS256_ISSUERS", "https://trusted.example.com") // Different from token issuer
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
// Create HS256 token with non-whitelisted issuer (simulating attack)
-
tokenString := createHS256Token(t, "did:plc:attacker", "https://victim-pds.example.com", "some-secret", 1*time.Hour)
-
-
// Verify should fail because issuer is not in HS256 whitelist
-
_, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
-
if err == nil {
-
t.Error("SECURITY VULNERABILITY: HS256 token accepted for non-whitelisted issuer")
-
}
-
if err != nil && !contains(err.Error(), "not in HS256_ISSUERS whitelist") {
-
t.Errorf("Expected error about HS256 not allowed for issuer, got: %v", err)
-
}
-
}
-
-
func TestVerifyJWT_AlgorithmConfusionAttack_EmptyWhitelist(t *testing.T) {
-
// SECURITY TEST: When no issuers are whitelisted for HS256, all HS256 tokens should be rejected
-
-
ResetJWTConfigForTesting()
-
t.Setenv("PDS_JWT_SECRET", "some-secret")
-
t.Setenv("HS256_ISSUERS", "") // Empty whitelist
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
tokenString := createHS256Token(t, "did:plc:test123", "https://any-pds.example.com", "some-secret", 1*time.Hour)
-
-
// Verify should fail because no issuers are whitelisted for HS256
-
_, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
-
if err == nil {
-
t.Error("SECURITY VULNERABILITY: HS256 token accepted with empty issuer whitelist")
-
}
-
}
-
-
func TestVerifyJWT_IssuerRequiresHS256ButTokenUsesRS256(t *testing.T) {
-
// Test that issuer whitelisted for HS256 rejects tokens claiming to use RS256
-
issuer := "https://pds.coves.social"
-
-
ResetJWTConfigForTesting()
-
t.Setenv("PDS_JWT_SECRET", "test-secret")
-
t.Setenv("HS256_ISSUERS", issuer)
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
// Create RS256-signed token (can't actually sign without RSA key, but we can test the header check)
-
claims := &Claims{
-
RegisteredClaims: jwt.RegisteredClaims{
-
Subject: "did:plc:test123",
-
Issuer: issuer,
-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
-
},
-
}
-
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
-
// This will create an invalid signature but valid header structure
-
// The test should fail at algorithm check, not signature verification
-
tokenString, _ := token.SignedString([]byte("dummy-key"))
-
-
if tokenString != "" {
-
_, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
-
if err == nil {
-
t.Error("Expected error when HS256 issuer receives non-HS256 token")
-
}
-
}
-
}
-
-
// === ParseJWTHeader Tests ===
-
-
func TestParseJWTHeader_Valid(t *testing.T) {
-
tokenString := createHS256Token(t, "did:plc:test123", "https://test.example.com", "secret", 1*time.Hour)
-
-
header, err := ParseJWTHeader(tokenString)
-
if err != nil {
-
t.Fatalf("ParseJWTHeader failed: %v", err)
-
}
-
-
if header.Alg != AlgorithmHS256 {
-
t.Errorf("Expected alg '%s', got '%s'", AlgorithmHS256, header.Alg)
-
}
-
}
-
-
func TestParseJWTHeader_WithBearerPrefix(t *testing.T) {
-
tokenString := createHS256Token(t, "did:plc:test123", "https://test.example.com", "secret", 1*time.Hour)
-
-
header, err := ParseJWTHeader("Bearer " + tokenString)
-
if err != nil {
-
t.Fatalf("ParseJWTHeader failed with Bearer prefix: %v", err)
-
}
-
-
if header.Alg != AlgorithmHS256 {
-
t.Errorf("Expected alg '%s', got '%s'", AlgorithmHS256, header.Alg)
-
}
-
}
-
-
func TestParseJWTHeader_InvalidFormat(t *testing.T) {
-
testCases := []struct {
-
name string
-
input string
-
}{
-
{"empty string", ""},
-
{"single part", "abc"},
-
{"two parts", "abc.def"},
-
{"too many parts", "a.b.c.d"},
-
}
-
-
for _, tc := range testCases {
-
t.Run(tc.name, func(t *testing.T) {
-
_, err := ParseJWTHeader(tc.input)
-
if err == nil {
-
t.Errorf("Expected error for invalid JWT format '%s', got nil", tc.input)
-
}
-
})
-
}
-
}
-
-
// === shouldUseHS256 and isHS256IssuerWhitelisted Tests ===
-
-
func TestIsHS256IssuerWhitelisted_Whitelisted(t *testing.T) {
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", "https://pds1.example.com,https://pds2.example.com")
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
if !isHS256IssuerWhitelisted("https://pds1.example.com") {
-
t.Error("Expected pds1 to be whitelisted")
-
}
-
if !isHS256IssuerWhitelisted("https://pds2.example.com") {
-
t.Error("Expected pds2 to be whitelisted")
-
}
-
}
-
-
func TestIsHS256IssuerWhitelisted_NotWhitelisted(t *testing.T) {
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", "https://pds1.example.com")
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
if isHS256IssuerWhitelisted("https://attacker.example.com") {
-
t.Error("Expected non-whitelisted issuer to return false")
-
}
-
}
-
-
func TestIsHS256IssuerWhitelisted_EmptyWhitelist(t *testing.T) {
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", "") // Empty whitelist
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
if isHS256IssuerWhitelisted("https://any.example.com") {
-
t.Error("Expected false when whitelist is empty (safe default)")
-
}
-
}
-
-
func TestIsHS256IssuerWhitelisted_WhitespaceHandling(t *testing.T) {
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", " https://pds1.example.com , https://pds2.example.com ")
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
if !isHS256IssuerWhitelisted("https://pds1.example.com") {
-
t.Error("Expected whitespace-trimmed issuer to be whitelisted")
-
}
-
}
-
-
// === shouldUseHS256 Tests (kid-based logic) ===
-
-
func TestShouldUseHS256_WithKid_AlwaysFalse(t *testing.T) {
-
// Tokens with kid should NEVER use HS256, regardless of issuer whitelist
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", "https://whitelisted.example.com")
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
header := &JWTHeader{
-
Alg: AlgorithmHS256,
-
Kid: "some-key-id", // Has kid
-
}
-
-
// Even whitelisted issuer should not use HS256 if token has kid
-
if shouldUseHS256(header, "https://whitelisted.example.com") {
-
t.Error("Tokens with kid should never use HS256 (supports federation)")
-
}
-
}
-
-
func TestShouldUseHS256_WithoutKid_WhitelistedIssuer(t *testing.T) {
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", "https://my-pds.example.com")
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
header := &JWTHeader{
-
Alg: AlgorithmHS256,
-
Kid: "", // No kid
-
}
-
-
if !shouldUseHS256(header, "https://my-pds.example.com") {
-
t.Error("Token without kid from whitelisted issuer should use HS256")
-
}
-
}
-
-
func TestShouldUseHS256_WithoutKid_NotWhitelisted(t *testing.T) {
-
ResetJWTConfigForTesting()
-
t.Setenv("HS256_ISSUERS", "https://my-pds.example.com")
-
t.Cleanup(ResetJWTConfigForTesting)
-
-
header := &JWTHeader{
-
Alg: AlgorithmHS256,
-
Kid: "", // No kid
-
}
-
-
if shouldUseHS256(header, "https://external-pds.example.com") {
-
t.Error("Token without kid from non-whitelisted issuer should NOT use HS256")
-
}
-
}
-
-
// Helper function
-
func contains(s, substr string) bool {
-
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
-
}
-
-
func containsHelper(s, substr string) bool {
-
for i := 0; i <= len(s)-len(substr); i++ {
-
if s[i:i+len(substr)] == substr {
-
return true
-
}
-
}
-
return false
-
}
+1 -32
internal/atproto/oauth/client.go
···
"net/url"
"time"
-
"github.com/bluesky-social/indigo/atproto/atcrypto"
"github.com/bluesky-social/indigo/atproto/auth/oauth"
"github.com/bluesky-social/indigo/atproto/identity"
)
···
// OAuthConfig holds Coves OAuth client configuration
type OAuthConfig struct {
PublicURL string
-
ClientSecret string
-
ClientKID string
SealSecret string
PLCURL string
Scopes []string
···
callbackURL := "http://localhost:3000/oauth/callback"
clientConfig = oauth.NewLocalhostConfig(callbackURL, config.Scopes)
} else {
-
// Production mode: HTTPS with client secret
+
// Production mode: public OAuth client with HTTPS
callbackURL := config.PublicURL + "/oauth/callback"
clientConfig = oauth.NewPublicConfig(config.PublicURL, callbackURL, config.Scopes)
-
-
// Set up confidential client if client secret is provided
-
if config.ClientSecret != "" && config.ClientKID != "" {
-
privKey, err := atcrypto.ParsePrivateMultibase(config.ClientSecret)
-
if err != nil {
-
return nil, fmt.Errorf("failed to parse client secret: %w", err)
-
}
-
-
if err := clientConfig.SetClientSecret(privKey, config.ClientKID); err != nil {
-
return nil, fmt.Errorf("failed to set client secret: %w", err)
-
}
-
}
}
// Set user agent
···
metadata.ClientURI = strPtr(c.Config.PublicURL)
}
-
// For confidential clients, set JWKS URI
-
if c.ClientApp.Config.IsConfidential() && !c.Config.DevMode {
-
jwksURI := c.Config.PublicURL + "/.well-known/oauth-jwks.json"
-
metadata.JWKSURI = &jwksURI
-
}
-
return metadata
-
}
-
-
// PublicJWKS returns the public JWKS for this client (for confidential clients)
-
func (c *OAuthClient) PublicJWKS() oauth.JWKS {
-
return c.ClientApp.Config.PublicJWKS()
-
}
-
-
// IsConfidential returns true if this is a confidential OAuth client
-
func (c *OAuthClient) IsConfidential() bool {
-
return c.ClientApp.Config.IsConfidential()
}
// strPtr is a helper to get a pointer to a string
-19
internal/atproto/oauth/handlers.go
···
func (h *OAuthHandler) HandleClientMetadata(w http.ResponseWriter, r *http.Request) {
metadata := h.client.ClientMetadata()
-
// For confidential clients in production, set JWKS URI based on request host
-
if h.client.IsConfidential() && !h.client.Config.DevMode {
-
jwksURI := fmt.Sprintf("https://%s/oauth/jwks.json", r.Host)
-
metadata.JWKSURI = &jwksURI
-
}
-
// Validate metadata before returning (skip in dev mode - localhost doesn't need https validation)
if !h.client.Config.DevMode {
if err := metadata.Validate(h.client.ClientApp.Config.ClientID); err != nil {
···
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(metadata); err != nil {
slog.Error("failed to encode client metadata", "error", err)
-
http.Error(w, "internal server error", http.StatusInternalServerError)
-
return
-
}
-
}
-
-
// HandleJWKS serves the public JWKS for confidential clients
-
// GET /oauth/jwks.json
-
func (h *OAuthHandler) HandleJWKS(w http.ResponseWriter, r *http.Request) {
-
jwks := h.client.PublicJWKS()
-
-
w.Header().Set("Content-Type", "application/json")
-
if err := json.NewEncoder(w).Encode(jwks); err != nil {
-
slog.Error("failed to encode JWKS", "error", err)
http.Error(w, "internal server error", http.StatusInternalServerError)
return
}
-37
internal/atproto/oauth/handlers_test.go
···
assert.Contains(t, metadata.Scope, "atproto")
}
-
// TestHandleJWKS tests the JWKS endpoint
-
func TestHandleJWKS(t *testing.T) {
-
// Create a test OAuth client configuration (public client, no keys)
-
config := &OAuthConfig{
-
PublicURL: "https://coves.social",
-
Scopes: []string{"atproto"},
-
DevMode: false,
-
AllowPrivateIPs: false,
-
SealSecret: "MTIzNDU2Nzg5MDEyMzQ1Njc4OTAxMjM0NTY3ODkwMTI=",
-
}
-
-
client, err := NewOAuthClient(config, oauth.NewMemStore())
-
require.NoError(t, err)
-
-
handler := NewOAuthHandler(client, oauth.NewMemStore())
-
-
// Create test request
-
req := httptest.NewRequest(http.MethodGet, "/oauth/jwks.json", nil)
-
rec := httptest.NewRecorder()
-
-
// Call handler
-
handler.HandleJWKS(rec, req)
-
-
// Check response
-
assert.Equal(t, http.StatusOK, rec.Code)
-
assert.Equal(t, "application/json", rec.Header().Get("Content-Type"))
-
-
// Parse response
-
var jwks oauth.JWKS
-
err = json.NewDecoder(rec.Body).Decode(&jwks)
-
require.NoError(t, err)
-
-
// Public client should have empty JWKS
-
assert.NotNil(t, jwks.Keys)
-
assert.Equal(t, 0, len(jwks.Keys))
-
}
-
// TestHandleLogin tests the login endpoint
func TestHandleLogin(t *testing.T) {
config := &OAuthConfig{
-4
tests/integration/oauth_helpers.go
···
// Production mode: HTTPS PDS, use real PLC directory
config = &oauth.OAuthConfig{
PublicURL: "http://localhost:3000", // Test server callback URL
-
ClientSecret: "", // Public client
-
ClientKID: "", // Public client
SealSecret: sealSecretB64, // For sealing mobile tokens
Scopes: []string{"atproto", "transition:generic"},
DevMode: false, // Production mode for HTTPS PDS
···
// Dev mode: localhost PDS with HTTP
config = &oauth.OAuthConfig{
PublicURL: "http://localhost:3000", // Match the callback URL expected by PDS
-
ClientSecret: "", // Empty for public client in dev mode
-
ClientKID: "", // Empty for public client
SealSecret: sealSecretB64, // For sealing mobile tokens
Scopes: []string{"atproto", "transition:generic"},
DevMode: true, // Enable dev mode for localhost testing