1use std::sync::Arc;
2
3use crate::{
4 atproto::{AtprotoClientMetadata, atproto_client_metadata},
5 authstore::ClientAuthStore,
6 dpop::DpopExt,
7 keyset::Keyset,
8 request::{OAuthMetadata, refresh},
9 resolver::OAuthResolver,
10 scopes::Scope,
11 types::TokenSet,
12};
13
14use dashmap::DashMap;
15use jacquard_common::{
16 CowStr, IntoStatic,
17 http_client::HttpClient,
18 session::SessionStoreError,
19 types::{did::Did, string::Datetime},
20};
21use jose_jwk::Key;
22use serde::{Deserialize, Serialize};
23use smol_str::{SmolStr, format_smolstr};
24use tokio::sync::Mutex;
25use url::Url;
26
27pub trait DpopDataSource {
28 fn key(&self) -> &Key;
29 fn authserver_nonce(&self) -> Option<CowStr<'_>>;
30 fn set_authserver_nonce(&mut self, nonce: CowStr<'_>);
31 fn host_nonce(&self) -> Option<CowStr<'_>>;
32 fn set_host_nonce(&mut self, nonce: CowStr<'_>);
33}
34
35/// Persisted information about an OAuth session. Used to resume an active session.
36#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
37pub struct ClientSessionData<'s> {
38 // Account DID for this session. Assuming only one active session per account, this can be used as "primary key" for storing and retrieving this information.
39 #[serde(borrow)]
40 pub account_did: Did<'s>,
41
42 // Identifier to distinguish this particular session for the account. Server backends generally support multiple sessions for the same account. This package will re-use the random 'state' token from the auth flow as the session ID.
43 pub session_id: CowStr<'s>,
44
45 // Base URL of the "resource server" (eg, PDS). Should include scheme, hostname, port; no path or auth info.
46 pub host_url: Url,
47
48 // Base URL of the "auth server" (eg, PDS or entryway). Should include scheme, hostname, port; no path or auth info.
49 pub authserver_url: Url,
50
51 // Full token endpoint
52 pub authserver_token_endpoint: CowStr<'s>,
53
54 // Full revocation endpoint, if it exists
55 #[serde(skip_serializing_if = "std::option::Option::is_none")]
56 pub authserver_revocation_endpoint: Option<CowStr<'s>>,
57
58 // The set of scopes approved for this session (returned in the initial token request)
59 pub scopes: Vec<Scope<'s>>,
60
61 #[serde(flatten)]
62 pub dpop_data: DpopClientData<'s>,
63
64 #[serde(flatten)]
65 pub token_set: TokenSet<'s>,
66}
67
68impl IntoStatic for ClientSessionData<'_> {
69 type Output = ClientSessionData<'static>;
70
71 fn into_static(self) -> Self::Output {
72 ClientSessionData {
73 authserver_url: self.authserver_url,
74 authserver_token_endpoint: self.authserver_token_endpoint.into_static(),
75 authserver_revocation_endpoint: self
76 .authserver_revocation_endpoint
77 .map(IntoStatic::into_static),
78 scopes: self.scopes.into_static(),
79 dpop_data: self.dpop_data.into_static(),
80 token_set: self.token_set.into_static(),
81 account_did: self.account_did.into_static(),
82 session_id: self.session_id.into_static(),
83 host_url: self.host_url,
84 }
85 }
86}
87
88impl ClientSessionData<'_> {
89 pub fn update_with_tokens(&mut self, token_set: TokenSet<'_>) {
90 if let Some(Ok(scopes)) = token_set
91 .scope
92 .as_ref()
93 .map(|scope| Scope::parse_multiple_reduced(&scope).map(IntoStatic::into_static))
94 {
95 self.scopes = scopes;
96 }
97 self.token_set = token_set.into_static();
98 }
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
102pub struct DpopClientData<'s> {
103 pub dpop_key: Key,
104 // Current auth server DPoP nonce
105 #[serde(borrow)]
106 pub dpop_authserver_nonce: CowStr<'s>,
107 // Current host ("resource server", eg PDS) DPoP nonce
108 pub dpop_host_nonce: CowStr<'s>,
109}
110
111impl IntoStatic for DpopClientData<'_> {
112 type Output = DpopClientData<'static>;
113
114 fn into_static(self) -> Self::Output {
115 DpopClientData {
116 dpop_key: self.dpop_key,
117 dpop_authserver_nonce: self.dpop_authserver_nonce.into_static(),
118 dpop_host_nonce: self.dpop_host_nonce.into_static(),
119 }
120 }
121}
122
123impl DpopDataSource for DpopClientData<'_> {
124 fn key(&self) -> &Key {
125 &self.dpop_key
126 }
127 fn authserver_nonce(&self) -> Option<CowStr<'_>> {
128 Some(self.dpop_authserver_nonce.clone())
129 }
130
131 fn host_nonce(&self) -> Option<CowStr<'_>> {
132 Some(self.dpop_host_nonce.clone())
133 }
134
135 fn set_authserver_nonce(&mut self, nonce: CowStr<'_>) {
136 self.dpop_authserver_nonce = nonce.into_static();
137 }
138
139 fn set_host_nonce(&mut self, nonce: CowStr<'_>) {
140 self.dpop_host_nonce = nonce.into_static();
141 }
142}
143
144#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
145pub struct AuthRequestData<'s> {
146 // The random identifier generated by the client for the auth request flow. Can be used as "primary key" for storing and retrieving this information.
147 #[serde(borrow)]
148 pub state: CowStr<'s>,
149
150 // URL of the auth server (eg, PDS or entryway)
151 pub authserver_url: Url,
152
153 // If the flow started with an account identifier (DID or handle), it should be persisted, to verify against the initial token response.
154 #[serde(skip_serializing_if = "std::option::Option::is_none")]
155 pub account_did: Option<Did<'s>>,
156
157 // OAuth scope strings
158 pub scopes: Vec<Scope<'s>>,
159
160 // unique token in URI format, which will be used by the client in the auth flow redirect
161 pub request_uri: CowStr<'s>,
162
163 // Full token endpoint URL
164 pub authserver_token_endpoint: CowStr<'s>,
165
166 // Full revocation endpoint, if it exists
167 #[serde(skip_serializing_if = "std::option::Option::is_none")]
168 pub authserver_revocation_endpoint: Option<CowStr<'s>>,
169
170 // The secret token/nonce which a code challenge was generated from
171 pub pkce_verifier: CowStr<'s>,
172
173 #[serde(flatten)]
174 pub dpop_data: DpopReqData<'s>,
175}
176
177impl IntoStatic for AuthRequestData<'_> {
178 type Output = AuthRequestData<'static>;
179 fn into_static(self) -> AuthRequestData<'static> {
180 AuthRequestData {
181 request_uri: self.request_uri.into_static(),
182 authserver_token_endpoint: self.authserver_token_endpoint.into_static(),
183 authserver_revocation_endpoint: self
184 .authserver_revocation_endpoint
185 .map(|s| s.into_static()),
186 pkce_verifier: self.pkce_verifier.into_static(),
187 dpop_data: self.dpop_data.into_static(),
188 state: self.state.into_static(),
189 authserver_url: self.authserver_url,
190 account_did: self.account_did.into_static(),
191 scopes: self.scopes.into_static(),
192 }
193 }
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
197pub struct DpopReqData<'s> {
198 // The secret cryptographic key generated by the client for this specific OAuth session
199 pub dpop_key: Key,
200 // Server-provided DPoP nonce from auth request (PAR)
201 #[serde(borrow)]
202 pub dpop_authserver_nonce: Option<CowStr<'s>>,
203}
204
205impl IntoStatic for DpopReqData<'_> {
206 type Output = DpopReqData<'static>;
207 fn into_static(self) -> DpopReqData<'static> {
208 DpopReqData {
209 dpop_key: self.dpop_key,
210 dpop_authserver_nonce: self.dpop_authserver_nonce.into_static(),
211 }
212 }
213}
214
215impl DpopDataSource for DpopReqData<'_> {
216 fn key(&self) -> &Key {
217 &self.dpop_key
218 }
219 fn authserver_nonce(&self) -> Option<CowStr<'_>> {
220 self.dpop_authserver_nonce.clone()
221 }
222
223 fn host_nonce(&self) -> Option<CowStr<'_>> {
224 None
225 }
226
227 fn set_authserver_nonce(&mut self, nonce: CowStr<'_>) {
228 self.dpop_authserver_nonce = Some(nonce.into_static());
229 }
230
231 fn set_host_nonce(&mut self, _nonce: CowStr<'_>) {}
232}
233
234#[derive(Clone, Debug)]
235pub struct ClientData<'s> {
236 pub keyset: Option<Keyset>,
237 pub config: AtprotoClientMetadata<'s>,
238}
239
240pub struct ClientSession<'s> {
241 pub keyset: Option<Keyset>,
242 pub config: AtprotoClientMetadata<'s>,
243 pub session_data: ClientSessionData<'s>,
244}
245
246impl<'s> ClientSession<'s> {
247 pub fn new(
248 ClientData { keyset, config }: ClientData<'s>,
249 session_data: ClientSessionData<'s>,
250 ) -> Self {
251 Self {
252 keyset,
253 config,
254 session_data,
255 }
256 }
257
258 pub async fn metadata<T: HttpClient + OAuthResolver + Send + Sync>(
259 &self,
260 client: &T,
261 ) -> Result<OAuthMetadata, Error> {
262 Ok(OAuthMetadata {
263 server_metadata: client
264 .get_authorization_server_metadata(&self.session_data.authserver_url)
265 .await
266 .map_err(|e| Error::ServerAgent(crate::request::RequestError::ResolverError(e)))?,
267 client_metadata: atproto_client_metadata(self.config.clone(), &self.keyset)
268 .unwrap()
269 .into_static(),
270 keyset: self.keyset.clone(),
271 })
272 }
273}
274
275#[derive(thiserror::Error, Debug, miette::Diagnostic)]
276pub enum Error {
277 #[error(transparent)]
278 #[diagnostic(code(jacquard_oauth::session::request))]
279 ServerAgent(#[from] crate::request::RequestError),
280 #[error(transparent)]
281 #[diagnostic(code(jacquard_oauth::session::storage))]
282 Store(#[from] SessionStoreError),
283 #[error("session does not exist")]
284 #[diagnostic(code(jacquard_oauth::session::not_found))]
285 SessionNotFound,
286}
287
288pub struct SessionRegistry<T, S>
289where
290 T: OAuthResolver,
291 S: ClientAuthStore,
292{
293 pub store: Arc<S>,
294 pub client: Arc<T>,
295 pub client_data: ClientData<'static>,
296 pending: DashMap<SmolStr, Arc<Mutex<()>>>,
297}
298
299impl<T, S> SessionRegistry<T, S>
300where
301 S: ClientAuthStore,
302 T: OAuthResolver,
303{
304 pub fn new(store: S, client: Arc<T>, client_data: ClientData<'static>) -> Self {
305 let store = Arc::new(store);
306 Self {
307 store: Arc::clone(&store),
308 client,
309 client_data,
310 pending: DashMap::new(),
311 }
312 }
313}
314
315impl<T, S> SessionRegistry<T, S>
316where
317 S: ClientAuthStore + Send + Sync + 'static,
318 T: OAuthResolver + DpopExt + Send + Sync + 'static,
319{
320 async fn get_refreshed(
321 &self,
322 did: &Did<'_>,
323 session_id: &str,
324 ) -> Result<ClientSessionData<'_>, Error> {
325 let key = format_smolstr!("{}_{}", did, session_id);
326 let lock = self
327 .pending
328 .entry(key)
329 .or_insert_with(|| Arc::new(Mutex::new(())))
330 .clone();
331 let _guard = lock.lock().await;
332
333 let mut session = self
334 .store
335 .get_session(did, session_id)
336 .await?
337 .ok_or(Error::SessionNotFound)?;
338 if let Some(expires_at) = &session.token_set.expires_at {
339 if expires_at > &Datetime::now() {
340 return Ok(session);
341 }
342 }
343 let metadata =
344 OAuthMetadata::new(self.client.as_ref(), &self.client_data, &session).await?;
345 session = refresh(self.client.as_ref(), session, &metadata).await?;
346 self.store.upsert_session(session.clone()).await?;
347
348 Ok(session)
349 }
350 pub async fn get(
351 &self,
352 did: &Did<'_>,
353 session_id: &str,
354 refresh: bool,
355 ) -> Result<ClientSessionData<'_>, Error> {
356 if refresh {
357 self.get_refreshed(did, session_id).await
358 } else {
359 // TODO: cached?
360 self.store
361 .get_session(did, session_id)
362 .await?
363 .ok_or(Error::SessionNotFound)
364 }
365 }
366 pub async fn set(&self, value: ClientSessionData<'_>) -> Result<(), Error> {
367 self.store.upsert_session(value).await?;
368 Ok(())
369 }
370 pub async fn del(&self, did: &Did<'_>, session_id: &str) -> Result<(), Error> {
371 self.store.delete_session(did, session_id).await?;
372 Ok(())
373 }
374}