A better Rust ATProto crate
at lifetimes 14 kB view raw
1use crate::{ 2 atproto::atproto_client_metadata, 3 authstore::ClientAuthStore, 4 dpop::DpopExt, 5 error::{CallbackError, Result}, 6 request::{OAuthMetadata, exchange_code, par}, 7 resolver::OAuthResolver, 8 scopes::Scope, 9 session::{ClientData, ClientSessionData, DpopClientData, SessionRegistry}, 10 types::{AuthorizeOptions, CallbackParams}, 11}; 12use jacquard_common::{ 13 AuthorizationToken, CowStr, IntoStatic, 14 error::{AuthError, ClientError, TransportError, XrpcResult}, 15 http_client::HttpClient, 16 types::did::Did, 17 xrpc::{ 18 CallOptions, Response, XrpcClient, XrpcExt, XrpcRequest, XrpcResp, build_http_request, 19 process_response, 20 }, 21}; 22use jacquard_identity::JacquardResolver; 23use jose_jwk::JwkSet; 24use std::sync::Arc; 25use tokio::sync::RwLock; 26use url::Url; 27 28pub struct OAuthClient<T, S> 29where 30 T: OAuthResolver, 31 S: ClientAuthStore, 32{ 33 pub registry: Arc<SessionRegistry<T, S>>, 34 pub client: Arc<T>, 35} 36 37impl<S: ClientAuthStore> OAuthClient<JacquardResolver, S> { 38 pub fn new(store: S, client_data: ClientData<'static>) -> Self { 39 let client = JacquardResolver::default(); 40 Self::new_from_resolver(store, client, client_data) 41 } 42} 43 44impl<T, S> OAuthClient<T, S> 45where 46 T: OAuthResolver, 47 S: ClientAuthStore, 48{ 49 pub fn new_from_resolver(store: S, client: T, client_data: ClientData<'static>) -> Self { 50 let client = Arc::new(client); 51 let registry = Arc::new(SessionRegistry::new(store, client.clone(), client_data)); 52 Self { registry, client } 53 } 54 55 pub fn new_with_shared( 56 store: Arc<S>, 57 client: Arc<T>, 58 client_data: ClientData<'static>, 59 ) -> Self { 60 let registry = Arc::new(SessionRegistry::new_shared( 61 store, 62 client.clone(), 63 client_data, 64 )); 65 Self { registry, client } 66 } 67} 68 69impl<T, S> OAuthClient<T, S> 70where 71 S: ClientAuthStore + Send + Sync + 'static, 72 T: OAuthResolver + DpopExt + Send + Sync + 'static, 73{ 74 pub fn jwks(&self) -> JwkSet { 75 self.registry 76 .client_data 77 .keyset 78 .as_ref() 79 .map(|keyset| keyset.public_jwks()) 80 .unwrap_or_default() 81 } 82 pub async fn start_auth( 83 &self, 84 input: impl AsRef<str>, 85 options: AuthorizeOptions<'_>, 86 ) -> Result<String> { 87 let client_metadata = atproto_client_metadata( 88 self.registry.client_data.config.clone(), 89 &self.registry.client_data.keyset, 90 )?; 91 92 let (server_metadata, identity) = self.client.resolve_oauth(input.as_ref()).await?; 93 let login_hint = if identity.is_some() { 94 Some(input.as_ref().into()) 95 } else { 96 None 97 }; 98 let metadata = OAuthMetadata { 99 server_metadata, 100 client_metadata, 101 keyset: self.registry.client_data.keyset.clone(), 102 }; 103 let auth_req_info = 104 par(self.client.as_ref(), login_hint, options.prompt, &metadata).await?; 105 // Persist state for callback handling 106 self.registry 107 .store 108 .save_auth_req_info(&auth_req_info) 109 .await?; 110 111 #[derive(serde::Serialize)] 112 struct Parameters<'s> { 113 client_id: Url, 114 request_uri: CowStr<'s>, 115 } 116 Ok(metadata.server_metadata.authorization_endpoint.to_string() 117 + "?" 118 + &serde_html_form::to_string(Parameters { 119 client_id: metadata.client_metadata.client_id.clone(), 120 request_uri: auth_req_info.request_uri, 121 }) 122 .unwrap()) 123 } 124 125 pub async fn callback(&self, params: CallbackParams<'_>) -> Result<OAuthSession<T, S>> { 126 let Some(state_key) = params.state else { 127 return Err(CallbackError::MissingState.into()); 128 }; 129 130 let Some(auth_req_info) = self.registry.store.get_auth_req_info(&state_key).await? else { 131 return Err(CallbackError::MissingState.into()); 132 }; 133 134 self.registry.store.delete_auth_req_info(&state_key).await?; 135 136 let metadata = self 137 .client 138 .get_authorization_server_metadata(&auth_req_info.authserver_url) 139 .await?; 140 141 if let Some(iss) = params.iss { 142 if !crate::resolver::issuer_equivalent(&iss, &metadata.issuer) { 143 return Err(CallbackError::IssuerMismatch { 144 expected: metadata.issuer.to_string(), 145 got: iss.to_string(), 146 } 147 .into()); 148 } 149 } else if metadata.authorization_response_iss_parameter_supported == Some(true) { 150 return Err(CallbackError::MissingIssuer.into()); 151 } 152 let metadata = OAuthMetadata { 153 server_metadata: metadata, 154 client_metadata: atproto_client_metadata( 155 self.registry.client_data.config.clone(), 156 &self.registry.client_data.keyset, 157 )?, 158 keyset: self.registry.client_data.keyset.clone(), 159 }; 160 let authserver_nonce = auth_req_info.dpop_data.dpop_authserver_nonce.clone(); 161 162 match exchange_code( 163 self.client.as_ref(), 164 &mut auth_req_info.dpop_data.clone(), 165 &params.code, 166 &auth_req_info.pkce_verifier, 167 &metadata, 168 ) 169 .await 170 { 171 Ok(token_set) => { 172 let scopes = if let Some(scope) = &token_set.scope { 173 Scope::parse_multiple_reduced(&scope) 174 .expect("Failed to parse scopes") 175 .into_static() 176 } else { 177 vec![] 178 }; 179 let client_data = ClientSessionData { 180 account_did: token_set.sub.clone(), 181 session_id: auth_req_info.state, 182 host_url: Url::parse(&token_set.iss).expect("Failed to parse host URL"), 183 authserver_url: auth_req_info.authserver_url, 184 authserver_token_endpoint: auth_req_info.authserver_token_endpoint, 185 authserver_revocation_endpoint: auth_req_info.authserver_revocation_endpoint, 186 scopes, 187 dpop_data: DpopClientData { 188 dpop_key: auth_req_info.dpop_data.dpop_key.clone(), 189 dpop_authserver_nonce: authserver_nonce.unwrap_or(CowStr::default()), 190 dpop_host_nonce: auth_req_info 191 .dpop_data 192 .dpop_authserver_nonce 193 .unwrap_or(CowStr::default()), 194 }, 195 token_set, 196 }; 197 198 self.create_session(client_data).await 199 } 200 Err(e) => Err(e.into()), 201 } 202 } 203 204 async fn create_session(&self, data: ClientSessionData<'_>) -> Result<OAuthSession<T, S>> { 205 self.registry.set(data.clone()).await?; 206 Ok(OAuthSession::new( 207 self.registry.clone(), 208 self.client.clone(), 209 data.into_static(), 210 )) 211 } 212 213 pub async fn restore(&self, did: &Did<'_>, session_id: &str) -> Result<OAuthSession<T, S>> { 214 self.create_session(self.registry.get(did, session_id, false).await?) 215 .await 216 } 217 218 pub async fn revoke(&self, did: &Did<'_>, session_id: &str) -> Result<()> { 219 Ok(self.registry.del(did, session_id).await?) 220 } 221} 222 223pub struct OAuthSession<T, S> 224where 225 T: OAuthResolver, 226 S: ClientAuthStore, 227{ 228 pub registry: Arc<SessionRegistry<T, S>>, 229 pub client: Arc<T>, 230 pub data: RwLock<ClientSessionData<'static>>, 231 pub options: RwLock<CallOptions<'static>>, 232} 233 234impl<T, S> OAuthSession<T, S> 235where 236 T: OAuthResolver, 237 S: ClientAuthStore, 238{ 239 pub fn new( 240 registry: Arc<SessionRegistry<T, S>>, 241 client: Arc<T>, 242 data: ClientSessionData<'static>, 243 ) -> Self { 244 Self { 245 registry, 246 client, 247 data: RwLock::new(data), 248 options: RwLock::new(CallOptions::default()), 249 } 250 } 251 252 pub fn with_options(self, options: CallOptions<'_>) -> Self { 253 Self { 254 registry: self.registry, 255 client: self.client, 256 data: self.data, 257 options: RwLock::new(options.into_static()), 258 } 259 } 260 261 pub async fn set_options(&self, options: CallOptions<'_>) { 262 *self.options.write().await = options.into_static(); 263 } 264 265 pub async fn session_info(&self) -> (Did<'_>, CowStr<'_>) { 266 let data = self.data.read().await; 267 (data.account_did.clone(), data.session_id.clone()) 268 } 269 270 pub async fn endpoint(&self) -> Url { 271 self.data.read().await.host_url.clone() 272 } 273 274 pub async fn access_token(&self) -> AuthorizationToken<'_> { 275 AuthorizationToken::Dpop(self.data.read().await.token_set.access_token.clone()) 276 } 277 278 pub async fn refresh_token(&self) -> Option<AuthorizationToken<'_>> { 279 self.data 280 .read() 281 .await 282 .token_set 283 .refresh_token 284 .as_ref() 285 .map(|t| AuthorizationToken::Dpop(t.clone())) 286 } 287} 288impl<T, S> OAuthSession<T, S> 289where 290 S: ClientAuthStore + Send + Sync + 'static, 291 T: OAuthResolver + DpopExt + Send + Sync + 'static, 292{ 293 pub async fn logout(&self) -> Result<()> { 294 use crate::request::{OAuthMetadata, revoke}; 295 let mut data = self.data.write().await; 296 let meta = 297 OAuthMetadata::new(self.client.as_ref(), &self.registry.client_data, &data).await?; 298 if meta.server_metadata.revocation_endpoint.is_some() { 299 let token = data.token_set.access_token.clone(); 300 revoke(self.client.as_ref(), &mut data.dpop_data, &token, &meta) 301 .await 302 .ok(); 303 } 304 // Remove from store 305 self.registry 306 .del(&data.account_did, &data.session_id) 307 .await?; 308 Ok(()) 309 } 310} 311 312impl<T, S> OAuthClient<T, S> 313where 314 T: OAuthResolver, 315 S: ClientAuthStore, 316{ 317 pub fn from_session(session: &OAuthSession<T, S>) -> Self { 318 Self { 319 registry: session.registry.clone(), 320 client: session.client.clone(), 321 } 322 } 323} 324impl<T, S> OAuthSession<T, S> 325where 326 S: ClientAuthStore + Send + Sync + 'static, 327 T: OAuthResolver + DpopExt + Send + Sync + 'static, 328{ 329 pub async fn refresh(&self) -> Result<AuthorizationToken<'_>> { 330 // Read identifiers without holding the lock across await 331 let (did, sid) = { 332 let data = self.data.read().await; 333 (data.account_did.clone(), data.session_id.clone()) 334 }; 335 let refreshed = self.registry.as_ref().get(&did, &sid, true).await?; 336 let token = AuthorizationToken::Dpop(refreshed.token_set.access_token.clone()); 337 // Write back updated session 338 *self.data.write().await = refreshed.clone().into_static(); 339 // Store in the registry 340 self.registry.set(refreshed).await?; 341 Ok(token) 342 } 343} 344 345impl<T, S> HttpClient for OAuthSession<T, S> 346where 347 S: ClientAuthStore + Send + Sync + 'static, 348 T: OAuthResolver + DpopExt + Send + Sync + 'static, 349{ 350 type Error = T::Error; 351 352 async fn send_http( 353 &self, 354 request: http::Request<Vec<u8>>, 355 ) -> core::result::Result<http::Response<Vec<u8>>, Self::Error> { 356 self.client.send_http(request).await 357 } 358} 359 360impl<T, S> XrpcClient for OAuthSession<T, S> 361where 362 S: ClientAuthStore + Send + Sync + 'static, 363 T: OAuthResolver + DpopExt + XrpcExt + Send + Sync + 'static, 364{ 365 fn base_uri(&self) -> Url { 366 // base_uri is a synchronous trait method; we must avoid async `.read().await`. 367 // Use `block_in_place` under Tokio to perform a blocking RwLock read safely. 368 if tokio::runtime::Handle::try_current().is_ok() { 369 tokio::task::block_in_place(|| self.data.blocking_read().host_url.clone()) 370 } else { 371 self.data.blocking_read().host_url.clone() 372 } 373 } 374 375 async fn opts(&self) -> CallOptions<'_> { 376 self.options.read().await.clone() 377 } 378 379 async fn send<'s, R>( 380 &self, 381 request: R, 382 ) -> XrpcResult<Response<<R as XrpcRequest<'s>>::Response>> 383 where 384 R: XrpcRequest<'s>, 385 { 386 let base_uri = self.base_uri(); 387 let mut opts = self.options.read().await.clone(); 388 opts.auth = Some(self.access_token().await); 389 let guard = self.data.read().await; 390 let mut dpop = guard.dpop_data.clone(); 391 let http_response = self 392 .client 393 .dpop_call(&mut dpop) 394 .send(build_http_request(&base_uri, &request, &opts)?) 395 .await 396 .map_err(|e| TransportError::Other(Box::new(e)))?; 397 let resp = process_response(http_response); 398 drop(guard); 399 if is_invalid_token_response(&resp) { 400 opts.auth = Some( 401 self.refresh() 402 .await 403 .map_err(|e| ClientError::Transport(TransportError::Other(e.into())))?, 404 ); 405 let guard = self.data.read().await; 406 let mut dpop = guard.dpop_data.clone(); 407 let http_response = self 408 .client 409 .dpop_call(&mut dpop) 410 .send(build_http_request(&base_uri, &request, &opts)?) 411 .await 412 .map_err(|e| TransportError::Other(Box::new(e)))?; 413 process_response(http_response) 414 } else { 415 resp 416 } 417 } 418} 419 420fn is_invalid_token_response<R: XrpcResp>(response: &XrpcResult<Response<R>>) -> bool { 421 match response { 422 Err(ClientError::Auth(AuthError::InvalidToken)) => true, 423 Err(ClientError::Auth(AuthError::Other(value))) => value 424 .to_str() 425 .is_ok_and(|s| s.starts_with("DPoP ") && s.contains("error=\"invalid_token\"")), 426 Ok(resp) => match resp.parse() { 427 Err(jacquard_common::xrpc::XrpcError::Auth(AuthError::InvalidToken)) => true, 428 _ => false, 429 }, 430 _ => false, 431 } 432}