forked from tangled.org/core
Monorepo for Tangled — https://tangled.org
1package oauth 2 3import ( 4 "errors" 5 "fmt" 6 "log/slog" 7 "net/http" 8 "time" 9 10 comatproto "github.com/bluesky-social/indigo/api/atproto" 11 "github.com/bluesky-social/indigo/atproto/auth/oauth" 12 atpclient "github.com/bluesky-social/indigo/atproto/client" 13 "github.com/bluesky-social/indigo/atproto/syntax" 14 xrpc "github.com/bluesky-social/indigo/xrpc" 15 "github.com/gorilla/sessions" 16 "github.com/lestrrat-go/jwx/v2/jwk" 17 "github.com/posthog/posthog-go" 18 "tangled.org/core/appview/config" 19 "tangled.org/core/appview/db" 20 "tangled.org/core/idresolver" 21 "tangled.org/core/rbac" 22) 23 24type OAuth struct { 25 ClientApp *oauth.ClientApp 26 SessStore *sessions.CookieStore 27 Config *config.Config 28 JwksUri string 29 Posthog posthog.Client 30 Db *db.DB 31 Enforcer *rbac.Enforcer 32 IdResolver *idresolver.Resolver 33 Logger *slog.Logger 34} 35 36func New(config *config.Config, ph posthog.Client, db *db.DB, enforcer *rbac.Enforcer, res *idresolver.Resolver, logger *slog.Logger) (*OAuth, error) { 37 38 var oauthConfig oauth.ClientConfig 39 var clientUri string 40 41 if config.Core.Dev { 42 clientUri = "http://127.0.0.1:3000" 43 callbackUri := clientUri + "/oauth/callback" 44 oauthConfig = oauth.NewLocalhostConfig(callbackUri, []string{"atproto", "transition:generic"}) 45 } else { 46 clientUri = config.Core.AppviewHost 47 clientId := fmt.Sprintf("%s/oauth/client-metadata.json", clientUri) 48 callbackUri := clientUri + "/oauth/callback" 49 oauthConfig = oauth.NewPublicConfig(clientId, callbackUri, []string{"atproto", "transition:generic"}) 50 } 51 52 jwksUri := clientUri + "/oauth/jwks.json" 53 54 authStore, err := NewRedisStore(config.Redis.ToURL()) 55 if err != nil { 56 return nil, err 57 } 58 59 sessStore := sessions.NewCookieStore([]byte(config.Core.CookieSecret)) 60 61 return &OAuth{ 62 ClientApp: oauth.NewClientApp(&oauthConfig, authStore), 63 Config: config, 64 SessStore: sessStore, 65 JwksUri: jwksUri, 66 Posthog: ph, 67 Db: db, 68 Enforcer: enforcer, 69 IdResolver: res, 70 Logger: logger, 71 }, nil 72} 73 74func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, sessData *oauth.ClientSessionData) error { 75 // first we save the did in the user session 76 userSession, err := o.SessStore.Get(r, SessionName) 77 if err != nil { 78 return err 79 } 80 81 userSession.Values[SessionDid] = sessData.AccountDID.String() 82 userSession.Values[SessionPds] = sessData.HostURL 83 userSession.Values[SessionId] = sessData.SessionID 84 userSession.Values[SessionAuthenticated] = true 85 return userSession.Save(r, w) 86} 87 88func (o *OAuth) ResumeSession(r *http.Request) (*oauth.ClientSession, error) { 89 userSession, err := o.SessStore.Get(r, SessionName) 90 if err != nil { 91 return nil, fmt.Errorf("error getting user session: %w", err) 92 } 93 if userSession.IsNew { 94 return nil, fmt.Errorf("no session available for user") 95 } 96 97 d := userSession.Values[SessionDid].(string) 98 sessDid, err := syntax.ParseDID(d) 99 if err != nil { 100 return nil, fmt.Errorf("malformed DID in session cookie '%s': %w", d, err) 101 } 102 103 sessId := userSession.Values[SessionId].(string) 104 105 clientSess, err := o.ClientApp.ResumeSession(r.Context(), sessDid, sessId) 106 if err != nil { 107 return nil, fmt.Errorf("failed to resume session: %w", err) 108 } 109 110 return clientSess, nil 111} 112 113func (o *OAuth) DeleteSession(w http.ResponseWriter, r *http.Request) error { 114 userSession, err := o.SessStore.Get(r, SessionName) 115 if err != nil { 116 return fmt.Errorf("error getting user session: %w", err) 117 } 118 if userSession.IsNew { 119 return fmt.Errorf("no session available for user") 120 } 121 122 d := userSession.Values[SessionDid].(string) 123 sessDid, err := syntax.ParseDID(d) 124 if err != nil { 125 return fmt.Errorf("malformed DID in session cookie '%s': %w", d, err) 126 } 127 128 sessId := userSession.Values[SessionId].(string) 129 130 // delete the session 131 err1 := o.ClientApp.Logout(r.Context(), sessDid, sessId) 132 133 // remove the cookie 134 userSession.Options.MaxAge = -1 135 err2 := o.SessStore.Save(r, w, userSession) 136 137 return errors.Join(err1, err2) 138} 139 140func pubKeyFromJwk(jwks string) (jwk.Key, error) { 141 k, err := jwk.ParseKey([]byte(jwks)) 142 if err != nil { 143 return nil, err 144 } 145 pubKey, err := k.PublicKey() 146 if err != nil { 147 return nil, err 148 } 149 return pubKey, nil 150} 151 152type User struct { 153 Did string 154 Pds string 155} 156 157func (o *OAuth) GetUser(r *http.Request) *User { 158 sess, err := o.SessStore.Get(r, SessionName) 159 160 if err != nil || sess.IsNew { 161 return nil 162 } 163 164 return &User{ 165 Did: sess.Values[SessionDid].(string), 166 Pds: sess.Values[SessionPds].(string), 167 } 168} 169 170func (o *OAuth) GetDid(r *http.Request) string { 171 if u := o.GetUser(r); u != nil { 172 return u.Did 173 } 174 175 return "" 176} 177 178func (o *OAuth) AuthorizedClient(r *http.Request) (*atpclient.APIClient, error) { 179 session, err := o.ResumeSession(r) 180 if err != nil { 181 return nil, fmt.Errorf("error getting session: %w", err) 182 } 183 return session.APIClient(), nil 184} 185 186// this is a higher level abstraction on ServerGetServiceAuth 187type ServiceClientOpts struct { 188 service string 189 exp int64 190 lxm string 191 dev bool 192} 193 194type ServiceClientOpt func(*ServiceClientOpts) 195 196func WithService(service string) ServiceClientOpt { 197 return func(s *ServiceClientOpts) { 198 s.service = service 199 } 200} 201 202// Specify the Duration in seconds for the expiry of this token 203// 204// The time of expiry is calculated as time.Now().Unix() + exp 205func WithExp(exp int64) ServiceClientOpt { 206 return func(s *ServiceClientOpts) { 207 s.exp = time.Now().Unix() + exp 208 } 209} 210 211func WithLxm(lxm string) ServiceClientOpt { 212 return func(s *ServiceClientOpts) { 213 s.lxm = lxm 214 } 215} 216 217func WithDev(dev bool) ServiceClientOpt { 218 return func(s *ServiceClientOpts) { 219 s.dev = dev 220 } 221} 222 223func (s *ServiceClientOpts) Audience() string { 224 return fmt.Sprintf("did:web:%s", s.service) 225} 226 227func (s *ServiceClientOpts) Host() string { 228 scheme := "https://" 229 if s.dev { 230 scheme = "http://" 231 } 232 233 return scheme + s.service 234} 235 236func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*xrpc.Client, error) { 237 opts := ServiceClientOpts{} 238 for _, o := range os { 239 o(&opts) 240 } 241 242 client, err := o.AuthorizedClient(r) 243 if err != nil { 244 return nil, err 245 } 246 247 // force expiry to atleast 60 seconds in the future 248 sixty := time.Now().Unix() + 60 249 if opts.exp < sixty { 250 opts.exp = sixty 251 } 252 253 resp, err := comatproto.ServerGetServiceAuth(r.Context(), client, opts.Audience(), opts.exp, opts.lxm) 254 if err != nil { 255 return nil, err 256 } 257 258 return &xrpc.Client{ 259 Auth: &xrpc.AuthInfo{ 260 AccessJwt: resp.Token, 261 }, 262 Host: opts.Host(), 263 Client: &http.Client{ 264 Timeout: time.Second * 5, 265 }, 266 }, nil 267}