1package oauth
2
3import (
4 "fmt"
5 "log"
6 "net/http"
7 "net/url"
8 "time"
9
10 "github.com/gorilla/sessions"
11 oauth "tangled.sh/icyphox.sh/atproto-oauth"
12 "tangled.sh/icyphox.sh/atproto-oauth/helpers"
13 sessioncache "tangled.sh/tangled.sh/core/appview/cache/session"
14 "tangled.sh/tangled.sh/core/appview/config"
15 "tangled.sh/tangled.sh/core/appview/oauth/client"
16 xrpc "tangled.sh/tangled.sh/core/appview/xrpcclient"
17)
18
19type OAuth struct {
20 store *sessions.CookieStore
21 config *config.Config
22 sess *sessioncache.SessionStore
23}
24
25func NewOAuth(config *config.Config, sess *sessioncache.SessionStore) *OAuth {
26 return &OAuth{
27 store: sessions.NewCookieStore([]byte(config.Core.CookieSecret)),
28 config: config,
29 sess: sess,
30 }
31}
32
33func (o *OAuth) Stores() *sessions.CookieStore {
34 return o.store
35}
36
37func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, oreq sessioncache.OAuthRequest, oresp *oauth.TokenResponse) error {
38 // first we save the did in the user session
39 userSession, err := o.store.Get(r, SessionName)
40 if err != nil {
41 return err
42 }
43
44 userSession.Values[SessionDid] = oreq.Did
45 userSession.Values[SessionHandle] = oreq.Handle
46 userSession.Values[SessionPds] = oreq.PdsUrl
47 userSession.Values[SessionAuthenticated] = true
48 err = userSession.Save(r, w)
49 if err != nil {
50 return fmt.Errorf("error saving user session: %w", err)
51 }
52
53 // then save the whole thing in the db
54 session := sessioncache.OAuthSession{
55 Did: oreq.Did,
56 Handle: oreq.Handle,
57 PdsUrl: oreq.PdsUrl,
58 DpopAuthserverNonce: oreq.DpopAuthserverNonce,
59 AuthServerIss: oreq.AuthserverIss,
60 DpopPrivateJwk: oreq.DpopPrivateJwk,
61 AccessJwt: oresp.AccessToken,
62 RefreshJwt: oresp.RefreshToken,
63 Expiry: time.Now().Add(time.Duration(oresp.ExpiresIn) * time.Second).Format(time.RFC3339),
64 }
65
66 return o.sess.SaveSession(r.Context(), session)
67}
68
69func (o *OAuth) ClearSession(r *http.Request, w http.ResponseWriter) error {
70 userSession, err := o.store.Get(r, SessionName)
71 if err != nil || userSession.IsNew {
72 return fmt.Errorf("error getting user session (or new session?): %w", err)
73 }
74
75 did := userSession.Values[SessionDid].(string)
76
77 err = o.sess.DeleteSession(r.Context(), did)
78 if err != nil {
79 return fmt.Errorf("error deleting oauth session: %w", err)
80 }
81
82 userSession.Options.MaxAge = -1
83
84 return userSession.Save(r, w)
85}
86
87func (o *OAuth) GetSession(r *http.Request) (*sessioncache.OAuthSession, bool, error) {
88 userSession, err := o.store.Get(r, SessionName)
89 if err != nil || userSession.IsNew {
90 return nil, false, fmt.Errorf("error getting user session (or new session?): %w", err)
91 }
92
93 did := userSession.Values[SessionDid].(string)
94 auth := userSession.Values[SessionAuthenticated].(bool)
95
96 session, err := o.sess.GetSession(r.Context(), did)
97 if err != nil {
98 return nil, false, fmt.Errorf("error getting oauth session: %w", err)
99 }
100
101 expiry, err := time.Parse(time.RFC3339, session.Expiry)
102 if err != nil {
103 return nil, false, fmt.Errorf("error parsing expiry time: %w", err)
104 }
105 if expiry.Sub(time.Now()) <= 5*time.Minute {
106 privateJwk, err := helpers.ParseJWKFromBytes([]byte(session.DpopPrivateJwk))
107 if err != nil {
108 return nil, false, err
109 }
110
111 self := o.ClientMetadata()
112
113 oauthClient, err := client.NewClient(
114 self.ClientID,
115 o.config.OAuth.Jwks,
116 self.RedirectURIs[0],
117 )
118
119 if err != nil {
120 return nil, false, err
121 }
122
123 resp, err := oauthClient.RefreshTokenRequest(r.Context(), session.RefreshJwt, session.AuthServerIss, session.DpopAuthserverNonce, privateJwk)
124 if err != nil {
125 return nil, false, err
126 }
127
128 newExpiry := time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second).Format(time.RFC3339)
129 err = o.sess.RefreshSession(r.Context(), did, resp.AccessToken, resp.RefreshToken, newExpiry)
130 if err != nil {
131 return nil, false, fmt.Errorf("error refreshing oauth session: %w", err)
132 }
133
134 // update the current session
135 session.AccessJwt = resp.AccessToken
136 session.RefreshJwt = resp.RefreshToken
137 session.DpopAuthserverNonce = resp.DpopAuthserverNonce
138 session.Expiry = newExpiry
139 }
140
141 return session, auth, nil
142}
143
144type User struct {
145 Handle string
146 Did string
147 Pds string
148}
149
150func (a *OAuth) GetUser(r *http.Request) *User {
151 clientSession, err := a.store.Get(r, SessionName)
152
153 if err != nil || clientSession.IsNew {
154 return nil
155 }
156
157 return &User{
158 Handle: clientSession.Values[SessionHandle].(string),
159 Did: clientSession.Values[SessionDid].(string),
160 Pds: clientSession.Values[SessionPds].(string),
161 }
162}
163
164func (a *OAuth) GetDid(r *http.Request) string {
165 clientSession, err := a.store.Get(r, SessionName)
166
167 if err != nil || clientSession.IsNew {
168 return ""
169 }
170
171 return clientSession.Values[SessionDid].(string)
172}
173
174func (o *OAuth) AuthorizedClient(r *http.Request) (*xrpc.Client, error) {
175 session, auth, err := o.GetSession(r)
176 if err != nil {
177 return nil, fmt.Errorf("error getting session: %w", err)
178 }
179 if !auth {
180 return nil, fmt.Errorf("not authorized")
181 }
182
183 client := &oauth.XrpcClient{
184 OnDpopPdsNonceChanged: func(did, newNonce string) {
185 err := o.sess.UpdateNonce(r.Context(), did, newNonce)
186 if err != nil {
187 log.Printf("error updating dpop pds nonce: %v", err)
188 }
189 },
190 }
191
192 privateJwk, err := helpers.ParseJWKFromBytes([]byte(session.DpopPrivateJwk))
193 if err != nil {
194 return nil, fmt.Errorf("error parsing private jwk: %w", err)
195 }
196
197 xrpcClient := xrpc.NewClient(client, &oauth.XrpcAuthedRequestArgs{
198 Did: session.Did,
199 PdsUrl: session.PdsUrl,
200 DpopPdsNonce: session.PdsUrl,
201 AccessToken: session.AccessJwt,
202 Issuer: session.AuthServerIss,
203 DpopPrivateJwk: privateJwk,
204 })
205
206 return xrpcClient, nil
207}
208
209type ClientMetadata struct {
210 ClientID string `json:"client_id"`
211 ClientName string `json:"client_name"`
212 SubjectType string `json:"subject_type"`
213 ClientURI string `json:"client_uri"`
214 RedirectURIs []string `json:"redirect_uris"`
215 GrantTypes []string `json:"grant_types"`
216 ResponseTypes []string `json:"response_types"`
217 ApplicationType string `json:"application_type"`
218 DpopBoundAccessTokens bool `json:"dpop_bound_access_tokens"`
219 JwksURI string `json:"jwks_uri"`
220 Scope string `json:"scope"`
221 TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"`
222 TokenEndpointAuthSigningAlg string `json:"token_endpoint_auth_signing_alg"`
223}
224
225func (o *OAuth) ClientMetadata() ClientMetadata {
226 makeRedirectURIs := func(c string) []string {
227 return []string{fmt.Sprintf("%s/oauth/callback", c)}
228 }
229
230 clientURI := o.config.Core.AppviewHost
231 clientID := fmt.Sprintf("%s/oauth/client-metadata.json", clientURI)
232 redirectURIs := makeRedirectURIs(clientURI)
233
234 if o.config.Core.Dev {
235 clientURI = fmt.Sprintf("http://127.0.0.1:3000")
236 redirectURIs = makeRedirectURIs(clientURI)
237
238 query := url.Values{}
239 query.Add("redirect_uri", redirectURIs[0])
240 query.Add("scope", "atproto transition:generic")
241 clientID = fmt.Sprintf("http://localhost?%s", query.Encode())
242 }
243
244 jwksURI := fmt.Sprintf("%s/oauth/jwks.json", clientURI)
245
246 return ClientMetadata{
247 ClientID: clientID,
248 ClientName: "Tangled",
249 SubjectType: "public",
250 ClientURI: clientURI,
251 RedirectURIs: redirectURIs,
252 GrantTypes: []string{"authorization_code", "refresh_token"},
253 ResponseTypes: []string{"code"},
254 ApplicationType: "web",
255 DpopBoundAccessTokens: true,
256 JwksURI: jwksURI,
257 Scope: "atproto transition:generic",
258 TokenEndpointAuthMethod: "private_key_jwt",
259 TokenEndpointAuthSigningAlg: "ES256",
260 }
261}