A better Rust ATProto crate
at oauth 12 kB view raw
1use std::sync::Arc; 2 3use crate::{ 4 atproto::{AtprotoClientMetadata, atproto_client_metadata}, 5 authstore::ClientAuthStore, 6 dpop::DpopExt, 7 keyset::Keyset, 8 request::{OAuthMetadata, refresh}, 9 resolver::OAuthResolver, 10 scopes::Scope, 11 types::TokenSet, 12}; 13 14use dashmap::DashMap; 15use jacquard_common::{ 16 CowStr, IntoStatic, 17 http_client::HttpClient, 18 session::SessionStoreError, 19 types::{did::Did, string::Datetime}, 20}; 21use jose_jwk::Key; 22use serde::{Deserialize, Serialize}; 23use smol_str::{SmolStr, format_smolstr}; 24use tokio::sync::Mutex; 25use url::Url; 26 27pub trait DpopDataSource { 28 fn key(&self) -> &Key; 29 fn authserver_nonce(&self) -> Option<CowStr<'_>>; 30 fn set_authserver_nonce(&mut self, nonce: CowStr<'_>); 31 fn host_nonce(&self) -> Option<CowStr<'_>>; 32 fn set_host_nonce(&mut self, nonce: CowStr<'_>); 33} 34 35/// Persisted information about an OAuth session. Used to resume an active session. 36#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] 37pub struct ClientSessionData<'s> { 38 // Account DID for this session. Assuming only one active session per account, this can be used as "primary key" for storing and retrieving this information. 39 #[serde(borrow)] 40 pub account_did: Did<'s>, 41 42 // Identifier to distinguish this particular session for the account. Server backends generally support multiple sessions for the same account. This package will re-use the random 'state' token from the auth flow as the session ID. 43 pub session_id: CowStr<'s>, 44 45 // Base URL of the "resource server" (eg, PDS). Should include scheme, hostname, port; no path or auth info. 46 pub host_url: Url, 47 48 // Base URL of the "auth server" (eg, PDS or entryway). Should include scheme, hostname, port; no path or auth info. 49 pub authserver_url: Url, 50 51 // Full token endpoint 52 pub authserver_token_endpoint: CowStr<'s>, 53 54 // Full revocation endpoint, if it exists 55 #[serde(skip_serializing_if = "std::option::Option::is_none")] 56 pub authserver_revocation_endpoint: Option<CowStr<'s>>, 57 58 // The set of scopes approved for this session (returned in the initial token request) 59 pub scopes: Vec<Scope<'s>>, 60 61 #[serde(flatten)] 62 pub dpop_data: DpopClientData<'s>, 63 64 #[serde(flatten)] 65 pub token_set: TokenSet<'s>, 66} 67 68impl IntoStatic for ClientSessionData<'_> { 69 type Output = ClientSessionData<'static>; 70 71 fn into_static(self) -> Self::Output { 72 ClientSessionData { 73 authserver_url: self.authserver_url, 74 authserver_token_endpoint: self.authserver_token_endpoint.into_static(), 75 authserver_revocation_endpoint: self 76 .authserver_revocation_endpoint 77 .map(IntoStatic::into_static), 78 scopes: self.scopes.into_static(), 79 dpop_data: self.dpop_data.into_static(), 80 token_set: self.token_set.into_static(), 81 account_did: self.account_did.into_static(), 82 session_id: self.session_id.into_static(), 83 host_url: self.host_url, 84 } 85 } 86} 87 88impl ClientSessionData<'_> { 89 pub fn update_with_tokens(&mut self, token_set: TokenSet<'_>) { 90 if let Some(Ok(scopes)) = token_set 91 .scope 92 .as_ref() 93 .map(|scope| Scope::parse_multiple_reduced(&scope).map(IntoStatic::into_static)) 94 { 95 self.scopes = scopes; 96 } 97 self.token_set = token_set.into_static(); 98 } 99} 100 101#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] 102pub struct DpopClientData<'s> { 103 pub dpop_key: Key, 104 // Current auth server DPoP nonce 105 #[serde(borrow)] 106 pub dpop_authserver_nonce: CowStr<'s>, 107 // Current host ("resource server", eg PDS) DPoP nonce 108 pub dpop_host_nonce: CowStr<'s>, 109} 110 111impl IntoStatic for DpopClientData<'_> { 112 type Output = DpopClientData<'static>; 113 114 fn into_static(self) -> Self::Output { 115 DpopClientData { 116 dpop_key: self.dpop_key, 117 dpop_authserver_nonce: self.dpop_authserver_nonce.into_static(), 118 dpop_host_nonce: self.dpop_host_nonce.into_static(), 119 } 120 } 121} 122 123impl DpopDataSource for DpopClientData<'_> { 124 fn key(&self) -> &Key { 125 &self.dpop_key 126 } 127 fn authserver_nonce(&self) -> Option<CowStr<'_>> { 128 Some(self.dpop_authserver_nonce.clone()) 129 } 130 131 fn host_nonce(&self) -> Option<CowStr<'_>> { 132 Some(self.dpop_host_nonce.clone()) 133 } 134 135 fn set_authserver_nonce(&mut self, nonce: CowStr<'_>) { 136 self.dpop_authserver_nonce = nonce.into_static(); 137 } 138 139 fn set_host_nonce(&mut self, nonce: CowStr<'_>) { 140 self.dpop_host_nonce = nonce.into_static(); 141 } 142} 143 144#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] 145pub struct AuthRequestData<'s> { 146 // The random identifier generated by the client for the auth request flow. Can be used as "primary key" for storing and retrieving this information. 147 #[serde(borrow)] 148 pub state: CowStr<'s>, 149 150 // URL of the auth server (eg, PDS or entryway) 151 pub authserver_url: Url, 152 153 // If the flow started with an account identifier (DID or handle), it should be persisted, to verify against the initial token response. 154 #[serde(skip_serializing_if = "std::option::Option::is_none")] 155 pub account_did: Option<Did<'s>>, 156 157 // OAuth scope strings 158 pub scopes: Vec<Scope<'s>>, 159 160 // unique token in URI format, which will be used by the client in the auth flow redirect 161 pub request_uri: CowStr<'s>, 162 163 // Full token endpoint URL 164 pub authserver_token_endpoint: CowStr<'s>, 165 166 // Full revocation endpoint, if it exists 167 #[serde(skip_serializing_if = "std::option::Option::is_none")] 168 pub authserver_revocation_endpoint: Option<CowStr<'s>>, 169 170 // The secret token/nonce which a code challenge was generated from 171 pub pkce_verifier: CowStr<'s>, 172 173 #[serde(flatten)] 174 pub dpop_data: DpopReqData<'s>, 175} 176 177impl IntoStatic for AuthRequestData<'_> { 178 type Output = AuthRequestData<'static>; 179 fn into_static(self) -> AuthRequestData<'static> { 180 AuthRequestData { 181 request_uri: self.request_uri.into_static(), 182 authserver_token_endpoint: self.authserver_token_endpoint.into_static(), 183 authserver_revocation_endpoint: self 184 .authserver_revocation_endpoint 185 .map(|s| s.into_static()), 186 pkce_verifier: self.pkce_verifier.into_static(), 187 dpop_data: self.dpop_data.into_static(), 188 state: self.state.into_static(), 189 authserver_url: self.authserver_url, 190 account_did: self.account_did.into_static(), 191 scopes: self.scopes.into_static(), 192 } 193 } 194} 195 196#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] 197pub struct DpopReqData<'s> { 198 // The secret cryptographic key generated by the client for this specific OAuth session 199 pub dpop_key: Key, 200 // Server-provided DPoP nonce from auth request (PAR) 201 #[serde(borrow)] 202 pub dpop_authserver_nonce: Option<CowStr<'s>>, 203} 204 205impl IntoStatic for DpopReqData<'_> { 206 type Output = DpopReqData<'static>; 207 fn into_static(self) -> DpopReqData<'static> { 208 DpopReqData { 209 dpop_key: self.dpop_key, 210 dpop_authserver_nonce: self.dpop_authserver_nonce.into_static(), 211 } 212 } 213} 214 215impl DpopDataSource for DpopReqData<'_> { 216 fn key(&self) -> &Key { 217 &self.dpop_key 218 } 219 fn authserver_nonce(&self) -> Option<CowStr<'_>> { 220 self.dpop_authserver_nonce.clone() 221 } 222 223 fn host_nonce(&self) -> Option<CowStr<'_>> { 224 None 225 } 226 227 fn set_authserver_nonce(&mut self, nonce: CowStr<'_>) { 228 self.dpop_authserver_nonce = Some(nonce.into_static()); 229 } 230 231 fn set_host_nonce(&mut self, _nonce: CowStr<'_>) {} 232} 233 234#[derive(Clone, Debug)] 235pub struct ClientData<'s> { 236 pub keyset: Option<Keyset>, 237 pub config: AtprotoClientMetadata<'s>, 238} 239 240pub struct ClientSession<'s> { 241 pub keyset: Option<Keyset>, 242 pub config: AtprotoClientMetadata<'s>, 243 pub session_data: ClientSessionData<'s>, 244} 245 246impl<'s> ClientSession<'s> { 247 pub fn new( 248 ClientData { keyset, config }: ClientData<'s>, 249 session_data: ClientSessionData<'s>, 250 ) -> Self { 251 Self { 252 keyset, 253 config, 254 session_data, 255 } 256 } 257 258 pub async fn metadata<T: HttpClient + OAuthResolver + Send + Sync>( 259 &self, 260 client: &T, 261 ) -> Result<OAuthMetadata, Error> { 262 Ok(OAuthMetadata { 263 server_metadata: client 264 .get_authorization_server_metadata(&self.session_data.authserver_url) 265 .await 266 .map_err(|e| Error::ServerAgent(crate::request::RequestError::ResolverError(e)))?, 267 client_metadata: atproto_client_metadata(self.config.clone(), &self.keyset) 268 .unwrap() 269 .into_static(), 270 keyset: self.keyset.clone(), 271 }) 272 } 273} 274 275#[derive(thiserror::Error, Debug, miette::Diagnostic)] 276pub enum Error { 277 #[error(transparent)] 278 #[diagnostic(code(jacquard_oauth::session::request))] 279 ServerAgent(#[from] crate::request::RequestError), 280 #[error(transparent)] 281 #[diagnostic(code(jacquard_oauth::session::storage))] 282 Store(#[from] SessionStoreError), 283 #[error("session does not exist")] 284 #[diagnostic(code(jacquard_oauth::session::not_found))] 285 SessionNotFound, 286} 287 288pub struct SessionRegistry<T, S> 289where 290 T: OAuthResolver, 291 S: ClientAuthStore, 292{ 293 pub store: Arc<S>, 294 pub client: Arc<T>, 295 pub client_data: ClientData<'static>, 296 pending: DashMap<SmolStr, Arc<Mutex<()>>>, 297} 298 299impl<T, S> SessionRegistry<T, S> 300where 301 S: ClientAuthStore, 302 T: OAuthResolver, 303{ 304 pub fn new(store: S, client: Arc<T>, client_data: ClientData<'static>) -> Self { 305 let store = Arc::new(store); 306 Self { 307 store: Arc::clone(&store), 308 client, 309 client_data, 310 pending: DashMap::new(), 311 } 312 } 313} 314 315impl<T, S> SessionRegistry<T, S> 316where 317 S: ClientAuthStore + Send + Sync + 'static, 318 T: OAuthResolver + DpopExt + Send + Sync + 'static, 319{ 320 async fn get_refreshed( 321 &self, 322 did: &Did<'_>, 323 session_id: &str, 324 ) -> Result<ClientSessionData<'_>, Error> { 325 let key = format_smolstr!("{}_{}", did, session_id); 326 let lock = self 327 .pending 328 .entry(key) 329 .or_insert_with(|| Arc::new(Mutex::new(()))) 330 .clone(); 331 let _guard = lock.lock().await; 332 333 let mut session = self 334 .store 335 .get_session(did, session_id) 336 .await? 337 .ok_or(Error::SessionNotFound)?; 338 if let Some(expires_at) = &session.token_set.expires_at { 339 if expires_at > &Datetime::now() { 340 return Ok(session); 341 } 342 } 343 let metadata = 344 OAuthMetadata::new(self.client.as_ref(), &self.client_data, &session).await?; 345 session = refresh(self.client.as_ref(), session, &metadata).await?; 346 self.store.upsert_session(session.clone()).await?; 347 348 Ok(session) 349 } 350 pub async fn get( 351 &self, 352 did: &Did<'_>, 353 session_id: &str, 354 refresh: bool, 355 ) -> Result<ClientSessionData<'_>, Error> { 356 if refresh { 357 self.get_refreshed(did, session_id).await 358 } else { 359 // TODO: cached? 360 self.store 361 .get_session(did, session_id) 362 .await? 363 .ok_or(Error::SessionNotFound) 364 } 365 } 366 pub async fn set(&self, value: ClientSessionData<'_>) -> Result<(), Error> { 367 self.store.upsert_session(value).await?; 368 Ok(()) 369 } 370 pub async fn del(&self, did: &Did<'_>, session_id: &str) -> Result<(), Error> { 371 self.store.delete_session(did, session_id).await?; 372 Ok(()) 373 } 374}