A better Rust ATProto crate
1use crate::{ 2 atproto::atproto_client_metadata, 3 authstore::ClientAuthStore, 4 dpop::DpopExt, 5 error::{CallbackError, Result}, 6 request::{OAuthMetadata, exchange_code, par}, 7 resolver::OAuthResolver, 8 scopes::Scope, 9 session::{ClientData, ClientSessionData, DpopClientData, SessionRegistry}, 10 types::{AuthorizeOptions, CallbackParams}, 11}; 12use jacquard_common::{ 13 AuthorizationToken, CowStr, IntoStatic, 14 error::{AuthError, ClientError, TransportError, XrpcResult}, 15 http_client::HttpClient, 16 types::{did::Did, string::Handle}, 17 xrpc::{ 18 CallOptions, Response, XrpcClient, XrpcExt, XrpcRequest, XrpcResp, build_http_request, 19 process_response, 20 }, 21}; 22use jacquard_identity::{JacquardResolver, resolver::IdentityResolver}; 23use jose_jwk::JwkSet; 24use std::sync::Arc; 25use tokio::sync::RwLock; 26use url::Url; 27 28pub struct OAuthClient<T, S> 29where 30 T: OAuthResolver, 31 S: ClientAuthStore, 32{ 33 pub registry: Arc<SessionRegistry<T, S>>, 34 pub client: Arc<T>, 35} 36 37impl<S: ClientAuthStore> OAuthClient<JacquardResolver, S> { 38 pub fn new(store: S, client_data: ClientData<'static>) -> Self { 39 let client = JacquardResolver::default(); 40 Self::new_from_resolver(store, client, client_data) 41 } 42 43 /// Create an OAuth client with the provided store and default localhost client metadata. 44 /// 45 /// This is a convenience constructor for quickly setting up an OAuth client 46 /// with default localhost redirect URIs and "atproto transition:generic" scopes. 47 /// 48 /// # Example 49 /// 50 /// ```no_run 51 /// # use jacquard_oauth::client::OAuthClient; 52 /// # use jacquard_oauth::authstore::MemoryAuthStore; 53 /// # #[tokio::main] 54 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> { 55 /// let store = MemoryAuthStore::new(); 56 /// let oauth = OAuthClient::with_default_config(store); 57 /// # Ok(()) 58 /// # } 59 /// ``` 60 pub fn with_default_config(store: S) -> Self { 61 let client_data = ClientData { 62 keyset: None, 63 config: crate::atproto::AtprotoClientMetadata::default_localhost(), 64 }; 65 Self::new(store, client_data) 66 } 67} 68 69impl OAuthClient<JacquardResolver, crate::authstore::MemoryAuthStore> { 70 /// Create an OAuth client with an in-memory auth store and default localhost client metadata. 71 /// 72 /// This is a convenience constructor for simple testing and development. 73 /// The session will not persist across restarts. 74 /// 75 /// # Example 76 /// 77 /// ```no_run 78 /// # use jacquard_oauth::client::OAuthClient; 79 /// # #[tokio::main] 80 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> { 81 /// let oauth = OAuthClient::with_memory_store(); 82 /// # Ok(()) 83 /// # } 84 /// ``` 85 pub fn with_memory_store() -> Self { 86 Self::with_default_config(crate::authstore::MemoryAuthStore::new()) 87 } 88} 89 90impl<T, S> OAuthClient<T, S> 91where 92 T: OAuthResolver, 93 S: ClientAuthStore, 94{ 95 pub fn new_from_resolver(store: S, client: T, client_data: ClientData<'static>) -> Self { 96 #[cfg(feature = "tracing")] 97 tracing::info!( 98 redirect_uris = ?client_data.config.redirect_uris, 99 scopes = ?client_data.config.scopes, 100 has_keyset = client_data.keyset.is_some(), 101 "oauth client created" 102 ); 103 104 let client = Arc::new(client); 105 let registry = Arc::new(SessionRegistry::new(store, client.clone(), client_data)); 106 Self { registry, client } 107 } 108 109 pub fn new_with_shared( 110 store: Arc<S>, 111 client: Arc<T>, 112 client_data: ClientData<'static>, 113 ) -> Self { 114 let registry = Arc::new(SessionRegistry::new_shared( 115 store, 116 client.clone(), 117 client_data, 118 )); 119 Self { registry, client } 120 } 121} 122 123impl<T, S> OAuthClient<T, S> 124where 125 S: ClientAuthStore + Send + Sync + 'static, 126 T: OAuthResolver + DpopExt + Send + Sync + 'static, 127{ 128 pub fn jwks(&self) -> JwkSet { 129 self.registry 130 .client_data 131 .keyset 132 .as_ref() 133 .map(|keyset| keyset.public_jwks()) 134 .unwrap_or_default() 135 } 136 #[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", skip(self, input), fields(input = input.as_ref())))] 137 pub async fn start_auth( 138 &self, 139 input: impl AsRef<str>, 140 options: AuthorizeOptions<'_>, 141 ) -> Result<String> { 142 let client_metadata = atproto_client_metadata( 143 self.registry.client_data.config.clone(), 144 &self.registry.client_data.keyset, 145 )?; 146 147 let (server_metadata, identity) = self.client.resolve_oauth(input.as_ref()).await?; 148 let login_hint = if identity.is_some() { 149 Some(input.as_ref().into()) 150 } else { 151 None 152 }; 153 let metadata = OAuthMetadata { 154 server_metadata, 155 client_metadata, 156 keyset: self.registry.client_data.keyset.clone(), 157 }; 158 let auth_req_info = 159 par(self.client.as_ref(), login_hint, options.prompt, &metadata).await?; 160 // Persist state for callback handling 161 self.registry 162 .store 163 .save_auth_req_info(&auth_req_info) 164 .await?; 165 166 #[derive(serde::Serialize)] 167 struct Parameters<'s> { 168 client_id: Url, 169 request_uri: CowStr<'s>, 170 } 171 Ok(metadata.server_metadata.authorization_endpoint.to_string() 172 + "?" 173 + &serde_html_form::to_string(Parameters { 174 client_id: metadata.client_metadata.client_id.clone(), 175 request_uri: auth_req_info.request_uri, 176 }) 177 .unwrap()) 178 } 179 180 #[cfg_attr(feature = "tracing", tracing::instrument(level = "info", skip_all, fields(state = params.state.as_ref().map(|s| s.as_ref()))))] 181 pub async fn callback(&self, params: CallbackParams<'_>) -> Result<OAuthSession<T, S>> { 182 let Some(state_key) = params.state else { 183 return Err(CallbackError::MissingState.into()); 184 }; 185 186 let Some(auth_req_info) = self.registry.store.get_auth_req_info(&state_key).await? else { 187 return Err(CallbackError::MissingState.into()); 188 }; 189 190 self.registry.store.delete_auth_req_info(&state_key).await?; 191 192 let metadata = self 193 .client 194 .get_authorization_server_metadata(&auth_req_info.authserver_url) 195 .await?; 196 197 if let Some(iss) = params.iss { 198 if !crate::resolver::issuer_equivalent(&iss, &metadata.issuer) { 199 return Err(CallbackError::IssuerMismatch { 200 expected: metadata.issuer.to_string(), 201 got: iss.to_string(), 202 } 203 .into()); 204 } 205 } else if metadata.authorization_response_iss_parameter_supported == Some(true) { 206 return Err(CallbackError::MissingIssuer.into()); 207 } 208 let metadata = OAuthMetadata { 209 server_metadata: metadata, 210 client_metadata: atproto_client_metadata( 211 self.registry.client_data.config.clone(), 212 &self.registry.client_data.keyset, 213 )?, 214 keyset: self.registry.client_data.keyset.clone(), 215 }; 216 let authserver_nonce = auth_req_info.dpop_data.dpop_authserver_nonce.clone(); 217 218 match exchange_code( 219 self.client.as_ref(), 220 &mut auth_req_info.dpop_data.clone(), 221 &params.code, 222 &auth_req_info.pkce_verifier, 223 &metadata, 224 ) 225 .await 226 { 227 Ok(token_set) => { 228 let scopes = if let Some(scope) = &token_set.scope { 229 Scope::parse_multiple_reduced(&scope) 230 .expect("Failed to parse scopes") 231 .into_static() 232 } else { 233 vec![] 234 }; 235 let client_data = ClientSessionData { 236 account_did: token_set.sub.clone(), 237 session_id: auth_req_info.state, 238 host_url: Url::parse(&token_set.iss).expect("Failed to parse host URL"), 239 authserver_url: auth_req_info.authserver_url, 240 authserver_token_endpoint: auth_req_info.authserver_token_endpoint, 241 authserver_revocation_endpoint: auth_req_info.authserver_revocation_endpoint, 242 scopes, 243 dpop_data: DpopClientData { 244 dpop_key: auth_req_info.dpop_data.dpop_key.clone(), 245 dpop_authserver_nonce: authserver_nonce.unwrap_or(CowStr::default()), 246 dpop_host_nonce: auth_req_info 247 .dpop_data 248 .dpop_authserver_nonce 249 .unwrap_or(CowStr::default()), 250 }, 251 token_set, 252 }; 253 254 self.create_session(client_data).await 255 } 256 Err(e) => Err(e.into()), 257 } 258 } 259 260 async fn create_session(&self, data: ClientSessionData<'_>) -> Result<OAuthSession<T, S>> { 261 self.registry.set(data.clone()).await?; 262 Ok(OAuthSession::new( 263 self.registry.clone(), 264 self.client.clone(), 265 data.into_static(), 266 )) 267 } 268 269 pub async fn restore(&self, did: &Did<'_>, session_id: &str) -> Result<OAuthSession<T, S>> { 270 self.create_session(self.registry.get(did, session_id, false).await?) 271 .await 272 } 273 274 pub async fn revoke(&self, did: &Did<'_>, session_id: &str) -> Result<()> { 275 Ok(self.registry.del(did, session_id).await?) 276 } 277} 278 279pub struct OAuthSession<T, S> 280where 281 T: OAuthResolver, 282 S: ClientAuthStore, 283{ 284 pub registry: Arc<SessionRegistry<T, S>>, 285 pub client: Arc<T>, 286 pub data: RwLock<ClientSessionData<'static>>, 287 pub options: RwLock<CallOptions<'static>>, 288} 289 290impl<T, S> OAuthSession<T, S> 291where 292 T: OAuthResolver, 293 S: ClientAuthStore, 294{ 295 pub fn new( 296 registry: Arc<SessionRegistry<T, S>>, 297 client: Arc<T>, 298 data: ClientSessionData<'static>, 299 ) -> Self { 300 Self { 301 registry, 302 client, 303 data: RwLock::new(data), 304 options: RwLock::new(CallOptions::default()), 305 } 306 } 307 308 pub fn with_options(self, options: CallOptions<'_>) -> Self { 309 Self { 310 registry: self.registry, 311 client: self.client, 312 data: self.data, 313 options: RwLock::new(options.into_static()), 314 } 315 } 316 317 pub async fn set_options(&self, options: CallOptions<'_>) { 318 *self.options.write().await = options.into_static(); 319 } 320 321 pub async fn session_info(&self) -> (Did<'_>, CowStr<'_>) { 322 let data = self.data.read().await; 323 (data.account_did.clone(), data.session_id.clone()) 324 } 325 326 pub async fn endpoint(&self) -> Url { 327 self.data.read().await.host_url.clone() 328 } 329 330 pub async fn access_token(&self) -> AuthorizationToken<'_> { 331 AuthorizationToken::Dpop(self.data.read().await.token_set.access_token.clone()) 332 } 333 334 pub async fn refresh_token(&self) -> Option<AuthorizationToken<'_>> { 335 self.data 336 .read() 337 .await 338 .token_set 339 .refresh_token 340 .as_ref() 341 .map(|t| AuthorizationToken::Dpop(t.clone())) 342 } 343} 344impl<T, S> OAuthSession<T, S> 345where 346 S: ClientAuthStore + Send + Sync + 'static, 347 T: OAuthResolver + DpopExt + Send + Sync + 'static, 348{ 349 pub async fn logout(&self) -> Result<()> { 350 use crate::request::{OAuthMetadata, revoke}; 351 let mut data = self.data.write().await; 352 let meta = 353 OAuthMetadata::new(self.client.as_ref(), &self.registry.client_data, &data).await?; 354 if meta.server_metadata.revocation_endpoint.is_some() { 355 let token = data.token_set.access_token.clone(); 356 revoke(self.client.as_ref(), &mut data.dpop_data, &token, &meta) 357 .await 358 .ok(); 359 } 360 // Remove from store 361 self.registry 362 .del(&data.account_did, &data.session_id) 363 .await?; 364 Ok(()) 365 } 366} 367 368impl<T, S> OAuthClient<T, S> 369where 370 T: OAuthResolver, 371 S: ClientAuthStore, 372{ 373 pub fn from_session(session: &OAuthSession<T, S>) -> Self { 374 Self { 375 registry: session.registry.clone(), 376 client: session.client.clone(), 377 } 378 } 379} 380impl<T, S> OAuthSession<T, S> 381where 382 S: ClientAuthStore + Send + Sync + 'static, 383 T: OAuthResolver + DpopExt + Send + Sync + 'static, 384{ 385 #[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", skip_all))] 386 pub async fn refresh(&self) -> Result<AuthorizationToken<'_>> { 387 // Read identifiers without holding the lock across await 388 let (did, sid) = { 389 let data = self.data.read().await; 390 (data.account_did.clone(), data.session_id.clone()) 391 }; 392 let refreshed = self.registry.as_ref().get(&did, &sid, true).await?; 393 let token = AuthorizationToken::Dpop(refreshed.token_set.access_token.clone()); 394 // Write back updated session 395 *self.data.write().await = refreshed.clone().into_static(); 396 // Store in the registry 397 self.registry.set(refreshed).await?; 398 Ok(token) 399 } 400} 401 402impl<T, S> HttpClient for OAuthSession<T, S> 403where 404 S: ClientAuthStore + Send + Sync + 'static, 405 T: OAuthResolver + DpopExt + Send + Sync + 'static, 406{ 407 type Error = T::Error; 408 409 async fn send_http( 410 &self, 411 request: http::Request<Vec<u8>>, 412 ) -> core::result::Result<http::Response<Vec<u8>>, Self::Error> { 413 self.client.send_http(request).await 414 } 415} 416 417impl<T, S> XrpcClient for OAuthSession<T, S> 418where 419 S: ClientAuthStore + Send + Sync + 'static, 420 T: OAuthResolver + DpopExt + XrpcExt + Send + Sync + 'static, 421{ 422 fn base_uri(&self) -> Url { 423 // base_uri is a synchronous trait method; we must avoid async `.read().await`. 424 // Use `block_in_place` under Tokio to perform a blocking RwLock read safely. 425 if tokio::runtime::Handle::try_current().is_ok() { 426 tokio::task::block_in_place(|| self.data.blocking_read().host_url.clone()) 427 } else { 428 self.data.blocking_read().host_url.clone() 429 } 430 } 431 432 async fn opts(&self) -> CallOptions<'_> { 433 self.options.read().await.clone() 434 } 435 436 async fn send<R>( 437 &self, 438 request: R, 439 ) -> XrpcResult<Response<<R as XrpcRequest>::Response>> 440 where 441 R: XrpcRequest, 442 { 443 let base_uri = self.base_uri(); 444 let mut opts = self.options.read().await.clone(); 445 opts.auth = Some(self.access_token().await); 446 let guard = self.data.read().await; 447 let mut dpop = guard.dpop_data.clone(); 448 let http_response = self 449 .client 450 .dpop_call(&mut dpop) 451 .send(build_http_request(&base_uri, &request, &opts)?) 452 .await 453 .map_err(|e| TransportError::Other(Box::new(e)))?; 454 let resp = process_response(http_response); 455 drop(guard); 456 if is_invalid_token_response(&resp) { 457 opts.auth = Some( 458 self.refresh() 459 .await 460 .map_err(|e| ClientError::Transport(TransportError::Other(e.into())))?, 461 ); 462 let guard = self.data.read().await; 463 let mut dpop = guard.dpop_data.clone(); 464 let http_response = self 465 .client 466 .dpop_call(&mut dpop) 467 .send(build_http_request(&base_uri, &request, &opts)?) 468 .await 469 .map_err(|e| TransportError::Other(Box::new(e)))?; 470 process_response(http_response) 471 } else { 472 resp 473 } 474 } 475} 476 477fn is_invalid_token_response<R: XrpcResp>(response: &XrpcResult<Response<R>>) -> bool { 478 match response { 479 Err(ClientError::Auth(AuthError::InvalidToken)) => true, 480 Err(ClientError::Auth(AuthError::Other(value))) => value 481 .to_str() 482 .is_ok_and(|s| s.starts_with("DPoP ") && s.contains("error=\"invalid_token\"")), 483 Ok(resp) => match resp.parse() { 484 Err(jacquard_common::xrpc::XrpcError::Auth(AuthError::InvalidToken)) => true, 485 _ => false, 486 }, 487 _ => false, 488 } 489} 490 491impl<T, S> IdentityResolver for OAuthSession<T, S> 492where 493 S: ClientAuthStore + Send + Sync + 'static, 494 T: OAuthResolver + IdentityResolver + XrpcExt + Send + Sync + 'static, 495{ 496 fn options(&self) -> &jacquard_identity::resolver::ResolverOptions { 497 self.client.options() 498 } 499 500 fn resolve_handle( 501 &self, 502 handle: &Handle<'_>, 503 ) -> impl Future< 504 Output = std::result::Result<Did<'static>, jacquard_identity::resolver::IdentityError>, 505 > { 506 async { self.client.resolve_handle(handle).await } 507 } 508 509 fn resolve_did_doc( 510 &self, 511 did: &Did<'_>, 512 ) -> impl Future< 513 Output = std::result::Result< 514 jacquard_identity::resolver::DidDocResponse, 515 jacquard_identity::resolver::IdentityError, 516 >, 517 > { 518 async { self.client.resolve_did_doc(did).await } 519 } 520}