1use crate::{
2 atproto::atproto_client_metadata,
3 authstore::ClientAuthStore,
4 dpop::DpopExt,
5 error::{CallbackError, Result},
6 request::{OAuthMetadata, exchange_code, par},
7 resolver::OAuthResolver,
8 scopes::Scope,
9 session::{ClientData, ClientSessionData, DpopClientData, SessionRegistry},
10 types::{AuthorizeOptions, CallbackParams},
11};
12use jacquard_common::{
13 AuthorizationToken, CowStr, IntoStatic,
14 error::{AuthError, ClientError, TransportError, XrpcResult},
15 http_client::HttpClient,
16 types::{
17 did::Did,
18 xrpc::{
19 CallOptions, Response, XrpcClient, XrpcExt, XrpcRequest, build_http_request,
20 process_response,
21 },
22 },
23};
24use jacquard_identity::JacquardResolver;
25use jose_jwk::JwkSet;
26use std::sync::Arc;
27use tokio::sync::RwLock;
28use url::Url;
29
30pub struct OAuthClient<T, S>
31where
32 T: OAuthResolver,
33 S: ClientAuthStore,
34{
35 pub registry: Arc<SessionRegistry<T, S>>,
36 pub client: Arc<T>,
37}
38
39impl<S: ClientAuthStore> OAuthClient<JacquardResolver, S> {
40 pub fn new(store: S, client_data: ClientData<'static>) -> Self {
41 let client = JacquardResolver::default();
42 Self::new_from_resolver(store, client, client_data)
43 }
44}
45
46impl<T, S> OAuthClient<T, S>
47where
48 T: OAuthResolver,
49 S: ClientAuthStore,
50{
51 pub fn new_from_resolver(store: S, client: T, client_data: ClientData<'static>) -> Self {
52 let client = Arc::new(client);
53 let registry = Arc::new(SessionRegistry::new(store, client.clone(), client_data));
54 Self { registry, client }
55 }
56
57 pub fn new_with_shared(
58 store: Arc<S>,
59 client: Arc<T>,
60 client_data: ClientData<'static>,
61 ) -> Self {
62 let registry = Arc::new(SessionRegistry::new_shared(
63 store,
64 client.clone(),
65 client_data,
66 ));
67 Self { registry, client }
68 }
69}
70
71impl<T, S> OAuthClient<T, S>
72where
73 S: ClientAuthStore + Send + Sync + 'static,
74 T: OAuthResolver + DpopExt + Send + Sync + 'static,
75{
76 pub fn jwks(&self) -> JwkSet {
77 self.registry
78 .client_data
79 .keyset
80 .as_ref()
81 .map(|keyset| keyset.public_jwks())
82 .unwrap_or_default()
83 }
84 pub async fn start_auth(
85 &self,
86 input: impl AsRef<str>,
87 options: AuthorizeOptions<'_>,
88 ) -> Result<String> {
89 let client_metadata = atproto_client_metadata(
90 self.registry.client_data.config.clone(),
91 &self.registry.client_data.keyset,
92 )?;
93
94 let (server_metadata, identity) = self.client.resolve_oauth(input.as_ref()).await?;
95 let login_hint = if identity.is_some() {
96 Some(input.as_ref().into())
97 } else {
98 None
99 };
100 let metadata = OAuthMetadata {
101 server_metadata,
102 client_metadata,
103 keyset: self.registry.client_data.keyset.clone(),
104 };
105 let auth_req_info =
106 par(self.client.as_ref(), login_hint, options.prompt, &metadata).await?;
107 // Persist state for callback handling
108 self.registry
109 .store
110 .save_auth_req_info(&auth_req_info)
111 .await?;
112
113 #[derive(serde::Serialize)]
114 struct Parameters<'s> {
115 client_id: Url,
116 request_uri: CowStr<'s>,
117 }
118 Ok(metadata.server_metadata.authorization_endpoint.to_string()
119 + "?"
120 + &serde_html_form::to_string(Parameters {
121 client_id: metadata.client_metadata.client_id.clone(),
122 request_uri: auth_req_info.request_uri,
123 })
124 .unwrap())
125 }
126
127 pub async fn callback(&self, params: CallbackParams<'_>) -> Result<OAuthSession<T, S>> {
128 let Some(state_key) = params.state else {
129 return Err(CallbackError::MissingState.into());
130 };
131
132 let Some(auth_req_info) = self.registry.store.get_auth_req_info(&state_key).await? else {
133 return Err(CallbackError::MissingState.into());
134 };
135
136 self.registry.store.delete_auth_req_info(&state_key).await?;
137
138 let metadata = self
139 .client
140 .get_authorization_server_metadata(&auth_req_info.authserver_url)
141 .await?;
142
143 if let Some(iss) = params.iss {
144 if !crate::resolver::issuer_equivalent(&iss, &metadata.issuer) {
145 return Err(CallbackError::IssuerMismatch {
146 expected: metadata.issuer.to_string(),
147 got: iss.to_string(),
148 }
149 .into());
150 }
151 } else if metadata.authorization_response_iss_parameter_supported == Some(true) {
152 return Err(CallbackError::MissingIssuer.into());
153 }
154 let metadata = OAuthMetadata {
155 server_metadata: metadata,
156 client_metadata: atproto_client_metadata(
157 self.registry.client_data.config.clone(),
158 &self.registry.client_data.keyset,
159 )?,
160 keyset: self.registry.client_data.keyset.clone(),
161 };
162 let authserver_nonce = auth_req_info.dpop_data.dpop_authserver_nonce.clone();
163
164 match exchange_code(
165 self.client.as_ref(),
166 &mut auth_req_info.dpop_data.clone(),
167 ¶ms.code,
168 &auth_req_info.pkce_verifier,
169 &metadata,
170 )
171 .await
172 {
173 Ok(token_set) => {
174 let scopes = if let Some(scope) = &token_set.scope {
175 Scope::parse_multiple_reduced(&scope)
176 .expect("Failed to parse scopes")
177 .into_static()
178 } else {
179 vec![]
180 };
181 let client_data = ClientSessionData {
182 account_did: token_set.sub.clone(),
183 session_id: auth_req_info.state,
184 host_url: Url::parse(&token_set.iss).expect("Failed to parse host URL"),
185 authserver_url: auth_req_info.authserver_url,
186 authserver_token_endpoint: auth_req_info.authserver_token_endpoint,
187 authserver_revocation_endpoint: auth_req_info.authserver_revocation_endpoint,
188 scopes,
189 dpop_data: DpopClientData {
190 dpop_key: auth_req_info.dpop_data.dpop_key.clone(),
191 dpop_authserver_nonce: authserver_nonce.unwrap_or(CowStr::default()),
192 dpop_host_nonce: auth_req_info
193 .dpop_data
194 .dpop_authserver_nonce
195 .unwrap_or(CowStr::default()),
196 },
197 token_set,
198 };
199
200 self.create_session(client_data).await
201 }
202 Err(e) => Err(e.into()),
203 }
204 }
205
206 async fn create_session(&self, data: ClientSessionData<'_>) -> Result<OAuthSession<T, S>> {
207 self.registry.set(data.clone()).await?;
208 Ok(OAuthSession::new(
209 self.registry.clone(),
210 self.client.clone(),
211 data.into_static(),
212 ))
213 }
214
215 pub async fn restore(&self, did: &Did<'_>, session_id: &str) -> Result<OAuthSession<T, S>> {
216 self.create_session(self.registry.get(did, session_id, false).await?)
217 .await
218 }
219
220 pub async fn revoke(&self, did: &Did<'_>, session_id: &str) -> Result<()> {
221 Ok(self.registry.del(did, session_id).await?)
222 }
223}
224
225pub struct OAuthSession<T, S>
226where
227 T: OAuthResolver,
228 S: ClientAuthStore,
229{
230 pub registry: Arc<SessionRegistry<T, S>>,
231 pub client: Arc<T>,
232 pub data: RwLock<ClientSessionData<'static>>,
233 pub options: RwLock<CallOptions<'static>>,
234}
235
236impl<T, S> OAuthSession<T, S>
237where
238 T: OAuthResolver,
239 S: ClientAuthStore,
240{
241 pub fn new(
242 registry: Arc<SessionRegistry<T, S>>,
243 client: Arc<T>,
244 data: ClientSessionData<'static>,
245 ) -> Self {
246 Self {
247 registry,
248 client,
249 data: RwLock::new(data),
250 options: RwLock::new(CallOptions::default()),
251 }
252 }
253
254 pub fn with_options(self, options: CallOptions<'_>) -> Self {
255 Self {
256 registry: self.registry,
257 client: self.client,
258 data: self.data,
259 options: RwLock::new(options.into_static()),
260 }
261 }
262
263 pub async fn set_options(&self, options: CallOptions<'_>) {
264 *self.options.write().await = options.into_static();
265 }
266
267 pub async fn session_info(&self) -> (Did<'_>, CowStr<'_>) {
268 let data = self.data.read().await;
269 (data.account_did.clone(), data.session_id.clone())
270 }
271
272 pub async fn endpoint(&self) -> Url {
273 self.data.read().await.host_url.clone()
274 }
275
276 pub async fn access_token(&self) -> AuthorizationToken<'_> {
277 AuthorizationToken::Dpop(self.data.read().await.token_set.access_token.clone())
278 }
279
280 pub async fn refresh_token(&self) -> Option<AuthorizationToken<'_>> {
281 self.data
282 .read()
283 .await
284 .token_set
285 .refresh_token
286 .as_ref()
287 .map(|t| AuthorizationToken::Dpop(t.clone()))
288 }
289}
290impl<T, S> OAuthSession<T, S>
291where
292 S: ClientAuthStore + Send + Sync + 'static,
293 T: OAuthResolver + DpopExt + Send + Sync + 'static,
294{
295 pub async fn logout(&self) -> Result<()> {
296 use crate::request::{OAuthMetadata, revoke};
297 let mut data = self.data.write().await;
298 let meta =
299 OAuthMetadata::new(self.client.as_ref(), &self.registry.client_data, &data).await?;
300 if meta.server_metadata.revocation_endpoint.is_some() {
301 let token = data.token_set.access_token.clone();
302 revoke(self.client.as_ref(), &mut data.dpop_data, &token, &meta)
303 .await
304 .ok();
305 }
306 // Remove from store
307 self.registry
308 .del(&data.account_did, &data.session_id)
309 .await?;
310 Ok(())
311 }
312}
313
314impl<T, S> OAuthClient<T, S>
315where
316 T: OAuthResolver,
317 S: ClientAuthStore,
318{
319 pub fn from_session(session: &OAuthSession<T, S>) -> Self {
320 Self {
321 registry: session.registry.clone(),
322 client: session.client.clone(),
323 }
324 }
325}
326impl<T, S> OAuthSession<T, S>
327where
328 S: ClientAuthStore + Send + Sync + 'static,
329 T: OAuthResolver + DpopExt + Send + Sync + 'static,
330{
331 pub async fn refresh(&self) -> Result<AuthorizationToken<'_>> {
332 // Read identifiers without holding the lock across await
333 let (did, sid) = {
334 let data = self.data.read().await;
335 (data.account_did.clone(), data.session_id.clone())
336 };
337 let refreshed = self.registry.as_ref().get(&did, &sid, true).await?;
338 let token = AuthorizationToken::Dpop(refreshed.token_set.access_token.clone());
339 // Write back updated session
340 *self.data.write().await = refreshed.clone().into_static();
341 // Store in the registry
342 self.registry.set(refreshed).await?;
343 Ok(token)
344 }
345}
346
347impl<T, S> HttpClient for OAuthSession<T, S>
348where
349 S: ClientAuthStore + Send + Sync + 'static,
350 T: OAuthResolver + DpopExt + Send + Sync + 'static,
351{
352 type Error = T::Error;
353
354 async fn send_http(
355 &self,
356 request: http::Request<Vec<u8>>,
357 ) -> core::result::Result<http::Response<Vec<u8>>, Self::Error> {
358 self.client.send_http(request).await
359 }
360}
361
362impl<T, S> XrpcClient for OAuthSession<T, S>
363where
364 S: ClientAuthStore + Send + Sync + 'static,
365 T: OAuthResolver + DpopExt + XrpcExt + Send + Sync + 'static,
366{
367 fn base_uri(&self) -> Url {
368 // base_uri is a synchronous trait method; we must avoid async `.read().await`.
369 // Use `block_in_place` under Tokio to perform a blocking RwLock read safely.
370 if tokio::runtime::Handle::try_current().is_ok() {
371 tokio::task::block_in_place(|| self.data.blocking_read().host_url.clone())
372 } else {
373 self.data.blocking_read().host_url.clone()
374 }
375 }
376
377 async fn opts(&self) -> CallOptions<'_> {
378 self.options.read().await.clone()
379 }
380
381 async fn send<R: jacquard_common::types::xrpc::XrpcRequest + Send>(
382 self,
383 request: &R,
384 ) -> XrpcResult<Response<R>> {
385 let base_uri = self.base_uri();
386 let mut opts = self.options.read().await.clone();
387 opts.auth = Some(self.access_token().await);
388 let guard = self.data.read().await;
389 let mut dpop = guard.dpop_data.clone();
390 let http_response = self
391 .client
392 .dpop_call(&mut dpop)
393 .send(build_http_request(&base_uri, request, &opts)?)
394 .await
395 .map_err(|e| TransportError::Other(Box::new(e)))?;
396 drop(guard);
397 let res = process_response(http_response);
398 if is_invalid_token_response(&res) {
399 opts.auth = Some(
400 self.refresh()
401 .await
402 .map_err(|e| ClientError::Transport(TransportError::Other(e.into())))?,
403 );
404 let guard = self.data.read().await;
405 let mut dpop = guard.dpop_data.clone();
406 let http_response = self
407 .client
408 .dpop_call(&mut dpop)
409 .send(build_http_request(&base_uri, request, &opts)?)
410 .await
411 .map_err(|e| TransportError::Other(Box::new(e)))?;
412 process_response(http_response)
413 } else {
414 res
415 }
416 }
417}
418
419fn is_invalid_token_response<R: XrpcRequest>(response: &XrpcResult<Response<R>>) -> bool {
420 match response {
421 Err(ClientError::Auth(AuthError::InvalidToken)) => true,
422 Err(ClientError::Auth(AuthError::Other(value))) => value
423 .to_str()
424 .is_ok_and(|s| s.starts_with("DPoP ") && s.contains("error=\"invalid_token\"")),
425 Ok(resp) => match resp.parse() {
426 Err(jacquard_common::types::xrpc::XrpcError::Auth(AuthError::InvalidToken)) => true,
427 _ => false,
428 },
429 _ => false,
430 }
431}