use std::sync::Arc; use crate::{ atproto::{AtprotoClientMetadata, atproto_client_metadata}, authstore::ClientAuthStore, dpop::DpopExt, keyset::Keyset, request::{OAuthMetadata, refresh}, resolver::OAuthResolver, scopes::Scope, types::TokenSet, }; use dashmap::DashMap; use jacquard_common::{ CowStr, IntoStatic, http_client::HttpClient, session::SessionStoreError, types::{did::Did, string::Datetime}, }; use jose_jwk::Key; use serde::{Deserialize, Serialize}; use smol_str::{SmolStr, format_smolstr}; use tokio::sync::Mutex; use url::Url; pub trait DpopDataSource { fn key(&self) -> &Key; fn authserver_nonce(&self) -> Option>; fn set_authserver_nonce(&mut self, nonce: CowStr<'_>); fn host_nonce(&self) -> Option>; fn set_host_nonce(&mut self, nonce: CowStr<'_>); } /// Persisted information about an OAuth session. Used to resume an active session. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct ClientSessionData<'s> { // Account DID for this session. Assuming only one active session per account, this can be used as "primary key" for storing and retrieving this information. #[serde(borrow)] pub account_did: Did<'s>, // Identifier to distinguish this particular session for the account. Server backends generally support multiple sessions for the same account. This package will re-use the random 'state' token from the auth flow as the session ID. pub session_id: CowStr<'s>, // Base URL of the "resource server" (eg, PDS). Should include scheme, hostname, port; no path or auth info. pub host_url: Url, // Base URL of the "auth server" (eg, PDS or entryway). Should include scheme, hostname, port; no path or auth info. pub authserver_url: Url, // Full token endpoint pub authserver_token_endpoint: CowStr<'s>, // Full revocation endpoint, if it exists #[serde(skip_serializing_if = "std::option::Option::is_none")] pub authserver_revocation_endpoint: Option>, // The set of scopes approved for this session (returned in the initial token request) pub scopes: Vec>, #[serde(flatten)] pub dpop_data: DpopClientData<'s>, #[serde(flatten)] pub token_set: TokenSet<'s>, } impl IntoStatic for ClientSessionData<'_> { type Output = ClientSessionData<'static>; fn into_static(self) -> Self::Output { ClientSessionData { authserver_url: self.authserver_url, authserver_token_endpoint: self.authserver_token_endpoint.into_static(), authserver_revocation_endpoint: self .authserver_revocation_endpoint .map(IntoStatic::into_static), scopes: self.scopes.into_static(), dpop_data: self.dpop_data.into_static(), token_set: self.token_set.into_static(), account_did: self.account_did.into_static(), session_id: self.session_id.into_static(), host_url: self.host_url, } } } impl ClientSessionData<'_> { pub fn update_with_tokens(&mut self, token_set: TokenSet<'_>) { if let Some(Ok(scopes)) = token_set .scope .as_ref() .map(|scope| Scope::parse_multiple_reduced(&scope).map(IntoStatic::into_static)) { self.scopes = scopes; } self.token_set = token_set.into_static(); } } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct DpopClientData<'s> { pub dpop_key: Key, // Current auth server DPoP nonce #[serde(borrow)] pub dpop_authserver_nonce: CowStr<'s>, // Current host ("resource server", eg PDS) DPoP nonce pub dpop_host_nonce: CowStr<'s>, } impl IntoStatic for DpopClientData<'_> { type Output = DpopClientData<'static>; fn into_static(self) -> Self::Output { DpopClientData { dpop_key: self.dpop_key, dpop_authserver_nonce: self.dpop_authserver_nonce.into_static(), dpop_host_nonce: self.dpop_host_nonce.into_static(), } } } impl DpopDataSource for DpopClientData<'_> { fn key(&self) -> &Key { &self.dpop_key } fn authserver_nonce(&self) -> Option> { Some(self.dpop_authserver_nonce.clone()) } fn host_nonce(&self) -> Option> { Some(self.dpop_host_nonce.clone()) } fn set_authserver_nonce(&mut self, nonce: CowStr<'_>) { self.dpop_authserver_nonce = nonce.into_static(); } fn set_host_nonce(&mut self, nonce: CowStr<'_>) { self.dpop_host_nonce = nonce.into_static(); } } #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct AuthRequestData<'s> { // The random identifier generated by the client for the auth request flow. Can be used as "primary key" for storing and retrieving this information. #[serde(borrow)] pub state: CowStr<'s>, // URL of the auth server (eg, PDS or entryway) pub authserver_url: Url, // If the flow started with an account identifier (DID or handle), it should be persisted, to verify against the initial token response. #[serde(skip_serializing_if = "std::option::Option::is_none")] pub account_did: Option>, // OAuth scope strings pub scopes: Vec>, // unique token in URI format, which will be used by the client in the auth flow redirect pub request_uri: CowStr<'s>, // Full token endpoint URL pub authserver_token_endpoint: CowStr<'s>, // Full revocation endpoint, if it exists #[serde(skip_serializing_if = "std::option::Option::is_none")] pub authserver_revocation_endpoint: Option>, // The secret token/nonce which a code challenge was generated from pub pkce_verifier: CowStr<'s>, #[serde(flatten)] pub dpop_data: DpopReqData<'s>, } impl IntoStatic for AuthRequestData<'_> { type Output = AuthRequestData<'static>; fn into_static(self) -> AuthRequestData<'static> { AuthRequestData { request_uri: self.request_uri.into_static(), authserver_token_endpoint: self.authserver_token_endpoint.into_static(), authserver_revocation_endpoint: self .authserver_revocation_endpoint .map(|s| s.into_static()), pkce_verifier: self.pkce_verifier.into_static(), dpop_data: self.dpop_data.into_static(), state: self.state.into_static(), authserver_url: self.authserver_url, account_did: self.account_did.into_static(), scopes: self.scopes.into_static(), } } } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct DpopReqData<'s> { // The secret cryptographic key generated by the client for this specific OAuth session pub dpop_key: Key, // Server-provided DPoP nonce from auth request (PAR) #[serde(borrow)] pub dpop_authserver_nonce: Option>, } impl IntoStatic for DpopReqData<'_> { type Output = DpopReqData<'static>; fn into_static(self) -> DpopReqData<'static> { DpopReqData { dpop_key: self.dpop_key, dpop_authserver_nonce: self.dpop_authserver_nonce.into_static(), } } } impl DpopDataSource for DpopReqData<'_> { fn key(&self) -> &Key { &self.dpop_key } fn authserver_nonce(&self) -> Option> { self.dpop_authserver_nonce.clone() } fn host_nonce(&self) -> Option> { None } fn set_authserver_nonce(&mut self, nonce: CowStr<'_>) { self.dpop_authserver_nonce = Some(nonce.into_static()); } fn set_host_nonce(&mut self, _nonce: CowStr<'_>) {} } #[derive(Clone, Debug)] pub struct ClientData<'s> { pub keyset: Option, pub config: AtprotoClientMetadata<'s>, } pub struct ClientSession<'s> { pub keyset: Option, pub config: AtprotoClientMetadata<'s>, pub session_data: ClientSessionData<'s>, } impl<'s> ClientSession<'s> { pub fn new( ClientData { keyset, config }: ClientData<'s>, session_data: ClientSessionData<'s>, ) -> Self { Self { keyset, config, session_data, } } pub async fn metadata( &self, client: &T, ) -> Result { Ok(OAuthMetadata { server_metadata: client .get_authorization_server_metadata(&self.session_data.authserver_url) .await .map_err(|e| Error::ServerAgent(crate::request::RequestError::ResolverError(e)))?, client_metadata: atproto_client_metadata(self.config.clone(), &self.keyset) .unwrap() .into_static(), keyset: self.keyset.clone(), }) } } #[derive(thiserror::Error, Debug, miette::Diagnostic)] pub enum Error { #[error(transparent)] #[diagnostic(code(jacquard_oauth::session::request))] ServerAgent(#[from] crate::request::RequestError), #[error(transparent)] #[diagnostic(code(jacquard_oauth::session::storage))] Store(#[from] SessionStoreError), #[error("session does not exist")] #[diagnostic(code(jacquard_oauth::session::not_found))] SessionNotFound, } pub struct SessionRegistry where T: OAuthResolver, S: ClientAuthStore, { pub store: Arc, pub client: Arc, pub client_data: ClientData<'static>, pending: DashMap>>, } impl SessionRegistry where S: ClientAuthStore, T: OAuthResolver, { pub fn new(store: S, client: Arc, client_data: ClientData<'static>) -> Self { let store = Arc::new(store); Self { store: Arc::clone(&store), client, client_data, pending: DashMap::new(), } } } impl SessionRegistry where S: ClientAuthStore + Send + Sync + 'static, T: OAuthResolver + DpopExt + Send + Sync + 'static, { async fn get_refreshed( &self, did: &Did<'_>, session_id: &str, ) -> Result, Error> { let key = format_smolstr!("{}_{}", did, session_id); let lock = self .pending .entry(key) .or_insert_with(|| Arc::new(Mutex::new(()))) .clone(); let _guard = lock.lock().await; let mut session = self .store .get_session(did, session_id) .await? .ok_or(Error::SessionNotFound)?; if let Some(expires_at) = &session.token_set.expires_at { if expires_at > &Datetime::now() { return Ok(session); } } let metadata = OAuthMetadata::new(self.client.as_ref(), &self.client_data, &session).await?; session = refresh(self.client.as_ref(), session, &metadata).await?; self.store.upsert_session(session.clone()).await?; Ok(session) } pub async fn get( &self, did: &Did<'_>, session_id: &str, refresh: bool, ) -> Result, Error> { if refresh { self.get_refreshed(did, session_id).await } else { // TODO: cached? self.store .get_session(did, session_id) .await? .ok_or(Error::SessionNotFound) } } pub async fn set(&self, value: ClientSessionData<'_>) -> Result<(), Error> { self.store.upsert_session(value).await?; Ok(()) } pub async fn del(&self, did: &Did<'_>, session_id: &str) -> Result<(), Error> { self.store.delete_session(did, session_id).await?; Ok(()) } }