A community based topic aggregation platform built on atproto

feat(auth): implement JWT validation with JWKS fetching

Add new simplified authentication system:
- JWT parsing and validation against atProto standards
- JWKS fetcher with caching for PDS public keys
- Support for both signature verification and parse-only modes
- Claims extraction (sub, iss, aud, exp, iat)

Dependencies:
- Add github.com/golang-jwt/jwt/v5 for JWT handling

This replaces the complex OAuth/DPoP flow with direct JWT validation,
suitable for alpha phase where we control both the PDS and AppView.

Files:
- internal/atproto/auth/jwt.go: JWT parsing and verification
- internal/atproto/auth/jwks_fetcher.go: Public key fetching
- internal/atproto/auth/jwt_test.go: Test coverage
- internal/atproto/auth/README.md: Documentation

Changed files
+845
internal
+1
go.mod
···
github.com/go-logr/stdr v1.2.2 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
+
github.com/golang-jwt/jwt/v5 v5.3.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/gorilla/securecookie v1.1.2 // indirect
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
+3
go.sum
···
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
+
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
+
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
+
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
+194
internal/atproto/auth/README.md
···
+
# atProto OAuth Authentication
+
+
This package implements third-party OAuth authentication for Coves, validating JWT Bearer tokens from mobile apps and other atProto clients.
+
+
## Architecture
+
+
This is **third-party authentication** (validating incoming requests), not first-party authentication (logging users into Coves web frontend).
+
+
### Components
+
+
1. **JWT Parser** (`jwt.go`) - Parses and validates JWT tokens
+
2. **JWKS Fetcher** (`jwks_fetcher.go`) - Fetches and caches public keys from PDS authorization servers
+
3. **Auth Middleware** (`internal/api/middleware/auth.go`) - HTTP middleware that protects endpoints
+
+
### Flow
+
+
```
+
Client Request
+
+
Authorization: Bearer <jwt>
+
+
Auth Middleware
+
+
Extract JWT → Parse Claims → Verify Signature (via JWKS)
+
+
Inject DID into Context → Call Handler
+
```
+
+
## Usage
+
+
### Phase 1: Parse-Only Mode (Testing)
+
+
Set `AUTH_SKIP_VERIFY=true` to only parse JWTs without signature verification:
+
+
```bash
+
export AUTH_SKIP_VERIFY=true
+
```
+
+
This is useful for:
+
- Initial integration testing
+
- Testing with mock tokens
+
- Debugging JWT structure
+
+
### Phase 2: Full Verification (Production)
+
+
Set `AUTH_SKIP_VERIFY=false` (or unset) to enable full JWT signature verification:
+
+
```bash
+
export AUTH_SKIP_VERIFY=false
+
# or just unset it
+
```
+
+
This is **required for production** and validates:
+
- JWT signature using PDS public key
+
- Token expiration
+
- Required claims (sub, iss)
+
- DID format
+
+
## Protected Endpoints
+
+
The following endpoints require authentication:
+
+
- `POST /xrpc/social.coves.community.create`
+
- `POST /xrpc/social.coves.community.update`
+
- `POST /xrpc/social.coves.community.subscribe`
+
- `POST /xrpc/social.coves.community.unsubscribe`
+
+
### Making Authenticated Requests
+
+
Include the JWT in the `Authorization` header:
+
+
```bash
+
curl -X POST https://coves.social/xrpc/social.coves.community.create \
+
-H "Authorization: Bearer eyJhbGc..." \
+
-H "Content-Type: application/json" \
+
-d '{"name":"Gaming","hostedByDid":"did:plc:..."}'
+
```
+
+
### Getting User DID in Handlers
+
+
The middleware injects the authenticated user's DID into the request context:
+
+
```go
+
import "Coves/internal/api/middleware"
+
+
func (h *Handler) HandleCreate(w http.ResponseWriter, r *http.Request) {
+
// Extract authenticated user DID
+
userDID := middleware.GetUserDID(r)
+
if userDID == "" {
+
// Not authenticated (should never happen with RequireAuth middleware)
+
http.Error(w, "Unauthorized", http.StatusUnauthorized)
+
return
+
}
+
+
// Use userDID for authorization checks
+
// ...
+
}
+
```
+
+
## Key Caching
+
+
Public keys are fetched from PDS authorization servers and cached for 1 hour. The cache is automatically cleaned up hourly to remove expired entries.
+
+
### JWKS Discovery Flow
+
+
1. Extract `iss` claim from JWT (e.g., `https://pds.example.com`)
+
2. Fetch `https://pds.example.com/.well-known/oauth-authorization-server`
+
3. Extract `jwks_uri` from metadata
+
4. Fetch JWKS from `jwks_uri`
+
5. Find matching key by `kid` from JWT header
+
6. Cache the JWKS for 1 hour
+
+
## Security Considerations
+
+
### ✅ Implemented
+
+
- JWT signature verification with PDS public keys
+
- Token expiration validation
+
- DID format validation
+
- Required claims validation (sub, iss)
+
- Key caching with TTL
+
- Secure error messages (no internal details leaked)
+
+
### ⚠️ Not Yet Implemented
+
+
- DPoP validation (for replay attack prevention)
+
- Scope validation (checking `scope` claim)
+
- Audience validation (checking `aud` claim)
+
- Rate limiting per DID
+
- Token revocation checking
+
+
## Testing
+
+
Run the test suite:
+
+
```bash
+
go test ./internal/atproto/auth/... -v
+
```
+
+
### Manual Testing
+
+
1. **Phase 1 (Parse Only)**:
+
```bash
+
# Create a test JWT (use jwt.io or a tool)
+
export AUTH_SKIP_VERIFY=true
+
curl -X POST http://localhost:8081/xrpc/social.coves.community.create \
+
-H "Authorization: Bearer <test-jwt>" \
+
-d '{"name":"Test","hostedByDid":"did:plc:test"}'
+
```
+
+
2. **Phase 2 (Full Verification)**:
+
```bash
+
# Use a real JWT from a PDS
+
export AUTH_SKIP_VERIFY=false
+
curl -X POST http://localhost:8081/xrpc/social.coves.community.create \
+
-H "Authorization: Bearer <real-jwt>" \
+
-d '{"name":"Test","hostedByDid":"did:plc:test"}'
+
```
+
+
## Error Responses
+
+
### 401 Unauthorized
+
+
Missing or invalid token:
+
+
```json
+
{
+
"error": "AuthenticationRequired",
+
"message": "Missing Authorization header"
+
}
+
```
+
+
```json
+
{
+
"error": "AuthenticationRequired",
+
"message": "Invalid or expired token"
+
}
+
```
+
+
### Common Issues
+
+
1. **Missing Authorization header** → Add `Authorization: Bearer <token>`
+
2. **Token expired** → Get a new token from PDS
+
3. **Invalid signature** → Ensure token is from a valid PDS
+
4. **JWKS fetch fails** → Check PDS availability and network connectivity
+
+
## Future Enhancements
+
+
- [ ] DPoP proof validation
+
- [ ] Scope-based authorization
+
- [ ] Audience claim validation
+
- [ ] Token revocation support
+
- [ ] Rate limiting per DID
+
- [ ] Metrics and monitoring
+189
internal/atproto/auth/jwks_fetcher.go
···
+
package auth
+
+
import (
+
"context"
+
"encoding/json"
+
"fmt"
+
"net/http"
+
"strings"
+
"sync"
+
"time"
+
)
+
+
// CachedJWKSFetcher fetches and caches JWKS from authorization servers
+
type CachedJWKSFetcher struct {
+
cache map[string]*cachedJWKS
+
httpClient *http.Client
+
cacheMutex sync.RWMutex
+
cacheTTL time.Duration
+
}
+
+
type cachedJWKS struct {
+
jwks *JWKS
+
expiresAt time.Time
+
}
+
+
// NewCachedJWKSFetcher creates a new JWKS fetcher with caching
+
func NewCachedJWKSFetcher(cacheTTL time.Duration) *CachedJWKSFetcher {
+
return &CachedJWKSFetcher{
+
cache: make(map[string]*cachedJWKS),
+
httpClient: &http.Client{
+
Timeout: 10 * time.Second,
+
},
+
cacheTTL: cacheTTL,
+
}
+
}
+
+
// FetchPublicKey fetches the public key for verifying a JWT from the issuer
+
// Implements JWKSFetcher interface
+
// Returns interface{} to support both RSA and ECDSA keys
+
func (f *CachedJWKSFetcher) FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) {
+
// Extract key ID from token
+
kid, err := ExtractKeyID(token)
+
if err != nil {
+
return nil, fmt.Errorf("failed to extract key ID: %w", err)
+
}
+
+
// Get JWKS from cache or fetch
+
jwks, err := f.getJWKS(ctx, issuer)
+
if err != nil {
+
return nil, err
+
}
+
+
// Find the key by ID
+
jwk, err := jwks.FindKeyByID(kid)
+
if err != nil {
+
// Key not found in cache - try refreshing
+
jwks, err = f.fetchJWKS(ctx, issuer)
+
if err != nil {
+
return nil, fmt.Errorf("failed to refresh JWKS: %w", err)
+
}
+
f.cacheJWKS(issuer, jwks)
+
+
// Try again with fresh JWKS
+
jwk, err = jwks.FindKeyByID(kid)
+
if err != nil {
+
return nil, err
+
}
+
}
+
+
// Convert JWK to public key (RSA or ECDSA)
+
return jwk.ToPublicKey()
+
}
+
+
// getJWKS gets JWKS from cache or fetches if not cached/expired
+
func (f *CachedJWKSFetcher) getJWKS(ctx context.Context, issuer string) (*JWKS, error) {
+
// Check cache first
+
f.cacheMutex.RLock()
+
cached, exists := f.cache[issuer]
+
f.cacheMutex.RUnlock()
+
+
if exists && time.Now().Before(cached.expiresAt) {
+
return cached.jwks, nil
+
}
+
+
// Not in cache or expired - fetch from issuer
+
jwks, err := f.fetchJWKS(ctx, issuer)
+
if err != nil {
+
return nil, err
+
}
+
+
// Cache it
+
f.cacheJWKS(issuer, jwks)
+
+
return jwks, nil
+
}
+
+
// fetchJWKS fetches JWKS from the authorization server
+
func (f *CachedJWKSFetcher) fetchJWKS(ctx context.Context, issuer string) (*JWKS, error) {
+
// Step 1: Fetch OAuth server metadata to get JWKS URI
+
metadataURL := strings.TrimSuffix(issuer, "/") + "/.well-known/oauth-authorization-server"
+
+
req, err := http.NewRequestWithContext(ctx, "GET", metadataURL, nil)
+
if err != nil {
+
return nil, fmt.Errorf("failed to create metadata request: %w", err)
+
}
+
+
resp, err := f.httpClient.Do(req)
+
if err != nil {
+
return nil, fmt.Errorf("failed to fetch metadata: %w", err)
+
}
+
defer func() {
+
_ = resp.Body.Close()
+
}()
+
+
if resp.StatusCode != http.StatusOK {
+
return nil, fmt.Errorf("metadata endpoint returned status %d", resp.StatusCode)
+
}
+
+
var metadata struct {
+
JWKSURI string `json:"jwks_uri"`
+
}
+
if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil {
+
return nil, fmt.Errorf("failed to decode metadata: %w", err)
+
}
+
+
if metadata.JWKSURI == "" {
+
return nil, fmt.Errorf("jwks_uri not found in metadata")
+
}
+
+
// Step 2: Fetch JWKS from the JWKS URI
+
jwksReq, err := http.NewRequestWithContext(ctx, "GET", metadata.JWKSURI, nil)
+
if err != nil {
+
return nil, fmt.Errorf("failed to create JWKS request: %w", err)
+
}
+
+
jwksResp, err := f.httpClient.Do(jwksReq)
+
if err != nil {
+
return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
+
}
+
defer func() {
+
_ = jwksResp.Body.Close()
+
}()
+
+
if jwksResp.StatusCode != http.StatusOK {
+
return nil, fmt.Errorf("JWKS endpoint returned status %d", jwksResp.StatusCode)
+
}
+
+
var jwks JWKS
+
if err := json.NewDecoder(jwksResp.Body).Decode(&jwks); err != nil {
+
return nil, fmt.Errorf("failed to decode JWKS: %w", err)
+
}
+
+
if len(jwks.Keys) == 0 {
+
return nil, fmt.Errorf("no keys found in JWKS")
+
}
+
+
return &jwks, nil
+
}
+
+
// cacheJWKS stores JWKS in the cache
+
func (f *CachedJWKSFetcher) cacheJWKS(issuer string, jwks *JWKS) {
+
f.cacheMutex.Lock()
+
defer f.cacheMutex.Unlock()
+
+
f.cache[issuer] = &cachedJWKS{
+
jwks: jwks,
+
expiresAt: time.Now().Add(f.cacheTTL),
+
}
+
}
+
+
// ClearCache clears the entire JWKS cache
+
func (f *CachedJWKSFetcher) ClearCache() {
+
f.cacheMutex.Lock()
+
defer f.cacheMutex.Unlock()
+
f.cache = make(map[string]*cachedJWKS)
+
}
+
+
// CleanupExpiredCache removes expired entries from the cache
+
func (f *CachedJWKSFetcher) CleanupExpiredCache() {
+
f.cacheMutex.Lock()
+
defer f.cacheMutex.Unlock()
+
+
now := time.Now()
+
for issuer, cached := range f.cache {
+
if now.After(cached.expiresAt) {
+
delete(f.cache, issuer)
+
}
+
}
+
}
+288
internal/atproto/auth/jwt.go
···
+
package auth
+
+
import (
+
"context"
+
"crypto/ecdsa"
+
"crypto/elliptic"
+
"crypto/rsa"
+
"encoding/base64"
+
"encoding/json"
+
"fmt"
+
"math/big"
+
"net/url"
+
"strings"
+
"time"
+
+
"github.com/golang-jwt/jwt/v5"
+
)
+
+
// Claims represents the standard JWT claims we care about
+
type Claims struct {
+
jwt.RegisteredClaims
+
Scope string `json:"scope,omitempty"`
+
}
+
+
// ParseJWT parses a JWT token without verification (Phase 1)
+
// Returns the claims if the token is valid JSON and has required fields
+
func ParseJWT(tokenString string) (*Claims, error) {
+
// Remove "Bearer " prefix if present
+
tokenString = strings.TrimPrefix(tokenString, "Bearer ")
+
tokenString = strings.TrimSpace(tokenString)
+
+
// Parse without verification first to extract claims
+
parser := jwt.NewParser(jwt.WithoutClaimsValidation())
+
token, _, err := parser.ParseUnverified(tokenString, &Claims{})
+
if err != nil {
+
return nil, fmt.Errorf("failed to parse JWT: %w", err)
+
}
+
+
claims, ok := token.Claims.(*Claims)
+
if !ok {
+
return nil, fmt.Errorf("invalid claims type")
+
}
+
+
// Validate required fields
+
if claims.Subject == "" {
+
return nil, fmt.Errorf("missing 'sub' claim (user DID)")
+
}
+
+
// atProto PDSes may use 'aud' instead of 'iss' for the authorization server
+
// If 'iss' is missing, use 'aud' as the authorization server identifier
+
if claims.Issuer == "" {
+
if len(claims.Audience) > 0 {
+
claims.Issuer = claims.Audience[0]
+
} else {
+
return nil, fmt.Errorf("missing both 'iss' and 'aud' claims (authorization server)")
+
}
+
}
+
+
// Validate claims (even in Phase 1, we need basic validation like expiry)
+
if err := validateClaims(claims); err != nil {
+
return nil, err
+
}
+
+
return claims, nil
+
}
+
+
// VerifyJWT verifies a JWT token's signature and claims (Phase 2)
+
// Fetches the public key from the issuer's JWKS endpoint and validates the signature
+
func VerifyJWT(ctx context.Context, tokenString string, keyFetcher JWKSFetcher) (*Claims, error) {
+
// First parse to get the issuer
+
claims, err := ParseJWT(tokenString)
+
if err != nil {
+
return nil, err
+
}
+
+
// Fetch the public key from the issuer
+
publicKey, err := keyFetcher.FetchPublicKey(ctx, claims.Issuer, tokenString)
+
if err != nil {
+
return nil, fmt.Errorf("failed to fetch public key: %w", err)
+
}
+
+
// Now parse and verify with the public key
+
tokenString = strings.TrimPrefix(tokenString, "Bearer ")
+
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
+
// Validate signing method - support both RSA and ECDSA (atProto uses ES256 primarily)
+
switch token.Method.(type) {
+
case *jwt.SigningMethodRSA, *jwt.SigningMethodECDSA:
+
// Valid signing methods for atProto
+
default:
+
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
+
}
+
return publicKey, nil
+
})
+
if err != nil {
+
return nil, fmt.Errorf("failed to verify JWT: %w", err)
+
}
+
+
if !token.Valid {
+
return nil, fmt.Errorf("token is invalid")
+
}
+
+
verifiedClaims, ok := token.Claims.(*Claims)
+
if !ok {
+
return nil, fmt.Errorf("invalid claims type after verification")
+
}
+
+
// Additional validation
+
if err := validateClaims(verifiedClaims); err != nil {
+
return nil, err
+
}
+
+
return verifiedClaims, nil
+
}
+
+
// validateClaims performs additional validation on JWT claims
+
func validateClaims(claims *Claims) error {
+
now := time.Now()
+
+
// Check expiration
+
if claims.ExpiresAt != nil && claims.ExpiresAt.Before(now) {
+
return fmt.Errorf("token has expired")
+
}
+
+
// Check not before
+
if claims.NotBefore != nil && claims.NotBefore.After(now) {
+
return fmt.Errorf("token not yet valid")
+
}
+
+
// Validate DID format in sub claim
+
if !strings.HasPrefix(claims.Subject, "did:") {
+
return fmt.Errorf("invalid DID format in 'sub' claim: %s", claims.Subject)
+
}
+
+
// Validate issuer is either an HTTPS URL or a DID
+
// atProto uses DIDs (did:web:, did:plc:) or HTTPS URLs as issuer identifiers
+
if !strings.HasPrefix(claims.Issuer, "https://") && !strings.HasPrefix(claims.Issuer, "did:") {
+
return fmt.Errorf("issuer must be HTTPS URL or DID, got: %s", claims.Issuer)
+
}
+
+
// Parse to ensure it's a valid URL
+
if _, err := url.Parse(claims.Issuer); err != nil {
+
return fmt.Errorf("invalid issuer URL: %w", err)
+
}
+
+
// Validate scope if present (lenient: allow empty, but reject wrong scopes)
+
if claims.Scope != "" && !strings.Contains(claims.Scope, "atproto") {
+
return fmt.Errorf("token missing required 'atproto' scope, got: %s", claims.Scope)
+
}
+
+
return nil
+
}
+
+
// JWKSFetcher defines the interface for fetching public keys from JWKS endpoints
+
// Returns interface{} to support both RSA and ECDSA keys
+
type JWKSFetcher interface {
+
FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error)
+
}
+
+
// JWK represents a JSON Web Key from a JWKS endpoint
+
// Supports both RSA and EC (ECDSA) keys
+
type JWK struct {
+
Kid string `json:"kid"` // Key ID
+
Kty string `json:"kty"` // Key type ("RSA" or "EC")
+
Alg string `json:"alg"` // Algorithm (e.g., "RS256", "ES256")
+
Use string `json:"use"` // Public key use (should be "sig" for signatures)
+
// RSA fields
+
N string `json:"n,omitempty"` // RSA modulus
+
E string `json:"e,omitempty"` // RSA exponent
+
// EC fields
+
Crv string `json:"crv,omitempty"` // EC curve (e.g., "P-256")
+
X string `json:"x,omitempty"` // EC x coordinate
+
Y string `json:"y,omitempty"` // EC y coordinate
+
}
+
+
// ToPublicKey converts a JWK to a public key (RSA or ECDSA)
+
func (j *JWK) ToPublicKey() (interface{}, error) {
+
switch j.Kty {
+
case "RSA":
+
return j.toRSAPublicKey()
+
case "EC":
+
return j.toECPublicKey()
+
default:
+
return nil, fmt.Errorf("unsupported key type: %s", j.Kty)
+
}
+
}
+
+
// toRSAPublicKey converts a JWK to an RSA public key
+
func (j *JWK) toRSAPublicKey() (*rsa.PublicKey, error) {
+
// Decode modulus
+
nBytes, err := base64.RawURLEncoding.DecodeString(j.N)
+
if err != nil {
+
return nil, fmt.Errorf("failed to decode RSA modulus: %w", err)
+
}
+
+
// Decode exponent
+
eBytes, err := base64.RawURLEncoding.DecodeString(j.E)
+
if err != nil {
+
return nil, fmt.Errorf("failed to decode RSA exponent: %w", err)
+
}
+
+
// Convert exponent to int
+
var eInt int
+
for _, b := range eBytes {
+
eInt = eInt*256 + int(b)
+
}
+
+
return &rsa.PublicKey{
+
N: new(big.Int).SetBytes(nBytes),
+
E: eInt,
+
}, nil
+
}
+
+
// toECPublicKey converts a JWK to an ECDSA public key
+
func (j *JWK) toECPublicKey() (*ecdsa.PublicKey, error) {
+
// Determine curve
+
var curve elliptic.Curve
+
switch j.Crv {
+
case "P-256":
+
curve = elliptic.P256()
+
case "P-384":
+
curve = elliptic.P384()
+
case "P-521":
+
curve = elliptic.P521()
+
default:
+
return nil, fmt.Errorf("unsupported EC curve: %s", j.Crv)
+
}
+
+
// Decode X coordinate
+
xBytes, err := base64.RawURLEncoding.DecodeString(j.X)
+
if err != nil {
+
return nil, fmt.Errorf("failed to decode EC x coordinate: %w", err)
+
}
+
+
// Decode Y coordinate
+
yBytes, err := base64.RawURLEncoding.DecodeString(j.Y)
+
if err != nil {
+
return nil, fmt.Errorf("failed to decode EC y coordinate: %w", err)
+
}
+
+
return &ecdsa.PublicKey{
+
Curve: curve,
+
X: new(big.Int).SetBytes(xBytes),
+
Y: new(big.Int).SetBytes(yBytes),
+
}, nil
+
}
+
+
// JWKS represents a JSON Web Key Set
+
type JWKS struct {
+
Keys []JWK `json:"keys"`
+
}
+
+
// FindKeyByID finds a key in the JWKS by its key ID
+
func (j *JWKS) FindKeyByID(kid string) (*JWK, error) {
+
for _, key := range j.Keys {
+
if key.Kid == kid {
+
return &key, nil
+
}
+
}
+
return nil, fmt.Errorf("key with kid %s not found", kid)
+
}
+
+
// ExtractKeyID extracts the key ID from a JWT token header
+
func ExtractKeyID(tokenString string) (string, error) {
+
tokenString = strings.TrimPrefix(tokenString, "Bearer ")
+
parts := strings.Split(tokenString, ".")
+
if len(parts) != 3 {
+
return "", fmt.Errorf("invalid JWT format")
+
}
+
+
// Decode header
+
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
+
if err != nil {
+
return "", fmt.Errorf("failed to decode header: %w", err)
+
}
+
+
var header struct {
+
Kid string `json:"kid"`
+
}
+
if err := json.Unmarshal(headerBytes, &header); err != nil {
+
return "", fmt.Errorf("failed to unmarshal header: %w", err)
+
}
+
+
if header.Kid == "" {
+
return "", fmt.Errorf("missing kid in token header")
+
}
+
+
return header.Kid, nil
+
}
+170
internal/atproto/auth/jwt_test.go
···
+
package auth
+
+
import (
+
"testing"
+
"time"
+
+
"github.com/golang-jwt/jwt/v5"
+
)
+
+
func TestParseJWT(t *testing.T) {
+
// Create a test JWT token
+
claims := &Claims{
+
RegisteredClaims: jwt.RegisteredClaims{
+
Subject: "did:plc:test123",
+
Issuer: "https://test-pds.example.com",
+
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
+
IssuedAt: jwt.NewNumericDate(time.Now()),
+
},
+
Scope: "atproto transition:generic",
+
}
+
+
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
+
tokenString, err := token.SignedString([]byte("test-secret"))
+
if err != nil {
+
t.Fatalf("Failed to create test token: %v", err)
+
}
+
+
// Test parsing
+
parsedClaims, err := ParseJWT(tokenString)
+
if err != nil {
+
t.Fatalf("ParseJWT failed: %v", err)
+
}
+
+
if parsedClaims.Subject != "did:plc:test123" {
+
t.Errorf("Expected subject 'did:plc:test123', got '%s'", parsedClaims.Subject)
+
}
+
+
if parsedClaims.Issuer != "https://test-pds.example.com" {
+
t.Errorf("Expected issuer 'https://test-pds.example.com', got '%s'", parsedClaims.Issuer)
+
}
+
+
if parsedClaims.Scope != "atproto transition:generic" {
+
t.Errorf("Expected scope 'atproto transition:generic', got '%s'", parsedClaims.Scope)
+
}
+
}
+
+
func TestParseJWT_MissingSubject(t *testing.T) {
+
// Create a token without subject
+
claims := &Claims{
+
RegisteredClaims: jwt.RegisteredClaims{
+
Issuer: "https://test-pds.example.com",
+
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
+
},
+
}
+
+
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
+
tokenString, err := token.SignedString([]byte("test-secret"))
+
if err != nil {
+
t.Fatalf("Failed to create test token: %v", err)
+
}
+
+
// Test parsing - should fail
+
_, err = ParseJWT(tokenString)
+
if err == nil {
+
t.Error("Expected error for missing subject, got nil")
+
}
+
}
+
+
func TestParseJWT_MissingIssuer(t *testing.T) {
+
// Create a token without issuer
+
claims := &Claims{
+
RegisteredClaims: jwt.RegisteredClaims{
+
Subject: "did:plc:test123",
+
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
+
},
+
}
+
+
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
+
tokenString, err := token.SignedString([]byte("test-secret"))
+
if err != nil {
+
t.Fatalf("Failed to create test token: %v", err)
+
}
+
+
// Test parsing - should fail
+
_, err = ParseJWT(tokenString)
+
if err == nil {
+
t.Error("Expected error for missing issuer, got nil")
+
}
+
}
+
+
func TestParseJWT_WithBearerPrefix(t *testing.T) {
+
// Create a test JWT token
+
claims := &Claims{
+
RegisteredClaims: jwt.RegisteredClaims{
+
Subject: "did:plc:test123",
+
Issuer: "https://test-pds.example.com",
+
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
+
},
+
}
+
+
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
+
tokenString, err := token.SignedString([]byte("test-secret"))
+
if err != nil {
+
t.Fatalf("Failed to create test token: %v", err)
+
}
+
+
// Test parsing with Bearer prefix
+
parsedClaims, err := ParseJWT("Bearer " + tokenString)
+
if err != nil {
+
t.Fatalf("ParseJWT failed with Bearer prefix: %v", err)
+
}
+
+
if parsedClaims.Subject != "did:plc:test123" {
+
t.Errorf("Expected subject 'did:plc:test123', got '%s'", parsedClaims.Subject)
+
}
+
}
+
+
func TestValidateClaims_Expired(t *testing.T) {
+
claims := &Claims{
+
RegisteredClaims: jwt.RegisteredClaims{
+
Subject: "did:plc:test123",
+
Issuer: "https://test-pds.example.com",
+
ExpiresAt: jwt.NewNumericDate(time.Now().Add(-1 * time.Hour)), // Expired
+
},
+
}
+
+
err := validateClaims(claims)
+
if err == nil {
+
t.Error("Expected error for expired token, got nil")
+
}
+
}
+
+
func TestValidateClaims_InvalidDID(t *testing.T) {
+
claims := &Claims{
+
RegisteredClaims: jwt.RegisteredClaims{
+
Subject: "invalid-did-format",
+
Issuer: "https://test-pds.example.com",
+
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
+
},
+
}
+
+
err := validateClaims(claims)
+
if err == nil {
+
t.Error("Expected error for invalid DID format, got nil")
+
}
+
}
+
+
func TestExtractKeyID(t *testing.T) {
+
// Create a test JWT token with kid in header
+
token := jwt.New(jwt.SigningMethodRS256)
+
token.Header["kid"] = "test-key-id"
+
token.Claims = &Claims{
+
RegisteredClaims: jwt.RegisteredClaims{
+
Subject: "did:plc:test123",
+
Issuer: "https://test-pds.example.com",
+
},
+
}
+
+
// Sign with a dummy RSA key (we just need a valid token structure)
+
tokenString, err := token.SignedString([]byte("dummy"))
+
if err == nil {
+
// If it succeeds (shouldn't with wrong key type, but let's handle it)
+
kid, err := ExtractKeyID(tokenString)
+
if err != nil {
+
t.Logf("ExtractKeyID failed (expected if signing fails): %v", err)
+
} else if kid != "test-key-id" {
+
t.Errorf("Expected kid 'test-key-id', got '%s'", kid)
+
}
+
}
+
}