forked from tangled.org/core
Monorepo for Tangled — https://tangled.org
1package oauth 2 3import ( 4 "errors" 5 "fmt" 6 "net/http" 7 "time" 8 9 comatproto "github.com/bluesky-social/indigo/api/atproto" 10 "github.com/bluesky-social/indigo/atproto/auth/oauth" 11 atpclient "github.com/bluesky-social/indigo/atproto/client" 12 "github.com/bluesky-social/indigo/atproto/syntax" 13 xrpc "github.com/bluesky-social/indigo/xrpc" 14 "github.com/gorilla/sessions" 15 "github.com/lestrrat-go/jwx/v2/jwk" 16 "github.com/posthog/posthog-go" 17 "tangled.org/core/appview/config" 18 "tangled.org/core/appview/db" 19 "tangled.org/core/idresolver" 20 "tangled.org/core/rbac" 21) 22 23func New(config *config.Config, ph posthog.Client, db *db.DB, enforcer *rbac.Enforcer, res *idresolver.Resolver) (*OAuth, error) { 24 25 var oauthConfig oauth.ClientConfig 26 var clientUri string 27 28 if config.Core.Dev { 29 clientUri = "http://127.0.0.1:3000" 30 callbackUri := clientUri + "/oauth/callback" 31 oauthConfig = oauth.NewLocalhostConfig(callbackUri, []string{"atproto", "transition:generic"}) 32 } else { 33 clientUri = config.Core.AppviewHost 34 clientId := fmt.Sprintf("%s/oauth/client-metadata.json", clientUri) 35 callbackUri := clientUri + "/oauth/callback" 36 oauthConfig = oauth.NewPublicConfig(clientId, callbackUri, []string{"atproto", "transition:generic"}) 37 } 38 39 jwksUri := clientUri + "/oauth/jwks.json" 40 41 authStore, err := NewRedisStore(config.Redis.ToURL()) 42 if err != nil { 43 return nil, err 44 } 45 46 sessStore := sessions.NewCookieStore([]byte(config.Core.CookieSecret)) 47 48 return &OAuth{ 49 ClientApp: oauth.NewClientApp(&oauthConfig, authStore), 50 Config: config, 51 SessStore: sessStore, 52 JwksUri: jwksUri, 53 Posthog: ph, 54 Db: db, 55 Enforcer: enforcer, 56 IdResolver: res, 57 }, nil 58} 59 60type OAuth struct { 61 ClientApp *oauth.ClientApp 62 SessStore *sessions.CookieStore 63 Config *config.Config 64 JwksUri string 65 Posthog posthog.Client 66 Db *db.DB 67 Enforcer *rbac.Enforcer 68 IdResolver *idresolver.Resolver 69} 70 71func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, sessData *oauth.ClientSessionData) error { 72 // first we save the did in the user session 73 userSession, err := o.SessStore.Get(r, SessionName) 74 if err != nil { 75 return err 76 } 77 78 userSession.Values[SessionDid] = sessData.AccountDID.String() 79 userSession.Values[SessionPds] = sessData.HostURL 80 userSession.Values[SessionId] = sessData.SessionID 81 userSession.Values[SessionAuthenticated] = true 82 return userSession.Save(r, w) 83} 84 85func (o *OAuth) ResumeSession(r *http.Request) (*oauth.ClientSession, error) { 86 userSession, err := o.SessStore.Get(r, SessionName) 87 if err != nil { 88 return nil, fmt.Errorf("error getting user session: %w", err) 89 } 90 if userSession.IsNew { 91 return nil, fmt.Errorf("no session available for user") 92 } 93 94 d := userSession.Values[SessionDid].(string) 95 sessDid, err := syntax.ParseDID(d) 96 if err != nil { 97 return nil, fmt.Errorf("malformed DID in session cookie '%s': %w", d, err) 98 } 99 100 sessId := userSession.Values[SessionId].(string) 101 102 clientSess, err := o.ClientApp.ResumeSession(r.Context(), sessDid, sessId) 103 if err != nil { 104 return nil, fmt.Errorf("failed to resume session: %w", err) 105 } 106 107 return clientSess, nil 108} 109 110func (o *OAuth) DeleteSession(w http.ResponseWriter, r *http.Request) error { 111 userSession, err := o.SessStore.Get(r, SessionName) 112 if err != nil { 113 return fmt.Errorf("error getting user session: %w", err) 114 } 115 if userSession.IsNew { 116 return fmt.Errorf("no session available for user") 117 } 118 119 d := userSession.Values[SessionDid].(string) 120 sessDid, err := syntax.ParseDID(d) 121 if err != nil { 122 return fmt.Errorf("malformed DID in session cookie '%s': %w", d, err) 123 } 124 125 sessId := userSession.Values[SessionId].(string) 126 127 // delete the session 128 err1 := o.ClientApp.Logout(r.Context(), sessDid, sessId) 129 130 // remove the cookie 131 userSession.Options.MaxAge = -1 132 err2 := o.SessStore.Save(r, w, userSession) 133 134 return errors.Join(err1, err2) 135} 136 137func pubKeyFromJwk(jwks string) (jwk.Key, error) { 138 k, err := jwk.ParseKey([]byte(jwks)) 139 if err != nil { 140 return nil, err 141 } 142 pubKey, err := k.PublicKey() 143 if err != nil { 144 return nil, err 145 } 146 return pubKey, nil 147} 148 149type User struct { 150 Did string 151 Pds string 152} 153 154func (o *OAuth) GetUser(r *http.Request) *User { 155 sess, err := o.SessStore.Get(r, SessionName) 156 157 if err != nil || sess.IsNew { 158 return nil 159 } 160 161 return &User{ 162 Did: sess.Values[SessionDid].(string), 163 Pds: sess.Values[SessionPds].(string), 164 } 165} 166 167func (o *OAuth) GetDid(r *http.Request) string { 168 if u := o.GetUser(r); u != nil { 169 return u.Did 170 } 171 172 return "" 173} 174 175func (o *OAuth) AuthorizedClient(r *http.Request) (*atpclient.APIClient, error) { 176 session, err := o.ResumeSession(r) 177 if err != nil { 178 return nil, fmt.Errorf("error getting session: %w", err) 179 } 180 return session.APIClient(), nil 181} 182 183// this is a higher level abstraction on ServerGetServiceAuth 184type ServiceClientOpts struct { 185 service string 186 exp int64 187 lxm string 188 dev bool 189} 190 191type ServiceClientOpt func(*ServiceClientOpts) 192 193func WithService(service string) ServiceClientOpt { 194 return func(s *ServiceClientOpts) { 195 s.service = service 196 } 197} 198 199// Specify the Duration in seconds for the expiry of this token 200// 201// The time of expiry is calculated as time.Now().Unix() + exp 202func WithExp(exp int64) ServiceClientOpt { 203 return func(s *ServiceClientOpts) { 204 s.exp = time.Now().Unix() + exp 205 } 206} 207 208func WithLxm(lxm string) ServiceClientOpt { 209 return func(s *ServiceClientOpts) { 210 s.lxm = lxm 211 } 212} 213 214func WithDev(dev bool) ServiceClientOpt { 215 return func(s *ServiceClientOpts) { 216 s.dev = dev 217 } 218} 219 220func (s *ServiceClientOpts) Audience() string { 221 return fmt.Sprintf("did:web:%s", s.service) 222} 223 224func (s *ServiceClientOpts) Host() string { 225 scheme := "https://" 226 if s.dev { 227 scheme = "http://" 228 } 229 230 return scheme + s.service 231} 232 233func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*xrpc.Client, error) { 234 opts := ServiceClientOpts{} 235 for _, o := range os { 236 o(&opts) 237 } 238 239 client, err := o.AuthorizedClient(r) 240 if err != nil { 241 return nil, err 242 } 243 244 // force expiry to atleast 60 seconds in the future 245 sixty := time.Now().Unix() + 60 246 if opts.exp < sixty { 247 opts.exp = sixty 248 } 249 250 resp, err := comatproto.ServerGetServiceAuth(r.Context(), client, opts.Audience(), opts.exp, opts.lxm) 251 if err != nil { 252 return nil, err 253 } 254 255 return &xrpc.Client{ 256 Auth: &xrpc.AuthInfo{ 257 AccessJwt: resp.Token, 258 }, 259 Host: opts.Host(), 260 Client: &http.Client{ 261 Timeout: time.Second * 5, 262 }, 263 }, nil 264}