A better Rust ATProto crate
1use crate::{ 2 atproto::atproto_client_metadata, 3 authstore::ClientAuthStore, 4 dpop::DpopExt, 5 error::{CallbackError, Result}, 6 request::{OAuthMetadata, exchange_code, par}, 7 resolver::OAuthResolver, 8 scopes::Scope, 9 session::{ClientData, ClientSessionData, DpopClientData, SessionRegistry}, 10 types::{AuthorizeOptions, CallbackParams}, 11}; 12use jacquard_common::{ 13 AuthorizationToken, CowStr, IntoStatic, 14 error::{AuthError, ClientError, TransportError, XrpcResult}, 15 http_client::HttpClient, 16 types::{did::Did, string::Handle}, 17 xrpc::{ 18 CallOptions, Response, XrpcClient, XrpcExt, XrpcRequest, XrpcResp, build_http_request, 19 process_response, 20 }, 21}; 22use jacquard_identity::{JacquardResolver, resolver::IdentityResolver}; 23use jose_jwk::JwkSet; 24use std::sync::Arc; 25use tokio::sync::RwLock; 26use url::Url; 27 28pub struct OAuthClient<T, S> 29where 30 T: OAuthResolver, 31 S: ClientAuthStore, 32{ 33 pub registry: Arc<SessionRegistry<T, S>>, 34 pub client: Arc<T>, 35} 36 37impl<S: ClientAuthStore> OAuthClient<JacquardResolver, S> { 38 pub fn new(store: S, client_data: ClientData<'static>) -> Self { 39 let client = JacquardResolver::default(); 40 Self::new_from_resolver(store, client, client_data) 41 } 42 43 /// Create an OAuth client with the provided store and default localhost client metadata. 44 /// 45 /// This is a convenience constructor for quickly setting up an OAuth client 46 /// with default localhost redirect URIs and "atproto transition:generic" scopes. 47 /// 48 /// # Example 49 /// 50 /// ```no_run 51 /// # use jacquard_oauth::client::OAuthClient; 52 /// # use jacquard_oauth::authstore::MemoryAuthStore; 53 /// # #[tokio::main] 54 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> { 55 /// let store = MemoryAuthStore::new(); 56 /// let oauth = OAuthClient::with_default_config(store); 57 /// # Ok(()) 58 /// # } 59 /// ``` 60 pub fn with_default_config(store: S) -> Self { 61 let client_data = ClientData { 62 keyset: None, 63 config: crate::atproto::AtprotoClientMetadata::default_localhost(), 64 }; 65 Self::new(store, client_data) 66 } 67} 68 69impl OAuthClient<JacquardResolver, crate::authstore::MemoryAuthStore> { 70 /// Create an OAuth client with an in-memory auth store and default localhost client metadata. 71 /// 72 /// This is a convenience constructor for simple testing and development. 73 /// The session will not persist across restarts. 74 /// 75 /// # Example 76 /// 77 /// ```no_run 78 /// # use jacquard_oauth::client::OAuthClient; 79 /// # #[tokio::main] 80 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> { 81 /// let oauth = OAuthClient::with_memory_store(); 82 /// # Ok(()) 83 /// # } 84 /// ``` 85 pub fn with_memory_store() -> Self { 86 Self::with_default_config(crate::authstore::MemoryAuthStore::new()) 87 } 88} 89 90impl<T, S> OAuthClient<T, S> 91where 92 T: OAuthResolver, 93 S: ClientAuthStore, 94{ 95 pub fn new_from_resolver(store: S, client: T, client_data: ClientData<'static>) -> Self { 96 let client = Arc::new(client); 97 let registry = Arc::new(SessionRegistry::new(store, client.clone(), client_data)); 98 Self { registry, client } 99 } 100 101 pub fn new_with_shared( 102 store: Arc<S>, 103 client: Arc<T>, 104 client_data: ClientData<'static>, 105 ) -> Self { 106 let registry = Arc::new(SessionRegistry::new_shared( 107 store, 108 client.clone(), 109 client_data, 110 )); 111 Self { registry, client } 112 } 113} 114 115impl<T, S> OAuthClient<T, S> 116where 117 S: ClientAuthStore + Send + Sync + 'static, 118 T: OAuthResolver + DpopExt + Send + Sync + 'static, 119{ 120 pub fn jwks(&self) -> JwkSet { 121 self.registry 122 .client_data 123 .keyset 124 .as_ref() 125 .map(|keyset| keyset.public_jwks()) 126 .unwrap_or_default() 127 } 128 pub async fn start_auth( 129 &self, 130 input: impl AsRef<str>, 131 options: AuthorizeOptions<'_>, 132 ) -> Result<String> { 133 let client_metadata = atproto_client_metadata( 134 self.registry.client_data.config.clone(), 135 &self.registry.client_data.keyset, 136 )?; 137 138 let (server_metadata, identity) = self.client.resolve_oauth(input.as_ref()).await?; 139 let login_hint = if identity.is_some() { 140 Some(input.as_ref().into()) 141 } else { 142 None 143 }; 144 let metadata = OAuthMetadata { 145 server_metadata, 146 client_metadata, 147 keyset: self.registry.client_data.keyset.clone(), 148 }; 149 let auth_req_info = 150 par(self.client.as_ref(), login_hint, options.prompt, &metadata).await?; 151 // Persist state for callback handling 152 self.registry 153 .store 154 .save_auth_req_info(&auth_req_info) 155 .await?; 156 157 #[derive(serde::Serialize)] 158 struct Parameters<'s> { 159 client_id: Url, 160 request_uri: CowStr<'s>, 161 } 162 Ok(metadata.server_metadata.authorization_endpoint.to_string() 163 + "?" 164 + &serde_html_form::to_string(Parameters { 165 client_id: metadata.client_metadata.client_id.clone(), 166 request_uri: auth_req_info.request_uri, 167 }) 168 .unwrap()) 169 } 170 171 pub async fn callback(&self, params: CallbackParams<'_>) -> Result<OAuthSession<T, S>> { 172 let Some(state_key) = params.state else { 173 return Err(CallbackError::MissingState.into()); 174 }; 175 176 let Some(auth_req_info) = self.registry.store.get_auth_req_info(&state_key).await? else { 177 return Err(CallbackError::MissingState.into()); 178 }; 179 180 self.registry.store.delete_auth_req_info(&state_key).await?; 181 182 let metadata = self 183 .client 184 .get_authorization_server_metadata(&auth_req_info.authserver_url) 185 .await?; 186 187 if let Some(iss) = params.iss { 188 if !crate::resolver::issuer_equivalent(&iss, &metadata.issuer) { 189 return Err(CallbackError::IssuerMismatch { 190 expected: metadata.issuer.to_string(), 191 got: iss.to_string(), 192 } 193 .into()); 194 } 195 } else if metadata.authorization_response_iss_parameter_supported == Some(true) { 196 return Err(CallbackError::MissingIssuer.into()); 197 } 198 let metadata = OAuthMetadata { 199 server_metadata: metadata, 200 client_metadata: atproto_client_metadata( 201 self.registry.client_data.config.clone(), 202 &self.registry.client_data.keyset, 203 )?, 204 keyset: self.registry.client_data.keyset.clone(), 205 }; 206 let authserver_nonce = auth_req_info.dpop_data.dpop_authserver_nonce.clone(); 207 208 match exchange_code( 209 self.client.as_ref(), 210 &mut auth_req_info.dpop_data.clone(), 211 &params.code, 212 &auth_req_info.pkce_verifier, 213 &metadata, 214 ) 215 .await 216 { 217 Ok(token_set) => { 218 let scopes = if let Some(scope) = &token_set.scope { 219 Scope::parse_multiple_reduced(&scope) 220 .expect("Failed to parse scopes") 221 .into_static() 222 } else { 223 vec![] 224 }; 225 let client_data = ClientSessionData { 226 account_did: token_set.sub.clone(), 227 session_id: auth_req_info.state, 228 host_url: Url::parse(&token_set.iss).expect("Failed to parse host URL"), 229 authserver_url: auth_req_info.authserver_url, 230 authserver_token_endpoint: auth_req_info.authserver_token_endpoint, 231 authserver_revocation_endpoint: auth_req_info.authserver_revocation_endpoint, 232 scopes, 233 dpop_data: DpopClientData { 234 dpop_key: auth_req_info.dpop_data.dpop_key.clone(), 235 dpop_authserver_nonce: authserver_nonce.unwrap_or(CowStr::default()), 236 dpop_host_nonce: auth_req_info 237 .dpop_data 238 .dpop_authserver_nonce 239 .unwrap_or(CowStr::default()), 240 }, 241 token_set, 242 }; 243 244 self.create_session(client_data).await 245 } 246 Err(e) => Err(e.into()), 247 } 248 } 249 250 async fn create_session(&self, data: ClientSessionData<'_>) -> Result<OAuthSession<T, S>> { 251 self.registry.set(data.clone()).await?; 252 Ok(OAuthSession::new( 253 self.registry.clone(), 254 self.client.clone(), 255 data.into_static(), 256 )) 257 } 258 259 pub async fn restore(&self, did: &Did<'_>, session_id: &str) -> Result<OAuthSession<T, S>> { 260 self.create_session(self.registry.get(did, session_id, false).await?) 261 .await 262 } 263 264 pub async fn revoke(&self, did: &Did<'_>, session_id: &str) -> Result<()> { 265 Ok(self.registry.del(did, session_id).await?) 266 } 267} 268 269pub struct OAuthSession<T, S> 270where 271 T: OAuthResolver, 272 S: ClientAuthStore, 273{ 274 pub registry: Arc<SessionRegistry<T, S>>, 275 pub client: Arc<T>, 276 pub data: RwLock<ClientSessionData<'static>>, 277 pub options: RwLock<CallOptions<'static>>, 278} 279 280impl<T, S> OAuthSession<T, S> 281where 282 T: OAuthResolver, 283 S: ClientAuthStore, 284{ 285 pub fn new( 286 registry: Arc<SessionRegistry<T, S>>, 287 client: Arc<T>, 288 data: ClientSessionData<'static>, 289 ) -> Self { 290 Self { 291 registry, 292 client, 293 data: RwLock::new(data), 294 options: RwLock::new(CallOptions::default()), 295 } 296 } 297 298 pub fn with_options(self, options: CallOptions<'_>) -> Self { 299 Self { 300 registry: self.registry, 301 client: self.client, 302 data: self.data, 303 options: RwLock::new(options.into_static()), 304 } 305 } 306 307 pub async fn set_options(&self, options: CallOptions<'_>) { 308 *self.options.write().await = options.into_static(); 309 } 310 311 pub async fn session_info(&self) -> (Did<'_>, CowStr<'_>) { 312 let data = self.data.read().await; 313 (data.account_did.clone(), data.session_id.clone()) 314 } 315 316 pub async fn endpoint(&self) -> Url { 317 self.data.read().await.host_url.clone() 318 } 319 320 pub async fn access_token(&self) -> AuthorizationToken<'_> { 321 AuthorizationToken::Dpop(self.data.read().await.token_set.access_token.clone()) 322 } 323 324 pub async fn refresh_token(&self) -> Option<AuthorizationToken<'_>> { 325 self.data 326 .read() 327 .await 328 .token_set 329 .refresh_token 330 .as_ref() 331 .map(|t| AuthorizationToken::Dpop(t.clone())) 332 } 333} 334impl<T, S> OAuthSession<T, S> 335where 336 S: ClientAuthStore + Send + Sync + 'static, 337 T: OAuthResolver + DpopExt + Send + Sync + 'static, 338{ 339 pub async fn logout(&self) -> Result<()> { 340 use crate::request::{OAuthMetadata, revoke}; 341 let mut data = self.data.write().await; 342 let meta = 343 OAuthMetadata::new(self.client.as_ref(), &self.registry.client_data, &data).await?; 344 if meta.server_metadata.revocation_endpoint.is_some() { 345 let token = data.token_set.access_token.clone(); 346 revoke(self.client.as_ref(), &mut data.dpop_data, &token, &meta) 347 .await 348 .ok(); 349 } 350 // Remove from store 351 self.registry 352 .del(&data.account_did, &data.session_id) 353 .await?; 354 Ok(()) 355 } 356} 357 358impl<T, S> OAuthClient<T, S> 359where 360 T: OAuthResolver, 361 S: ClientAuthStore, 362{ 363 pub fn from_session(session: &OAuthSession<T, S>) -> Self { 364 Self { 365 registry: session.registry.clone(), 366 client: session.client.clone(), 367 } 368 } 369} 370impl<T, S> OAuthSession<T, S> 371where 372 S: ClientAuthStore + Send + Sync + 'static, 373 T: OAuthResolver + DpopExt + Send + Sync + 'static, 374{ 375 pub async fn refresh(&self) -> Result<AuthorizationToken<'_>> { 376 // Read identifiers without holding the lock across await 377 let (did, sid) = { 378 let data = self.data.read().await; 379 (data.account_did.clone(), data.session_id.clone()) 380 }; 381 let refreshed = self.registry.as_ref().get(&did, &sid, true).await?; 382 let token = AuthorizationToken::Dpop(refreshed.token_set.access_token.clone()); 383 // Write back updated session 384 *self.data.write().await = refreshed.clone().into_static(); 385 // Store in the registry 386 self.registry.set(refreshed).await?; 387 Ok(token) 388 } 389} 390 391impl<T, S> HttpClient for OAuthSession<T, S> 392where 393 S: ClientAuthStore + Send + Sync + 'static, 394 T: OAuthResolver + DpopExt + Send + Sync + 'static, 395{ 396 type Error = T::Error; 397 398 async fn send_http( 399 &self, 400 request: http::Request<Vec<u8>>, 401 ) -> core::result::Result<http::Response<Vec<u8>>, Self::Error> { 402 self.client.send_http(request).await 403 } 404} 405 406impl<T, S> XrpcClient for OAuthSession<T, S> 407where 408 S: ClientAuthStore + Send + Sync + 'static, 409 T: OAuthResolver + DpopExt + XrpcExt + Send + Sync + 'static, 410{ 411 fn base_uri(&self) -> Url { 412 // base_uri is a synchronous trait method; we must avoid async `.read().await`. 413 // Use `block_in_place` under Tokio to perform a blocking RwLock read safely. 414 if tokio::runtime::Handle::try_current().is_ok() { 415 tokio::task::block_in_place(|| self.data.blocking_read().host_url.clone()) 416 } else { 417 self.data.blocking_read().host_url.clone() 418 } 419 } 420 421 async fn opts(&self) -> CallOptions<'_> { 422 self.options.read().await.clone() 423 } 424 425 async fn send<'s, R>( 426 &self, 427 request: R, 428 ) -> XrpcResult<Response<<R as XrpcRequest<'s>>::Response>> 429 where 430 R: XrpcRequest<'s>, 431 { 432 let base_uri = self.base_uri(); 433 let mut opts = self.options.read().await.clone(); 434 opts.auth = Some(self.access_token().await); 435 let guard = self.data.read().await; 436 let mut dpop = guard.dpop_data.clone(); 437 let http_response = self 438 .client 439 .dpop_call(&mut dpop) 440 .send(build_http_request(&base_uri, &request, &opts)?) 441 .await 442 .map_err(|e| TransportError::Other(Box::new(e)))?; 443 let resp = process_response(http_response); 444 drop(guard); 445 if is_invalid_token_response(&resp) { 446 opts.auth = Some( 447 self.refresh() 448 .await 449 .map_err(|e| ClientError::Transport(TransportError::Other(e.into())))?, 450 ); 451 let guard = self.data.read().await; 452 let mut dpop = guard.dpop_data.clone(); 453 let http_response = self 454 .client 455 .dpop_call(&mut dpop) 456 .send(build_http_request(&base_uri, &request, &opts)?) 457 .await 458 .map_err(|e| TransportError::Other(Box::new(e)))?; 459 process_response(http_response) 460 } else { 461 resp 462 } 463 } 464} 465 466fn is_invalid_token_response<R: XrpcResp>(response: &XrpcResult<Response<R>>) -> bool { 467 match response { 468 Err(ClientError::Auth(AuthError::InvalidToken)) => true, 469 Err(ClientError::Auth(AuthError::Other(value))) => value 470 .to_str() 471 .is_ok_and(|s| s.starts_with("DPoP ") && s.contains("error=\"invalid_token\"")), 472 Ok(resp) => match resp.parse() { 473 Err(jacquard_common::xrpc::XrpcError::Auth(AuthError::InvalidToken)) => true, 474 _ => false, 475 }, 476 _ => false, 477 } 478} 479 480impl<T, S> IdentityResolver for OAuthSession<T, S> 481where 482 S: ClientAuthStore + Send + Sync + 'static, 483 T: OAuthResolver + IdentityResolver + XrpcExt + Send + Sync + 'static, 484{ 485 fn options(&self) -> &jacquard_identity::resolver::ResolverOptions { 486 self.client.options() 487 } 488 489 fn resolve_handle( 490 &self, 491 handle: &Handle<'_>, 492 ) -> impl Future< 493 Output = std::result::Result<Did<'static>, jacquard_identity::resolver::IdentityError>, 494 > { 495 async { self.client.resolve_handle(handle).await } 496 } 497 498 fn resolve_did_doc( 499 &self, 500 did: &Did<'_>, 501 ) -> impl Future< 502 Output = std::result::Result< 503 jacquard_identity::resolver::DidDocResponse, 504 jacquard_identity::resolver::IdentityError, 505 >, 506 > { 507 async { self.client.resolve_did_doc(did).await } 508 } 509}