A better Rust ATProto crate
at oauth 12 kB view raw
1use crate::{ 2 atproto::atproto_client_metadata, 3 authstore::ClientAuthStore, 4 dpop::DpopExt, 5 error::{CallbackError, Result}, 6 request::{OAuthMetadata, exchange_code, par}, 7 resolver::OAuthResolver, 8 scopes::Scope, 9 session::{ClientData, ClientSessionData, DpopClientData, SessionRegistry}, 10 types::{AuthorizeOptions, CallbackParams}, 11}; 12use jacquard_common::{ 13 AuthorizationToken, CowStr, IntoStatic, 14 error::{AuthError, ClientError, TransportError, XrpcResult}, 15 http_client::HttpClient, 16 types::{ 17 did::Did, 18 xrpc::{CallOptions, Response, XrpcClient, XrpcExt, XrpcRequest}, 19 }, 20}; 21use jacquard_identity::JacquardResolver; 22use jose_jwk::JwkSet; 23use std::sync::Arc; 24use tokio::sync::RwLock; 25use url::Url; 26 27pub struct OAuthClient<T, S> 28where 29 T: OAuthResolver, 30 S: ClientAuthStore, 31{ 32 pub registry: Arc<SessionRegistry<T, S>>, 33 pub client: Arc<T>, 34} 35 36impl<S: ClientAuthStore> OAuthClient<JacquardResolver, S> { 37 pub fn new(store: S, client_data: ClientData<'static>) -> Self { 38 let client = JacquardResolver::default(); 39 Self::new_from_resolver(store, client, client_data) 40 } 41} 42 43impl<T, S> OAuthClient<T, S> 44where 45 T: OAuthResolver, 46 S: ClientAuthStore, 47{ 48 pub fn new_from_resolver(store: S, client: T, client_data: ClientData<'static>) -> Self { 49 let client = Arc::new(client); 50 let registry = Arc::new(SessionRegistry::new(store, client.clone(), client_data)); 51 Self { registry, client } 52 } 53} 54 55impl<T, S> OAuthClient<T, S> 56where 57 S: ClientAuthStore + Send + Sync + 'static, 58 T: OAuthResolver + DpopExt + Send + Sync + 'static, 59{ 60 pub fn jwks(&self) -> JwkSet { 61 self.registry 62 .client_data 63 .keyset 64 .as_ref() 65 .map(|keyset| keyset.public_jwks()) 66 .unwrap_or_default() 67 } 68 pub async fn start_auth( 69 &self, 70 input: impl AsRef<str>, 71 options: AuthorizeOptions<'_>, 72 ) -> Result<String> { 73 let client_metadata = atproto_client_metadata( 74 self.registry.client_data.config.clone(), 75 &self.registry.client_data.keyset, 76 )?; 77 78 let (server_metadata, identity) = self.client.resolve_oauth(input.as_ref()).await?; 79 let login_hint = if identity.is_some() { 80 Some(input.as_ref().into()) 81 } else { 82 None 83 }; 84 let metadata = OAuthMetadata { 85 server_metadata, 86 client_metadata, 87 keyset: self.registry.client_data.keyset.clone(), 88 }; 89 let auth_req_info = 90 par(self.client.as_ref(), login_hint, options.prompt, &metadata).await?; 91 92 #[derive(serde::Serialize)] 93 struct Parameters<'s> { 94 client_id: Url, 95 request_uri: CowStr<'s>, 96 } 97 Ok(metadata.server_metadata.authorization_endpoint.to_string() 98 + "?" 99 + &serde_html_form::to_string(Parameters { 100 client_id: metadata.client_metadata.client_id.clone(), 101 request_uri: auth_req_info.request_uri, 102 }) 103 .unwrap()) 104 } 105 106 pub async fn callback(&self, params: CallbackParams<'_>) -> Result<OAuthSession<T, S>> { 107 let Some(state_key) = params.state else { 108 return Err(CallbackError::MissingState.into()); 109 }; 110 111 let Some(auth_req_info) = self.registry.store.get_auth_req_info(&state_key).await? else { 112 return Err(CallbackError::MissingState.into()); 113 }; 114 115 self.registry.store.delete_auth_req_info(&state_key).await?; 116 117 let metadata = self 118 .client 119 .get_authorization_server_metadata(&auth_req_info.authserver_url) 120 .await?; 121 122 if let Some(iss) = params.iss { 123 if !crate::resolver::issuer_equivalent(&iss, &metadata.issuer) { 124 return Err(CallbackError::IssuerMismatch { expected: metadata.issuer.to_string(), got: iss.to_string() }.into()); 125 } 126 } else if metadata.authorization_response_iss_parameter_supported == Some(true) { 127 return Err(CallbackError::MissingIssuer.into()); 128 } 129 let metadata = OAuthMetadata { 130 server_metadata: metadata, 131 client_metadata: atproto_client_metadata( 132 self.registry.client_data.config.clone(), 133 &self.registry.client_data.keyset, 134 )?, 135 keyset: self.registry.client_data.keyset.clone(), 136 }; 137 let authserver_nonce = auth_req_info.dpop_data.dpop_authserver_nonce.clone(); 138 139 match exchange_code( 140 self.client.as_ref(), 141 &mut auth_req_info.dpop_data.clone(), 142 &params.code, 143 &auth_req_info.pkce_verifier, 144 &metadata, 145 ) 146 .await 147 { 148 Ok(token_set) => { 149 let scopes = if let Some(scope) = &token_set.scope { 150 Scope::parse_multiple_reduced(&scope) 151 .expect("Failed to parse scopes") 152 .into_static() 153 } else { 154 vec![] 155 }; 156 let client_data = ClientSessionData { 157 account_did: token_set.sub.clone(), 158 session_id: auth_req_info.state, 159 host_url: Url::parse(&token_set.iss).expect("Failed to parse host URL"), 160 authserver_url: auth_req_info.authserver_url, 161 authserver_token_endpoint: auth_req_info.authserver_token_endpoint, 162 authserver_revocation_endpoint: auth_req_info.authserver_revocation_endpoint, 163 scopes, 164 dpop_data: DpopClientData { 165 dpop_key: auth_req_info.dpop_data.dpop_key.clone(), 166 dpop_authserver_nonce: authserver_nonce.unwrap_or(CowStr::default()), 167 dpop_host_nonce: auth_req_info 168 .dpop_data 169 .dpop_authserver_nonce 170 .unwrap_or(CowStr::default()), 171 }, 172 token_set, 173 }; 174 175 self.create_session(client_data).await 176 } 177 Err(e) => Err(e.into()), 178 } 179 } 180 181 async fn create_session(&self, data: ClientSessionData<'_>) -> Result<OAuthSession<T, S>> { 182 Ok(OAuthSession::new( 183 self.registry.clone(), 184 self.client.clone(), 185 data.into_static(), 186 )) 187 } 188 189 pub async fn restore(&self, did: &Did<'_>, session_id: &str) -> Result<OAuthSession<T, S>> { 190 self.create_session(self.registry.get(did, session_id, false).await?) 191 .await 192 } 193 194 pub async fn revoke(&self, did: &Did<'_>, session_id: &str) -> Result<()> { 195 Ok(self.registry.del(did, session_id).await?) 196 } 197} 198 199pub struct OAuthSession<T, S> 200where 201 T: OAuthResolver, 202 S: ClientAuthStore, 203{ 204 pub registry: Arc<SessionRegistry<T, S>>, 205 pub client: Arc<T>, 206 pub data: RwLock<ClientSessionData<'static>>, 207 pub options: RwLock<CallOptions<'static>>, 208} 209 210impl<T, S> OAuthSession<T, S> 211where 212 T: OAuthResolver, 213 S: ClientAuthStore, 214{ 215 pub fn new( 216 registry: Arc<SessionRegistry<T, S>>, 217 client: Arc<T>, 218 data: ClientSessionData<'static>, 219 ) -> Self { 220 Self { 221 registry, 222 client, 223 data: RwLock::new(data), 224 options: RwLock::new(CallOptions::default()), 225 } 226 } 227 228 pub fn with_options(self, options: CallOptions<'_>) -> Self { 229 Self { 230 registry: self.registry, 231 client: self.client, 232 data: self.data, 233 options: RwLock::new(options.into_static()), 234 } 235 } 236 237 pub async fn set_options(&self, options: CallOptions<'_>) { 238 *self.options.write().await = options.into_static(); 239 } 240 241 pub async fn session_info(&self) -> (Did<'_>, CowStr<'_>) { 242 let data = self.data.read().await; 243 (data.account_did.clone(), data.session_id.clone()) 244 } 245 246 pub async fn endpoint(&self) -> Url { 247 self.data.read().await.host_url.clone() 248 } 249 250 pub async fn access_token(&self) -> AuthorizationToken<'_> { 251 AuthorizationToken::Dpop(self.data.read().await.token_set.access_token.clone()) 252 } 253 254 pub async fn refresh_token(&self) -> Option<AuthorizationToken<'_>> { 255 self.data.read().await.token_set.refresh_token.as_ref().map(|t| AuthorizationToken::Dpop(t.clone())) 256 } 257} 258impl<T, S> OAuthSession<T, S> 259where 260 S: ClientAuthStore + Send + Sync + 'static, 261 T: OAuthResolver + DpopExt + Send + Sync + 'static, 262{ 263 pub async fn refresh(&self) -> Result<AuthorizationToken<'_>> { 264 // Read identifiers without holding the lock across await 265 let (did, sid) = { 266 let data = self.data.read().await; 267 (data.account_did.clone(), data.session_id.clone()) 268 }; 269 let refreshed = self 270 .registry 271 .as_ref() 272 .get(&did, &sid, true) 273 .await?; 274 let token = AuthorizationToken::Dpop(refreshed.token_set.access_token.clone()); 275 // Write back updated session 276 *self.data.write().await = refreshed.into_static(); 277 Ok(token) 278 } 279} 280 281impl<T, S> HttpClient for OAuthSession<T, S> 282where 283 S: ClientAuthStore + Send + Sync + 'static, 284 T: OAuthResolver + DpopExt + Send + Sync + 'static, 285{ 286 type Error = T::Error; 287 288 async fn send_http( 289 &self, 290 request: http::Request<Vec<u8>>, 291 ) -> core::result::Result<http::Response<Vec<u8>>, Self::Error> { 292 self.client.send_http(request).await 293 } 294} 295 296impl<T, S> XrpcClient for OAuthSession<T, S> 297where 298 S: ClientAuthStore + Send + Sync + 'static, 299 T: OAuthResolver + DpopExt + XrpcExt + Send + Sync + 'static, 300{ 301 fn base_uri(&self) -> Url { 302 // base_uri is a synchronous trait method; we must avoid async `.read().await`. 303 // Use `block_in_place` under Tokio to perform a blocking RwLock read safely. 304 if tokio::runtime::Handle::try_current().is_ok() { 305 tokio::task::block_in_place(|| self.data.blocking_read().host_url.clone()) 306 } else { 307 self.data.blocking_read().host_url.clone() 308 } 309 } 310 311 async fn opts(&self) -> CallOptions<'_> { 312 self.options.read().await.clone() 313 } 314 315 async fn send<R: jacquard_common::types::xrpc::XrpcRequest + Send>( 316 self, 317 request: &R, 318 ) -> XrpcResult<Response<R>> { 319 let base_uri = self.base_uri(); 320 let auth = self.access_token().await; 321 let mut opts = self.options.read().await.clone(); 322 opts.auth = Some(auth); 323 let res = self 324 .client 325 .xrpc(base_uri.clone()) 326 .with_options(opts.clone()) 327 .send(request) 328 .await; 329 if is_invalid_token_response(&res) { 330 opts.auth = Some( 331 self.refresh() 332 .await 333 .map_err(|e| ClientError::Transport(TransportError::Other(e.into())))?, 334 ); 335 self.client 336 .xrpc(base_uri) 337 .with_options(opts) 338 .send(request) 339 .await 340 } else { 341 res 342 } 343 } 344} 345 346fn is_invalid_token_response<R: XrpcRequest>(response: &XrpcResult<Response<R>>) -> bool { 347 match response { 348 Err(ClientError::Auth(AuthError::InvalidToken)) => true, 349 Err(ClientError::Auth(AuthError::Other(value))) => value 350 .to_str() 351 .is_ok_and(|s| s.starts_with("DPoP ") && s.contains("error=\"invalid_token\"")), 352 Ok(resp) => match resp.parse() { 353 Err(jacquard_common::types::xrpc::XrpcError::Auth(AuthError::InvalidToken)) => true, 354 _ => false, 355 }, 356 _ => false, 357 } 358}