forked from tangled.org/core
Monorepo for Tangled — https://tangled.org

appview: oauth: clean up router init

Signed-off-by: Anirudh Oppiliappan <anirudh@tangled.sh>

Changed files
+90 -78
appview
+1 -1
appview/consts.go appview/oauth/consts.go
···
-
package appview
+
package oauth
const (
SessionName = "appview-session"
+68 -46
appview/oauth/handler/handler.go
···
)
type OAuthHandler struct {
-
Config *config.Config
-
Pages *pages.Pages
-
Idresolver *idresolver.Resolver
-
Db *db.DB
-
Store *sessions.CookieStore
-
OAuth *oauth.OAuth
-
Enforcer *rbac.Enforcer
-
Posthog posthog.Client
+
config *config.Config
+
pages *pages.Pages
+
idResolver *idresolver.Resolver
+
db *db.DB
+
store *sessions.CookieStore
+
oauth *oauth.OAuth
+
enforcer *rbac.Enforcer
+
posthog posthog.Client
+
}
+
+
func New(
+
config *config.Config,
+
pages *pages.Pages,
+
idResolver *idresolver.Resolver,
+
db *db.DB,
+
store *sessions.CookieStore,
+
oauth *oauth.OAuth,
+
enforcer *rbac.Enforcer,
+
posthog posthog.Client,
+
) *OAuthHandler {
+
return &OAuthHandler{
+
config: config,
+
pages: pages,
+
idResolver: idResolver,
+
db: db,
+
store: store,
+
oauth: oauth,
+
enforcer: enforcer,
+
posthog: posthog,
+
}
}
func (o *OAuthHandler) Router() http.Handler {
···
r.Get("/login", o.login)
r.Post("/login", o.login)
-
r.With(middleware.AuthMiddleware(o.OAuth)).Post("/logout", o.logout)
+
r.With(middleware.AuthMiddleware(o.oauth)).Post("/logout", o.logout)
r.Get("/oauth/client-metadata.json", o.clientMetadata)
r.Get("/oauth/jwks.json", o.jwks)
···
func (o *OAuthHandler) clientMetadata(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
-
json.NewEncoder(w).Encode(o.OAuth.ClientMetadata())
+
json.NewEncoder(w).Encode(o.oauth.ClientMetadata())
}
func (o *OAuthHandler) jwks(w http.ResponseWriter, r *http.Request) {
-
jwks := o.Config.OAuth.Jwks
+
jwks := o.config.OAuth.Jwks
pubKey, err := pubKeyFromJwk(jwks)
if err != nil {
log.Printf("error parsing public key: %v", err)
···
func (o *OAuthHandler) login(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
-
o.Pages.Login(w, pages.LoginParams{})
+
o.pages.Login(w, pages.LoginParams{})
case http.MethodPost:
handle := strings.TrimPrefix(r.FormValue("handle"), "@")
-
resolved, err := o.Idresolver.ResolveIdent(r.Context(), handle)
+
resolved, err := o.idResolver.ResolveIdent(r.Context(), handle)
if err != nil {
log.Println("failed to resolve handle:", err)
-
o.Pages.Notice(w, "login-msg", fmt.Sprintf("\"%s\" is an invalid handle.", handle))
+
o.pages.Notice(w, "login-msg", fmt.Sprintf("\"%s\" is an invalid handle.", handle))
return
}
-
self := o.OAuth.ClientMetadata()
+
self := o.oauth.ClientMetadata()
oauthClient, err := client.NewClient(
self.ClientID,
-
o.Config.OAuth.Jwks,
+
o.config.OAuth.Jwks,
self.RedirectURIs[0],
)
if err != nil {
log.Println("failed to create oauth client:", err)
-
o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
+
o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
return
}
authServer, err := oauthClient.ResolvePdsAuthServer(r.Context(), resolved.PDSEndpoint())
if err != nil {
log.Println("failed to resolve auth server:", err)
-
o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
+
o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
return
}
authMeta, err := oauthClient.FetchAuthServerMetadata(r.Context(), authServer)
if err != nil {
log.Println("failed to fetch auth server metadata:", err)
-
o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
+
o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
return
}
dpopKey, err := helpers.GenerateKey(nil)
if err != nil {
log.Println("failed to generate dpop key:", err)
-
o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
+
o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
return
}
dpopKeyJson, err := json.Marshal(dpopKey)
if err != nil {
log.Println("failed to marshal dpop key:", err)
-
o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
+
o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
return
}
parResp, err := oauthClient.SendParAuthRequest(r.Context(), authServer, authMeta, handle, oauthScope, dpopKey)
if err != nil {
log.Println("failed to send par auth request:", err)
-
o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
+
o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
return
}
-
err = db.SaveOAuthRequest(o.Db, db.OAuthRequest{
+
err = db.SaveOAuthRequest(o.db, db.OAuthRequest{
Did: resolved.DID.String(),
PdsUrl: resolved.PDSEndpoint(),
Handle: handle,
···
})
if err != nil {
log.Println("failed to save oauth request:", err)
-
o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
+
o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
return
}
···
query.Add("client_id", self.ClientID)
query.Add("request_uri", parResp.RequestUri)
u.RawQuery = query.Encode()
-
o.Pages.HxRedirect(w, u.String())
+
o.pages.HxRedirect(w, u.String())
}
}
func (o *OAuthHandler) callback(w http.ResponseWriter, r *http.Request) {
state := r.FormValue("state")
-
oauthRequest, err := db.GetOAuthRequestByState(o.Db, state)
+
oauthRequest, err := db.GetOAuthRequestByState(o.db, state)
if err != nil {
log.Println("failed to get oauth request:", err)
-
o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
+
o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
return
}
defer func() {
-
err := db.DeleteOAuthRequestByState(o.Db, state)
+
err := db.DeleteOAuthRequestByState(o.db, state)
if err != nil {
log.Println("failed to delete oauth request for state:", state, err)
}
···
errorDescription := r.FormValue("error_description")
if error != "" || errorDescription != "" {
log.Printf("error: %s, %s", error, errorDescription)
-
o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
+
o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
return
}
code := r.FormValue("code")
if code == "" {
log.Println("missing code for state: ", state)
-
o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
+
o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
return
}
iss := r.FormValue("iss")
if iss == "" {
log.Println("missing iss for state: ", state)
-
o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
+
o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
return
}
-
self := o.OAuth.ClientMetadata()
+
self := o.oauth.ClientMetadata()
oauthClient, err := client.NewClient(
self.ClientID,
-
o.Config.OAuth.Jwks,
+
o.config.OAuth.Jwks,
self.RedirectURIs[0],
)
if err != nil {
log.Println("failed to create oauth client:", err)
-
o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
+
o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
return
}
jwk, err := helpers.ParseJWKFromBytes([]byte(oauthRequest.DpopPrivateJwk))
if err != nil {
log.Println("failed to parse jwk:", err)
-
o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
+
o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
return
}
···
)
if err != nil {
log.Println("failed to get token:", err)
-
o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
+
o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
return
}
if tokenResp.Scope != oauthScope {
log.Println("scope doesn't match:", tokenResp.Scope)
-
o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
+
o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
return
}
-
err = o.OAuth.SaveSession(w, r, oauthRequest, tokenResp)
+
err = o.oauth.SaveSession(w, r, oauthRequest, tokenResp)
if err != nil {
log.Println("failed to save session:", err)
-
o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
+
o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
return
}
log.Println("session saved successfully")
go o.addToDefaultKnot(oauthRequest.Did)
-
if !o.Config.Core.Dev {
-
err = o.Posthog.Enqueue(posthog.Capture{
+
if !o.config.Core.Dev {
+
err = o.posthog.Enqueue(posthog.Capture{
DistinctId: oauthRequest.Did,
Event: "signin",
})
···
}
func (o *OAuthHandler) logout(w http.ResponseWriter, r *http.Request) {
-
err := o.OAuth.ClearSession(r, w)
+
err := o.oauth.ClearSession(r, w)
if err != nil {
log.Println("failed to clear session:", err)
http.Redirect(w, r, "/", http.StatusFound)
···
defaultKnot := "knot1.tangled.sh"
log.Printf("adding %s to default knot", did)
-
err := o.Enforcer.AddMember(defaultKnot, did)
+
err := o.enforcer.AddMember(defaultKnot, did)
if err != nil {
log.Println("failed to add user to knot1.tangled.sh: ", err)
return
}
-
err = o.Enforcer.E.SavePolicy()
+
err = o.enforcer.E.SavePolicy()
if err != nil {
log.Println("failed to add user to knot1.tangled.sh: ", err)
return
}
-
secret, err := db.GetRegistrationKey(o.Db, defaultKnot)
+
secret, err := db.GetRegistrationKey(o.db, defaultKnot)
if err != nil {
log.Println("failed to get registration key for knot1.tangled.sh")
return
}
-
signedClient, err := knotclient.NewSignedClient(defaultKnot, secret, o.Config.Core.Dev)
+
signedClient, err := knotclient.NewSignedClient(defaultKnot, secret, o.config.Core.Dev)
resp, err := signedClient.AddMember(did)
if err != nil {
log.Println("failed to add user to knot1.tangled.sh: ", err)
+16 -17
appview/oauth/oauth.go
···
"github.com/gorilla/sessions"
oauth "tangled.sh/icyphox.sh/atproto-oauth"
"tangled.sh/icyphox.sh/atproto-oauth/helpers"
-
"tangled.sh/tangled.sh/core/appview"
"tangled.sh/tangled.sh/core/appview/config"
"tangled.sh/tangled.sh/core/appview/db"
"tangled.sh/tangled.sh/core/appview/oauth/client"
···
func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, oreq db.OAuthRequest, oresp *oauth.TokenResponse) error {
// first we save the did in the user session
-
userSession, err := o.Store.Get(r, appview.SessionName)
+
userSession, err := o.Store.Get(r, SessionName)
if err != nil {
return err
}
-
userSession.Values[appview.SessionDid] = oreq.Did
-
userSession.Values[appview.SessionHandle] = oreq.Handle
-
userSession.Values[appview.SessionPds] = oreq.PdsUrl
-
userSession.Values[appview.SessionAuthenticated] = true
+
userSession.Values[SessionDid] = oreq.Did
+
userSession.Values[SessionHandle] = oreq.Handle
+
userSession.Values[SessionPds] = oreq.PdsUrl
+
userSession.Values[SessionAuthenticated] = true
err = userSession.Save(r, w)
if err != nil {
return fmt.Errorf("error saving user session: %w", err)
···
}
func (o *OAuth) ClearSession(r *http.Request, w http.ResponseWriter) error {
-
userSession, err := o.Store.Get(r, appview.SessionName)
+
userSession, err := o.Store.Get(r, SessionName)
if err != nil || userSession.IsNew {
return fmt.Errorf("error getting user session (or new session?): %w", err)
}
-
did := userSession.Values[appview.SessionDid].(string)
+
did := userSession.Values[SessionDid].(string)
err = db.DeleteOAuthSessionByDid(o.Db, did)
if err != nil {
···
}
func (o *OAuth) GetSession(r *http.Request) (*db.OAuthSession, bool, error) {
-
userSession, err := o.Store.Get(r, appview.SessionName)
+
userSession, err := o.Store.Get(r, SessionName)
if err != nil || userSession.IsNew {
return nil, false, fmt.Errorf("error getting user session (or new session?): %w", err)
}
-
did := userSession.Values[appview.SessionDid].(string)
-
auth := userSession.Values[appview.SessionAuthenticated].(bool)
+
did := userSession.Values[SessionDid].(string)
+
auth := userSession.Values[SessionAuthenticated].(bool)
session, err := db.GetOAuthSessionByDid(o.Db, did)
if err != nil {
···
}
func (a *OAuth) GetUser(r *http.Request) *User {
-
clientSession, err := a.Store.Get(r, appview.SessionName)
+
clientSession, err := a.Store.Get(r, SessionName)
if err != nil || clientSession.IsNew {
return nil
}
return &User{
-
Handle: clientSession.Values[appview.SessionHandle].(string),
-
Did: clientSession.Values[appview.SessionDid].(string),
-
Pds: clientSession.Values[appview.SessionPds].(string),
+
Handle: clientSession.Values[SessionHandle].(string),
+
Did: clientSession.Values[SessionDid].(string),
+
Pds: clientSession.Values[SessionPds].(string),
}
}
func (a *OAuth) GetDid(r *http.Request) string {
-
clientSession, err := a.Store.Get(r, appview.SessionName)
+
clientSession, err := a.Store.Get(r, SessionName)
if err != nil || clientSession.IsNew {
return ""
}
-
return clientSession.Values[appview.SessionDid].(string)
+
return clientSession.Values[SessionDid].(string)
}
func (o *OAuth) AuthorizedClient(r *http.Request) (*xrpc.Client, error) {
+3 -12
appview/state/router.go
···
"github.com/go-chi/chi/v5"
"github.com/gorilla/sessions"
"tangled.sh/tangled.sh/core/appview/middleware"
-
oauthhandler "tangled.sh/tangled.sh/core/appview/oauth/handler"
+
oauth "tangled.sh/tangled.sh/core/appview/oauth/handler"
"tangled.sh/tangled.sh/core/appview/pulls"
"tangled.sh/tangled.sh/core/appview/repo"
"tangled.sh/tangled.sh/core/appview/settings"
···
}
func (s *State) OAuthRouter() http.Handler {
-
oauth := &oauthhandler.OAuthHandler{
-
Config: s.config,
-
Pages: s.pages,
-
Idresolver: s.idResolver,
-
Db: s.db,
-
Store: sessions.NewCookieStore([]byte(s.config.Core.CookieSecret)),
-
OAuth: s.oauth,
-
Enforcer: s.enforcer,
-
Posthog: s.posthog,
-
}
-
+
store := sessions.NewCookieStore([]byte(s.config.Core.CookieSecret))
+
oauth := oauth.New(s.config, s.pages, s.idResolver, s.db, store, s.oauth, s.enforcer, s.posthog)
return oauth.Router()
}
+2 -2
appview/state/state.go
···
return
case http.MethodPost:
-
session, err := s.oauth.Store.Get(r, appview.SessionName)
+
session, err := s.oauth.Store.Get(r, oauth.SessionName)
if err != nil || session.IsNew {
log.Println("unauthorized attempt to generate registration key")
http.Error(w, "Forbidden", http.StatusUnauthorized)
return
}
-
did := session.Values[appview.SessionDid].(string)
+
did := session.Values[oauth.SessionDid].(string)
// check if domain is valid url, and strip extra bits down to just host
domain := r.FormValue("domain")