forked from tangled.org/core
this repo has no description
1package oauth 2 3import ( 4 "fmt" 5 "log" 6 "net/http" 7 "net/url" 8 "time" 9 10 indigo_xrpc "github.com/bluesky-social/indigo/xrpc" 11 "github.com/gorilla/sessions" 12 oauth "tangled.sh/icyphox.sh/atproto-oauth" 13 "tangled.sh/icyphox.sh/atproto-oauth/helpers" 14 sessioncache "tangled.sh/tangled.sh/core/appview/cache/session" 15 "tangled.sh/tangled.sh/core/appview/config" 16 "tangled.sh/tangled.sh/core/appview/oauth/client" 17 xrpc "tangled.sh/tangled.sh/core/appview/xrpcclient" 18) 19 20type OAuth struct { 21 store *sessions.CookieStore 22 config *config.Config 23 sess *sessioncache.SessionStore 24} 25 26func NewOAuth(config *config.Config, sess *sessioncache.SessionStore) *OAuth { 27 return &OAuth{ 28 store: sessions.NewCookieStore([]byte(config.Core.CookieSecret)), 29 config: config, 30 sess: sess, 31 } 32} 33 34func (o *OAuth) Stores() *sessions.CookieStore { 35 return o.store 36} 37 38func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, oreq sessioncache.OAuthRequest, oresp *oauth.TokenResponse) error { 39 // first we save the did in the user session 40 userSession, err := o.store.Get(r, SessionName) 41 if err != nil { 42 return err 43 } 44 45 userSession.Values[SessionDid] = oreq.Did 46 userSession.Values[SessionHandle] = oreq.Handle 47 userSession.Values[SessionPds] = oreq.PdsUrl 48 userSession.Values[SessionAuthenticated] = true 49 err = userSession.Save(r, w) 50 if err != nil { 51 return fmt.Errorf("error saving user session: %w", err) 52 } 53 54 // then save the whole thing in the db 55 session := sessioncache.OAuthSession{ 56 Did: oreq.Did, 57 Handle: oreq.Handle, 58 PdsUrl: oreq.PdsUrl, 59 DpopAuthserverNonce: oreq.DpopAuthserverNonce, 60 AuthServerIss: oreq.AuthserverIss, 61 DpopPrivateJwk: oreq.DpopPrivateJwk, 62 AccessJwt: oresp.AccessToken, 63 RefreshJwt: oresp.RefreshToken, 64 Expiry: time.Now().Add(time.Duration(oresp.ExpiresIn) * time.Second).Format(time.RFC3339), 65 } 66 67 return o.sess.SaveSession(r.Context(), session) 68} 69 70func (o *OAuth) ClearSession(r *http.Request, w http.ResponseWriter) error { 71 userSession, err := o.store.Get(r, SessionName) 72 if err != nil || userSession.IsNew { 73 return fmt.Errorf("error getting user session (or new session?): %w", err) 74 } 75 76 did := userSession.Values[SessionDid].(string) 77 78 err = o.sess.DeleteSession(r.Context(), did) 79 if err != nil { 80 return fmt.Errorf("error deleting oauth session: %w", err) 81 } 82 83 userSession.Options.MaxAge = -1 84 85 return userSession.Save(r, w) 86} 87 88func (o *OAuth) GetSession(r *http.Request) (*sessioncache.OAuthSession, bool, error) { 89 userSession, err := o.store.Get(r, SessionName) 90 if err != nil || userSession.IsNew { 91 return nil, false, fmt.Errorf("error getting user session (or new session?): %w", err) 92 } 93 94 did := userSession.Values[SessionDid].(string) 95 auth := userSession.Values[SessionAuthenticated].(bool) 96 97 session, err := o.sess.GetSession(r.Context(), did) 98 if err != nil { 99 return nil, false, fmt.Errorf("error getting oauth session: %w", err) 100 } 101 102 expiry, err := time.Parse(time.RFC3339, session.Expiry) 103 if err != nil { 104 return nil, false, fmt.Errorf("error parsing expiry time: %w", err) 105 } 106 if expiry.Sub(time.Now()) <= 5*time.Minute { 107 privateJwk, err := helpers.ParseJWKFromBytes([]byte(session.DpopPrivateJwk)) 108 if err != nil { 109 return nil, false, err 110 } 111 112 self := o.ClientMetadata() 113 114 oauthClient, err := client.NewClient( 115 self.ClientID, 116 o.config.OAuth.Jwks, 117 self.RedirectURIs[0], 118 ) 119 120 if err != nil { 121 return nil, false, err 122 } 123 124 resp, err := oauthClient.RefreshTokenRequest(r.Context(), session.RefreshJwt, session.AuthServerIss, session.DpopAuthserverNonce, privateJwk) 125 if err != nil { 126 return nil, false, err 127 } 128 129 newExpiry := time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second).Format(time.RFC3339) 130 err = o.sess.RefreshSession(r.Context(), did, resp.AccessToken, resp.RefreshToken, newExpiry) 131 if err != nil { 132 return nil, false, fmt.Errorf("error refreshing oauth session: %w", err) 133 } 134 135 // update the current session 136 session.AccessJwt = resp.AccessToken 137 session.RefreshJwt = resp.RefreshToken 138 session.DpopAuthserverNonce = resp.DpopAuthserverNonce 139 session.Expiry = newExpiry 140 } 141 142 return session, auth, nil 143} 144 145type User struct { 146 Handle string 147 Did string 148 Pds string 149} 150 151func (a *OAuth) GetUser(r *http.Request) *User { 152 clientSession, err := a.store.Get(r, SessionName) 153 154 if err != nil || clientSession.IsNew { 155 return nil 156 } 157 158 return &User{ 159 Handle: clientSession.Values[SessionHandle].(string), 160 Did: clientSession.Values[SessionDid].(string), 161 Pds: clientSession.Values[SessionPds].(string), 162 } 163} 164 165func (a *OAuth) GetDid(r *http.Request) string { 166 clientSession, err := a.store.Get(r, SessionName) 167 168 if err != nil || clientSession.IsNew { 169 return "" 170 } 171 172 return clientSession.Values[SessionDid].(string) 173} 174 175func (o *OAuth) AuthorizedClient(r *http.Request) (*xrpc.Client, error) { 176 session, auth, err := o.GetSession(r) 177 if err != nil { 178 return nil, fmt.Errorf("error getting session: %w", err) 179 } 180 if !auth { 181 return nil, fmt.Errorf("not authorized") 182 } 183 184 client := &oauth.XrpcClient{ 185 OnDpopPdsNonceChanged: func(did, newNonce string) { 186 err := o.sess.UpdateNonce(r.Context(), did, newNonce) 187 if err != nil { 188 log.Printf("error updating dpop pds nonce: %v", err) 189 } 190 }, 191 } 192 193 privateJwk, err := helpers.ParseJWKFromBytes([]byte(session.DpopPrivateJwk)) 194 if err != nil { 195 return nil, fmt.Errorf("error parsing private jwk: %w", err) 196 } 197 198 xrpcClient := xrpc.NewClient(client, &oauth.XrpcAuthedRequestArgs{ 199 Did: session.Did, 200 PdsUrl: session.PdsUrl, 201 DpopPdsNonce: session.PdsUrl, 202 AccessToken: session.AccessJwt, 203 Issuer: session.AuthServerIss, 204 DpopPrivateJwk: privateJwk, 205 }) 206 207 return xrpcClient, nil 208} 209 210// use this to create a client to communicate with knots or spindles 211// 212// this is a higher level abstraction on ServerGetServiceAuth 213type ServiceClientOpts struct { 214 service string 215 exp int64 216 lxm string 217 dev bool 218} 219 220type ServiceClientOpt func(*ServiceClientOpts) 221 222func WithService(service string) ServiceClientOpt { 223 return func(s *ServiceClientOpts) { 224 s.service = service 225 } 226} 227func WithExp(exp int64) ServiceClientOpt { 228 return func(s *ServiceClientOpts) { 229 s.exp = exp 230 } 231} 232 233func WithLxm(lxm string) ServiceClientOpt { 234 return func(s *ServiceClientOpts) { 235 s.lxm = lxm 236 } 237} 238 239func WithDev(dev bool) ServiceClientOpt { 240 return func(s *ServiceClientOpts) { 241 s.dev = dev 242 } 243} 244 245func (s *ServiceClientOpts) Audience() string { 246 return fmt.Sprintf("did:web:%s", s.service) 247} 248 249func (s *ServiceClientOpts) Host() string { 250 scheme := "https://" 251 if s.dev { 252 scheme = "http://" 253 } 254 255 return scheme + s.service 256} 257 258func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*indigo_xrpc.Client, error) { 259 opts := ServiceClientOpts{} 260 for _, o := range os { 261 o(&opts) 262 } 263 264 authorizedClient, err := o.AuthorizedClient(r) 265 if err != nil { 266 return nil, err 267 } 268 269 resp, err := authorizedClient.ServerGetServiceAuth(r.Context(), opts.Audience(), opts.exp, opts.lxm) 270 if err != nil { 271 return nil, err 272 } 273 274 return &indigo_xrpc.Client{ 275 Auth: &indigo_xrpc.AuthInfo{ 276 AccessJwt: resp.Token, 277 }, 278 Host: opts.Host(), 279 }, nil 280} 281 282type ClientMetadata struct { 283 ClientID string `json:"client_id"` 284 ClientName string `json:"client_name"` 285 SubjectType string `json:"subject_type"` 286 ClientURI string `json:"client_uri"` 287 RedirectURIs []string `json:"redirect_uris"` 288 GrantTypes []string `json:"grant_types"` 289 ResponseTypes []string `json:"response_types"` 290 ApplicationType string `json:"application_type"` 291 DpopBoundAccessTokens bool `json:"dpop_bound_access_tokens"` 292 JwksURI string `json:"jwks_uri"` 293 Scope string `json:"scope"` 294 TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` 295 TokenEndpointAuthSigningAlg string `json:"token_endpoint_auth_signing_alg"` 296} 297 298func (o *OAuth) ClientMetadata() ClientMetadata { 299 makeRedirectURIs := func(c string) []string { 300 return []string{fmt.Sprintf("%s/oauth/callback", c)} 301 } 302 303 clientURI := o.config.Core.AppviewHost 304 clientID := fmt.Sprintf("%s/oauth/client-metadata.json", clientURI) 305 redirectURIs := makeRedirectURIs(clientURI) 306 307 if o.config.Core.Dev { 308 clientURI = fmt.Sprintf("http://127.0.0.1:3000") 309 redirectURIs = makeRedirectURIs(clientURI) 310 311 query := url.Values{} 312 query.Add("redirect_uri", redirectURIs[0]) 313 query.Add("scope", "atproto transition:generic") 314 clientID = fmt.Sprintf("http://localhost?%s", query.Encode()) 315 } 316 317 jwksURI := fmt.Sprintf("%s/oauth/jwks.json", clientURI) 318 319 return ClientMetadata{ 320 ClientID: clientID, 321 ClientName: "Tangled", 322 SubjectType: "public", 323 ClientURI: clientURI, 324 RedirectURIs: redirectURIs, 325 GrantTypes: []string{"authorization_code", "refresh_token"}, 326 ResponseTypes: []string{"code"}, 327 ApplicationType: "web", 328 DpopBoundAccessTokens: true, 329 JwksURI: jwksURI, 330 Scope: "atproto transition:generic", 331 TokenEndpointAuthMethod: "private_key_jwt", 332 TokenEndpointAuthSigningAlg: "ES256", 333 } 334}