forked from tangled.org/core
this repo has no description
at test-ci 7.9 kB view raw
1package oauth 2 3import ( 4 "fmt" 5 "log" 6 "net/http" 7 "net/url" 8 "time" 9 10 "github.com/gorilla/sessions" 11 oauth "tangled.sh/icyphox.sh/atproto-oauth" 12 "tangled.sh/icyphox.sh/atproto-oauth/helpers" 13 sessioncache "tangled.sh/tangled.sh/core/appview/cache/session" 14 "tangled.sh/tangled.sh/core/appview/config" 15 "tangled.sh/tangled.sh/core/appview/oauth/client" 16 xrpc "tangled.sh/tangled.sh/core/appview/xrpcclient" 17) 18 19type OAuth struct { 20 store *sessions.CookieStore 21 config *config.Config 22 sess *sessioncache.SessionStore 23} 24 25func NewOAuth(config *config.Config, sess *sessioncache.SessionStore) *OAuth { 26 return &OAuth{ 27 store: sessions.NewCookieStore([]byte(config.Core.CookieSecret)), 28 config: config, 29 sess: sess, 30 } 31} 32 33func (o *OAuth) Stores() *sessions.CookieStore { 34 return o.store 35} 36 37func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, oreq sessioncache.OAuthRequest, oresp *oauth.TokenResponse) error { 38 // first we save the did in the user session 39 userSession, err := o.store.Get(r, SessionName) 40 if err != nil { 41 return err 42 } 43 44 userSession.Values[SessionDid] = oreq.Did 45 userSession.Values[SessionHandle] = oreq.Handle 46 userSession.Values[SessionPds] = oreq.PdsUrl 47 userSession.Values[SessionAuthenticated] = true 48 err = userSession.Save(r, w) 49 if err != nil { 50 return fmt.Errorf("error saving user session: %w", err) 51 } 52 53 // then save the whole thing in the db 54 session := sessioncache.OAuthSession{ 55 Did: oreq.Did, 56 Handle: oreq.Handle, 57 PdsUrl: oreq.PdsUrl, 58 DpopAuthserverNonce: oreq.DpopAuthserverNonce, 59 AuthServerIss: oreq.AuthserverIss, 60 DpopPrivateJwk: oreq.DpopPrivateJwk, 61 AccessJwt: oresp.AccessToken, 62 RefreshJwt: oresp.RefreshToken, 63 Expiry: time.Now().Add(time.Duration(oresp.ExpiresIn) * time.Second).Format(time.RFC3339), 64 } 65 66 return o.sess.SaveSession(r.Context(), session) 67} 68 69func (o *OAuth) ClearSession(r *http.Request, w http.ResponseWriter) error { 70 userSession, err := o.store.Get(r, SessionName) 71 if err != nil || userSession.IsNew { 72 return fmt.Errorf("error getting user session (or new session?): %w", err) 73 } 74 75 did := userSession.Values[SessionDid].(string) 76 77 err = o.sess.DeleteSession(r.Context(), did) 78 if err != nil { 79 return fmt.Errorf("error deleting oauth session: %w", err) 80 } 81 82 userSession.Options.MaxAge = -1 83 84 return userSession.Save(r, w) 85} 86 87func (o *OAuth) GetSession(r *http.Request) (*sessioncache.OAuthSession, bool, error) { 88 userSession, err := o.store.Get(r, SessionName) 89 if err != nil || userSession.IsNew { 90 return nil, false, fmt.Errorf("error getting user session (or new session?): %w", err) 91 } 92 93 did := userSession.Values[SessionDid].(string) 94 auth := userSession.Values[SessionAuthenticated].(bool) 95 96 session, err := o.sess.GetSession(r.Context(), did) 97 if err != nil { 98 return nil, false, fmt.Errorf("error getting oauth session: %w", err) 99 } 100 101 expiry, err := time.Parse(time.RFC3339, session.Expiry) 102 if err != nil { 103 return nil, false, fmt.Errorf("error parsing expiry time: %w", err) 104 } 105 if expiry.Sub(time.Now()) <= 5*time.Minute { 106 privateJwk, err := helpers.ParseJWKFromBytes([]byte(session.DpopPrivateJwk)) 107 if err != nil { 108 return nil, false, err 109 } 110 111 self := o.ClientMetadata() 112 113 oauthClient, err := client.NewClient( 114 self.ClientID, 115 o.config.OAuth.Jwks, 116 self.RedirectURIs[0], 117 ) 118 119 if err != nil { 120 return nil, false, err 121 } 122 123 resp, err := oauthClient.RefreshTokenRequest(r.Context(), session.RefreshJwt, session.AuthServerIss, session.DpopAuthserverNonce, privateJwk) 124 if err != nil { 125 return nil, false, err 126 } 127 128 newExpiry := time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second).Format(time.RFC3339) 129 err = o.sess.RefreshSession(r.Context(), did, resp.AccessToken, resp.RefreshToken, newExpiry) 130 if err != nil { 131 return nil, false, fmt.Errorf("error refreshing oauth session: %w", err) 132 } 133 134 // update the current session 135 session.AccessJwt = resp.AccessToken 136 session.RefreshJwt = resp.RefreshToken 137 session.DpopAuthserverNonce = resp.DpopAuthserverNonce 138 session.Expiry = newExpiry 139 } 140 141 return session, auth, nil 142} 143 144type User struct { 145 Handle string 146 Did string 147 Pds string 148} 149 150func (a *OAuth) GetUser(r *http.Request) *User { 151 clientSession, err := a.store.Get(r, SessionName) 152 153 if err != nil || clientSession.IsNew { 154 return nil 155 } 156 157 return &User{ 158 Handle: clientSession.Values[SessionHandle].(string), 159 Did: clientSession.Values[SessionDid].(string), 160 Pds: clientSession.Values[SessionPds].(string), 161 } 162} 163 164func (a *OAuth) GetDid(r *http.Request) string { 165 clientSession, err := a.store.Get(r, SessionName) 166 167 if err != nil || clientSession.IsNew { 168 return "" 169 } 170 171 return clientSession.Values[SessionDid].(string) 172} 173 174func (o *OAuth) AuthorizedClient(r *http.Request) (*xrpc.Client, error) { 175 session, auth, err := o.GetSession(r) 176 if err != nil { 177 return nil, fmt.Errorf("error getting session: %w", err) 178 } 179 if !auth { 180 return nil, fmt.Errorf("not authorized") 181 } 182 183 client := &oauth.XrpcClient{ 184 OnDpopPdsNonceChanged: func(did, newNonce string) { 185 err := o.sess.UpdateNonce(r.Context(), did, newNonce) 186 if err != nil { 187 log.Printf("error updating dpop pds nonce: %v", err) 188 } 189 }, 190 } 191 192 privateJwk, err := helpers.ParseJWKFromBytes([]byte(session.DpopPrivateJwk)) 193 if err != nil { 194 return nil, fmt.Errorf("error parsing private jwk: %w", err) 195 } 196 197 xrpcClient := xrpc.NewClient(client, &oauth.XrpcAuthedRequestArgs{ 198 Did: session.Did, 199 PdsUrl: session.PdsUrl, 200 DpopPdsNonce: session.PdsUrl, 201 AccessToken: session.AccessJwt, 202 Issuer: session.AuthServerIss, 203 DpopPrivateJwk: privateJwk, 204 }) 205 206 return xrpcClient, nil 207} 208 209type ClientMetadata struct { 210 ClientID string `json:"client_id"` 211 ClientName string `json:"client_name"` 212 SubjectType string `json:"subject_type"` 213 ClientURI string `json:"client_uri"` 214 RedirectURIs []string `json:"redirect_uris"` 215 GrantTypes []string `json:"grant_types"` 216 ResponseTypes []string `json:"response_types"` 217 ApplicationType string `json:"application_type"` 218 DpopBoundAccessTokens bool `json:"dpop_bound_access_tokens"` 219 JwksURI string `json:"jwks_uri"` 220 Scope string `json:"scope"` 221 TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` 222 TokenEndpointAuthSigningAlg string `json:"token_endpoint_auth_signing_alg"` 223} 224 225func (o *OAuth) ClientMetadata() ClientMetadata { 226 makeRedirectURIs := func(c string) []string { 227 return []string{fmt.Sprintf("%s/oauth/callback", c)} 228 } 229 230 clientURI := o.config.Core.AppviewHost 231 clientID := fmt.Sprintf("%s/oauth/client-metadata.json", clientURI) 232 redirectURIs := makeRedirectURIs(clientURI) 233 234 if o.config.Core.Dev { 235 clientURI = fmt.Sprintf("http://127.0.0.1:3000") 236 redirectURIs = makeRedirectURIs(clientURI) 237 238 query := url.Values{} 239 query.Add("redirect_uri", redirectURIs[0]) 240 query.Add("scope", "atproto transition:generic") 241 clientID = fmt.Sprintf("http://localhost?%s", query.Encode()) 242 } 243 244 jwksURI := fmt.Sprintf("%s/oauth/jwks.json", clientURI) 245 246 return ClientMetadata{ 247 ClientID: clientID, 248 ClientName: "Tangled", 249 SubjectType: "public", 250 ClientURI: clientURI, 251 RedirectURIs: redirectURIs, 252 GrantTypes: []string{"authorization_code", "refresh_token"}, 253 ResponseTypes: []string{"code"}, 254 ApplicationType: "web", 255 DpopBoundAccessTokens: true, 256 JwksURI: jwksURI, 257 Scope: "atproto transition:generic", 258 TokenEndpointAuthMethod: "private_key_jwt", 259 TokenEndpointAuthSigningAlg: "ES256", 260 } 261}