forked from tangled.org/core
Monorepo for Tangled — https://tangled.org
at opengraph 8.6 kB view raw
1package oauth 2 3import ( 4 "encoding/json" 5 "fmt" 6 "log" 7 "net/http" 8 "net/url" 9 "strings" 10 11 "github.com/go-chi/chi/v5" 12 "github.com/gorilla/sessions" 13 "github.com/haileyok/atproto-oauth-golang/helpers" 14 "github.com/lestrrat-go/jwx/v2/jwk" 15 "tangled.sh/tangled.sh/core/appview" 16 "tangled.sh/tangled.sh/core/appview/db" 17 "tangled.sh/tangled.sh/core/appview/knotclient" 18 "tangled.sh/tangled.sh/core/appview/middleware" 19 "tangled.sh/tangled.sh/core/appview/oauth" 20 "tangled.sh/tangled.sh/core/appview/oauth/client" 21 "tangled.sh/tangled.sh/core/appview/pages" 22 "tangled.sh/tangled.sh/core/rbac" 23) 24 25const ( 26 oauthScope = "atproto transition:generic" 27) 28 29type OAuthHandler struct { 30 Config *appview.Config 31 Pages *pages.Pages 32 Resolver *appview.Resolver 33 Db *db.DB 34 Store *sessions.CookieStore 35 OAuth *oauth.OAuth 36 Enforcer *rbac.Enforcer 37} 38 39func (o *OAuthHandler) Router() http.Handler { 40 r := chi.NewRouter() 41 42 r.Get("/login", o.login) 43 r.Post("/login", o.login) 44 45 r.With(middleware.AuthMiddleware(o.OAuth)).Post("/logout", o.logout) 46 47 r.Get("/oauth/client-metadata.json", o.clientMetadata) 48 r.Get("/oauth/jwks.json", o.jwks) 49 r.Get("/oauth/callback", o.callback) 50 return r 51} 52 53func (o *OAuthHandler) clientMetadata(w http.ResponseWriter, r *http.Request) { 54 w.Header().Set("Content-Type", "application/json") 55 w.WriteHeader(http.StatusOK) 56 json.NewEncoder(w).Encode(o.OAuth.ClientMetadata()) 57} 58 59func (o *OAuthHandler) jwks(w http.ResponseWriter, r *http.Request) { 60 jwks := o.Config.OAuth.Jwks 61 pubKey, err := pubKeyFromJwk(jwks) 62 if err != nil { 63 log.Printf("error parsing public key: %v", err) 64 http.Error(w, err.Error(), http.StatusInternalServerError) 65 return 66 } 67 68 response := helpers.CreateJwksResponseObject(pubKey) 69 70 w.Header().Set("Content-Type", "application/json") 71 w.WriteHeader(http.StatusOK) 72 json.NewEncoder(w).Encode(response) 73} 74 75func (o *OAuthHandler) login(w http.ResponseWriter, r *http.Request) { 76 switch r.Method { 77 case http.MethodGet: 78 o.Pages.Login(w, pages.LoginParams{}) 79 case http.MethodPost: 80 handle := strings.TrimPrefix(r.FormValue("handle"), "@") 81 82 resolved, err := o.Resolver.ResolveIdent(r.Context(), handle) 83 if err != nil { 84 log.Println("failed to resolve handle:", err) 85 o.Pages.Notice(w, "login-msg", fmt.Sprintf("\"%s\" is an invalid handle.", handle)) 86 return 87 } 88 self := o.OAuth.ClientMetadata() 89 oauthClient, err := client.NewClient( 90 self.ClientID, 91 o.Config.OAuth.Jwks, 92 self.RedirectURIs[0], 93 ) 94 95 if err != nil { 96 log.Println("failed to create oauth client:", err) 97 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 98 return 99 } 100 101 authServer, err := oauthClient.ResolvePdsAuthServer(r.Context(), resolved.PDSEndpoint()) 102 if err != nil { 103 log.Println("failed to resolve auth server:", err) 104 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 105 return 106 } 107 108 authMeta, err := oauthClient.FetchAuthServerMetadata(r.Context(), authServer) 109 if err != nil { 110 log.Println("failed to fetch auth server metadata:", err) 111 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 112 return 113 } 114 115 dpopKey, err := helpers.GenerateKey(nil) 116 if err != nil { 117 log.Println("failed to generate dpop key:", err) 118 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 119 return 120 } 121 122 dpopKeyJson, err := json.Marshal(dpopKey) 123 if err != nil { 124 log.Println("failed to marshal dpop key:", err) 125 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 126 return 127 } 128 129 parResp, err := oauthClient.SendParAuthRequest(r.Context(), authServer, authMeta, handle, oauthScope, dpopKey) 130 if err != nil { 131 log.Println("failed to send par auth request:", err) 132 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 133 return 134 } 135 136 err = db.SaveOAuthRequest(o.Db, db.OAuthRequest{ 137 Did: resolved.DID.String(), 138 PdsUrl: resolved.PDSEndpoint(), 139 Handle: handle, 140 AuthserverIss: authMeta.Issuer, 141 PkceVerifier: parResp.PkceVerifier, 142 DpopAuthserverNonce: parResp.DpopAuthserverNonce, 143 DpopPrivateJwk: string(dpopKeyJson), 144 State: parResp.State, 145 }) 146 if err != nil { 147 log.Println("failed to save oauth request:", err) 148 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 149 return 150 } 151 152 u, _ := url.Parse(authMeta.AuthorizationEndpoint) 153 query := url.Values{} 154 query.Add("client_id", self.ClientID) 155 query.Add("request_uri", parResp.RequestUri) 156 u.RawQuery = query.Encode() 157 o.Pages.HxRedirect(w, u.String()) 158 } 159} 160 161func (o *OAuthHandler) callback(w http.ResponseWriter, r *http.Request) { 162 state := r.FormValue("state") 163 164 oauthRequest, err := db.GetOAuthRequestByState(o.Db, state) 165 if err != nil { 166 log.Println("failed to get oauth request:", err) 167 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 168 return 169 } 170 171 defer func() { 172 err := db.DeleteOAuthRequestByState(o.Db, state) 173 if err != nil { 174 log.Println("failed to delete oauth request for state:", state, err) 175 } 176 }() 177 178 error := r.FormValue("error") 179 errorDescription := r.FormValue("error_description") 180 if error != "" || errorDescription != "" { 181 log.Printf("error: %s, %s", error, errorDescription) 182 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 183 return 184 } 185 186 code := r.FormValue("code") 187 if code == "" { 188 log.Println("missing code for state: ", state) 189 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 190 return 191 } 192 193 iss := r.FormValue("iss") 194 if iss == "" { 195 log.Println("missing iss for state: ", state) 196 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 197 return 198 } 199 200 self := o.OAuth.ClientMetadata() 201 202 oauthClient, err := client.NewClient( 203 self.ClientID, 204 o.Config.OAuth.Jwks, 205 self.RedirectURIs[0], 206 ) 207 208 if err != nil { 209 log.Println("failed to create oauth client:", err) 210 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 211 return 212 } 213 214 jwk, err := helpers.ParseJWKFromBytes([]byte(oauthRequest.DpopPrivateJwk)) 215 if err != nil { 216 log.Println("failed to parse jwk:", err) 217 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 218 return 219 } 220 221 tokenResp, err := oauthClient.InitialTokenRequest( 222 r.Context(), 223 code, 224 oauthRequest.AuthserverIss, 225 oauthRequest.PkceVerifier, 226 oauthRequest.DpopAuthserverNonce, 227 jwk, 228 ) 229 if err != nil { 230 log.Println("failed to get token:", err) 231 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 232 return 233 } 234 235 if tokenResp.Scope != oauthScope { 236 log.Println("scope doesn't match:", tokenResp.Scope) 237 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 238 return 239 } 240 241 err = o.OAuth.SaveSession(w, r, oauthRequest, tokenResp) 242 if err != nil { 243 log.Println("failed to save session:", err) 244 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") 245 return 246 } 247 248 log.Println("session saved successfully") 249 go o.addToDefaultKnot(oauthRequest.Did) 250 251 http.Redirect(w, r, "/", http.StatusFound) 252} 253 254func (o *OAuthHandler) logout(w http.ResponseWriter, r *http.Request) { 255 err := o.OAuth.ClearSession(r, w) 256 if err != nil { 257 log.Println("failed to clear session:", err) 258 http.Redirect(w, r, "/", http.StatusFound) 259 return 260 } 261 262 log.Println("session cleared successfully") 263 http.Redirect(w, r, "/", http.StatusFound) 264} 265 266func pubKeyFromJwk(jwks string) (jwk.Key, error) { 267 k, err := helpers.ParseJWKFromBytes([]byte(jwks)) 268 if err != nil { 269 return nil, err 270 } 271 pubKey, err := k.PublicKey() 272 if err != nil { 273 return nil, err 274 } 275 return pubKey, nil 276} 277 278func (o *OAuthHandler) addToDefaultKnot(did string) { 279 defaultKnot := "knot1.tangled.sh" 280 281 log.Printf("adding %s to default knot", did) 282 err := o.Enforcer.AddMember(defaultKnot, did) 283 if err != nil { 284 log.Println("failed to add user to knot1.tangled.sh: ", err) 285 return 286 } 287 err = o.Enforcer.E.SavePolicy() 288 if err != nil { 289 log.Println("failed to add user to knot1.tangled.sh: ", err) 290 return 291 } 292 293 secret, err := db.GetRegistrationKey(o.Db, defaultKnot) 294 if err != nil { 295 log.Println("failed to get registration key for knot1.tangled.sh") 296 return 297 } 298 signedClient, err := knotclient.NewSignedClient(defaultKnot, secret, o.Config.Core.Dev) 299 resp, err := signedClient.AddMember(did) 300 if err != nil { 301 log.Println("failed to add user to knot1.tangled.sh: ", err) 302 return 303 } 304 305 if resp.StatusCode != http.StatusNoContent { 306 log.Println("failed to add user to knot1.tangled.sh: ", resp.StatusCode) 307 return 308 } 309}