1package oauth
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "log"
9 "net/http"
10 "net/url"
11 "strings"
12 "time"
13
14 "github.com/go-chi/chi/v5"
15 "github.com/gorilla/sessions"
16 "github.com/lestrrat-go/jwx/v2/jwk"
17 "github.com/posthog/posthog-go"
18 "tangled.sh/icyphox.sh/atproto-oauth/helpers"
19 tangled "tangled.sh/tangled.sh/core/api/tangled"
20 sessioncache "tangled.sh/tangled.sh/core/appview/cache/session"
21 "tangled.sh/tangled.sh/core/appview/config"
22 "tangled.sh/tangled.sh/core/appview/db"
23 "tangled.sh/tangled.sh/core/appview/middleware"
24 "tangled.sh/tangled.sh/core/appview/oauth"
25 "tangled.sh/tangled.sh/core/appview/oauth/client"
26 "tangled.sh/tangled.sh/core/appview/pages"
27 "tangled.sh/tangled.sh/core/idresolver"
28 "tangled.sh/tangled.sh/core/knotclient"
29 "tangled.sh/tangled.sh/core/rbac"
30 "tangled.sh/tangled.sh/core/tid"
31)
32
33const (
34 oauthScope = "atproto transition:generic"
35)
36
37type OAuthHandler struct {
38 config *config.Config
39 pages *pages.Pages
40 idResolver *idresolver.Resolver
41 sess *sessioncache.SessionStore
42 db *db.DB
43 store *sessions.CookieStore
44 oauth *oauth.OAuth
45 enforcer *rbac.Enforcer
46 posthog posthog.Client
47}
48
49func New(
50 config *config.Config,
51 pages *pages.Pages,
52 idResolver *idresolver.Resolver,
53 db *db.DB,
54 sess *sessioncache.SessionStore,
55 store *sessions.CookieStore,
56 oauth *oauth.OAuth,
57 enforcer *rbac.Enforcer,
58 posthog posthog.Client,
59) *OAuthHandler {
60 return &OAuthHandler{
61 config: config,
62 pages: pages,
63 idResolver: idResolver,
64 db: db,
65 sess: sess,
66 store: store,
67 oauth: oauth,
68 enforcer: enforcer,
69 posthog: posthog,
70 }
71}
72
73func (o *OAuthHandler) Router() http.Handler {
74 r := chi.NewRouter()
75
76 r.Get("/login", o.login)
77 r.Post("/login", o.login)
78
79 r.With(middleware.AuthMiddleware(o.oauth)).Post("/logout", o.logout)
80
81 r.Get("/oauth/client-metadata.json", o.clientMetadata)
82 r.Get("/oauth/jwks.json", o.jwks)
83 r.Get("/oauth/callback", o.callback)
84 return r
85}
86
87func (o *OAuthHandler) clientMetadata(w http.ResponseWriter, r *http.Request) {
88 w.Header().Set("Content-Type", "application/json")
89 w.WriteHeader(http.StatusOK)
90 json.NewEncoder(w).Encode(o.oauth.ClientMetadata())
91}
92
93func (o *OAuthHandler) jwks(w http.ResponseWriter, r *http.Request) {
94 jwks := o.config.OAuth.Jwks
95 pubKey, err := pubKeyFromJwk(jwks)
96 if err != nil {
97 log.Printf("error parsing public key: %v", err)
98 http.Error(w, err.Error(), http.StatusInternalServerError)
99 return
100 }
101
102 response := helpers.CreateJwksResponseObject(pubKey)
103
104 w.Header().Set("Content-Type", "application/json")
105 w.WriteHeader(http.StatusOK)
106 json.NewEncoder(w).Encode(response)
107}
108
109func (o *OAuthHandler) login(w http.ResponseWriter, r *http.Request) {
110 switch r.Method {
111 case http.MethodGet:
112 o.pages.Login(w, pages.LoginParams{})
113 case http.MethodPost:
114 handle := r.FormValue("handle")
115
116 // when users copy their handle from bsky.app, it tends to have these characters around it:
117 //
118 // @nelind.dk:
119 // \u202a ensures that the handle is always rendered left to right and
120 // \u202c reverts that so the rest of the page renders however it should
121 handle = strings.TrimPrefix(handle, "\u202a")
122 handle = strings.TrimSuffix(handle, "\u202c")
123
124 // `@` is harmless
125 handle = strings.TrimPrefix(handle, "@")
126
127 // basic handle validation
128 if !strings.Contains(handle, ".") {
129 log.Println("invalid handle format", "raw", handle)
130 o.pages.Notice(w, "login-msg", fmt.Sprintf("\"%s\" is an invalid handle. Did you mean %s.bsky.social?", handle, handle))
131 return
132 }
133
134 resolved, err := o.idResolver.ResolveIdent(r.Context(), handle)
135 if err != nil {
136 log.Println("failed to resolve handle:", err)
137 o.pages.Notice(w, "login-msg", fmt.Sprintf("\"%s\" is an invalid handle.", handle))
138 return
139 }
140 self := o.oauth.ClientMetadata()
141 oauthClient, err := client.NewClient(
142 self.ClientID,
143 o.config.OAuth.Jwks,
144 self.RedirectURIs[0],
145 )
146
147 if err != nil {
148 log.Println("failed to create oauth client:", err)
149 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
150 return
151 }
152
153 authServer, err := oauthClient.ResolvePdsAuthServer(r.Context(), resolved.PDSEndpoint())
154 if err != nil {
155 log.Println("failed to resolve auth server:", err)
156 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
157 return
158 }
159
160 authMeta, err := oauthClient.FetchAuthServerMetadata(r.Context(), authServer)
161 if err != nil {
162 log.Println("failed to fetch auth server metadata:", err)
163 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
164 return
165 }
166
167 dpopKey, err := helpers.GenerateKey(nil)
168 if err != nil {
169 log.Println("failed to generate dpop key:", err)
170 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
171 return
172 }
173
174 dpopKeyJson, err := json.Marshal(dpopKey)
175 if err != nil {
176 log.Println("failed to marshal dpop key:", err)
177 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
178 return
179 }
180
181 parResp, err := oauthClient.SendParAuthRequest(r.Context(), authServer, authMeta, handle, oauthScope, dpopKey)
182 if err != nil {
183 log.Println("failed to send par auth request:", err)
184 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
185 return
186 }
187
188 err = o.sess.SaveRequest(r.Context(), sessioncache.OAuthRequest{
189 Did: resolved.DID.String(),
190 PdsUrl: resolved.PDSEndpoint(),
191 Handle: handle,
192 AuthserverIss: authMeta.Issuer,
193 PkceVerifier: parResp.PkceVerifier,
194 DpopAuthserverNonce: parResp.DpopAuthserverNonce,
195 DpopPrivateJwk: string(dpopKeyJson),
196 State: parResp.State,
197 })
198 if err != nil {
199 log.Println("failed to save oauth request:", err)
200 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
201 return
202 }
203
204 u, _ := url.Parse(authMeta.AuthorizationEndpoint)
205 query := url.Values{}
206 query.Add("client_id", self.ClientID)
207 query.Add("request_uri", parResp.RequestUri)
208 u.RawQuery = query.Encode()
209 o.pages.HxRedirect(w, u.String())
210 }
211}
212
213func (o *OAuthHandler) callback(w http.ResponseWriter, r *http.Request) {
214 state := r.FormValue("state")
215
216 oauthRequest, err := o.sess.GetRequestByState(r.Context(), state)
217 if err != nil {
218 log.Println("failed to get oauth request:", err)
219 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
220 return
221 }
222
223 defer func() {
224 err := o.sess.DeleteRequestByState(r.Context(), state)
225 if err != nil {
226 log.Println("failed to delete oauth request for state:", state, err)
227 }
228 }()
229
230 error := r.FormValue("error")
231 errorDescription := r.FormValue("error_description")
232 if error != "" || errorDescription != "" {
233 log.Printf("error: %s, %s", error, errorDescription)
234 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
235 return
236 }
237
238 code := r.FormValue("code")
239 if code == "" {
240 log.Println("missing code for state: ", state)
241 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
242 return
243 }
244
245 iss := r.FormValue("iss")
246 if iss == "" {
247 log.Println("missing iss for state: ", state)
248 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
249 return
250 }
251
252 self := o.oauth.ClientMetadata()
253
254 oauthClient, err := client.NewClient(
255 self.ClientID,
256 o.config.OAuth.Jwks,
257 self.RedirectURIs[0],
258 )
259
260 if err != nil {
261 log.Println("failed to create oauth client:", err)
262 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
263 return
264 }
265
266 jwk, err := helpers.ParseJWKFromBytes([]byte(oauthRequest.DpopPrivateJwk))
267 if err != nil {
268 log.Println("failed to parse jwk:", err)
269 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
270 return
271 }
272
273 tokenResp, err := oauthClient.InitialTokenRequest(
274 r.Context(),
275 code,
276 oauthRequest.AuthserverIss,
277 oauthRequest.PkceVerifier,
278 oauthRequest.DpopAuthserverNonce,
279 jwk,
280 )
281 if err != nil {
282 log.Println("failed to get token:", err)
283 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
284 return
285 }
286
287 if tokenResp.Scope != oauthScope {
288 log.Println("scope doesn't match:", tokenResp.Scope)
289 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
290 return
291 }
292
293 err = o.oauth.SaveSession(w, r, *oauthRequest, tokenResp)
294 if err != nil {
295 log.Println("failed to save session:", err)
296 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
297 return
298 }
299
300 log.Println("session saved successfully")
301 go o.addToDefaultKnot(oauthRequest.Did)
302 go o.addToDefaultSpindle(oauthRequest.Did)
303
304 if !o.config.Core.Dev {
305 err = o.posthog.Enqueue(posthog.Capture{
306 DistinctId: oauthRequest.Did,
307 Event: "signin",
308 })
309 if err != nil {
310 log.Println("failed to enqueue posthog event:", err)
311 }
312 }
313
314 http.Redirect(w, r, "/", http.StatusFound)
315}
316
317func (o *OAuthHandler) logout(w http.ResponseWriter, r *http.Request) {
318 err := o.oauth.ClearSession(r, w)
319 if err != nil {
320 log.Println("failed to clear session:", err)
321 http.Redirect(w, r, "/", http.StatusFound)
322 return
323 }
324
325 log.Println("session cleared successfully")
326 o.pages.HxRedirect(w, "/login")
327}
328
329func pubKeyFromJwk(jwks string) (jwk.Key, error) {
330 k, err := helpers.ParseJWKFromBytes([]byte(jwks))
331 if err != nil {
332 return nil, err
333 }
334 pubKey, err := k.PublicKey()
335 if err != nil {
336 return nil, err
337 }
338 return pubKey, nil
339}
340
341func (o *OAuthHandler) addToDefaultSpindle(did string) {
342 // use the tangled.sh app password to get an accessJwt
343 // and create an sh.tangled.spindle.member record with that
344
345 defaultSpindle := "spindle.tangled.sh"
346 appPassword := o.config.Core.AppPassword
347
348 spindleMembers, err := db.GetSpindleMembers(
349 o.db,
350 db.FilterEq("instance", "spindle.tangled.sh"),
351 db.FilterEq("subject", did),
352 )
353 if err != nil {
354 log.Printf("failed to get spindle members for did %s: %v", did, err)
355 return
356 }
357
358 if len(spindleMembers) != 0 {
359 log.Printf("did %s is already a member of the default spindle", did)
360 return
361 }
362
363 // TODO: hardcoded tangled handle and did for now
364 tangledHandle := "tangled.sh"
365 tangledDid := "did:plc:wshs7t2adsemcrrd4snkeqli"
366
367 if appPassword == "" {
368 log.Println("no app password configured, skipping spindle member addition")
369 return
370 }
371
372 log.Printf("adding %s to default spindle", did)
373
374 resolved, err := o.idResolver.ResolveIdent(context.Background(), tangledDid)
375 if err != nil {
376 log.Printf("failed to resolve tangled.sh DID %s: %v", tangledDid, err)
377 return
378 }
379
380 pdsEndpoint := resolved.PDSEndpoint()
381 if pdsEndpoint == "" {
382 log.Printf("no PDS endpoint found for tangled.sh DID %s", tangledDid)
383 return
384 }
385
386 sessionPayload := map[string]string{
387 "identifier": tangledHandle,
388 "password": appPassword,
389 }
390 sessionBytes, err := json.Marshal(sessionPayload)
391 if err != nil {
392 log.Printf("failed to marshal session payload: %v", err)
393 return
394 }
395
396 sessionURL := pdsEndpoint + "/xrpc/com.atproto.server.createSession"
397 sessionReq, err := http.NewRequestWithContext(context.Background(), "POST", sessionURL, bytes.NewBuffer(sessionBytes))
398 if err != nil {
399 log.Printf("failed to create session request: %v", err)
400 return
401 }
402 sessionReq.Header.Set("Content-Type", "application/json")
403
404 client := &http.Client{Timeout: 30 * time.Second}
405 sessionResp, err := client.Do(sessionReq)
406 if err != nil {
407 log.Printf("failed to create session: %v", err)
408 return
409 }
410 defer sessionResp.Body.Close()
411
412 if sessionResp.StatusCode != http.StatusOK {
413 log.Printf("failed to create session: HTTP %d", sessionResp.StatusCode)
414 return
415 }
416
417 var session struct {
418 AccessJwt string `json:"accessJwt"`
419 }
420 if err := json.NewDecoder(sessionResp.Body).Decode(&session); err != nil {
421 log.Printf("failed to decode session response: %v", err)
422 return
423 }
424
425 record := tangled.SpindleMember{
426 LexiconTypeID: "sh.tangled.spindle.member",
427 Subject: did,
428 Instance: defaultSpindle,
429 CreatedAt: time.Now().Format(time.RFC3339),
430 }
431
432 recordBytes, err := json.Marshal(record)
433 if err != nil {
434 log.Printf("failed to marshal spindle member record: %v", err)
435 return
436 }
437
438 payload := map[string]interface{}{
439 "repo": tangledDid,
440 "collection": tangled.SpindleMemberNSID,
441 "rkey": tid.TID(),
442 "record": json.RawMessage(recordBytes),
443 }
444
445 payloadBytes, err := json.Marshal(payload)
446 if err != nil {
447 log.Printf("failed to marshal request payload: %v", err)
448 return
449 }
450
451 url := pdsEndpoint + "/xrpc/com.atproto.repo.putRecord"
452 req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewBuffer(payloadBytes))
453 if err != nil {
454 log.Printf("failed to create HTTP request: %v", err)
455 return
456 }
457
458 req.Header.Set("Content-Type", "application/json")
459 req.Header.Set("Authorization", "Bearer "+session.AccessJwt)
460
461 resp, err := client.Do(req)
462 if err != nil {
463 log.Printf("failed to add user to default spindle: %v", err)
464 return
465 }
466 defer resp.Body.Close()
467
468 if resp.StatusCode != http.StatusOK {
469 log.Printf("failed to add user to default spindle: HTTP %d", resp.StatusCode)
470 return
471 }
472
473 log.Printf("successfully added %s to default spindle", did)
474}
475
476func (o *OAuthHandler) addToDefaultKnot(did string) {
477 defaultKnot := "knot1.tangled.sh"
478
479 log.Printf("adding %s to default knot", did)
480 err := o.enforcer.AddKnotMember(defaultKnot, did)
481 if err != nil {
482 log.Println("failed to add user to knot1.tangled.sh: ", err)
483 return
484 }
485 err = o.enforcer.E.SavePolicy()
486 if err != nil {
487 log.Println("failed to add user to knot1.tangled.sh: ", err)
488 return
489 }
490
491 secret, err := db.GetRegistrationKey(o.db, defaultKnot)
492 if err != nil {
493 log.Println("failed to get registration key for knot1.tangled.sh")
494 return
495 }
496 signedClient, err := knotclient.NewSignedClient(defaultKnot, secret, o.config.Core.Dev)
497 resp, err := signedClient.AddMember(did)
498 if err != nil {
499 log.Println("failed to add user to knot1.tangled.sh: ", err)
500 return
501 }
502
503 if resp.StatusCode != http.StatusNoContent {
504 log.Println("failed to add user to knot1.tangled.sh: ", resp.StatusCode)
505 return
506 }
507}