1package oauth
2
3import (
4 "errors"
5 "fmt"
6 "log/slog"
7 "net/http"
8 "time"
9
10 comatproto "github.com/bluesky-social/indigo/api/atproto"
11 "github.com/bluesky-social/indigo/atproto/auth/oauth"
12 atpclient "github.com/bluesky-social/indigo/atproto/client"
13 "github.com/bluesky-social/indigo/atproto/syntax"
14 xrpc "github.com/bluesky-social/indigo/xrpc"
15 "github.com/gorilla/sessions"
16 "github.com/lestrrat-go/jwx/v2/jwk"
17 "github.com/posthog/posthog-go"
18 "tangled.org/core/appview/config"
19 "tangled.org/core/appview/db"
20 "tangled.org/core/idresolver"
21 "tangled.org/core/rbac"
22)
23
24type OAuth struct {
25 ClientApp *oauth.ClientApp
26 SessStore *sessions.CookieStore
27 Config *config.Config
28 JwksUri string
29 Posthog posthog.Client
30 Db *db.DB
31 Enforcer *rbac.Enforcer
32 IdResolver *idresolver.Resolver
33 Logger *slog.Logger
34}
35
36func New(config *config.Config, ph posthog.Client, db *db.DB, enforcer *rbac.Enforcer, res *idresolver.Resolver, logger *slog.Logger) (*OAuth, error) {
37
38 var oauthConfig oauth.ClientConfig
39 var clientUri string
40
41 if config.Core.Dev {
42 clientUri = "http://127.0.0.1:3000"
43 callbackUri := clientUri + "/oauth/callback"
44 oauthConfig = oauth.NewLocalhostConfig(callbackUri, []string{"atproto", "transition:generic"})
45 } else {
46 clientUri = config.Core.AppviewHost
47 clientId := fmt.Sprintf("%s/oauth/client-metadata.json", clientUri)
48 callbackUri := clientUri + "/oauth/callback"
49 oauthConfig = oauth.NewPublicConfig(clientId, callbackUri, []string{"atproto", "transition:generic"})
50 }
51
52 jwksUri := clientUri + "/oauth/jwks.json"
53
54 authStore, err := NewRedisStore(config.Redis.ToURL())
55 if err != nil {
56 return nil, err
57 }
58
59 sessStore := sessions.NewCookieStore([]byte(config.Core.CookieSecret))
60
61 clientApp := oauth.NewClientApp(&oauthConfig, authStore)
62 clientApp.Dir = res.Directory()
63
64 return &OAuth{
65 ClientApp: clientApp,
66 Config: config,
67 SessStore: sessStore,
68 JwksUri: jwksUri,
69 Posthog: ph,
70 Db: db,
71 Enforcer: enforcer,
72 IdResolver: res,
73 Logger: logger,
74 }, nil
75}
76
77func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, sessData *oauth.ClientSessionData) error {
78 // first we save the did in the user session
79 userSession, err := o.SessStore.Get(r, SessionName)
80 if err != nil {
81 return err
82 }
83
84 userSession.Values[SessionDid] = sessData.AccountDID.String()
85 userSession.Values[SessionPds] = sessData.HostURL
86 userSession.Values[SessionId] = sessData.SessionID
87 userSession.Values[SessionAuthenticated] = true
88 return userSession.Save(r, w)
89}
90
91func (o *OAuth) ResumeSession(r *http.Request) (*oauth.ClientSession, error) {
92 userSession, err := o.SessStore.Get(r, SessionName)
93 if err != nil {
94 return nil, fmt.Errorf("error getting user session: %w", err)
95 }
96 if userSession.IsNew {
97 return nil, fmt.Errorf("no session available for user")
98 }
99
100 d := userSession.Values[SessionDid].(string)
101 sessDid, err := syntax.ParseDID(d)
102 if err != nil {
103 return nil, fmt.Errorf("malformed DID in session cookie '%s': %w", d, err)
104 }
105
106 sessId := userSession.Values[SessionId].(string)
107
108 clientSess, err := o.ClientApp.ResumeSession(r.Context(), sessDid, sessId)
109 if err != nil {
110 return nil, fmt.Errorf("failed to resume session: %w", err)
111 }
112
113 return clientSess, nil
114}
115
116func (o *OAuth) DeleteSession(w http.ResponseWriter, r *http.Request) error {
117 userSession, err := o.SessStore.Get(r, SessionName)
118 if err != nil {
119 return fmt.Errorf("error getting user session: %w", err)
120 }
121 if userSession.IsNew {
122 return fmt.Errorf("no session available for user")
123 }
124
125 d := userSession.Values[SessionDid].(string)
126 sessDid, err := syntax.ParseDID(d)
127 if err != nil {
128 return fmt.Errorf("malformed DID in session cookie '%s': %w", d, err)
129 }
130
131 sessId := userSession.Values[SessionId].(string)
132
133 // delete the session
134 err1 := o.ClientApp.Logout(r.Context(), sessDid, sessId)
135
136 // remove the cookie
137 userSession.Options.MaxAge = -1
138 err2 := o.SessStore.Save(r, w, userSession)
139
140 return errors.Join(err1, err2)
141}
142
143func pubKeyFromJwk(jwks string) (jwk.Key, error) {
144 k, err := jwk.ParseKey([]byte(jwks))
145 if err != nil {
146 return nil, err
147 }
148 pubKey, err := k.PublicKey()
149 if err != nil {
150 return nil, err
151 }
152 return pubKey, nil
153}
154
155type User struct {
156 Did string
157 Pds string
158}
159
160func (o *OAuth) GetUser(r *http.Request) *User {
161 sess, err := o.SessStore.Get(r, SessionName)
162
163 if err != nil || sess.IsNew {
164 return nil
165 }
166
167 return &User{
168 Did: sess.Values[SessionDid].(string),
169 Pds: sess.Values[SessionPds].(string),
170 }
171}
172
173func (o *OAuth) GetDid(r *http.Request) string {
174 if u := o.GetUser(r); u != nil {
175 return u.Did
176 }
177
178 return ""
179}
180
181func (o *OAuth) AuthorizedClient(r *http.Request) (*atpclient.APIClient, error) {
182 session, err := o.ResumeSession(r)
183 if err != nil {
184 return nil, fmt.Errorf("error getting session: %w", err)
185 }
186 return session.APIClient(), nil
187}
188
189// this is a higher level abstraction on ServerGetServiceAuth
190type ServiceClientOpts struct {
191 service string
192 exp int64
193 lxm string
194 dev bool
195}
196
197type ServiceClientOpt func(*ServiceClientOpts)
198
199func WithService(service string) ServiceClientOpt {
200 return func(s *ServiceClientOpts) {
201 s.service = service
202 }
203}
204
205// Specify the Duration in seconds for the expiry of this token
206//
207// The time of expiry is calculated as time.Now().Unix() + exp
208func WithExp(exp int64) ServiceClientOpt {
209 return func(s *ServiceClientOpts) {
210 s.exp = time.Now().Unix() + exp
211 }
212}
213
214func WithLxm(lxm string) ServiceClientOpt {
215 return func(s *ServiceClientOpts) {
216 s.lxm = lxm
217 }
218}
219
220func WithDev(dev bool) ServiceClientOpt {
221 return func(s *ServiceClientOpts) {
222 s.dev = dev
223 }
224}
225
226func (s *ServiceClientOpts) Audience() string {
227 return fmt.Sprintf("did:web:%s", s.service)
228}
229
230func (s *ServiceClientOpts) Host() string {
231 scheme := "https://"
232 if s.dev {
233 scheme = "http://"
234 }
235
236 return scheme + s.service
237}
238
239func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*xrpc.Client, error) {
240 opts := ServiceClientOpts{}
241 for _, o := range os {
242 o(&opts)
243 }
244
245 client, err := o.AuthorizedClient(r)
246 if err != nil {
247 return nil, err
248 }
249
250 // force expiry to atleast 60 seconds in the future
251 sixty := time.Now().Unix() + 60
252 if opts.exp < sixty {
253 opts.exp = sixty
254 }
255
256 resp, err := comatproto.ServerGetServiceAuth(r.Context(), client, opts.Audience(), opts.exp, opts.lxm)
257 if err != nil {
258 return nil, err
259 }
260
261 return &xrpc.Client{
262 Auth: &xrpc.AuthInfo{
263 AccessJwt: resp.Token,
264 },
265 Host: opts.Host(),
266 Client: &http.Client{
267 Timeout: time.Second * 5,
268 },
269 }, nil
270}