forked from
tangled.org/core
Monorepo for Tangled — https://tangled.org
1package oauth
2
3import (
4 "errors"
5 "fmt"
6 "net/http"
7 "time"
8
9 comatproto "github.com/bluesky-social/indigo/api/atproto"
10 "github.com/bluesky-social/indigo/atproto/auth/oauth"
11 atpclient "github.com/bluesky-social/indigo/atproto/client"
12 "github.com/bluesky-social/indigo/atproto/syntax"
13 xrpc "github.com/bluesky-social/indigo/xrpc"
14 "github.com/gorilla/sessions"
15 "github.com/lestrrat-go/jwx/v2/jwk"
16 "tangled.org/core/appview/config"
17)
18
19func New(config *config.Config) (*OAuth, error) {
20
21 var oauthConfig oauth.ClientConfig
22 var clientUri string
23
24 if config.Core.Dev {
25 clientUri = "http://127.0.0.1:3000"
26 callbackUri := clientUri + "/oauth/callback"
27 oauthConfig = oauth.NewLocalhostConfig(callbackUri, []string{"atproto", "transition:generic"})
28 } else {
29 clientUri = config.Core.AppviewHost
30 clientId := fmt.Sprintf("%s/oauth/client-metadata.json", clientUri)
31 callbackUri := clientUri + "/oauth/callback"
32 oauthConfig = oauth.NewPublicConfig(clientId, callbackUri, []string{"atproto", "transition:generic"})
33 }
34
35 jwksUri := clientUri + "/oauth/jwks.json"
36
37 authStore, err := NewRedisStore(config.Redis.ToURL())
38 if err != nil {
39 return nil, err
40 }
41
42 sessStore := sessions.NewCookieStore([]byte(config.Core.CookieSecret))
43
44 return &OAuth{
45 ClientApp: oauth.NewClientApp(&oauthConfig, authStore),
46 Config: config,
47 SessStore: sessStore,
48 JwksUri: jwksUri,
49 }, nil
50}
51
52type OAuth struct {
53 ClientApp *oauth.ClientApp
54 SessStore *sessions.CookieStore
55 Config *config.Config
56 JwksUri string
57}
58
59func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, sessData *oauth.ClientSessionData) error {
60 // first we save the did in the user session
61 userSession, err := o.SessStore.Get(r, SessionName)
62 if err != nil {
63 return err
64 }
65
66 userSession.Values[SessionDid] = sessData.AccountDID.String()
67 userSession.Values[SessionPds] = sessData.HostURL
68 userSession.Values[SessionId] = sessData.SessionID
69 userSession.Values[SessionAuthenticated] = true
70 return userSession.Save(r, w)
71}
72
73func (o *OAuth) ResumeSession(r *http.Request) (*oauth.ClientSession, error) {
74 userSession, err := o.SessStore.Get(r, SessionName)
75 if err != nil {
76 return nil, fmt.Errorf("error getting user session: %w", err)
77 }
78 if userSession.IsNew {
79 return nil, fmt.Errorf("no session available for user")
80 }
81
82 d := userSession.Values[SessionDid].(string)
83 sessDid, err := syntax.ParseDID(d)
84 if err != nil {
85 return nil, fmt.Errorf("malformed DID in session cookie '%s': %w", d, err)
86 }
87
88 sessId := userSession.Values[SessionId].(string)
89
90 clientSess, err := o.ClientApp.ResumeSession(r.Context(), sessDid, sessId)
91 if err != nil {
92 return nil, fmt.Errorf("failed to resume session: %w", err)
93 }
94
95 return clientSess, nil
96}
97
98func (o *OAuth) DeleteSession(w http.ResponseWriter, r *http.Request) error {
99 userSession, err := o.SessStore.Get(r, SessionName)
100 if err != nil {
101 return fmt.Errorf("error getting user session: %w", err)
102 }
103 if userSession.IsNew {
104 return fmt.Errorf("no session available for user")
105 }
106
107 d := userSession.Values[SessionDid].(string)
108 sessDid, err := syntax.ParseDID(d)
109 if err != nil {
110 return fmt.Errorf("malformed DID in session cookie '%s': %w", d, err)
111 }
112
113 sessId := userSession.Values[SessionId].(string)
114
115 // delete the session
116 err1 := o.ClientApp.Logout(r.Context(), sessDid, sessId)
117
118 // remove the cookie
119 userSession.Options.MaxAge = -1
120 err2 := o.SessStore.Save(r, w, userSession)
121
122 return errors.Join(err1, err2)
123}
124
125func pubKeyFromJwk(jwks string) (jwk.Key, error) {
126 k, err := jwk.ParseKey([]byte(jwks))
127 if err != nil {
128 return nil, err
129 }
130 pubKey, err := k.PublicKey()
131 if err != nil {
132 return nil, err
133 }
134 return pubKey, nil
135}
136
137type User struct {
138 Did string
139 Pds string
140}
141
142func (o *OAuth) GetUser(r *http.Request) *User {
143 sess, err := o.SessStore.Get(r, SessionName)
144
145 if err != nil || sess.IsNew {
146 return nil
147 }
148
149 return &User{
150 Did: sess.Values[SessionDid].(string),
151 Pds: sess.Values[SessionPds].(string),
152 }
153}
154
155func (o *OAuth) GetDid(r *http.Request) string {
156 if u := o.GetUser(r); u != nil {
157 return u.Did
158 }
159
160 return ""
161}
162
163func (o *OAuth) AuthorizedClient(r *http.Request) (*atpclient.APIClient, error) {
164 session, err := o.ResumeSession(r)
165 if err != nil {
166 return nil, fmt.Errorf("error getting session: %w", err)
167 }
168 return session.APIClient(), nil
169}
170
171// this is a higher level abstraction on ServerGetServiceAuth
172type ServiceClientOpts struct {
173 service string
174 exp int64
175 lxm string
176 dev bool
177}
178
179type ServiceClientOpt func(*ServiceClientOpts)
180
181func WithService(service string) ServiceClientOpt {
182 return func(s *ServiceClientOpts) {
183 s.service = service
184 }
185}
186
187// Specify the Duration in seconds for the expiry of this token
188//
189// The time of expiry is calculated as time.Now().Unix() + exp
190func WithExp(exp int64) ServiceClientOpt {
191 return func(s *ServiceClientOpts) {
192 s.exp = time.Now().Unix() + exp
193 }
194}
195
196func WithLxm(lxm string) ServiceClientOpt {
197 return func(s *ServiceClientOpts) {
198 s.lxm = lxm
199 }
200}
201
202func WithDev(dev bool) ServiceClientOpt {
203 return func(s *ServiceClientOpts) {
204 s.dev = dev
205 }
206}
207
208func (s *ServiceClientOpts) Audience() string {
209 return fmt.Sprintf("did:web:%s", s.service)
210}
211
212func (s *ServiceClientOpts) Host() string {
213 scheme := "https://"
214 if s.dev {
215 scheme = "http://"
216 }
217
218 return scheme + s.service
219}
220
221func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*xrpc.Client, error) {
222 opts := ServiceClientOpts{}
223 for _, o := range os {
224 o(&opts)
225 }
226
227 client, err := o.AuthorizedClient(r)
228 if err != nil {
229 return nil, err
230 }
231
232 // force expiry to atleast 60 seconds in the future
233 sixty := time.Now().Unix() + 60
234 if opts.exp < sixty {
235 opts.exp = sixty
236 }
237
238 resp, err := comatproto.ServerGetServiceAuth(r.Context(), client, opts.Audience(), opts.exp, opts.lxm)
239 if err != nil {
240 return nil, err
241 }
242
243 return &xrpc.Client{
244 Auth: &xrpc.AuthInfo{
245 AccessJwt: resp.Token,
246 },
247 Host: opts.Host(),
248 Client: &http.Client{
249 Timeout: time.Second * 5,
250 },
251 }, nil
252}