1use crate::{
2 atproto::atproto_client_metadata,
3 authstore::ClientAuthStore,
4 dpop::DpopExt,
5 error::{CallbackError, Result},
6 request::{OAuthMetadata, exchange_code, par},
7 resolver::OAuthResolver,
8 scopes::Scope,
9 session::{ClientData, ClientSessionData, DpopClientData, SessionRegistry},
10 types::{AuthorizeOptions, CallbackParams},
11};
12use jacquard_common::{
13 AuthorizationToken, CowStr, IntoStatic,
14 error::{AuthError, ClientError, TransportError, XrpcResult},
15 http_client::HttpClient,
16 types::{did::Did, string::Handle},
17 xrpc::{
18 CallOptions, Response, XrpcClient, XrpcExt, XrpcRequest, XrpcResp, build_http_request,
19 process_response,
20 },
21};
22use jacquard_identity::{JacquardResolver, resolver::IdentityResolver};
23use jose_jwk::JwkSet;
24use std::sync::Arc;
25use tokio::sync::RwLock;
26use url::Url;
27
28pub struct OAuthClient<T, S>
29where
30 T: OAuthResolver,
31 S: ClientAuthStore,
32{
33 pub registry: Arc<SessionRegistry<T, S>>,
34 pub client: Arc<T>,
35}
36
37impl<S: ClientAuthStore> OAuthClient<JacquardResolver, S> {
38 pub fn new(store: S, client_data: ClientData<'static>) -> Self {
39 let client = JacquardResolver::default();
40 Self::new_from_resolver(store, client, client_data)
41 }
42
43 /// Create an OAuth client with the provided store and default localhost client metadata.
44 ///
45 /// This is a convenience constructor for quickly setting up an OAuth client
46 /// with default localhost redirect URIs and "atproto transition:generic" scopes.
47 ///
48 /// # Example
49 ///
50 /// ```no_run
51 /// # use jacquard_oauth::client::OAuthClient;
52 /// # use jacquard_oauth::authstore::MemoryAuthStore;
53 /// # #[tokio::main]
54 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
55 /// let store = MemoryAuthStore::new();
56 /// let oauth = OAuthClient::with_default_config(store);
57 /// # Ok(())
58 /// # }
59 /// ```
60 pub fn with_default_config(store: S) -> Self {
61 let client_data = ClientData {
62 keyset: None,
63 config: crate::atproto::AtprotoClientMetadata::default_localhost(),
64 };
65 Self::new(store, client_data)
66 }
67}
68
69impl OAuthClient<JacquardResolver, crate::authstore::MemoryAuthStore> {
70 /// Create an OAuth client with an in-memory auth store and default localhost client metadata.
71 ///
72 /// This is a convenience constructor for simple testing and development.
73 /// The session will not persist across restarts.
74 ///
75 /// # Example
76 ///
77 /// ```no_run
78 /// # use jacquard_oauth::client::OAuthClient;
79 /// # #[tokio::main]
80 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
81 /// let oauth = OAuthClient::with_memory_store();
82 /// # Ok(())
83 /// # }
84 /// ```
85 pub fn with_memory_store() -> Self {
86 Self::with_default_config(crate::authstore::MemoryAuthStore::new())
87 }
88}
89
90impl<T, S> OAuthClient<T, S>
91where
92 T: OAuthResolver,
93 S: ClientAuthStore,
94{
95 pub fn new_from_resolver(store: S, client: T, client_data: ClientData<'static>) -> Self {
96 #[cfg(feature = "tracing")]
97 tracing::info!(
98 redirect_uris = ?client_data.config.redirect_uris,
99 scopes = ?client_data.config.scopes,
100 has_keyset = client_data.keyset.is_some(),
101 "oauth client created"
102 );
103
104 let client = Arc::new(client);
105 let registry = Arc::new(SessionRegistry::new(store, client.clone(), client_data));
106 Self { registry, client }
107 }
108
109 pub fn new_with_shared(
110 store: Arc<S>,
111 client: Arc<T>,
112 client_data: ClientData<'static>,
113 ) -> Self {
114 let registry = Arc::new(SessionRegistry::new_shared(
115 store,
116 client.clone(),
117 client_data,
118 ));
119 Self { registry, client }
120 }
121}
122
123impl<T, S> OAuthClient<T, S>
124where
125 S: ClientAuthStore + Send + Sync + 'static,
126 T: OAuthResolver + DpopExt + Send + Sync + 'static,
127{
128 pub fn jwks(&self) -> JwkSet {
129 self.registry
130 .client_data
131 .keyset
132 .as_ref()
133 .map(|keyset| keyset.public_jwks())
134 .unwrap_or_default()
135 }
136 #[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", skip(self, input), fields(input = input.as_ref())))]
137 pub async fn start_auth(
138 &self,
139 input: impl AsRef<str>,
140 options: AuthorizeOptions<'_>,
141 ) -> Result<String> {
142 let client_metadata = atproto_client_metadata(
143 self.registry.client_data.config.clone(),
144 &self.registry.client_data.keyset,
145 )?;
146
147 let (server_metadata, identity) = self.client.resolve_oauth(input.as_ref()).await?;
148 let login_hint = if identity.is_some() {
149 Some(input.as_ref().into())
150 } else {
151 None
152 };
153 let metadata = OAuthMetadata {
154 server_metadata,
155 client_metadata,
156 keyset: self.registry.client_data.keyset.clone(),
157 };
158 let auth_req_info =
159 par(self.client.as_ref(), login_hint, options.prompt, &metadata).await?;
160 // Persist state for callback handling
161 self.registry
162 .store
163 .save_auth_req_info(&auth_req_info)
164 .await?;
165
166 #[derive(serde::Serialize)]
167 struct Parameters<'s> {
168 client_id: Url,
169 request_uri: CowStr<'s>,
170 }
171 Ok(metadata.server_metadata.authorization_endpoint.to_string()
172 + "?"
173 + &serde_html_form::to_string(Parameters {
174 client_id: metadata.client_metadata.client_id.clone(),
175 request_uri: auth_req_info.request_uri,
176 })
177 .unwrap())
178 }
179
180 #[cfg_attr(feature = "tracing", tracing::instrument(level = "info", skip_all, fields(state = params.state.as_ref().map(|s| s.as_ref()))))]
181 pub async fn callback(&self, params: CallbackParams<'_>) -> Result<OAuthSession<T, S>> {
182 let Some(state_key) = params.state else {
183 return Err(CallbackError::MissingState.into());
184 };
185
186 let Some(auth_req_info) = self.registry.store.get_auth_req_info(&state_key).await? else {
187 return Err(CallbackError::MissingState.into());
188 };
189
190 self.registry.store.delete_auth_req_info(&state_key).await?;
191
192 let metadata = self
193 .client
194 .get_authorization_server_metadata(&auth_req_info.authserver_url)
195 .await?;
196
197 if let Some(iss) = params.iss {
198 if !crate::resolver::issuer_equivalent(&iss, &metadata.issuer) {
199 return Err(CallbackError::IssuerMismatch {
200 expected: metadata.issuer.to_string(),
201 got: iss.to_string(),
202 }
203 .into());
204 }
205 } else if metadata.authorization_response_iss_parameter_supported == Some(true) {
206 return Err(CallbackError::MissingIssuer.into());
207 }
208 let metadata = OAuthMetadata {
209 server_metadata: metadata,
210 client_metadata: atproto_client_metadata(
211 self.registry.client_data.config.clone(),
212 &self.registry.client_data.keyset,
213 )?,
214 keyset: self.registry.client_data.keyset.clone(),
215 };
216 let authserver_nonce = auth_req_info.dpop_data.dpop_authserver_nonce.clone();
217
218 match exchange_code(
219 self.client.as_ref(),
220 &mut auth_req_info.dpop_data.clone(),
221 ¶ms.code,
222 &auth_req_info.pkce_verifier,
223 &metadata,
224 )
225 .await
226 {
227 Ok(token_set) => {
228 let scopes = if let Some(scope) = &token_set.scope {
229 Scope::parse_multiple_reduced(&scope)
230 .expect("Failed to parse scopes")
231 .into_static()
232 } else {
233 vec![]
234 };
235 let client_data = ClientSessionData {
236 account_did: token_set.sub.clone(),
237 session_id: auth_req_info.state,
238 host_url: Url::parse(&token_set.iss).expect("Failed to parse host URL"),
239 authserver_url: auth_req_info.authserver_url,
240 authserver_token_endpoint: auth_req_info.authserver_token_endpoint,
241 authserver_revocation_endpoint: auth_req_info.authserver_revocation_endpoint,
242 scopes,
243 dpop_data: DpopClientData {
244 dpop_key: auth_req_info.dpop_data.dpop_key.clone(),
245 dpop_authserver_nonce: authserver_nonce.unwrap_or(CowStr::default()),
246 dpop_host_nonce: auth_req_info
247 .dpop_data
248 .dpop_authserver_nonce
249 .unwrap_or(CowStr::default()),
250 },
251 token_set,
252 };
253
254 self.create_session(client_data).await
255 }
256 Err(e) => Err(e.into()),
257 }
258 }
259
260 async fn create_session(&self, data: ClientSessionData<'_>) -> Result<OAuthSession<T, S>> {
261 self.registry.set(data.clone()).await?;
262 Ok(OAuthSession::new(
263 self.registry.clone(),
264 self.client.clone(),
265 data.into_static(),
266 ))
267 }
268
269 pub async fn restore(&self, did: &Did<'_>, session_id: &str) -> Result<OAuthSession<T, S>> {
270 self.create_session(self.registry.get(did, session_id, false).await?)
271 .await
272 }
273
274 pub async fn revoke(&self, did: &Did<'_>, session_id: &str) -> Result<()> {
275 Ok(self.registry.del(did, session_id).await?)
276 }
277}
278
279pub struct OAuthSession<T, S>
280where
281 T: OAuthResolver,
282 S: ClientAuthStore,
283{
284 pub registry: Arc<SessionRegistry<T, S>>,
285 pub client: Arc<T>,
286 pub data: RwLock<ClientSessionData<'static>>,
287 pub options: RwLock<CallOptions<'static>>,
288}
289
290impl<T, S> OAuthSession<T, S>
291where
292 T: OAuthResolver,
293 S: ClientAuthStore,
294{
295 pub fn new(
296 registry: Arc<SessionRegistry<T, S>>,
297 client: Arc<T>,
298 data: ClientSessionData<'static>,
299 ) -> Self {
300 Self {
301 registry,
302 client,
303 data: RwLock::new(data),
304 options: RwLock::new(CallOptions::default()),
305 }
306 }
307
308 pub fn with_options(self, options: CallOptions<'_>) -> Self {
309 Self {
310 registry: self.registry,
311 client: self.client,
312 data: self.data,
313 options: RwLock::new(options.into_static()),
314 }
315 }
316
317 pub async fn set_options(&self, options: CallOptions<'_>) {
318 *self.options.write().await = options.into_static();
319 }
320
321 pub async fn session_info(&self) -> (Did<'_>, CowStr<'_>) {
322 let data = self.data.read().await;
323 (data.account_did.clone(), data.session_id.clone())
324 }
325
326 pub async fn endpoint(&self) -> Url {
327 self.data.read().await.host_url.clone()
328 }
329
330 pub async fn access_token(&self) -> AuthorizationToken<'_> {
331 AuthorizationToken::Dpop(self.data.read().await.token_set.access_token.clone())
332 }
333
334 pub async fn refresh_token(&self) -> Option<AuthorizationToken<'_>> {
335 self.data
336 .read()
337 .await
338 .token_set
339 .refresh_token
340 .as_ref()
341 .map(|t| AuthorizationToken::Dpop(t.clone()))
342 }
343}
344impl<T, S> OAuthSession<T, S>
345where
346 S: ClientAuthStore + Send + Sync + 'static,
347 T: OAuthResolver + DpopExt + Send + Sync + 'static,
348{
349 pub async fn logout(&self) -> Result<()> {
350 use crate::request::{OAuthMetadata, revoke};
351 let mut data = self.data.write().await;
352 let meta =
353 OAuthMetadata::new(self.client.as_ref(), &self.registry.client_data, &data).await?;
354 if meta.server_metadata.revocation_endpoint.is_some() {
355 let token = data.token_set.access_token.clone();
356 revoke(self.client.as_ref(), &mut data.dpop_data, &token, &meta)
357 .await
358 .ok();
359 }
360 // Remove from store
361 self.registry
362 .del(&data.account_did, &data.session_id)
363 .await?;
364 Ok(())
365 }
366}
367
368impl<T, S> OAuthClient<T, S>
369where
370 T: OAuthResolver,
371 S: ClientAuthStore,
372{
373 pub fn from_session(session: &OAuthSession<T, S>) -> Self {
374 Self {
375 registry: session.registry.clone(),
376 client: session.client.clone(),
377 }
378 }
379}
380impl<T, S> OAuthSession<T, S>
381where
382 S: ClientAuthStore + Send + Sync + 'static,
383 T: OAuthResolver + DpopExt + Send + Sync + 'static,
384{
385 #[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", skip_all))]
386 pub async fn refresh(&self) -> Result<AuthorizationToken<'_>> {
387 // Read identifiers without holding the lock across await
388 let (did, sid) = {
389 let data = self.data.read().await;
390 (data.account_did.clone(), data.session_id.clone())
391 };
392 let refreshed = self.registry.as_ref().get(&did, &sid, true).await?;
393 let token = AuthorizationToken::Dpop(refreshed.token_set.access_token.clone());
394 // Write back updated session
395 *self.data.write().await = refreshed.clone().into_static();
396 // Store in the registry
397 self.registry.set(refreshed).await?;
398 Ok(token)
399 }
400}
401
402impl<T, S> HttpClient for OAuthSession<T, S>
403where
404 S: ClientAuthStore + Send + Sync + 'static,
405 T: OAuthResolver + DpopExt + Send + Sync + 'static,
406{
407 type Error = T::Error;
408
409 async fn send_http(
410 &self,
411 request: http::Request<Vec<u8>>,
412 ) -> core::result::Result<http::Response<Vec<u8>>, Self::Error> {
413 self.client.send_http(request).await
414 }
415}
416
417impl<T, S> XrpcClient for OAuthSession<T, S>
418where
419 S: ClientAuthStore + Send + Sync + 'static,
420 T: OAuthResolver + DpopExt + XrpcExt + Send + Sync + 'static,
421{
422 fn base_uri(&self) -> Url {
423 // base_uri is a synchronous trait method; we must avoid async `.read().await`.
424 // Use `block_in_place` under Tokio to perform a blocking RwLock read safely.
425 if tokio::runtime::Handle::try_current().is_ok() {
426 tokio::task::block_in_place(|| self.data.blocking_read().host_url.clone())
427 } else {
428 self.data.blocking_read().host_url.clone()
429 }
430 }
431
432 async fn opts(&self) -> CallOptions<'_> {
433 self.options.read().await.clone()
434 }
435
436 async fn send<R>(
437 &self,
438 request: R,
439 ) -> XrpcResult<Response<<R as XrpcRequest>::Response>>
440 where
441 R: XrpcRequest,
442 {
443 let base_uri = self.base_uri();
444 let mut opts = self.options.read().await.clone();
445 opts.auth = Some(self.access_token().await);
446 let guard = self.data.read().await;
447 let mut dpop = guard.dpop_data.clone();
448 let http_response = self
449 .client
450 .dpop_call(&mut dpop)
451 .send(build_http_request(&base_uri, &request, &opts)?)
452 .await
453 .map_err(|e| TransportError::Other(Box::new(e)))?;
454 let resp = process_response(http_response);
455 drop(guard);
456 if is_invalid_token_response(&resp) {
457 opts.auth = Some(
458 self.refresh()
459 .await
460 .map_err(|e| ClientError::Transport(TransportError::Other(e.into())))?,
461 );
462 let guard = self.data.read().await;
463 let mut dpop = guard.dpop_data.clone();
464 let http_response = self
465 .client
466 .dpop_call(&mut dpop)
467 .send(build_http_request(&base_uri, &request, &opts)?)
468 .await
469 .map_err(|e| TransportError::Other(Box::new(e)))?;
470 process_response(http_response)
471 } else {
472 resp
473 }
474 }
475}
476
477fn is_invalid_token_response<R: XrpcResp>(response: &XrpcResult<Response<R>>) -> bool {
478 match response {
479 Err(ClientError::Auth(AuthError::InvalidToken)) => true,
480 Err(ClientError::Auth(AuthError::Other(value))) => value
481 .to_str()
482 .is_ok_and(|s| s.starts_with("DPoP ") && s.contains("error=\"invalid_token\"")),
483 Ok(resp) => match resp.parse() {
484 Err(jacquard_common::xrpc::XrpcError::Auth(AuthError::InvalidToken)) => true,
485 _ => false,
486 },
487 _ => false,
488 }
489}
490
491impl<T, S> IdentityResolver for OAuthSession<T, S>
492where
493 S: ClientAuthStore + Send + Sync + 'static,
494 T: OAuthResolver + IdentityResolver + XrpcExt + Send + Sync + 'static,
495{
496 fn options(&self) -> &jacquard_identity::resolver::ResolverOptions {
497 self.client.options()
498 }
499
500 fn resolve_handle(
501 &self,
502 handle: &Handle<'_>,
503 ) -> impl Future<
504 Output = std::result::Result<Did<'static>, jacquard_identity::resolver::IdentityError>,
505 > {
506 async { self.client.resolve_handle(handle).await }
507 }
508
509 fn resolve_did_doc(
510 &self,
511 did: &Did<'_>,
512 ) -> impl Future<
513 Output = std::result::Result<
514 jacquard_identity::resolver::DidDocResponse,
515 jacquard_identity::resolver::IdentityError,
516 >,
517 > {
518 async { self.client.resolve_did_doc(did).await }
519 }
520}