1package oauth
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "log"
9 "net/http"
10 "net/url"
11 "slices"
12 "strings"
13 "time"
14
15 "github.com/go-chi/chi/v5"
16 "github.com/gorilla/sessions"
17 "github.com/lestrrat-go/jwx/v2/jwk"
18 "github.com/posthog/posthog-go"
19 "tangled.sh/icyphox.sh/atproto-oauth/helpers"
20 tangled "tangled.sh/tangled.sh/core/api/tangled"
21 sessioncache "tangled.sh/tangled.sh/core/appview/cache/session"
22 "tangled.sh/tangled.sh/core/appview/config"
23 "tangled.sh/tangled.sh/core/appview/db"
24 "tangled.sh/tangled.sh/core/appview/middleware"
25 "tangled.sh/tangled.sh/core/appview/oauth"
26 "tangled.sh/tangled.sh/core/appview/oauth/client"
27 "tangled.sh/tangled.sh/core/appview/pages"
28 "tangled.sh/tangled.sh/core/idresolver"
29 "tangled.sh/tangled.sh/core/rbac"
30 "tangled.sh/tangled.sh/core/tid"
31)
32
33const (
34 oauthScope = "atproto transition:generic"
35)
36
37type OAuthHandler struct {
38 config *config.Config
39 pages *pages.Pages
40 idResolver *idresolver.Resolver
41 sess *sessioncache.SessionStore
42 db *db.DB
43 store *sessions.CookieStore
44 oauth *oauth.OAuth
45 enforcer *rbac.Enforcer
46 posthog posthog.Client
47}
48
49func New(
50 config *config.Config,
51 pages *pages.Pages,
52 idResolver *idresolver.Resolver,
53 db *db.DB,
54 sess *sessioncache.SessionStore,
55 store *sessions.CookieStore,
56 oauth *oauth.OAuth,
57 enforcer *rbac.Enforcer,
58 posthog posthog.Client,
59) *OAuthHandler {
60 return &OAuthHandler{
61 config: config,
62 pages: pages,
63 idResolver: idResolver,
64 db: db,
65 sess: sess,
66 store: store,
67 oauth: oauth,
68 enforcer: enforcer,
69 posthog: posthog,
70 }
71}
72
73func (o *OAuthHandler) Router() http.Handler {
74 r := chi.NewRouter()
75
76 r.Get("/login", o.login)
77 r.Post("/login", o.login)
78
79 r.With(middleware.AuthMiddleware(o.oauth)).Post("/logout", o.logout)
80
81 r.Get("/oauth/client-metadata.json", o.clientMetadata)
82 r.Get("/oauth/jwks.json", o.jwks)
83 r.Get("/oauth/callback", o.callback)
84 return r
85}
86
87func (o *OAuthHandler) clientMetadata(w http.ResponseWriter, r *http.Request) {
88 w.Header().Set("Content-Type", "application/json")
89 w.WriteHeader(http.StatusOK)
90 json.NewEncoder(w).Encode(o.oauth.ClientMetadata())
91}
92
93func (o *OAuthHandler) jwks(w http.ResponseWriter, r *http.Request) {
94 jwks := o.config.OAuth.Jwks
95 pubKey, err := pubKeyFromJwk(jwks)
96 if err != nil {
97 log.Printf("error parsing public key: %v", err)
98 http.Error(w, err.Error(), http.StatusInternalServerError)
99 return
100 }
101
102 response := helpers.CreateJwksResponseObject(pubKey)
103
104 w.Header().Set("Content-Type", "application/json")
105 w.WriteHeader(http.StatusOK)
106 json.NewEncoder(w).Encode(response)
107}
108
109func (o *OAuthHandler) login(w http.ResponseWriter, r *http.Request) {
110 switch r.Method {
111 case http.MethodGet:
112 returnURL := r.URL.Query().Get("return_url")
113 o.pages.Login(w, pages.LoginParams{
114 ReturnUrl: returnURL,
115 })
116 case http.MethodPost:
117 handle := r.FormValue("handle")
118
119 // when users copy their handle from bsky.app, it tends to have these characters around it:
120 //
121 // @nelind.dk:
122 // \u202a ensures that the handle is always rendered left to right and
123 // \u202c reverts that so the rest of the page renders however it should
124 handle = strings.TrimPrefix(handle, "\u202a")
125 handle = strings.TrimSuffix(handle, "\u202c")
126
127 // `@` is harmless
128 handle = strings.TrimPrefix(handle, "@")
129
130 // basic handle validation
131 if !strings.Contains(handle, ".") {
132 log.Println("invalid handle format", "raw", handle)
133 o.pages.Notice(w, "login-msg", fmt.Sprintf("\"%s\" is an invalid handle. Did you mean %s.bsky.social?", handle, handle))
134 return
135 }
136
137 resolved, err := o.idResolver.ResolveIdent(r.Context(), handle)
138 if err != nil {
139 log.Println("failed to resolve handle:", err)
140 o.pages.Notice(w, "login-msg", fmt.Sprintf("\"%s\" is an invalid handle.", handle))
141 return
142 }
143 self := o.oauth.ClientMetadata()
144 oauthClient, err := client.NewClient(
145 self.ClientID,
146 o.config.OAuth.Jwks,
147 self.RedirectURIs[0],
148 )
149
150 if err != nil {
151 log.Println("failed to create oauth client:", err)
152 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
153 return
154 }
155
156 authServer, err := oauthClient.ResolvePdsAuthServer(r.Context(), resolved.PDSEndpoint())
157 if err != nil {
158 log.Println("failed to resolve auth server:", err)
159 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
160 return
161 }
162
163 authMeta, err := oauthClient.FetchAuthServerMetadata(r.Context(), authServer)
164 if err != nil {
165 log.Println("failed to fetch auth server metadata:", err)
166 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
167 return
168 }
169
170 dpopKey, err := helpers.GenerateKey(nil)
171 if err != nil {
172 log.Println("failed to generate dpop key:", err)
173 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
174 return
175 }
176
177 dpopKeyJson, err := json.Marshal(dpopKey)
178 if err != nil {
179 log.Println("failed to marshal dpop key:", err)
180 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
181 return
182 }
183
184 parResp, err := oauthClient.SendParAuthRequest(r.Context(), authServer, authMeta, handle, oauthScope, dpopKey)
185 if err != nil {
186 log.Println("failed to send par auth request:", err)
187 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
188 return
189 }
190
191 err = o.sess.SaveRequest(r.Context(), sessioncache.OAuthRequest{
192 Did: resolved.DID.String(),
193 PdsUrl: resolved.PDSEndpoint(),
194 Handle: handle,
195 AuthserverIss: authMeta.Issuer,
196 PkceVerifier: parResp.PkceVerifier,
197 DpopAuthserverNonce: parResp.DpopAuthserverNonce,
198 DpopPrivateJwk: string(dpopKeyJson),
199 State: parResp.State,
200 ReturnUrl: r.FormValue("return_url"),
201 })
202 if err != nil {
203 log.Println("failed to save oauth request:", err)
204 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
205 return
206 }
207
208 u, _ := url.Parse(authMeta.AuthorizationEndpoint)
209 query := url.Values{}
210 query.Add("client_id", self.ClientID)
211 query.Add("request_uri", parResp.RequestUri)
212 u.RawQuery = query.Encode()
213 o.pages.HxRedirect(w, u.String())
214 }
215}
216
217func (o *OAuthHandler) callback(w http.ResponseWriter, r *http.Request) {
218 state := r.FormValue("state")
219
220 oauthRequest, err := o.sess.GetRequestByState(r.Context(), state)
221 if err != nil {
222 log.Println("failed to get oauth request:", err)
223 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
224 return
225 }
226
227 defer func() {
228 err := o.sess.DeleteRequestByState(r.Context(), state)
229 if err != nil {
230 log.Println("failed to delete oauth request for state:", state, err)
231 }
232 }()
233
234 error := r.FormValue("error")
235 errorDescription := r.FormValue("error_description")
236 if error != "" || errorDescription != "" {
237 log.Printf("error: %s, %s", error, errorDescription)
238 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
239 return
240 }
241
242 code := r.FormValue("code")
243 if code == "" {
244 log.Println("missing code for state: ", state)
245 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
246 return
247 }
248
249 iss := r.FormValue("iss")
250 if iss == "" {
251 log.Println("missing iss for state: ", state)
252 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
253 return
254 }
255
256 if iss != oauthRequest.AuthserverIss {
257 log.Println("mismatched iss:", iss, "!=", oauthRequest.AuthserverIss, "for state:", state)
258 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
259 return
260 }
261
262 self := o.oauth.ClientMetadata()
263
264 oauthClient, err := client.NewClient(
265 self.ClientID,
266 o.config.OAuth.Jwks,
267 self.RedirectURIs[0],
268 )
269
270 if err != nil {
271 log.Println("failed to create oauth client:", err)
272 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
273 return
274 }
275
276 jwk, err := helpers.ParseJWKFromBytes([]byte(oauthRequest.DpopPrivateJwk))
277 if err != nil {
278 log.Println("failed to parse jwk:", err)
279 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
280 return
281 }
282
283 tokenResp, err := oauthClient.InitialTokenRequest(
284 r.Context(),
285 code,
286 oauthRequest.AuthserverIss,
287 oauthRequest.PkceVerifier,
288 oauthRequest.DpopAuthserverNonce,
289 jwk,
290 )
291 if err != nil {
292 log.Println("failed to get token:", err)
293 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
294 return
295 }
296
297 if tokenResp.Scope != oauthScope {
298 log.Println("scope doesn't match:", tokenResp.Scope)
299 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
300 return
301 }
302
303 err = o.oauth.SaveSession(w, r, *oauthRequest, tokenResp)
304 if err != nil {
305 log.Println("failed to save session:", err)
306 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
307 return
308 }
309
310 log.Println("session saved successfully")
311 go o.addToDefaultKnot(oauthRequest.Did)
312 go o.addToDefaultSpindle(oauthRequest.Did)
313
314 if !o.config.Core.Dev {
315 err = o.posthog.Enqueue(posthog.Capture{
316 DistinctId: oauthRequest.Did,
317 Event: "signin",
318 })
319 if err != nil {
320 log.Println("failed to enqueue posthog event:", err)
321 }
322 }
323
324 returnUrl := oauthRequest.ReturnUrl
325 if returnUrl == "" {
326 returnUrl = "/"
327 }
328
329 http.Redirect(w, r, returnUrl, http.StatusFound)
330}
331
332func (o *OAuthHandler) logout(w http.ResponseWriter, r *http.Request) {
333 err := o.oauth.ClearSession(r, w)
334 if err != nil {
335 log.Println("failed to clear session:", err)
336 http.Redirect(w, r, "/", http.StatusFound)
337 return
338 }
339
340 log.Println("session cleared successfully")
341 o.pages.HxRedirect(w, "/login")
342}
343
344func pubKeyFromJwk(jwks string) (jwk.Key, error) {
345 k, err := helpers.ParseJWKFromBytes([]byte(jwks))
346 if err != nil {
347 return nil, err
348 }
349 pubKey, err := k.PublicKey()
350 if err != nil {
351 return nil, err
352 }
353 return pubKey, nil
354}
355
356var (
357 tangledDid = "did:plc:wshs7t2adsemcrrd4snkeqli"
358 icyDid = "did:plc:hwevmowznbiukdf6uk5dwrrq"
359
360 defaultSpindle = "spindle.tangled.sh"
361 defaultKnot = "knot1.tangled.sh"
362)
363
364func (o *OAuthHandler) addToDefaultSpindle(did string) {
365 // use the tangled.sh app password to get an accessJwt
366 // and create an sh.tangled.spindle.member record with that
367 spindleMembers, err := db.GetSpindleMembers(
368 o.db,
369 db.FilterEq("instance", "spindle.tangled.sh"),
370 db.FilterEq("subject", did),
371 )
372 if err != nil {
373 log.Printf("failed to get spindle members for did %s: %v", did, err)
374 return
375 }
376
377 if len(spindleMembers) != 0 {
378 log.Printf("did %s is already a member of the default spindle", did)
379 return
380 }
381
382 log.Printf("adding %s to default spindle", did)
383 session, err := o.createAppPasswordSession(o.config.Core.AppPassword, tangledDid)
384 if err != nil {
385 log.Printf("failed to create session: %s", err)
386 return
387 }
388
389 record := tangled.SpindleMember{
390 LexiconTypeID: "sh.tangled.spindle.member",
391 Subject: did,
392 Instance: defaultSpindle,
393 CreatedAt: time.Now().Format(time.RFC3339),
394 }
395
396 if err := session.putRecord(record, tangled.SpindleMemberNSID); err != nil {
397 log.Printf("failed to add member to default spindle: %s", err)
398 return
399 }
400
401 log.Printf("successfully added %s to default spindle", did)
402}
403
404func (o *OAuthHandler) addToDefaultKnot(did string) {
405 // use the tangled.sh app password to get an accessJwt
406 // and create an sh.tangled.spindle.member record with that
407
408 allKnots, err := o.enforcer.GetKnotsForUser(did)
409 if err != nil {
410 log.Printf("failed to get knot members for did %s: %v", did, err)
411 return
412 }
413
414 if slices.Contains(allKnots, defaultKnot) {
415 log.Printf("did %s is already a member of the default knot", did)
416 return
417 }
418
419 log.Printf("adding %s to default knot", did)
420 session, err := o.createAppPasswordSession(o.config.Core.TmpAltAppPassword, icyDid)
421 if err != nil {
422 log.Printf("failed to create session: %s", err)
423 return
424 }
425
426 record := tangled.KnotMember{
427 LexiconTypeID: "sh.tangled.knot.member",
428 Subject: did,
429 Domain: defaultKnot,
430 CreatedAt: time.Now().Format(time.RFC3339),
431 }
432
433 if err := session.putRecord(record, tangled.KnotMemberNSID); err != nil {
434 log.Printf("failed to add member to default knot: %s", err)
435 return
436 }
437
438 if err := o.enforcer.AddKnotMember(defaultKnot, did); err != nil {
439 log.Printf("failed to set up enforcer rules: %s", err)
440 return
441 }
442
443 log.Printf("successfully added %s to default Knot", did)
444}
445
446// create a session using apppasswords
447type session struct {
448 AccessJwt string `json:"accessJwt"`
449 PdsEndpoint string
450 Did string
451}
452
453func (o *OAuthHandler) createAppPasswordSession(appPassword, did string) (*session, error) {
454 if appPassword == "" {
455 return nil, fmt.Errorf("no app password configured, skipping member addition")
456 }
457
458 resolved, err := o.idResolver.ResolveIdent(context.Background(), did)
459 if err != nil {
460 return nil, fmt.Errorf("failed to resolve tangled.sh DID %s: %v", did, err)
461 }
462
463 pdsEndpoint := resolved.PDSEndpoint()
464 if pdsEndpoint == "" {
465 return nil, fmt.Errorf("no PDS endpoint found for tangled.sh DID %s", did)
466 }
467
468 sessionPayload := map[string]string{
469 "identifier": did,
470 "password": appPassword,
471 }
472 sessionBytes, err := json.Marshal(sessionPayload)
473 if err != nil {
474 return nil, fmt.Errorf("failed to marshal session payload: %v", err)
475 }
476
477 sessionURL := pdsEndpoint + "/xrpc/com.atproto.server.createSession"
478 sessionReq, err := http.NewRequestWithContext(context.Background(), "POST", sessionURL, bytes.NewBuffer(sessionBytes))
479 if err != nil {
480 return nil, fmt.Errorf("failed to create session request: %v", err)
481 }
482 sessionReq.Header.Set("Content-Type", "application/json")
483
484 client := &http.Client{Timeout: 30 * time.Second}
485 sessionResp, err := client.Do(sessionReq)
486 if err != nil {
487 return nil, fmt.Errorf("failed to create session: %v", err)
488 }
489 defer sessionResp.Body.Close()
490
491 if sessionResp.StatusCode != http.StatusOK {
492 return nil, fmt.Errorf("failed to create session: HTTP %d", sessionResp.StatusCode)
493 }
494
495 var session session
496 if err := json.NewDecoder(sessionResp.Body).Decode(&session); err != nil {
497 return nil, fmt.Errorf("failed to decode session response: %v", err)
498 }
499
500 session.PdsEndpoint = pdsEndpoint
501 session.Did = did
502
503 return &session, nil
504}
505
506func (s *session) putRecord(record any, collection string) error {
507 recordBytes, err := json.Marshal(record)
508 if err != nil {
509 return fmt.Errorf("failed to marshal knot member record: %w", err)
510 }
511
512 payload := map[string]any{
513 "repo": s.Did,
514 "collection": collection,
515 "rkey": tid.TID(),
516 "record": json.RawMessage(recordBytes),
517 }
518
519 payloadBytes, err := json.Marshal(payload)
520 if err != nil {
521 return fmt.Errorf("failed to marshal request payload: %w", err)
522 }
523
524 url := s.PdsEndpoint + "/xrpc/com.atproto.repo.putRecord"
525 req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewBuffer(payloadBytes))
526 if err != nil {
527 return fmt.Errorf("failed to create HTTP request: %w", err)
528 }
529
530 req.Header.Set("Content-Type", "application/json")
531 req.Header.Set("Authorization", "Bearer "+s.AccessJwt)
532
533 client := &http.Client{Timeout: 30 * time.Second}
534 resp, err := client.Do(req)
535 if err != nil {
536 return fmt.Errorf("failed to add user to default service: %w", err)
537 }
538 defer resp.Body.Close()
539
540 if resp.StatusCode != http.StatusOK {
541 return fmt.Errorf("failed to add user to default service: HTTP %d", resp.StatusCode)
542 }
543
544 return nil
545}