forked from tangled.org/core
Monorepo for Tangled — https://tangled.org
1package oauth 2 3import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "fmt" 8 "log" 9 "net/http" 10 "net/url" 11 "slices" 12 "strings" 13 "time" 14 15 "github.com/go-chi/chi/v5" 16 "github.com/gorilla/sessions" 17 "github.com/lestrrat-go/jwx/v2/jwk" 18 "github.com/posthog/posthog-go" 19 "tangled.sh/icyphox.sh/atproto-oauth/helpers" 20 tangled "tangled.sh/tangled.sh/core/api/tangled" 21 sessioncache "tangled.sh/tangled.sh/core/appview/cache/session" 22 "tangled.sh/tangled.sh/core/appview/config" 23 "tangled.sh/tangled.sh/core/appview/db" 24 "tangled.sh/tangled.sh/core/appview/middleware" 25 "tangled.sh/tangled.sh/core/appview/oauth" 26 "tangled.sh/tangled.sh/core/appview/oauth/client" 27 "tangled.sh/tangled.sh/core/appview/pages" 28 "tangled.sh/tangled.sh/core/idresolver" 29 "tangled.sh/tangled.sh/core/rbac" 30 "tangled.sh/tangled.sh/core/tid" 31) 32 33const ( 34 oauthScope = "atproto transition:generic" 35) 36 37type OAuthHandler struct { 38 config *config.Config 39 pages *pages.Pages 40 idResolver *idresolver.Resolver 41 sess *sessioncache.SessionStore 42 db *db.DB 43 store *sessions.CookieStore 44 oauth *oauth.OAuth 45 enforcer *rbac.Enforcer 46 posthog posthog.Client 47} 48 49func New( 50 config *config.Config, 51 pages *pages.Pages, 52 idResolver *idresolver.Resolver, 53 db *db.DB, 54 sess *sessioncache.SessionStore, 55 store *sessions.CookieStore, 56 oauth *oauth.OAuth, 57 enforcer *rbac.Enforcer, 58 posthog posthog.Client, 59) *OAuthHandler { 60 return &OAuthHandler{ 61 config: config, 62 pages: pages, 63 idResolver: idResolver, 64 db: db, 65 sess: sess, 66 store: store, 67 oauth: oauth, 68 enforcer: enforcer, 69 posthog: posthog, 70 } 71} 72 73func (o *OAuthHandler) Router() http.Handler { 74 r := chi.NewRouter() 75 76 r.Get("/login", o.login) 77 r.Post("/login", o.login) 78 79 r.With(middleware.AuthMiddleware(o.oauth)).Post("/logout", o.logout) 80 81 r.Get("/oauth/client-metadata.json", o.clientMetadata) 82 r.Get("/oauth/jwks.json", o.jwks) 83 r.Get("/oauth/callback", o.callback) 84 return r 85} 86 87func (o *OAuthHandler) clientMetadata(w http.ResponseWriter, r *http.Request) { 88 w.Header().Set("Content-Type", "application/json") 89 w.WriteHeader(http.StatusOK) 90 json.NewEncoder(w).Encode(o.oauth.ClientMetadata()) 91} 92 93func (o *OAuthHandler) jwks(w http.ResponseWriter, r *http.Request) { 94 jwks := o.config.OAuth.Jwks 95 pubKey, err := pubKeyFromJwk(jwks) 96 if err != nil { 97 log.Printf("error parsing public key: %v", err) 98 http.Error(w, err.Error(), http.StatusInternalServerError) 99 return 100 } 101 102 response := helpers.CreateJwksResponseObject(pubKey) 103 104 w.Header().Set("Content-Type", "application/json") 105 w.WriteHeader(http.StatusOK) 106 json.NewEncoder(w).Encode(response) 107} 108 109func (o *OAuthHandler) login(w http.ResponseWriter, r *http.Request) { 110 switch r.Method { 111 case http.MethodGet: 112 returnURL := r.URL.Query().Get("return_url") 113 o.pages.Login(w, pages.LoginParams{ 114 ReturnUrl: returnURL, 115 }) 116 case http.MethodPost: 117 handle := r.FormValue("handle") 118 119 // when users copy their handle from bsky.app, it tends to have these characters around it: 120 // 121 // @nelind.dk: 122 // \u202a ensures that the handle is always rendered left to right and 123 // \u202c reverts that so the rest of the page renders however it should 124 handle = strings.TrimPrefix(handle, "\u202a") 125 handle = strings.TrimSuffix(handle, "\u202c") 126 127 // `@` is harmless 128 handle = strings.TrimPrefix(handle, "@") 129 130 // basic handle validation 131 if !strings.Contains(handle, ".") { 132 log.Println("invalid handle format", "raw", handle) 133 o.pages.Notice(w, "login-msg", fmt.Sprintf("\"%s\" is an invalid handle. Did you mean %s.bsky.social?", handle, handle)) 134 return 135 } 136 137 resolved, err := o.idResolver.ResolveIdent(r.Context(), handle) 138 if err != nil { 139 log.Println("failed to resolve handle:", err) 140 o.pages.Notice(w, "login-msg", fmt.Sprintf("\"%s\" is an invalid handle.", handle)) 141 return 142 } 143 self := o.oauth.ClientMetadata() 144 oauthClient, err := client.NewClient( 145 self.ClientID, 146 o.config.OAuth.Jwks, 147 self.RedirectURIs[0], 148 ) 149 150 if err != nil { 151 log.Println("failed to create oauth client:", err) 152 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 153 return 154 } 155 156 authServer, err := oauthClient.ResolvePdsAuthServer(r.Context(), resolved.PDSEndpoint()) 157 if err != nil { 158 log.Println("failed to resolve auth server:", err) 159 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 160 return 161 } 162 163 authMeta, err := oauthClient.FetchAuthServerMetadata(r.Context(), authServer) 164 if err != nil { 165 log.Println("failed to fetch auth server metadata:", err) 166 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 167 return 168 } 169 170 dpopKey, err := helpers.GenerateKey(nil) 171 if err != nil { 172 log.Println("failed to generate dpop key:", err) 173 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 174 return 175 } 176 177 dpopKeyJson, err := json.Marshal(dpopKey) 178 if err != nil { 179 log.Println("failed to marshal dpop key:", err) 180 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 181 return 182 } 183 184 parResp, err := oauthClient.SendParAuthRequest(r.Context(), authServer, authMeta, handle, oauthScope, dpopKey) 185 if err != nil { 186 log.Println("failed to send par auth request:", err) 187 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 188 return 189 } 190 191 err = o.sess.SaveRequest(r.Context(), sessioncache.OAuthRequest{ 192 Did: resolved.DID.String(), 193 PdsUrl: resolved.PDSEndpoint(), 194 Handle: handle, 195 AuthserverIss: authMeta.Issuer, 196 PkceVerifier: parResp.PkceVerifier, 197 DpopAuthserverNonce: parResp.DpopAuthserverNonce, 198 DpopPrivateJwk: string(dpopKeyJson), 199 State: parResp.State, 200 ReturnUrl: r.FormValue("return_url"), 201 }) 202 if err != nil { 203 log.Println("failed to save oauth request:", err) 204 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 205 return 206 } 207 208 u, _ := url.Parse(authMeta.AuthorizationEndpoint) 209 query := url.Values{} 210 query.Add("client_id", self.ClientID) 211 query.Add("request_uri", parResp.RequestUri) 212 u.RawQuery = query.Encode() 213 o.pages.HxRedirect(w, u.String()) 214 } 215} 216 217func (o *OAuthHandler) callback(w http.ResponseWriter, r *http.Request) { 218 state := r.FormValue("state") 219 220 oauthRequest, err := o.sess.GetRequestByState(r.Context(), state) 221 if err != nil { 222 log.Println("failed to get oauth request:", err) 223 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 224 return 225 } 226 227 defer func() { 228 err := o.sess.DeleteRequestByState(r.Context(), state) 229 if err != nil { 230 log.Println("failed to delete oauth request for state:", state, err) 231 } 232 }() 233 234 error := r.FormValue("error") 235 errorDescription := r.FormValue("error_description") 236 if error != "" || errorDescription != "" { 237 log.Printf("error: %s, %s", error, errorDescription) 238 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 239 return 240 } 241 242 code := r.FormValue("code") 243 if code == "" { 244 log.Println("missing code for state: ", state) 245 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 246 return 247 } 248 249 iss := r.FormValue("iss") 250 if iss == "" { 251 log.Println("missing iss for state: ", state) 252 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 253 return 254 } 255 256 if iss != oauthRequest.AuthserverIss { 257 log.Println("mismatched iss:", iss, "!=", oauthRequest.AuthserverIss, "for state:", state) 258 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 259 return 260 } 261 262 self := o.oauth.ClientMetadata() 263 264 oauthClient, err := client.NewClient( 265 self.ClientID, 266 o.config.OAuth.Jwks, 267 self.RedirectURIs[0], 268 ) 269 270 if err != nil { 271 log.Println("failed to create oauth client:", err) 272 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 273 return 274 } 275 276 jwk, err := helpers.ParseJWKFromBytes([]byte(oauthRequest.DpopPrivateJwk)) 277 if err != nil { 278 log.Println("failed to parse jwk:", err) 279 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 280 return 281 } 282 283 tokenResp, err := oauthClient.InitialTokenRequest( 284 r.Context(), 285 code, 286 oauthRequest.AuthserverIss, 287 oauthRequest.PkceVerifier, 288 oauthRequest.DpopAuthserverNonce, 289 jwk, 290 ) 291 if err != nil { 292 log.Println("failed to get token:", err) 293 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 294 return 295 } 296 297 if tokenResp.Scope != oauthScope { 298 log.Println("scope doesn't match:", tokenResp.Scope) 299 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 300 return 301 } 302 303 err = o.oauth.SaveSession(w, r, *oauthRequest, tokenResp) 304 if err != nil { 305 log.Println("failed to save session:", err) 306 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 307 return 308 } 309 310 log.Println("session saved successfully") 311 go o.addToDefaultKnot(oauthRequest.Did) 312 go o.addToDefaultSpindle(oauthRequest.Did) 313 314 if !o.config.Core.Dev { 315 err = o.posthog.Enqueue(posthog.Capture{ 316 DistinctId: oauthRequest.Did, 317 Event: "signin", 318 }) 319 if err != nil { 320 log.Println("failed to enqueue posthog event:", err) 321 } 322 } 323 324 returnUrl := oauthRequest.ReturnUrl 325 if returnUrl == "" { 326 returnUrl = "/" 327 } 328 329 http.Redirect(w, r, returnUrl, http.StatusFound) 330} 331 332func (o *OAuthHandler) logout(w http.ResponseWriter, r *http.Request) { 333 err := o.oauth.ClearSession(r, w) 334 if err != nil { 335 log.Println("failed to clear session:", err) 336 http.Redirect(w, r, "/", http.StatusFound) 337 return 338 } 339 340 log.Println("session cleared successfully") 341 o.pages.HxRedirect(w, "/login") 342} 343 344func pubKeyFromJwk(jwks string) (jwk.Key, error) { 345 k, err := helpers.ParseJWKFromBytes([]byte(jwks)) 346 if err != nil { 347 return nil, err 348 } 349 pubKey, err := k.PublicKey() 350 if err != nil { 351 return nil, err 352 } 353 return pubKey, nil 354} 355 356var ( 357 tangledDid = "did:plc:wshs7t2adsemcrrd4snkeqli" 358 icyDid = "did:plc:hwevmowznbiukdf6uk5dwrrq" 359 360 defaultSpindle = "spindle.tangled.sh" 361 defaultKnot = "knot1.tangled.sh" 362) 363 364func (o *OAuthHandler) addToDefaultSpindle(did string) { 365 // use the tangled.sh app password to get an accessJwt 366 // and create an sh.tangled.spindle.member record with that 367 spindleMembers, err := db.GetSpindleMembers( 368 o.db, 369 db.FilterEq("instance", "spindle.tangled.sh"), 370 db.FilterEq("subject", did), 371 ) 372 if err != nil { 373 log.Printf("failed to get spindle members for did %s: %v", did, err) 374 return 375 } 376 377 if len(spindleMembers) != 0 { 378 log.Printf("did %s is already a member of the default spindle", did) 379 return 380 } 381 382 log.Printf("adding %s to default spindle", did) 383 session, err := o.createAppPasswordSession(o.config.Core.AppPassword, tangledDid) 384 if err != nil { 385 log.Printf("failed to create session: %s", err) 386 return 387 } 388 389 record := tangled.SpindleMember{ 390 LexiconTypeID: "sh.tangled.spindle.member", 391 Subject: did, 392 Instance: defaultSpindle, 393 CreatedAt: time.Now().Format(time.RFC3339), 394 } 395 396 if err := session.putRecord(record, tangled.SpindleMemberNSID); err != nil { 397 log.Printf("failed to add member to default spindle: %s", err) 398 return 399 } 400 401 log.Printf("successfully added %s to default spindle", did) 402} 403 404func (o *OAuthHandler) addToDefaultKnot(did string) { 405 // use the tangled.sh app password to get an accessJwt 406 // and create an sh.tangled.spindle.member record with that 407 408 allKnots, err := o.enforcer.GetKnotsForUser(did) 409 if err != nil { 410 log.Printf("failed to get knot members for did %s: %v", did, err) 411 return 412 } 413 414 if slices.Contains(allKnots, defaultKnot) { 415 log.Printf("did %s is already a member of the default knot", did) 416 return 417 } 418 419 log.Printf("adding %s to default knot", did) 420 session, err := o.createAppPasswordSession(o.config.Core.TmpAltAppPassword, icyDid) 421 if err != nil { 422 log.Printf("failed to create session: %s", err) 423 return 424 } 425 426 record := tangled.KnotMember{ 427 LexiconTypeID: "sh.tangled.knot.member", 428 Subject: did, 429 Domain: defaultKnot, 430 CreatedAt: time.Now().Format(time.RFC3339), 431 } 432 433 if err := session.putRecord(record, tangled.KnotMemberNSID); err != nil { 434 log.Printf("failed to add member to default knot: %s", err) 435 return 436 } 437 438 if err := o.enforcer.AddKnotMember(defaultKnot, did); err != nil { 439 log.Printf("failed to set up enforcer rules: %s", err) 440 return 441 } 442 443 log.Printf("successfully added %s to default Knot", did) 444} 445 446// create a session using apppasswords 447type session struct { 448 AccessJwt string `json:"accessJwt"` 449 PdsEndpoint string 450 Did string 451} 452 453func (o *OAuthHandler) createAppPasswordSession(appPassword, did string) (*session, error) { 454 if appPassword == "" { 455 return nil, fmt.Errorf("no app password configured, skipping member addition") 456 } 457 458 resolved, err := o.idResolver.ResolveIdent(context.Background(), did) 459 if err != nil { 460 return nil, fmt.Errorf("failed to resolve tangled.sh DID %s: %v", did, err) 461 } 462 463 pdsEndpoint := resolved.PDSEndpoint() 464 if pdsEndpoint == "" { 465 return nil, fmt.Errorf("no PDS endpoint found for tangled.sh DID %s", did) 466 } 467 468 sessionPayload := map[string]string{ 469 "identifier": did, 470 "password": appPassword, 471 } 472 sessionBytes, err := json.Marshal(sessionPayload) 473 if err != nil { 474 return nil, fmt.Errorf("failed to marshal session payload: %v", err) 475 } 476 477 sessionURL := pdsEndpoint + "/xrpc/com.atproto.server.createSession" 478 sessionReq, err := http.NewRequestWithContext(context.Background(), "POST", sessionURL, bytes.NewBuffer(sessionBytes)) 479 if err != nil { 480 return nil, fmt.Errorf("failed to create session request: %v", err) 481 } 482 sessionReq.Header.Set("Content-Type", "application/json") 483 484 client := &http.Client{Timeout: 30 * time.Second} 485 sessionResp, err := client.Do(sessionReq) 486 if err != nil { 487 return nil, fmt.Errorf("failed to create session: %v", err) 488 } 489 defer sessionResp.Body.Close() 490 491 if sessionResp.StatusCode != http.StatusOK { 492 return nil, fmt.Errorf("failed to create session: HTTP %d", sessionResp.StatusCode) 493 } 494 495 var session session 496 if err := json.NewDecoder(sessionResp.Body).Decode(&session); err != nil { 497 return nil, fmt.Errorf("failed to decode session response: %v", err) 498 } 499 500 session.PdsEndpoint = pdsEndpoint 501 session.Did = did 502 503 return &session, nil 504} 505 506func (s *session) putRecord(record any, collection string) error { 507 recordBytes, err := json.Marshal(record) 508 if err != nil { 509 return fmt.Errorf("failed to marshal knot member record: %w", err) 510 } 511 512 payload := map[string]any{ 513 "repo": s.Did, 514 "collection": collection, 515 "rkey": tid.TID(), 516 "record": json.RawMessage(recordBytes), 517 } 518 519 payloadBytes, err := json.Marshal(payload) 520 if err != nil { 521 return fmt.Errorf("failed to marshal request payload: %w", err) 522 } 523 524 url := s.PdsEndpoint + "/xrpc/com.atproto.repo.putRecord" 525 req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewBuffer(payloadBytes)) 526 if err != nil { 527 return fmt.Errorf("failed to create HTTP request: %w", err) 528 } 529 530 req.Header.Set("Content-Type", "application/json") 531 req.Header.Set("Authorization", "Bearer "+s.AccessJwt) 532 533 client := &http.Client{Timeout: 30 * time.Second} 534 resp, err := client.Do(req) 535 if err != nil { 536 return fmt.Errorf("failed to add user to default service: %w", err) 537 } 538 defer resp.Body.Close() 539 540 if resp.StatusCode != http.StatusOK { 541 return fmt.Errorf("failed to add user to default service: HTTP %d", resp.StatusCode) 542 } 543 544 return nil 545}