A better Rust ATProto crate
1use std::sync::Arc; 2 3use crate::{ 4 atproto::{AtprotoClientMetadata, atproto_client_metadata}, 5 authstore::ClientAuthStore, 6 dpop::DpopExt, 7 keyset::Keyset, 8 request::{OAuthMetadata, refresh}, 9 resolver::OAuthResolver, 10 scopes::Scope, 11 types::TokenSet, 12}; 13 14use dashmap::DashMap; 15use jacquard_common::{ 16 CowStr, IntoStatic, 17 http_client::HttpClient, 18 session::SessionStoreError, 19 types::{did::Did, string::Datetime}, 20}; 21use jose_jwk::Key; 22use serde::{Deserialize, Serialize}; 23use smol_str::{SmolStr, format_smolstr}; 24use tokio::sync::Mutex; 25use url::Url; 26 27pub trait DpopDataSource { 28 fn key(&self) -> &Key; 29 fn authserver_nonce(&self) -> Option<CowStr<'_>>; 30 fn set_authserver_nonce(&mut self, nonce: CowStr<'_>); 31 fn host_nonce(&self) -> Option<CowStr<'_>>; 32 fn set_host_nonce(&mut self, nonce: CowStr<'_>); 33} 34 35/// Persisted information about an OAuth session. Used to resume an active session. 36#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] 37pub struct ClientSessionData<'s> { 38 // Account DID for this session. Assuming only one active session per account, this can be used as "primary key" for storing and retrieving this information. 39 #[serde(borrow)] 40 pub account_did: Did<'s>, 41 42 // Identifier to distinguish this particular session for the account. Server backends generally support multiple sessions for the same account. This package will re-use the random 'state' token from the auth flow as the session ID. 43 pub session_id: CowStr<'s>, 44 45 // Base URL of the "resource server" (eg, PDS). Should include scheme, hostname, port; no path or auth info. 46 pub host_url: Url, 47 48 // Base URL of the "auth server" (eg, PDS or entryway). Should include scheme, hostname, port; no path or auth info. 49 pub authserver_url: Url, 50 51 // Full token endpoint 52 pub authserver_token_endpoint: CowStr<'s>, 53 54 // Full revocation endpoint, if it exists 55 #[serde(skip_serializing_if = "std::option::Option::is_none")] 56 pub authserver_revocation_endpoint: Option<CowStr<'s>>, 57 58 // The set of scopes approved for this session (returned in the initial token request) 59 pub scopes: Vec<Scope<'s>>, 60 61 #[serde(flatten)] 62 pub dpop_data: DpopClientData<'s>, 63 64 #[serde(flatten)] 65 pub token_set: TokenSet<'s>, 66} 67 68impl IntoStatic for ClientSessionData<'_> { 69 type Output = ClientSessionData<'static>; 70 71 fn into_static(self) -> Self::Output { 72 ClientSessionData { 73 authserver_url: self.authserver_url, 74 authserver_token_endpoint: self.authserver_token_endpoint.into_static(), 75 authserver_revocation_endpoint: self 76 .authserver_revocation_endpoint 77 .map(IntoStatic::into_static), 78 scopes: self.scopes.into_static(), 79 dpop_data: self.dpop_data.into_static(), 80 token_set: self.token_set.into_static(), 81 account_did: self.account_did.into_static(), 82 session_id: self.session_id.into_static(), 83 host_url: self.host_url, 84 } 85 } 86} 87 88impl ClientSessionData<'_> { 89 pub fn update_with_tokens(&mut self, token_set: TokenSet<'_>) { 90 if let Some(Ok(scopes)) = token_set 91 .scope 92 .as_ref() 93 .map(|scope| Scope::parse_multiple_reduced(&scope).map(IntoStatic::into_static)) 94 { 95 self.scopes = scopes; 96 } 97 self.token_set = token_set.into_static(); 98 } 99} 100 101#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] 102pub struct DpopClientData<'s> { 103 pub dpop_key: Key, 104 // Current auth server DPoP nonce 105 #[serde(borrow)] 106 pub dpop_authserver_nonce: CowStr<'s>, 107 // Current host ("resource server", eg PDS) DPoP nonce 108 pub dpop_host_nonce: CowStr<'s>, 109} 110 111impl IntoStatic for DpopClientData<'_> { 112 type Output = DpopClientData<'static>; 113 114 fn into_static(self) -> Self::Output { 115 DpopClientData { 116 dpop_key: self.dpop_key, 117 dpop_authserver_nonce: self.dpop_authserver_nonce.into_static(), 118 dpop_host_nonce: self.dpop_host_nonce.into_static(), 119 } 120 } 121} 122 123impl DpopDataSource for DpopClientData<'_> { 124 fn key(&self) -> &Key { 125 &self.dpop_key 126 } 127 fn authserver_nonce(&self) -> Option<CowStr<'_>> { 128 Some(self.dpop_authserver_nonce.clone()) 129 } 130 131 fn host_nonce(&self) -> Option<CowStr<'_>> { 132 Some(self.dpop_host_nonce.clone()) 133 } 134 135 fn set_authserver_nonce(&mut self, nonce: CowStr<'_>) { 136 self.dpop_authserver_nonce = nonce.into_static(); 137 } 138 139 fn set_host_nonce(&mut self, nonce: CowStr<'_>) { 140 self.dpop_host_nonce = nonce.into_static(); 141 } 142} 143 144#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] 145pub struct AuthRequestData<'s> { 146 // The random identifier generated by the client for the auth request flow. Can be used as "primary key" for storing and retrieving this information. 147 #[serde(borrow)] 148 pub state: CowStr<'s>, 149 150 // URL of the auth server (eg, PDS or entryway) 151 pub authserver_url: Url, 152 153 // If the flow started with an account identifier (DID or handle), it should be persisted, to verify against the initial token response. 154 #[serde(skip_serializing_if = "std::option::Option::is_none")] 155 pub account_did: Option<Did<'s>>, 156 157 // OAuth scope strings 158 pub scopes: Vec<Scope<'s>>, 159 160 // unique token in URI format, which will be used by the client in the auth flow redirect 161 pub request_uri: CowStr<'s>, 162 163 // Full token endpoint URL 164 pub authserver_token_endpoint: CowStr<'s>, 165 166 // Full revocation endpoint, if it exists 167 #[serde(skip_serializing_if = "std::option::Option::is_none")] 168 pub authserver_revocation_endpoint: Option<CowStr<'s>>, 169 170 // The secret token/nonce which a code challenge was generated from 171 pub pkce_verifier: CowStr<'s>, 172 173 #[serde(flatten)] 174 pub dpop_data: DpopReqData<'s>, 175} 176 177impl IntoStatic for AuthRequestData<'_> { 178 type Output = AuthRequestData<'static>; 179 fn into_static(self) -> AuthRequestData<'static> { 180 AuthRequestData { 181 request_uri: self.request_uri.into_static(), 182 authserver_token_endpoint: self.authserver_token_endpoint.into_static(), 183 authserver_revocation_endpoint: self 184 .authserver_revocation_endpoint 185 .map(|s| s.into_static()), 186 pkce_verifier: self.pkce_verifier.into_static(), 187 dpop_data: self.dpop_data.into_static(), 188 state: self.state.into_static(), 189 authserver_url: self.authserver_url, 190 account_did: self.account_did.into_static(), 191 scopes: self.scopes.into_static(), 192 } 193 } 194} 195 196#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] 197pub struct DpopReqData<'s> { 198 // The secret cryptographic key generated by the client for this specific OAuth session 199 pub dpop_key: Key, 200 // Server-provided DPoP nonce from auth request (PAR) 201 #[serde(borrow)] 202 pub dpop_authserver_nonce: Option<CowStr<'s>>, 203} 204 205impl IntoStatic for DpopReqData<'_> { 206 type Output = DpopReqData<'static>; 207 fn into_static(self) -> DpopReqData<'static> { 208 DpopReqData { 209 dpop_key: self.dpop_key, 210 dpop_authserver_nonce: self.dpop_authserver_nonce.into_static(), 211 } 212 } 213} 214 215impl DpopDataSource for DpopReqData<'_> { 216 fn key(&self) -> &Key { 217 &self.dpop_key 218 } 219 fn authserver_nonce(&self) -> Option<CowStr<'_>> { 220 self.dpop_authserver_nonce.clone() 221 } 222 223 fn host_nonce(&self) -> Option<CowStr<'_>> { 224 None 225 } 226 227 fn set_authserver_nonce(&mut self, nonce: CowStr<'_>) { 228 self.dpop_authserver_nonce = Some(nonce.into_static()); 229 } 230 231 fn set_host_nonce(&mut self, _nonce: CowStr<'_>) {} 232} 233 234#[derive(Clone, Debug)] 235pub struct ClientData<'s> { 236 pub keyset: Option<Keyset>, 237 pub config: AtprotoClientMetadata<'s>, 238} 239 240pub struct ClientSession<'s> { 241 pub keyset: Option<Keyset>, 242 pub config: AtprotoClientMetadata<'s>, 243 pub session_data: ClientSessionData<'s>, 244} 245 246impl<'s> ClientSession<'s> { 247 pub fn new( 248 ClientData { keyset, config }: ClientData<'s>, 249 session_data: ClientSessionData<'s>, 250 ) -> Self { 251 Self { 252 keyset, 253 config, 254 session_data, 255 } 256 } 257 258 pub async fn metadata<T: HttpClient + OAuthResolver + Send + Sync>( 259 &self, 260 client: &T, 261 ) -> Result<OAuthMetadata, Error> { 262 Ok(OAuthMetadata { 263 server_metadata: client 264 .get_authorization_server_metadata(&self.session_data.authserver_url) 265 .await 266 .map_err(|e| Error::ServerAgent(crate::request::Error::ResolverError(e)))?, 267 client_metadata: atproto_client_metadata(self.config.clone(), &self.keyset) 268 .unwrap() 269 .into_static(), 270 keyset: self.keyset.clone(), 271 }) 272 } 273} 274 275#[derive(thiserror::Error, Debug)] 276pub enum Error { 277 #[error(transparent)] 278 ServerAgent(#[from] crate::request::Error), 279 #[error(transparent)] 280 Store(#[from] SessionStoreError), 281 #[error("session does not exist")] 282 SessionNotFound, 283} 284 285pub struct SessionRegistry<T, S> 286where 287 T: OAuthResolver, 288 S: ClientAuthStore, 289{ 290 pub store: Arc<S>, 291 pub client: Arc<T>, 292 pub client_data: ClientData<'static>, 293 pending: DashMap<SmolStr, Arc<Mutex<()>>>, 294} 295 296impl<T, S> SessionRegistry<T, S> 297where 298 S: ClientAuthStore, 299 T: OAuthResolver, 300{ 301 pub fn new(store: S, client: Arc<T>, client_data: ClientData<'static>) -> Self { 302 let store = Arc::new(store); 303 Self { 304 store: Arc::clone(&store), 305 client, 306 client_data, 307 pending: DashMap::new(), 308 } 309 } 310} 311 312impl<T, S> SessionRegistry<T, S> 313where 314 S: ClientAuthStore + Send + Sync + 'static, 315 T: OAuthResolver + DpopExt + Send + Sync + 'static, 316{ 317 async fn get_refreshed( 318 &self, 319 did: &Did<'_>, 320 session_id: &str, 321 ) -> Result<ClientSessionData<'_>, Error> { 322 let key = format_smolstr!("{}_{}", did, session_id); 323 let lock = self 324 .pending 325 .entry(key) 326 .or_insert_with(|| Arc::new(Mutex::new(()))) 327 .clone(); 328 let _guard = lock.lock().await; 329 330 let mut session = self 331 .store 332 .get_session(did, session_id) 333 .await? 334 .ok_or(Error::SessionNotFound)?; 335 if let Some(expires_at) = &session.token_set.expires_at { 336 if expires_at > &Datetime::now() { 337 return Ok(session); 338 } 339 } 340 let metadata = 341 OAuthMetadata::new(self.client.as_ref(), &self.client_data, &session).await?; 342 session = refresh(self.client.as_ref(), session, &metadata).await?; 343 self.store.upsert_session(session.clone()).await?; 344 345 Ok(session) 346 } 347 pub async fn get( 348 &self, 349 did: &Did<'_>, 350 session_id: &str, 351 refresh: bool, 352 ) -> Result<ClientSessionData<'_>, Error> { 353 if refresh { 354 self.get_refreshed(did, session_id).await 355 } else { 356 // TODO: cached? 357 self.store 358 .get_session(did, session_id) 359 .await? 360 .ok_or(Error::SessionNotFound) 361 } 362 } 363 pub async fn set(&self, value: ClientSessionData<'_>) -> Result<(), Error> { 364 self.store.upsert_session(value).await?; 365 Ok(()) 366 } 367 pub async fn del(&self, did: &Did<'_>, session_id: &str) -> Result<(), Error> { 368 self.store.delete_session(did, session_id).await?; 369 Ok(()) 370 } 371}