forked from
tangled.org/core
Monorepo for Tangled — https://tangled.org
1package oauth
2
3import (
4 "errors"
5 "fmt"
6 "net/http"
7 "time"
8
9 comatproto "github.com/bluesky-social/indigo/api/atproto"
10 "github.com/bluesky-social/indigo/atproto/auth/oauth"
11 atpclient "github.com/bluesky-social/indigo/atproto/client"
12 "github.com/bluesky-social/indigo/atproto/syntax"
13 xrpc "github.com/bluesky-social/indigo/xrpc"
14 "github.com/gorilla/sessions"
15 "github.com/lestrrat-go/jwx/v2/jwk"
16 "github.com/posthog/posthog-go"
17 "tangled.org/core/appview/config"
18 "tangled.org/core/appview/db"
19 "tangled.org/core/idresolver"
20 "tangled.org/core/rbac"
21)
22
23func New(config *config.Config, ph posthog.Client, db *db.DB, enforcer *rbac.Enforcer, res *idresolver.Resolver) (*OAuth, error) {
24
25 var oauthConfig oauth.ClientConfig
26 var clientUri string
27
28 if config.Core.Dev {
29 clientUri = "http://127.0.0.1:3000"
30 callbackUri := clientUri + "/oauth/callback"
31 oauthConfig = oauth.NewLocalhostConfig(callbackUri, []string{"atproto", "transition:generic"})
32 } else {
33 clientUri = config.Core.AppviewHost
34 clientId := fmt.Sprintf("%s/oauth/client-metadata.json", clientUri)
35 callbackUri := clientUri + "/oauth/callback"
36 oauthConfig = oauth.NewPublicConfig(clientId, callbackUri, []string{"atproto", "transition:generic"})
37 }
38
39 jwksUri := clientUri + "/oauth/jwks.json"
40
41 authStore, err := NewRedisStore(config.Redis.ToURL())
42 if err != nil {
43 return nil, err
44 }
45
46 sessStore := sessions.NewCookieStore([]byte(config.Core.CookieSecret))
47
48 return &OAuth{
49 ClientApp: oauth.NewClientApp(&oauthConfig, authStore),
50 Config: config,
51 SessStore: sessStore,
52 JwksUri: jwksUri,
53 Posthog: ph,
54 Db: db,
55 Enforcer: enforcer,
56 IdResolver: res,
57 }, nil
58}
59
60type OAuth struct {
61 ClientApp *oauth.ClientApp
62 SessStore *sessions.CookieStore
63 Config *config.Config
64 JwksUri string
65 Posthog posthog.Client
66 Db *db.DB
67 Enforcer *rbac.Enforcer
68 IdResolver *idresolver.Resolver
69}
70
71func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, sessData *oauth.ClientSessionData) error {
72 // first we save the did in the user session
73 userSession, err := o.SessStore.Get(r, SessionName)
74 if err != nil {
75 return err
76 }
77
78 userSession.Values[SessionDid] = sessData.AccountDID.String()
79 userSession.Values[SessionPds] = sessData.HostURL
80 userSession.Values[SessionId] = sessData.SessionID
81 userSession.Values[SessionAuthenticated] = true
82 return userSession.Save(r, w)
83}
84
85func (o *OAuth) ResumeSession(r *http.Request) (*oauth.ClientSession, error) {
86 userSession, err := o.SessStore.Get(r, SessionName)
87 if err != nil {
88 return nil, fmt.Errorf("error getting user session: %w", err)
89 }
90 if userSession.IsNew {
91 return nil, fmt.Errorf("no session available for user")
92 }
93
94 d := userSession.Values[SessionDid].(string)
95 sessDid, err := syntax.ParseDID(d)
96 if err != nil {
97 return nil, fmt.Errorf("malformed DID in session cookie '%s': %w", d, err)
98 }
99
100 sessId := userSession.Values[SessionId].(string)
101
102 clientSess, err := o.ClientApp.ResumeSession(r.Context(), sessDid, sessId)
103 if err != nil {
104 return nil, fmt.Errorf("failed to resume session: %w", err)
105 }
106
107 return clientSess, nil
108}
109
110func (o *OAuth) DeleteSession(w http.ResponseWriter, r *http.Request) error {
111 userSession, err := o.SessStore.Get(r, SessionName)
112 if err != nil {
113 return fmt.Errorf("error getting user session: %w", err)
114 }
115 if userSession.IsNew {
116 return fmt.Errorf("no session available for user")
117 }
118
119 d := userSession.Values[SessionDid].(string)
120 sessDid, err := syntax.ParseDID(d)
121 if err != nil {
122 return fmt.Errorf("malformed DID in session cookie '%s': %w", d, err)
123 }
124
125 sessId := userSession.Values[SessionId].(string)
126
127 // delete the session
128 err1 := o.ClientApp.Logout(r.Context(), sessDid, sessId)
129
130 // remove the cookie
131 userSession.Options.MaxAge = -1
132 err2 := o.SessStore.Save(r, w, userSession)
133
134 return errors.Join(err1, err2)
135}
136
137func pubKeyFromJwk(jwks string) (jwk.Key, error) {
138 k, err := jwk.ParseKey([]byte(jwks))
139 if err != nil {
140 return nil, err
141 }
142 pubKey, err := k.PublicKey()
143 if err != nil {
144 return nil, err
145 }
146 return pubKey, nil
147}
148
149type User struct {
150 Did string
151 Pds string
152}
153
154func (o *OAuth) GetUser(r *http.Request) *User {
155 sess, err := o.SessStore.Get(r, SessionName)
156
157 if err != nil || sess.IsNew {
158 return nil
159 }
160
161 return &User{
162 Did: sess.Values[SessionDid].(string),
163 Pds: sess.Values[SessionPds].(string),
164 }
165}
166
167func (o *OAuth) GetDid(r *http.Request) string {
168 if u := o.GetUser(r); u != nil {
169 return u.Did
170 }
171
172 return ""
173}
174
175func (o *OAuth) AuthorizedClient(r *http.Request) (*atpclient.APIClient, error) {
176 session, err := o.ResumeSession(r)
177 if err != nil {
178 return nil, fmt.Errorf("error getting session: %w", err)
179 }
180 return session.APIClient(), nil
181}
182
183// this is a higher level abstraction on ServerGetServiceAuth
184type ServiceClientOpts struct {
185 service string
186 exp int64
187 lxm string
188 dev bool
189}
190
191type ServiceClientOpt func(*ServiceClientOpts)
192
193func WithService(service string) ServiceClientOpt {
194 return func(s *ServiceClientOpts) {
195 s.service = service
196 }
197}
198
199// Specify the Duration in seconds for the expiry of this token
200//
201// The time of expiry is calculated as time.Now().Unix() + exp
202func WithExp(exp int64) ServiceClientOpt {
203 return func(s *ServiceClientOpts) {
204 s.exp = time.Now().Unix() + exp
205 }
206}
207
208func WithLxm(lxm string) ServiceClientOpt {
209 return func(s *ServiceClientOpts) {
210 s.lxm = lxm
211 }
212}
213
214func WithDev(dev bool) ServiceClientOpt {
215 return func(s *ServiceClientOpts) {
216 s.dev = dev
217 }
218}
219
220func (s *ServiceClientOpts) Audience() string {
221 return fmt.Sprintf("did:web:%s", s.service)
222}
223
224func (s *ServiceClientOpts) Host() string {
225 scheme := "https://"
226 if s.dev {
227 scheme = "http://"
228 }
229
230 return scheme + s.service
231}
232
233func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*xrpc.Client, error) {
234 opts := ServiceClientOpts{}
235 for _, o := range os {
236 o(&opts)
237 }
238
239 client, err := o.AuthorizedClient(r)
240 if err != nil {
241 return nil, err
242 }
243
244 // force expiry to atleast 60 seconds in the future
245 sixty := time.Now().Unix() + 60
246 if opts.exp < sixty {
247 opts.exp = sixty
248 }
249
250 resp, err := comatproto.ServerGetServiceAuth(r.Context(), client, opts.Audience(), opts.exp, opts.lxm)
251 if err != nil {
252 return nil, err
253 }
254
255 return &xrpc.Client{
256 Auth: &xrpc.AuthInfo{
257 AccessJwt: resp.Token,
258 },
259 Host: opts.Host(),
260 Client: &http.Client{
261 Timeout: time.Second * 5,
262 },
263 }, nil
264}