A community based topic aggregation platform built on atproto
1package oauth
2
3import (
4 "encoding/base64"
5 "fmt"
6 "net/url"
7 "time"
8
9 "github.com/bluesky-social/indigo/atproto/atcrypto"
10 "github.com/bluesky-social/indigo/atproto/auth/oauth"
11 "github.com/bluesky-social/indigo/atproto/identity"
12)
13
14// OAuthClient wraps indigo's OAuth ClientApp with Coves-specific configuration
15type OAuthClient struct {
16 ClientApp *oauth.ClientApp
17 Config *OAuthConfig
18 SealSecret []byte // For sealing mobile tokens
19}
20
21// OAuthConfig holds Coves OAuth client configuration
22type OAuthConfig struct {
23 PublicURL string
24 ClientSecret string
25 ClientKID string
26 SealSecret string
27 PLCURL string
28 Scopes []string
29 SessionTTL time.Duration
30 SealedTokenTTL time.Duration
31 DevMode bool
32 AllowPrivateIPs bool
33}
34
35// NewOAuthClient creates a new OAuth client for Coves
36func NewOAuthClient(config *OAuthConfig, store oauth.ClientAuthStore) (*OAuthClient, error) {
37 if config == nil {
38 return nil, fmt.Errorf("config is required")
39 }
40
41 // Validate seal secret
42 var sealSecret []byte
43 if config.SealSecret != "" {
44 decoded, err := base64.StdEncoding.DecodeString(config.SealSecret)
45 if err != nil {
46 return nil, fmt.Errorf("failed to decode seal secret: %w", err)
47 }
48 if len(decoded) != 32 {
49 return nil, fmt.Errorf("seal secret must be 32 bytes, got %d", len(decoded))
50 }
51 sealSecret = decoded
52 }
53
54 // Validate scopes
55 if len(config.Scopes) == 0 {
56 return nil, fmt.Errorf("scopes are required")
57 }
58 hasAtproto := false
59 for _, scope := range config.Scopes {
60 if scope == "atproto" {
61 hasAtproto = true
62 break
63 }
64 }
65 if !hasAtproto {
66 return nil, fmt.Errorf("scopes must include 'atproto'")
67 }
68
69 // Set default TTL values if not specified
70 // Per atproto OAuth spec:
71 // - Public clients: 2-week (14 day) maximum session lifetime
72 // - Confidential clients: 180-day maximum session lifetime
73 if config.SessionTTL == 0 {
74 config.SessionTTL = 7 * 24 * time.Hour // 7 days default
75 }
76 if config.SealedTokenTTL == 0 {
77 config.SealedTokenTTL = 14 * 24 * time.Hour // 14 days (public client limit)
78 }
79
80 // Create indigo client config
81 var clientConfig oauth.ClientConfig
82 if config.DevMode {
83 // Dev mode: localhost with HTTP
84 callbackURL := "http://localhost:3000/oauth/callback"
85 clientConfig = oauth.NewLocalhostConfig(callbackURL, config.Scopes)
86 } else {
87 // Production mode: HTTPS with client secret
88 callbackURL := config.PublicURL + "/oauth/callback"
89 clientConfig = oauth.NewPublicConfig(config.PublicURL, callbackURL, config.Scopes)
90
91 // Set up confidential client if client secret is provided
92 if config.ClientSecret != "" && config.ClientKID != "" {
93 privKey, err := atcrypto.ParsePrivateMultibase(config.ClientSecret)
94 if err != nil {
95 return nil, fmt.Errorf("failed to parse client secret: %w", err)
96 }
97
98 if err := clientConfig.SetClientSecret(privKey, config.ClientKID); err != nil {
99 return nil, fmt.Errorf("failed to set client secret: %w", err)
100 }
101 }
102 }
103
104 // Set user agent
105 clientConfig.UserAgent = "Coves/1.0"
106
107 // Create the indigo OAuth ClientApp
108 clientApp := oauth.NewClientApp(&clientConfig, store)
109
110 // Override the default HTTP client with our SSRF-safe client
111 // This protects against SSRF attacks via malicious PDS URLs, DID documents, and JWKS URIs
112 clientApp.Client = NewSSRFSafeHTTPClient(config.AllowPrivateIPs)
113
114 // Override the directory if a custom PLC URL is configured
115 // This is necessary for local development with a local PLC directory
116 if config.PLCURL != "" {
117 // Use SSRF-safe HTTP client for PLC directory requests
118 httpClient := NewSSRFSafeHTTPClient(config.AllowPrivateIPs)
119 baseDir := &identity.BaseDirectory{
120 PLCURL: config.PLCURL,
121 HTTPClient: *httpClient,
122 UserAgent: "Coves/1.0",
123 }
124 // Wrap in cache directory for better performance
125 // Use pointer since CacheDirectory methods have pointer receivers
126 cacheDir := identity.NewCacheDirectory(baseDir, 100_000, time.Hour*24, time.Minute*2, time.Minute*5)
127 clientApp.Dir = &cacheDir
128 }
129
130 return &OAuthClient{
131 ClientApp: clientApp,
132 Config: config,
133 SealSecret: sealSecret,
134 }, nil
135}
136
137// ClientMetadata returns the OAuth client metadata document
138func (c *OAuthClient) ClientMetadata() oauth.ClientMetadata {
139 metadata := c.ClientApp.Config.ClientMetadata()
140
141 // Add additional metadata for Coves
142 metadata.ClientName = strPtr("Coves")
143 if !c.Config.DevMode {
144 metadata.ClientURI = strPtr(c.Config.PublicURL)
145 }
146
147 // For confidential clients, set JWKS URI
148 if c.ClientApp.Config.IsConfidential() && !c.Config.DevMode {
149 jwksURI := c.Config.PublicURL + "/.well-known/oauth-jwks.json"
150 metadata.JWKSURI = &jwksURI
151 }
152
153 return metadata
154}
155
156// PublicJWKS returns the public JWKS for this client (for confidential clients)
157func (c *OAuthClient) PublicJWKS() oauth.JWKS {
158 return c.ClientApp.Config.PublicJWKS()
159}
160
161// IsConfidential returns true if this is a confidential OAuth client
162func (c *OAuthClient) IsConfidential() bool {
163 return c.ClientApp.Config.IsConfidential()
164}
165
166// strPtr is a helper to get a pointer to a string
167func strPtr(s string) *string {
168 return &s
169}
170
171// ValidateCallbackURL validates that a callback URL matches the expected callback URL
172func (c *OAuthClient) ValidateCallbackURL(callbackURL string) error {
173 expectedCallback := c.ClientApp.Config.CallbackURL
174
175 // Parse both URLs
176 expected, err := url.Parse(expectedCallback)
177 if err != nil {
178 return fmt.Errorf("invalid expected callback URL: %w", err)
179 }
180
181 actual, err := url.Parse(callbackURL)
182 if err != nil {
183 return fmt.Errorf("invalid callback URL: %w", err)
184 }
185
186 // Compare scheme, host, and path (ignore query params)
187 if expected.Scheme != actual.Scheme {
188 return fmt.Errorf("callback URL scheme mismatch: expected %s, got %s", expected.Scheme, actual.Scheme)
189 }
190 if expected.Host != actual.Host {
191 return fmt.Errorf("callback URL host mismatch: expected %s, got %s", expected.Host, actual.Host)
192 }
193 if expected.Path != actual.Path {
194 return fmt.Errorf("callback URL path mismatch: expected %s, got %s", expected.Path, actual.Path)
195 }
196
197 return nil
198}