forked from tangled.org/core
Monorepo for Tangled — https://tangled.org
1package oauth 2 3import ( 4 "encoding/json" 5 "fmt" 6 "log" 7 "net/http" 8 "net/url" 9 "strings" 10 11 "github.com/go-chi/chi/v5" 12 "github.com/gorilla/sessions" 13 "github.com/lestrrat-go/jwx/v2/jwk" 14 "github.com/posthog/posthog-go" 15 "tangled.sh/icyphox.sh/atproto-oauth/helpers" 16 "tangled.sh/tangled.sh/core/appview/config" 17 "tangled.sh/tangled.sh/core/appview/db" 18 "tangled.sh/tangled.sh/core/appview/idresolver" 19 "tangled.sh/tangled.sh/core/appview/middleware" 20 "tangled.sh/tangled.sh/core/appview/oauth" 21 "tangled.sh/tangled.sh/core/appview/oauth/client" 22 "tangled.sh/tangled.sh/core/appview/pages" 23 "tangled.sh/tangled.sh/core/knotclient" 24 "tangled.sh/tangled.sh/core/rbac" 25) 26 27const ( 28 oauthScope = "atproto transition:generic" 29) 30 31type OAuthHandler struct { 32 config *config.Config 33 pages *pages.Pages 34 idResolver *idresolver.Resolver 35 db *db.DB 36 store *sessions.CookieStore 37 oauth *oauth.OAuth 38 enforcer *rbac.Enforcer 39 posthog posthog.Client 40} 41 42func New( 43 config *config.Config, 44 pages *pages.Pages, 45 idResolver *idresolver.Resolver, 46 db *db.DB, 47 store *sessions.CookieStore, 48 oauth *oauth.OAuth, 49 enforcer *rbac.Enforcer, 50 posthog posthog.Client, 51) *OAuthHandler { 52 return &OAuthHandler{ 53 config: config, 54 pages: pages, 55 idResolver: idResolver, 56 db: db, 57 store: store, 58 oauth: oauth, 59 enforcer: enforcer, 60 posthog: posthog, 61 } 62} 63 64func (o *OAuthHandler) Router() http.Handler { 65 r := chi.NewRouter() 66 67 r.Get("/login", o.login) 68 r.Post("/login", o.login) 69 70 r.With(middleware.AuthMiddleware(o.oauth)).Post("/logout", o.logout) 71 72 r.Get("/oauth/client-metadata.json", o.clientMetadata) 73 r.Get("/oauth/jwks.json", o.jwks) 74 r.Get("/oauth/callback", o.callback) 75 return r 76} 77 78func (o *OAuthHandler) clientMetadata(w http.ResponseWriter, r *http.Request) { 79 w.Header().Set("Content-Type", "application/json") 80 w.WriteHeader(http.StatusOK) 81 json.NewEncoder(w).Encode(o.oauth.ClientMetadata()) 82} 83 84func (o *OAuthHandler) jwks(w http.ResponseWriter, r *http.Request) { 85 jwks := o.config.OAuth.Jwks 86 pubKey, err := pubKeyFromJwk(jwks) 87 if err != nil { 88 log.Printf("error parsing public key: %v", err) 89 http.Error(w, err.Error(), http.StatusInternalServerError) 90 return 91 } 92 93 response := helpers.CreateJwksResponseObject(pubKey) 94 95 w.Header().Set("Content-Type", "application/json") 96 w.WriteHeader(http.StatusOK) 97 json.NewEncoder(w).Encode(response) 98} 99 100func (o *OAuthHandler) login(w http.ResponseWriter, r *http.Request) { 101 switch r.Method { 102 case http.MethodGet: 103 o.pages.Login(w, pages.LoginParams{}) 104 case http.MethodPost: 105 handle := r.FormValue("handle") 106 107 // when users copy their handle from bsky.app, it tends to have these characters around it: 108 // 109 // @nelind.dk: 110 // \u202a ensures that the handle is always rendered left to right and 111 // \u202c reverts that so the rest of the page renders however it should 112 handle = strings.TrimPrefix(handle, "\u202a") 113 handle = strings.TrimSuffix(handle, "\u202c") 114 115 // `@` is harmless 116 handle = strings.TrimPrefix(handle, "@") 117 118 // basic handle validation 119 if !strings.Contains(handle, ".") { 120 log.Println("invalid handle format", "raw", handle) 121 o.pages.Notice(w, "login-msg", fmt.Sprintf("\"%s\" is an invalid handle. Did you mean %s.bsky.social?", handle, handle)) 122 return 123 } 124 125 resolved, err := o.idResolver.ResolveIdent(r.Context(), handle) 126 if err != nil { 127 log.Println("failed to resolve handle:", err) 128 o.pages.Notice(w, "login-msg", fmt.Sprintf("\"%s\" is an invalid handle.", handle)) 129 return 130 } 131 self := o.oauth.ClientMetadata() 132 oauthClient, err := client.NewClient( 133 self.ClientID, 134 o.config.OAuth.Jwks, 135 self.RedirectURIs[0], 136 ) 137 138 if err != nil { 139 log.Println("failed to create oauth client:", err) 140 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 141 return 142 } 143 144 authServer, err := oauthClient.ResolvePdsAuthServer(r.Context(), resolved.PDSEndpoint()) 145 if err != nil { 146 log.Println("failed to resolve auth server:", err) 147 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 148 return 149 } 150 151 authMeta, err := oauthClient.FetchAuthServerMetadata(r.Context(), authServer) 152 if err != nil { 153 log.Println("failed to fetch auth server metadata:", err) 154 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 155 return 156 } 157 158 dpopKey, err := helpers.GenerateKey(nil) 159 if err != nil { 160 log.Println("failed to generate dpop key:", err) 161 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 162 return 163 } 164 165 dpopKeyJson, err := json.Marshal(dpopKey) 166 if err != nil { 167 log.Println("failed to marshal dpop key:", err) 168 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 169 return 170 } 171 172 parResp, err := oauthClient.SendParAuthRequest(r.Context(), authServer, authMeta, handle, oauthScope, dpopKey) 173 if err != nil { 174 log.Println("failed to send par auth request:", err) 175 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 176 return 177 } 178 179 err = db.SaveOAuthRequest(o.db, db.OAuthRequest{ 180 Did: resolved.DID.String(), 181 PdsUrl: resolved.PDSEndpoint(), 182 Handle: handle, 183 AuthserverIss: authMeta.Issuer, 184 PkceVerifier: parResp.PkceVerifier, 185 DpopAuthserverNonce: parResp.DpopAuthserverNonce, 186 DpopPrivateJwk: string(dpopKeyJson), 187 State: parResp.State, 188 }) 189 if err != nil { 190 log.Println("failed to save oauth request:", err) 191 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 192 return 193 } 194 195 u, _ := url.Parse(authMeta.AuthorizationEndpoint) 196 query := url.Values{} 197 query.Add("client_id", self.ClientID) 198 query.Add("request_uri", parResp.RequestUri) 199 u.RawQuery = query.Encode() 200 o.pages.HxRedirect(w, u.String()) 201 } 202} 203 204func (o *OAuthHandler) callback(w http.ResponseWriter, r *http.Request) { 205 state := r.FormValue("state") 206 207 oauthRequest, err := db.GetOAuthRequestByState(o.db, state) 208 if err != nil { 209 log.Println("failed to get oauth request:", err) 210 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 211 return 212 } 213 214 defer func() { 215 err := db.DeleteOAuthRequestByState(o.db, state) 216 if err != nil { 217 log.Println("failed to delete oauth request for state:", state, err) 218 } 219 }() 220 221 error := r.FormValue("error") 222 errorDescription := r.FormValue("error_description") 223 if error != "" || errorDescription != "" { 224 log.Printf("error: %s, %s", error, errorDescription) 225 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 226 return 227 } 228 229 code := r.FormValue("code") 230 if code == "" { 231 log.Println("missing code for state: ", state) 232 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 233 return 234 } 235 236 iss := r.FormValue("iss") 237 if iss == "" { 238 log.Println("missing iss for state: ", state) 239 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 240 return 241 } 242 243 self := o.oauth.ClientMetadata() 244 245 oauthClient, err := client.NewClient( 246 self.ClientID, 247 o.config.OAuth.Jwks, 248 self.RedirectURIs[0], 249 ) 250 251 if err != nil { 252 log.Println("failed to create oauth client:", err) 253 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 254 return 255 } 256 257 jwk, err := helpers.ParseJWKFromBytes([]byte(oauthRequest.DpopPrivateJwk)) 258 if err != nil { 259 log.Println("failed to parse jwk:", err) 260 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 261 return 262 } 263 264 tokenResp, err := oauthClient.InitialTokenRequest( 265 r.Context(), 266 code, 267 oauthRequest.AuthserverIss, 268 oauthRequest.PkceVerifier, 269 oauthRequest.DpopAuthserverNonce, 270 jwk, 271 ) 272 if err != nil { 273 log.Println("failed to get token:", err) 274 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 275 return 276 } 277 278 if tokenResp.Scope != oauthScope { 279 log.Println("scope doesn't match:", tokenResp.Scope) 280 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 281 return 282 } 283 284 err = o.oauth.SaveSession(w, r, oauthRequest, tokenResp) 285 if err != nil { 286 log.Println("failed to save session:", err) 287 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 288 return 289 } 290 291 log.Println("session saved successfully") 292 go o.addToDefaultKnot(oauthRequest.Did) 293 294 if !o.config.Core.Dev { 295 err = o.posthog.Enqueue(posthog.Capture{ 296 DistinctId: oauthRequest.Did, 297 Event: "signin", 298 }) 299 if err != nil { 300 log.Println("failed to enqueue posthog event:", err) 301 } 302 } 303 304 http.Redirect(w, r, "/", http.StatusFound) 305} 306 307func (o *OAuthHandler) logout(w http.ResponseWriter, r *http.Request) { 308 err := o.oauth.ClearSession(r, w) 309 if err != nil { 310 log.Println("failed to clear session:", err) 311 http.Redirect(w, r, "/", http.StatusFound) 312 return 313 } 314 315 log.Println("session cleared successfully") 316 http.Redirect(w, r, "/", http.StatusFound) 317} 318 319func pubKeyFromJwk(jwks string) (jwk.Key, error) { 320 k, err := helpers.ParseJWKFromBytes([]byte(jwks)) 321 if err != nil { 322 return nil, err 323 } 324 pubKey, err := k.PublicKey() 325 if err != nil { 326 return nil, err 327 } 328 return pubKey, nil 329} 330 331func (o *OAuthHandler) addToDefaultKnot(did string) { 332 defaultKnot := "knot1.tangled.sh" 333 334 log.Printf("adding %s to default knot", did) 335 err := o.enforcer.AddMember(defaultKnot, did) 336 if err != nil { 337 log.Println("failed to add user to knot1.tangled.sh: ", err) 338 return 339 } 340 err = o.enforcer.E.SavePolicy() 341 if err != nil { 342 log.Println("failed to add user to knot1.tangled.sh: ", err) 343 return 344 } 345 346 secret, err := db.GetRegistrationKey(o.db, defaultKnot) 347 if err != nil { 348 log.Println("failed to get registration key for knot1.tangled.sh") 349 return 350 } 351 signedClient, err := knotclient.NewSignedClient(defaultKnot, secret, o.config.Core.Dev) 352 resp, err := signedClient.AddMember(did) 353 if err != nil { 354 log.Println("failed to add user to knot1.tangled.sh: ", err) 355 return 356 } 357 358 if resp.StatusCode != http.StatusNoContent { 359 log.Println("failed to add user to knot1.tangled.sh: ", resp.StatusCode) 360 return 361 } 362}