From c3d04acd6dec94cc961916440a607d82bd636050 Mon Sep 17 00:00:00 2001 From: Anirudh Oppiliappan Date: Wed, 4 Jun 2025 12:09:47 +0300 Subject: [PATCH] appview: oauth: swap out db store for redis cache Change-Id: ruoqnsmttnxxwpprynornwqktlmoolrx Signed-off-by: Anirudh Oppiliappan --- appview/cache/session/store.go | 6 +-- appview/oauth/handler/handler.go | 14 ++++--- appview/oauth/oauth.go | 63 ++++++++++++++------------------ appview/state/router.go | 4 +- appview/state/state.go | 17 +++++---- 5 files changed, 50 insertions(+), 54 deletions(-) diff --git a/appview/cache/session/store.go b/appview/cache/session/store.go index c5c43aa..6922d4e 100644 --- a/appview/cache/session/store.go +++ b/appview/cache/session/store.go @@ -102,7 +102,7 @@ func (s *SessionStore) GetSession(ctx context.Context, did string) (*OAuthSessio } func (s *SessionStore) GetRequestByState(ctx context.Context, state string) (*OAuthRequest, error) { - didKey, err := s.getRequestKey(ctx, state) + didKey, err := s.getRequestKeyFromState(ctx, state) if err != nil { return nil, err } @@ -127,12 +127,12 @@ func (s *SessionStore) DeleteSession(ctx context.Context, did string) error { } func (s *SessionStore) DeleteRequestByState(ctx context.Context, state string) error { - didKey, err := s.getRequestKey(ctx, state) + didKey, err := s.getRequestKeyFromState(ctx, state) if err != nil { return err } - err = s.cache.Del(ctx, fmt.Sprintf(stateKey, "state")).Err() + err = s.cache.Del(ctx, fmt.Sprintf(stateKey, state)).Err() if err != nil { return err } diff --git a/appview/oauth/handler/handler.go b/appview/oauth/handler/handler.go index fc07f8d..de90f0f 100644 --- a/appview/oauth/handler/handler.go +++ b/appview/oauth/handler/handler.go @@ -13,6 +13,7 @@ import ( "github.com/lestrrat-go/jwx/v2/jwk" "github.com/posthog/posthog-go" "tangled.sh/icyphox.sh/atproto-oauth/helpers" + sessioncache "tangled.sh/tangled.sh/core/appview/cache/session" "tangled.sh/tangled.sh/core/appview/config" "tangled.sh/tangled.sh/core/appview/db" "tangled.sh/tangled.sh/core/appview/idresolver" @@ -32,6 +33,7 @@ type OAuthHandler struct { config *config.Config pages *pages.Pages idResolver *idresolver.Resolver + sess *sessioncache.SessionStore db *db.DB store *sessions.CookieStore oauth *oauth.OAuth @@ -44,6 +46,7 @@ func New( pages *pages.Pages, idResolver *idresolver.Resolver, db *db.DB, + sess *sessioncache.SessionStore, store *sessions.CookieStore, oauth *oauth.OAuth, enforcer *rbac.Enforcer, @@ -54,6 +57,7 @@ func New( pages: pages, idResolver: idResolver, db: db, + sess: sess, store: store, oauth: oauth, enforcer: enforcer, @@ -158,7 +162,7 @@ func (o *OAuthHandler) login(w http.ResponseWriter, r *http.Request) { return } - err = db.SaveOAuthRequest(o.db, db.OAuthRequest{ + err = o.sess.SaveRequest(r.Context(), sessioncache.OAuthRequest{ Did: resolved.DID.String(), PdsUrl: resolved.PDSEndpoint(), Handle: handle, @@ -186,7 +190,7 @@ func (o *OAuthHandler) login(w http.ResponseWriter, r *http.Request) { func (o *OAuthHandler) callback(w http.ResponseWriter, r *http.Request) { state := r.FormValue("state") - oauthRequest, err := db.GetOAuthRequestByState(o.db, state) + oauthRequest, err := o.sess.GetRequestByState(r.Context(), state) if err != nil { log.Println("failed to get oauth request:", err) o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") @@ -194,7 +198,7 @@ func (o *OAuthHandler) callback(w http.ResponseWriter, r *http.Request) { } defer func() { - err := db.DeleteOAuthRequestByState(o.db, state) + err := o.sess.DeleteRequestByState(r.Context(), state) if err != nil { log.Println("failed to delete oauth request for state:", state, err) } @@ -263,7 +267,7 @@ func (o *OAuthHandler) callback(w http.ResponseWriter, r *http.Request) { return } - err = o.oauth.SaveSession(w, r, oauthRequest, tokenResp) + err = o.oauth.SaveSession(w, r, *oauthRequest, tokenResp) if err != nil { log.Println("failed to save session:", err) o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") @@ -295,7 +299,7 @@ func (o *OAuthHandler) logout(w http.ResponseWriter, r *http.Request) { } log.Println("session cleared successfully") - http.Redirect(w, r, "/", http.StatusFound) + o.pages.HxRedirect(w, "/login") } func pubKeyFromJwk(jwks string) (jwk.Key, error) { diff --git a/appview/oauth/oauth.go b/appview/oauth/oauth.go index f5d307b..5ce628b 100644 --- a/appview/oauth/oauth.go +++ b/appview/oauth/oauth.go @@ -10,40 +10,33 @@ import ( "github.com/gorilla/sessions" oauth "tangled.sh/icyphox.sh/atproto-oauth" "tangled.sh/icyphox.sh/atproto-oauth/helpers" + sessioncache "tangled.sh/tangled.sh/core/appview/cache/session" "tangled.sh/tangled.sh/core/appview/config" - "tangled.sh/tangled.sh/core/appview/db" "tangled.sh/tangled.sh/core/appview/oauth/client" xrpc "tangled.sh/tangled.sh/core/appview/xrpcclient" ) -type OAuthRequest struct { - ID uint - AuthserverIss string - State string - Did string - PdsUrl string - PkceVerifier string - DpopAuthserverNonce string - DpopPrivateJwk string -} - type OAuth struct { - Store *sessions.CookieStore - Db *db.DB - Config *config.Config + store *sessions.CookieStore + config *config.Config + sess *sessioncache.SessionStore } -func NewOAuth(db *db.DB, config *config.Config) *OAuth { +func NewOAuth(config *config.Config, sess *sessioncache.SessionStore) *OAuth { return &OAuth{ - Store: sessions.NewCookieStore([]byte(config.Core.CookieSecret)), - Db: db, - Config: config, + store: sessions.NewCookieStore([]byte(config.Core.CookieSecret)), + config: config, + sess: sess, } } -func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, oreq db.OAuthRequest, oresp *oauth.TokenResponse) error { +func (o *OAuth) Stores() *sessions.CookieStore { + return o.store +} + +func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, oreq sessioncache.OAuthRequest, oresp *oauth.TokenResponse) error { // first we save the did in the user session - userSession, err := o.Store.Get(r, SessionName) + userSession, err := o.store.Get(r, SessionName) if err != nil { return err } @@ -58,7 +51,7 @@ func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, oreq db.OAut } // then save the whole thing in the db - session := db.OAuthSession{ + session := sessioncache.OAuthSession{ Did: oreq.Did, Handle: oreq.Handle, PdsUrl: oreq.PdsUrl, @@ -70,18 +63,18 @@ func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, oreq db.OAut Expiry: time.Now().Add(time.Duration(oresp.ExpiresIn) * time.Second).Format(time.RFC3339), } - return db.SaveOAuthSession(o.Db, session) + return o.sess.SaveSession(r.Context(), session) } func (o *OAuth) ClearSession(r *http.Request, w http.ResponseWriter) error { - userSession, err := o.Store.Get(r, SessionName) + userSession, err := o.store.Get(r, SessionName) if err != nil || userSession.IsNew { return fmt.Errorf("error getting user session (or new session?): %w", err) } did := userSession.Values[SessionDid].(string) - err = db.DeleteOAuthSessionByDid(o.Db, did) + err = o.sess.DeleteSession(r.Context(), did) if err != nil { return fmt.Errorf("error deleting oauth session: %w", err) } @@ -91,8 +84,8 @@ func (o *OAuth) ClearSession(r *http.Request, w http.ResponseWriter) error { return userSession.Save(r, w) } -func (o *OAuth) GetSession(r *http.Request) (*db.OAuthSession, bool, error) { - userSession, err := o.Store.Get(r, SessionName) +func (o *OAuth) GetSession(r *http.Request) (*sessioncache.OAuthSession, bool, error) { + userSession, err := o.store.Get(r, SessionName) if err != nil || userSession.IsNew { return nil, false, fmt.Errorf("error getting user session (or new session?): %w", err) } @@ -100,7 +93,7 @@ func (o *OAuth) GetSession(r *http.Request) (*db.OAuthSession, bool, error) { did := userSession.Values[SessionDid].(string) auth := userSession.Values[SessionAuthenticated].(bool) - session, err := db.GetOAuthSessionByDid(o.Db, did) + session, err := o.sess.GetSession(r.Context(), did) if err != nil { return nil, false, fmt.Errorf("error getting oauth session: %w", err) } @@ -119,7 +112,7 @@ func (o *OAuth) GetSession(r *http.Request) (*db.OAuthSession, bool, error) { oauthClient, err := client.NewClient( self.ClientID, - o.Config.OAuth.Jwks, + o.config.OAuth.Jwks, self.RedirectURIs[0], ) @@ -133,7 +126,7 @@ func (o *OAuth) GetSession(r *http.Request) (*db.OAuthSession, bool, error) { } newExpiry := time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second).Format(time.RFC3339) - err = db.RefreshOAuthSession(o.Db, did, resp.AccessToken, resp.RefreshToken, newExpiry) + err = o.sess.RefreshSession(r.Context(), did, resp.AccessToken, resp.RefreshToken, newExpiry) if err != nil { return nil, false, fmt.Errorf("error refreshing oauth session: %w", err) } @@ -155,7 +148,7 @@ type User struct { } func (a *OAuth) GetUser(r *http.Request) *User { - clientSession, err := a.Store.Get(r, SessionName) + clientSession, err := a.store.Get(r, SessionName) if err != nil || clientSession.IsNew { return nil @@ -169,7 +162,7 @@ func (a *OAuth) GetUser(r *http.Request) *User { } func (a *OAuth) GetDid(r *http.Request) string { - clientSession, err := a.Store.Get(r, SessionName) + clientSession, err := a.store.Get(r, SessionName) if err != nil || clientSession.IsNew { return "" @@ -189,7 +182,7 @@ func (o *OAuth) AuthorizedClient(r *http.Request) (*xrpc.Client, error) { client := &oauth.XrpcClient{ OnDpopPdsNonceChanged: func(did, newNonce string) { - err := db.UpdateDpopPdsNonce(o.Db, did, newNonce) + err := o.sess.UpdateNonce(r.Context(), did, newNonce) if err != nil { log.Printf("error updating dpop pds nonce: %v", err) } @@ -234,11 +227,11 @@ func (o *OAuth) ClientMetadata() ClientMetadata { return []string{fmt.Sprintf("%s/oauth/callback", c)} } - clientURI := o.Config.Core.AppviewHost + clientURI := o.config.Core.AppviewHost clientID := fmt.Sprintf("%s/oauth/client-metadata.json", clientURI) redirectURIs := makeRedirectURIs(clientURI) - if o.Config.Core.Dev { + if o.config.Core.Dev { clientURI = fmt.Sprintf("http://127.0.0.1:3000") redirectURIs = makeRedirectURIs(clientURI) diff --git a/appview/state/router.go b/appview/state/router.go index 926cede..e20d056 100644 --- a/appview/state/router.go +++ b/appview/state/router.go @@ -97,8 +97,6 @@ func (s *State) StandardRouter(mw *middleware.Middleware) http.Handler { r.Get("/", s.Timeline) - r.With(middleware.AuthMiddleware(s.oauth)).Post("/logout", s.Logout) - r.Route("/knots", func(r chi.Router) { r.Use(middleware.AuthMiddleware(s.oauth)) r.Get("/", s.Knots) @@ -156,7 +154,7 @@ func (s *State) StandardRouter(mw *middleware.Middleware) http.Handler { func (s *State) OAuthRouter() http.Handler { store := sessions.NewCookieStore([]byte(s.config.Core.CookieSecret)) - oauth := oauthhandler.New(s.config, s.pages, s.idResolver, s.db, store, s.oauth, s.enforcer, s.posthog) + oauth := oauthhandler.New(s.config, s.pages, s.idResolver, s.db, s.sess, store, s.oauth, s.enforcer, s.posthog) return oauth.Router() } diff --git a/appview/state/state.go b/appview/state/state.go index 2c9b336..4658987 100644 --- a/appview/state/state.go +++ b/appview/state/state.go @@ -20,6 +20,8 @@ import ( "github.com/posthog/posthog-go" "tangled.sh/tangled.sh/core/api/tangled" "tangled.sh/tangled.sh/core/appview" + "tangled.sh/tangled.sh/core/appview/cache" + "tangled.sh/tangled.sh/core/appview/cache/session" "tangled.sh/tangled.sh/core/appview/config" "tangled.sh/tangled.sh/core/appview/db" "tangled.sh/tangled.sh/core/appview/idresolver" @@ -37,6 +39,7 @@ type State struct { enforcer *rbac.Enforcer tidClock syntax.TIDClock pages *pages.Pages + sess *session.SessionStore idResolver *idresolver.Resolver posthog posthog.Client jc *jetstream.JetstreamClient @@ -65,7 +68,10 @@ func Make(config *config.Config) (*State, error) { res = idresolver.DefaultResolver() } - oauth := oauth.NewOAuth(d, config) + cache := cache.New(config.Redis.Addr) + sess := session.New(cache) + + oauth := oauth.NewOAuth(config, sess) posthog, err := posthog.NewWithConfig(config.Posthog.ApiKey, posthog.Config{Endpoint: config.Posthog.Endpoint}) if err != nil { @@ -104,6 +110,7 @@ func Make(config *config.Config) (*State, error) { enforcer, clock, pgs, + sess, res, posthog, jc, @@ -118,12 +125,6 @@ func TID(c *syntax.TIDClock) string { return c.Next().String() } -func (s *State) Logout(w http.ResponseWriter, r *http.Request) { - s.oauth.ClearSession(r, w) - w.Header().Set("HX-Redirect", "/login") - w.WriteHeader(http.StatusSeeOther) -} - func (s *State) Timeline(w http.ResponseWriter, r *http.Request) { user := s.oauth.GetUser(r) @@ -176,7 +177,7 @@ func (s *State) RegistrationKey(w http.ResponseWriter, r *http.Request) { return case http.MethodPost: - session, err := s.oauth.Store.Get(r, oauth.SessionName) + session, err := s.oauth.Stores().Get(r, oauth.SessionName) if err != nil || session.IsNew { log.Println("unauthorized attempt to generate registration key") http.Error(w, "Forbidden", http.StatusUnauthorized) -- 2.43.0