1package oauth
2
3import (
4 "fmt"
5 "log"
6 "net/http"
7 "net/url"
8 "time"
9
10 indigo_xrpc "github.com/bluesky-social/indigo/xrpc"
11 "github.com/gorilla/sessions"
12 sessioncache "tangled.org/core/appview/cache/session"
13 "tangled.org/core/appview/config"
14 "tangled.org/core/appview/oauth/client"
15 xrpc "tangled.org/core/appview/xrpcclient"
16 oauth "tangled.sh/icyphox.sh/atproto-oauth"
17 "tangled.sh/icyphox.sh/atproto-oauth/helpers"
18)
19
20type OAuth struct {
21 store *sessions.CookieStore
22 config *config.Config
23 sess *sessioncache.SessionStore
24}
25
26func NewOAuth(config *config.Config, sess *sessioncache.SessionStore) *OAuth {
27 return &OAuth{
28 store: sessions.NewCookieStore([]byte(config.Core.CookieSecret)),
29 config: config,
30 sess: sess,
31 }
32}
33
34func (o *OAuth) Stores() *sessions.CookieStore {
35 return o.store
36}
37
38func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, oreq sessioncache.OAuthRequest, oresp *oauth.TokenResponse) error {
39 // first we save the did in the user session
40 userSession, err := o.store.Get(r, SessionName)
41 if err != nil {
42 return err
43 }
44
45 userSession.Values[SessionDid] = oreq.Did
46 userSession.Values[SessionHandle] = oreq.Handle
47 userSession.Values[SessionPds] = oreq.PdsUrl
48 userSession.Values[SessionAuthenticated] = true
49 err = userSession.Save(r, w)
50 if err != nil {
51 return fmt.Errorf("error saving user session: %w", err)
52 }
53
54 // then save the whole thing in the db
55 session := sessioncache.OAuthSession{
56 Did: oreq.Did,
57 Handle: oreq.Handle,
58 PdsUrl: oreq.PdsUrl,
59 DpopAuthserverNonce: oreq.DpopAuthserverNonce,
60 AuthServerIss: oreq.AuthserverIss,
61 DpopPrivateJwk: oreq.DpopPrivateJwk,
62 AccessJwt: oresp.AccessToken,
63 RefreshJwt: oresp.RefreshToken,
64 Expiry: time.Now().Add(time.Duration(oresp.ExpiresIn) * time.Second).Format(time.RFC3339),
65 }
66
67 return o.sess.SaveSession(r.Context(), session)
68}
69
70func (o *OAuth) ClearSession(r *http.Request, w http.ResponseWriter) error {
71 userSession, err := o.store.Get(r, SessionName)
72 if err != nil || userSession.IsNew {
73 return fmt.Errorf("error getting user session (or new session?): %w", err)
74 }
75
76 did := userSession.Values[SessionDid].(string)
77
78 err = o.sess.DeleteSession(r.Context(), did)
79 if err != nil {
80 return fmt.Errorf("error deleting oauth session: %w", err)
81 }
82
83 userSession.Options.MaxAge = -1
84
85 return userSession.Save(r, w)
86}
87
88func (o *OAuth) GetSession(r *http.Request) (*sessioncache.OAuthSession, bool, error) {
89 userSession, err := o.store.Get(r, SessionName)
90 if err != nil || userSession.IsNew {
91 return nil, false, fmt.Errorf("error getting user session (or new session?): %w", err)
92 }
93
94 did := userSession.Values[SessionDid].(string)
95 auth := userSession.Values[SessionAuthenticated].(bool)
96
97 session, err := o.sess.GetSession(r.Context(), did)
98 if err != nil {
99 return nil, false, fmt.Errorf("error getting oauth session: %w", err)
100 }
101
102 expiry, err := time.Parse(time.RFC3339, session.Expiry)
103 if err != nil {
104 return nil, false, fmt.Errorf("error parsing expiry time: %w", err)
105 }
106 if time.Until(expiry) <= 5*time.Minute {
107 privateJwk, err := helpers.ParseJWKFromBytes([]byte(session.DpopPrivateJwk))
108 if err != nil {
109 return nil, false, err
110 }
111
112 self := o.ClientMetadata()
113
114 oauthClient, err := client.NewClient(
115 self.ClientID,
116 o.config.OAuth.Jwks,
117 self.RedirectURIs[0],
118 )
119
120 if err != nil {
121 return nil, false, err
122 }
123
124 resp, err := oauthClient.RefreshTokenRequest(r.Context(), session.RefreshJwt, session.AuthServerIss, session.DpopAuthserverNonce, privateJwk)
125 if err != nil {
126 return nil, false, err
127 }
128
129 newExpiry := time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second).Format(time.RFC3339)
130 err = o.sess.RefreshSession(r.Context(), did, resp.AccessToken, resp.RefreshToken, newExpiry)
131 if err != nil {
132 return nil, false, fmt.Errorf("error refreshing oauth session: %w", err)
133 }
134
135 // update the current session
136 session.AccessJwt = resp.AccessToken
137 session.RefreshJwt = resp.RefreshToken
138 session.DpopAuthserverNonce = resp.DpopAuthserverNonce
139 session.Expiry = newExpiry
140 }
141
142 return session, auth, nil
143}
144
145type User struct {
146 Handle string
147 Did string
148 Pds string
149}
150
151func (a *OAuth) GetUser(r *http.Request) *User {
152 clientSession, err := a.store.Get(r, SessionName)
153
154 if err != nil || clientSession.IsNew {
155 return nil
156 }
157
158 return &User{
159 Handle: clientSession.Values[SessionHandle].(string),
160 Did: clientSession.Values[SessionDid].(string),
161 Pds: clientSession.Values[SessionPds].(string),
162 }
163}
164
165func (a *OAuth) GetDid(r *http.Request) string {
166 clientSession, err := a.store.Get(r, SessionName)
167
168 if err != nil || clientSession.IsNew {
169 return ""
170 }
171
172 return clientSession.Values[SessionDid].(string)
173}
174
175func (o *OAuth) AuthorizedClient(r *http.Request) (*xrpc.Client, error) {
176 session, auth, err := o.GetSession(r)
177 if err != nil {
178 return nil, fmt.Errorf("error getting session: %w", err)
179 }
180 if !auth {
181 return nil, fmt.Errorf("not authorized")
182 }
183
184 client := &oauth.XrpcClient{
185 OnDpopPdsNonceChanged: func(did, newNonce string) {
186 err := o.sess.UpdateNonce(r.Context(), did, newNonce)
187 if err != nil {
188 log.Printf("error updating dpop pds nonce: %v", err)
189 }
190 },
191 }
192
193 privateJwk, err := helpers.ParseJWKFromBytes([]byte(session.DpopPrivateJwk))
194 if err != nil {
195 return nil, fmt.Errorf("error parsing private jwk: %w", err)
196 }
197
198 xrpcClient := xrpc.NewClient(client, &oauth.XrpcAuthedRequestArgs{
199 Did: session.Did,
200 PdsUrl: session.PdsUrl,
201 DpopPdsNonce: session.PdsUrl,
202 AccessToken: session.AccessJwt,
203 Issuer: session.AuthServerIss,
204 DpopPrivateJwk: privateJwk,
205 })
206
207 return xrpcClient, nil
208}
209
210// use this to create a client to communicate with knots or spindles
211//
212// this is a higher level abstraction on ServerGetServiceAuth
213type ServiceClientOpts struct {
214 service string
215 exp int64
216 lxm string
217 dev bool
218}
219
220type ServiceClientOpt func(*ServiceClientOpts)
221
222func WithService(service string) ServiceClientOpt {
223 return func(s *ServiceClientOpts) {
224 s.service = service
225 }
226}
227
228// Specify the Duration in seconds for the expiry of this token
229//
230// The time of expiry is calculated as time.Now().Unix() + exp
231func WithExp(exp int64) ServiceClientOpt {
232 return func(s *ServiceClientOpts) {
233 s.exp = time.Now().Unix() + exp
234 }
235}
236
237func WithLxm(lxm string) ServiceClientOpt {
238 return func(s *ServiceClientOpts) {
239 s.lxm = lxm
240 }
241}
242
243func WithDev(dev bool) ServiceClientOpt {
244 return func(s *ServiceClientOpts) {
245 s.dev = dev
246 }
247}
248
249func (s *ServiceClientOpts) Audience() string {
250 return fmt.Sprintf("did:web:%s", s.service)
251}
252
253func (s *ServiceClientOpts) Host() string {
254 scheme := "https://"
255 if s.dev {
256 scheme = "http://"
257 }
258
259 return scheme + s.service
260}
261
262func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*indigo_xrpc.Client, error) {
263 opts := ServiceClientOpts{}
264 for _, o := range os {
265 o(&opts)
266 }
267
268 authorizedClient, err := o.AuthorizedClient(r)
269 if err != nil {
270 return nil, err
271 }
272
273 // force expiry to atleast 60 seconds in the future
274 sixty := time.Now().Unix() + 60
275 if opts.exp < sixty {
276 opts.exp = sixty
277 }
278
279 resp, err := authorizedClient.ServerGetServiceAuth(r.Context(), opts.Audience(), opts.exp, opts.lxm)
280 if err != nil {
281 return nil, err
282 }
283
284 return &indigo_xrpc.Client{
285 Auth: &indigo_xrpc.AuthInfo{
286 AccessJwt: resp.Token,
287 },
288 Host: opts.Host(),
289 Client: &http.Client{
290 Timeout: time.Second * 5,
291 },
292 }, nil
293}
294
295type ClientMetadata struct {
296 ClientID string `json:"client_id"`
297 ClientName string `json:"client_name"`
298 SubjectType string `json:"subject_type"`
299 ClientURI string `json:"client_uri"`
300 RedirectURIs []string `json:"redirect_uris"`
301 GrantTypes []string `json:"grant_types"`
302 ResponseTypes []string `json:"response_types"`
303 ApplicationType string `json:"application_type"`
304 DpopBoundAccessTokens bool `json:"dpop_bound_access_tokens"`
305 JwksURI string `json:"jwks_uri"`
306 Scope string `json:"scope"`
307 TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"`
308 TokenEndpointAuthSigningAlg string `json:"token_endpoint_auth_signing_alg"`
309}
310
311func (o *OAuth) ClientMetadata() ClientMetadata {
312 makeRedirectURIs := func(c string) []string {
313 return []string{fmt.Sprintf("%s/oauth/callback", c)}
314 }
315
316 clientURI := o.config.Core.AppviewHost
317 clientID := fmt.Sprintf("%s/oauth/client-metadata.json", clientURI)
318 redirectURIs := makeRedirectURIs(clientURI)
319
320 if o.config.Core.Dev {
321 clientURI = "http://127.0.0.1:3000"
322 redirectURIs = makeRedirectURIs(clientURI)
323
324 query := url.Values{}
325 query.Add("redirect_uri", redirectURIs[0])
326 query.Add("scope", "atproto transition:generic")
327 clientID = fmt.Sprintf("http://localhost?%s", query.Encode())
328 }
329
330 jwksURI := fmt.Sprintf("%s/oauth/jwks.json", clientURI)
331
332 return ClientMetadata{
333 ClientID: clientID,
334 ClientName: "Tangled",
335 SubjectType: "public",
336 ClientURI: clientURI,
337 RedirectURIs: redirectURIs,
338 GrantTypes: []string{"authorization_code", "refresh_token"},
339 ResponseTypes: []string{"code"},
340 ApplicationType: "web",
341 DpopBoundAccessTokens: true,
342 JwksURI: jwksURI,
343 Scope: "atproto transition:generic",
344 TokenEndpointAuthMethod: "private_key_jwt",
345 TokenEndpointAuthSigningAlg: "ES256",
346 }
347}