forked from tangled.org/core
Monorepo for Tangled — https://tangled.org
1package oauth 2 3import ( 4 "errors" 5 "fmt" 6 "net/http" 7 "time" 8 9 comatproto "github.com/bluesky-social/indigo/api/atproto" 10 "github.com/bluesky-social/indigo/atproto/auth/oauth" 11 atpclient "github.com/bluesky-social/indigo/atproto/client" 12 "github.com/bluesky-social/indigo/atproto/syntax" 13 xrpc "github.com/bluesky-social/indigo/xrpc" 14 "github.com/gorilla/sessions" 15 "github.com/lestrrat-go/jwx/v2/jwk" 16 "tangled.org/core/appview/config" 17) 18 19func New(config *config.Config) (*OAuth, error) { 20 21 var oauthConfig oauth.ClientConfig 22 var clientUri string 23 24 if config.Core.Dev { 25 clientUri = "http://127.0.0.1:3000" 26 callbackUri := clientUri + "/oauth/callback" 27 oauthConfig = oauth.NewLocalhostConfig(callbackUri, []string{"atproto", "transition:generic"}) 28 } else { 29 clientUri = config.Core.AppviewHost 30 clientId := fmt.Sprintf("%s/oauth/client-metadata.json", clientUri) 31 callbackUri := clientUri + "/oauth/callback" 32 oauthConfig = oauth.NewPublicConfig(clientId, callbackUri, []string{"atproto", "transition:generic"}) 33 } 34 35 jwksUri := clientUri + "/oauth/jwks.json" 36 37 authStore, err := NewRedisStore(config.Redis.ToURL()) 38 if err != nil { 39 return nil, err 40 } 41 42 sessStore := sessions.NewCookieStore([]byte(config.Core.CookieSecret)) 43 44 return &OAuth{ 45 ClientApp: oauth.NewClientApp(&oauthConfig, authStore), 46 Config: config, 47 SessStore: sessStore, 48 JwksUri: jwksUri, 49 }, nil 50} 51 52type OAuth struct { 53 ClientApp *oauth.ClientApp 54 SessStore *sessions.CookieStore 55 Config *config.Config 56 JwksUri string 57} 58 59func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, sessData *oauth.ClientSessionData) error { 60 // first we save the did in the user session 61 userSession, err := o.SessStore.Get(r, SessionName) 62 if err != nil { 63 return err 64 } 65 66 userSession.Values[SessionDid] = sessData.AccountDID.String() 67 userSession.Values[SessionPds] = sessData.HostURL 68 userSession.Values[SessionId] = sessData.SessionID 69 userSession.Values[SessionAuthenticated] = true 70 return userSession.Save(r, w) 71} 72 73func (o *OAuth) ResumeSession(r *http.Request) (*oauth.ClientSession, error) { 74 userSession, err := o.SessStore.Get(r, SessionName) 75 if err != nil { 76 return nil, fmt.Errorf("error getting user session: %w", err) 77 } 78 if userSession.IsNew { 79 return nil, fmt.Errorf("no session available for user") 80 } 81 82 d := userSession.Values[SessionDid].(string) 83 sessDid, err := syntax.ParseDID(d) 84 if err != nil { 85 return nil, fmt.Errorf("malformed DID in session cookie '%s': %w", d, err) 86 } 87 88 sessId := userSession.Values[SessionId].(string) 89 90 clientSess, err := o.ClientApp.ResumeSession(r.Context(), sessDid, sessId) 91 if err != nil { 92 return nil, fmt.Errorf("failed to resume session: %w", err) 93 } 94 95 return clientSess, nil 96} 97 98func (o *OAuth) DeleteSession(w http.ResponseWriter, r *http.Request) error { 99 userSession, err := o.SessStore.Get(r, SessionName) 100 if err != nil { 101 return fmt.Errorf("error getting user session: %w", err) 102 } 103 if userSession.IsNew { 104 return fmt.Errorf("no session available for user") 105 } 106 107 d := userSession.Values[SessionDid].(string) 108 sessDid, err := syntax.ParseDID(d) 109 if err != nil { 110 return fmt.Errorf("malformed DID in session cookie '%s': %w", d, err) 111 } 112 113 sessId := userSession.Values[SessionId].(string) 114 115 // delete the session 116 err1 := o.ClientApp.Logout(r.Context(), sessDid, sessId) 117 118 // remove the cookie 119 userSession.Options.MaxAge = -1 120 err2 := o.SessStore.Save(r, w, userSession) 121 122 return errors.Join(err1, err2) 123} 124 125func pubKeyFromJwk(jwks string) (jwk.Key, error) { 126 k, err := jwk.ParseKey([]byte(jwks)) 127 if err != nil { 128 return nil, err 129 } 130 pubKey, err := k.PublicKey() 131 if err != nil { 132 return nil, err 133 } 134 return pubKey, nil 135} 136 137type User struct { 138 Did string 139 Pds string 140} 141 142func (o *OAuth) GetUser(r *http.Request) *User { 143 sess, err := o.SessStore.Get(r, SessionName) 144 145 if err != nil || sess.IsNew { 146 return nil 147 } 148 149 return &User{ 150 Did: sess.Values[SessionDid].(string), 151 Pds: sess.Values[SessionPds].(string), 152 } 153} 154 155func (o *OAuth) GetDid(r *http.Request) string { 156 if u := o.GetUser(r); u != nil { 157 return u.Did 158 } 159 160 return "" 161} 162 163func (o *OAuth) AuthorizedClient(r *http.Request) (*atpclient.APIClient, error) { 164 session, err := o.ResumeSession(r) 165 if err != nil { 166 return nil, fmt.Errorf("error getting session: %w", err) 167 } 168 return session.APIClient(), nil 169} 170 171// this is a higher level abstraction on ServerGetServiceAuth 172type ServiceClientOpts struct { 173 service string 174 exp int64 175 lxm string 176 dev bool 177} 178 179type ServiceClientOpt func(*ServiceClientOpts) 180 181func WithService(service string) ServiceClientOpt { 182 return func(s *ServiceClientOpts) { 183 s.service = service 184 } 185} 186 187// Specify the Duration in seconds for the expiry of this token 188// 189// The time of expiry is calculated as time.Now().Unix() + exp 190func WithExp(exp int64) ServiceClientOpt { 191 return func(s *ServiceClientOpts) { 192 s.exp = time.Now().Unix() + exp 193 } 194} 195 196func WithLxm(lxm string) ServiceClientOpt { 197 return func(s *ServiceClientOpts) { 198 s.lxm = lxm 199 } 200} 201 202func WithDev(dev bool) ServiceClientOpt { 203 return func(s *ServiceClientOpts) { 204 s.dev = dev 205 } 206} 207 208func (s *ServiceClientOpts) Audience() string { 209 return fmt.Sprintf("did:web:%s", s.service) 210} 211 212func (s *ServiceClientOpts) Host() string { 213 scheme := "https://" 214 if s.dev { 215 scheme = "http://" 216 } 217 218 return scheme + s.service 219} 220 221func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*xrpc.Client, error) { 222 opts := ServiceClientOpts{} 223 for _, o := range os { 224 o(&opts) 225 } 226 227 client, err := o.AuthorizedClient(r) 228 if err != nil { 229 return nil, err 230 } 231 232 // force expiry to atleast 60 seconds in the future 233 sixty := time.Now().Unix() + 60 234 if opts.exp < sixty { 235 opts.exp = sixty 236 } 237 238 resp, err := comatproto.ServerGetServiceAuth(r.Context(), client, opts.Audience(), opts.exp, opts.lxm) 239 if err != nil { 240 return nil, err 241 } 242 243 return &xrpc.Client{ 244 Auth: &xrpc.AuthInfo{ 245 AccessJwt: resp.Token, 246 }, 247 Host: opts.Host(), 248 Client: &http.Client{ 249 Timeout: time.Second * 5, 250 }, 251 }, nil 252}