A better Rust ATProto crate
1use crate::{ 2 atproto::atproto_client_metadata, 3 authstore::ClientAuthStore, 4 dpop::DpopExt, 5 error::{CallbackError, Result}, 6 request::{OAuthMetadata, exchange_code, par}, 7 resolver::OAuthResolver, 8 scopes::Scope, 9 session::{ClientData, ClientSessionData, DpopClientData, SessionRegistry}, 10 types::{AuthorizeOptions, CallbackParams}, 11}; 12use jacquard_common::{ 13 AuthorizationToken, CowStr, IntoStatic, 14 error::{AuthError, ClientError, TransportError, XrpcResult}, 15 http_client::HttpClient, 16 types::{ 17 did::Did, 18 xrpc::{ 19 CallOptions, Response, XrpcClient, XrpcExt, XrpcRequest, build_http_request, 20 process_response, 21 }, 22 }, 23}; 24use jacquard_identity::JacquardResolver; 25use jose_jwk::JwkSet; 26use std::sync::Arc; 27use tokio::sync::RwLock; 28use url::Url; 29 30pub struct OAuthClient<T, S> 31where 32 T: OAuthResolver, 33 S: ClientAuthStore, 34{ 35 pub registry: Arc<SessionRegistry<T, S>>, 36 pub client: Arc<T>, 37} 38 39impl<S: ClientAuthStore> OAuthClient<JacquardResolver, S> { 40 pub fn new(store: S, client_data: ClientData<'static>) -> Self { 41 let client = JacquardResolver::default(); 42 Self::new_from_resolver(store, client, client_data) 43 } 44} 45 46impl<T, S> OAuthClient<T, S> 47where 48 T: OAuthResolver, 49 S: ClientAuthStore, 50{ 51 pub fn new_from_resolver(store: S, client: T, client_data: ClientData<'static>) -> Self { 52 let client = Arc::new(client); 53 let registry = Arc::new(SessionRegistry::new(store, client.clone(), client_data)); 54 Self { registry, client } 55 } 56 57 pub fn new_with_shared( 58 store: Arc<S>, 59 client: Arc<T>, 60 client_data: ClientData<'static>, 61 ) -> Self { 62 let registry = Arc::new(SessionRegistry::new_shared( 63 store, 64 client.clone(), 65 client_data, 66 )); 67 Self { registry, client } 68 } 69} 70 71impl<T, S> OAuthClient<T, S> 72where 73 S: ClientAuthStore + Send + Sync + 'static, 74 T: OAuthResolver + DpopExt + Send + Sync + 'static, 75{ 76 pub fn jwks(&self) -> JwkSet { 77 self.registry 78 .client_data 79 .keyset 80 .as_ref() 81 .map(|keyset| keyset.public_jwks()) 82 .unwrap_or_default() 83 } 84 pub async fn start_auth( 85 &self, 86 input: impl AsRef<str>, 87 options: AuthorizeOptions<'_>, 88 ) -> Result<String> { 89 let client_metadata = atproto_client_metadata( 90 self.registry.client_data.config.clone(), 91 &self.registry.client_data.keyset, 92 )?; 93 94 let (server_metadata, identity) = self.client.resolve_oauth(input.as_ref()).await?; 95 let login_hint = if identity.is_some() { 96 Some(input.as_ref().into()) 97 } else { 98 None 99 }; 100 let metadata = OAuthMetadata { 101 server_metadata, 102 client_metadata, 103 keyset: self.registry.client_data.keyset.clone(), 104 }; 105 let auth_req_info = 106 par(self.client.as_ref(), login_hint, options.prompt, &metadata).await?; 107 // Persist state for callback handling 108 self.registry 109 .store 110 .save_auth_req_info(&auth_req_info) 111 .await?; 112 113 #[derive(serde::Serialize)] 114 struct Parameters<'s> { 115 client_id: Url, 116 request_uri: CowStr<'s>, 117 } 118 Ok(metadata.server_metadata.authorization_endpoint.to_string() 119 + "?" 120 + &serde_html_form::to_string(Parameters { 121 client_id: metadata.client_metadata.client_id.clone(), 122 request_uri: auth_req_info.request_uri, 123 }) 124 .unwrap()) 125 } 126 127 pub async fn callback(&self, params: CallbackParams<'_>) -> Result<OAuthSession<T, S>> { 128 let Some(state_key) = params.state else { 129 return Err(CallbackError::MissingState.into()); 130 }; 131 132 let Some(auth_req_info) = self.registry.store.get_auth_req_info(&state_key).await? else { 133 return Err(CallbackError::MissingState.into()); 134 }; 135 136 self.registry.store.delete_auth_req_info(&state_key).await?; 137 138 let metadata = self 139 .client 140 .get_authorization_server_metadata(&auth_req_info.authserver_url) 141 .await?; 142 143 if let Some(iss) = params.iss { 144 if !crate::resolver::issuer_equivalent(&iss, &metadata.issuer) { 145 return Err(CallbackError::IssuerMismatch { 146 expected: metadata.issuer.to_string(), 147 got: iss.to_string(), 148 } 149 .into()); 150 } 151 } else if metadata.authorization_response_iss_parameter_supported == Some(true) { 152 return Err(CallbackError::MissingIssuer.into()); 153 } 154 let metadata = OAuthMetadata { 155 server_metadata: metadata, 156 client_metadata: atproto_client_metadata( 157 self.registry.client_data.config.clone(), 158 &self.registry.client_data.keyset, 159 )?, 160 keyset: self.registry.client_data.keyset.clone(), 161 }; 162 let authserver_nonce = auth_req_info.dpop_data.dpop_authserver_nonce.clone(); 163 164 match exchange_code( 165 self.client.as_ref(), 166 &mut auth_req_info.dpop_data.clone(), 167 &params.code, 168 &auth_req_info.pkce_verifier, 169 &metadata, 170 ) 171 .await 172 { 173 Ok(token_set) => { 174 let scopes = if let Some(scope) = &token_set.scope { 175 Scope::parse_multiple_reduced(&scope) 176 .expect("Failed to parse scopes") 177 .into_static() 178 } else { 179 vec![] 180 }; 181 let client_data = ClientSessionData { 182 account_did: token_set.sub.clone(), 183 session_id: auth_req_info.state, 184 host_url: Url::parse(&token_set.iss).expect("Failed to parse host URL"), 185 authserver_url: auth_req_info.authserver_url, 186 authserver_token_endpoint: auth_req_info.authserver_token_endpoint, 187 authserver_revocation_endpoint: auth_req_info.authserver_revocation_endpoint, 188 scopes, 189 dpop_data: DpopClientData { 190 dpop_key: auth_req_info.dpop_data.dpop_key.clone(), 191 dpop_authserver_nonce: authserver_nonce.unwrap_or(CowStr::default()), 192 dpop_host_nonce: auth_req_info 193 .dpop_data 194 .dpop_authserver_nonce 195 .unwrap_or(CowStr::default()), 196 }, 197 token_set, 198 }; 199 200 self.create_session(client_data).await 201 } 202 Err(e) => Err(e.into()), 203 } 204 } 205 206 async fn create_session(&self, data: ClientSessionData<'_>) -> Result<OAuthSession<T, S>> { 207 self.registry.set(data.clone()).await?; 208 Ok(OAuthSession::new( 209 self.registry.clone(), 210 self.client.clone(), 211 data.into_static(), 212 )) 213 } 214 215 pub async fn restore(&self, did: &Did<'_>, session_id: &str) -> Result<OAuthSession<T, S>> { 216 self.create_session(self.registry.get(did, session_id, false).await?) 217 .await 218 } 219 220 pub async fn revoke(&self, did: &Did<'_>, session_id: &str) -> Result<()> { 221 Ok(self.registry.del(did, session_id).await?) 222 } 223} 224 225pub struct OAuthSession<T, S> 226where 227 T: OAuthResolver, 228 S: ClientAuthStore, 229{ 230 pub registry: Arc<SessionRegistry<T, S>>, 231 pub client: Arc<T>, 232 pub data: RwLock<ClientSessionData<'static>>, 233 pub options: RwLock<CallOptions<'static>>, 234} 235 236impl<T, S> OAuthSession<T, S> 237where 238 T: OAuthResolver, 239 S: ClientAuthStore, 240{ 241 pub fn new( 242 registry: Arc<SessionRegistry<T, S>>, 243 client: Arc<T>, 244 data: ClientSessionData<'static>, 245 ) -> Self { 246 Self { 247 registry, 248 client, 249 data: RwLock::new(data), 250 options: RwLock::new(CallOptions::default()), 251 } 252 } 253 254 pub fn with_options(self, options: CallOptions<'_>) -> Self { 255 Self { 256 registry: self.registry, 257 client: self.client, 258 data: self.data, 259 options: RwLock::new(options.into_static()), 260 } 261 } 262 263 pub async fn set_options(&self, options: CallOptions<'_>) { 264 *self.options.write().await = options.into_static(); 265 } 266 267 pub async fn session_info(&self) -> (Did<'_>, CowStr<'_>) { 268 let data = self.data.read().await; 269 (data.account_did.clone(), data.session_id.clone()) 270 } 271 272 pub async fn endpoint(&self) -> Url { 273 self.data.read().await.host_url.clone() 274 } 275 276 pub async fn access_token(&self) -> AuthorizationToken<'_> { 277 AuthorizationToken::Dpop(self.data.read().await.token_set.access_token.clone()) 278 } 279 280 pub async fn refresh_token(&self) -> Option<AuthorizationToken<'_>> { 281 self.data 282 .read() 283 .await 284 .token_set 285 .refresh_token 286 .as_ref() 287 .map(|t| AuthorizationToken::Dpop(t.clone())) 288 } 289} 290impl<T, S> OAuthSession<T, S> 291where 292 S: ClientAuthStore + Send + Sync + 'static, 293 T: OAuthResolver + DpopExt + Send + Sync + 'static, 294{ 295 pub async fn logout(&self) -> Result<()> { 296 use crate::request::{OAuthMetadata, revoke}; 297 let mut data = self.data.write().await; 298 let meta = 299 OAuthMetadata::new(self.client.as_ref(), &self.registry.client_data, &data).await?; 300 if meta.server_metadata.revocation_endpoint.is_some() { 301 let token = data.token_set.access_token.clone(); 302 revoke(self.client.as_ref(), &mut data.dpop_data, &token, &meta) 303 .await 304 .ok(); 305 } 306 // Remove from store 307 self.registry 308 .del(&data.account_did, &data.session_id) 309 .await?; 310 Ok(()) 311 } 312} 313 314impl<T, S> OAuthClient<T, S> 315where 316 T: OAuthResolver, 317 S: ClientAuthStore, 318{ 319 pub fn from_session(session: &OAuthSession<T, S>) -> Self { 320 Self { 321 registry: session.registry.clone(), 322 client: session.client.clone(), 323 } 324 } 325} 326impl<T, S> OAuthSession<T, S> 327where 328 S: ClientAuthStore + Send + Sync + 'static, 329 T: OAuthResolver + DpopExt + Send + Sync + 'static, 330{ 331 pub async fn refresh(&self) -> Result<AuthorizationToken<'_>> { 332 // Read identifiers without holding the lock across await 333 let (did, sid) = { 334 let data = self.data.read().await; 335 (data.account_did.clone(), data.session_id.clone()) 336 }; 337 let refreshed = self.registry.as_ref().get(&did, &sid, true).await?; 338 let token = AuthorizationToken::Dpop(refreshed.token_set.access_token.clone()); 339 // Write back updated session 340 *self.data.write().await = refreshed.clone().into_static(); 341 // Store in the registry 342 self.registry.set(refreshed).await?; 343 Ok(token) 344 } 345} 346 347impl<T, S> HttpClient for OAuthSession<T, S> 348where 349 S: ClientAuthStore + Send + Sync + 'static, 350 T: OAuthResolver + DpopExt + Send + Sync + 'static, 351{ 352 type Error = T::Error; 353 354 async fn send_http( 355 &self, 356 request: http::Request<Vec<u8>>, 357 ) -> core::result::Result<http::Response<Vec<u8>>, Self::Error> { 358 self.client.send_http(request).await 359 } 360} 361 362impl<T, S> XrpcClient for OAuthSession<T, S> 363where 364 S: ClientAuthStore + Send + Sync + 'static, 365 T: OAuthResolver + DpopExt + XrpcExt + Send + Sync + 'static, 366{ 367 fn base_uri(&self) -> Url { 368 // base_uri is a synchronous trait method; we must avoid async `.read().await`. 369 // Use `block_in_place` under Tokio to perform a blocking RwLock read safely. 370 if tokio::runtime::Handle::try_current().is_ok() { 371 tokio::task::block_in_place(|| self.data.blocking_read().host_url.clone()) 372 } else { 373 self.data.blocking_read().host_url.clone() 374 } 375 } 376 377 async fn opts(&self) -> CallOptions<'_> { 378 self.options.read().await.clone() 379 } 380 381 async fn send<R: jacquard_common::types::xrpc::XrpcRequest + Send>( 382 self, 383 request: &R, 384 ) -> XrpcResult<Response<R>> { 385 let base_uri = self.base_uri(); 386 let mut opts = self.options.read().await.clone(); 387 opts.auth = Some(self.access_token().await); 388 let guard = self.data.read().await; 389 let mut dpop = guard.dpop_data.clone(); 390 let http_response = self 391 .client 392 .dpop_call(&mut dpop) 393 .send(build_http_request(&base_uri, request, &opts)?) 394 .await 395 .map_err(|e| TransportError::Other(Box::new(e)))?; 396 drop(guard); 397 let res = process_response(http_response); 398 if is_invalid_token_response(&res) { 399 opts.auth = Some( 400 self.refresh() 401 .await 402 .map_err(|e| ClientError::Transport(TransportError::Other(e.into())))?, 403 ); 404 let guard = self.data.read().await; 405 let mut dpop = guard.dpop_data.clone(); 406 let http_response = self 407 .client 408 .dpop_call(&mut dpop) 409 .send(build_http_request(&base_uri, request, &opts)?) 410 .await 411 .map_err(|e| TransportError::Other(Box::new(e)))?; 412 process_response(http_response) 413 } else { 414 res 415 } 416 } 417} 418 419fn is_invalid_token_response<R: XrpcRequest>(response: &XrpcResult<Response<R>>) -> bool { 420 match response { 421 Err(ClientError::Auth(AuthError::InvalidToken)) => true, 422 Err(ClientError::Auth(AuthError::Other(value))) => value 423 .to_str() 424 .is_ok_and(|s| s.starts_with("DPoP ") && s.contains("error=\"invalid_token\"")), 425 Ok(resp) => match resp.parse() { 426 Err(jacquard_common::types::xrpc::XrpcError::Auth(AuthError::InvalidToken)) => true, 427 _ => false, 428 }, 429 _ => false, 430 } 431}