A community based topic aggregation platform built on atproto
1package oauth
2
3import (
4 "encoding/base64"
5 "fmt"
6 "net/url"
7 "time"
8
9 "github.com/bluesky-social/indigo/atproto/auth/oauth"
10 "github.com/bluesky-social/indigo/atproto/identity"
11)
12
13// OAuthClient wraps indigo's OAuth ClientApp with Coves-specific configuration
14type OAuthClient struct {
15 ClientApp *oauth.ClientApp
16 Config *OAuthConfig
17 SealSecret []byte // For sealing mobile tokens
18}
19
20// OAuthConfig holds Coves OAuth client configuration
21type OAuthConfig struct {
22 PublicURL string
23 SealSecret string
24 PLCURL string
25 Scopes []string
26 SessionTTL time.Duration
27 SealedTokenTTL time.Duration
28 DevMode bool
29 AllowPrivateIPs bool
30}
31
32// NewOAuthClient creates a new OAuth client for Coves
33func NewOAuthClient(config *OAuthConfig, store oauth.ClientAuthStore) (*OAuthClient, error) {
34 if config == nil {
35 return nil, fmt.Errorf("config is required")
36 }
37
38 // Validate seal secret
39 var sealSecret []byte
40 if config.SealSecret != "" {
41 decoded, err := base64.StdEncoding.DecodeString(config.SealSecret)
42 if err != nil {
43 return nil, fmt.Errorf("failed to decode seal secret: %w", err)
44 }
45 if len(decoded) != 32 {
46 return nil, fmt.Errorf("seal secret must be 32 bytes, got %d", len(decoded))
47 }
48 sealSecret = decoded
49 }
50
51 // Validate scopes
52 if len(config.Scopes) == 0 {
53 return nil, fmt.Errorf("scopes are required")
54 }
55 hasAtproto := false
56 for _, scope := range config.Scopes {
57 if scope == "atproto" {
58 hasAtproto = true
59 break
60 }
61 }
62 if !hasAtproto {
63 return nil, fmt.Errorf("scopes must include 'atproto'")
64 }
65
66 // Set default TTL values if not specified
67 // Per atproto OAuth spec:
68 // - Public clients: 2-week (14 day) maximum session lifetime
69 // - Confidential clients: 180-day maximum session lifetime
70 if config.SessionTTL == 0 {
71 config.SessionTTL = 7 * 24 * time.Hour // 7 days default
72 }
73 if config.SealedTokenTTL == 0 {
74 config.SealedTokenTTL = 14 * 24 * time.Hour // 14 days (public client limit)
75 }
76
77 // Create indigo client config
78 var clientConfig oauth.ClientConfig
79 if config.DevMode {
80 // Dev mode: localhost with HTTP
81 callbackURL := "http://localhost:3000/oauth/callback"
82 clientConfig = oauth.NewLocalhostConfig(callbackURL, config.Scopes)
83 } else {
84 // Production mode: public OAuth client with HTTPS
85 // client_id must be the URL of the client metadata document per atproto OAuth spec
86 clientID := config.PublicURL + "/oauth/client-metadata.json"
87 callbackURL := config.PublicURL + "/oauth/callback"
88 clientConfig = oauth.NewPublicConfig(clientID, callbackURL, config.Scopes)
89 }
90
91 // Set user agent
92 clientConfig.UserAgent = "Coves/1.0"
93
94 // Create the indigo OAuth ClientApp
95 clientApp := oauth.NewClientApp(&clientConfig, store)
96
97 // Override the default HTTP client with our SSRF-safe client
98 // This protects against SSRF attacks via malicious PDS URLs, DID documents, and JWKS URIs
99 clientApp.Client = NewSSRFSafeHTTPClient(config.AllowPrivateIPs)
100
101 // Override the directory if a custom PLC URL is configured
102 // This is necessary for local development with a local PLC directory
103 if config.PLCURL != "" {
104 // Use SSRF-safe HTTP client for PLC directory requests
105 httpClient := NewSSRFSafeHTTPClient(config.AllowPrivateIPs)
106 baseDir := &identity.BaseDirectory{
107 PLCURL: config.PLCURL,
108 HTTPClient: *httpClient,
109 UserAgent: "Coves/1.0",
110 }
111 // Wrap in cache directory for better performance
112 // Use pointer since CacheDirectory methods have pointer receivers
113 cacheDir := identity.NewCacheDirectory(baseDir, 100_000, time.Hour*24, time.Minute*2, time.Minute*5)
114 clientApp.Dir = &cacheDir
115 }
116
117 return &OAuthClient{
118 ClientApp: clientApp,
119 Config: config,
120 SealSecret: sealSecret,
121 }, nil
122}
123
124// ClientMetadata returns the OAuth client metadata document
125func (c *OAuthClient) ClientMetadata() oauth.ClientMetadata {
126 metadata := c.ClientApp.Config.ClientMetadata()
127
128 // Add additional metadata for Coves
129 metadata.ClientName = strPtr("Coves")
130 if !c.Config.DevMode {
131 metadata.ClientURI = strPtr(c.Config.PublicURL)
132 }
133
134 return metadata
135}
136
137// strPtr is a helper to get a pointer to a string
138func strPtr(s string) *string {
139 return &s
140}
141
142// ValidateCallbackURL validates that a callback URL matches the expected callback URL
143func (c *OAuthClient) ValidateCallbackURL(callbackURL string) error {
144 expectedCallback := c.ClientApp.Config.CallbackURL
145
146 // Parse both URLs
147 expected, err := url.Parse(expectedCallback)
148 if err != nil {
149 return fmt.Errorf("invalid expected callback URL: %w", err)
150 }
151
152 actual, err := url.Parse(callbackURL)
153 if err != nil {
154 return fmt.Errorf("invalid callback URL: %w", err)
155 }
156
157 // Compare scheme, host, and path (ignore query params)
158 if expected.Scheme != actual.Scheme {
159 return fmt.Errorf("callback URL scheme mismatch: expected %s, got %s", expected.Scheme, actual.Scheme)
160 }
161 if expected.Host != actual.Host {
162 return fmt.Errorf("callback URL host mismatch: expected %s, got %s", expected.Host, actual.Host)
163 }
164 if expected.Path != actual.Path {
165 return fmt.Errorf("callback URL path mismatch: expected %s, got %s", expected.Path, actual.Path)
166 }
167
168 return nil
169}