1package oauth
2
3import (
4 "fmt"
5 "log"
6 "net/http"
7 "net/url"
8 "time"
9
10 "github.com/gorilla/sessions"
11 oauth "github.com/haileyok/atproto-oauth-golang"
12 "github.com/haileyok/atproto-oauth-golang/helpers"
13 "tangled.sh/tangled.sh/core/appview"
14 "tangled.sh/tangled.sh/core/appview/db"
15 "tangled.sh/tangled.sh/core/appview/oauth/client"
16 xrpc "tangled.sh/tangled.sh/core/appview/xrpcclient"
17)
18
19type OAuthRequest struct {
20 ID uint
21 AuthserverIss string
22 State string
23 Did string
24 PdsUrl string
25 PkceVerifier string
26 DpopAuthserverNonce string
27 DpopPrivateJwk string
28}
29
30type OAuth struct {
31 Store *sessions.CookieStore
32 Db *db.DB
33 Config *appview.Config
34}
35
36func NewOAuth(db *db.DB, config *appview.Config) *OAuth {
37 return &OAuth{
38 Store: sessions.NewCookieStore([]byte(config.Core.CookieSecret)),
39 Db: db,
40 Config: config,
41 }
42}
43
44func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, oreq db.OAuthRequest, oresp *oauth.TokenResponse) error {
45 // first we save the did in the user session
46 userSession, err := o.Store.Get(r, appview.SessionName)
47 if err != nil {
48 return err
49 }
50
51 userSession.Values[appview.SessionDid] = oreq.Did
52 userSession.Values[appview.SessionHandle] = oreq.Handle
53 userSession.Values[appview.SessionPds] = oreq.PdsUrl
54 userSession.Values[appview.SessionAuthenticated] = true
55 err = userSession.Save(r, w)
56 if err != nil {
57 return fmt.Errorf("error saving user session: %w", err)
58 }
59
60 // then save the whole thing in the db
61 session := db.OAuthSession{
62 Did: oreq.Did,
63 Handle: oreq.Handle,
64 PdsUrl: oreq.PdsUrl,
65 DpopAuthserverNonce: oreq.DpopAuthserverNonce,
66 AuthServerIss: oreq.AuthserverIss,
67 DpopPrivateJwk: oreq.DpopPrivateJwk,
68 AccessJwt: oresp.AccessToken,
69 RefreshJwt: oresp.RefreshToken,
70 Expiry: time.Now().Add(time.Duration(oresp.ExpiresIn) * time.Second).Format(time.RFC3339),
71 }
72
73 return db.SaveOAuthSession(o.Db, session)
74}
75
76func (o *OAuth) ClearSession(r *http.Request, w http.ResponseWriter) error {
77 userSession, err := o.Store.Get(r, appview.SessionName)
78 if err != nil || userSession.IsNew {
79 return fmt.Errorf("error getting user session (or new session?): %w", err)
80 }
81
82 did := userSession.Values[appview.SessionDid].(string)
83
84 err = db.DeleteOAuthSessionByDid(o.Db, did)
85 if err != nil {
86 return fmt.Errorf("error deleting oauth session: %w", err)
87 }
88
89 userSession.Options.MaxAge = -1
90
91 return userSession.Save(r, w)
92}
93
94func (o *OAuth) GetSession(r *http.Request) (*db.OAuthSession, bool, error) {
95 userSession, err := o.Store.Get(r, appview.SessionName)
96 if err != nil || userSession.IsNew {
97 return nil, false, fmt.Errorf("error getting user session (or new session?): %w", err)
98 }
99
100 did := userSession.Values[appview.SessionDid].(string)
101 auth := userSession.Values[appview.SessionAuthenticated].(bool)
102
103 session, err := db.GetOAuthSessionByDid(o.Db, did)
104 if err != nil {
105 return nil, false, fmt.Errorf("error getting oauth session: %w", err)
106 }
107
108 expiry, err := time.Parse(time.RFC3339, session.Expiry)
109 if err != nil {
110 return nil, false, fmt.Errorf("error parsing expiry time: %w", err)
111 }
112 if expiry.Sub(time.Now()) <= 5*time.Minute {
113 privateJwk, err := helpers.ParseJWKFromBytes([]byte(session.DpopPrivateJwk))
114 if err != nil {
115 return nil, false, err
116 }
117
118 self := o.ClientMetadata()
119
120 oauthClient, err := client.NewClient(
121 self.ClientID,
122 o.Config.OAuth.Jwks,
123 self.RedirectURIs[0],
124 )
125
126 if err != nil {
127 return nil, false, err
128 }
129
130 resp, err := oauthClient.RefreshTokenRequest(r.Context(), session.RefreshJwt, session.AuthServerIss, session.DpopAuthserverNonce, privateJwk)
131 if err != nil {
132 return nil, false, err
133 }
134
135 newExpiry := time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second).Format(time.RFC3339)
136 err = db.RefreshOAuthSession(o.Db, did, resp.AccessToken, resp.RefreshToken, newExpiry)
137 if err != nil {
138 return nil, false, fmt.Errorf("error refreshing oauth session: %w", err)
139 }
140
141 // update the current session
142 session.AccessJwt = resp.AccessToken
143 session.RefreshJwt = resp.RefreshToken
144 session.DpopAuthserverNonce = resp.DpopAuthserverNonce
145 session.Expiry = newExpiry
146 }
147
148 return session, auth, nil
149}
150
151type User struct {
152 Handle string
153 Did string
154 Pds string
155}
156
157func (a *OAuth) GetUser(r *http.Request) *User {
158 clientSession, err := a.Store.Get(r, appview.SessionName)
159
160 if err != nil || clientSession.IsNew {
161 return nil
162 }
163
164 return &User{
165 Handle: clientSession.Values[appview.SessionHandle].(string),
166 Did: clientSession.Values[appview.SessionDid].(string),
167 Pds: clientSession.Values[appview.SessionPds].(string),
168 }
169}
170
171func (a *OAuth) GetDid(r *http.Request) string {
172 clientSession, err := a.Store.Get(r, appview.SessionName)
173
174 if err != nil || clientSession.IsNew {
175 return ""
176 }
177
178 return clientSession.Values[appview.SessionDid].(string)
179}
180
181func (o *OAuth) AuthorizedClient(r *http.Request) (*xrpc.Client, error) {
182 session, auth, err := o.GetSession(r)
183 if err != nil {
184 return nil, fmt.Errorf("error getting session: %w", err)
185 }
186 if !auth {
187 return nil, fmt.Errorf("not authorized")
188 }
189
190 client := &oauth.XrpcClient{
191 OnDpopPdsNonceChanged: func(did, newNonce string) {
192 err := db.UpdateDpopPdsNonce(o.Db, did, newNonce)
193 if err != nil {
194 log.Printf("error updating dpop pds nonce: %v", err)
195 }
196 },
197 }
198
199 privateJwk, err := helpers.ParseJWKFromBytes([]byte(session.DpopPrivateJwk))
200 if err != nil {
201 return nil, fmt.Errorf("error parsing private jwk: %w", err)
202 }
203
204 xrpcClient := xrpc.NewClient(client, &oauth.XrpcAuthedRequestArgs{
205 Did: session.Did,
206 PdsUrl: session.PdsUrl,
207 DpopPdsNonce: session.PdsUrl,
208 AccessToken: session.AccessJwt,
209 Issuer: session.AuthServerIss,
210 DpopPrivateJwk: privateJwk,
211 })
212
213 return xrpcClient, nil
214}
215
216type ClientMetadata struct {
217 ClientID string `json:"client_id"`
218 ClientName string `json:"client_name"`
219 SubjectType string `json:"subject_type"`
220 ClientURI string `json:"client_uri"`
221 RedirectURIs []string `json:"redirect_uris"`
222 GrantTypes []string `json:"grant_types"`
223 ResponseTypes []string `json:"response_types"`
224 ApplicationType string `json:"application_type"`
225 DpopBoundAccessTokens bool `json:"dpop_bound_access_tokens"`
226 JwksURI string `json:"jwks_uri"`
227 Scope string `json:"scope"`
228 TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"`
229 TokenEndpointAuthSigningAlg string `json:"token_endpoint_auth_signing_alg"`
230}
231
232func (o *OAuth) ClientMetadata() ClientMetadata {
233 makeRedirectURIs := func(c string) []string {
234 return []string{fmt.Sprintf("%s/oauth/callback", c)}
235 }
236
237 clientURI := o.Config.Core.AppviewHost
238 clientID := fmt.Sprintf("%s/oauth/client-metadata.json", clientURI)
239 redirectURIs := makeRedirectURIs(clientURI)
240
241 if o.Config.Core.Dev {
242 clientURI = fmt.Sprintf("http://127.0.0.1:3000")
243 redirectURIs = makeRedirectURIs(clientURI)
244
245 query := url.Values{}
246 query.Add("redirect_uri", redirectURIs[0])
247 query.Add("scope", "atproto transition:generic")
248 clientID = fmt.Sprintf("http://localhost?%s", query.Encode())
249 }
250
251 jwksURI := fmt.Sprintf("%s/oauth/jwks.json", clientURI)
252
253 return ClientMetadata{
254 ClientID: clientID,
255 ClientName: "Tangled",
256 SubjectType: "public",
257 ClientURI: clientURI,
258 RedirectURIs: redirectURIs,
259 GrantTypes: []string{"authorization_code", "refresh_token"},
260 ResponseTypes: []string{"code"},
261 ApplicationType: "web",
262 DpopBoundAccessTokens: true,
263 JwksURI: jwksURI,
264 Scope: "atproto transition:generic",
265 TokenEndpointAuthMethod: "private_key_jwt",
266 TokenEndpointAuthSigningAlg: "ES256",
267 }
268}