appview: cache/session: init redis session store #210

closed
opened by anirudh.fi targeting master from push-ruoqnsmttnxx

And a high-level cache package for future use.

Signed-off-by: Anirudh Oppiliappan anirudh@tangled.sh

Changed files
+223 -42
appview
cache
session
oauth
state
+14
appview/cache/cache.go
···
+
package cache
+
+
import "github.com/redis/go-redis/v9"
+
+
type Cache struct {
+
*redis.Client
+
}
+
+
func New(addr string) *Cache {
+
rdb := redis.NewClient(&redis.Options{
+
Addr: addr,
+
})
+
return &Cache{rdb}
+
}
+163
appview/cache/session/store.go
···
+
package session
+
+
import (
+
"context"
+
"encoding/json"
+
"fmt"
+
"time"
+
+
"tangled.sh/tangled.sh/core/appview/cache"
+
)
+
+
type OAuthSession struct {
+
Handle string
+
Did string
+
PdsUrl string
+
AccessJwt string
+
RefreshJwt string
+
AuthServerIss string
+
DpopPdsNonce string
+
DpopAuthserverNonce string
+
DpopPrivateJwk string
+
Expiry string
+
}
+
+
type OAuthRequest struct {
+
AuthserverIss string
+
Handle string
+
State string
+
Did string
+
PdsUrl string
+
PkceVerifier string
+
DpopAuthserverNonce string
+
DpopPrivateJwk string
+
}
+
+
type SessionStore struct {
+
cache *cache.Cache
+
}
+
+
func New(cache *cache.Cache) *SessionStore {
+
return &SessionStore{cache: cache}
+
}
+
+
func (s *SessionStore) SaveSession(ctx context.Context, session OAuthSession) error {
+
key := fmt.Sprintf("oauthsession:%s", session.Did)
+
data, err := json.Marshal(session)
+
if err != nil {
+
return err
+
}
+
+
// Set with TTL = expires in + buffer
+
expiry, _ := time.Parse(time.RFC3339, session.Expiry)
+
ttl := time.Until(expiry) + time.Minute
+
+
return s.cache.Set(ctx, key, data, ttl).Err()
+
}
+
+
// SaveRequest stores the OAuth request to be later fetched in the callback. Since
+
// the fetching happens by comparing the state we get in the callback params, we
+
// store an additional state->did mapping which then lets us fetch the whole OAuth request.
+
func (s *SessionStore) SaveRequest(ctx context.Context, request OAuthRequest) error {
+
key := fmt.Sprintf("oauthrequest:%s", request.Did)
+
data, err := json.Marshal(request)
+
if err != nil {
+
return err
+
}
+
+
// oauth flow must complete within 30 minutes
+
err = s.cache.Set(ctx, key, data, 30*time.Minute).Err()
+
if err != nil {
+
return fmt.Errorf("error saving request: %w", err)
+
}
+
+
stateKey := fmt.Sprintf("oauthstate:%s", request.State)
+
err = s.cache.Set(ctx, stateKey, request.Did, 30*time.Minute).Err()
+
if err != nil {
+
return fmt.Errorf("error saving state->did mapping: %w", err)
+
}
+
+
return nil
+
}
+
+
func (s *SessionStore) GetSession(ctx context.Context, did string) (*OAuthSession, error) {
+
key := fmt.Sprintf("oauthsession:%s", did)
+
val, err := s.cache.Get(ctx, key).Result()
+
if err != nil {
+
return nil, err
+
}
+
+
var session OAuthSession
+
err = json.Unmarshal([]byte(val), &session)
+
if err != nil {
+
return nil, err
+
}
+
return &session, nil
+
}
+
+
func (s *SessionStore) GetRequestByState(ctx context.Context, state string) (*OAuthRequest, error) {
+
didKey, err := s.getRequestKey(ctx, state)
+
if err != nil {
+
return nil, err
+
}
+
+
val, err := s.cache.Get(ctx, didKey).Result()
+
if err != nil {
+
return nil, err
+
}
+
+
var request OAuthRequest
+
err = json.Unmarshal([]byte(val), &request)
+
if err != nil {
+
return nil, err
+
}
+
+
return &request, nil
+
}
+
+
func (s *SessionStore) DeleteSession(ctx context.Context, did string) error {
+
key := fmt.Sprintf("oauthsession:%s", did)
+
return s.cache.Del(ctx, key).Err()
+
}
+
+
func (s *SessionStore) DeleteRequestByState(ctx context.Context, state string) error {
+
key := fmt.Sprintf("oauthstate:%s", state)
+
did, err := s.cache.Get(ctx, key).Result()
+
if err != nil {
+
return err
+
}
+
+
didKey := fmt.Sprintf("oauthrequest:%s", did)
+
return s.cache.Del(ctx, didKey).Err()
+
}
+
+
func (s *SessionStore) RefreshSession(ctx context.Context, did, access, refresh, expiry string) error {
+
session, err := s.GetSession(ctx, did)
+
if err != nil {
+
return err
+
}
+
session.AccessJwt = access
+
session.RefreshJwt = refresh
+
session.Expiry = expiry
+
return s.SaveSession(ctx, *session)
+
}
+
+
func (s *SessionStore) UpdateNonce(ctx context.Context, did, nonce string) error {
+
session, err := s.GetSession(ctx, did)
+
if err != nil {
+
return err
+
}
+
session.DpopAuthserverNonce = nonce
+
return s.SaveSession(ctx, *session)
+
}
+
+
func (s *SessionStore) getRequestKey(ctx context.Context, state string) (string, error) {
+
key := fmt.Sprintf("oauthstate:%s", state)
+
did, err := s.cache.Get(ctx, key).Result()
+
if err != nil {
+
return "", err
+
}
+
+
didKey := fmt.Sprintf("oauthrequest:%s", did)
+
return didKey, nil
+
}
+8 -4
appview/oauth/handler/handler.go
···
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/posthog/posthog-go"
"tangled.sh/icyphox.sh/atproto-oauth/helpers"
+
sessioncache "tangled.sh/tangled.sh/core/appview/cache/session"
"tangled.sh/tangled.sh/core/appview/config"
"tangled.sh/tangled.sh/core/appview/db"
"tangled.sh/tangled.sh/core/appview/idresolver"
···
config *config.Config
pages *pages.Pages
idResolver *idresolver.Resolver
+
sess *sessioncache.SessionStore
db *db.DB
store *sessions.CookieStore
oauth *oauth.OAuth
···
pages *pages.Pages,
idResolver *idresolver.Resolver,
db *db.DB,
+
sess *sessioncache.SessionStore,
store *sessions.CookieStore,
oauth *oauth.OAuth,
enforcer *rbac.Enforcer,
···
pages: pages,
idResolver: idResolver,
db: db,
+
sess: sess,
store: store,
oauth: oauth,
enforcer: enforcer,
···
return
}
-
err = db.SaveOAuthRequest(o.db, db.OAuthRequest{
+
err = o.sess.SaveRequest(r.Context(), sessioncache.OAuthRequest{
Did: resolved.DID.String(),
PdsUrl: resolved.PDSEndpoint(),
Handle: handle,
···
func (o *OAuthHandler) callback(w http.ResponseWriter, r *http.Request) {
state := r.FormValue("state")
-
oauthRequest, err := db.GetOAuthRequestByState(o.db, state)
+
oauthRequest, err := o.sess.GetRequestByState(r.Context(), state)
if err != nil {
log.Println("failed to get oauth request:", err)
o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
···
}
defer func() {
-
err := db.DeleteOAuthRequestByState(o.db, state)
+
err := o.sess.DeleteRequestByState(r.Context(), state)
if err != nil {
log.Println("failed to delete oauth request for state:", state, err)
}
···
return
}
-
err = o.oauth.SaveSession(w, r, oauthRequest, tokenResp)
+
err = o.oauth.SaveSession(w, r, *oauthRequest, tokenResp)
if err != nil {
log.Println("failed to save session:", err)
o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
+28 -35
appview/oauth/oauth.go
···
"github.com/gorilla/sessions"
oauth "tangled.sh/icyphox.sh/atproto-oauth"
"tangled.sh/icyphox.sh/atproto-oauth/helpers"
+
sessioncache "tangled.sh/tangled.sh/core/appview/cache/session"
"tangled.sh/tangled.sh/core/appview/config"
-
"tangled.sh/tangled.sh/core/appview/db"
"tangled.sh/tangled.sh/core/appview/oauth/client"
xrpc "tangled.sh/tangled.sh/core/appview/xrpcclient"
)
-
type OAuthRequest struct {
-
ID uint
-
AuthserverIss string
-
State string
-
Did string
-
PdsUrl string
-
PkceVerifier string
-
DpopAuthserverNonce string
-
DpopPrivateJwk string
-
}
-
type OAuth struct {
-
Store *sessions.CookieStore
-
Db *db.DB
-
Config *config.Config
+
store *sessions.CookieStore
+
config *config.Config
+
sess *sessioncache.SessionStore
}
-
func NewOAuth(db *db.DB, config *config.Config) *OAuth {
+
func NewOAuth(config *config.Config, sess *sessioncache.SessionStore) *OAuth {
return &OAuth{
-
Store: sessions.NewCookieStore([]byte(config.Core.CookieSecret)),
-
Db: db,
-
Config: config,
+
store: sessions.NewCookieStore([]byte(config.Core.CookieSecret)),
+
config: config,
+
sess: sess,
}
}
-
func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, oreq db.OAuthRequest, oresp *oauth.TokenResponse) error {
+
func (o *OAuth) Stores() *sessions.CookieStore {
+
return o.store
+
}
+
+
func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, oreq sessioncache.OAuthRequest, oresp *oauth.TokenResponse) error {
// first we save the did in the user session
-
userSession, err := o.Store.Get(r, SessionName)
+
userSession, err := o.store.Get(r, SessionName)
if err != nil {
return err
}
···
}
// then save the whole thing in the db
-
session := db.OAuthSession{
+
session := sessioncache.OAuthSession{
Did: oreq.Did,
Handle: oreq.Handle,
PdsUrl: oreq.PdsUrl,
···
Expiry: time.Now().Add(time.Duration(oresp.ExpiresIn) * time.Second).Format(time.RFC3339),
}
-
return db.SaveOAuthSession(o.Db, session)
+
return o.sess.SaveSession(r.Context(), session)
}
func (o *OAuth) ClearSession(r *http.Request, w http.ResponseWriter) error {
-
userSession, err := o.Store.Get(r, SessionName)
+
userSession, err := o.store.Get(r, SessionName)
if err != nil || userSession.IsNew {
return fmt.Errorf("error getting user session (or new session?): %w", err)
}
did := userSession.Values[SessionDid].(string)
-
err = db.DeleteOAuthSessionByDid(o.Db, did)
+
err = o.sess.DeleteSession(r.Context(), did)
if err != nil {
return fmt.Errorf("error deleting oauth session: %w", err)
}
···
return userSession.Save(r, w)
}
-
func (o *OAuth) GetSession(r *http.Request) (*db.OAuthSession, bool, error) {
-
userSession, err := o.Store.Get(r, SessionName)
+
func (o *OAuth) GetSession(r *http.Request) (*sessioncache.OAuthSession, bool, error) {
+
userSession, err := o.store.Get(r, SessionName)
if err != nil || userSession.IsNew {
return nil, false, fmt.Errorf("error getting user session (or new session?): %w", err)
}
···
did := userSession.Values[SessionDid].(string)
auth := userSession.Values[SessionAuthenticated].(bool)
-
session, err := db.GetOAuthSessionByDid(o.Db, did)
+
session, err := o.sess.GetSession(r.Context(), did)
if err != nil {
return nil, false, fmt.Errorf("error getting oauth session: %w", err)
}
···
oauthClient, err := client.NewClient(
self.ClientID,
-
o.Config.OAuth.Jwks,
+
o.config.OAuth.Jwks,
self.RedirectURIs[0],
)
···
}
newExpiry := time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second).Format(time.RFC3339)
-
err = db.RefreshOAuthSession(o.Db, did, resp.AccessToken, resp.RefreshToken, newExpiry)
+
err = o.sess.RefreshSession(r.Context(), did, resp.AccessToken, resp.RefreshToken, newExpiry)
if err != nil {
return nil, false, fmt.Errorf("error refreshing oauth session: %w", err)
}
···
}
func (a *OAuth) GetUser(r *http.Request) *User {
-
clientSession, err := a.Store.Get(r, SessionName)
+
clientSession, err := a.store.Get(r, SessionName)
if err != nil || clientSession.IsNew {
return nil
···
}
func (a *OAuth) GetDid(r *http.Request) string {
-
clientSession, err := a.Store.Get(r, SessionName)
+
clientSession, err := a.store.Get(r, SessionName)
if err != nil || clientSession.IsNew {
return ""
···
client := &oauth.XrpcClient{
OnDpopPdsNonceChanged: func(did, newNonce string) {
-
err := db.UpdateDpopPdsNonce(o.Db, did, newNonce)
+
err := o.sess.UpdateNonce(r.Context(), did, newNonce)
if err != nil {
log.Printf("error updating dpop pds nonce: %v", err)
}
···
return []string{fmt.Sprintf("%s/oauth/callback", c)}
}
-
clientURI := o.Config.Core.AppviewHost
+
clientURI := o.config.Core.AppviewHost
clientID := fmt.Sprintf("%s/oauth/client-metadata.json", clientURI)
redirectURIs := makeRedirectURIs(clientURI)
-
if o.Config.Core.Dev {
+
if o.config.Core.Dev {
clientURI = fmt.Sprintf("http://127.0.0.1:3000")
redirectURIs = makeRedirectURIs(clientURI)
+1 -1
appview/state/router.go
···
func (s *State) OAuthRouter() http.Handler {
store := sessions.NewCookieStore([]byte(s.config.Core.CookieSecret))
-
oauth := oauthhandler.New(s.config, s.pages, s.idResolver, s.db, store, s.oauth, s.enforcer, s.posthog)
+
oauth := oauthhandler.New(s.config, s.pages, s.idResolver, s.db, s.sess, store, s.oauth, s.enforcer, s.posthog)
return oauth.Router()
}
+9 -2
appview/state/state.go
···
"github.com/posthog/posthog-go"
"tangled.sh/tangled.sh/core/api/tangled"
"tangled.sh/tangled.sh/core/appview"
+
"tangled.sh/tangled.sh/core/appview/cache"
+
"tangled.sh/tangled.sh/core/appview/cache/session"
"tangled.sh/tangled.sh/core/appview/config"
"tangled.sh/tangled.sh/core/appview/db"
"tangled.sh/tangled.sh/core/appview/idresolver"
···
enforcer *rbac.Enforcer
tidClock syntax.TIDClock
pages *pages.Pages
+
sess *session.SessionStore
idResolver *idresolver.Resolver
posthog posthog.Client
jc *jetstream.JetstreamClient
···
res = idresolver.DefaultResolver()
}
-
oauth := oauth.NewOAuth(d, config)
+
cache := cache.New(config.Redis.Addr)
+
sess := session.New(cache)
+
+
oauth := oauth.NewOAuth(config, sess)
posthog, err := posthog.NewWithConfig(config.Posthog.ApiKey, posthog.Config{Endpoint: config.Posthog.Endpoint})
if err != nil {
···
enforcer,
clock,
pgs,
+
sess,
res,
posthog,
jc,
···
return
case http.MethodPost:
-
session, err := s.oauth.Store.Get(r, oauth.SessionName)
+
session, err := s.oauth.Stores().Get(r, oauth.SessionName)
if err != nil || session.IsNew {
log.Println("unauthorized attempt to generate registration key")
http.Error(w, "Forbidden", http.StatusUnauthorized)