1package oauth
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "log"
9 "net/http"
10 "net/url"
11 "slices"
12 "strings"
13 "time"
14
15 "github.com/go-chi/chi/v5"
16 "github.com/gorilla/sessions"
17 "github.com/lestrrat-go/jwx/v2/jwk"
18 "github.com/posthog/posthog-go"
19 tangled "tangled.org/core/api/tangled"
20 sessioncache "tangled.org/core/appview/cache/session"
21 "tangled.org/core/appview/config"
22 "tangled.org/core/appview/db"
23 "tangled.org/core/appview/middleware"
24 "tangled.org/core/appview/oauth"
25 "tangled.org/core/appview/oauth/client"
26 "tangled.org/core/appview/pages"
27 "tangled.org/core/consts"
28 "tangled.org/core/idresolver"
29 "tangled.org/core/rbac"
30 "tangled.org/core/tid"
31 "tangled.sh/icyphox.sh/atproto-oauth/helpers"
32)
33
34const (
35 oauthScope = "atproto transition:generic"
36)
37
38type OAuthHandler struct {
39 config *config.Config
40 pages *pages.Pages
41 idResolver *idresolver.Resolver
42 sess *sessioncache.SessionStore
43 db *db.DB
44 store *sessions.CookieStore
45 oauth *oauth.OAuth
46 enforcer *rbac.Enforcer
47 posthog posthog.Client
48}
49
50func New(
51 config *config.Config,
52 pages *pages.Pages,
53 idResolver *idresolver.Resolver,
54 db *db.DB,
55 sess *sessioncache.SessionStore,
56 store *sessions.CookieStore,
57 oauth *oauth.OAuth,
58 enforcer *rbac.Enforcer,
59 posthog posthog.Client,
60) *OAuthHandler {
61 return &OAuthHandler{
62 config: config,
63 pages: pages,
64 idResolver: idResolver,
65 db: db,
66 sess: sess,
67 store: store,
68 oauth: oauth,
69 enforcer: enforcer,
70 posthog: posthog,
71 }
72}
73
74func (o *OAuthHandler) Router() http.Handler {
75 r := chi.NewRouter()
76
77 r.Get("/login", o.login)
78 r.Post("/login", o.login)
79
80 r.With(middleware.AuthMiddleware(o.oauth)).Post("/logout", o.logout)
81
82 r.Get("/oauth/client-metadata.json", o.clientMetadata)
83 r.Get("/oauth/jwks.json", o.jwks)
84 r.Get("/oauth/callback", o.callback)
85 return r
86}
87
88func (o *OAuthHandler) clientMetadata(w http.ResponseWriter, r *http.Request) {
89 w.Header().Set("Content-Type", "application/json")
90 w.WriteHeader(http.StatusOK)
91 json.NewEncoder(w).Encode(o.oauth.ClientMetadata())
92}
93
94func (o *OAuthHandler) jwks(w http.ResponseWriter, r *http.Request) {
95 jwks := o.config.OAuth.Jwks
96 pubKey, err := pubKeyFromJwk(jwks)
97 if err != nil {
98 log.Printf("error parsing public key: %v", err)
99 http.Error(w, err.Error(), http.StatusInternalServerError)
100 return
101 }
102
103 response := helpers.CreateJwksResponseObject(pubKey)
104
105 w.Header().Set("Content-Type", "application/json")
106 w.WriteHeader(http.StatusOK)
107 json.NewEncoder(w).Encode(response)
108}
109
110func (o *OAuthHandler) login(w http.ResponseWriter, r *http.Request) {
111 switch r.Method {
112 case http.MethodGet:
113 returnURL := r.URL.Query().Get("return_url")
114 o.pages.Login(w, pages.LoginParams{
115 ReturnUrl: returnURL,
116 })
117 case http.MethodPost:
118 handle := r.FormValue("handle")
119
120 // when users copy their handle from bsky.app, it tends to have these characters around it:
121 //
122 // @nelind.dk:
123 // \u202a ensures that the handle is always rendered left to right and
124 // \u202c reverts that so the rest of the page renders however it should
125 handle = strings.TrimPrefix(handle, "\u202a")
126 handle = strings.TrimSuffix(handle, "\u202c")
127
128 // `@` is harmless
129 handle = strings.TrimPrefix(handle, "@")
130
131 // basic handle validation
132 if !strings.Contains(handle, ".") {
133 log.Println("invalid handle format", "raw", handle)
134 o.pages.Notice(w, "login-msg", fmt.Sprintf("\"%s\" is an invalid handle. Did you mean %s.bsky.social?", handle, handle))
135 return
136 }
137
138 resolved, err := o.idResolver.ResolveIdent(r.Context(), handle)
139 if err != nil {
140 log.Println("failed to resolve handle:", err)
141 o.pages.Notice(w, "login-msg", fmt.Sprintf("\"%s\" is an invalid handle.", handle))
142 return
143 }
144 self := o.oauth.ClientMetadata()
145 oauthClient, err := client.NewClient(
146 self.ClientID,
147 o.config.OAuth.Jwks,
148 self.RedirectURIs[0],
149 )
150
151 if err != nil {
152 log.Println("failed to create oauth client:", err)
153 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
154 return
155 }
156
157 authServer, err := oauthClient.ResolvePdsAuthServer(r.Context(), resolved.PDSEndpoint())
158 if err != nil {
159 log.Println("failed to resolve auth server:", err)
160 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
161 return
162 }
163
164 authMeta, err := oauthClient.FetchAuthServerMetadata(r.Context(), authServer)
165 if err != nil {
166 log.Println("failed to fetch auth server metadata:", err)
167 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
168 return
169 }
170
171 dpopKey, err := helpers.GenerateKey(nil)
172 if err != nil {
173 log.Println("failed to generate dpop key:", err)
174 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
175 return
176 }
177
178 dpopKeyJson, err := json.Marshal(dpopKey)
179 if err != nil {
180 log.Println("failed to marshal dpop key:", err)
181 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
182 return
183 }
184
185 parResp, err := oauthClient.SendParAuthRequest(r.Context(), authServer, authMeta, handle, oauthScope, dpopKey)
186 if err != nil {
187 log.Println("failed to send par auth request:", err)
188 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
189 return
190 }
191
192 err = o.sess.SaveRequest(r.Context(), sessioncache.OAuthRequest{
193 Did: resolved.DID.String(),
194 PdsUrl: resolved.PDSEndpoint(),
195 Handle: handle,
196 AuthserverIss: authMeta.Issuer,
197 PkceVerifier: parResp.PkceVerifier,
198 DpopAuthserverNonce: parResp.DpopAuthserverNonce,
199 DpopPrivateJwk: string(dpopKeyJson),
200 State: parResp.State,
201 ReturnUrl: r.FormValue("return_url"),
202 })
203 if err != nil {
204 log.Println("failed to save oauth request:", err)
205 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
206 return
207 }
208
209 u, _ := url.Parse(authMeta.AuthorizationEndpoint)
210 query := url.Values{}
211 query.Add("client_id", self.ClientID)
212 query.Add("request_uri", parResp.RequestUri)
213 u.RawQuery = query.Encode()
214 o.pages.HxRedirect(w, u.String())
215 }
216}
217
218func (o *OAuthHandler) callback(w http.ResponseWriter, r *http.Request) {
219 state := r.FormValue("state")
220
221 oauthRequest, err := o.sess.GetRequestByState(r.Context(), state)
222 if err != nil {
223 log.Println("failed to get oauth request:", err)
224 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
225 return
226 }
227
228 defer func() {
229 err := o.sess.DeleteRequestByState(r.Context(), state)
230 if err != nil {
231 log.Println("failed to delete oauth request for state:", state, err)
232 }
233 }()
234
235 error := r.FormValue("error")
236 errorDescription := r.FormValue("error_description")
237 if error != "" || errorDescription != "" {
238 log.Printf("error: %s, %s", error, errorDescription)
239 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
240 return
241 }
242
243 code := r.FormValue("code")
244 if code == "" {
245 log.Println("missing code for state: ", state)
246 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
247 return
248 }
249
250 iss := r.FormValue("iss")
251 if iss == "" {
252 log.Println("missing iss for state: ", state)
253 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
254 return
255 }
256
257 if iss != oauthRequest.AuthserverIss {
258 log.Println("mismatched iss:", iss, "!=", oauthRequest.AuthserverIss, "for state:", state)
259 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
260 return
261 }
262
263 self := o.oauth.ClientMetadata()
264
265 oauthClient, err := client.NewClient(
266 self.ClientID,
267 o.config.OAuth.Jwks,
268 self.RedirectURIs[0],
269 )
270
271 if err != nil {
272 log.Println("failed to create oauth client:", err)
273 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
274 return
275 }
276
277 jwk, err := helpers.ParseJWKFromBytes([]byte(oauthRequest.DpopPrivateJwk))
278 if err != nil {
279 log.Println("failed to parse jwk:", err)
280 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
281 return
282 }
283
284 tokenResp, err := oauthClient.InitialTokenRequest(
285 r.Context(),
286 code,
287 oauthRequest.AuthserverIss,
288 oauthRequest.PkceVerifier,
289 oauthRequest.DpopAuthserverNonce,
290 jwk,
291 )
292 if err != nil {
293 log.Println("failed to get token:", err)
294 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
295 return
296 }
297
298 if tokenResp.Scope != oauthScope {
299 log.Println("scope doesn't match:", tokenResp.Scope)
300 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
301 return
302 }
303
304 err = o.oauth.SaveSession(w, r, *oauthRequest, tokenResp)
305 if err != nil {
306 log.Println("failed to save session:", err)
307 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
308 return
309 }
310
311 log.Println("session saved successfully")
312 go o.addToDefaultKnot(oauthRequest.Did)
313 go o.addToDefaultSpindle(oauthRequest.Did)
314
315 if !o.config.Core.Dev {
316 err = o.posthog.Enqueue(posthog.Capture{
317 DistinctId: oauthRequest.Did,
318 Event: "signin",
319 })
320 if err != nil {
321 log.Println("failed to enqueue posthog event:", err)
322 }
323 }
324
325 returnUrl := oauthRequest.ReturnUrl
326 if returnUrl == "" {
327 returnUrl = "/"
328 }
329
330 http.Redirect(w, r, returnUrl, http.StatusFound)
331}
332
333func (o *OAuthHandler) logout(w http.ResponseWriter, r *http.Request) {
334 err := o.oauth.ClearSession(r, w)
335 if err != nil {
336 log.Println("failed to clear session:", err)
337 http.Redirect(w, r, "/", http.StatusFound)
338 return
339 }
340
341 log.Println("session cleared successfully")
342 o.pages.HxRedirect(w, "/login")
343}
344
345func pubKeyFromJwk(jwks string) (jwk.Key, error) {
346 k, err := helpers.ParseJWKFromBytes([]byte(jwks))
347 if err != nil {
348 return nil, err
349 }
350 pubKey, err := k.PublicKey()
351 if err != nil {
352 return nil, err
353 }
354 return pubKey, nil
355}
356
357func (o *OAuthHandler) addToDefaultSpindle(did string) {
358 // use the tangled.sh app password to get an accessJwt
359 // and create an sh.tangled.spindle.member record with that
360 spindleMembers, err := db.GetSpindleMembers(
361 o.db,
362 db.FilterEq("instance", "spindle.tangled.sh"),
363 db.FilterEq("subject", did),
364 )
365 if err != nil {
366 log.Printf("failed to get spindle members for did %s: %v", did, err)
367 return
368 }
369
370 if len(spindleMembers) != 0 {
371 log.Printf("did %s is already a member of the default spindle", did)
372 return
373 }
374
375 log.Printf("adding %s to default spindle", did)
376 session, err := o.createAppPasswordSession(o.config.Core.AppPassword, consts.TangledDid)
377 if err != nil {
378 log.Printf("failed to create session: %s", err)
379 return
380 }
381
382 record := tangled.SpindleMember{
383 LexiconTypeID: "sh.tangled.spindle.member",
384 Subject: did,
385 Instance: consts.DefaultSpindle,
386 CreatedAt: time.Now().Format(time.RFC3339),
387 }
388
389 if err := session.putRecord(record, tangled.SpindleMemberNSID); err != nil {
390 log.Printf("failed to add member to default spindle: %s", err)
391 return
392 }
393
394 log.Printf("successfully added %s to default spindle", did)
395}
396
397func (o *OAuthHandler) addToDefaultKnot(did string) {
398 // use the tangled.sh app password to get an accessJwt
399 // and create an sh.tangled.spindle.member record with that
400
401 allKnots, err := o.enforcer.GetKnotsForUser(did)
402 if err != nil {
403 log.Printf("failed to get knot members for did %s: %v", did, err)
404 return
405 }
406
407 if slices.Contains(allKnots, consts.DefaultKnot) {
408 log.Printf("did %s is already a member of the default knot", did)
409 return
410 }
411
412 log.Printf("adding %s to default knot", did)
413 session, err := o.createAppPasswordSession(o.config.Core.TmpAltAppPassword, consts.IcyDid)
414 if err != nil {
415 log.Printf("failed to create session: %s", err)
416 return
417 }
418
419 record := tangled.KnotMember{
420 LexiconTypeID: "sh.tangled.knot.member",
421 Subject: did,
422 Domain: consts.DefaultKnot,
423 CreatedAt: time.Now().Format(time.RFC3339),
424 }
425
426 if err := session.putRecord(record, tangled.KnotMemberNSID); err != nil {
427 log.Printf("failed to add member to default knot: %s", err)
428 return
429 }
430
431 if err := o.enforcer.AddKnotMember(consts.DefaultKnot, did); err != nil {
432 log.Printf("failed to set up enforcer rules: %s", err)
433 return
434 }
435
436 log.Printf("successfully added %s to default Knot", did)
437}
438
439// create a session using apppasswords
440type session struct {
441 AccessJwt string `json:"accessJwt"`
442 PdsEndpoint string
443 Did string
444}
445
446func (o *OAuthHandler) createAppPasswordSession(appPassword, did string) (*session, error) {
447 if appPassword == "" {
448 return nil, fmt.Errorf("no app password configured, skipping member addition")
449 }
450
451 resolved, err := o.idResolver.ResolveIdent(context.Background(), did)
452 if err != nil {
453 return nil, fmt.Errorf("failed to resolve tangled.sh DID %s: %v", did, err)
454 }
455
456 pdsEndpoint := resolved.PDSEndpoint()
457 if pdsEndpoint == "" {
458 return nil, fmt.Errorf("no PDS endpoint found for tangled.sh DID %s", did)
459 }
460
461 sessionPayload := map[string]string{
462 "identifier": did,
463 "password": appPassword,
464 }
465 sessionBytes, err := json.Marshal(sessionPayload)
466 if err != nil {
467 return nil, fmt.Errorf("failed to marshal session payload: %v", err)
468 }
469
470 sessionURL := pdsEndpoint + "/xrpc/com.atproto.server.createSession"
471 sessionReq, err := http.NewRequestWithContext(context.Background(), "POST", sessionURL, bytes.NewBuffer(sessionBytes))
472 if err != nil {
473 return nil, fmt.Errorf("failed to create session request: %v", err)
474 }
475 sessionReq.Header.Set("Content-Type", "application/json")
476
477 client := &http.Client{Timeout: 30 * time.Second}
478 sessionResp, err := client.Do(sessionReq)
479 if err != nil {
480 return nil, fmt.Errorf("failed to create session: %v", err)
481 }
482 defer sessionResp.Body.Close()
483
484 if sessionResp.StatusCode != http.StatusOK {
485 return nil, fmt.Errorf("failed to create session: HTTP %d", sessionResp.StatusCode)
486 }
487
488 var session session
489 if err := json.NewDecoder(sessionResp.Body).Decode(&session); err != nil {
490 return nil, fmt.Errorf("failed to decode session response: %v", err)
491 }
492
493 session.PdsEndpoint = pdsEndpoint
494 session.Did = did
495
496 return &session, nil
497}
498
499func (s *session) putRecord(record any, collection string) error {
500 recordBytes, err := json.Marshal(record)
501 if err != nil {
502 return fmt.Errorf("failed to marshal knot member record: %w", err)
503 }
504
505 payload := map[string]any{
506 "repo": s.Did,
507 "collection": collection,
508 "rkey": tid.TID(),
509 "record": json.RawMessage(recordBytes),
510 }
511
512 payloadBytes, err := json.Marshal(payload)
513 if err != nil {
514 return fmt.Errorf("failed to marshal request payload: %w", err)
515 }
516
517 url := s.PdsEndpoint + "/xrpc/com.atproto.repo.putRecord"
518 req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewBuffer(payloadBytes))
519 if err != nil {
520 return fmt.Errorf("failed to create HTTP request: %w", err)
521 }
522
523 req.Header.Set("Content-Type", "application/json")
524 req.Header.Set("Authorization", "Bearer "+s.AccessJwt)
525
526 client := &http.Client{Timeout: 30 * time.Second}
527 resp, err := client.Do(req)
528 if err != nil {
529 return fmt.Errorf("failed to add user to default service: %w", err)
530 }
531 defer resp.Body.Close()
532
533 if resp.StatusCode != http.StatusOK {
534 return fmt.Errorf("failed to add user to default service: HTTP %d", resp.StatusCode)
535 }
536
537 return nil
538}