forked from tangled.org/core
Monorepo for Tangled — https://tangled.org
1package oauth 2 3import ( 4 "encoding/json" 5 "fmt" 6 "log" 7 "net/http" 8 "net/url" 9 "strings" 10 11 "github.com/go-chi/chi/v5" 12 "github.com/gorilla/sessions" 13 "github.com/lestrrat-go/jwx/v2/jwk" 14 "github.com/posthog/posthog-go" 15 "tangled.sh/icyphox.sh/atproto-oauth/helpers" 16 sessioncache "tangled.sh/tangled.sh/core/appview/cache/session" 17 "tangled.sh/tangled.sh/core/appview/config" 18 "tangled.sh/tangled.sh/core/appview/db" 19 "tangled.sh/tangled.sh/core/appview/idresolver" 20 "tangled.sh/tangled.sh/core/appview/middleware" 21 "tangled.sh/tangled.sh/core/appview/oauth" 22 "tangled.sh/tangled.sh/core/appview/oauth/client" 23 "tangled.sh/tangled.sh/core/appview/pages" 24 "tangled.sh/tangled.sh/core/knotclient" 25 "tangled.sh/tangled.sh/core/rbac" 26) 27 28const ( 29 oauthScope = "atproto transition:generic" 30) 31 32type OAuthHandler struct { 33 config *config.Config 34 pages *pages.Pages 35 idResolver *idresolver.Resolver 36 sess *sessioncache.SessionStore 37 db *db.DB 38 store *sessions.CookieStore 39 oauth *oauth.OAuth 40 enforcer *rbac.Enforcer 41 posthog posthog.Client 42} 43 44func New( 45 config *config.Config, 46 pages *pages.Pages, 47 idResolver *idresolver.Resolver, 48 db *db.DB, 49 sess *sessioncache.SessionStore, 50 store *sessions.CookieStore, 51 oauth *oauth.OAuth, 52 enforcer *rbac.Enforcer, 53 posthog posthog.Client, 54) *OAuthHandler { 55 return &OAuthHandler{ 56 config: config, 57 pages: pages, 58 idResolver: idResolver, 59 db: db, 60 sess: sess, 61 store: store, 62 oauth: oauth, 63 enforcer: enforcer, 64 posthog: posthog, 65 } 66} 67 68func (o *OAuthHandler) Router() http.Handler { 69 r := chi.NewRouter() 70 71 r.Get("/login", o.login) 72 r.Post("/login", o.login) 73 74 r.With(middleware.AuthMiddleware(o.oauth)).Post("/logout", o.logout) 75 76 r.Get("/oauth/client-metadata.json", o.clientMetadata) 77 r.Get("/oauth/jwks.json", o.jwks) 78 r.Get("/oauth/callback", o.callback) 79 return r 80} 81 82func (o *OAuthHandler) clientMetadata(w http.ResponseWriter, r *http.Request) { 83 w.Header().Set("Content-Type", "application/json") 84 w.WriteHeader(http.StatusOK) 85 json.NewEncoder(w).Encode(o.oauth.ClientMetadata()) 86} 87 88func (o *OAuthHandler) jwks(w http.ResponseWriter, r *http.Request) { 89 jwks := o.config.OAuth.Jwks 90 pubKey, err := pubKeyFromJwk(jwks) 91 if err != nil { 92 log.Printf("error parsing public key: %v", err) 93 http.Error(w, err.Error(), http.StatusInternalServerError) 94 return 95 } 96 97 response := helpers.CreateJwksResponseObject(pubKey) 98 99 w.Header().Set("Content-Type", "application/json") 100 w.WriteHeader(http.StatusOK) 101 json.NewEncoder(w).Encode(response) 102} 103 104func (o *OAuthHandler) login(w http.ResponseWriter, r *http.Request) { 105 switch r.Method { 106 case http.MethodGet: 107 o.pages.Login(w, pages.LoginParams{}) 108 case http.MethodPost: 109 handle := r.FormValue("handle") 110 111 // when users copy their handle from bsky.app, it tends to have these characters around it: 112 // 113 // @nelind.dk: 114 // \u202a ensures that the handle is always rendered left to right and 115 // \u202c reverts that so the rest of the page renders however it should 116 handle = strings.TrimPrefix(handle, "\u202a") 117 handle = strings.TrimSuffix(handle, "\u202c") 118 119 // `@` is harmless 120 handle = strings.TrimPrefix(handle, "@") 121 122 // basic handle validation 123 if !strings.Contains(handle, ".") { 124 log.Println("invalid handle format", "raw", handle) 125 o.pages.Notice(w, "login-msg", fmt.Sprintf("\"%s\" is an invalid handle. Did you mean %s.bsky.social?", handle, handle)) 126 return 127 } 128 129 resolved, err := o.idResolver.ResolveIdent(r.Context(), handle) 130 if err != nil { 131 log.Println("failed to resolve handle:", err) 132 o.pages.Notice(w, "login-msg", fmt.Sprintf("\"%s\" is an invalid handle.", handle)) 133 return 134 } 135 self := o.oauth.ClientMetadata() 136 oauthClient, err := client.NewClient( 137 self.ClientID, 138 o.config.OAuth.Jwks, 139 self.RedirectURIs[0], 140 ) 141 142 if err != nil { 143 log.Println("failed to create oauth client:", err) 144 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 145 return 146 } 147 148 authServer, err := oauthClient.ResolvePdsAuthServer(r.Context(), resolved.PDSEndpoint()) 149 if err != nil { 150 log.Println("failed to resolve auth server:", err) 151 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 152 return 153 } 154 155 authMeta, err := oauthClient.FetchAuthServerMetadata(r.Context(), authServer) 156 if err != nil { 157 log.Println("failed to fetch auth server metadata:", err) 158 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 159 return 160 } 161 162 dpopKey, err := helpers.GenerateKey(nil) 163 if err != nil { 164 log.Println("failed to generate dpop key:", err) 165 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 166 return 167 } 168 169 dpopKeyJson, err := json.Marshal(dpopKey) 170 if err != nil { 171 log.Println("failed to marshal dpop key:", err) 172 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 173 return 174 } 175 176 parResp, err := oauthClient.SendParAuthRequest(r.Context(), authServer, authMeta, handle, oauthScope, dpopKey) 177 if err != nil { 178 log.Println("failed to send par auth request:", err) 179 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 180 return 181 } 182 183 err = o.sess.SaveRequest(r.Context(), sessioncache.OAuthRequest{ 184 Did: resolved.DID.String(), 185 PdsUrl: resolved.PDSEndpoint(), 186 Handle: handle, 187 AuthserverIss: authMeta.Issuer, 188 PkceVerifier: parResp.PkceVerifier, 189 DpopAuthserverNonce: parResp.DpopAuthserverNonce, 190 DpopPrivateJwk: string(dpopKeyJson), 191 State: parResp.State, 192 }) 193 if err != nil { 194 log.Println("failed to save oauth request:", err) 195 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 196 return 197 } 198 199 u, _ := url.Parse(authMeta.AuthorizationEndpoint) 200 query := url.Values{} 201 query.Add("client_id", self.ClientID) 202 query.Add("request_uri", parResp.RequestUri) 203 u.RawQuery = query.Encode() 204 o.pages.HxRedirect(w, u.String()) 205 } 206} 207 208func (o *OAuthHandler) callback(w http.ResponseWriter, r *http.Request) { 209 state := r.FormValue("state") 210 211 oauthRequest, err := o.sess.GetRequestByState(r.Context(), state) 212 if err != nil { 213 log.Println("failed to get oauth request:", err) 214 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 215 return 216 } 217 218 defer func() { 219 err := o.sess.DeleteRequestByState(r.Context(), state) 220 if err != nil { 221 log.Println("failed to delete oauth request for state:", state, err) 222 } 223 }() 224 225 error := r.FormValue("error") 226 errorDescription := r.FormValue("error_description") 227 if error != "" || errorDescription != "" { 228 log.Printf("error: %s, %s", error, errorDescription) 229 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 230 return 231 } 232 233 code := r.FormValue("code") 234 if code == "" { 235 log.Println("missing code for state: ", state) 236 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 237 return 238 } 239 240 iss := r.FormValue("iss") 241 if iss == "" { 242 log.Println("missing iss for state: ", state) 243 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 244 return 245 } 246 247 self := o.oauth.ClientMetadata() 248 249 oauthClient, err := client.NewClient( 250 self.ClientID, 251 o.config.OAuth.Jwks, 252 self.RedirectURIs[0], 253 ) 254 255 if err != nil { 256 log.Println("failed to create oauth client:", err) 257 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 258 return 259 } 260 261 jwk, err := helpers.ParseJWKFromBytes([]byte(oauthRequest.DpopPrivateJwk)) 262 if err != nil { 263 log.Println("failed to parse jwk:", err) 264 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 265 return 266 } 267 268 tokenResp, err := oauthClient.InitialTokenRequest( 269 r.Context(), 270 code, 271 oauthRequest.AuthserverIss, 272 oauthRequest.PkceVerifier, 273 oauthRequest.DpopAuthserverNonce, 274 jwk, 275 ) 276 if err != nil { 277 log.Println("failed to get token:", err) 278 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 279 return 280 } 281 282 if tokenResp.Scope != oauthScope { 283 log.Println("scope doesn't match:", tokenResp.Scope) 284 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 285 return 286 } 287 288 err = o.oauth.SaveSession(w, r, *oauthRequest, tokenResp) 289 if err != nil { 290 log.Println("failed to save session:", err) 291 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 292 return 293 } 294 295 log.Println("session saved successfully") 296 go o.addToDefaultKnot(oauthRequest.Did) 297 298 if !o.config.Core.Dev { 299 err = o.posthog.Enqueue(posthog.Capture{ 300 DistinctId: oauthRequest.Did, 301 Event: "signin", 302 }) 303 if err != nil { 304 log.Println("failed to enqueue posthog event:", err) 305 } 306 } 307 308 http.Redirect(w, r, "/", http.StatusFound) 309} 310 311func (o *OAuthHandler) logout(w http.ResponseWriter, r *http.Request) { 312 err := o.oauth.ClearSession(r, w) 313 if err != nil { 314 log.Println("failed to clear session:", err) 315 http.Redirect(w, r, "/", http.StatusFound) 316 return 317 } 318 319 log.Println("session cleared successfully") 320 o.pages.HxRedirect(w, "/login") 321} 322 323func pubKeyFromJwk(jwks string) (jwk.Key, error) { 324 k, err := helpers.ParseJWKFromBytes([]byte(jwks)) 325 if err != nil { 326 return nil, err 327 } 328 pubKey, err := k.PublicKey() 329 if err != nil { 330 return nil, err 331 } 332 return pubKey, nil 333} 334 335func (o *OAuthHandler) addToDefaultKnot(did string) { 336 defaultKnot := "knot1.tangled.sh" 337 338 log.Printf("adding %s to default knot", did) 339 err := o.enforcer.AddKnotMember(defaultKnot, did) 340 if err != nil { 341 log.Println("failed to add user to knot1.tangled.sh: ", err) 342 return 343 } 344 err = o.enforcer.E.SavePolicy() 345 if err != nil { 346 log.Println("failed to add user to knot1.tangled.sh: ", err) 347 return 348 } 349 350 secret, err := db.GetRegistrationKey(o.db, defaultKnot) 351 if err != nil { 352 log.Println("failed to get registration key for knot1.tangled.sh") 353 return 354 } 355 signedClient, err := knotclient.NewSignedClient(defaultKnot, secret, o.config.Core.Dev) 356 resp, err := signedClient.AddMember(did) 357 if err != nil { 358 log.Println("failed to add user to knot1.tangled.sh: ", err) 359 return 360 } 361 362 if resp.StatusCode != http.StatusNoContent { 363 log.Println("failed to add user to knot1.tangled.sh: ", resp.StatusCode) 364 return 365 } 366}