forked from tangled.org/core
this repo has no description
at master 9.7 kB view raw
1package oauth 2 3import ( 4 "fmt" 5 "log" 6 "net/http" 7 "net/url" 8 "time" 9 10 indigo_xrpc "github.com/bluesky-social/indigo/xrpc" 11 "github.com/gorilla/sessions" 12 sessioncache "tangled.org/core/appview/cache/session" 13 "tangled.org/core/appview/config" 14 "tangled.org/core/appview/oauth/client" 15 xrpc "tangled.org/core/appview/xrpcclient" 16 oauth "tangled.sh/icyphox.sh/atproto-oauth" 17 "tangled.sh/icyphox.sh/atproto-oauth/helpers" 18) 19 20type OAuth struct { 21 store *sessions.CookieStore 22 config *config.Config 23 sess *sessioncache.SessionStore 24} 25 26func NewOAuth(config *config.Config, sess *sessioncache.SessionStore) *OAuth { 27 return &OAuth{ 28 store: sessions.NewCookieStore([]byte(config.Core.CookieSecret)), 29 config: config, 30 sess: sess, 31 } 32} 33 34func (o *OAuth) Stores() *sessions.CookieStore { 35 return o.store 36} 37 38func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, oreq sessioncache.OAuthRequest, oresp *oauth.TokenResponse) error { 39 // first we save the did in the user session 40 userSession, err := o.store.Get(r, SessionName) 41 if err != nil { 42 return err 43 } 44 45 userSession.Values[SessionDid] = oreq.Did 46 userSession.Values[SessionHandle] = oreq.Handle 47 userSession.Values[SessionPds] = oreq.PdsUrl 48 userSession.Values[SessionAuthenticated] = true 49 err = userSession.Save(r, w) 50 if err != nil { 51 return fmt.Errorf("error saving user session: %w", err) 52 } 53 54 // then save the whole thing in the db 55 session := sessioncache.OAuthSession{ 56 Did: oreq.Did, 57 Handle: oreq.Handle, 58 PdsUrl: oreq.PdsUrl, 59 DpopAuthserverNonce: oreq.DpopAuthserverNonce, 60 AuthServerIss: oreq.AuthserverIss, 61 DpopPrivateJwk: oreq.DpopPrivateJwk, 62 AccessJwt: oresp.AccessToken, 63 RefreshJwt: oresp.RefreshToken, 64 Expiry: time.Now().Add(time.Duration(oresp.ExpiresIn) * time.Second).Format(time.RFC3339), 65 } 66 67 return o.sess.SaveSession(r.Context(), session) 68} 69 70func (o *OAuth) ClearSession(r *http.Request, w http.ResponseWriter) error { 71 userSession, err := o.store.Get(r, SessionName) 72 if err != nil || userSession.IsNew { 73 return fmt.Errorf("error getting user session (or new session?): %w", err) 74 } 75 76 did := userSession.Values[SessionDid].(string) 77 78 err = o.sess.DeleteSession(r.Context(), did) 79 if err != nil { 80 return fmt.Errorf("error deleting oauth session: %w", err) 81 } 82 83 userSession.Options.MaxAge = -1 84 85 return userSession.Save(r, w) 86} 87 88func (o *OAuth) GetSession(r *http.Request) (*sessioncache.OAuthSession, bool, error) { 89 userSession, err := o.store.Get(r, SessionName) 90 if err != nil || userSession.IsNew { 91 return nil, false, fmt.Errorf("error getting user session (or new session?): %w", err) 92 } 93 94 did := userSession.Values[SessionDid].(string) 95 auth := userSession.Values[SessionAuthenticated].(bool) 96 97 session, err := o.sess.GetSession(r.Context(), did) 98 if err != nil { 99 return nil, false, fmt.Errorf("error getting oauth session: %w", err) 100 } 101 102 expiry, err := time.Parse(time.RFC3339, session.Expiry) 103 if err != nil { 104 return nil, false, fmt.Errorf("error parsing expiry time: %w", err) 105 } 106 if time.Until(expiry) <= 5*time.Minute { 107 privateJwk, err := helpers.ParseJWKFromBytes([]byte(session.DpopPrivateJwk)) 108 if err != nil { 109 return nil, false, err 110 } 111 112 self := o.ClientMetadata() 113 114 oauthClient, err := client.NewClient( 115 self.ClientID, 116 o.config.OAuth.Jwks, 117 self.RedirectURIs[0], 118 ) 119 120 if err != nil { 121 return nil, false, err 122 } 123 124 resp, err := oauthClient.RefreshTokenRequest(r.Context(), session.RefreshJwt, session.AuthServerIss, session.DpopAuthserverNonce, privateJwk) 125 if err != nil { 126 return nil, false, err 127 } 128 129 newExpiry := time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second).Format(time.RFC3339) 130 err = o.sess.RefreshSession(r.Context(), did, resp.AccessToken, resp.RefreshToken, newExpiry) 131 if err != nil { 132 return nil, false, fmt.Errorf("error refreshing oauth session: %w", err) 133 } 134 135 // update the current session 136 session.AccessJwt = resp.AccessToken 137 session.RefreshJwt = resp.RefreshToken 138 session.DpopAuthserverNonce = resp.DpopAuthserverNonce 139 session.Expiry = newExpiry 140 } 141 142 return session, auth, nil 143} 144 145type User struct { 146 Handle string 147 Did string 148 Pds string 149} 150 151func (a *OAuth) GetUser(r *http.Request) *User { 152 clientSession, err := a.store.Get(r, SessionName) 153 154 if err != nil || clientSession.IsNew { 155 return nil 156 } 157 158 return &User{ 159 Handle: clientSession.Values[SessionHandle].(string), 160 Did: clientSession.Values[SessionDid].(string), 161 Pds: clientSession.Values[SessionPds].(string), 162 } 163} 164 165func (a *OAuth) GetDid(r *http.Request) string { 166 clientSession, err := a.store.Get(r, SessionName) 167 168 if err != nil || clientSession.IsNew { 169 return "" 170 } 171 172 return clientSession.Values[SessionDid].(string) 173} 174 175func (o *OAuth) AuthorizedClient(r *http.Request) (*xrpc.Client, error) { 176 session, auth, err := o.GetSession(r) 177 if err != nil { 178 return nil, fmt.Errorf("error getting session: %w", err) 179 } 180 if !auth { 181 return nil, fmt.Errorf("not authorized") 182 } 183 184 client := &oauth.XrpcClient{ 185 OnDpopPdsNonceChanged: func(did, newNonce string) { 186 err := o.sess.UpdateNonce(r.Context(), did, newNonce) 187 if err != nil { 188 log.Printf("error updating dpop pds nonce: %v", err) 189 } 190 }, 191 } 192 193 privateJwk, err := helpers.ParseJWKFromBytes([]byte(session.DpopPrivateJwk)) 194 if err != nil { 195 return nil, fmt.Errorf("error parsing private jwk: %w", err) 196 } 197 198 xrpcClient := xrpc.NewClient(client, &oauth.XrpcAuthedRequestArgs{ 199 Did: session.Did, 200 PdsUrl: session.PdsUrl, 201 DpopPdsNonce: session.PdsUrl, 202 AccessToken: session.AccessJwt, 203 Issuer: session.AuthServerIss, 204 DpopPrivateJwk: privateJwk, 205 }) 206 207 return xrpcClient, nil 208} 209 210// use this to create a client to communicate with knots or spindles 211// 212// this is a higher level abstraction on ServerGetServiceAuth 213type ServiceClientOpts struct { 214 service string 215 exp int64 216 lxm string 217 dev bool 218} 219 220type ServiceClientOpt func(*ServiceClientOpts) 221 222func WithService(service string) ServiceClientOpt { 223 return func(s *ServiceClientOpts) { 224 s.service = service 225 } 226} 227 228// Specify the Duration in seconds for the expiry of this token 229// 230// The time of expiry is calculated as time.Now().Unix() + exp 231func WithExp(exp int64) ServiceClientOpt { 232 return func(s *ServiceClientOpts) { 233 s.exp = time.Now().Unix() + exp 234 } 235} 236 237func WithLxm(lxm string) ServiceClientOpt { 238 return func(s *ServiceClientOpts) { 239 s.lxm = lxm 240 } 241} 242 243func WithDev(dev bool) ServiceClientOpt { 244 return func(s *ServiceClientOpts) { 245 s.dev = dev 246 } 247} 248 249func (s *ServiceClientOpts) Audience() string { 250 return fmt.Sprintf("did:web:%s", s.service) 251} 252 253func (s *ServiceClientOpts) Host() string { 254 scheme := "https://" 255 if s.dev { 256 scheme = "http://" 257 } 258 259 return scheme + s.service 260} 261 262func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*indigo_xrpc.Client, error) { 263 opts := ServiceClientOpts{} 264 for _, o := range os { 265 o(&opts) 266 } 267 268 authorizedClient, err := o.AuthorizedClient(r) 269 if err != nil { 270 return nil, err 271 } 272 273 // force expiry to atleast 60 seconds in the future 274 sixty := time.Now().Unix() + 60 275 if opts.exp < sixty { 276 opts.exp = sixty 277 } 278 279 resp, err := authorizedClient.ServerGetServiceAuth(r.Context(), opts.Audience(), opts.exp, opts.lxm) 280 if err != nil { 281 return nil, err 282 } 283 284 return &indigo_xrpc.Client{ 285 Auth: &indigo_xrpc.AuthInfo{ 286 AccessJwt: resp.Token, 287 }, 288 Host: opts.Host(), 289 Client: &http.Client{ 290 Timeout: time.Second * 5, 291 }, 292 }, nil 293} 294 295type ClientMetadata struct { 296 ClientID string `json:"client_id"` 297 ClientName string `json:"client_name"` 298 SubjectType string `json:"subject_type"` 299 ClientURI string `json:"client_uri"` 300 RedirectURIs []string `json:"redirect_uris"` 301 GrantTypes []string `json:"grant_types"` 302 ResponseTypes []string `json:"response_types"` 303 ApplicationType string `json:"application_type"` 304 DpopBoundAccessTokens bool `json:"dpop_bound_access_tokens"` 305 JwksURI string `json:"jwks_uri"` 306 Scope string `json:"scope"` 307 TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` 308 TokenEndpointAuthSigningAlg string `json:"token_endpoint_auth_signing_alg"` 309} 310 311func (o *OAuth) ClientMetadata() ClientMetadata { 312 makeRedirectURIs := func(c string) []string { 313 return []string{fmt.Sprintf("%s/oauth/callback", c)} 314 } 315 316 clientURI := o.config.Core.AppviewHost 317 clientID := fmt.Sprintf("%s/oauth/client-metadata.json", clientURI) 318 redirectURIs := makeRedirectURIs(clientURI) 319 320 if o.config.Core.Dev { 321 clientURI = "http://127.0.0.1:3000" 322 redirectURIs = makeRedirectURIs(clientURI) 323 324 query := url.Values{} 325 query.Add("redirect_uri", redirectURIs[0]) 326 query.Add("scope", "atproto transition:generic") 327 clientID = fmt.Sprintf("http://localhost?%s", query.Encode()) 328 } 329 330 jwksURI := fmt.Sprintf("%s/oauth/jwks.json", clientURI) 331 332 return ClientMetadata{ 333 ClientID: clientID, 334 ClientName: "Tangled", 335 SubjectType: "public", 336 ClientURI: clientURI, 337 RedirectURIs: redirectURIs, 338 GrantTypes: []string{"authorization_code", "refresh_token"}, 339 ResponseTypes: []string{"code"}, 340 ApplicationType: "web", 341 DpopBoundAccessTokens: true, 342 JwksURI: jwksURI, 343 Scope: "atproto transition:generic", 344 TokenEndpointAuthMethod: "private_key_jwt", 345 TokenEndpointAuthSigningAlg: "ES256", 346 } 347}