1package oauth
2
3import (
4 "fmt"
5 "log"
6 "net/http"
7 "net/url"
8 "time"
9
10 indigo_xrpc "github.com/bluesky-social/indigo/xrpc"
11 "github.com/gorilla/sessions"
12 oauth "tangled.sh/icyphox.sh/atproto-oauth"
13 "tangled.sh/icyphox.sh/atproto-oauth/helpers"
14 sessioncache "tangled.sh/tangled.sh/core/appview/cache/session"
15 "tangled.sh/tangled.sh/core/appview/config"
16 "tangled.sh/tangled.sh/core/appview/oauth/client"
17 xrpc "tangled.sh/tangled.sh/core/appview/xrpcclient"
18)
19
20type OAuth struct {
21 store *sessions.CookieStore
22 config *config.Config
23 sess *sessioncache.SessionStore
24}
25
26func NewOAuth(config *config.Config, sess *sessioncache.SessionStore) *OAuth {
27 return &OAuth{
28 store: sessions.NewCookieStore([]byte(config.Core.CookieSecret)),
29 config: config,
30 sess: sess,
31 }
32}
33
34func (o *OAuth) Stores() *sessions.CookieStore {
35 return o.store
36}
37
38func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, oreq sessioncache.OAuthRequest, oresp *oauth.TokenResponse) error {
39 // first we save the did in the user session
40 userSession, err := o.store.Get(r, SessionName)
41 if err != nil {
42 return err
43 }
44
45 userSession.Values[SessionDid] = oreq.Did
46 userSession.Values[SessionHandle] = oreq.Handle
47 userSession.Values[SessionPds] = oreq.PdsUrl
48 userSession.Values[SessionAuthenticated] = true
49 err = userSession.Save(r, w)
50 if err != nil {
51 return fmt.Errorf("error saving user session: %w", err)
52 }
53
54 // then save the whole thing in the db
55 session := sessioncache.OAuthSession{
56 Did: oreq.Did,
57 Handle: oreq.Handle,
58 PdsUrl: oreq.PdsUrl,
59 DpopAuthserverNonce: oreq.DpopAuthserverNonce,
60 AuthServerIss: oreq.AuthserverIss,
61 DpopPrivateJwk: oreq.DpopPrivateJwk,
62 AccessJwt: oresp.AccessToken,
63 RefreshJwt: oresp.RefreshToken,
64 Expiry: time.Now().Add(time.Duration(oresp.ExpiresIn) * time.Second).Format(time.RFC3339),
65 }
66
67 return o.sess.SaveSession(r.Context(), session)
68}
69
70func (o *OAuth) ClearSession(r *http.Request, w http.ResponseWriter) error {
71 userSession, err := o.store.Get(r, SessionName)
72 if err != nil || userSession.IsNew {
73 return fmt.Errorf("error getting user session (or new session?): %w", err)
74 }
75
76 did := userSession.Values[SessionDid].(string)
77
78 err = o.sess.DeleteSession(r.Context(), did)
79 if err != nil {
80 return fmt.Errorf("error deleting oauth session: %w", err)
81 }
82
83 userSession.Options.MaxAge = -1
84
85 return userSession.Save(r, w)
86}
87
88func (o *OAuth) GetSession(r *http.Request) (*sessioncache.OAuthSession, bool, error) {
89 userSession, err := o.store.Get(r, SessionName)
90 if err != nil || userSession.IsNew {
91 return nil, false, fmt.Errorf("error getting user session (or new session?): %w", err)
92 }
93
94 did := userSession.Values[SessionDid].(string)
95 auth := userSession.Values[SessionAuthenticated].(bool)
96
97 session, err := o.sess.GetSession(r.Context(), did)
98 if err != nil {
99 return nil, false, fmt.Errorf("error getting oauth session: %w", err)
100 }
101
102 expiry, err := time.Parse(time.RFC3339, session.Expiry)
103 if err != nil {
104 return nil, false, fmt.Errorf("error parsing expiry time: %w", err)
105 }
106 if expiry.Sub(time.Now()) <= 5*time.Minute {
107 privateJwk, err := helpers.ParseJWKFromBytes([]byte(session.DpopPrivateJwk))
108 if err != nil {
109 return nil, false, err
110 }
111
112 self := o.ClientMetadata()
113
114 oauthClient, err := client.NewClient(
115 self.ClientID,
116 o.config.OAuth.Jwks,
117 self.RedirectURIs[0],
118 )
119
120 if err != nil {
121 return nil, false, err
122 }
123
124 resp, err := oauthClient.RefreshTokenRequest(r.Context(), session.RefreshJwt, session.AuthServerIss, session.DpopAuthserverNonce, privateJwk)
125 if err != nil {
126 return nil, false, err
127 }
128
129 newExpiry := time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second).Format(time.RFC3339)
130 err = o.sess.RefreshSession(r.Context(), did, resp.AccessToken, resp.RefreshToken, newExpiry)
131 if err != nil {
132 return nil, false, fmt.Errorf("error refreshing oauth session: %w", err)
133 }
134
135 // update the current session
136 session.AccessJwt = resp.AccessToken
137 session.RefreshJwt = resp.RefreshToken
138 session.DpopAuthserverNonce = resp.DpopAuthserverNonce
139 session.Expiry = newExpiry
140 }
141
142 return session, auth, nil
143}
144
145type User struct {
146 Handle string
147 Did string
148 Pds string
149}
150
151func (a *OAuth) GetUser(r *http.Request) *User {
152 clientSession, err := a.store.Get(r, SessionName)
153
154 if err != nil || clientSession.IsNew {
155 return nil
156 }
157
158 return &User{
159 Handle: clientSession.Values[SessionHandle].(string),
160 Did: clientSession.Values[SessionDid].(string),
161 Pds: clientSession.Values[SessionPds].(string),
162 }
163}
164
165func (a *OAuth) GetDid(r *http.Request) string {
166 clientSession, err := a.store.Get(r, SessionName)
167
168 if err != nil || clientSession.IsNew {
169 return ""
170 }
171
172 return clientSession.Values[SessionDid].(string)
173}
174
175func (o *OAuth) AuthorizedClient(r *http.Request) (*xrpc.Client, error) {
176 session, auth, err := o.GetSession(r)
177 if err != nil {
178 return nil, fmt.Errorf("error getting session: %w", err)
179 }
180 if !auth {
181 return nil, fmt.Errorf("not authorized")
182 }
183
184 client := &oauth.XrpcClient{
185 OnDpopPdsNonceChanged: func(did, newNonce string) {
186 err := o.sess.UpdateNonce(r.Context(), did, newNonce)
187 if err != nil {
188 log.Printf("error updating dpop pds nonce: %v", err)
189 }
190 },
191 }
192
193 privateJwk, err := helpers.ParseJWKFromBytes([]byte(session.DpopPrivateJwk))
194 if err != nil {
195 return nil, fmt.Errorf("error parsing private jwk: %w", err)
196 }
197
198 xrpcClient := xrpc.NewClient(client, &oauth.XrpcAuthedRequestArgs{
199 Did: session.Did,
200 PdsUrl: session.PdsUrl,
201 DpopPdsNonce: session.PdsUrl,
202 AccessToken: session.AccessJwt,
203 Issuer: session.AuthServerIss,
204 DpopPrivateJwk: privateJwk,
205 })
206
207 return xrpcClient, nil
208}
209
210// use this to create a client to communicate with knots or spindles
211//
212// this is a higher level abstraction on ServerGetServiceAuth
213type ServiceClientOpts struct {
214 service string
215 exp int64
216 lxm string
217 dev bool
218}
219
220type ServiceClientOpt func(*ServiceClientOpts)
221
222func WithService(service string) ServiceClientOpt {
223 return func(s *ServiceClientOpts) {
224 s.service = service
225 }
226}
227func WithExp(exp int64) ServiceClientOpt {
228 return func(s *ServiceClientOpts) {
229 s.exp = exp
230 }
231}
232
233func WithLxm(lxm string) ServiceClientOpt {
234 return func(s *ServiceClientOpts) {
235 s.lxm = lxm
236 }
237}
238
239func WithDev(dev bool) ServiceClientOpt {
240 return func(s *ServiceClientOpts) {
241 s.dev = dev
242 }
243}
244
245func (s *ServiceClientOpts) Audience() string {
246 return fmt.Sprintf("did:web:%s", s.service)
247}
248
249func (s *ServiceClientOpts) Host() string {
250 scheme := "https://"
251 if s.dev {
252 scheme = "http://"
253 }
254
255 return scheme + s.service
256}
257
258func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*indigo_xrpc.Client, error) {
259 opts := ServiceClientOpts{}
260 for _, o := range os {
261 o(&opts)
262 }
263
264 authorizedClient, err := o.AuthorizedClient(r)
265 if err != nil {
266 return nil, err
267 }
268
269 resp, err := authorizedClient.ServerGetServiceAuth(r.Context(), opts.Audience(), opts.exp, opts.lxm)
270 if err != nil {
271 return nil, err
272 }
273
274 return &indigo_xrpc.Client{
275 Auth: &indigo_xrpc.AuthInfo{
276 AccessJwt: resp.Token,
277 },
278 Host: opts.Host(),
279 }, nil
280}
281
282type ClientMetadata struct {
283 ClientID string `json:"client_id"`
284 ClientName string `json:"client_name"`
285 SubjectType string `json:"subject_type"`
286 ClientURI string `json:"client_uri"`
287 RedirectURIs []string `json:"redirect_uris"`
288 GrantTypes []string `json:"grant_types"`
289 ResponseTypes []string `json:"response_types"`
290 ApplicationType string `json:"application_type"`
291 DpopBoundAccessTokens bool `json:"dpop_bound_access_tokens"`
292 JwksURI string `json:"jwks_uri"`
293 Scope string `json:"scope"`
294 TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"`
295 TokenEndpointAuthSigningAlg string `json:"token_endpoint_auth_signing_alg"`
296}
297
298func (o *OAuth) ClientMetadata() ClientMetadata {
299 makeRedirectURIs := func(c string) []string {
300 return []string{fmt.Sprintf("%s/oauth/callback", c)}
301 }
302
303 clientURI := o.config.Core.AppviewHost
304 clientID := fmt.Sprintf("%s/oauth/client-metadata.json", clientURI)
305 redirectURIs := makeRedirectURIs(clientURI)
306
307 if o.config.Core.Dev {
308 clientURI = fmt.Sprintf("http://127.0.0.1:3000")
309 redirectURIs = makeRedirectURIs(clientURI)
310
311 query := url.Values{}
312 query.Add("redirect_uri", redirectURIs[0])
313 query.Add("scope", "atproto transition:generic")
314 clientID = fmt.Sprintf("http://localhost?%s", query.Encode())
315 }
316
317 jwksURI := fmt.Sprintf("%s/oauth/jwks.json", clientURI)
318
319 return ClientMetadata{
320 ClientID: clientID,
321 ClientName: "Tangled",
322 SubjectType: "public",
323 ClientURI: clientURI,
324 RedirectURIs: redirectURIs,
325 GrantTypes: []string{"authorization_code", "refresh_token"},
326 ResponseTypes: []string{"code"},
327 ApplicationType: "web",
328 DpopBoundAccessTokens: true,
329 JwksURI: jwksURI,
330 Scope: "atproto transition:generic",
331 TokenEndpointAuthMethod: "private_key_jwt",
332 TokenEndpointAuthSigningAlg: "ES256",
333 }
334}