forked from tangled.org/core
Monorepo for Tangled — https://tangled.org
1package oauth 2 3import ( 4 "errors" 5 "fmt" 6 "log/slog" 7 "net/http" 8 "time" 9 10 comatproto "github.com/bluesky-social/indigo/api/atproto" 11 "github.com/bluesky-social/indigo/atproto/auth/oauth" 12 atpclient "github.com/bluesky-social/indigo/atproto/client" 13 "github.com/bluesky-social/indigo/atproto/syntax" 14 xrpc "github.com/bluesky-social/indigo/xrpc" 15 "github.com/gorilla/sessions" 16 "github.com/lestrrat-go/jwx/v2/jwk" 17 "github.com/posthog/posthog-go" 18 "tangled.org/core/appview/config" 19 "tangled.org/core/appview/db" 20 "tangled.org/core/idresolver" 21 "tangled.org/core/rbac" 22) 23 24type OAuth struct { 25 ClientApp *oauth.ClientApp 26 SessStore *sessions.CookieStore 27 Config *config.Config 28 JwksUri string 29 Posthog posthog.Client 30 Db *db.DB 31 Enforcer *rbac.Enforcer 32 IdResolver *idresolver.Resolver 33 Logger *slog.Logger 34} 35 36func New(config *config.Config, ph posthog.Client, db *db.DB, enforcer *rbac.Enforcer, res *idresolver.Resolver, logger *slog.Logger) (*OAuth, error) { 37 38 var oauthConfig oauth.ClientConfig 39 var clientUri string 40 41 if config.Core.Dev { 42 clientUri = "http://127.0.0.1:3000" 43 callbackUri := clientUri + "/oauth/callback" 44 oauthConfig = oauth.NewLocalhostConfig(callbackUri, []string{"atproto", "transition:generic"}) 45 } else { 46 clientUri = config.Core.AppviewHost 47 clientId := fmt.Sprintf("%s/oauth/client-metadata.json", clientUri) 48 callbackUri := clientUri + "/oauth/callback" 49 oauthConfig = oauth.NewPublicConfig(clientId, callbackUri, []string{"atproto", "transition:generic"}) 50 } 51 52 jwksUri := clientUri + "/oauth/jwks.json" 53 54 authStore, err := NewRedisStore(config.Redis.ToURL()) 55 if err != nil { 56 return nil, err 57 } 58 59 sessStore := sessions.NewCookieStore([]byte(config.Core.CookieSecret)) 60 61 clientApp := oauth.NewClientApp(&oauthConfig, authStore) 62 clientApp.Dir = res.Directory() 63 64 return &OAuth{ 65 ClientApp: clientApp, 66 Config: config, 67 SessStore: sessStore, 68 JwksUri: jwksUri, 69 Posthog: ph, 70 Db: db, 71 Enforcer: enforcer, 72 IdResolver: res, 73 Logger: logger, 74 }, nil 75} 76 77func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, sessData *oauth.ClientSessionData) error { 78 // first we save the did in the user session 79 userSession, err := o.SessStore.Get(r, SessionName) 80 if err != nil { 81 return err 82 } 83 84 userSession.Values[SessionDid] = sessData.AccountDID.String() 85 userSession.Values[SessionPds] = sessData.HostURL 86 userSession.Values[SessionId] = sessData.SessionID 87 userSession.Values[SessionAuthenticated] = true 88 return userSession.Save(r, w) 89} 90 91func (o *OAuth) ResumeSession(r *http.Request) (*oauth.ClientSession, error) { 92 userSession, err := o.SessStore.Get(r, SessionName) 93 if err != nil { 94 return nil, fmt.Errorf("error getting user session: %w", err) 95 } 96 if userSession.IsNew { 97 return nil, fmt.Errorf("no session available for user") 98 } 99 100 d := userSession.Values[SessionDid].(string) 101 sessDid, err := syntax.ParseDID(d) 102 if err != nil { 103 return nil, fmt.Errorf("malformed DID in session cookie '%s': %w", d, err) 104 } 105 106 sessId := userSession.Values[SessionId].(string) 107 108 clientSess, err := o.ClientApp.ResumeSession(r.Context(), sessDid, sessId) 109 if err != nil { 110 return nil, fmt.Errorf("failed to resume session: %w", err) 111 } 112 113 return clientSess, nil 114} 115 116func (o *OAuth) DeleteSession(w http.ResponseWriter, r *http.Request) error { 117 userSession, err := o.SessStore.Get(r, SessionName) 118 if err != nil { 119 return fmt.Errorf("error getting user session: %w", err) 120 } 121 if userSession.IsNew { 122 return fmt.Errorf("no session available for user") 123 } 124 125 d := userSession.Values[SessionDid].(string) 126 sessDid, err := syntax.ParseDID(d) 127 if err != nil { 128 return fmt.Errorf("malformed DID in session cookie '%s': %w", d, err) 129 } 130 131 sessId := userSession.Values[SessionId].(string) 132 133 // delete the session 134 err1 := o.ClientApp.Logout(r.Context(), sessDid, sessId) 135 136 // remove the cookie 137 userSession.Options.MaxAge = -1 138 err2 := o.SessStore.Save(r, w, userSession) 139 140 return errors.Join(err1, err2) 141} 142 143func pubKeyFromJwk(jwks string) (jwk.Key, error) { 144 k, err := jwk.ParseKey([]byte(jwks)) 145 if err != nil { 146 return nil, err 147 } 148 pubKey, err := k.PublicKey() 149 if err != nil { 150 return nil, err 151 } 152 return pubKey, nil 153} 154 155type User struct { 156 Did string 157 Pds string 158} 159 160func (o *OAuth) GetUser(r *http.Request) *User { 161 sess, err := o.SessStore.Get(r, SessionName) 162 163 if err != nil || sess.IsNew { 164 return nil 165 } 166 167 return &User{ 168 Did: sess.Values[SessionDid].(string), 169 Pds: sess.Values[SessionPds].(string), 170 } 171} 172 173func (o *OAuth) GetDid(r *http.Request) string { 174 if u := o.GetUser(r); u != nil { 175 return u.Did 176 } 177 178 return "" 179} 180 181func (o *OAuth) AuthorizedClient(r *http.Request) (*atpclient.APIClient, error) { 182 session, err := o.ResumeSession(r) 183 if err != nil { 184 return nil, fmt.Errorf("error getting session: %w", err) 185 } 186 return session.APIClient(), nil 187} 188 189// this is a higher level abstraction on ServerGetServiceAuth 190type ServiceClientOpts struct { 191 service string 192 exp int64 193 lxm string 194 dev bool 195} 196 197type ServiceClientOpt func(*ServiceClientOpts) 198 199func WithService(service string) ServiceClientOpt { 200 return func(s *ServiceClientOpts) { 201 s.service = service 202 } 203} 204 205// Specify the Duration in seconds for the expiry of this token 206// 207// The time of expiry is calculated as time.Now().Unix() + exp 208func WithExp(exp int64) ServiceClientOpt { 209 return func(s *ServiceClientOpts) { 210 s.exp = time.Now().Unix() + exp 211 } 212} 213 214func WithLxm(lxm string) ServiceClientOpt { 215 return func(s *ServiceClientOpts) { 216 s.lxm = lxm 217 } 218} 219 220func WithDev(dev bool) ServiceClientOpt { 221 return func(s *ServiceClientOpts) { 222 s.dev = dev 223 } 224} 225 226func (s *ServiceClientOpts) Audience() string { 227 return fmt.Sprintf("did:web:%s", s.service) 228} 229 230func (s *ServiceClientOpts) Host() string { 231 scheme := "https://" 232 if s.dev { 233 scheme = "http://" 234 } 235 236 return scheme + s.service 237} 238 239func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*xrpc.Client, error) { 240 opts := ServiceClientOpts{} 241 for _, o := range os { 242 o(&opts) 243 } 244 245 client, err := o.AuthorizedClient(r) 246 if err != nil { 247 return nil, err 248 } 249 250 // force expiry to atleast 60 seconds in the future 251 sixty := time.Now().Unix() + 60 252 if opts.exp < sixty { 253 opts.exp = sixty 254 } 255 256 resp, err := comatproto.ServerGetServiceAuth(r.Context(), client, opts.Audience(), opts.exp, opts.lxm) 257 if err != nil { 258 return nil, err 259 } 260 261 return &xrpc.Client{ 262 Auth: &xrpc.AuthInfo{ 263 AccessJwt: resp.Token, 264 }, 265 Host: opts.Host(), 266 Client: &http.Client{ 267 Timeout: time.Second * 5, 268 }, 269 }, nil 270}