forked from tangled.org/core
Monorepo for Tangled — https://tangled.org
1package oauth 2 3import ( 4 "errors" 5 "fmt" 6 "log/slog" 7 "net/http" 8 "time" 9 10 comatproto "github.com/bluesky-social/indigo/api/atproto" 11 "github.com/bluesky-social/indigo/atproto/auth/oauth" 12 atpclient "github.com/bluesky-social/indigo/atproto/client" 13 atcrypto "github.com/bluesky-social/indigo/atproto/crypto" 14 "github.com/bluesky-social/indigo/atproto/syntax" 15 xrpc "github.com/bluesky-social/indigo/xrpc" 16 "github.com/gorilla/sessions" 17 "github.com/posthog/posthog-go" 18 "tangled.org/core/appview/config" 19 "tangled.org/core/appview/db" 20 "tangled.org/core/idresolver" 21 "tangled.org/core/rbac" 22) 23 24type OAuth struct { 25 ClientApp *oauth.ClientApp 26 SessStore *sessions.CookieStore 27 Config *config.Config 28 JwksUri string 29 ClientName string 30 ClientUri string 31 Posthog posthog.Client 32 Db *db.DB 33 Enforcer *rbac.Enforcer 34 IdResolver *idresolver.Resolver 35 Logger *slog.Logger 36} 37 38func New(config *config.Config, ph posthog.Client, db *db.DB, enforcer *rbac.Enforcer, res *idresolver.Resolver, logger *slog.Logger) (*OAuth, error) { 39 var oauthConfig oauth.ClientConfig 40 var clientUri string 41 if config.Core.Dev { 42 clientUri = "http://127.0.0.1:3000" 43 callbackUri := clientUri + "/oauth/callback" 44 oauthConfig = oauth.NewLocalhostConfig(callbackUri, []string{"atproto", "transition:generic"}) 45 } else { 46 clientUri = config.Core.AppviewHost 47 clientId := fmt.Sprintf("%s/oauth/client-metadata.json", clientUri) 48 callbackUri := clientUri + "/oauth/callback" 49 oauthConfig = oauth.NewPublicConfig(clientId, callbackUri, []string{"atproto", "transition:generic"}) 50 } 51 52 // configure client secret 53 priv, err := atcrypto.ParsePrivateMultibase(config.OAuth.ClientSecret) 54 if err != nil { 55 return nil, err 56 } 57 if err := oauthConfig.SetClientSecret(priv, config.OAuth.ClientKid); err != nil { 58 return nil, err 59 } 60 61 jwksUri := clientUri + "/oauth/jwks.json" 62 63 authStore, err := NewRedisStore(&RedisStoreConfig{ 64 RedisURL: config.Redis.ToURL(), 65 SessionExpiryDuration: time.Hour * 24 * 90, 66 SessionInactivityDuration: time.Hour * 24 * 14, 67 AuthRequestExpiryDuration: time.Minute * 30, 68 }) 69 if err != nil { 70 return nil, err 71 } 72 73 sessStore := sessions.NewCookieStore([]byte(config.Core.CookieSecret)) 74 75 clientApp := oauth.NewClientApp(&oauthConfig, authStore) 76 clientApp.Dir = res.Directory() 77 78 clientName := config.Core.AppviewName 79 80 logger.Info("oauth setup successfully", "IsConfidential", clientApp.Config.IsConfidential()) 81 return &OAuth{ 82 ClientApp: clientApp, 83 Config: config, 84 SessStore: sessStore, 85 JwksUri: jwksUri, 86 ClientName: clientName, 87 ClientUri: clientUri, 88 Posthog: ph, 89 Db: db, 90 Enforcer: enforcer, 91 IdResolver: res, 92 Logger: logger, 93 }, nil 94} 95 96func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, sessData *oauth.ClientSessionData) error { 97 // first we save the did in the user session 98 userSession, err := o.SessStore.Get(r, SessionName) 99 if err != nil { 100 return err 101 } 102 103 userSession.Values[SessionDid] = sessData.AccountDID.String() 104 userSession.Values[SessionPds] = sessData.HostURL 105 userSession.Values[SessionId] = sessData.SessionID 106 userSession.Values[SessionAuthenticated] = true 107 return userSession.Save(r, w) 108} 109 110func (o *OAuth) ResumeSession(r *http.Request) (*oauth.ClientSession, error) { 111 userSession, err := o.SessStore.Get(r, SessionName) 112 if err != nil { 113 return nil, fmt.Errorf("error getting user session: %w", err) 114 } 115 if userSession.IsNew { 116 return nil, fmt.Errorf("no session available for user") 117 } 118 119 d := userSession.Values[SessionDid].(string) 120 sessDid, err := syntax.ParseDID(d) 121 if err != nil { 122 return nil, fmt.Errorf("malformed DID in session cookie '%s': %w", d, err) 123 } 124 125 sessId := userSession.Values[SessionId].(string) 126 127 clientSess, err := o.ClientApp.ResumeSession(r.Context(), sessDid, sessId) 128 if err != nil { 129 return nil, fmt.Errorf("failed to resume session: %w", err) 130 } 131 132 return clientSess, nil 133} 134 135func (o *OAuth) DeleteSession(w http.ResponseWriter, r *http.Request) error { 136 userSession, err := o.SessStore.Get(r, SessionName) 137 if err != nil { 138 return fmt.Errorf("error getting user session: %w", err) 139 } 140 if userSession.IsNew { 141 return fmt.Errorf("no session available for user") 142 } 143 144 d := userSession.Values[SessionDid].(string) 145 sessDid, err := syntax.ParseDID(d) 146 if err != nil { 147 return fmt.Errorf("malformed DID in session cookie '%s': %w", d, err) 148 } 149 150 sessId := userSession.Values[SessionId].(string) 151 152 // delete the session 153 err1 := o.ClientApp.Logout(r.Context(), sessDid, sessId) 154 155 // remove the cookie 156 userSession.Options.MaxAge = -1 157 err2 := o.SessStore.Save(r, w, userSession) 158 159 return errors.Join(err1, err2) 160} 161 162type User struct { 163 Did string 164 Pds string 165} 166 167func (o *OAuth) GetUser(r *http.Request) *User { 168 sess, err := o.ResumeSession(r) 169 if err != nil { 170 return nil 171 } 172 173 return &User{ 174 Did: sess.Data.AccountDID.String(), 175 Pds: sess.Data.HostURL, 176 } 177} 178 179func (o *OAuth) GetDid(r *http.Request) string { 180 if u := o.GetUser(r); u != nil { 181 return u.Did 182 } 183 184 return "" 185} 186 187func (o *OAuth) AuthorizedClient(r *http.Request) (*atpclient.APIClient, error) { 188 session, err := o.ResumeSession(r) 189 if err != nil { 190 return nil, fmt.Errorf("error getting session: %w", err) 191 } 192 return session.APIClient(), nil 193} 194 195// this is a higher level abstraction on ServerGetServiceAuth 196type ServiceClientOpts struct { 197 service string 198 exp int64 199 lxm string 200 dev bool 201} 202 203type ServiceClientOpt func(*ServiceClientOpts) 204 205func WithService(service string) ServiceClientOpt { 206 return func(s *ServiceClientOpts) { 207 s.service = service 208 } 209} 210 211// Specify the Duration in seconds for the expiry of this token 212// 213// The time of expiry is calculated as time.Now().Unix() + exp 214func WithExp(exp int64) ServiceClientOpt { 215 return func(s *ServiceClientOpts) { 216 s.exp = time.Now().Unix() + exp 217 } 218} 219 220func WithLxm(lxm string) ServiceClientOpt { 221 return func(s *ServiceClientOpts) { 222 s.lxm = lxm 223 } 224} 225 226func WithDev(dev bool) ServiceClientOpt { 227 return func(s *ServiceClientOpts) { 228 s.dev = dev 229 } 230} 231 232func (s *ServiceClientOpts) Audience() string { 233 return fmt.Sprintf("did:web:%s", s.service) 234} 235 236func (s *ServiceClientOpts) Host() string { 237 scheme := "https://" 238 if s.dev { 239 scheme = "http://" 240 } 241 242 return scheme + s.service 243} 244 245func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*xrpc.Client, error) { 246 opts := ServiceClientOpts{} 247 for _, o := range os { 248 o(&opts) 249 } 250 251 client, err := o.AuthorizedClient(r) 252 if err != nil { 253 return nil, err 254 } 255 256 // force expiry to atleast 60 seconds in the future 257 sixty := time.Now().Unix() + 60 258 if opts.exp < sixty { 259 opts.exp = sixty 260 } 261 262 resp, err := comatproto.ServerGetServiceAuth(r.Context(), client, opts.Audience(), opts.exp, opts.lxm) 263 if err != nil { 264 return nil, err 265 } 266 267 return &xrpc.Client{ 268 Auth: &xrpc.AuthInfo{ 269 AccessJwt: resp.Token, 270 }, 271 Host: opts.Host(), 272 Client: &http.Client{ 273 Timeout: time.Second * 5, 274 }, 275 }, nil 276}