forked from tangled.org/core
this repo has no description
1package oauth 2 3import ( 4 "fmt" 5 "log" 6 "net/http" 7 "net/url" 8 "time" 9 10 "github.com/gorilla/sessions" 11 oauth "tangled.sh/icyphox.sh/atproto-oauth" 12 "tangled.sh/icyphox.sh/atproto-oauth/helpers" 13 "tangled.sh/tangled.sh/core/appview" 14 "tangled.sh/tangled.sh/core/appview/config" 15 "tangled.sh/tangled.sh/core/appview/db" 16 "tangled.sh/tangled.sh/core/appview/oauth/client" 17 xrpc "tangled.sh/tangled.sh/core/appview/xrpcclient" 18) 19 20type OAuthRequest struct { 21 ID uint 22 AuthserverIss string 23 State string 24 Did string 25 PdsUrl string 26 PkceVerifier string 27 DpopAuthserverNonce string 28 DpopPrivateJwk string 29} 30 31type OAuth struct { 32 Store *sessions.CookieStore 33 Db *db.DB 34 Config *config.Config 35} 36 37func NewOAuth(db *db.DB, config *config.Config) *OAuth { 38 return &OAuth{ 39 Store: sessions.NewCookieStore([]byte(config.Core.CookieSecret)), 40 Db: db, 41 Config: config, 42 } 43} 44 45func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, oreq db.OAuthRequest, oresp *oauth.TokenResponse) error { 46 // first we save the did in the user session 47 userSession, err := o.Store.Get(r, appview.SessionName) 48 if err != nil { 49 return err 50 } 51 52 userSession.Values[appview.SessionDid] = oreq.Did 53 userSession.Values[appview.SessionHandle] = oreq.Handle 54 userSession.Values[appview.SessionPds] = oreq.PdsUrl 55 userSession.Values[appview.SessionAuthenticated] = true 56 err = userSession.Save(r, w) 57 if err != nil { 58 return fmt.Errorf("error saving user session: %w", err) 59 } 60 61 // then save the whole thing in the db 62 session := db.OAuthSession{ 63 Did: oreq.Did, 64 Handle: oreq.Handle, 65 PdsUrl: oreq.PdsUrl, 66 DpopAuthserverNonce: oreq.DpopAuthserverNonce, 67 AuthServerIss: oreq.AuthserverIss, 68 DpopPrivateJwk: oreq.DpopPrivateJwk, 69 AccessJwt: oresp.AccessToken, 70 RefreshJwt: oresp.RefreshToken, 71 Expiry: time.Now().Add(time.Duration(oresp.ExpiresIn) * time.Second).Format(time.RFC3339), 72 } 73 74 return db.SaveOAuthSession(o.Db, session) 75} 76 77func (o *OAuth) ClearSession(r *http.Request, w http.ResponseWriter) error { 78 userSession, err := o.Store.Get(r, appview.SessionName) 79 if err != nil || userSession.IsNew { 80 return fmt.Errorf("error getting user session (or new session?): %w", err) 81 } 82 83 did := userSession.Values[appview.SessionDid].(string) 84 85 err = db.DeleteOAuthSessionByDid(o.Db, did) 86 if err != nil { 87 return fmt.Errorf("error deleting oauth session: %w", err) 88 } 89 90 userSession.Options.MaxAge = -1 91 92 return userSession.Save(r, w) 93} 94 95func (o *OAuth) GetSession(r *http.Request) (*db.OAuthSession, bool, error) { 96 userSession, err := o.Store.Get(r, appview.SessionName) 97 if err != nil || userSession.IsNew { 98 return nil, false, fmt.Errorf("error getting user session (or new session?): %w", err) 99 } 100 101 did := userSession.Values[appview.SessionDid].(string) 102 auth := userSession.Values[appview.SessionAuthenticated].(bool) 103 104 session, err := db.GetOAuthSessionByDid(o.Db, did) 105 if err != nil { 106 return nil, false, fmt.Errorf("error getting oauth session: %w", err) 107 } 108 109 expiry, err := time.Parse(time.RFC3339, session.Expiry) 110 if err != nil { 111 return nil, false, fmt.Errorf("error parsing expiry time: %w", err) 112 } 113 if expiry.Sub(time.Now()) <= 5*time.Minute { 114 privateJwk, err := helpers.ParseJWKFromBytes([]byte(session.DpopPrivateJwk)) 115 if err != nil { 116 return nil, false, err 117 } 118 119 self := o.ClientMetadata() 120 121 oauthClient, err := client.NewClient( 122 self.ClientID, 123 o.Config.OAuth.Jwks, 124 self.RedirectURIs[0], 125 ) 126 127 if err != nil { 128 return nil, false, err 129 } 130 131 resp, err := oauthClient.RefreshTokenRequest(r.Context(), session.RefreshJwt, session.AuthServerIss, session.DpopAuthserverNonce, privateJwk) 132 if err != nil { 133 return nil, false, err 134 } 135 136 newExpiry := time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second).Format(time.RFC3339) 137 err = db.RefreshOAuthSession(o.Db, did, resp.AccessToken, resp.RefreshToken, newExpiry) 138 if err != nil { 139 return nil, false, fmt.Errorf("error refreshing oauth session: %w", err) 140 } 141 142 // update the current session 143 session.AccessJwt = resp.AccessToken 144 session.RefreshJwt = resp.RefreshToken 145 session.DpopAuthserverNonce = resp.DpopAuthserverNonce 146 session.Expiry = newExpiry 147 } 148 149 return session, auth, nil 150} 151 152type User struct { 153 Handle string 154 Did string 155 Pds string 156} 157 158func (a *OAuth) GetUser(r *http.Request) *User { 159 clientSession, err := a.Store.Get(r, appview.SessionName) 160 161 if err != nil || clientSession.IsNew { 162 return nil 163 } 164 165 return &User{ 166 Handle: clientSession.Values[appview.SessionHandle].(string), 167 Did: clientSession.Values[appview.SessionDid].(string), 168 Pds: clientSession.Values[appview.SessionPds].(string), 169 } 170} 171 172func (a *OAuth) GetDid(r *http.Request) string { 173 clientSession, err := a.Store.Get(r, appview.SessionName) 174 175 if err != nil || clientSession.IsNew { 176 return "" 177 } 178 179 return clientSession.Values[appview.SessionDid].(string) 180} 181 182func (o *OAuth) AuthorizedClient(r *http.Request) (*xrpc.Client, error) { 183 session, auth, err := o.GetSession(r) 184 if err != nil { 185 return nil, fmt.Errorf("error getting session: %w", err) 186 } 187 if !auth { 188 return nil, fmt.Errorf("not authorized") 189 } 190 191 client := &oauth.XrpcClient{ 192 OnDpopPdsNonceChanged: func(did, newNonce string) { 193 err := db.UpdateDpopPdsNonce(o.Db, did, newNonce) 194 if err != nil { 195 log.Printf("error updating dpop pds nonce: %v", err) 196 } 197 }, 198 } 199 200 privateJwk, err := helpers.ParseJWKFromBytes([]byte(session.DpopPrivateJwk)) 201 if err != nil { 202 return nil, fmt.Errorf("error parsing private jwk: %w", err) 203 } 204 205 xrpcClient := xrpc.NewClient(client, &oauth.XrpcAuthedRequestArgs{ 206 Did: session.Did, 207 PdsUrl: session.PdsUrl, 208 DpopPdsNonce: session.PdsUrl, 209 AccessToken: session.AccessJwt, 210 Issuer: session.AuthServerIss, 211 DpopPrivateJwk: privateJwk, 212 }) 213 214 return xrpcClient, nil 215} 216 217type ClientMetadata struct { 218 ClientID string `json:"client_id"` 219 ClientName string `json:"client_name"` 220 SubjectType string `json:"subject_type"` 221 ClientURI string `json:"client_uri"` 222 RedirectURIs []string `json:"redirect_uris"` 223 GrantTypes []string `json:"grant_types"` 224 ResponseTypes []string `json:"response_types"` 225 ApplicationType string `json:"application_type"` 226 DpopBoundAccessTokens bool `json:"dpop_bound_access_tokens"` 227 JwksURI string `json:"jwks_uri"` 228 Scope string `json:"scope"` 229 TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` 230 TokenEndpointAuthSigningAlg string `json:"token_endpoint_auth_signing_alg"` 231} 232 233func (o *OAuth) ClientMetadata() ClientMetadata { 234 makeRedirectURIs := func(c string) []string { 235 return []string{fmt.Sprintf("%s/oauth/callback", c)} 236 } 237 238 clientURI := o.Config.Core.AppviewHost 239 clientID := fmt.Sprintf("%s/oauth/client-metadata.json", clientURI) 240 redirectURIs := makeRedirectURIs(clientURI) 241 242 if o.Config.Core.Dev { 243 clientURI = fmt.Sprintf("http://127.0.0.1:3000") 244 redirectURIs = makeRedirectURIs(clientURI) 245 246 query := url.Values{} 247 query.Add("redirect_uri", redirectURIs[0]) 248 query.Add("scope", "atproto transition:generic") 249 clientID = fmt.Sprintf("http://localhost?%s", query.Encode()) 250 } 251 252 jwksURI := fmt.Sprintf("%s/oauth/jwks.json", clientURI) 253 254 return ClientMetadata{ 255 ClientID: clientID, 256 ClientName: "Tangled", 257 SubjectType: "public", 258 ClientURI: clientURI, 259 RedirectURIs: redirectURIs, 260 GrantTypes: []string{"authorization_code", "refresh_token"}, 261 ResponseTypes: []string{"code"}, 262 ApplicationType: "web", 263 DpopBoundAccessTokens: true, 264 JwksURI: jwksURI, 265 Scope: "atproto transition:generic", 266 TokenEndpointAuthMethod: "private_key_jwt", 267 TokenEndpointAuthSigningAlg: "ES256", 268 } 269}