forked from tangled.org/core
Monorepo for Tangled — https://tangled.org
1package oauth 2 3import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "fmt" 8 "log" 9 "net/http" 10 "net/url" 11 "strings" 12 "time" 13 14 "github.com/go-chi/chi/v5" 15 "github.com/gorilla/sessions" 16 "github.com/lestrrat-go/jwx/v2/jwk" 17 "github.com/posthog/posthog-go" 18 "tangled.sh/icyphox.sh/atproto-oauth/helpers" 19 tangled "tangled.sh/tangled.sh/core/api/tangled" 20 sessioncache "tangled.sh/tangled.sh/core/appview/cache/session" 21 "tangled.sh/tangled.sh/core/appview/config" 22 "tangled.sh/tangled.sh/core/appview/db" 23 "tangled.sh/tangled.sh/core/appview/middleware" 24 "tangled.sh/tangled.sh/core/appview/oauth" 25 "tangled.sh/tangled.sh/core/appview/oauth/client" 26 "tangled.sh/tangled.sh/core/appview/pages" 27 "tangled.sh/tangled.sh/core/idresolver" 28 "tangled.sh/tangled.sh/core/knotclient" 29 "tangled.sh/tangled.sh/core/rbac" 30 "tangled.sh/tangled.sh/core/tid" 31) 32 33const ( 34 oauthScope = "atproto transition:generic" 35) 36 37type OAuthHandler struct { 38 config *config.Config 39 pages *pages.Pages 40 idResolver *idresolver.Resolver 41 sess *sessioncache.SessionStore 42 db *db.DB 43 store *sessions.CookieStore 44 oauth *oauth.OAuth 45 enforcer *rbac.Enforcer 46 posthog posthog.Client 47} 48 49func New( 50 config *config.Config, 51 pages *pages.Pages, 52 idResolver *idresolver.Resolver, 53 db *db.DB, 54 sess *sessioncache.SessionStore, 55 store *sessions.CookieStore, 56 oauth *oauth.OAuth, 57 enforcer *rbac.Enforcer, 58 posthog posthog.Client, 59) *OAuthHandler { 60 return &OAuthHandler{ 61 config: config, 62 pages: pages, 63 idResolver: idResolver, 64 db: db, 65 sess: sess, 66 store: store, 67 oauth: oauth, 68 enforcer: enforcer, 69 posthog: posthog, 70 } 71} 72 73func (o *OAuthHandler) Router() http.Handler { 74 r := chi.NewRouter() 75 76 r.Get("/login", o.login) 77 r.Post("/login", o.login) 78 79 r.With(middleware.AuthMiddleware(o.oauth)).Post("/logout", o.logout) 80 81 r.Get("/oauth/client-metadata.json", o.clientMetadata) 82 r.Get("/oauth/jwks.json", o.jwks) 83 r.Get("/oauth/callback", o.callback) 84 return r 85} 86 87func (o *OAuthHandler) clientMetadata(w http.ResponseWriter, r *http.Request) { 88 w.Header().Set("Content-Type", "application/json") 89 w.WriteHeader(http.StatusOK) 90 json.NewEncoder(w).Encode(o.oauth.ClientMetadata()) 91} 92 93func (o *OAuthHandler) jwks(w http.ResponseWriter, r *http.Request) { 94 jwks := o.config.OAuth.Jwks 95 pubKey, err := pubKeyFromJwk(jwks) 96 if err != nil { 97 log.Printf("error parsing public key: %v", err) 98 http.Error(w, err.Error(), http.StatusInternalServerError) 99 return 100 } 101 102 response := helpers.CreateJwksResponseObject(pubKey) 103 104 w.Header().Set("Content-Type", "application/json") 105 w.WriteHeader(http.StatusOK) 106 json.NewEncoder(w).Encode(response) 107} 108 109func (o *OAuthHandler) login(w http.ResponseWriter, r *http.Request) { 110 switch r.Method { 111 case http.MethodGet: 112 o.pages.Login(w, pages.LoginParams{}) 113 case http.MethodPost: 114 handle := r.FormValue("handle") 115 116 // when users copy their handle from bsky.app, it tends to have these characters around it: 117 // 118 // @nelind.dk: 119 // \u202a ensures that the handle is always rendered left to right and 120 // \u202c reverts that so the rest of the page renders however it should 121 handle = strings.TrimPrefix(handle, "\u202a") 122 handle = strings.TrimSuffix(handle, "\u202c") 123 124 // `@` is harmless 125 handle = strings.TrimPrefix(handle, "@") 126 127 // basic handle validation 128 if !strings.Contains(handle, ".") { 129 log.Println("invalid handle format", "raw", handle) 130 o.pages.Notice(w, "login-msg", fmt.Sprintf("\"%s\" is an invalid handle. Did you mean %s.bsky.social?", handle, handle)) 131 return 132 } 133 134 resolved, err := o.idResolver.ResolveIdent(r.Context(), handle) 135 if err != nil { 136 log.Println("failed to resolve handle:", err) 137 o.pages.Notice(w, "login-msg", fmt.Sprintf("\"%s\" is an invalid handle.", handle)) 138 return 139 } 140 self := o.oauth.ClientMetadata() 141 oauthClient, err := client.NewClient( 142 self.ClientID, 143 o.config.OAuth.Jwks, 144 self.RedirectURIs[0], 145 ) 146 147 if err != nil { 148 log.Println("failed to create oauth client:", err) 149 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 150 return 151 } 152 153 authServer, err := oauthClient.ResolvePdsAuthServer(r.Context(), resolved.PDSEndpoint()) 154 if err != nil { 155 log.Println("failed to resolve auth server:", err) 156 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 157 return 158 } 159 160 authMeta, err := oauthClient.FetchAuthServerMetadata(r.Context(), authServer) 161 if err != nil { 162 log.Println("failed to fetch auth server metadata:", err) 163 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 164 return 165 } 166 167 dpopKey, err := helpers.GenerateKey(nil) 168 if err != nil { 169 log.Println("failed to generate dpop key:", err) 170 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 171 return 172 } 173 174 dpopKeyJson, err := json.Marshal(dpopKey) 175 if err != nil { 176 log.Println("failed to marshal dpop key:", err) 177 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 178 return 179 } 180 181 parResp, err := oauthClient.SendParAuthRequest(r.Context(), authServer, authMeta, handle, oauthScope, dpopKey) 182 if err != nil { 183 log.Println("failed to send par auth request:", err) 184 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 185 return 186 } 187 188 err = o.sess.SaveRequest(r.Context(), sessioncache.OAuthRequest{ 189 Did: resolved.DID.String(), 190 PdsUrl: resolved.PDSEndpoint(), 191 Handle: handle, 192 AuthserverIss: authMeta.Issuer, 193 PkceVerifier: parResp.PkceVerifier, 194 DpopAuthserverNonce: parResp.DpopAuthserverNonce, 195 DpopPrivateJwk: string(dpopKeyJson), 196 State: parResp.State, 197 }) 198 if err != nil { 199 log.Println("failed to save oauth request:", err) 200 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 201 return 202 } 203 204 u, _ := url.Parse(authMeta.AuthorizationEndpoint) 205 query := url.Values{} 206 query.Add("client_id", self.ClientID) 207 query.Add("request_uri", parResp.RequestUri) 208 u.RawQuery = query.Encode() 209 o.pages.HxRedirect(w, u.String()) 210 } 211} 212 213func (o *OAuthHandler) callback(w http.ResponseWriter, r *http.Request) { 214 state := r.FormValue("state") 215 216 oauthRequest, err := o.sess.GetRequestByState(r.Context(), state) 217 if err != nil { 218 log.Println("failed to get oauth request:", err) 219 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 220 return 221 } 222 223 defer func() { 224 err := o.sess.DeleteRequestByState(r.Context(), state) 225 if err != nil { 226 log.Println("failed to delete oauth request for state:", state, err) 227 } 228 }() 229 230 error := r.FormValue("error") 231 errorDescription := r.FormValue("error_description") 232 if error != "" || errorDescription != "" { 233 log.Printf("error: %s, %s", error, errorDescription) 234 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 235 return 236 } 237 238 code := r.FormValue("code") 239 if code == "" { 240 log.Println("missing code for state: ", state) 241 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 242 return 243 } 244 245 iss := r.FormValue("iss") 246 if iss == "" { 247 log.Println("missing iss for state: ", state) 248 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 249 return 250 } 251 252 self := o.oauth.ClientMetadata() 253 254 oauthClient, err := client.NewClient( 255 self.ClientID, 256 o.config.OAuth.Jwks, 257 self.RedirectURIs[0], 258 ) 259 260 if err != nil { 261 log.Println("failed to create oauth client:", err) 262 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 263 return 264 } 265 266 jwk, err := helpers.ParseJWKFromBytes([]byte(oauthRequest.DpopPrivateJwk)) 267 if err != nil { 268 log.Println("failed to parse jwk:", err) 269 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 270 return 271 } 272 273 tokenResp, err := oauthClient.InitialTokenRequest( 274 r.Context(), 275 code, 276 oauthRequest.AuthserverIss, 277 oauthRequest.PkceVerifier, 278 oauthRequest.DpopAuthserverNonce, 279 jwk, 280 ) 281 if err != nil { 282 log.Println("failed to get token:", err) 283 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 284 return 285 } 286 287 if tokenResp.Scope != oauthScope { 288 log.Println("scope doesn't match:", tokenResp.Scope) 289 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 290 return 291 } 292 293 err = o.oauth.SaveSession(w, r, *oauthRequest, tokenResp) 294 if err != nil { 295 log.Println("failed to save session:", err) 296 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 297 return 298 } 299 300 log.Println("session saved successfully") 301 go o.addToDefaultKnot(oauthRequest.Did) 302 go o.addToDefaultSpindle(oauthRequest.Did) 303 304 if !o.config.Core.Dev { 305 err = o.posthog.Enqueue(posthog.Capture{ 306 DistinctId: oauthRequest.Did, 307 Event: "signin", 308 }) 309 if err != nil { 310 log.Println("failed to enqueue posthog event:", err) 311 } 312 } 313 314 http.Redirect(w, r, "/", http.StatusFound) 315} 316 317func (o *OAuthHandler) logout(w http.ResponseWriter, r *http.Request) { 318 err := o.oauth.ClearSession(r, w) 319 if err != nil { 320 log.Println("failed to clear session:", err) 321 http.Redirect(w, r, "/", http.StatusFound) 322 return 323 } 324 325 log.Println("session cleared successfully") 326 o.pages.HxRedirect(w, "/login") 327} 328 329func pubKeyFromJwk(jwks string) (jwk.Key, error) { 330 k, err := helpers.ParseJWKFromBytes([]byte(jwks)) 331 if err != nil { 332 return nil, err 333 } 334 pubKey, err := k.PublicKey() 335 if err != nil { 336 return nil, err 337 } 338 return pubKey, nil 339} 340 341func (o *OAuthHandler) addToDefaultSpindle(did string) { 342 // use the tangled.sh app password to get an accessJwt 343 // and create an sh.tangled.spindle.member record with that 344 345 defaultSpindle := "spindle.tangled.sh" 346 appPassword := o.config.Core.AppPassword 347 348 spindleMembers, err := db.GetSpindleMembers( 349 o.db, 350 db.FilterEq("instance", "spindle.tangled.sh"), 351 db.FilterEq("subject", did), 352 ) 353 if err != nil { 354 log.Printf("failed to get spindle members for did %s: %v", did, err) 355 return 356 } 357 358 if len(spindleMembers) != 0 { 359 log.Printf("did %s is already a member of the default spindle", did) 360 return 361 } 362 363 // TODO: hardcoded tangled handle and did for now 364 tangledHandle := "tangled.sh" 365 tangledDid := "did:plc:wshs7t2adsemcrrd4snkeqli" 366 367 if appPassword == "" { 368 log.Println("no app password configured, skipping spindle member addition") 369 return 370 } 371 372 log.Printf("adding %s to default spindle", did) 373 374 resolved, err := o.idResolver.ResolveIdent(context.Background(), tangledDid) 375 if err != nil { 376 log.Printf("failed to resolve tangled.sh DID %s: %v", tangledDid, err) 377 return 378 } 379 380 pdsEndpoint := resolved.PDSEndpoint() 381 if pdsEndpoint == "" { 382 log.Printf("no PDS endpoint found for tangled.sh DID %s", tangledDid) 383 return 384 } 385 386 sessionPayload := map[string]string{ 387 "identifier": tangledHandle, 388 "password": appPassword, 389 } 390 sessionBytes, err := json.Marshal(sessionPayload) 391 if err != nil { 392 log.Printf("failed to marshal session payload: %v", err) 393 return 394 } 395 396 sessionURL := pdsEndpoint + "/xrpc/com.atproto.server.createSession" 397 sessionReq, err := http.NewRequestWithContext(context.Background(), "POST", sessionURL, bytes.NewBuffer(sessionBytes)) 398 if err != nil { 399 log.Printf("failed to create session request: %v", err) 400 return 401 } 402 sessionReq.Header.Set("Content-Type", "application/json") 403 404 client := &http.Client{Timeout: 30 * time.Second} 405 sessionResp, err := client.Do(sessionReq) 406 if err != nil { 407 log.Printf("failed to create session: %v", err) 408 return 409 } 410 defer sessionResp.Body.Close() 411 412 if sessionResp.StatusCode != http.StatusOK { 413 log.Printf("failed to create session: HTTP %d", sessionResp.StatusCode) 414 return 415 } 416 417 var session struct { 418 AccessJwt string `json:"accessJwt"` 419 } 420 if err := json.NewDecoder(sessionResp.Body).Decode(&session); err != nil { 421 log.Printf("failed to decode session response: %v", err) 422 return 423 } 424 425 record := tangled.SpindleMember{ 426 LexiconTypeID: "sh.tangled.spindle.member", 427 Subject: did, 428 Instance: defaultSpindle, 429 CreatedAt: time.Now().Format(time.RFC3339), 430 } 431 432 recordBytes, err := json.Marshal(record) 433 if err != nil { 434 log.Printf("failed to marshal spindle member record: %v", err) 435 return 436 } 437 438 payload := map[string]interface{}{ 439 "repo": tangledDid, 440 "collection": tangled.SpindleMemberNSID, 441 "rkey": tid.TID(), 442 "record": json.RawMessage(recordBytes), 443 } 444 445 payloadBytes, err := json.Marshal(payload) 446 if err != nil { 447 log.Printf("failed to marshal request payload: %v", err) 448 return 449 } 450 451 url := pdsEndpoint + "/xrpc/com.atproto.repo.putRecord" 452 req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewBuffer(payloadBytes)) 453 if err != nil { 454 log.Printf("failed to create HTTP request: %v", err) 455 return 456 } 457 458 req.Header.Set("Content-Type", "application/json") 459 req.Header.Set("Authorization", "Bearer "+session.AccessJwt) 460 461 resp, err := client.Do(req) 462 if err != nil { 463 log.Printf("failed to add user to default spindle: %v", err) 464 return 465 } 466 defer resp.Body.Close() 467 468 if resp.StatusCode != http.StatusOK { 469 log.Printf("failed to add user to default spindle: HTTP %d", resp.StatusCode) 470 return 471 } 472 473 log.Printf("successfully added %s to default spindle", did) 474} 475 476func (o *OAuthHandler) addToDefaultKnot(did string) { 477 defaultKnot := "knot1.tangled.sh" 478 479 log.Printf("adding %s to default knot", did) 480 err := o.enforcer.AddKnotMember(defaultKnot, did) 481 if err != nil { 482 log.Println("failed to add user to knot1.tangled.sh: ", err) 483 return 484 } 485 err = o.enforcer.E.SavePolicy() 486 if err != nil { 487 log.Println("failed to add user to knot1.tangled.sh: ", err) 488 return 489 } 490 491 secret, err := db.GetRegistrationKey(o.db, defaultKnot) 492 if err != nil { 493 log.Println("failed to get registration key for knot1.tangled.sh") 494 return 495 } 496 signedClient, err := knotclient.NewSignedClient(defaultKnot, secret, o.config.Core.Dev) 497 resp, err := signedClient.AddMember(did) 498 if err != nil { 499 log.Println("failed to add user to knot1.tangled.sh: ", err) 500 return 501 } 502 503 if resp.StatusCode != http.StatusNoContent { 504 log.Println("failed to add user to knot1.tangled.sh: ", resp.StatusCode) 505 return 506 } 507}