1package oauth
2
3import (
4 "errors"
5 "fmt"
6 "log/slog"
7 "net/http"
8 "time"
9
10 comatproto "github.com/bluesky-social/indigo/api/atproto"
11 "github.com/bluesky-social/indigo/atproto/auth/oauth"
12 atpclient "github.com/bluesky-social/indigo/atproto/client"
13 atcrypto "github.com/bluesky-social/indigo/atproto/crypto"
14 "github.com/bluesky-social/indigo/atproto/syntax"
15 xrpc "github.com/bluesky-social/indigo/xrpc"
16 "github.com/gorilla/sessions"
17 "github.com/posthog/posthog-go"
18 "tangled.org/core/appview/config"
19 "tangled.org/core/appview/db"
20 "tangled.org/core/idresolver"
21 "tangled.org/core/rbac"
22)
23
24type OAuth struct {
25 ClientApp *oauth.ClientApp
26 SessStore *sessions.CookieStore
27 Config *config.Config
28 JwksUri string
29 ClientName string
30 ClientUri string
31 Posthog posthog.Client
32 Db *db.DB
33 Enforcer *rbac.Enforcer
34 IdResolver *idresolver.Resolver
35 Logger *slog.Logger
36}
37
38func New(config *config.Config, ph posthog.Client, db *db.DB, enforcer *rbac.Enforcer, res *idresolver.Resolver, logger *slog.Logger) (*OAuth, error) {
39 var oauthConfig oauth.ClientConfig
40 var clientUri string
41 if config.Core.Dev {
42 clientUri = "http://127.0.0.1:3000"
43 callbackUri := clientUri + "/oauth/callback"
44 oauthConfig = oauth.NewLocalhostConfig(callbackUri, []string{"atproto", "transition:generic"})
45 } else {
46 clientUri = config.Core.AppviewHost
47 clientId := fmt.Sprintf("%s/oauth/client-metadata.json", clientUri)
48 callbackUri := clientUri + "/oauth/callback"
49 oauthConfig = oauth.NewPublicConfig(clientId, callbackUri, []string{"atproto", "transition:generic"})
50 }
51
52 // configure client secret
53 priv, err := atcrypto.ParsePrivateMultibase(config.OAuth.ClientSecret)
54 if err != nil {
55 return nil, err
56 }
57 if err := oauthConfig.SetClientSecret(priv, config.OAuth.ClientKid); err != nil {
58 return nil, err
59 }
60
61 jwksUri := clientUri + "/oauth/jwks.json"
62
63 authStore, err := NewRedisStore(&RedisStoreConfig{
64 RedisURL: config.Redis.ToURL(),
65 SessionExpiryDuration: time.Hour * 24 * 90,
66 SessionInactivityDuration: time.Hour * 24 * 14,
67 AuthRequestExpiryDuration: time.Minute * 30,
68 })
69 if err != nil {
70 return nil, err
71 }
72
73 sessStore := sessions.NewCookieStore([]byte(config.Core.CookieSecret))
74
75 clientApp := oauth.NewClientApp(&oauthConfig, authStore)
76 clientApp.Dir = res.Directory()
77
78 clientName := config.Core.AppviewName
79
80 logger.Info("oauth setup successfully", "IsConfidential", clientApp.Config.IsConfidential())
81 return &OAuth{
82 ClientApp: clientApp,
83 Config: config,
84 SessStore: sessStore,
85 JwksUri: jwksUri,
86 ClientName: clientName,
87 ClientUri: clientUri,
88 Posthog: ph,
89 Db: db,
90 Enforcer: enforcer,
91 IdResolver: res,
92 Logger: logger,
93 }, nil
94}
95
96func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, sessData *oauth.ClientSessionData) error {
97 // first we save the did in the user session
98 userSession, err := o.SessStore.Get(r, SessionName)
99 if err != nil {
100 return err
101 }
102
103 userSession.Values[SessionDid] = sessData.AccountDID.String()
104 userSession.Values[SessionPds] = sessData.HostURL
105 userSession.Values[SessionId] = sessData.SessionID
106 userSession.Values[SessionAuthenticated] = true
107 return userSession.Save(r, w)
108}
109
110func (o *OAuth) ResumeSession(r *http.Request) (*oauth.ClientSession, error) {
111 userSession, err := o.SessStore.Get(r, SessionName)
112 if err != nil {
113 return nil, fmt.Errorf("error getting user session: %w", err)
114 }
115 if userSession.IsNew {
116 return nil, fmt.Errorf("no session available for user")
117 }
118
119 d := userSession.Values[SessionDid].(string)
120 sessDid, err := syntax.ParseDID(d)
121 if err != nil {
122 return nil, fmt.Errorf("malformed DID in session cookie '%s': %w", d, err)
123 }
124
125 sessId := userSession.Values[SessionId].(string)
126
127 clientSess, err := o.ClientApp.ResumeSession(r.Context(), sessDid, sessId)
128 if err != nil {
129 return nil, fmt.Errorf("failed to resume session: %w", err)
130 }
131
132 return clientSess, nil
133}
134
135func (o *OAuth) DeleteSession(w http.ResponseWriter, r *http.Request) error {
136 userSession, err := o.SessStore.Get(r, SessionName)
137 if err != nil {
138 return fmt.Errorf("error getting user session: %w", err)
139 }
140 if userSession.IsNew {
141 return fmt.Errorf("no session available for user")
142 }
143
144 d := userSession.Values[SessionDid].(string)
145 sessDid, err := syntax.ParseDID(d)
146 if err != nil {
147 return fmt.Errorf("malformed DID in session cookie '%s': %w", d, err)
148 }
149
150 sessId := userSession.Values[SessionId].(string)
151
152 // delete the session
153 err1 := o.ClientApp.Logout(r.Context(), sessDid, sessId)
154
155 // remove the cookie
156 userSession.Options.MaxAge = -1
157 err2 := o.SessStore.Save(r, w, userSession)
158
159 return errors.Join(err1, err2)
160}
161
162type User struct {
163 Did string
164 Pds string
165}
166
167func (o *OAuth) GetUser(r *http.Request) *User {
168 sess, err := o.ResumeSession(r)
169 if err != nil {
170 return nil
171 }
172
173 return &User{
174 Did: sess.Data.AccountDID.String(),
175 Pds: sess.Data.HostURL,
176 }
177}
178
179func (o *OAuth) GetDid(r *http.Request) string {
180 if u := o.GetUser(r); u != nil {
181 return u.Did
182 }
183
184 return ""
185}
186
187func (o *OAuth) AuthorizedClient(r *http.Request) (*atpclient.APIClient, error) {
188 session, err := o.ResumeSession(r)
189 if err != nil {
190 return nil, fmt.Errorf("error getting session: %w", err)
191 }
192 return session.APIClient(), nil
193}
194
195// this is a higher level abstraction on ServerGetServiceAuth
196type ServiceClientOpts struct {
197 service string
198 exp int64
199 lxm string
200 dev bool
201}
202
203type ServiceClientOpt func(*ServiceClientOpts)
204
205func WithService(service string) ServiceClientOpt {
206 return func(s *ServiceClientOpts) {
207 s.service = service
208 }
209}
210
211// Specify the Duration in seconds for the expiry of this token
212//
213// The time of expiry is calculated as time.Now().Unix() + exp
214func WithExp(exp int64) ServiceClientOpt {
215 return func(s *ServiceClientOpts) {
216 s.exp = time.Now().Unix() + exp
217 }
218}
219
220func WithLxm(lxm string) ServiceClientOpt {
221 return func(s *ServiceClientOpts) {
222 s.lxm = lxm
223 }
224}
225
226func WithDev(dev bool) ServiceClientOpt {
227 return func(s *ServiceClientOpts) {
228 s.dev = dev
229 }
230}
231
232func (s *ServiceClientOpts) Audience() string {
233 return fmt.Sprintf("did:web:%s", s.service)
234}
235
236func (s *ServiceClientOpts) Host() string {
237 scheme := "https://"
238 if s.dev {
239 scheme = "http://"
240 }
241
242 return scheme + s.service
243}
244
245func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*xrpc.Client, error) {
246 opts := ServiceClientOpts{}
247 for _, o := range os {
248 o(&opts)
249 }
250
251 client, err := o.AuthorizedClient(r)
252 if err != nil {
253 return nil, err
254 }
255
256 // force expiry to atleast 60 seconds in the future
257 sixty := time.Now().Unix() + 60
258 if opts.exp < sixty {
259 opts.exp = sixty
260 }
261
262 resp, err := comatproto.ServerGetServiceAuth(r.Context(), client, opts.Audience(), opts.exp, opts.lxm)
263 if err != nil {
264 return nil, err
265 }
266
267 return &xrpc.Client{
268 Auth: &xrpc.AuthInfo{
269 AccessJwt: resp.Token,
270 },
271 Host: opts.Host(),
272 Client: &http.Client{
273 Timeout: time.Second * 5,
274 },
275 }, nil
276}