1use std::sync::Arc;
2
3use crate::{
4 atproto::{AtprotoClientMetadata, atproto_client_metadata},
5 authstore::ClientAuthStore,
6 dpop::DpopExt,
7 keyset::Keyset,
8 request::{OAuthMetadata, refresh},
9 resolver::OAuthResolver,
10 scopes::Scope,
11 types::TokenSet,
12};
13
14use dashmap::DashMap;
15use jacquard_common::{
16 CowStr, IntoStatic,
17 http_client::HttpClient,
18 session::SessionStoreError,
19 types::{did::Did, string::Datetime},
20};
21use jose_jwk::Key;
22use serde::{Deserialize, Serialize};
23use smol_str::{SmolStr, format_smolstr};
24use tokio::sync::Mutex;
25use url::Url;
26
27pub trait DpopDataSource {
28 fn key(&self) -> &Key;
29 fn authserver_nonce(&self) -> Option<CowStr<'_>>;
30 fn set_authserver_nonce(&mut self, nonce: CowStr<'_>);
31 fn host_nonce(&self) -> Option<CowStr<'_>>;
32 fn set_host_nonce(&mut self, nonce: CowStr<'_>);
33}
34
35/// Persisted information about an OAuth session. Used to resume an active session.
36#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
37pub struct ClientSessionData<'s> {
38 // Account DID for this session. Assuming only one active session per account, this can be used as "primary key" for storing and retrieving this information.
39 #[serde(borrow)]
40 pub account_did: Did<'s>,
41
42 // Identifier to distinguish this particular session for the account. Server backends generally support multiple sessions for the same account. This package will re-use the random 'state' token from the auth flow as the session ID.
43 pub session_id: CowStr<'s>,
44
45 // Base URL of the "resource server" (eg, PDS). Should include scheme, hostname, port; no path or auth info.
46 pub host_url: Url,
47
48 // Base URL of the "auth server" (eg, PDS or entryway). Should include scheme, hostname, port; no path or auth info.
49 pub authserver_url: Url,
50
51 // Full token endpoint
52 pub authserver_token_endpoint: CowStr<'s>,
53
54 // Full revocation endpoint, if it exists
55 #[serde(skip_serializing_if = "std::option::Option::is_none")]
56 pub authserver_revocation_endpoint: Option<CowStr<'s>>,
57
58 // The set of scopes approved for this session (returned in the initial token request)
59 pub scopes: Vec<Scope<'s>>,
60
61 #[serde(flatten)]
62 pub dpop_data: DpopClientData<'s>,
63
64 #[serde(flatten)]
65 pub token_set: TokenSet<'s>,
66}
67
68impl IntoStatic for ClientSessionData<'_> {
69 type Output = ClientSessionData<'static>;
70
71 fn into_static(self) -> Self::Output {
72 ClientSessionData {
73 authserver_url: self.authserver_url,
74 authserver_token_endpoint: self.authserver_token_endpoint.into_static(),
75 authserver_revocation_endpoint: self
76 .authserver_revocation_endpoint
77 .map(IntoStatic::into_static),
78 scopes: self.scopes.into_static(),
79 dpop_data: self.dpop_data.into_static(),
80 token_set: self.token_set.into_static(),
81 account_did: self.account_did.into_static(),
82 session_id: self.session_id.into_static(),
83 host_url: self.host_url,
84 }
85 }
86}
87
88impl ClientSessionData<'_> {
89 pub fn update_with_tokens(&mut self, token_set: TokenSet<'_>) {
90 if let Some(Ok(scopes)) = token_set
91 .scope
92 .as_ref()
93 .map(|scope| Scope::parse_multiple_reduced(&scope).map(IntoStatic::into_static))
94 {
95 self.scopes = scopes;
96 }
97 self.token_set = token_set.into_static();
98 }
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
102pub struct DpopClientData<'s> {
103 pub dpop_key: Key,
104 // Current auth server DPoP nonce
105 #[serde(borrow)]
106 pub dpop_authserver_nonce: CowStr<'s>,
107 // Current host ("resource server", eg PDS) DPoP nonce
108 pub dpop_host_nonce: CowStr<'s>,
109}
110
111impl IntoStatic for DpopClientData<'_> {
112 type Output = DpopClientData<'static>;
113
114 fn into_static(self) -> Self::Output {
115 DpopClientData {
116 dpop_key: self.dpop_key,
117 dpop_authserver_nonce: self.dpop_authserver_nonce.into_static(),
118 dpop_host_nonce: self.dpop_host_nonce.into_static(),
119 }
120 }
121}
122
123impl DpopDataSource for DpopClientData<'_> {
124 fn key(&self) -> &Key {
125 &self.dpop_key
126 }
127 fn authserver_nonce(&self) -> Option<CowStr<'_>> {
128 Some(self.dpop_authserver_nonce.clone())
129 }
130
131 fn host_nonce(&self) -> Option<CowStr<'_>> {
132 Some(self.dpop_host_nonce.clone())
133 }
134
135 fn set_authserver_nonce(&mut self, nonce: CowStr<'_>) {
136 self.dpop_authserver_nonce = nonce.into_static();
137 }
138
139 fn set_host_nonce(&mut self, nonce: CowStr<'_>) {
140 self.dpop_host_nonce = nonce.into_static();
141 }
142}
143
144#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
145pub struct AuthRequestData<'s> {
146 // The random identifier generated by the client for the auth request flow. Can be used as "primary key" for storing and retrieving this information.
147 #[serde(borrow)]
148 pub state: CowStr<'s>,
149
150 // URL of the auth server (eg, PDS or entryway)
151 pub authserver_url: Url,
152
153 // If the flow started with an account identifier (DID or handle), it should be persisted, to verify against the initial token response.
154 #[serde(skip_serializing_if = "std::option::Option::is_none")]
155 pub account_did: Option<Did<'s>>,
156
157 // OAuth scope strings
158 pub scopes: Vec<Scope<'s>>,
159
160 // unique token in URI format, which will be used by the client in the auth flow redirect
161 pub request_uri: CowStr<'s>,
162
163 // Full token endpoint URL
164 pub authserver_token_endpoint: CowStr<'s>,
165
166 // Full revocation endpoint, if it exists
167 #[serde(skip_serializing_if = "std::option::Option::is_none")]
168 pub authserver_revocation_endpoint: Option<CowStr<'s>>,
169
170 // The secret token/nonce which a code challenge was generated from
171 pub pkce_verifier: CowStr<'s>,
172
173 #[serde(flatten)]
174 pub dpop_data: DpopReqData<'s>,
175}
176
177impl IntoStatic for AuthRequestData<'_> {
178 type Output = AuthRequestData<'static>;
179 fn into_static(self) -> AuthRequestData<'static> {
180 AuthRequestData {
181 request_uri: self.request_uri.into_static(),
182 authserver_token_endpoint: self.authserver_token_endpoint.into_static(),
183 authserver_revocation_endpoint: self
184 .authserver_revocation_endpoint
185 .map(|s| s.into_static()),
186 pkce_verifier: self.pkce_verifier.into_static(),
187 dpop_data: self.dpop_data.into_static(),
188 state: self.state.into_static(),
189 authserver_url: self.authserver_url,
190 account_did: self.account_did.into_static(),
191 scopes: self.scopes.into_static(),
192 }
193 }
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
197pub struct DpopReqData<'s> {
198 // The secret cryptographic key generated by the client for this specific OAuth session
199 pub dpop_key: Key,
200 // Server-provided DPoP nonce from auth request (PAR)
201 #[serde(borrow)]
202 pub dpop_authserver_nonce: Option<CowStr<'s>>,
203}
204
205impl IntoStatic for DpopReqData<'_> {
206 type Output = DpopReqData<'static>;
207 fn into_static(self) -> DpopReqData<'static> {
208 DpopReqData {
209 dpop_key: self.dpop_key,
210 dpop_authserver_nonce: self.dpop_authserver_nonce.into_static(),
211 }
212 }
213}
214
215impl DpopDataSource for DpopReqData<'_> {
216 fn key(&self) -> &Key {
217 &self.dpop_key
218 }
219 fn authserver_nonce(&self) -> Option<CowStr<'_>> {
220 self.dpop_authserver_nonce.clone()
221 }
222
223 fn host_nonce(&self) -> Option<CowStr<'_>> {
224 None
225 }
226
227 fn set_authserver_nonce(&mut self, nonce: CowStr<'_>) {
228 self.dpop_authserver_nonce = Some(nonce.into_static());
229 }
230
231 fn set_host_nonce(&mut self, _nonce: CowStr<'_>) {}
232}
233
234#[derive(Clone, Debug)]
235pub struct ClientData<'s> {
236 pub keyset: Option<Keyset>,
237 pub config: AtprotoClientMetadata<'s>,
238}
239
240pub struct ClientSession<'s> {
241 pub keyset: Option<Keyset>,
242 pub config: AtprotoClientMetadata<'s>,
243 pub session_data: ClientSessionData<'s>,
244}
245
246impl<'s> ClientSession<'s> {
247 pub fn new(
248 ClientData { keyset, config }: ClientData<'s>,
249 session_data: ClientSessionData<'s>,
250 ) -> Self {
251 Self {
252 keyset,
253 config,
254 session_data,
255 }
256 }
257
258 pub async fn metadata<T: HttpClient + OAuthResolver + Send + Sync>(
259 &self,
260 client: &T,
261 ) -> Result<OAuthMetadata, Error> {
262 Ok(OAuthMetadata {
263 server_metadata: client
264 .get_authorization_server_metadata(&self.session_data.authserver_url)
265 .await
266 .map_err(|e| Error::ServerAgent(crate::request::RequestError::ResolverError(e)))?,
267 client_metadata: atproto_client_metadata(self.config.clone(), &self.keyset)
268 .unwrap()
269 .into_static(),
270 keyset: self.keyset.clone(),
271 })
272 }
273}
274
275#[derive(thiserror::Error, Debug, miette::Diagnostic)]
276pub enum Error {
277 #[error(transparent)]
278 #[diagnostic(code(jacquard_oauth::session::request))]
279 ServerAgent(#[from] crate::request::RequestError),
280 #[error(transparent)]
281 #[diagnostic(code(jacquard_oauth::session::storage))]
282 Store(#[from] SessionStoreError),
283 #[error("session does not exist")]
284 #[diagnostic(code(jacquard_oauth::session::not_found))]
285 SessionNotFound,
286}
287
288pub struct SessionRegistry<T, S>
289where
290 T: OAuthResolver,
291 S: ClientAuthStore,
292{
293 pub store: Arc<S>,
294 pub client: Arc<T>,
295 pub client_data: ClientData<'static>,
296 pending: DashMap<SmolStr, Arc<Mutex<()>>>,
297}
298
299impl<T, S> SessionRegistry<T, S>
300where
301 S: ClientAuthStore,
302 T: OAuthResolver,
303{
304 pub fn new(store: S, client: Arc<T>, client_data: ClientData<'static>) -> Self {
305 let store = Arc::new(store);
306 Self {
307 store: Arc::clone(&store),
308 client,
309 client_data,
310 pending: DashMap::new(),
311 }
312 }
313
314 pub fn new_shared(store: Arc<S>, client: Arc<T>, client_data: ClientData<'static>) -> Self {
315 Self {
316 store,
317 client,
318 client_data,
319 pending: DashMap::new(),
320 }
321 }
322}
323
324impl<T, S> SessionRegistry<T, S>
325where
326 S: ClientAuthStore + Send + Sync + 'static,
327 T: OAuthResolver + DpopExt + Send + Sync + 'static,
328{
329 async fn get_refreshed(
330 &self,
331 did: &Did<'_>,
332 session_id: &str,
333 ) -> Result<ClientSessionData<'_>, Error> {
334 let key = format_smolstr!("{}_{}", did, session_id);
335 let lock = self
336 .pending
337 .entry(key)
338 .or_insert_with(|| Arc::new(Mutex::new(())))
339 .clone();
340 let _guard = lock.lock().await;
341
342 let mut session = self
343 .store
344 .get_session(did, session_id)
345 .await?
346 .ok_or(Error::SessionNotFound)?;
347 if let Some(expires_at) = &session.token_set.expires_at {
348 if expires_at > &Datetime::now() {
349 return Ok(session);
350 }
351 }
352 let metadata =
353 OAuthMetadata::new(self.client.as_ref(), &self.client_data, &session).await?;
354 session = refresh(self.client.as_ref(), session, &metadata).await?;
355 self.store.upsert_session(session.clone()).await?;
356
357 Ok(session)
358 }
359 pub async fn get(
360 &self,
361 did: &Did<'_>,
362 session_id: &str,
363 refresh: bool,
364 ) -> Result<ClientSessionData<'_>, Error> {
365 if refresh {
366 self.get_refreshed(did, session_id).await
367 } else {
368 // TODO: cached?
369 self.store
370 .get_session(did, session_id)
371 .await?
372 .ok_or(Error::SessionNotFound)
373 }
374 }
375 pub async fn set(&self, value: ClientSessionData<'_>) -> Result<(), Error> {
376 self.store.upsert_session(value).await?;
377 Ok(())
378 }
379 pub async fn del(&self, did: &Did<'_>, session_id: &str) -> Result<(), Error> {
380 self.store.delete_session(did, session_id).await?;
381 Ok(())
382 }
383}