forked from tangled.org/core
this repo has no description
at master 15 kB view raw
1package oauth 2 3import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "fmt" 8 "log" 9 "net/http" 10 "net/url" 11 "slices" 12 "strings" 13 "time" 14 15 "github.com/go-chi/chi/v5" 16 "github.com/gorilla/sessions" 17 "github.com/lestrrat-go/jwx/v2/jwk" 18 "github.com/posthog/posthog-go" 19 tangled "tangled.org/core/api/tangled" 20 sessioncache "tangled.org/core/appview/cache/session" 21 "tangled.org/core/appview/config" 22 "tangled.org/core/appview/db" 23 "tangled.org/core/appview/middleware" 24 "tangled.org/core/appview/oauth" 25 "tangled.org/core/appview/oauth/client" 26 "tangled.org/core/appview/pages" 27 "tangled.org/core/consts" 28 "tangled.org/core/idresolver" 29 "tangled.org/core/rbac" 30 "tangled.org/core/tid" 31 "tangled.sh/icyphox.sh/atproto-oauth/helpers" 32) 33 34const ( 35 oauthScope = "atproto transition:generic" 36) 37 38type OAuthHandler struct { 39 config *config.Config 40 pages *pages.Pages 41 idResolver *idresolver.Resolver 42 sess *sessioncache.SessionStore 43 db *db.DB 44 store *sessions.CookieStore 45 oauth *oauth.OAuth 46 enforcer *rbac.Enforcer 47 posthog posthog.Client 48} 49 50func New( 51 config *config.Config, 52 pages *pages.Pages, 53 idResolver *idresolver.Resolver, 54 db *db.DB, 55 sess *sessioncache.SessionStore, 56 store *sessions.CookieStore, 57 oauth *oauth.OAuth, 58 enforcer *rbac.Enforcer, 59 posthog posthog.Client, 60) *OAuthHandler { 61 return &OAuthHandler{ 62 config: config, 63 pages: pages, 64 idResolver: idResolver, 65 db: db, 66 sess: sess, 67 store: store, 68 oauth: oauth, 69 enforcer: enforcer, 70 posthog: posthog, 71 } 72} 73 74func (o *OAuthHandler) Router() http.Handler { 75 r := chi.NewRouter() 76 77 r.Get("/login", o.login) 78 r.Post("/login", o.login) 79 80 r.With(middleware.AuthMiddleware(o.oauth)).Post("/logout", o.logout) 81 82 r.Get("/oauth/client-metadata.json", o.clientMetadata) 83 r.Get("/oauth/jwks.json", o.jwks) 84 r.Get("/oauth/callback", o.callback) 85 return r 86} 87 88func (o *OAuthHandler) clientMetadata(w http.ResponseWriter, r *http.Request) { 89 w.Header().Set("Content-Type", "application/json") 90 w.WriteHeader(http.StatusOK) 91 json.NewEncoder(w).Encode(o.oauth.ClientMetadata()) 92} 93 94func (o *OAuthHandler) jwks(w http.ResponseWriter, r *http.Request) { 95 jwks := o.config.OAuth.Jwks 96 pubKey, err := pubKeyFromJwk(jwks) 97 if err != nil { 98 log.Printf("error parsing public key: %v", err) 99 http.Error(w, err.Error(), http.StatusInternalServerError) 100 return 101 } 102 103 response := helpers.CreateJwksResponseObject(pubKey) 104 105 w.Header().Set("Content-Type", "application/json") 106 w.WriteHeader(http.StatusOK) 107 json.NewEncoder(w).Encode(response) 108} 109 110func (o *OAuthHandler) login(w http.ResponseWriter, r *http.Request) { 111 switch r.Method { 112 case http.MethodGet: 113 returnURL := r.URL.Query().Get("return_url") 114 o.pages.Login(w, pages.LoginParams{ 115 ReturnUrl: returnURL, 116 }) 117 case http.MethodPost: 118 handle := r.FormValue("handle") 119 120 // when users copy their handle from bsky.app, it tends to have these characters around it: 121 // 122 // @nelind.dk: 123 // \u202a ensures that the handle is always rendered left to right and 124 // \u202c reverts that so the rest of the page renders however it should 125 handle = strings.TrimPrefix(handle, "\u202a") 126 handle = strings.TrimSuffix(handle, "\u202c") 127 128 // `@` is harmless 129 handle = strings.TrimPrefix(handle, "@") 130 131 // basic handle validation 132 if !strings.Contains(handle, ".") { 133 log.Println("invalid handle format", "raw", handle) 134 o.pages.Notice(w, "login-msg", fmt.Sprintf("\"%s\" is an invalid handle. Did you mean %s.bsky.social?", handle, handle)) 135 return 136 } 137 138 resolved, err := o.idResolver.ResolveIdent(r.Context(), handle) 139 if err != nil { 140 log.Println("failed to resolve handle:", err) 141 o.pages.Notice(w, "login-msg", fmt.Sprintf("\"%s\" is an invalid handle.", handle)) 142 return 143 } 144 self := o.oauth.ClientMetadata() 145 oauthClient, err := client.NewClient( 146 self.ClientID, 147 o.config.OAuth.Jwks, 148 self.RedirectURIs[0], 149 ) 150 151 if err != nil { 152 log.Println("failed to create oauth client:", err) 153 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 154 return 155 } 156 157 authServer, err := oauthClient.ResolvePdsAuthServer(r.Context(), resolved.PDSEndpoint()) 158 if err != nil { 159 log.Println("failed to resolve auth server:", err) 160 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 161 return 162 } 163 164 authMeta, err := oauthClient.FetchAuthServerMetadata(r.Context(), authServer) 165 if err != nil { 166 log.Println("failed to fetch auth server metadata:", err) 167 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 168 return 169 } 170 171 dpopKey, err := helpers.GenerateKey(nil) 172 if err != nil { 173 log.Println("failed to generate dpop key:", err) 174 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 175 return 176 } 177 178 dpopKeyJson, err := json.Marshal(dpopKey) 179 if err != nil { 180 log.Println("failed to marshal dpop key:", err) 181 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 182 return 183 } 184 185 parResp, err := oauthClient.SendParAuthRequest(r.Context(), authServer, authMeta, handle, oauthScope, dpopKey) 186 if err != nil { 187 log.Println("failed to send par auth request:", err) 188 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 189 return 190 } 191 192 err = o.sess.SaveRequest(r.Context(), sessioncache.OAuthRequest{ 193 Did: resolved.DID.String(), 194 PdsUrl: resolved.PDSEndpoint(), 195 Handle: handle, 196 AuthserverIss: authMeta.Issuer, 197 PkceVerifier: parResp.PkceVerifier, 198 DpopAuthserverNonce: parResp.DpopAuthserverNonce, 199 DpopPrivateJwk: string(dpopKeyJson), 200 State: parResp.State, 201 ReturnUrl: r.FormValue("return_url"), 202 }) 203 if err != nil { 204 log.Println("failed to save oauth request:", err) 205 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 206 return 207 } 208 209 u, _ := url.Parse(authMeta.AuthorizationEndpoint) 210 query := url.Values{} 211 query.Add("client_id", self.ClientID) 212 query.Add("request_uri", parResp.RequestUri) 213 u.RawQuery = query.Encode() 214 o.pages.HxRedirect(w, u.String()) 215 } 216} 217 218func (o *OAuthHandler) callback(w http.ResponseWriter, r *http.Request) { 219 state := r.FormValue("state") 220 221 oauthRequest, err := o.sess.GetRequestByState(r.Context(), state) 222 if err != nil { 223 log.Println("failed to get oauth request:", err) 224 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 225 return 226 } 227 228 defer func() { 229 err := o.sess.DeleteRequestByState(r.Context(), state) 230 if err != nil { 231 log.Println("failed to delete oauth request for state:", state, err) 232 } 233 }() 234 235 error := r.FormValue("error") 236 errorDescription := r.FormValue("error_description") 237 if error != "" || errorDescription != "" { 238 log.Printf("error: %s, %s", error, errorDescription) 239 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 240 return 241 } 242 243 code := r.FormValue("code") 244 if code == "" { 245 log.Println("missing code for state: ", state) 246 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 247 return 248 } 249 250 iss := r.FormValue("iss") 251 if iss == "" { 252 log.Println("missing iss for state: ", state) 253 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 254 return 255 } 256 257 if iss != oauthRequest.AuthserverIss { 258 log.Println("mismatched iss:", iss, "!=", oauthRequest.AuthserverIss, "for state:", state) 259 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 260 return 261 } 262 263 self := o.oauth.ClientMetadata() 264 265 oauthClient, err := client.NewClient( 266 self.ClientID, 267 o.config.OAuth.Jwks, 268 self.RedirectURIs[0], 269 ) 270 271 if err != nil { 272 log.Println("failed to create oauth client:", err) 273 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 274 return 275 } 276 277 jwk, err := helpers.ParseJWKFromBytes([]byte(oauthRequest.DpopPrivateJwk)) 278 if err != nil { 279 log.Println("failed to parse jwk:", err) 280 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 281 return 282 } 283 284 tokenResp, err := oauthClient.InitialTokenRequest( 285 r.Context(), 286 code, 287 oauthRequest.AuthserverIss, 288 oauthRequest.PkceVerifier, 289 oauthRequest.DpopAuthserverNonce, 290 jwk, 291 ) 292 if err != nil { 293 log.Println("failed to get token:", err) 294 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 295 return 296 } 297 298 if tokenResp.Scope != oauthScope { 299 log.Println("scope doesn't match:", tokenResp.Scope) 300 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 301 return 302 } 303 304 err = o.oauth.SaveSession(w, r, *oauthRequest, tokenResp) 305 if err != nil { 306 log.Println("failed to save session:", err) 307 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 308 return 309 } 310 311 log.Println("session saved successfully") 312 go o.addToDefaultKnot(oauthRequest.Did) 313 go o.addToDefaultSpindle(oauthRequest.Did) 314 315 if !o.config.Core.Dev { 316 err = o.posthog.Enqueue(posthog.Capture{ 317 DistinctId: oauthRequest.Did, 318 Event: "signin", 319 }) 320 if err != nil { 321 log.Println("failed to enqueue posthog event:", err) 322 } 323 } 324 325 returnUrl := oauthRequest.ReturnUrl 326 if returnUrl == "" { 327 returnUrl = "/" 328 } 329 330 http.Redirect(w, r, returnUrl, http.StatusFound) 331} 332 333func (o *OAuthHandler) logout(w http.ResponseWriter, r *http.Request) { 334 err := o.oauth.ClearSession(r, w) 335 if err != nil { 336 log.Println("failed to clear session:", err) 337 http.Redirect(w, r, "/", http.StatusFound) 338 return 339 } 340 341 log.Println("session cleared successfully") 342 o.pages.HxRedirect(w, "/login") 343} 344 345func pubKeyFromJwk(jwks string) (jwk.Key, error) { 346 k, err := helpers.ParseJWKFromBytes([]byte(jwks)) 347 if err != nil { 348 return nil, err 349 } 350 pubKey, err := k.PublicKey() 351 if err != nil { 352 return nil, err 353 } 354 return pubKey, nil 355} 356 357func (o *OAuthHandler) addToDefaultSpindle(did string) { 358 // use the tangled.sh app password to get an accessJwt 359 // and create an sh.tangled.spindle.member record with that 360 spindleMembers, err := db.GetSpindleMembers( 361 o.db, 362 db.FilterEq("instance", "spindle.tangled.sh"), 363 db.FilterEq("subject", did), 364 ) 365 if err != nil { 366 log.Printf("failed to get spindle members for did %s: %v", did, err) 367 return 368 } 369 370 if len(spindleMembers) != 0 { 371 log.Printf("did %s is already a member of the default spindle", did) 372 return 373 } 374 375 log.Printf("adding %s to default spindle", did) 376 session, err := o.createAppPasswordSession(o.config.Core.AppPassword, consts.TangledDid) 377 if err != nil { 378 log.Printf("failed to create session: %s", err) 379 return 380 } 381 382 record := tangled.SpindleMember{ 383 LexiconTypeID: "sh.tangled.spindle.member", 384 Subject: did, 385 Instance: consts.DefaultSpindle, 386 CreatedAt: time.Now().Format(time.RFC3339), 387 } 388 389 if err := session.putRecord(record, tangled.SpindleMemberNSID); err != nil { 390 log.Printf("failed to add member to default spindle: %s", err) 391 return 392 } 393 394 log.Printf("successfully added %s to default spindle", did) 395} 396 397func (o *OAuthHandler) addToDefaultKnot(did string) { 398 // use the tangled.sh app password to get an accessJwt 399 // and create an sh.tangled.spindle.member record with that 400 401 allKnots, err := o.enforcer.GetKnotsForUser(did) 402 if err != nil { 403 log.Printf("failed to get knot members for did %s: %v", did, err) 404 return 405 } 406 407 if slices.Contains(allKnots, consts.DefaultKnot) { 408 log.Printf("did %s is already a member of the default knot", did) 409 return 410 } 411 412 log.Printf("adding %s to default knot", did) 413 session, err := o.createAppPasswordSession(o.config.Core.TmpAltAppPassword, consts.IcyDid) 414 if err != nil { 415 log.Printf("failed to create session: %s", err) 416 return 417 } 418 419 record := tangled.KnotMember{ 420 LexiconTypeID: "sh.tangled.knot.member", 421 Subject: did, 422 Domain: consts.DefaultKnot, 423 CreatedAt: time.Now().Format(time.RFC3339), 424 } 425 426 if err := session.putRecord(record, tangled.KnotMemberNSID); err != nil { 427 log.Printf("failed to add member to default knot: %s", err) 428 return 429 } 430 431 if err := o.enforcer.AddKnotMember(consts.DefaultKnot, did); err != nil { 432 log.Printf("failed to set up enforcer rules: %s", err) 433 return 434 } 435 436 log.Printf("successfully added %s to default Knot", did) 437} 438 439// create a session using apppasswords 440type session struct { 441 AccessJwt string `json:"accessJwt"` 442 PdsEndpoint string 443 Did string 444} 445 446func (o *OAuthHandler) createAppPasswordSession(appPassword, did string) (*session, error) { 447 if appPassword == "" { 448 return nil, fmt.Errorf("no app password configured, skipping member addition") 449 } 450 451 resolved, err := o.idResolver.ResolveIdent(context.Background(), did) 452 if err != nil { 453 return nil, fmt.Errorf("failed to resolve tangled.sh DID %s: %v", did, err) 454 } 455 456 pdsEndpoint := resolved.PDSEndpoint() 457 if pdsEndpoint == "" { 458 return nil, fmt.Errorf("no PDS endpoint found for tangled.sh DID %s", did) 459 } 460 461 sessionPayload := map[string]string{ 462 "identifier": did, 463 "password": appPassword, 464 } 465 sessionBytes, err := json.Marshal(sessionPayload) 466 if err != nil { 467 return nil, fmt.Errorf("failed to marshal session payload: %v", err) 468 } 469 470 sessionURL := pdsEndpoint + "/xrpc/com.atproto.server.createSession" 471 sessionReq, err := http.NewRequestWithContext(context.Background(), "POST", sessionURL, bytes.NewBuffer(sessionBytes)) 472 if err != nil { 473 return nil, fmt.Errorf("failed to create session request: %v", err) 474 } 475 sessionReq.Header.Set("Content-Type", "application/json") 476 477 client := &http.Client{Timeout: 30 * time.Second} 478 sessionResp, err := client.Do(sessionReq) 479 if err != nil { 480 return nil, fmt.Errorf("failed to create session: %v", err) 481 } 482 defer sessionResp.Body.Close() 483 484 if sessionResp.StatusCode != http.StatusOK { 485 return nil, fmt.Errorf("failed to create session: HTTP %d", sessionResp.StatusCode) 486 } 487 488 var session session 489 if err := json.NewDecoder(sessionResp.Body).Decode(&session); err != nil { 490 return nil, fmt.Errorf("failed to decode session response: %v", err) 491 } 492 493 session.PdsEndpoint = pdsEndpoint 494 session.Did = did 495 496 return &session, nil 497} 498 499func (s *session) putRecord(record any, collection string) error { 500 recordBytes, err := json.Marshal(record) 501 if err != nil { 502 return fmt.Errorf("failed to marshal knot member record: %w", err) 503 } 504 505 payload := map[string]any{ 506 "repo": s.Did, 507 "collection": collection, 508 "rkey": tid.TID(), 509 "record": json.RawMessage(recordBytes), 510 } 511 512 payloadBytes, err := json.Marshal(payload) 513 if err != nil { 514 return fmt.Errorf("failed to marshal request payload: %w", err) 515 } 516 517 url := s.PdsEndpoint + "/xrpc/com.atproto.repo.putRecord" 518 req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewBuffer(payloadBytes)) 519 if err != nil { 520 return fmt.Errorf("failed to create HTTP request: %w", err) 521 } 522 523 req.Header.Set("Content-Type", "application/json") 524 req.Header.Set("Authorization", "Bearer "+s.AccessJwt) 525 526 client := &http.Client{Timeout: 30 * time.Second} 527 resp, err := client.Do(req) 528 if err != nil { 529 return fmt.Errorf("failed to add user to default service: %w", err) 530 } 531 defer resp.Body.Close() 532 533 if resp.StatusCode != http.StatusOK { 534 return fmt.Errorf("failed to add user to default service: HTTP %d", resp.StatusCode) 535 } 536 537 return nil 538}