1use crate::{
2 atproto::atproto_client_metadata,
3 authstore::ClientAuthStore,
4 dpop::DpopExt,
5 error::{CallbackError, Result},
6 request::{OAuthMetadata, exchange_code, par},
7 resolver::OAuthResolver,
8 scopes::Scope,
9 session::{ClientData, ClientSessionData, DpopClientData, SessionRegistry},
10 types::{AuthorizeOptions, CallbackParams},
11};
12use jacquard_common::{
13 AuthorizationToken, CowStr, IntoStatic,
14 error::{AuthError, ClientError, TransportError, XrpcResult},
15 http_client::HttpClient,
16 types::{did::Did, string::Handle},
17 xrpc::{
18 CallOptions, Response, XrpcClient, XrpcExt, XrpcRequest, XrpcResp, build_http_request,
19 process_response,
20 },
21};
22use jacquard_identity::{JacquardResolver, resolver::IdentityResolver};
23use jose_jwk::JwkSet;
24use std::sync::Arc;
25use tokio::sync::RwLock;
26use url::Url;
27
28pub struct OAuthClient<T, S>
29where
30 T: OAuthResolver,
31 S: ClientAuthStore,
32{
33 pub registry: Arc<SessionRegistry<T, S>>,
34 pub client: Arc<T>,
35}
36
37impl<S: ClientAuthStore> OAuthClient<JacquardResolver, S> {
38 pub fn new(store: S, client_data: ClientData<'static>) -> Self {
39 let client = JacquardResolver::default();
40 Self::new_from_resolver(store, client, client_data)
41 }
42
43 /// Create an OAuth client with the provided store and default localhost client metadata.
44 ///
45 /// This is a convenience constructor for quickly setting up an OAuth client
46 /// with default localhost redirect URIs and "atproto transition:generic" scopes.
47 ///
48 /// # Example
49 ///
50 /// ```no_run
51 /// # use jacquard_oauth::client::OAuthClient;
52 /// # use jacquard_oauth::authstore::MemoryAuthStore;
53 /// # #[tokio::main]
54 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
55 /// let store = MemoryAuthStore::new();
56 /// let oauth = OAuthClient::with_default_config(store);
57 /// # Ok(())
58 /// # }
59 /// ```
60 pub fn with_default_config(store: S) -> Self {
61 let client_data = ClientData {
62 keyset: None,
63 config: crate::atproto::AtprotoClientMetadata::default_localhost(),
64 };
65 Self::new(store, client_data)
66 }
67}
68
69impl OAuthClient<JacquardResolver, crate::authstore::MemoryAuthStore> {
70 /// Create an OAuth client with an in-memory auth store and default localhost client metadata.
71 ///
72 /// This is a convenience constructor for simple testing and development.
73 /// The session will not persist across restarts.
74 ///
75 /// # Example
76 ///
77 /// ```no_run
78 /// # use jacquard_oauth::client::OAuthClient;
79 /// # #[tokio::main]
80 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
81 /// let oauth = OAuthClient::with_memory_store();
82 /// # Ok(())
83 /// # }
84 /// ```
85 pub fn with_memory_store() -> Self {
86 Self::with_default_config(crate::authstore::MemoryAuthStore::new())
87 }
88}
89
90impl<T, S> OAuthClient<T, S>
91where
92 T: OAuthResolver,
93 S: ClientAuthStore,
94{
95 pub fn new_from_resolver(store: S, client: T, client_data: ClientData<'static>) -> Self {
96 let client = Arc::new(client);
97 let registry = Arc::new(SessionRegistry::new(store, client.clone(), client_data));
98 Self { registry, client }
99 }
100
101 pub fn new_with_shared(
102 store: Arc<S>,
103 client: Arc<T>,
104 client_data: ClientData<'static>,
105 ) -> Self {
106 let registry = Arc::new(SessionRegistry::new_shared(
107 store,
108 client.clone(),
109 client_data,
110 ));
111 Self { registry, client }
112 }
113}
114
115impl<T, S> OAuthClient<T, S>
116where
117 S: ClientAuthStore + Send + Sync + 'static,
118 T: OAuthResolver + DpopExt + Send + Sync + 'static,
119{
120 pub fn jwks(&self) -> JwkSet {
121 self.registry
122 .client_data
123 .keyset
124 .as_ref()
125 .map(|keyset| keyset.public_jwks())
126 .unwrap_or_default()
127 }
128 pub async fn start_auth(
129 &self,
130 input: impl AsRef<str>,
131 options: AuthorizeOptions<'_>,
132 ) -> Result<String> {
133 let client_metadata = atproto_client_metadata(
134 self.registry.client_data.config.clone(),
135 &self.registry.client_data.keyset,
136 )?;
137
138 let (server_metadata, identity) = self.client.resolve_oauth(input.as_ref()).await?;
139 let login_hint = if identity.is_some() {
140 Some(input.as_ref().into())
141 } else {
142 None
143 };
144 let metadata = OAuthMetadata {
145 server_metadata,
146 client_metadata,
147 keyset: self.registry.client_data.keyset.clone(),
148 };
149 let auth_req_info =
150 par(self.client.as_ref(), login_hint, options.prompt, &metadata).await?;
151 // Persist state for callback handling
152 self.registry
153 .store
154 .save_auth_req_info(&auth_req_info)
155 .await?;
156
157 #[derive(serde::Serialize)]
158 struct Parameters<'s> {
159 client_id: Url,
160 request_uri: CowStr<'s>,
161 }
162 Ok(metadata.server_metadata.authorization_endpoint.to_string()
163 + "?"
164 + &serde_html_form::to_string(Parameters {
165 client_id: metadata.client_metadata.client_id.clone(),
166 request_uri: auth_req_info.request_uri,
167 })
168 .unwrap())
169 }
170
171 pub async fn callback(&self, params: CallbackParams<'_>) -> Result<OAuthSession<T, S>> {
172 let Some(state_key) = params.state else {
173 return Err(CallbackError::MissingState.into());
174 };
175
176 let Some(auth_req_info) = self.registry.store.get_auth_req_info(&state_key).await? else {
177 return Err(CallbackError::MissingState.into());
178 };
179
180 self.registry.store.delete_auth_req_info(&state_key).await?;
181
182 let metadata = self
183 .client
184 .get_authorization_server_metadata(&auth_req_info.authserver_url)
185 .await?;
186
187 if let Some(iss) = params.iss {
188 if !crate::resolver::issuer_equivalent(&iss, &metadata.issuer) {
189 return Err(CallbackError::IssuerMismatch {
190 expected: metadata.issuer.to_string(),
191 got: iss.to_string(),
192 }
193 .into());
194 }
195 } else if metadata.authorization_response_iss_parameter_supported == Some(true) {
196 return Err(CallbackError::MissingIssuer.into());
197 }
198 let metadata = OAuthMetadata {
199 server_metadata: metadata,
200 client_metadata: atproto_client_metadata(
201 self.registry.client_data.config.clone(),
202 &self.registry.client_data.keyset,
203 )?,
204 keyset: self.registry.client_data.keyset.clone(),
205 };
206 let authserver_nonce = auth_req_info.dpop_data.dpop_authserver_nonce.clone();
207
208 match exchange_code(
209 self.client.as_ref(),
210 &mut auth_req_info.dpop_data.clone(),
211 ¶ms.code,
212 &auth_req_info.pkce_verifier,
213 &metadata,
214 )
215 .await
216 {
217 Ok(token_set) => {
218 let scopes = if let Some(scope) = &token_set.scope {
219 Scope::parse_multiple_reduced(&scope)
220 .expect("Failed to parse scopes")
221 .into_static()
222 } else {
223 vec![]
224 };
225 let client_data = ClientSessionData {
226 account_did: token_set.sub.clone(),
227 session_id: auth_req_info.state,
228 host_url: Url::parse(&token_set.iss).expect("Failed to parse host URL"),
229 authserver_url: auth_req_info.authserver_url,
230 authserver_token_endpoint: auth_req_info.authserver_token_endpoint,
231 authserver_revocation_endpoint: auth_req_info.authserver_revocation_endpoint,
232 scopes,
233 dpop_data: DpopClientData {
234 dpop_key: auth_req_info.dpop_data.dpop_key.clone(),
235 dpop_authserver_nonce: authserver_nonce.unwrap_or(CowStr::default()),
236 dpop_host_nonce: auth_req_info
237 .dpop_data
238 .dpop_authserver_nonce
239 .unwrap_or(CowStr::default()),
240 },
241 token_set,
242 };
243
244 self.create_session(client_data).await
245 }
246 Err(e) => Err(e.into()),
247 }
248 }
249
250 async fn create_session(&self, data: ClientSessionData<'_>) -> Result<OAuthSession<T, S>> {
251 self.registry.set(data.clone()).await?;
252 Ok(OAuthSession::new(
253 self.registry.clone(),
254 self.client.clone(),
255 data.into_static(),
256 ))
257 }
258
259 pub async fn restore(&self, did: &Did<'_>, session_id: &str) -> Result<OAuthSession<T, S>> {
260 self.create_session(self.registry.get(did, session_id, false).await?)
261 .await
262 }
263
264 pub async fn revoke(&self, did: &Did<'_>, session_id: &str) -> Result<()> {
265 Ok(self.registry.del(did, session_id).await?)
266 }
267}
268
269pub struct OAuthSession<T, S>
270where
271 T: OAuthResolver,
272 S: ClientAuthStore,
273{
274 pub registry: Arc<SessionRegistry<T, S>>,
275 pub client: Arc<T>,
276 pub data: RwLock<ClientSessionData<'static>>,
277 pub options: RwLock<CallOptions<'static>>,
278}
279
280impl<T, S> OAuthSession<T, S>
281where
282 T: OAuthResolver,
283 S: ClientAuthStore,
284{
285 pub fn new(
286 registry: Arc<SessionRegistry<T, S>>,
287 client: Arc<T>,
288 data: ClientSessionData<'static>,
289 ) -> Self {
290 Self {
291 registry,
292 client,
293 data: RwLock::new(data),
294 options: RwLock::new(CallOptions::default()),
295 }
296 }
297
298 pub fn with_options(self, options: CallOptions<'_>) -> Self {
299 Self {
300 registry: self.registry,
301 client: self.client,
302 data: self.data,
303 options: RwLock::new(options.into_static()),
304 }
305 }
306
307 pub async fn set_options(&self, options: CallOptions<'_>) {
308 *self.options.write().await = options.into_static();
309 }
310
311 pub async fn session_info(&self) -> (Did<'_>, CowStr<'_>) {
312 let data = self.data.read().await;
313 (data.account_did.clone(), data.session_id.clone())
314 }
315
316 pub async fn endpoint(&self) -> Url {
317 self.data.read().await.host_url.clone()
318 }
319
320 pub async fn access_token(&self) -> AuthorizationToken<'_> {
321 AuthorizationToken::Dpop(self.data.read().await.token_set.access_token.clone())
322 }
323
324 pub async fn refresh_token(&self) -> Option<AuthorizationToken<'_>> {
325 self.data
326 .read()
327 .await
328 .token_set
329 .refresh_token
330 .as_ref()
331 .map(|t| AuthorizationToken::Dpop(t.clone()))
332 }
333}
334impl<T, S> OAuthSession<T, S>
335where
336 S: ClientAuthStore + Send + Sync + 'static,
337 T: OAuthResolver + DpopExt + Send + Sync + 'static,
338{
339 pub async fn logout(&self) -> Result<()> {
340 use crate::request::{OAuthMetadata, revoke};
341 let mut data = self.data.write().await;
342 let meta =
343 OAuthMetadata::new(self.client.as_ref(), &self.registry.client_data, &data).await?;
344 if meta.server_metadata.revocation_endpoint.is_some() {
345 let token = data.token_set.access_token.clone();
346 revoke(self.client.as_ref(), &mut data.dpop_data, &token, &meta)
347 .await
348 .ok();
349 }
350 // Remove from store
351 self.registry
352 .del(&data.account_did, &data.session_id)
353 .await?;
354 Ok(())
355 }
356}
357
358impl<T, S> OAuthClient<T, S>
359where
360 T: OAuthResolver,
361 S: ClientAuthStore,
362{
363 pub fn from_session(session: &OAuthSession<T, S>) -> Self {
364 Self {
365 registry: session.registry.clone(),
366 client: session.client.clone(),
367 }
368 }
369}
370impl<T, S> OAuthSession<T, S>
371where
372 S: ClientAuthStore + Send + Sync + 'static,
373 T: OAuthResolver + DpopExt + Send + Sync + 'static,
374{
375 pub async fn refresh(&self) -> Result<AuthorizationToken<'_>> {
376 // Read identifiers without holding the lock across await
377 let (did, sid) = {
378 let data = self.data.read().await;
379 (data.account_did.clone(), data.session_id.clone())
380 };
381 let refreshed = self.registry.as_ref().get(&did, &sid, true).await?;
382 let token = AuthorizationToken::Dpop(refreshed.token_set.access_token.clone());
383 // Write back updated session
384 *self.data.write().await = refreshed.clone().into_static();
385 // Store in the registry
386 self.registry.set(refreshed).await?;
387 Ok(token)
388 }
389}
390
391impl<T, S> HttpClient for OAuthSession<T, S>
392where
393 S: ClientAuthStore + Send + Sync + 'static,
394 T: OAuthResolver + DpopExt + Send + Sync + 'static,
395{
396 type Error = T::Error;
397
398 async fn send_http(
399 &self,
400 request: http::Request<Vec<u8>>,
401 ) -> core::result::Result<http::Response<Vec<u8>>, Self::Error> {
402 self.client.send_http(request).await
403 }
404}
405
406impl<T, S> XrpcClient for OAuthSession<T, S>
407where
408 S: ClientAuthStore + Send + Sync + 'static,
409 T: OAuthResolver + DpopExt + XrpcExt + Send + Sync + 'static,
410{
411 fn base_uri(&self) -> Url {
412 // base_uri is a synchronous trait method; we must avoid async `.read().await`.
413 // Use `block_in_place` under Tokio to perform a blocking RwLock read safely.
414 if tokio::runtime::Handle::try_current().is_ok() {
415 tokio::task::block_in_place(|| self.data.blocking_read().host_url.clone())
416 } else {
417 self.data.blocking_read().host_url.clone()
418 }
419 }
420
421 async fn opts(&self) -> CallOptions<'_> {
422 self.options.read().await.clone()
423 }
424
425 async fn send<'s, R>(
426 &self,
427 request: R,
428 ) -> XrpcResult<Response<<R as XrpcRequest<'s>>::Response>>
429 where
430 R: XrpcRequest<'s>,
431 {
432 let base_uri = self.base_uri();
433 let mut opts = self.options.read().await.clone();
434 opts.auth = Some(self.access_token().await);
435 let guard = self.data.read().await;
436 let mut dpop = guard.dpop_data.clone();
437 let http_response = self
438 .client
439 .dpop_call(&mut dpop)
440 .send(build_http_request(&base_uri, &request, &opts)?)
441 .await
442 .map_err(|e| TransportError::Other(Box::new(e)))?;
443 let resp = process_response(http_response);
444 drop(guard);
445 if is_invalid_token_response(&resp) {
446 opts.auth = Some(
447 self.refresh()
448 .await
449 .map_err(|e| ClientError::Transport(TransportError::Other(e.into())))?,
450 );
451 let guard = self.data.read().await;
452 let mut dpop = guard.dpop_data.clone();
453 let http_response = self
454 .client
455 .dpop_call(&mut dpop)
456 .send(build_http_request(&base_uri, &request, &opts)?)
457 .await
458 .map_err(|e| TransportError::Other(Box::new(e)))?;
459 process_response(http_response)
460 } else {
461 resp
462 }
463 }
464}
465
466fn is_invalid_token_response<R: XrpcResp>(response: &XrpcResult<Response<R>>) -> bool {
467 match response {
468 Err(ClientError::Auth(AuthError::InvalidToken)) => true,
469 Err(ClientError::Auth(AuthError::Other(value))) => value
470 .to_str()
471 .is_ok_and(|s| s.starts_with("DPoP ") && s.contains("error=\"invalid_token\"")),
472 Ok(resp) => match resp.parse() {
473 Err(jacquard_common::xrpc::XrpcError::Auth(AuthError::InvalidToken)) => true,
474 _ => false,
475 },
476 _ => false,
477 }
478}
479
480impl<T, S> IdentityResolver for OAuthSession<T, S>
481where
482 S: ClientAuthStore + Send + Sync + 'static,
483 T: OAuthResolver + IdentityResolver + XrpcExt + Send + Sync + 'static,
484{
485 fn options(&self) -> &jacquard_identity::resolver::ResolverOptions {
486 self.client.options()
487 }
488
489 fn resolve_handle(
490 &self,
491 handle: &Handle<'_>,
492 ) -> impl Future<
493 Output = std::result::Result<Did<'static>, jacquard_identity::resolver::IdentityError>,
494 > {
495 async { self.client.resolve_handle(handle).await }
496 }
497
498 fn resolve_did_doc(
499 &self,
500 did: &Did<'_>,
501 ) -> impl Future<
502 Output = std::result::Result<
503 jacquard_identity::resolver::DidDocResponse,
504 jacquard_identity::resolver::IdentityError,
505 >,
506 > {
507 async { self.client.resolve_did_doc(did).await }
508 }
509}