forked from tangled.org/core
Monorepo for Tangled — https://tangled.org

appview/auth: implement background session refresh

After successful login, we start a goroutine to handle refreshing the
session token in the background, with session expiry defined in
auth.ExpiryDuration.

To kill the goroutine cleanly at Logout, s.sessionCancelFuncs maintains
a map of did->cancelFunc which we look up and call at logout.

Accounting for this map getting cleared at program end, we restore a
token refresher at the *first* authenticated request using RestoreSessionIfNeeded
in the AuthMiddleware.

anirudh.fi 7e904629 53feba3a

verified
Changed files
+167 -46
appview
+59 -1
appview/auth/auth.go
···
"github.com/sotangled/tangled/appview"
)
+
const ExpiryDuration = 15 * time.Minute
+
type Auth struct {
Store *sessions.CookieStore
}
···
GetStatus() *string
}
+
type ClientSessionish struct {
+
sessions.Session
+
}
+
+
func (c *ClientSessionish) GetAccessJwt() string {
+
return c.Values[appview.SessionAccessJwt].(string)
+
}
+
+
func (c *ClientSessionish) GetActive() *bool {
+
return c.Values[appview.SessionAuthenticated].(*bool)
+
}
+
+
func (c *ClientSessionish) GetDid() string {
+
return c.Values[appview.SessionDid].(string)
+
}
+
+
func (c *ClientSessionish) GetDidDoc() *interface{} {
+
return nil
+
}
+
+
func (c *ClientSessionish) GetHandle() string {
+
return c.Values[appview.SessionHandle].(string)
+
}
+
+
func (c *ClientSessionish) GetRefreshJwt() string {
+
return c.Values[appview.SessionRefreshJwt].(string)
+
}
+
+
func (c *ClientSessionish) GetStatus() *string {
+
return nil
+
}
+
// Create a wrapper type for ServerRefreshSession_Output
type RefreshSessionWrapper struct {
*comatproto.ServerRefreshSession_Output
···
clientSession.Values[appview.SessionPds] = pdsEndpoint
clientSession.Values[appview.SessionAccessJwt] = atSessionish.GetAccessJwt()
clientSession.Values[appview.SessionRefreshJwt] = atSessionish.GetRefreshJwt()
-
clientSession.Values[appview.SessionExpiry] = time.Now().Add(time.Minute * 15).Format(time.RFC3339)
+
clientSession.Values[appview.SessionExpiry] = time.Now().Add(ExpiryDuration).Format(time.RFC3339)
clientSession.Values[appview.SessionAuthenticated] = true
return clientSession.Save(r, w)
+
}
+
+
func (a *Auth) RefreshSession(ctx context.Context, r *http.Request, w http.ResponseWriter, atSessionish Sessionish, pdsEndpoint string) error {
+
client := xrpc.Client{
+
Host: pdsEndpoint,
+
Auth: &xrpc.AuthInfo{
+
Did: atSessionish.GetDid(),
+
AccessJwt: atSessionish.GetRefreshJwt(),
+
RefreshJwt: atSessionish.GetRefreshJwt(),
+
},
+
}
+
+
atSession, err := comatproto.ServerRefreshSession(ctx, &client)
+
if err != nil {
+
return fmt.Errorf("failed to refresh session: %w", err)
+
}
+
+
newAtSessionish := &RefreshSessionWrapper{atSession}
+
err = a.StoreSession(r, w, newAtSessionish, pdsEndpoint)
+
if err != nil {
+
return fmt.Errorf("failed to store refreshed session: %w", err)
+
}
+
+
return nil
}
func (a *Auth) AuthorizedClient(r *http.Request) (*xrpc.Client, error) {
+9 -44
appview/state/middleware.go
···
"log"
"net/http"
"strings"
-
"time"
-
comatproto "github.com/bluesky-social/indigo/api/atproto"
"github.com/bluesky-social/indigo/atproto/identity"
-
"github.com/bluesky-social/indigo/xrpc"
"github.com/go-chi/chi/v5"
"github.com/sotangled/tangled/appview"
-
"github.com/sotangled/tangled/appview/auth"
"github.com/sotangled/tangled/appview/db"
)
···
func AuthMiddleware(s *State) Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-
session, _ := s.auth.Store.Get(r, appview.SessionName)
-
authorized, ok := session.Values[appview.SessionAuthenticated].(bool)
-
if !ok || !authorized {
-
log.Printf("not logged in, redirecting")
+
if s.auth == nil {
http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
return
}
-
-
// refresh if nearing expiry
-
// TODO: dedup with /login
-
expiryStr := session.Values[appview.SessionExpiry].(string)
-
expiry, err := time.Parse(time.RFC3339, expiryStr)
+
err := s.RestoreSessionIfNeeded(r, w)
if err != nil {
-
log.Println("invalid expiry time", err)
http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
return
}
-
pdsUrl := session.Values[appview.SessionPds].(string)
-
did := session.Values[appview.SessionDid].(string)
-
refreshJwt := session.Values[appview.SessionRefreshJwt].(string)
-
if time.Now().After(expiry) {
-
log.Println("token expired, refreshing ...")
-
-
client := xrpc.Client{
-
Host: pdsUrl,
-
Auth: &xrpc.AuthInfo{
-
Did: did,
-
AccessJwt: refreshJwt,
-
RefreshJwt: refreshJwt,
-
},
-
}
-
atSession, err := comatproto.ServerRefreshSession(r.Context(), &client)
-
if err != nil {
-
log.Println("failed to refresh session", err)
-
http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
-
return
-
}
-
-
sessionish := auth.RefreshSessionWrapper{atSession}
-
-
err = s.auth.StoreSession(r, w, &sessionish, pdsUrl)
-
if err != nil {
-
log.Printf("failed to store session for did: %s\n: %s", atSession.Did, err)
-
return
-
}
-
-
log.Println("successfully refreshed token")
+
session, _ := s.auth.Store.Get(r, appview.SessionName)
+
authorized, ok := session.Values[appview.SessionAuthenticated].(bool)
+
if !ok || !authorized {
+
log.Printf("not logged in, redirecting")
+
http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
+
return
}
+
// refresh if nearing expiry
next.ServeHTTP(w, r)
})
}
+72
appview/state/session.go
···
+
package state
+
+
import (
+
"context"
+
"fmt"
+
"log"
+
"net/http"
+
"time"
+
+
"github.com/gorilla/sessions"
+
"github.com/sotangled/tangled/appview"
+
"github.com/sotangled/tangled/appview/auth"
+
)
+
+
func (s *State) StartTokenRefresher(
+
ctx context.Context,
+
refreshInterval time.Duration,
+
r *http.Request,
+
w http.ResponseWriter,
+
atSessionish auth.Sessionish,
+
pdsEndpoint string,
+
) {
+
go func() {
+
ticker := time.NewTicker(refreshInterval)
+
defer ticker.Stop()
+
+
for {
+
select {
+
case <-ticker.C:
+
err := s.auth.RefreshSession(ctx, r, w, atSessionish, pdsEndpoint)
+
if err != nil {
+
log.Printf("token refresh failed: %v", err)
+
} else {
+
log.Println("token refreshed successfully")
+
}
+
case <-ctx.Done():
+
log.Println("stopping token refresher")
+
return
+
}
+
}
+
}()
+
}
+
+
// RestoreSessionIfNeeded checks if a session exists in the request and starts a
+
// token refresher if it doesn't have one running already.
+
func (s *State) RestoreSessionIfNeeded(r *http.Request, w http.ResponseWriter) error {
+
var session *sessions.Session
+
var err error
+
session, err = s.auth.GetSession(r)
+
if err != nil {
+
fmt.Errorf("error getting session: %w", err)
+
}
+
+
did, ok := session.Values[appview.SessionDid].(string)
+
if !ok {
+
return fmt.Errorf("session did not contain a did")
+
}
+
sessionish := auth.ClientSessionish{Session: *session}
+
pdsEndpoint := session.Values[appview.SessionPds].(string)
+
+
// If no refresher is running for this session, start one
+
if _, exists := s.sessionCancelFuncs[did]; !exists {
+
sessionCtx, cancel := context.WithCancel(context.Background())
+
s.sessionCancelFuncs[did] = cancel
+
+
s.StartTokenRefresher(sessionCtx, auth.ExpiryDuration, r, w, &sessionish, pdsEndpoint)
+
+
log.Printf("restored session refresher for %s", did)
+
}
+
+
return nil
+
}
+27 -1
appview/state/state.go
···
resolver *appview.Resolver
jc *jetstream.JetstreamClient
config *appview.Config
+
+
sessionCancelFuncs map[string]context.CancelFunc
}
func Make(config *appview.Config) (*State, error) {
···
return nil, fmt.Errorf("failed to start jetstream watcher: %w", err)
}
+
sessionCancelFuncs := make(map[string]context.CancelFunc)
+
state := &State{
d,
auth,
···
resolver,
jc,
config,
+
sessionCancelFuncs,
}
return state, nil
···
}
log.Printf("successfully saved session for %s (%s)", atSession.Handle, atSession.Did)
+
+
sessionCtx, cancel := context.WithCancel(context.Background())
+
s.sessionCancelFuncs[sessionish.GetDid()] = cancel
+
expiry := auth.ExpiryDuration
+
+
go s.StartTokenRefresher(sessionCtx, expiry, r, w, &sessionish, resolved.PDSEndpoint())
+
s.pages.HxRedirect(w, "/")
return
}
}
func (s *State) Logout(w http.ResponseWriter, r *http.Request) {
+
session, err := s.auth.GetSession(r)
+
did := session.Values[appview.SessionDid].(string)
+
if err == nil {
+
if cancel, exists := s.sessionCancelFuncs[did]; exists {
+
cancel()
+
delete(s.sessionCancelFuncs, did)
+
}
+
}
+
s.auth.ClearSession(r, w)
http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
}
···
switch r.Method {
case http.MethodGet:
user := s.auth.GetUser(r)
+
err := s.enforcer.AddMember("knot1.tangled.sh", user.Did)
+
if err != nil {
+
log.Println("failed to add user to knot1.tangled.sh: ", err)
+
s.pages.Notice(w, "repo", "Failed to add user to knot1.tangled.sh. You should be able to use your own knot however.")
+
}
+
knots, err := s.enforcer.GetDomainsForUser(user.Did)
-
if err != nil {
s.pages.Notice(w, "repo", "Invalid user account.")
return