1package oauth
2
3import (
4 "fmt"
5 "log"
6 "net/http"
7 "net/url"
8 "time"
9
10 "github.com/gorilla/sessions"
11 oauth "tangled.sh/icyphox.sh/atproto-oauth"
12 "tangled.sh/icyphox.sh/atproto-oauth/helpers"
13 "tangled.sh/tangled.sh/core/appview"
14 "tangled.sh/tangled.sh/core/appview/config"
15 "tangled.sh/tangled.sh/core/appview/db"
16 "tangled.sh/tangled.sh/core/appview/oauth/client"
17 xrpc "tangled.sh/tangled.sh/core/appview/xrpcclient"
18)
19
20type OAuthRequest struct {
21 ID uint
22 AuthserverIss string
23 State string
24 Did string
25 PdsUrl string
26 PkceVerifier string
27 DpopAuthserverNonce string
28 DpopPrivateJwk string
29}
30
31type OAuth struct {
32 Store *sessions.CookieStore
33 Db *db.DB
34 Config *config.Config
35}
36
37func NewOAuth(db *db.DB, config *config.Config) *OAuth {
38 return &OAuth{
39 Store: sessions.NewCookieStore([]byte(config.Core.CookieSecret)),
40 Db: db,
41 Config: config,
42 }
43}
44
45func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, oreq db.OAuthRequest, oresp *oauth.TokenResponse) error {
46 // first we save the did in the user session
47 userSession, err := o.Store.Get(r, appview.SessionName)
48 if err != nil {
49 return err
50 }
51
52 userSession.Values[appview.SessionDid] = oreq.Did
53 userSession.Values[appview.SessionHandle] = oreq.Handle
54 userSession.Values[appview.SessionPds] = oreq.PdsUrl
55 userSession.Values[appview.SessionAuthenticated] = true
56 err = userSession.Save(r, w)
57 if err != nil {
58 return fmt.Errorf("error saving user session: %w", err)
59 }
60
61 // then save the whole thing in the db
62 session := db.OAuthSession{
63 Did: oreq.Did,
64 Handle: oreq.Handle,
65 PdsUrl: oreq.PdsUrl,
66 DpopAuthserverNonce: oreq.DpopAuthserverNonce,
67 AuthServerIss: oreq.AuthserverIss,
68 DpopPrivateJwk: oreq.DpopPrivateJwk,
69 AccessJwt: oresp.AccessToken,
70 RefreshJwt: oresp.RefreshToken,
71 Expiry: time.Now().Add(time.Duration(oresp.ExpiresIn) * time.Second).Format(time.RFC3339),
72 }
73
74 return db.SaveOAuthSession(o.Db, session)
75}
76
77func (o *OAuth) ClearSession(r *http.Request, w http.ResponseWriter) error {
78 userSession, err := o.Store.Get(r, appview.SessionName)
79 if err != nil || userSession.IsNew {
80 return fmt.Errorf("error getting user session (or new session?): %w", err)
81 }
82
83 did := userSession.Values[appview.SessionDid].(string)
84
85 err = db.DeleteOAuthSessionByDid(o.Db, did)
86 if err != nil {
87 return fmt.Errorf("error deleting oauth session: %w", err)
88 }
89
90 userSession.Options.MaxAge = -1
91
92 return userSession.Save(r, w)
93}
94
95func (o *OAuth) GetSession(r *http.Request) (*db.OAuthSession, bool, error) {
96 userSession, err := o.Store.Get(r, appview.SessionName)
97 if err != nil || userSession.IsNew {
98 return nil, false, fmt.Errorf("error getting user session (or new session?): %w", err)
99 }
100
101 did := userSession.Values[appview.SessionDid].(string)
102 auth := userSession.Values[appview.SessionAuthenticated].(bool)
103
104 session, err := db.GetOAuthSessionByDid(o.Db, did)
105 if err != nil {
106 return nil, false, fmt.Errorf("error getting oauth session: %w", err)
107 }
108
109 expiry, err := time.Parse(time.RFC3339, session.Expiry)
110 if err != nil {
111 return nil, false, fmt.Errorf("error parsing expiry time: %w", err)
112 }
113 if expiry.Sub(time.Now()) <= 5*time.Minute {
114 privateJwk, err := helpers.ParseJWKFromBytes([]byte(session.DpopPrivateJwk))
115 if err != nil {
116 return nil, false, err
117 }
118
119 self := o.ClientMetadata()
120
121 oauthClient, err := client.NewClient(
122 self.ClientID,
123 o.Config.OAuth.Jwks,
124 self.RedirectURIs[0],
125 )
126
127 if err != nil {
128 return nil, false, err
129 }
130
131 resp, err := oauthClient.RefreshTokenRequest(r.Context(), session.RefreshJwt, session.AuthServerIss, session.DpopAuthserverNonce, privateJwk)
132 if err != nil {
133 return nil, false, err
134 }
135
136 newExpiry := time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second).Format(time.RFC3339)
137 err = db.RefreshOAuthSession(o.Db, did, resp.AccessToken, resp.RefreshToken, newExpiry)
138 if err != nil {
139 return nil, false, fmt.Errorf("error refreshing oauth session: %w", err)
140 }
141
142 // update the current session
143 session.AccessJwt = resp.AccessToken
144 session.RefreshJwt = resp.RefreshToken
145 session.DpopAuthserverNonce = resp.DpopAuthserverNonce
146 session.Expiry = newExpiry
147 }
148
149 return session, auth, nil
150}
151
152type User struct {
153 Handle string
154 Did string
155 Pds string
156}
157
158func (a *OAuth) GetUser(r *http.Request) *User {
159 clientSession, err := a.Store.Get(r, appview.SessionName)
160
161 if err != nil || clientSession.IsNew {
162 return nil
163 }
164
165 return &User{
166 Handle: clientSession.Values[appview.SessionHandle].(string),
167 Did: clientSession.Values[appview.SessionDid].(string),
168 Pds: clientSession.Values[appview.SessionPds].(string),
169 }
170}
171
172func (a *OAuth) GetDid(r *http.Request) string {
173 clientSession, err := a.Store.Get(r, appview.SessionName)
174
175 if err != nil || clientSession.IsNew {
176 return ""
177 }
178
179 return clientSession.Values[appview.SessionDid].(string)
180}
181
182func (o *OAuth) AuthorizedClient(r *http.Request) (*xrpc.Client, error) {
183 session, auth, err := o.GetSession(r)
184 if err != nil {
185 return nil, fmt.Errorf("error getting session: %w", err)
186 }
187 if !auth {
188 return nil, fmt.Errorf("not authorized")
189 }
190
191 client := &oauth.XrpcClient{
192 OnDpopPdsNonceChanged: func(did, newNonce string) {
193 err := db.UpdateDpopPdsNonce(o.Db, did, newNonce)
194 if err != nil {
195 log.Printf("error updating dpop pds nonce: %v", err)
196 }
197 },
198 }
199
200 privateJwk, err := helpers.ParseJWKFromBytes([]byte(session.DpopPrivateJwk))
201 if err != nil {
202 return nil, fmt.Errorf("error parsing private jwk: %w", err)
203 }
204
205 xrpcClient := xrpc.NewClient(client, &oauth.XrpcAuthedRequestArgs{
206 Did: session.Did,
207 PdsUrl: session.PdsUrl,
208 DpopPdsNonce: session.PdsUrl,
209 AccessToken: session.AccessJwt,
210 Issuer: session.AuthServerIss,
211 DpopPrivateJwk: privateJwk,
212 })
213
214 return xrpcClient, nil
215}
216
217type ClientMetadata struct {
218 ClientID string `json:"client_id"`
219 ClientName string `json:"client_name"`
220 SubjectType string `json:"subject_type"`
221 ClientURI string `json:"client_uri"`
222 RedirectURIs []string `json:"redirect_uris"`
223 GrantTypes []string `json:"grant_types"`
224 ResponseTypes []string `json:"response_types"`
225 ApplicationType string `json:"application_type"`
226 DpopBoundAccessTokens bool `json:"dpop_bound_access_tokens"`
227 JwksURI string `json:"jwks_uri"`
228 Scope string `json:"scope"`
229 TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"`
230 TokenEndpointAuthSigningAlg string `json:"token_endpoint_auth_signing_alg"`
231}
232
233func (o *OAuth) ClientMetadata() ClientMetadata {
234 makeRedirectURIs := func(c string) []string {
235 return []string{fmt.Sprintf("%s/oauth/callback", c)}
236 }
237
238 clientURI := o.Config.Core.AppviewHost
239 clientID := fmt.Sprintf("%s/oauth/client-metadata.json", clientURI)
240 redirectURIs := makeRedirectURIs(clientURI)
241
242 if o.Config.Core.Dev {
243 clientURI = fmt.Sprintf("http://127.0.0.1:3000")
244 redirectURIs = makeRedirectURIs(clientURI)
245
246 query := url.Values{}
247 query.Add("redirect_uri", redirectURIs[0])
248 query.Add("scope", "atproto transition:generic")
249 clientID = fmt.Sprintf("http://localhost?%s", query.Encode())
250 }
251
252 jwksURI := fmt.Sprintf("%s/oauth/jwks.json", clientURI)
253
254 return ClientMetadata{
255 ClientID: clientID,
256 ClientName: "Tangled",
257 SubjectType: "public",
258 ClientURI: clientURI,
259 RedirectURIs: redirectURIs,
260 GrantTypes: []string{"authorization_code", "refresh_token"},
261 ResponseTypes: []string{"code"},
262 ApplicationType: "web",
263 DpopBoundAccessTokens: true,
264 JwksURI: jwksURI,
265 Scope: "atproto transition:generic",
266 TokenEndpointAuthMethod: "private_key_jwt",
267 TokenEndpointAuthSigningAlg: "ES256",
268 }
269}