A better Rust ATProto crate
1use jacquard_common::{ 2 AuthorizationToken, CowStr, IntoStatic, 3 error::{AuthError, ClientError, TransportError, XrpcResult}, 4 http_client::HttpClient, 5 types::{ 6 did::Did, 7 xrpc::{CallOptions, Response, XrpcClient, XrpcExt, XrpcRequest}, 8 }, 9}; 10use jose_jwk::JwkSet; 11use smol_str::SmolStr; 12use std::sync::Arc; 13use tokio::sync::RwLock; 14use url::Url; 15 16use crate::{ 17 atproto::atproto_client_metadata, 18 authstore::ClientAuthStore, 19 dpop::DpopExt, 20 error::{OAuthError, Result}, 21 request::{OAuthMetadata, exchange_code, par}, 22 resolver::OAuthResolver, 23 scopes::Scope, 24 session::{ClientData, ClientSessionData, DpopClientData, SessionRegistry}, 25 types::{AuthorizeOptions, CallbackParams}, 26}; 27 28pub struct OAuthClient<T, S> 29where 30 T: OAuthResolver, 31 S: ClientAuthStore, 32{ 33 pub registry: Arc<SessionRegistry<T, S>>, 34 pub client: Arc<T>, 35} 36 37impl<T, S> OAuthClient<T, S> 38where 39 T: OAuthResolver, 40 S: ClientAuthStore, 41{ 42 pub fn new_from_resolver(store: S, client: T, client_data: ClientData<'static>) -> Self { 43 let client = Arc::new(client); 44 let registry = Arc::new(SessionRegistry::new(store, client.clone(), client_data)); 45 Self { registry, client } 46 } 47} 48 49impl<T, S> OAuthClient<T, S> 50where 51 S: ClientAuthStore + Send + Sync + 'static, 52 T: OAuthResolver + DpopExt + Send + Sync + 'static, 53{ 54 pub fn jwks(&self) -> JwkSet { 55 self.registry 56 .client_data 57 .keyset 58 .as_ref() 59 .map(|keyset| keyset.public_jwks()) 60 .unwrap_or_default() 61 } 62 pub async fn start_auth( 63 &self, 64 input: impl AsRef<str>, 65 options: AuthorizeOptions<'_>, 66 ) -> Result<String> { 67 let client_metadata = atproto_client_metadata( 68 self.registry.client_data.config.clone(), 69 &self.registry.client_data.keyset, 70 )?; 71 72 let (server_metadata, identity) = self.client.resolve_oauth(input.as_ref()).await?; 73 let login_hint = if identity.is_some() { 74 Some(input.as_ref().into()) 75 } else { 76 None 77 }; 78 let metadata = OAuthMetadata { 79 server_metadata, 80 client_metadata, 81 keyset: self.registry.client_data.keyset.clone(), 82 }; 83 let auth_req_info = 84 par(self.client.as_ref(), login_hint, options.prompt, &metadata).await?; 85 86 #[derive(serde::Serialize)] 87 struct Parameters<'s> { 88 client_id: Url, 89 request_uri: CowStr<'s>, 90 } 91 Ok(metadata.server_metadata.authorization_endpoint.to_string() 92 + "?" 93 + &serde_html_form::to_string(Parameters { 94 client_id: metadata.client_metadata.client_id.clone(), 95 request_uri: auth_req_info.request_uri, 96 }) 97 .unwrap()) 98 } 99 100 pub async fn callback(&self, params: CallbackParams<'_>) -> Result<OAuthSession<T, S>> { 101 let Some(state_key) = params.state else { 102 return Err(OAuthError::Callback("missing state parameter".into())); 103 }; 104 105 let Some(auth_req_info) = self.registry.store.get_auth_req_info(&state_key).await? else { 106 return Err(OAuthError::Callback(format!( 107 "unknown authorization state: {state_key}" 108 ))); 109 }; 110 111 self.registry.store.delete_auth_req_info(&state_key).await?; 112 113 let metadata = self 114 .client 115 .get_authorization_server_metadata(&auth_req_info.authserver_url) 116 .await?; 117 118 if let Some(iss) = params.iss { 119 if iss != metadata.issuer { 120 return Err(OAuthError::Callback(format!( 121 "issuer mismatch: expected {}, got {iss}", 122 metadata.issuer 123 ))); 124 } 125 } else if metadata.authorization_response_iss_parameter_supported == Some(true) { 126 return Err(OAuthError::Callback("missing `iss` parameter".into())); 127 } 128 let metadata = OAuthMetadata { 129 server_metadata: metadata, 130 client_metadata: atproto_client_metadata( 131 self.registry.client_data.config.clone(), 132 &self.registry.client_data.keyset, 133 )?, 134 keyset: self.registry.client_data.keyset.clone(), 135 }; 136 let authserver_nonce = auth_req_info.dpop_data.dpop_authserver_nonce.clone(); 137 138 match exchange_code( 139 self.client.as_ref(), 140 &mut auth_req_info.dpop_data.clone(), 141 &params.code, 142 &auth_req_info.pkce_verifier, 143 &metadata, 144 ) 145 .await 146 { 147 Ok(token_set) => { 148 let scopes = if let Some(scope) = &token_set.scope { 149 Scope::parse_multiple_reduced(&scope) 150 .expect("Failed to parse scopes") 151 .into_static() 152 } else { 153 vec![] 154 }; 155 let client_data = ClientSessionData { 156 account_did: token_set.sub.clone(), 157 session_id: auth_req_info.state, 158 host_url: Url::parse(&token_set.iss).expect("Failed to parse host URL"), 159 authserver_url: auth_req_info.authserver_url, 160 authserver_token_endpoint: auth_req_info.authserver_token_endpoint, 161 authserver_revocation_endpoint: auth_req_info.authserver_revocation_endpoint, 162 scopes, 163 dpop_data: DpopClientData { 164 dpop_key: auth_req_info.dpop_data.dpop_key.clone(), 165 dpop_authserver_nonce: authserver_nonce.unwrap_or(CowStr::default()), 166 dpop_host_nonce: auth_req_info 167 .dpop_data 168 .dpop_authserver_nonce 169 .unwrap_or(CowStr::default()), 170 }, 171 token_set, 172 }; 173 174 self.create_session(client_data).await 175 } 176 Err(e) => Err(e.into()), 177 } 178 } 179 180 async fn create_session(&self, data: ClientSessionData<'_>) -> Result<OAuthSession<T, S>> { 181 Ok(OAuthSession::new( 182 self.registry.clone(), 183 self.client.clone(), 184 data.into_static(), 185 )) 186 } 187 188 pub async fn restore(&self, did: &Did<'_>, session_id: &str) -> Result<OAuthSession<T, S>> { 189 self.create_session(self.registry.get(did, session_id, false).await?) 190 .await 191 } 192 193 pub async fn revoke(&self, did: &Did<'_>, session_id: &str) -> Result<()> { 194 Ok(self.registry.del(did, session_id).await?) 195 } 196} 197 198pub struct OAuthSession<T, S> 199where 200 T: OAuthResolver, 201 S: ClientAuthStore, 202{ 203 pub registry: Arc<SessionRegistry<T, S>>, 204 pub client: Arc<T>, 205 pub data: RwLock<ClientSessionData<'static>>, 206 pub options: RwLock<CallOptions<'static>>, 207} 208 209impl<T, S> OAuthSession<T, S> 210where 211 T: OAuthResolver, 212 S: ClientAuthStore, 213{ 214 pub fn new( 215 registry: Arc<SessionRegistry<T, S>>, 216 client: Arc<T>, 217 data: ClientSessionData<'static>, 218 ) -> Self { 219 Self { 220 registry, 221 client, 222 data: RwLock::new(data), 223 options: RwLock::new(CallOptions::default()), 224 } 225 } 226 227 pub fn with_options(self, options: CallOptions<'_>) -> Self { 228 Self { 229 registry: self.registry, 230 client: self.client, 231 data: self.data, 232 options: RwLock::new(options.into_static()), 233 } 234 } 235 236 pub async fn set_options(&self, options: CallOptions<'_>) { 237 *self.options.write().await = options.into_static(); 238 } 239 240 pub async fn session_info(&self) -> (Did<'_>, CowStr<'_>) { 241 let data = self.data.read().await; 242 (data.account_did.clone(), data.session_id.clone()) 243 } 244 245 pub async fn pds(&self) -> Url { 246 self.data.read().await.host_url.clone() 247 } 248 249 pub async fn access_token(&self) -> AuthorizationToken<'_> { 250 AuthorizationToken::Dpop(self.data.read().await.token_set.access_token.clone()) 251 } 252 253 pub async fn refresh_token(&self) -> Option<AuthorizationToken<'_>> { 254 self.data 255 .read() 256 .await 257 .token_set 258 .refresh_token 259 .as_ref() 260 .map(|token| AuthorizationToken::Dpop(token.clone())) 261 } 262} 263impl<T, S> OAuthSession<T, S> 264where 265 S: ClientAuthStore + Send + Sync + 'static, 266 T: OAuthResolver + DpopExt + Send + Sync + 'static, 267{ 268 pub async fn refresh(&self) -> Result<AuthorizationToken<'_>> { 269 let mut data = self.data.write().await; 270 let refreshed = self 271 .registry 272 .as_ref() 273 .get(&data.account_did, &data.session_id, true) 274 .await?; 275 let token = AuthorizationToken::Dpop(refreshed.token_set.access_token.clone()); 276 *data = refreshed.into_static(); 277 Ok(token) 278 } 279} 280 281impl<T, S> HttpClient for OAuthSession<T, S> 282where 283 S: ClientAuthStore + Send + Sync + 'static, 284 T: OAuthResolver + DpopExt + Send + Sync + 'static, 285{ 286 type Error = T::Error; 287 288 async fn send_http( 289 &self, 290 request: http::Request<Vec<u8>>, 291 ) -> core::result::Result<http::Response<Vec<u8>>, Self::Error> { 292 self.client.send_http(request).await 293 } 294} 295 296impl<T, S> XrpcClient for OAuthSession<T, S> 297where 298 S: ClientAuthStore + Send + Sync + 'static, 299 T: OAuthResolver + DpopExt + XrpcExt + Send + Sync + 'static, 300{ 301 fn base_uri(&self) -> Url { 302 self.data.blocking_read().host_url.clone() 303 } 304 305 async fn opts(&self) -> CallOptions<'_> { 306 self.options.read().await.clone() 307 } 308 309 async fn send<R: jacquard_common::types::xrpc::XrpcRequest + Send>( 310 self, 311 request: &R, 312 ) -> XrpcResult<Response<R>> { 313 let base_uri = self.base_uri(); 314 let auth = self.access_token().await; 315 let mut opts = self.options.read().await.clone(); 316 opts.auth = Some(auth); 317 let res = self 318 .client 319 .xrpc(base_uri.clone()) 320 .with_options(opts.clone()) 321 .send(request) 322 .await; 323 if is_invalid_token_response(&res) { 324 opts.auth = Some( 325 self.refresh() 326 .await 327 .map_err(|e| ClientError::Transport(TransportError::Other(e.into())))?, 328 ); 329 self.client 330 .xrpc(base_uri) 331 .with_options(opts) 332 .send(request) 333 .await 334 } else { 335 res 336 } 337 } 338} 339 340fn is_invalid_token_response<R: XrpcRequest>(response: &XrpcResult<Response<R>>) -> bool { 341 match response { 342 Err(ClientError::Auth(AuthError::InvalidToken)) => true, 343 Err(ClientError::Auth(AuthError::Other(value))) => value 344 .to_str() 345 .is_ok_and(|s| s.starts_with("DPoP ") && s.contains("error=\"invalid_token\"")), 346 _ => false, 347 } 348}