1use jacquard_common::{
2 AuthorizationToken, CowStr, IntoStatic,
3 error::{AuthError, ClientError, TransportError, XrpcResult},
4 http_client::HttpClient,
5 types::{
6 did::Did,
7 xrpc::{CallOptions, Response, XrpcClient, XrpcExt, XrpcRequest},
8 },
9};
10use jose_jwk::JwkSet;
11use smol_str::SmolStr;
12use std::sync::Arc;
13use tokio::sync::RwLock;
14use url::Url;
15
16use crate::{
17 atproto::atproto_client_metadata,
18 authstore::ClientAuthStore,
19 dpop::DpopExt,
20 error::{OAuthError, Result},
21 request::{OAuthMetadata, exchange_code, par},
22 resolver::OAuthResolver,
23 scopes::Scope,
24 session::{ClientData, ClientSessionData, DpopClientData, SessionRegistry},
25 types::{AuthorizeOptions, CallbackParams},
26};
27
28pub struct OAuthClient<T, S>
29where
30 T: OAuthResolver,
31 S: ClientAuthStore,
32{
33 pub registry: Arc<SessionRegistry<T, S>>,
34 pub client: Arc<T>,
35}
36
37impl<T, S> OAuthClient<T, S>
38where
39 T: OAuthResolver,
40 S: ClientAuthStore,
41{
42 pub fn new_from_resolver(store: S, client: T, client_data: ClientData<'static>) -> Self {
43 let client = Arc::new(client);
44 let registry = Arc::new(SessionRegistry::new(store, client.clone(), client_data));
45 Self { registry, client }
46 }
47}
48
49impl<T, S> OAuthClient<T, S>
50where
51 S: ClientAuthStore + Send + Sync + 'static,
52 T: OAuthResolver + DpopExt + Send + Sync + 'static,
53{
54 pub fn jwks(&self) -> JwkSet {
55 self.registry
56 .client_data
57 .keyset
58 .as_ref()
59 .map(|keyset| keyset.public_jwks())
60 .unwrap_or_default()
61 }
62 pub async fn start_auth(
63 &self,
64 input: impl AsRef<str>,
65 options: AuthorizeOptions<'_>,
66 ) -> Result<String> {
67 let client_metadata = atproto_client_metadata(
68 self.registry.client_data.config.clone(),
69 &self.registry.client_data.keyset,
70 )?;
71
72 let (server_metadata, identity) = self.client.resolve_oauth(input.as_ref()).await?;
73 let login_hint = if identity.is_some() {
74 Some(input.as_ref().into())
75 } else {
76 None
77 };
78 let metadata = OAuthMetadata {
79 server_metadata,
80 client_metadata,
81 keyset: self.registry.client_data.keyset.clone(),
82 };
83 let auth_req_info =
84 par(self.client.as_ref(), login_hint, options.prompt, &metadata).await?;
85
86 #[derive(serde::Serialize)]
87 struct Parameters<'s> {
88 client_id: Url,
89 request_uri: CowStr<'s>,
90 }
91 Ok(metadata.server_metadata.authorization_endpoint.to_string()
92 + "?"
93 + &serde_html_form::to_string(Parameters {
94 client_id: metadata.client_metadata.client_id.clone(),
95 request_uri: auth_req_info.request_uri,
96 })
97 .unwrap())
98 }
99
100 pub async fn callback(&self, params: CallbackParams<'_>) -> Result<OAuthSession<T, S>> {
101 let Some(state_key) = params.state else {
102 return Err(OAuthError::Callback("missing state parameter".into()));
103 };
104
105 let Some(auth_req_info) = self.registry.store.get_auth_req_info(&state_key).await? else {
106 return Err(OAuthError::Callback(format!(
107 "unknown authorization state: {state_key}"
108 )));
109 };
110
111 self.registry.store.delete_auth_req_info(&state_key).await?;
112
113 let metadata = self
114 .client
115 .get_authorization_server_metadata(&auth_req_info.authserver_url)
116 .await?;
117
118 if let Some(iss) = params.iss {
119 if iss != metadata.issuer {
120 return Err(OAuthError::Callback(format!(
121 "issuer mismatch: expected {}, got {iss}",
122 metadata.issuer
123 )));
124 }
125 } else if metadata.authorization_response_iss_parameter_supported == Some(true) {
126 return Err(OAuthError::Callback("missing `iss` parameter".into()));
127 }
128 let metadata = OAuthMetadata {
129 server_metadata: metadata,
130 client_metadata: atproto_client_metadata(
131 self.registry.client_data.config.clone(),
132 &self.registry.client_data.keyset,
133 )?,
134 keyset: self.registry.client_data.keyset.clone(),
135 };
136 let authserver_nonce = auth_req_info.dpop_data.dpop_authserver_nonce.clone();
137
138 match exchange_code(
139 self.client.as_ref(),
140 &mut auth_req_info.dpop_data.clone(),
141 ¶ms.code,
142 &auth_req_info.pkce_verifier,
143 &metadata,
144 )
145 .await
146 {
147 Ok(token_set) => {
148 let scopes = if let Some(scope) = &token_set.scope {
149 Scope::parse_multiple_reduced(&scope)
150 .expect("Failed to parse scopes")
151 .into_static()
152 } else {
153 vec![]
154 };
155 let client_data = ClientSessionData {
156 account_did: token_set.sub.clone(),
157 session_id: auth_req_info.state,
158 host_url: Url::parse(&token_set.iss).expect("Failed to parse host URL"),
159 authserver_url: auth_req_info.authserver_url,
160 authserver_token_endpoint: auth_req_info.authserver_token_endpoint,
161 authserver_revocation_endpoint: auth_req_info.authserver_revocation_endpoint,
162 scopes,
163 dpop_data: DpopClientData {
164 dpop_key: auth_req_info.dpop_data.dpop_key.clone(),
165 dpop_authserver_nonce: authserver_nonce.unwrap_or(CowStr::default()),
166 dpop_host_nonce: auth_req_info
167 .dpop_data
168 .dpop_authserver_nonce
169 .unwrap_or(CowStr::default()),
170 },
171 token_set,
172 };
173
174 self.create_session(client_data).await
175 }
176 Err(e) => Err(e.into()),
177 }
178 }
179
180 async fn create_session(&self, data: ClientSessionData<'_>) -> Result<OAuthSession<T, S>> {
181 Ok(OAuthSession::new(
182 self.registry.clone(),
183 self.client.clone(),
184 data.into_static(),
185 ))
186 }
187
188 pub async fn restore(&self, did: &Did<'_>, session_id: &str) -> Result<OAuthSession<T, S>> {
189 self.create_session(self.registry.get(did, session_id, false).await?)
190 .await
191 }
192
193 pub async fn revoke(&self, did: &Did<'_>, session_id: &str) -> Result<()> {
194 Ok(self.registry.del(did, session_id).await?)
195 }
196}
197
198pub struct OAuthSession<T, S>
199where
200 T: OAuthResolver,
201 S: ClientAuthStore,
202{
203 pub registry: Arc<SessionRegistry<T, S>>,
204 pub client: Arc<T>,
205 pub data: RwLock<ClientSessionData<'static>>,
206 pub options: RwLock<CallOptions<'static>>,
207}
208
209impl<T, S> OAuthSession<T, S>
210where
211 T: OAuthResolver,
212 S: ClientAuthStore,
213{
214 pub fn new(
215 registry: Arc<SessionRegistry<T, S>>,
216 client: Arc<T>,
217 data: ClientSessionData<'static>,
218 ) -> Self {
219 Self {
220 registry,
221 client,
222 data: RwLock::new(data),
223 options: RwLock::new(CallOptions::default()),
224 }
225 }
226
227 pub fn with_options(self, options: CallOptions<'_>) -> Self {
228 Self {
229 registry: self.registry,
230 client: self.client,
231 data: self.data,
232 options: RwLock::new(options.into_static()),
233 }
234 }
235
236 pub async fn set_options(&self, options: CallOptions<'_>) {
237 *self.options.write().await = options.into_static();
238 }
239
240 pub async fn session_info(&self) -> (Did<'_>, CowStr<'_>) {
241 let data = self.data.read().await;
242 (data.account_did.clone(), data.session_id.clone())
243 }
244
245 pub async fn pds(&self) -> Url {
246 self.data.read().await.host_url.clone()
247 }
248
249 pub async fn access_token(&self) -> AuthorizationToken<'_> {
250 AuthorizationToken::Dpop(self.data.read().await.token_set.access_token.clone())
251 }
252
253 pub async fn refresh_token(&self) -> Option<AuthorizationToken<'_>> {
254 self.data
255 .read()
256 .await
257 .token_set
258 .refresh_token
259 .as_ref()
260 .map(|token| AuthorizationToken::Dpop(token.clone()))
261 }
262}
263impl<T, S> OAuthSession<T, S>
264where
265 S: ClientAuthStore + Send + Sync + 'static,
266 T: OAuthResolver + DpopExt + Send + Sync + 'static,
267{
268 pub async fn refresh(&self) -> Result<AuthorizationToken<'_>> {
269 let mut data = self.data.write().await;
270 let refreshed = self
271 .registry
272 .as_ref()
273 .get(&data.account_did, &data.session_id, true)
274 .await?;
275 let token = AuthorizationToken::Dpop(refreshed.token_set.access_token.clone());
276 *data = refreshed.into_static();
277 Ok(token)
278 }
279}
280
281impl<T, S> HttpClient for OAuthSession<T, S>
282where
283 S: ClientAuthStore + Send + Sync + 'static,
284 T: OAuthResolver + DpopExt + Send + Sync + 'static,
285{
286 type Error = T::Error;
287
288 async fn send_http(
289 &self,
290 request: http::Request<Vec<u8>>,
291 ) -> core::result::Result<http::Response<Vec<u8>>, Self::Error> {
292 self.client.send_http(request).await
293 }
294}
295
296impl<T, S> XrpcClient for OAuthSession<T, S>
297where
298 S: ClientAuthStore + Send + Sync + 'static,
299 T: OAuthResolver + DpopExt + XrpcExt + Send + Sync + 'static,
300{
301 fn base_uri(&self) -> Url {
302 self.data.blocking_read().host_url.clone()
303 }
304
305 async fn opts(&self) -> CallOptions<'_> {
306 self.options.read().await.clone()
307 }
308
309 async fn send<R: jacquard_common::types::xrpc::XrpcRequest + Send>(
310 self,
311 request: &R,
312 ) -> XrpcResult<Response<R>> {
313 let base_uri = self.base_uri();
314 let auth = self.access_token().await;
315 let mut opts = self.options.read().await.clone();
316 opts.auth = Some(auth);
317 let res = self
318 .client
319 .xrpc(base_uri.clone())
320 .with_options(opts.clone())
321 .send(request)
322 .await;
323 if is_invalid_token_response(&res) {
324 opts.auth = Some(
325 self.refresh()
326 .await
327 .map_err(|e| ClientError::Transport(TransportError::Other(e.into())))?,
328 );
329 self.client
330 .xrpc(base_uri)
331 .with_options(opts)
332 .send(request)
333 .await
334 } else {
335 res
336 }
337 }
338}
339
340fn is_invalid_token_response<R: XrpcRequest>(response: &XrpcResult<Response<R>>) -> bool {
341 match response {
342 Err(ClientError::Auth(AuthError::InvalidToken)) => true,
343 Err(ClientError::Auth(AuthError::Other(value))) => value
344 .to_str()
345 .is_ok_and(|s| s.starts_with("DPoP ") && s.contains("error=\"invalid_token\"")),
346 _ => false,
347 }
348}