1use crate::{
2 atproto::atproto_client_metadata,
3 authstore::ClientAuthStore,
4 dpop::DpopExt,
5 error::{CallbackError, Result},
6 request::{OAuthMetadata, exchange_code, par},
7 resolver::OAuthResolver,
8 scopes::Scope,
9 session::{ClientData, ClientSessionData, DpopClientData, SessionRegistry},
10 types::{AuthorizeOptions, CallbackParams},
11};
12use jacquard_common::{
13 AuthorizationToken, CowStr, IntoStatic,
14 error::{AuthError, ClientError, TransportError, XrpcResult},
15 http_client::HttpClient,
16 types::did::Did,
17 xrpc::{
18 CallOptions, Response, XrpcClient, XrpcExt, XrpcRequest, XrpcResp, build_http_request,
19 process_response,
20 },
21};
22use jacquard_identity::JacquardResolver;
23use jose_jwk::JwkSet;
24use std::sync::Arc;
25use tokio::sync::RwLock;
26use url::Url;
27
28pub struct OAuthClient<T, S>
29where
30 T: OAuthResolver,
31 S: ClientAuthStore,
32{
33 pub registry: Arc<SessionRegistry<T, S>>,
34 pub client: Arc<T>,
35}
36
37impl<S: ClientAuthStore> OAuthClient<JacquardResolver, S> {
38 pub fn new(store: S, client_data: ClientData<'static>) -> Self {
39 let client = JacquardResolver::default();
40 Self::new_from_resolver(store, client, client_data)
41 }
42}
43
44impl<T, S> OAuthClient<T, S>
45where
46 T: OAuthResolver,
47 S: ClientAuthStore,
48{
49 pub fn new_from_resolver(store: S, client: T, client_data: ClientData<'static>) -> Self {
50 let client = Arc::new(client);
51 let registry = Arc::new(SessionRegistry::new(store, client.clone(), client_data));
52 Self { registry, client }
53 }
54
55 pub fn new_with_shared(
56 store: Arc<S>,
57 client: Arc<T>,
58 client_data: ClientData<'static>,
59 ) -> Self {
60 let registry = Arc::new(SessionRegistry::new_shared(
61 store,
62 client.clone(),
63 client_data,
64 ));
65 Self { registry, client }
66 }
67}
68
69impl<T, S> OAuthClient<T, S>
70where
71 S: ClientAuthStore + Send + Sync + 'static,
72 T: OAuthResolver + DpopExt + Send + Sync + 'static,
73{
74 pub fn jwks(&self) -> JwkSet {
75 self.registry
76 .client_data
77 .keyset
78 .as_ref()
79 .map(|keyset| keyset.public_jwks())
80 .unwrap_or_default()
81 }
82 pub async fn start_auth(
83 &self,
84 input: impl AsRef<str>,
85 options: AuthorizeOptions<'_>,
86 ) -> Result<String> {
87 let client_metadata = atproto_client_metadata(
88 self.registry.client_data.config.clone(),
89 &self.registry.client_data.keyset,
90 )?;
91
92 let (server_metadata, identity) = self.client.resolve_oauth(input.as_ref()).await?;
93 let login_hint = if identity.is_some() {
94 Some(input.as_ref().into())
95 } else {
96 None
97 };
98 let metadata = OAuthMetadata {
99 server_metadata,
100 client_metadata,
101 keyset: self.registry.client_data.keyset.clone(),
102 };
103 let auth_req_info =
104 par(self.client.as_ref(), login_hint, options.prompt, &metadata).await?;
105 // Persist state for callback handling
106 self.registry
107 .store
108 .save_auth_req_info(&auth_req_info)
109 .await?;
110
111 #[derive(serde::Serialize)]
112 struct Parameters<'s> {
113 client_id: Url,
114 request_uri: CowStr<'s>,
115 }
116 Ok(metadata.server_metadata.authorization_endpoint.to_string()
117 + "?"
118 + &serde_html_form::to_string(Parameters {
119 client_id: metadata.client_metadata.client_id.clone(),
120 request_uri: auth_req_info.request_uri,
121 })
122 .unwrap())
123 }
124
125 pub async fn callback(&self, params: CallbackParams<'_>) -> Result<OAuthSession<T, S>> {
126 let Some(state_key) = params.state else {
127 return Err(CallbackError::MissingState.into());
128 };
129
130 let Some(auth_req_info) = self.registry.store.get_auth_req_info(&state_key).await? else {
131 return Err(CallbackError::MissingState.into());
132 };
133
134 self.registry.store.delete_auth_req_info(&state_key).await?;
135
136 let metadata = self
137 .client
138 .get_authorization_server_metadata(&auth_req_info.authserver_url)
139 .await?;
140
141 if let Some(iss) = params.iss {
142 if !crate::resolver::issuer_equivalent(&iss, &metadata.issuer) {
143 return Err(CallbackError::IssuerMismatch {
144 expected: metadata.issuer.to_string(),
145 got: iss.to_string(),
146 }
147 .into());
148 }
149 } else if metadata.authorization_response_iss_parameter_supported == Some(true) {
150 return Err(CallbackError::MissingIssuer.into());
151 }
152 let metadata = OAuthMetadata {
153 server_metadata: metadata,
154 client_metadata: atproto_client_metadata(
155 self.registry.client_data.config.clone(),
156 &self.registry.client_data.keyset,
157 )?,
158 keyset: self.registry.client_data.keyset.clone(),
159 };
160 let authserver_nonce = auth_req_info.dpop_data.dpop_authserver_nonce.clone();
161
162 match exchange_code(
163 self.client.as_ref(),
164 &mut auth_req_info.dpop_data.clone(),
165 ¶ms.code,
166 &auth_req_info.pkce_verifier,
167 &metadata,
168 )
169 .await
170 {
171 Ok(token_set) => {
172 let scopes = if let Some(scope) = &token_set.scope {
173 Scope::parse_multiple_reduced(&scope)
174 .expect("Failed to parse scopes")
175 .into_static()
176 } else {
177 vec![]
178 };
179 let client_data = ClientSessionData {
180 account_did: token_set.sub.clone(),
181 session_id: auth_req_info.state,
182 host_url: Url::parse(&token_set.iss).expect("Failed to parse host URL"),
183 authserver_url: auth_req_info.authserver_url,
184 authserver_token_endpoint: auth_req_info.authserver_token_endpoint,
185 authserver_revocation_endpoint: auth_req_info.authserver_revocation_endpoint,
186 scopes,
187 dpop_data: DpopClientData {
188 dpop_key: auth_req_info.dpop_data.dpop_key.clone(),
189 dpop_authserver_nonce: authserver_nonce.unwrap_or(CowStr::default()),
190 dpop_host_nonce: auth_req_info
191 .dpop_data
192 .dpop_authserver_nonce
193 .unwrap_or(CowStr::default()),
194 },
195 token_set,
196 };
197
198 self.create_session(client_data).await
199 }
200 Err(e) => Err(e.into()),
201 }
202 }
203
204 async fn create_session(&self, data: ClientSessionData<'_>) -> Result<OAuthSession<T, S>> {
205 self.registry.set(data.clone()).await?;
206 Ok(OAuthSession::new(
207 self.registry.clone(),
208 self.client.clone(),
209 data.into_static(),
210 ))
211 }
212
213 pub async fn restore(&self, did: &Did<'_>, session_id: &str) -> Result<OAuthSession<T, S>> {
214 self.create_session(self.registry.get(did, session_id, false).await?)
215 .await
216 }
217
218 pub async fn revoke(&self, did: &Did<'_>, session_id: &str) -> Result<()> {
219 Ok(self.registry.del(did, session_id).await?)
220 }
221}
222
223pub struct OAuthSession<T, S>
224where
225 T: OAuthResolver,
226 S: ClientAuthStore,
227{
228 pub registry: Arc<SessionRegistry<T, S>>,
229 pub client: Arc<T>,
230 pub data: RwLock<ClientSessionData<'static>>,
231 pub options: RwLock<CallOptions<'static>>,
232}
233
234impl<T, S> OAuthSession<T, S>
235where
236 T: OAuthResolver,
237 S: ClientAuthStore,
238{
239 pub fn new(
240 registry: Arc<SessionRegistry<T, S>>,
241 client: Arc<T>,
242 data: ClientSessionData<'static>,
243 ) -> Self {
244 Self {
245 registry,
246 client,
247 data: RwLock::new(data),
248 options: RwLock::new(CallOptions::default()),
249 }
250 }
251
252 pub fn with_options(self, options: CallOptions<'_>) -> Self {
253 Self {
254 registry: self.registry,
255 client: self.client,
256 data: self.data,
257 options: RwLock::new(options.into_static()),
258 }
259 }
260
261 pub async fn set_options(&self, options: CallOptions<'_>) {
262 *self.options.write().await = options.into_static();
263 }
264
265 pub async fn session_info(&self) -> (Did<'_>, CowStr<'_>) {
266 let data = self.data.read().await;
267 (data.account_did.clone(), data.session_id.clone())
268 }
269
270 pub async fn endpoint(&self) -> Url {
271 self.data.read().await.host_url.clone()
272 }
273
274 pub async fn access_token(&self) -> AuthorizationToken<'_> {
275 AuthorizationToken::Dpop(self.data.read().await.token_set.access_token.clone())
276 }
277
278 pub async fn refresh_token(&self) -> Option<AuthorizationToken<'_>> {
279 self.data
280 .read()
281 .await
282 .token_set
283 .refresh_token
284 .as_ref()
285 .map(|t| AuthorizationToken::Dpop(t.clone()))
286 }
287}
288impl<T, S> OAuthSession<T, S>
289where
290 S: ClientAuthStore + Send + Sync + 'static,
291 T: OAuthResolver + DpopExt + Send + Sync + 'static,
292{
293 pub async fn logout(&self) -> Result<()> {
294 use crate::request::{OAuthMetadata, revoke};
295 let mut data = self.data.write().await;
296 let meta =
297 OAuthMetadata::new(self.client.as_ref(), &self.registry.client_data, &data).await?;
298 if meta.server_metadata.revocation_endpoint.is_some() {
299 let token = data.token_set.access_token.clone();
300 revoke(self.client.as_ref(), &mut data.dpop_data, &token, &meta)
301 .await
302 .ok();
303 }
304 // Remove from store
305 self.registry
306 .del(&data.account_did, &data.session_id)
307 .await?;
308 Ok(())
309 }
310}
311
312impl<T, S> OAuthClient<T, S>
313where
314 T: OAuthResolver,
315 S: ClientAuthStore,
316{
317 pub fn from_session(session: &OAuthSession<T, S>) -> Self {
318 Self {
319 registry: session.registry.clone(),
320 client: session.client.clone(),
321 }
322 }
323}
324impl<T, S> OAuthSession<T, S>
325where
326 S: ClientAuthStore + Send + Sync + 'static,
327 T: OAuthResolver + DpopExt + Send + Sync + 'static,
328{
329 pub async fn refresh(&self) -> Result<AuthorizationToken<'_>> {
330 // Read identifiers without holding the lock across await
331 let (did, sid) = {
332 let data = self.data.read().await;
333 (data.account_did.clone(), data.session_id.clone())
334 };
335 let refreshed = self.registry.as_ref().get(&did, &sid, true).await?;
336 let token = AuthorizationToken::Dpop(refreshed.token_set.access_token.clone());
337 // Write back updated session
338 *self.data.write().await = refreshed.clone().into_static();
339 // Store in the registry
340 self.registry.set(refreshed).await?;
341 Ok(token)
342 }
343}
344
345impl<T, S> HttpClient for OAuthSession<T, S>
346where
347 S: ClientAuthStore + Send + Sync + 'static,
348 T: OAuthResolver + DpopExt + Send + Sync + 'static,
349{
350 type Error = T::Error;
351
352 async fn send_http(
353 &self,
354 request: http::Request<Vec<u8>>,
355 ) -> core::result::Result<http::Response<Vec<u8>>, Self::Error> {
356 self.client.send_http(request).await
357 }
358}
359
360impl<T, S> XrpcClient for OAuthSession<T, S>
361where
362 S: ClientAuthStore + Send + Sync + 'static,
363 T: OAuthResolver + DpopExt + XrpcExt + Send + Sync + 'static,
364{
365 fn base_uri(&self) -> Url {
366 // base_uri is a synchronous trait method; we must avoid async `.read().await`.
367 // Use `block_in_place` under Tokio to perform a blocking RwLock read safely.
368 if tokio::runtime::Handle::try_current().is_ok() {
369 tokio::task::block_in_place(|| self.data.blocking_read().host_url.clone())
370 } else {
371 self.data.blocking_read().host_url.clone()
372 }
373 }
374
375 async fn opts(&self) -> CallOptions<'_> {
376 self.options.read().await.clone()
377 }
378
379 async fn send<'s, R>(
380 &self,
381 request: R,
382 ) -> XrpcResult<Response<<R as XrpcRequest<'s>>::Response>>
383 where
384 R: XrpcRequest<'s>,
385 {
386 let base_uri = self.base_uri();
387 let mut opts = self.options.read().await.clone();
388 opts.auth = Some(self.access_token().await);
389 let guard = self.data.read().await;
390 let mut dpop = guard.dpop_data.clone();
391 let http_response = self
392 .client
393 .dpop_call(&mut dpop)
394 .send(build_http_request(&base_uri, &request, &opts)?)
395 .await
396 .map_err(|e| TransportError::Other(Box::new(e)))?;
397 let resp = process_response(http_response);
398 drop(guard);
399 if is_invalid_token_response(&resp) {
400 opts.auth = Some(
401 self.refresh()
402 .await
403 .map_err(|e| ClientError::Transport(TransportError::Other(e.into())))?,
404 );
405 let guard = self.data.read().await;
406 let mut dpop = guard.dpop_data.clone();
407 let http_response = self
408 .client
409 .dpop_call(&mut dpop)
410 .send(build_http_request(&base_uri, &request, &opts)?)
411 .await
412 .map_err(|e| TransportError::Other(Box::new(e)))?;
413 process_response(http_response)
414 } else {
415 resp
416 }
417 }
418}
419
420fn is_invalid_token_response<R: XrpcResp>(response: &XrpcResult<Response<R>>) -> bool {
421 match response {
422 Err(ClientError::Auth(AuthError::InvalidToken)) => true,
423 Err(ClientError::Auth(AuthError::Other(value))) => value
424 .to_str()
425 .is_ok_and(|s| s.starts_with("DPoP ") && s.contains("error=\"invalid_token\"")),
426 Ok(resp) => match resp.parse() {
427 Err(jacquard_common::xrpc::XrpcError::Auth(AuthError::InvalidToken)) => true,
428 _ => false,
429 },
430 _ => false,
431 }
432}