1use crate::{
2 atproto::atproto_client_metadata,
3 authstore::ClientAuthStore,
4 dpop::DpopExt,
5 error::{CallbackError, Result},
6 request::{OAuthMetadata, exchange_code, par},
7 resolver::OAuthResolver,
8 scopes::Scope,
9 session::{ClientData, ClientSessionData, DpopClientData, SessionRegistry},
10 types::{AuthorizeOptions, CallbackParams},
11};
12use jacquard_common::{
13 AuthorizationToken, CowStr, IntoStatic,
14 error::{AuthError, ClientError, TransportError, XrpcResult},
15 http_client::HttpClient,
16 types::{
17 did::Did,
18 xrpc::{CallOptions, Response, XrpcClient, XrpcExt, XrpcRequest},
19 },
20};
21use jacquard_identity::JacquardResolver;
22use jose_jwk::JwkSet;
23use std::sync::Arc;
24use tokio::sync::RwLock;
25use url::Url;
26
27pub struct OAuthClient<T, S>
28where
29 T: OAuthResolver,
30 S: ClientAuthStore,
31{
32 pub registry: Arc<SessionRegistry<T, S>>,
33 pub client: Arc<T>,
34}
35
36impl<S: ClientAuthStore> OAuthClient<JacquardResolver, S> {
37 pub fn new(store: S, client_data: ClientData<'static>) -> Self {
38 let client = JacquardResolver::default();
39 Self::new_from_resolver(store, client, client_data)
40 }
41}
42
43impl<T, S> OAuthClient<T, S>
44where
45 T: OAuthResolver,
46 S: ClientAuthStore,
47{
48 pub fn new_from_resolver(store: S, client: T, client_data: ClientData<'static>) -> Self {
49 let client = Arc::new(client);
50 let registry = Arc::new(SessionRegistry::new(store, client.clone(), client_data));
51 Self { registry, client }
52 }
53}
54
55impl<T, S> OAuthClient<T, S>
56where
57 S: ClientAuthStore + Send + Sync + 'static,
58 T: OAuthResolver + DpopExt + Send + Sync + 'static,
59{
60 pub fn jwks(&self) -> JwkSet {
61 self.registry
62 .client_data
63 .keyset
64 .as_ref()
65 .map(|keyset| keyset.public_jwks())
66 .unwrap_or_default()
67 }
68 pub async fn start_auth(
69 &self,
70 input: impl AsRef<str>,
71 options: AuthorizeOptions<'_>,
72 ) -> Result<String> {
73 let client_metadata = atproto_client_metadata(
74 self.registry.client_data.config.clone(),
75 &self.registry.client_data.keyset,
76 )?;
77
78 let (server_metadata, identity) = self.client.resolve_oauth(input.as_ref()).await?;
79 let login_hint = if identity.is_some() {
80 Some(input.as_ref().into())
81 } else {
82 None
83 };
84 let metadata = OAuthMetadata {
85 server_metadata,
86 client_metadata,
87 keyset: self.registry.client_data.keyset.clone(),
88 };
89 let auth_req_info =
90 par(self.client.as_ref(), login_hint, options.prompt, &metadata).await?;
91
92 #[derive(serde::Serialize)]
93 struct Parameters<'s> {
94 client_id: Url,
95 request_uri: CowStr<'s>,
96 }
97 Ok(metadata.server_metadata.authorization_endpoint.to_string()
98 + "?"
99 + &serde_html_form::to_string(Parameters {
100 client_id: metadata.client_metadata.client_id.clone(),
101 request_uri: auth_req_info.request_uri,
102 })
103 .unwrap())
104 }
105
106 pub async fn callback(&self, params: CallbackParams<'_>) -> Result<OAuthSession<T, S>> {
107 let Some(state_key) = params.state else {
108 return Err(CallbackError::MissingState.into());
109 };
110
111 let Some(auth_req_info) = self.registry.store.get_auth_req_info(&state_key).await? else {
112 return Err(CallbackError::MissingState.into());
113 };
114
115 self.registry.store.delete_auth_req_info(&state_key).await?;
116
117 let metadata = self
118 .client
119 .get_authorization_server_metadata(&auth_req_info.authserver_url)
120 .await?;
121
122 if let Some(iss) = params.iss {
123 if !crate::resolver::issuer_equivalent(&iss, &metadata.issuer) {
124 return Err(CallbackError::IssuerMismatch { expected: metadata.issuer.to_string(), got: iss.to_string() }.into());
125 }
126 } else if metadata.authorization_response_iss_parameter_supported == Some(true) {
127 return Err(CallbackError::MissingIssuer.into());
128 }
129 let metadata = OAuthMetadata {
130 server_metadata: metadata,
131 client_metadata: atproto_client_metadata(
132 self.registry.client_data.config.clone(),
133 &self.registry.client_data.keyset,
134 )?,
135 keyset: self.registry.client_data.keyset.clone(),
136 };
137 let authserver_nonce = auth_req_info.dpop_data.dpop_authserver_nonce.clone();
138
139 match exchange_code(
140 self.client.as_ref(),
141 &mut auth_req_info.dpop_data.clone(),
142 ¶ms.code,
143 &auth_req_info.pkce_verifier,
144 &metadata,
145 )
146 .await
147 {
148 Ok(token_set) => {
149 let scopes = if let Some(scope) = &token_set.scope {
150 Scope::parse_multiple_reduced(&scope)
151 .expect("Failed to parse scopes")
152 .into_static()
153 } else {
154 vec![]
155 };
156 let client_data = ClientSessionData {
157 account_did: token_set.sub.clone(),
158 session_id: auth_req_info.state,
159 host_url: Url::parse(&token_set.iss).expect("Failed to parse host URL"),
160 authserver_url: auth_req_info.authserver_url,
161 authserver_token_endpoint: auth_req_info.authserver_token_endpoint,
162 authserver_revocation_endpoint: auth_req_info.authserver_revocation_endpoint,
163 scopes,
164 dpop_data: DpopClientData {
165 dpop_key: auth_req_info.dpop_data.dpop_key.clone(),
166 dpop_authserver_nonce: authserver_nonce.unwrap_or(CowStr::default()),
167 dpop_host_nonce: auth_req_info
168 .dpop_data
169 .dpop_authserver_nonce
170 .unwrap_or(CowStr::default()),
171 },
172 token_set,
173 };
174
175 self.create_session(client_data).await
176 }
177 Err(e) => Err(e.into()),
178 }
179 }
180
181 async fn create_session(&self, data: ClientSessionData<'_>) -> Result<OAuthSession<T, S>> {
182 Ok(OAuthSession::new(
183 self.registry.clone(),
184 self.client.clone(),
185 data.into_static(),
186 ))
187 }
188
189 pub async fn restore(&self, did: &Did<'_>, session_id: &str) -> Result<OAuthSession<T, S>> {
190 self.create_session(self.registry.get(did, session_id, false).await?)
191 .await
192 }
193
194 pub async fn revoke(&self, did: &Did<'_>, session_id: &str) -> Result<()> {
195 Ok(self.registry.del(did, session_id).await?)
196 }
197}
198
199pub struct OAuthSession<T, S>
200where
201 T: OAuthResolver,
202 S: ClientAuthStore,
203{
204 pub registry: Arc<SessionRegistry<T, S>>,
205 pub client: Arc<T>,
206 pub data: RwLock<ClientSessionData<'static>>,
207 pub options: RwLock<CallOptions<'static>>,
208}
209
210impl<T, S> OAuthSession<T, S>
211where
212 T: OAuthResolver,
213 S: ClientAuthStore,
214{
215 pub fn new(
216 registry: Arc<SessionRegistry<T, S>>,
217 client: Arc<T>,
218 data: ClientSessionData<'static>,
219 ) -> Self {
220 Self {
221 registry,
222 client,
223 data: RwLock::new(data),
224 options: RwLock::new(CallOptions::default()),
225 }
226 }
227
228 pub fn with_options(self, options: CallOptions<'_>) -> Self {
229 Self {
230 registry: self.registry,
231 client: self.client,
232 data: self.data,
233 options: RwLock::new(options.into_static()),
234 }
235 }
236
237 pub async fn set_options(&self, options: CallOptions<'_>) {
238 *self.options.write().await = options.into_static();
239 }
240
241 pub async fn session_info(&self) -> (Did<'_>, CowStr<'_>) {
242 let data = self.data.read().await;
243 (data.account_did.clone(), data.session_id.clone())
244 }
245
246 pub async fn endpoint(&self) -> Url {
247 self.data.read().await.host_url.clone()
248 }
249
250 pub async fn access_token(&self) -> AuthorizationToken<'_> {
251 AuthorizationToken::Dpop(self.data.read().await.token_set.access_token.clone())
252 }
253
254 pub async fn refresh_token(&self) -> Option<AuthorizationToken<'_>> {
255 self.data.read().await.token_set.refresh_token.as_ref().map(|t| AuthorizationToken::Dpop(t.clone()))
256 }
257}
258impl<T, S> OAuthSession<T, S>
259where
260 S: ClientAuthStore + Send + Sync + 'static,
261 T: OAuthResolver + DpopExt + Send + Sync + 'static,
262{
263 pub async fn refresh(&self) -> Result<AuthorizationToken<'_>> {
264 // Read identifiers without holding the lock across await
265 let (did, sid) = {
266 let data = self.data.read().await;
267 (data.account_did.clone(), data.session_id.clone())
268 };
269 let refreshed = self
270 .registry
271 .as_ref()
272 .get(&did, &sid, true)
273 .await?;
274 let token = AuthorizationToken::Dpop(refreshed.token_set.access_token.clone());
275 // Write back updated session
276 *self.data.write().await = refreshed.into_static();
277 Ok(token)
278 }
279}
280
281impl<T, S> HttpClient for OAuthSession<T, S>
282where
283 S: ClientAuthStore + Send + Sync + 'static,
284 T: OAuthResolver + DpopExt + Send + Sync + 'static,
285{
286 type Error = T::Error;
287
288 async fn send_http(
289 &self,
290 request: http::Request<Vec<u8>>,
291 ) -> core::result::Result<http::Response<Vec<u8>>, Self::Error> {
292 self.client.send_http(request).await
293 }
294}
295
296impl<T, S> XrpcClient for OAuthSession<T, S>
297where
298 S: ClientAuthStore + Send + Sync + 'static,
299 T: OAuthResolver + DpopExt + XrpcExt + Send + Sync + 'static,
300{
301 fn base_uri(&self) -> Url {
302 // base_uri is a synchronous trait method; we must avoid async `.read().await`.
303 // Use `block_in_place` under Tokio to perform a blocking RwLock read safely.
304 if tokio::runtime::Handle::try_current().is_ok() {
305 tokio::task::block_in_place(|| self.data.blocking_read().host_url.clone())
306 } else {
307 self.data.blocking_read().host_url.clone()
308 }
309 }
310
311 async fn opts(&self) -> CallOptions<'_> {
312 self.options.read().await.clone()
313 }
314
315 async fn send<R: jacquard_common::types::xrpc::XrpcRequest + Send>(
316 self,
317 request: &R,
318 ) -> XrpcResult<Response<R>> {
319 let base_uri = self.base_uri();
320 let auth = self.access_token().await;
321 let mut opts = self.options.read().await.clone();
322 opts.auth = Some(auth);
323 let res = self
324 .client
325 .xrpc(base_uri.clone())
326 .with_options(opts.clone())
327 .send(request)
328 .await;
329 if is_invalid_token_response(&res) {
330 opts.auth = Some(
331 self.refresh()
332 .await
333 .map_err(|e| ClientError::Transport(TransportError::Other(e.into())))?,
334 );
335 self.client
336 .xrpc(base_uri)
337 .with_options(opts)
338 .send(request)
339 .await
340 } else {
341 res
342 }
343 }
344}
345
346fn is_invalid_token_response<R: XrpcRequest>(response: &XrpcResult<Response<R>>) -> bool {
347 match response {
348 Err(ClientError::Auth(AuthError::InvalidToken)) => true,
349 Err(ClientError::Auth(AuthError::Other(value))) => value
350 .to_str()
351 .is_ok_and(|s| s.starts_with("DPoP ") && s.contains("error=\"invalid_token\"")),
352 Ok(resp) => match resp.parse() {
353 Err(jacquard_common::types::xrpc::XrpcError::Auth(AuthError::InvalidToken)) => true,
354 _ => false,
355 },
356 _ => false,
357 }
358}