A better Rust ATProto crate

home stretch. big pile of tests

Orual d02398d1 220b03d3

+1
Cargo.lock
···
dependencies = [
"async-trait",
"base64 0.22.1",
+
"bytes",
"chrono",
"dashmap",
"elliptic-curve",
+177 -14
crates/jacquard-common/src/types/xrpc.rs
···
+
//! Stateless XRPC utilities and request/response mapping
+
//!
+
//! Mapping overview:
+
//! - Success (2xx): parse body into the endpoint's typed output.
+
//! - 400: try typed error; on failure, fall back to a generic XRPC error (with
+
//! `nsid`, `method`, and `http_status`) and map common auth errors.
+
//! - 401: if `WWW-Authenticate` is present, return
+
//! `ClientError::Auth(AuthError::Other(header))` so higher layers (OAuth/DPoP)
+
//! can inspect `error="invalid_token"` or `error="use_dpop_nonce"` and refresh/retry.
+
//! If the header is absent, parse the body and map auth errors to
+
//! `AuthError::TokenExpired`/`InvalidToken`.
+
//!
use bytes::Bytes;
use http::{
HeaderName, HeaderValue, Request, StatusCode,
···
}
/// Send the given typed XRPC request and return a response wrapper.
+
///
+
/// Note on 401 handling:
+
/// - When the server returns 401 with a `WWW-Authenticate` header, this surfaces as
+
/// `ClientError::Auth(AuthError::Other(header))` so higher layers (e.g., OAuth/DPoP) can
+
/// inspect the header for `error="invalid_token"` or `error="use_dpop_nonce"` and react
+
/// (refresh/retry). If the header is absent, the 401 body flows through to `Response` and
+
/// can be parsed/mapped to `AuthError` as appropriate.
pub async fn send<R: XrpcRequest + Send>(self, request: &R) -> XrpcResult<Response<R>> {
let http_request = build_http_request(&self.base, request, &self.opts)
.map_err(crate::error::TransportError::from)?;
···
.map_err(|e| crate::error::TransportError::Other(Box::new(e)))?;
let status = http_response.status();
+
// If the server returned 401 with a WWW-Authenticate header, expose it so higher layers
+
// (e.g., DPoP handling) can detect `error="invalid_token"` and trigger refresh.
+
if status.as_u16() == 401 {
+
if let Some(hv) = http_response.headers().get(http::header::WWW_AUTHENTICATE) {
+
return Err(crate::error::ClientError::Auth(
+
crate::error::AuthError::Other(hv.clone()),
+
));
+
}
+
}
let buffer = Bytes::from(http_response.into_body());
if !status.is_success() && !matches!(status.as_u16(), 400 | 401) {
···
Err(_) => {
// Fallback to generic error (InvalidRequest, ExpiredToken, etc.)
match serde_json::from_slice::<GenericXrpcError>(&self.buffer) {
-
Ok(generic) => {
+
Ok(mut generic) => {
+
generic.nsid = R::NSID;
+
generic.method = R::METHOD.as_str();
+
generic.http_status = self.status;
// Map auth-related errors to AuthError
match generic.error.as_str() {
"ExpiredToken" => Err(XrpcError::Auth(AuthError::TokenExpired)),
···
// 401: always auth error
} else {
match serde_json::from_slice::<GenericXrpcError>(&self.buffer) {
-
Ok(generic) => match generic.error.as_str() {
-
"ExpiredToken" => Err(XrpcError::Auth(AuthError::TokenExpired)),
-
"InvalidToken" => Err(XrpcError::Auth(AuthError::InvalidToken)),
-
_ => Err(XrpcError::Auth(AuthError::NotAuthenticated)),
-
},
+
Ok(mut generic) => {
+
generic.nsid = R::NSID;
+
generic.method = R::METHOD.as_str();
+
generic.http_status = self.status;
+
match generic.error.as_str() {
+
"ExpiredToken" => Err(XrpcError::Auth(AuthError::TokenExpired)),
+
"InvalidToken" => Err(XrpcError::Auth(AuthError::InvalidToken)),
+
_ => Err(XrpcError::Auth(AuthError::NotAuthenticated)),
+
}
+
}
Err(e) => Err(XrpcError::Decode(e)),
}
}
···
Err(_) => {
// Fallback to generic error (InvalidRequest, ExpiredToken, etc.)
match serde_json::from_slice::<GenericXrpcError>(&self.buffer) {
-
Ok(generic) => {
+
Ok(mut generic) => {
+
generic.nsid = R::NSID;
+
generic.method = R::METHOD.as_str();
+
generic.http_status = self.status;
// Map auth-related errors to AuthError
match generic.error.as_ref() {
"ExpiredToken" => Err(XrpcError::Auth(AuthError::TokenExpired)),
···
// 401: always auth error
} else {
match serde_json::from_slice::<GenericXrpcError>(&self.buffer) {
-
Ok(generic) => match generic.error.as_ref() {
-
"ExpiredToken" => Err(XrpcError::Auth(AuthError::TokenExpired)),
-
"InvalidToken" => Err(XrpcError::Auth(AuthError::InvalidToken)),
-
_ => Err(XrpcError::Auth(AuthError::NotAuthenticated)),
-
},
+
Ok(mut generic) => {
+
let status = self.status;
+
generic.nsid = R::NSID;
+
generic.method = R::METHOD.as_str();
+
generic.http_status = status;
+
match generic.error.as_ref() {
+
"ExpiredToken" => Err(XrpcError::Auth(AuthError::TokenExpired)),
+
"InvalidToken" => Err(XrpcError::Auth(AuthError::InvalidToken)),
+
_ => Err(XrpcError::Auth(AuthError::NotAuthenticated)),
+
}
+
}
Err(e) => Err(XrpcError::Decode(e)),
}
}
···
pub error: SmolStr,
/// Optional error message with details
pub message: Option<SmolStr>,
+
/// XRPC method NSID that produced this error (context only; not serialized)
+
#[serde(skip)]
+
pub nsid: &'static str,
+
/// HTTP method used (GET/POST) (context only; not serialized)
+
#[serde(skip)]
+
pub method: &'static str,
+
/// HTTP status code (context only; not serialized)
+
#[serde(skip)]
+
pub http_status: StatusCode,
}
impl std::fmt::Display for GenericXrpcError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(msg) = &self.message {
-
write!(f, "{}: {}", self.error, msg)
+
write!(
+
f,
+
"{}: {} (nsid={}, method={}, status={})",
+
self.error, msg, self.nsid, self.method, self.http_status
+
)
} else {
-
write!(f, "{}", self.error)
+
write!(
+
f,
+
"{} (nsid={}, method={}, status={})",
+
self.error, self.nsid, self.method, self.http_status
+
)
}
}
}
···
pub enum XrpcError<E: std::error::Error + IntoStatic> {
/// Typed XRPC error from the endpoint's specific error enum
#[error("XRPC error: {0}")]
+
#[diagnostic(code(jacquard_common::xrpc::typed))]
Xrpc(E),
/// Authentication error (ExpiredToken, InvalidToken, etc.)
#[error("Authentication error: {0}")]
+
#[diagnostic(code(jacquard_common::xrpc::auth))]
Auth(#[from] AuthError),
/// Generic XRPC error not in the endpoint's error enum (e.g., InvalidRequest)
#[error("XRPC error: {0}")]
+
#[diagnostic(code(jacquard_common::xrpc::generic))]
Generic(GenericXrpcError),
/// Failed to decode the response body
#[error("Failed to decode response: {0}")]
+
#[diagnostic(code(jacquard_common::xrpc::decode))]
Decode(#[from] serde_json::Error),
+
}
+
+
#[cfg(test)]
+
mod tests {
+
use super::*;
+
use serde::{Deserialize, Serialize};
+
+
#[derive(Serialize)]
+
struct DummyReq;
+
+
#[derive(Deserialize, Debug, thiserror::Error)]
+
#[error("{0}")]
+
struct DummyErr<'a>(#[serde(borrow)] CowStr<'a>);
+
+
impl IntoStatic for DummyErr<'_> {
+
type Output = DummyErr<'static>;
+
fn into_static(self) -> Self::Output {
+
DummyErr(self.0.into_static())
+
}
+
}
+
+
impl XrpcRequest for DummyReq {
+
const NSID: &'static str = "test.dummy";
+
const METHOD: XrpcMethod = XrpcMethod::Procedure("application/json");
+
const OUTPUT_ENCODING: &'static str = "application/json";
+
type Output<'de> = ();
+
type Err<'de> = DummyErr<'de>;
+
}
+
+
#[test]
+
fn generic_error_carries_context() {
+
let body = serde_json::json!({"error":"InvalidRequest","message":"missing"});
+
let buf = Bytes::from(serde_json::to_vec(&body).unwrap());
+
let resp: Response<DummyReq> = Response::new(buf, StatusCode::BAD_REQUEST);
+
match resp.parse().unwrap_err() {
+
XrpcError::Generic(g) => {
+
assert_eq!(g.error.as_str(), "InvalidRequest");
+
assert_eq!(g.message.as_deref(), Some("missing"));
+
assert_eq!(g.nsid, DummyReq::NSID);
+
assert_eq!(g.method, DummyReq::METHOD.as_str());
+
assert_eq!(g.http_status, StatusCode::BAD_REQUEST);
+
}
+
other => panic!("unexpected: {other:?}"),
+
}
+
}
+
+
#[test]
+
fn auth_error_mapping() {
+
for (code, expect) in [
+
("ExpiredToken", AuthError::TokenExpired),
+
("InvalidToken", AuthError::InvalidToken),
+
] {
+
let body = serde_json::json!({"error": code});
+
let buf = Bytes::from(serde_json::to_vec(&body).unwrap());
+
let resp: Response<DummyReq> = Response::new(buf, StatusCode::UNAUTHORIZED);
+
match resp.parse().unwrap_err() {
+
XrpcError::Auth(e) => match (e, expect) {
+
(AuthError::TokenExpired, AuthError::TokenExpired) => {}
+
(AuthError::InvalidToken, AuthError::InvalidToken) => {}
+
other => panic!("mismatch: {other:?}"),
+
},
+
other => panic!("unexpected: {other:?}"),
+
}
+
}
+
}
+
+
#[test]
+
fn no_double_slash_in_path() {
+
#[derive(Serialize)]
+
struct Req;
+
#[derive(Deserialize, Debug, thiserror::Error)]
+
#[error("{0}")]
+
struct Err<'a>(#[serde(borrow)] CowStr<'a>);
+
impl IntoStatic for Err<'_> {
+
type Output = Err<'static>;
+
fn into_static(self) -> Self::Output { Err(self.0.into_static()) }
+
}
+
impl XrpcRequest for Req {
+
const NSID: &'static str = "com.example.test";
+
const METHOD: XrpcMethod = XrpcMethod::Query;
+
const OUTPUT_ENCODING: &'static str = "application/json";
+
type Output<'de> = ();
+
type Err<'de> = Err<'de>;
+
}
+
+
let opts = CallOptions::default();
+
for base in [
+
Url::parse("https://pds").unwrap(),
+
Url::parse("https://pds/").unwrap(),
+
Url::parse("https://pds/base/").unwrap(),
+
] {
+
let req = build_http_request(&base, &Req, &opts).unwrap();
+
let uri = req.uri().to_string();
+
assert!(uri.contains("/xrpc/com.example.test"));
+
assert!(!uri.contains("//xrpc"));
+
}
+
}
}
/// Stateful XRPC call trait
+12
crates/jacquard-identity/src/resolver.rs
···
#[allow(missing_docs)]
pub enum IdentityError {
#[error("unsupported DID method: {0}")]
+
#[diagnostic(code(jacquard_identity::unsupported_did_method), help("supported DID methods: did:web, did:plc"))]
UnsupportedDidMethod(String),
#[error("invalid well-known atproto-did content")]
+
#[diagnostic(code(jacquard_identity::invalid_well_known), help("expected first non-empty line to be a DID"))]
InvalidWellKnown,
#[error("missing PDS endpoint in DID document")]
+
#[diagnostic(code(jacquard_identity::missing_pds_endpoint))]
MissingPdsEndpoint,
#[error("HTTP error: {0}")]
+
#[diagnostic(code(jacquard_identity::http), help("check network connectivity and TLS configuration"))]
Http(#[from] TransportError),
#[error("HTTP status {0}")]
+
#[diagnostic(code(jacquard_identity::http_status), help("verify well-known paths or PDS XRPC endpoints"))]
HttpStatus(StatusCode),
#[error("XRPC error: {0}")]
+
#[diagnostic(code(jacquard_identity::xrpc), help("enable PDS fallback or public resolver if needed"))]
Xrpc(String),
#[error("URL parse error: {0}")]
+
#[diagnostic(code(jacquard_identity::url))]
Url(#[from] url::ParseError),
#[error("DNS error: {0}")]
#[cfg(feature = "dns")]
+
#[diagnostic(code(jacquard_identity::dns))]
Dns(#[from] hickory_resolver::error::ResolveError),
#[error("serialize/deserialize error: {0}")]
+
#[diagnostic(code(jacquard_identity::serde))]
Serde(#[from] serde_json::Error),
#[error("invalid DID document: {0}")]
+
#[diagnostic(code(jacquard_identity::invalid_doc), help("validate keys and services; ensure AtprotoPersonalDataServer service exists"))]
InvalidDoc(String),
#[error(transparent)]
+
#[diagnostic(code(jacquard_identity::data))]
Data(#[from] AtDataError),
/// DID document id did not match requested DID; includes the fetched document
#[error("DID doc id mismatch")]
+
#[diagnostic(code(jacquard_identity::doc_id_mismatch), help("document id differs from requested DID; do not trust this document"))]
DocIdMismatch {
expected: Did<'static>,
doc: DidDocument<'static>,
+1
crates/jacquard-oauth/Cargo.toml
···
chrono = "0.4"
elliptic-curve = "0.13.8"
http.workspace = true
+
bytes.workspace = true
rand = { version = "0.8.5", features = ["small_rng"] }
async-trait = "0.1.89"
dashmap = "6.1.0"
+77 -26
crates/jacquard-oauth/src/atproto.rs
···
mut redirect_uris: Option<Vec<Url>>,
scopes: Option<Vec<Scope<'m>>>,
) -> Self {
-
// coerce redirect uris to localhost
+
// Coerce provided redirect URIs to http://localhost while preserving path
if let Some(redirect_uris) = &mut redirect_uris {
for redirect_uri in redirect_uris {
-
redirect_uri.set_host(Some("http://localhost")).unwrap();
+
let _ = redirect_uri.set_scheme("http");
+
redirect_uri.set_host(Some("localhost")).unwrap();
+
let _ = redirect_uri.set_port(None);
}
}
// determine client_id
···
metadata: AtprotoClientMetadata<'m>,
keyset: &Option<Keyset>,
) -> Result<OAuthClientMetadata<'m>> {
+
// For non-loopback clients, require a keyset/JWKs.
+
let is_loopback = metadata.client_id.scheme() == "http"
+
&& metadata.client_id.host_str() == Some("localhost");
+
if !is_loopback && keyset.is_none() {
+
return Err(Error::EmptyJwks);
+
}
if metadata.redirect_uris.is_empty() {
return Err(Error::EmptyRedirectUris);
}
···
client_uri: metadata.client_uri,
redirect_uris: metadata.redirect_uris,
token_endpoint_auth_method: Some(auth_method.into()),
-
grant_types: Some(metadata.grant_types.into_iter().map(|v| v.into()).collect()),
-
scope: Some(Scope::serialize_multiple(metadata.scopes.as_slice())),
-
dpop_bound_access_tokens: Some(true),
+
grant_types: if keyset.is_some() {
+
Some(metadata.grant_types.into_iter().map(|v| v.into()).collect())
+
} else {
+
None
+
},
+
scope: if keyset.is_some() {
+
Some(Scope::serialize_multiple(metadata.scopes.as_slice()))
+
} else {
+
None
+
},
+
dpop_bound_access_tokens: if keyset.is_some() { Some(true) } else { None },
jwks_uri,
jwks,
token_endpoint_auth_signing_alg: if keyset.is_some() {
···
.expect("failed to convert metadata"),
OAuthClientMetadata {
client_id: Url::from_str(
-
"http://localhost?redirect_uri=http%3A%2F%2F127.0.0.1%2Fcallback&redirect_uri=http%3A%2F%2F%5B%3A%3A1%5D%2Fcallback&scope=account%3Aemail+atproto+transition%3Ageneric"
+
"http://localhost?redirect_uri=http%3A%2F%2Flocalhost%2Fcallback&redirect_uri=http%3A%2F%2Flocalhost%2Fcallback&scope=account%3Aemail+atproto+transition%3Ageneric"
).unwrap(),
client_uri: None,
redirect_uris: vec![
-
Url::from_str("http://127.0.0.1/callback").unwrap(),
-
Url::from_str("http://[::1]/callback").unwrap(),
+
Url::from_str("http://localhost/callback").unwrap(),
+
Url::from_str("http://localhost/callback").unwrap(),
],
scope: None,
grant_types: None,
···
#[test]
fn test_localhost_client_metadata_invalid() {
+
// Invalid inputs are coerced to http://localhost rather than failing
{
-
let err = atproto_client_metadata(
+
let out = atproto_client_metadata(
AtprotoClientMetadata::new_localhost(
Some(vec![Url::from_str("https://127.0.0.1/").unwrap()]),
None,
),
&None,
)
-
.expect_err("expected to fail");
-
assert!(matches!(
-
err,
-
Error::LocalhostClient(LocalhostClientError::NotHttpScheme)
-
));
+
.expect("should coerce to localhost");
+
assert_eq!(
+
out,
+
OAuthClientMetadata {
+
client_id: Url::from_str("http://localhost?redirect_uri=http%3A%2F%2Flocalhost%2F").unwrap(),
+
client_uri: None,
+
redirect_uris: vec![Url::from_str("http://localhost/").unwrap()],
+
scope: None,
+
grant_types: None,
+
token_endpoint_auth_method: Some(AuthMethod::None.into()),
+
dpop_bound_access_tokens: None,
+
jwks_uri: None,
+
jwks: None,
+
token_endpoint_auth_signing_alg: None,
+
}
+
);
}
{
-
let err = atproto_client_metadata(
+
let out = atproto_client_metadata(
AtprotoClientMetadata::new_localhost(
Some(vec![Url::from_str("http://localhost:8000/").unwrap()]),
None,
),
&None,
)
-
.expect_err("expected to fail");
-
assert!(matches!(
-
err,
-
Error::LocalhostClient(LocalhostClientError::Localhost)
-
));
+
.expect("should coerce to localhost");
+
assert_eq!(
+
out,
+
OAuthClientMetadata {
+
client_id: Url::from_str("http://localhost?redirect_uri=http%3A%2F%2Flocalhost%2F").unwrap(),
+
client_uri: None,
+
redirect_uris: vec![Url::from_str("http://localhost/").unwrap()],
+
scope: None,
+
grant_types: None,
+
token_endpoint_auth_method: Some(AuthMethod::None.into()),
+
dpop_bound_access_tokens: None,
+
jwks_uri: None,
+
jwks: None,
+
token_endpoint_auth_signing_alg: None,
+
}
+
);
}
{
-
let err = atproto_client_metadata(
+
let out = atproto_client_metadata(
AtprotoClientMetadata::new_localhost(
Some(vec![Url::from_str("http://192.168.0.0/").unwrap()]),
None,
),
&None,
)
-
.expect_err("expected to fail");
-
assert!(matches!(
-
err,
-
Error::LocalhostClient(LocalhostClientError::NotLoopbackHost)
-
));
+
.expect("should coerce to localhost");
+
assert_eq!(
+
out,
+
OAuthClientMetadata {
+
client_id: Url::from_str("http://localhost?redirect_uri=http%3A%2F%2Flocalhost%2F").unwrap(),
+
client_uri: None,
+
redirect_uris: vec![Url::from_str("http://localhost/").unwrap()],
+
scope: None,
+
grant_types: None,
+
token_endpoint_auth_method: Some(AuthMethod::None.into()),
+
dpop_bound_access_tokens: None,
+
jwks_uri: None,
+
jwks: None,
+
token_endpoint_auth_signing_alg: None,
+
}
+
);
}
}
···
jwks_uri: None,
};
{
+
// Non-loopback clients without a keyset should fail (must provide JWKS)
let metadata = metadata.clone();
let err = atproto_client_metadata(metadata, &None).expect_err("expected to fail");
assert!(matches!(err, Error::EmptyJwks));
+26 -22
crates/jacquard-oauth/src/client.rs
···
atproto::atproto_client_metadata,
authstore::ClientAuthStore,
dpop::DpopExt,
-
error::{OAuthError, Result},
+
error::{CallbackError, Result},
request::{OAuthMetadata, exchange_code, par},
resolver::OAuthResolver,
scopes::Scope,
···
pub async fn callback(&self, params: CallbackParams<'_>) -> Result<OAuthSession<T, S>> {
let Some(state_key) = params.state else {
-
return Err(OAuthError::Callback("missing state parameter".into()));
+
return Err(CallbackError::MissingState.into());
};
let Some(auth_req_info) = self.registry.store.get_auth_req_info(&state_key).await? else {
-
return Err(OAuthError::Callback(format!(
-
"unknown authorization state: {state_key}"
-
)));
+
return Err(CallbackError::MissingState.into());
};
self.registry.store.delete_auth_req_info(&state_key).await?;
···
.await?;
if let Some(iss) = params.iss {
-
if iss != metadata.issuer {
-
return Err(OAuthError::Callback(format!(
-
"issuer mismatch: expected {}, got {iss}",
-
metadata.issuer
-
)));
+
if !crate::resolver::issuer_equivalent(&iss, &metadata.issuer) {
+
return Err(CallbackError::IssuerMismatch { expected: metadata.issuer.to_string(), got: iss.to_string() }.into());
}
} else if metadata.authorization_response_iss_parameter_supported == Some(true) {
-
return Err(OAuthError::Callback("missing `iss` parameter".into()));
+
return Err(CallbackError::MissingIssuer.into());
}
let metadata = OAuthMetadata {
server_metadata: metadata,
···
}
pub async fn refresh_token(&self) -> Option<AuthorizationToken<'_>> {
-
self.data
-
.read()
-
.await
-
.token_set
-
.refresh_token
-
.as_ref()
-
.map(|token| AuthorizationToken::Dpop(token.clone()))
+
self.data.read().await.token_set.refresh_token.as_ref().map(|t| AuthorizationToken::Dpop(t.clone()))
}
}
impl<T, S> OAuthSession<T, S>
···
T: OAuthResolver + DpopExt + Send + Sync + 'static,
{
pub async fn refresh(&self) -> Result<AuthorizationToken<'_>> {
-
let mut data = self.data.write().await;
+
// Read identifiers without holding the lock across await
+
let (did, sid) = {
+
let data = self.data.read().await;
+
(data.account_did.clone(), data.session_id.clone())
+
};
let refreshed = self
.registry
.as_ref()
-
.get(&data.account_did, &data.session_id, true)
+
.get(&did, &sid, true)
.await?;
let token = AuthorizationToken::Dpop(refreshed.token_set.access_token.clone());
-
*data = refreshed.into_static();
+
// Write back updated session
+
*self.data.write().await = refreshed.into_static();
Ok(token)
}
}
···
T: OAuthResolver + DpopExt + XrpcExt + Send + Sync + 'static,
{
fn base_uri(&self) -> Url {
-
self.data.blocking_read().host_url.clone()
+
// base_uri is a synchronous trait method; we must avoid async `.read().await`.
+
// Use `block_in_place` under Tokio to perform a blocking RwLock read safely.
+
if tokio::runtime::Handle::try_current().is_ok() {
+
tokio::task::block_in_place(|| self.data.blocking_read().host_url.clone())
+
} else {
+
self.data.blocking_read().host_url.clone()
+
}
}
async fn opts(&self) -> CallOptions<'_> {
···
Err(ClientError::Auth(AuthError::Other(value))) => value
.to_str()
.is_ok_and(|s| s.starts_with("DPoP ") && s.contains("error=\"invalid_token\"")),
+
Ok(resp) => match resp.parse() {
+
Err(jacquard_common::types::xrpc::XrpcError::Auth(AuthError::InvalidToken)) => true,
+
_ => false,
+
},
_ => false,
}
}
+50 -39
crates/jacquard-oauth/src/error.rs
···
use jacquard_common::session::SessionStoreError;
use miette::Diagnostic;
+
use crate::request::RequestError;
use crate::resolver::ResolverError;
-
/// Errors emitted by OAuth helpers.
+
/// High-level errors emitted by OAuth helpers.
#[derive(Debug, thiserror::Error, Diagnostic)]
pub enum OAuthError {
-
/// Invalid or unsupported JWK
-
#[error("invalid JWK: {0}")]
-
#[diagnostic(
-
code(jacquard_oauth::jwk),
-
help("Ensure EC P-256 JWK with base64url x,y,d values")
-
)]
-
Jwk(String),
-
/// Signing error
-
#[error("signing error: {0}")]
-
#[diagnostic(
-
code(jacquard_oauth::signing),
-
help("Check ES256 key material and input payloads")
-
)]
-
Signing(String),
-
/// Serialization error
+
#[error(transparent)]
+
#[diagnostic(code(jacquard_oauth::resolver))]
+
Resolver(#[from] ResolverError),
+
+
#[error(transparent)]
+
#[diagnostic(code(jacquard_oauth::request))]
+
Request(#[from] RequestError),
+
#[error(transparent)]
-
#[diagnostic(code(jacquard_oauth::serde))]
-
Serde(#[from] serde_json::Error),
-
/// URL error
+
#[diagnostic(code(jacquard_oauth::storage))]
+
Storage(#[from] SessionStoreError),
+
#[error(transparent)]
-
#[diagnostic(code(jacquard_oauth::url))]
-
Url(#[from] url::ParseError),
-
/// URL error
+
#[diagnostic(code(jacquard_oauth::dpop))]
+
Dpop(#[from] crate::dpop::Error),
+
#[error(transparent)]
-
#[diagnostic(code(jacquard_oauth::url))]
-
UrlEncoding(#[from] serde_html_form::ser::Error),
-
/// PKCE error
-
#[error("pkce error: {0}")]
-
#[diagnostic(
-
code(jacquard_oauth::pkce),
-
help("PKCE must use S256; ensure verifier/challenge generated")
-
)]
-
Pkce(String),
-
#[error("authorize error: {0}")]
-
Authorize(String),
+
#[diagnostic(code(jacquard_oauth::keyset))]
+
Keyset(#[from] crate::keyset::Error),
+
#[error(transparent)]
+
#[diagnostic(code(jacquard_oauth::atproto))]
Atproto(#[from] crate::atproto::Error),
-
#[error("callback error: {0}")]
-
Callback(String),
-
#[error(transparent)]
-
Storage(#[from] SessionStoreError),
+
#[error(transparent)]
+
#[diagnostic(code(jacquard_oauth::session))]
Session(#[from] crate::session::Error),
+
#[error(transparent)]
-
Request(#[from] crate::request::Error),
+
#[diagnostic(code(jacquard_oauth::serde_json))]
+
SerdeJson(#[from] serde_json::Error),
+
#[error(transparent)]
-
Client(#[from] ResolverError),
+
#[diagnostic(code(jacquard_oauth::url))]
+
Url(#[from] url::ParseError),
+
+
#[error(transparent)]
+
#[diagnostic(code(jacquard_oauth::form))]
+
Form(#[from] serde_html_form::ser::Error),
+
+
#[error(transparent)]
+
#[diagnostic(code(jacquard_oauth::callback))]
+
Callback(#[from] CallbackError),
+
}
+
+
/// Typed callback validation errors (redirect handling).
+
#[derive(Debug, thiserror::Error, Diagnostic)]
+
pub enum CallbackError {
+
#[error("missing state parameter in callback")]
+
#[diagnostic(code(jacquard_oauth::callback::missing_state))]
+
MissingState,
+
#[error("missing `iss` parameter")]
+
#[diagnostic(code(jacquard_oauth::callback::missing_iss))]
+
MissingIssuer,
+
#[error("issuer mismatch: expected {expected}, got {got}")]
+
#[diagnostic(code(jacquard_oauth::callback::issuer_mismatch))]
+
IssuerMismatch { expected: String, got: String },
}
pub type Result<T> = core::result::Result<T, OAuthError>;
+207 -17
crates/jacquard-oauth/src/request.rs
···
const CLIENT_ASSERTION_TYPE_JWT_BEARER: &str =
"urn:ietf:params:oauth:client-assertion-type:jwt-bearer";
-
#[derive(Error, Debug)]
-
pub enum Error {
+
#[derive(Error, Debug, miette::Diagnostic)]
+
pub enum RequestError {
#[error("no {0} endpoint available")]
+
#[diagnostic(code(jacquard_oauth::request::no_endpoint), help("server does not advertise this endpoint"))]
NoEndpoint(CowStr<'static>),
#[error("token response verification failed")]
-
Token(CowStr<'static>),
+
#[diagnostic(code(jacquard_oauth::request::token_verification))]
+
TokenVerification,
#[error("unsupported authentication method")]
+
#[diagnostic(
+
code(jacquard_oauth::request::unsupported_auth_method),
+
help("server must support `private_key_jwt` or `none`; configure client metadata accordingly")
+
)]
UnsupportedAuthMethod,
#[error("no refresh token available")]
-
TokenRefresh,
+
#[diagnostic(code(jacquard_oauth::request::no_refresh_token))]
+
NoRefreshToken,
#[error("failed to parse DID: {0}")]
+
#[diagnostic(code(jacquard_oauth::request::invalid_did))]
InvalidDid(#[from] AtStrError),
#[error(transparent)]
+
#[diagnostic(code(jacquard_oauth::request::dpop))]
DpopClient(#[from] crate::dpop::Error),
#[error(transparent)]
+
#[diagnostic(code(jacquard_oauth::request::storage))]
Storage(#[from] SessionStoreError),
#[error(transparent)]
+
#[diagnostic(code(jacquard_oauth::request::resolver))]
ResolverError(#[from] crate::resolver::ResolverError),
// #[error(transparent)]
// OAuthSession(#[from] crate::oauth_session::Error),
#[error(transparent)]
+
#[diagnostic(code(jacquard_oauth::request::http_build))]
Http(#[from] http::Error),
-
#[error("http client error: {0}")]
-
HttpClient(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("http status: {0}")]
+
#[diagnostic(code(jacquard_oauth::request::http_status), help("see server response for details"))]
HttpStatus(StatusCode),
#[error("http status: {0}, body: {1:?}")]
+
#[diagnostic(code(jacquard_oauth::request::http_status_body), help("server returned error JSON; inspect fields like `error`, `error_description`"))]
HttpStatusWithBody(StatusCode, Value),
#[error(transparent)]
+
#[diagnostic(code(jacquard_oauth::request::identity))]
Identity(#[from] IdentityError),
#[error(transparent)]
+
#[diagnostic(code(jacquard_oauth::request::keyset))]
Keyset(#[from] crate::keyset::Error),
#[error(transparent)]
+
#[diagnostic(code(jacquard_oauth::request::serde_form))]
SerdeHtmlForm(#[from] serde_html_form::ser::Error),
#[error(transparent)]
+
#[diagnostic(code(jacquard_oauth::request::serde_json))]
SerdeJson(#[from] serde_json::Error),
#[error(transparent)]
+
#[diagnostic(code(jacquard_oauth::request::atproto))]
Atproto(#[from] crate::atproto::Error),
}
-
pub type Result<T> = core::result::Result<T, Error>;
+
pub type Result<T> = core::result::Result<T, RequestError>;
#[allow(dead_code)]
pub enum OAuthRequest<'a> {
···
}
}
+
#[cfg(test)]
+
mod tests {
+
use super::*;
+
use crate::types::{OAuthAuthorizationServerMetadata, OAuthClientMetadata};
+
use http::{Response as HttpResponse, StatusCode};
+
use bytes::Bytes;
+
use jacquard_common::http_client::HttpClient;
+
use jacquard_identity::resolver::IdentityResolver;
+
use std::sync::Arc;
+
use tokio::sync::Mutex;
+
+
#[derive(Clone, Default)]
+
struct MockClient {
+
resp: Arc<Mutex<Option<HttpResponse<Vec<u8>>>>>,
+
}
+
+
impl HttpClient for MockClient {
+
type Error = std::convert::Infallible;
+
fn send_http(
+
&self,
+
_request: http::Request<Vec<u8>>,
+
) -> impl core::future::Future<
+
Output = core::result::Result<http::Response<Vec<u8>>, Self::Error>,
+
> + Send {
+
let resp = self.resp.clone();
+
async move { Ok(resp.lock().await.take().unwrap()) }
+
}
+
}
+
+
// IdentityResolver methods won't be called in these tests; provide stubs.
+
#[async_trait::async_trait]
+
impl IdentityResolver for MockClient {
+
fn options(&self) -> &jacquard_identity::resolver::ResolverOptions {
+
use std::sync::LazyLock;
+
static OPTS: LazyLock<jacquard_identity::resolver::ResolverOptions> =
+
LazyLock::new(|| jacquard_identity::resolver::ResolverOptions::default());
+
&OPTS
+
}
+
async fn resolve_handle(
+
&self,
+
_handle: &jacquard_common::types::string::Handle<'_>,
+
) -> std::result::Result<
+
jacquard_common::types::string::Did<'static>,
+
jacquard_identity::resolver::IdentityError,
+
> {
+
Ok(jacquard_common::types::string::Did::new_static("did:plc:alice").unwrap())
+
}
+
async fn resolve_did_doc(
+
&self,
+
_did: &jacquard_common::types::string::Did<'_>,
+
) -> std::result::Result<
+
jacquard_identity::resolver::DidDocResponse,
+
jacquard_identity::resolver::IdentityError,
+
> {
+
let doc = serde_json::json!({
+
"id": "did:plc:alice",
+
"service": [{
+
"id": "#pds",
+
"type": "AtprotoPersonalDataServer",
+
"serviceEndpoint": "https://pds"
+
}]
+
});
+
let buf = Bytes::from(serde_json::to_vec(&doc).unwrap());
+
Ok(jacquard_identity::resolver::DidDocResponse {
+
buffer: buf,
+
status: StatusCode::OK,
+
requested: None,
+
})
+
}
+
}
+
+
// Allow using DPoP helpers on MockClient
+
impl crate::dpop::DpopExt for MockClient {}
+
impl crate::resolver::OAuthResolver for MockClient {}
+
+
fn base_metadata() -> OAuthMetadata {
+
let mut server = OAuthAuthorizationServerMetadata::default();
+
server.issuer = CowStr::from("https://issuer");
+
server.authorization_endpoint = CowStr::from("https://issuer/authorize");
+
server.token_endpoint = CowStr::from("https://issuer/token");
+
OAuthMetadata {
+
server_metadata: server,
+
client_metadata: OAuthClientMetadata {
+
client_id: url::Url::parse("https://client").unwrap(),
+
client_uri: None,
+
redirect_uris: vec![url::Url::parse("https://client/cb").unwrap()],
+
scope: Some(CowStr::from("atproto")),
+
grant_types: None,
+
token_endpoint_auth_method: Some(CowStr::from("none")),
+
dpop_bound_access_tokens: None,
+
jwks_uri: None,
+
jwks: None,
+
token_endpoint_auth_signing_alg: None,
+
},
+
keyset: None,
+
}
+
}
+
+
#[tokio::test]
+
async fn par_missing_endpoint() {
+
let mut meta = base_metadata();
+
meta.server_metadata.require_pushed_authorization_requests = Some(true);
+
meta.server_metadata.pushed_authorization_request_endpoint = None;
+
// require_pushed_authorization_requests is true and no endpoint
+
let err = super::par(&MockClient::default(), None, None, &meta)
+
.await
+
.unwrap_err();
+
match err {
+
RequestError::NoEndpoint(name) => {
+
assert_eq!(name.as_ref(), "pushed_authorization_request");
+
}
+
other => panic!("unexpected: {other:?}"),
+
}
+
}
+
+
#[tokio::test]
+
async fn refresh_no_refresh_token() {
+
let client = MockClient::default();
+
let meta = base_metadata();
+
let mut session = ClientSessionData {
+
account_did: jacquard_common::types::string::Did::new_static("did:plc:alice").unwrap(),
+
session_id: CowStr::from("state"),
+
host_url: url::Url::parse("https://pds").unwrap(),
+
authserver_url: url::Url::parse("https://issuer").unwrap(),
+
authserver_token_endpoint: CowStr::from("https://issuer/token"),
+
authserver_revocation_endpoint: None,
+
scopes: vec![],
+
dpop_data: DpopClientData {
+
dpop_key: crate::utils::generate_key(&[CowStr::from("ES256")]).unwrap(),
+
dpop_authserver_nonce: CowStr::from(""),
+
dpop_host_nonce: CowStr::from(""),
+
},
+
token_set: crate::types::TokenSet {
+
iss: CowStr::from("https://issuer"),
+
sub: jacquard_common::types::string::Did::new_static("did:plc:alice").unwrap(),
+
aud: CowStr::from("https://pds"),
+
scope: None,
+
refresh_token: None,
+
access_token: CowStr::from("abc"),
+
token_type: crate::types::OAuthTokenType::DPoP,
+
expires_at: None,
+
},
+
};
+
let err = super::refresh(&client, session, &meta).await.unwrap_err();
+
matches!(err, RequestError::NoRefreshToken);
+
}
+
+
#[tokio::test]
+
async fn exchange_code_missing_sub() {
+
let client = MockClient::default();
+
// set mock HTTP response body: token response without `sub`
+
*client.resp.lock().await = Some(
+
HttpResponse::builder()
+
.status(StatusCode::OK)
+
.body(serde_json::to_vec(&serde_json::json!({
+
"access_token":"tok",
+
"token_type":"DPoP",
+
"expires_in": 3600
+
})).unwrap())
+
.unwrap(),
+
);
+
let meta = base_metadata();
+
let mut dpop = DpopReqData {
+
dpop_key: crate::utils::generate_key(&[CowStr::from("ES256")]).unwrap(),
+
dpop_authserver_nonce: None,
+
};
+
let err = super::exchange_code(&client, &mut dpop, "abc", "verifier", &meta)
+
.await
+
.unwrap_err();
+
matches!(err, RequestError::TokenVerification);
+
}
+
}
+
#[derive(Debug, Serialize)]
pub struct RequestPayload<'a, T>
where
···
let (code_challenge, verifier) = generate_pkce();
let Some(dpop_key) = generate_dpop_key(&metadata.server_metadata) else {
-
return Err(Error::Token("none of the algorithms worked".into()));
+
return Err(RequestError::TokenVerification);
};
let mut dpop_data = DpopReqData {
dpop_key,
···
.require_pushed_authorization_requests
== Some(true)
{
-
Err(Error::NoEndpoint(CowStr::new_static(
-
"server requires PAR but no endpoint is available",
+
Err(RequestError::NoEndpoint(CowStr::new_static(
+
"pushed_authorization_request",
)))
} else {
todo!("use of PAR is mandatory")
···
T: OAuthResolver + DpopExt + Send + Sync + 'static,
{
let Some(refresh_token) = session_data.token_set.refresh_token.as_ref() else {
-
return Err(Error::TokenRefresh);
+
return Err(RequestError::NoRefreshToken);
};
// /!\ IMPORTANT /!\
···
)
.await?;
let Some(sub) = token_response.sub else {
-
return Err(Error::Token("missing `sub` in token response".into()));
+
return Err(RequestError::TokenVerification);
};
let sub = Did::new_owned(sub)?;
let iss = metadata.server_metadata.issuer.clone();
···
D: DpopDataSource,
{
let Some(url) = endpoint_for_req(&metadata.server_metadata, &request) else {
-
return Err(Error::NoEndpoint(request.name()));
+
return Err(RequestError::NoEndpoint(request.name()));
};
let client_assertions = build_auth(
metadata.keyset.as_ref(),
···
.dpop_server_call(data_source)
.send(req)
.await
-
.map_err(Error::DpopClient)?;
+
.map_err(RequestError::DpopClient)?;
if res.status() == request.expected_status() {
let body = res.body();
if body.is_empty() {
···
Ok(output)
}
} else if res.status().is_client_error() {
-
Err(Error::HttpStatusWithBody(
+
Err(RequestError::HttpStatusWithBody(
res.status(),
serde_json::from_slice(res.body())?,
))
} else {
-
Err(Error::HttpStatus(res.status()))
+
Err(RequestError::HttpStatus(res.status()))
}
}
···
}
}
-
Err(Error::UnsupportedAuthMethod)
+
Err(RequestError::UnsupportedAuthMethod)
}
+129 -17
crates/jacquard-oauth/src/resolver.rs
···
use crate::types::{OAuthAuthorizationServerMetadata, OAuthProtectedResourceMetadata};
use http::{Request, StatusCode};
-
use jacquard_common::IntoStatic;
+
use jacquard_common::{IntoStatic, error::TransportError};
use jacquard_common::types::did_doc::DidDocument;
use jacquard_common::types::ident::AtIdentifier;
use jacquard_common::{http_client::HttpClient, types::did::Did};
use jacquard_identity::resolver::{IdentityError, IdentityResolver};
use url::Url;
+
/// Compare two issuer strings strictly but without spuriously failing on trivial differences.
+
///
+
/// Rules:
+
/// - Schemes must match exactly.
+
/// - Hostnames and effective ports must match (treat missing port the same as default port).
+
/// - Path must match, except that an empty path and `/` are equivalent.
+
/// - Query/fragment are not considered; if present on either side, the comparison fails.
+
pub(crate) fn issuer_equivalent(a: &str, b: &str) -> bool {
+
fn normalize(url: &Url) -> Option<(String, String, u16, String)> {
+
if url.query().is_some() || url.fragment().is_some() {
+
return None;
+
}
+
let scheme = url.scheme().to_string();
+
let host = url.host_str()?.to_string();
+
let port = url.port_or_known_default()?;
+
let path = match url.path() {
+
"" => "/".to_string(),
+
"/" => "/".to_string(),
+
other => other.to_string(),
+
};
+
Some((scheme, host, port, path))
+
}
+
+
match (Url::parse(a), Url::parse(b)) {
+
(Ok(ua), Ok(ub)) => match (normalize(&ua), normalize(&ub)) {
+
(Some((sa, ha, pa, pa_path)), Some((sb, hb, pb, pb_path))) => {
+
if sa != sb || ha != hb || pa != pb {
+
return false;
+
}
+
if pa_path == "/" && pb_path == "/" {
+
return true;
+
}
+
pa_path == pb_path
+
}
+
_ => false,
+
},
+
_ => a == b,
+
}
+
}
+
#[derive(thiserror::Error, Debug, miette::Diagnostic)]
pub enum ResolverError {
#[error("resource not found")]
+
#[diagnostic(code(jacquard_oauth::resolver::not_found), help("check the base URL or identifier"))]
NotFound,
#[error("invalid at identifier: {0}")]
+
#[diagnostic(code(jacquard_oauth::resolver::at_identifier), help("ensure a valid handle or DID was provided"))]
AtIdentifier(String),
#[error("invalid did: {0}")]
+
#[diagnostic(code(jacquard_oauth::resolver::did), help("ensure DID is correctly formed (did:plc or did:web)"))]
Did(String),
#[error("invalid did document: {0}")]
+
#[diagnostic(code(jacquard_oauth::resolver::did_document), help("verify the DID document structure and service entries"))]
DidDocument(String),
#[error("protected resource metadata is invalid: {0}")]
+
#[diagnostic(code(jacquard_oauth::resolver::protected_resource_metadata), help("PDS must advertise an authorization server in its protected resource metadata"))]
ProtectedResourceMetadata(String),
#[error("authorization server metadata is invalid: {0}")]
+
#[diagnostic(code(jacquard_oauth::resolver::authorization_server_metadata), help("issuer must match and include the PDS resource"))]
AuthorizationServerMetadata(String),
#[error("error resolving identity: {0}")]
+
#[diagnostic(code(jacquard_oauth::resolver::identity))]
IdentityResolverError(#[from] IdentityError),
#[error("unsupported did method: {0:?}")]
+
#[diagnostic(code(jacquard_oauth::resolver::unsupported_did_method), help("supported DID methods: did:web, did:plc"))]
UnsupportedDidMethod(Did<'static>),
#[error(transparent)]
-
Http(#[from] http::Error),
-
#[error("http client error: {0}")]
-
HttpClient(Box<dyn std::error::Error + Send + Sync + 'static>),
+
#[diagnostic(code(jacquard_oauth::resolver::transport))]
+
Transport(#[from] TransportError),
#[error("http status: {0:?}")]
+
#[diagnostic(code(jacquard_oauth::resolver::http_status), help("check well-known paths and server configuration"))]
HttpStatus(StatusCode),
#[error(transparent)]
+
#[diagnostic(code(jacquard_oauth::resolver::serde_json))]
SerdeJson(#[from] serde_json::Error),
#[error(transparent)]
+
#[diagnostic(code(jacquard_oauth::resolver::serde_form))]
SerdeHtmlForm(#[from] serde_html_form::ser::Error),
#[error(transparent)]
+
#[diagnostic(code(jacquard_oauth::resolver::url))]
Uri(#[from] url::ParseError),
}
···
sub: &Did<'_>,
) -> Result<Url, ResolverError> {
let (metadata, identity) = self.resolve_from_identity(sub).await?;
-
if metadata.issuer != server_metadata.issuer {
-
return Err(ResolverError::Did(format!("DIDs did not match")));
+
if !issuer_equivalent(&metadata.issuer, &server_metadata.issuer) {
+
return Err(ResolverError::AuthorizationServerMetadata(
+
"issuer mismatch".to_string(),
+
));
}
Ok(identity
.pds_endpoint()
···
&self,
issuer: &Url,
) -> Result<OAuthAuthorizationServerMetadata<'static>, ResolverError> {
-
Ok(resolve_authorization_server(self, issuer).await?)
+
let mut md = resolve_authorization_server(self, issuer).await?;
+
// Normalize issuer string to the input URL representation to avoid slash quirks
+
md.issuer = jacquard_common::CowStr::from(issuer.as_str()).into_static();
+
Ok(md)
}
async fn get_resource_server_metadata(
&self,
···
) -> Result<OAuthAuthorizationServerMetadata<'static>, ResolverError> {
let url = server
.join("/.well-known/oauth-authorization-server")
-
.map_err(|e| ResolverError::HttpClient(e.into()))?;
+
.map_err(|e| ResolverError::Transport(TransportError::Other(Box::new(e))))?;
let req = Request::builder()
.uri(url.to_string())
.body(Vec::new())
-
.map_err(|e| ResolverError::HttpClient(e.into()))?;
+
.map_err(|e| ResolverError::Transport(TransportError::InvalidRequest(e.to_string())))?;
let res = client
.send_http(req)
.await
-
.map_err(|e| ResolverError::HttpClient(e.into()))?;
+
.map_err(|e| ResolverError::Transport(TransportError::Other(Box::new(e))))?;
if res.status() == StatusCode::OK {
-
let metadata = serde_json::from_slice::<OAuthAuthorizationServerMetadata>(res.body())
+
let mut metadata = serde_json::from_slice::<OAuthAuthorizationServerMetadata>(res.body())
.map_err(ResolverError::SerdeJson)?;
// https://datatracker.ietf.org/doc/html/rfc8414#section-3.3
-
if metadata.issuer == server.as_str() {
+
// Accept semantically equivalent issuer (normalize to the requested URL form)
+
if issuer_equivalent(&metadata.issuer, server.as_str()) {
+
metadata.issuer = server.as_str().into();
Ok(metadata.into_static())
} else {
Err(ResolverError::AuthorizationServerMetadata(format!(
···
) -> Result<OAuthProtectedResourceMetadata<'static>, ResolverError> {
let url = server
.join("/.well-known/oauth-protected-resource")
-
.map_err(|e| ResolverError::HttpClient(e.into()))?;
+
.map_err(|e| ResolverError::Transport(TransportError::Other(Box::new(e))))?;
let req = Request::builder()
.uri(url.to_string())
.body(Vec::new())
-
.map_err(|e| ResolverError::HttpClient(e.into()))?;
+
.map_err(|e| ResolverError::Transport(TransportError::InvalidRequest(e.to_string())))?;
let res = client
.send_http(req)
.await
-
.map_err(|e| ResolverError::HttpClient(e.into()))?;
+
.map_err(|e| ResolverError::Transport(TransportError::Other(Box::new(e))))?;
if res.status() == StatusCode::OK {
-
let metadata = serde_json::from_slice::<OAuthProtectedResourceMetadata>(res.body())
+
let mut metadata = serde_json::from_slice::<OAuthProtectedResourceMetadata>(res.body())
.map_err(ResolverError::SerdeJson)?;
// https://datatracker.ietf.org/doc/html/rfc8414#section-3.3
-
if metadata.resource == server.as_str() {
+
// Accept semantically equivalent resource URL (normalize to the requested URL form)
+
if issuer_equivalent(&metadata.resource, server.as_str()) {
+
metadata.resource = server.as_str().into();
Ok(metadata.into_static())
} else {
Err(ResolverError::AuthorizationServerMetadata(format!(
···
#[async_trait::async_trait]
impl OAuthResolver for jacquard_identity::JacquardResolver {}
+
+
#[cfg(test)]
+
mod tests {
+
use super::*;
+
use http::{Request as HttpRequest, Response as HttpResponse, StatusCode};
+
use jacquard_common::http_client::HttpClient;
+
+
#[derive(Default, Clone)]
+
struct MockHttp {
+
next: std::sync::Arc<tokio::sync::Mutex<Option<HttpResponse<Vec<u8>>>>>,
+
}
+
+
impl HttpClient for MockHttp {
+
type Error = std::convert::Infallible;
+
fn send_http(
+
&self,
+
_request: HttpRequest<Vec<u8>>,
+
) -> impl core::future::Future<
+
Output = core::result::Result<HttpResponse<Vec<u8>>, Self::Error>,
+
> + Send {
+
let next = self.next.clone();
+
async move { Ok(next.lock().await.take().unwrap()) }
+
}
+
}
+
+
#[tokio::test]
+
async fn authorization_server_http_status() {
+
let client = MockHttp::default();
+
*client.next.lock().await = Some(HttpResponse::builder().status(StatusCode::NOT_FOUND).body(Vec::new()).unwrap());
+
let issuer = url::Url::parse("https://issuer").unwrap();
+
let err = super::resolve_authorization_server(&client, &issuer).await.unwrap_err();
+
matches!(err, ResolverError::HttpStatus(StatusCode::NOT_FOUND));
+
}
+
+
#[tokio::test]
+
async fn authorization_server_bad_json() {
+
let client = MockHttp::default();
+
*client.next.lock().await = Some(HttpResponse::builder().status(StatusCode::OK).body(b"{not json}".to_vec()).unwrap());
+
let issuer = url::Url::parse("https://issuer").unwrap();
+
let err = super::resolve_authorization_server(&client, &issuer).await.unwrap_err();
+
matches!(err, ResolverError::SerdeJson(_));
+
}
+
+
#[test]
+
fn issuer_equivalence_rules() {
+
assert!(super::issuer_equivalent("https://issuer", "https://issuer/"));
+
assert!(super::issuer_equivalent("https://issuer:443/", "https://issuer/"));
+
assert!(!super::issuer_equivalent("http://issuer/", "https://issuer/"));
+
assert!(!super::issuer_equivalent("https://issuer/foo", "https://issuer/"));
+
assert!(!super::issuer_equivalent("https://issuer/?q=1", "https://issuer/"));
+
}
+
}
+6 -3
crates/jacquard-oauth/src/session.rs
···
server_metadata: client
.get_authorization_server_metadata(&self.session_data.authserver_url)
.await
-
.map_err(|e| Error::ServerAgent(crate::request::Error::ResolverError(e)))?,
+
.map_err(|e| Error::ServerAgent(crate::request::RequestError::ResolverError(e)))?,
client_metadata: atproto_client_metadata(self.config.clone(), &self.keyset)
.unwrap()
.into_static(),
···
}
}
-
#[derive(thiserror::Error, Debug)]
+
#[derive(thiserror::Error, Debug, miette::Diagnostic)]
pub enum Error {
#[error(transparent)]
-
ServerAgent(#[from] crate::request::Error),
+
#[diagnostic(code(jacquard_oauth::session::request))]
+
ServerAgent(#[from] crate::request::RequestError),
#[error(transparent)]
+
#[diagnostic(code(jacquard_oauth::session::storage))]
Store(#[from] SessionStoreError),
#[error("session does not exist")]
+
#[diagnostic(code(jacquard_oauth::session::not_found))]
SessionNotFound,
}
+7 -3
crates/jacquard/src/client.rs
···
//! This module provides HTTP and XRPC client traits along with an authenticated
//! client implementation that manages session tokens.
+
/// Stateful session client for app‑password auth with auto‑refresh.
pub mod credential_session;
+
/// Token storage and on‑disk formats shared across app‑password and OAuth.
pub mod token;
use core::future::Future;
···
}
}
-
/// A unified indicator for the type of authenticated session.
+
/// Identifies the active authentication mode for an agent/session.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AgentKind {
/// App password (Bearer) session
···
}
/// Common interface for stateful sessions used by the Agent wrapper.
+
///
+
/// Implemented by `CredentialSession` (app‑password) and `OAuthSession` (DPoP).
pub trait AgentSession: XrpcClient + HttpClient + Send + Sync {
/// Identify the kind of session.
fn session_kind(&self) -> AgentKind;
···
}
}
-
/// Thin wrapper that erases the concrete session type while preserving type-safety.
+
/// Thin wrapper over a stateful session providing a uniform `XrpcClient`.
pub struct Agent<A: AgentSession> {
inner: A,
}
···
self.inner.endpoint().await
}
-
/// Override call options.
+
/// Override call options for subsequent requests.
pub async fn set_options(&self, opts: CallOptions<'_>) {
self.inner.set_options(opts).await
}
+41 -4
crates/jacquard/src/client/credential_session.rs
···
use jacquard_identity::resolver::IdentityResolver;
use std::any::Any;
+
/// Storage key for app‑password sessions: `(account DID, session id)`.
pub type SessionKey = (Did<'static>, CowStr<'static>);
+
/// Stateful client for app‑password based sessions.
+
///
+
/// - Persists sessions via a pluggable `SessionStore`.
+
/// - Automatically refreshes on token expiry.
+
/// - Tracks a base endpoint, defaulting to the public appview until login/restore.
pub struct CredentialSession<S, T>
where
S: SessionStore<SessionKey, AtpSession>,
{
store: Arc<S>,
client: Arc<T>,
+
/// Default call options applied to each request (auth/headers/labelers).
pub options: RwLock<CallOptions<'static>>,
+
/// Active session key, if any.
pub key: RwLock<Option<SessionKey>>,
+
/// Current base endpoint (PDS); defaults to public appview when unset.
pub endpoint: RwLock<Option<Url>>,
}
···
where
S: SessionStore<SessionKey, AtpSession>,
{
+
/// Create a new credential session using the given store and client.
pub fn new(store: Arc<S>, client: Arc<T>) -> Self {
Self {
store,
···
where
S: SessionStore<SessionKey, AtpSession>,
{
+
/// Return a copy configured with the provided default call options.
pub fn with_options(self, options: CallOptions<'_>) -> Self {
Self {
client: self.client,
···
}
}
+
/// Replace default call options.
pub async fn set_options(&self, options: CallOptions<'_>) {
*self.options.write().await = options.into_static();
}
+
/// Get the active session key (account DID and session id), if any.
pub async fn session_info(&self) -> Option<SessionKey> {
self.key.read().await.clone()
}
+
/// Current base endpoint. Defaults to the public appview when unset.
pub async fn endpoint(&self) -> Url {
self.endpoint.read().await.clone().unwrap_or(
Url::parse("https://public.bsky.app").expect("public appview should be valid url"),
)
}
+
/// Override the current base endpoint.
pub async fn set_endpoint(&self, endpoint: Url) {
*self.endpoint.write().await = Some(endpoint);
}
+
/// Current access token (Bearer), if logged in.
pub async fn access_token(&self) -> Option<AuthorizationToken<'_>> {
let key = self.key.read().await.clone()?;
let session = self.store.get(&key).await;
session.map(|session| AuthorizationToken::Bearer(session.access_jwt))
}
+
/// Current refresh token (Bearer), if logged in.
pub async fn refresh_token(&self) -> Option<AuthorizationToken<'_>> {
let key = self.key.read().await.clone()?;
let session = self.store.get(&key).await;
···
S: SessionStore<SessionKey, AtpSession>,
T: HttpClient,
{
+
/// Refresh the active session by calling `com.atproto.server.refreshSession`.
pub async fn refresh(&self) -> Result<AuthorizationToken<'_>, ClientError> {
let key = self.key.read().await.clone().ok_or(ClientError::Auth(
jacquard_common::error::AuthError::NotAuthenticated,
···
///
/// - `identifier`: handle (preferred), DID, or `https://` PDS base URL.
/// - `session_id`: optional session label; defaults to "session".
+
/// - Persists and activates the session, and updates the base endpoint to the user's PDS.
pub async fn login(
&self,
identifier: CowStr<'_>,
···
Ok(())
}
-
/// Switch to a different stored session (and refresh endpoint from DID).
+
/// Switch to a different stored session (and refresh endpoint/PDS).
pub async fn switch_session(
&self,
did: Did<'_>,
···
T: HttpClient + XrpcExt + Send + Sync + 'static,
{
fn base_uri(&self) -> Url {
-
self.endpoint.blocking_read().clone().unwrap_or(
-
Url::parse("https://public.bsky.app").expect("public appview should be valid url"),
-
)
+
// base_uri is a synchronous trait method; avoid `.await` here.
+
// Under Tokio, use `block_in_place` to make a blocking RwLock read safe.
+
if tokio::runtime::Handle::try_current().is_ok() {
+
tokio::task::block_in_place(|| {
+
self.endpoint
+
.blocking_read()
+
.clone()
+
.unwrap_or(
+
Url::parse("https://public.bsky.app")
+
.expect("public appview should be valid url"),
+
)
+
})
+
} else {
+
self.endpoint
+
.blocking_read()
+
.clone()
+
.unwrap_or(
+
Url::parse("https://public.bsky.app")
+
.expect("public appview should be valid url"),
+
)
+
}
}
async fn send<R: jacquard_common::types::xrpc::XrpcRequest + Send>(
self,
+95 -25
crates/jacquard/src/client/token.rs
···
use serde_json::Value;
use url::Url;
+
/// On-disk session records for app-password and OAuth flows, sharing a single JSON map.
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum StoredSession {
+
/// App-password session
Atp(StoredAtSession),
+
/// OAuth client session
OAuth(OAuthSession),
+
/// OAuth authorization request state
OAuthState(OAuthState),
}
+
/// Minimal persisted representation of an app‑password session.
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct StoredAtSession {
+
/// Access token (JWT)
access_jwt: String,
+
/// Refresh token (JWT)
refresh_jwt: String,
+
/// Account DID
did: String,
+
/// Optional PDS endpoint for faster resume
#[serde(skip_serializing_if = "std::option::Option::is_none")]
pds: Option<String>,
+
/// Session id label (e.g., "session")
session_id: String,
+
/// Last known handle
handle: String,
}
+
/// Persisted OAuth client session (on-disk format).
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct OAuthSession {
+
/// Account DID
account_did: String,
+
/// Client-generated session id (usually auth `state`)
session_id: String,
-
// Base URL of the "resource server" (eg, PDS). Should include scheme, hostname, port; no path or auth info.
+
/// Base URL of the resource server (PDS)
host_url: Url,
-
// Base URL of the "auth server" (eg, PDS or entryway). Should include scheme, hostname, port; no path or auth info.
+
/// Base URL of the authorization server (PDS or entryway)
authserver_url: Url,
-
// Full token endpoint
+
/// Full token endpoint URL
authserver_token_endpoint: String,
-
// Full revocation endpoint, if it exists
+
/// Full revocation endpoint URL, if available
#[serde(skip_serializing_if = "std::option::Option::is_none")]
authserver_revocation_endpoint: Option<String>,
-
// The set of scopes approved for this session (returned in the initial token request)
+
/// Granted scopes
scopes: Vec<String>,
+
/// Client DPoP key material
pub dpop_key: Key,
-
// Current auth server DPoP nonce
+
/// Current auth server DPoP nonce
pub dpop_authserver_nonce: String,
-
// Current host ("resource server", eg PDS) DPoP nonce
+
/// Current resource server (PDS) DPoP nonce
pub dpop_host_nonce: String,
+
/// Token response issuer
pub iss: String,
+
/// Token subject (DID)
pub sub: String,
+
/// Token audience (verified PDS URL)
pub aud: String,
+
/// Token scopes (raw) if provided
pub scope: Option<String>,
+
/// Refresh token
pub refresh_token: Option<String>,
+
/// Access token
pub access_token: String,
+
/// Token type (e.g., DPoP)
pub token_type: OAuthTokenType,
+
/// Expiration timestamp
pub expires_at: Option<Datetime>,
}
···
}
}
+
/// Persisted OAuth authorization request state.
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct OAuthState {
-
// The random identifier generated by the client for the auth request flow. Can be used as "primary key" for storing and retrieving this information.
+
/// Random identifier generated for the authorization flow (`state`)
pub state: String,
-
// URL of the auth server (eg, PDS or entryway)
+
/// Base URL of the authorization server (PDS or entryway)
pub authserver_url: Url,
-
// If the flow started with an account identifier (DID or handle), it should be persisted, to verify against the initial token response.
+
/// Optional pre-known account DID
#[serde(skip_serializing_if = "std::option::Option::is_none")]
pub account_did: Option<String>,
-
// OAuth scope strings
+
/// Requested scopes
pub scopes: Vec<String>,
-
// unique token in URI format, which will be used by the client in the auth flow redirect
+
/// Request URI for the authorization step
pub request_uri: String,
-
// Full token endpoint URL
+
/// Full token endpoint URL
pub authserver_token_endpoint: String,
-
// Full revocation endpoint, if it exists
+
/// Full revocation endpoint URL, if available
#[serde(skip_serializing_if = "std::option::Option::is_none")]
pub authserver_revocation_endpoint: Option<String>,
-
// The secret token/nonce which a code challenge was generated from
+
/// PKCE verifier
pub pkce_verifier: String,
+
/// Client DPoP key material
pub dpop_key: Key,
-
// Current auth server DPoP nonce
+
/// Auth server DPoP nonce at PAR time
#[serde(skip_serializing_if = "std::option::Option::is_none")]
pub dpop_authserver_nonce: Option<String>,
}
···
}
}
+
/// Convenience wrapper over `FileTokenStore` offering unified storage across auth modes.
pub struct FileAuthStore(FileTokenStore);
impl FileAuthStore {
···
let mut store: Value = serde_json::from_str(&file)?;
if let Some(map) = store.as_object_mut() {
if let Some(value) = map.get_mut(&key_str) {
-
if let Some(obj) = value.as_object_mut() {
-
obj.insert(
-
"pds".to_string(),
-
serde_json::Value::String(pds.to_string()),
-
);
-
std::fs::write(&self.0.path, serde_json::to_string_pretty(&store)?)?;
-
return Ok(());
+
if let Some(outer) = value.as_object_mut() {
+
if let Some(inner) = outer.get_mut("Atp").and_then(|v| v.as_object_mut()) {
+
inner.insert(
+
"pds".to_string(),
+
serde_json::Value::String(pds.to_string()),
+
);
+
std::fs::write(&self.0.path, serde_json::to_string_pretty(&store)?)?;
+
return Ok(());
+
}
}
}
}
···
let store: Value = serde_json::from_str(&file)?;
if let Some(value) = store.get(&key_str) {
if let Some(obj) = value.as_object() {
-
if let Some(serde_json::Value::String(pds)) = obj.get("pds") {
-
return Ok(Url::parse(pds).ok());
+
if let Some(serde_json::Value::Object(inner)) = obj.get("Atp") {
+
if let Some(serde_json::Value::String(pds)) = inner.get("pds") {
+
return Ok(Url::parse(pds).ok());
+
}
}
}
}
···
}
}
}
+
+
#[cfg(test)]
+
mod tests {
+
use super::*;
+
use crate::client::credential_session::SessionKey;
+
use crate::client::AtpSession;
+
use jacquard_common::types::string::{Did, Handle};
+
use std::fs;
+
use std::path::PathBuf;
+
+
fn temp_file() -> PathBuf {
+
let mut p = std::env::temp_dir();
+
p.push(format!("jacquard-test-{}.json", std::process::id()));
+
p
+
}
+
+
#[tokio::test]
+
async fn file_auth_store_roundtrip_atp() {
+
let path = temp_file();
+
// initialize empty store file
+
fs::write(&path, "{}").unwrap();
+
let store = FileAuthStore::new(&path);
+
let session = AtpSession {
+
access_jwt: "a".into(),
+
refresh_jwt: "r".into(),
+
did: Did::new_static("did:plc:alice").unwrap(),
+
handle: Handle::new_static("alice.bsky.social").unwrap(),
+
};
+
let key: SessionKey = (session.did.clone(), "session".into());
+
jacquard_common::session::SessionStore::set(&store, key.clone(), session.clone())
+
.await
+
.unwrap();
+
let restored = jacquard_common::session::SessionStore::get(&store, &key)
+
.await
+
.unwrap();
+
assert_eq!(restored.access_jwt.as_ref(), "a");
+
// clean up
+
let _ = fs::remove_file(&path);
+
}
+
}
+4 -4
crates/jacquard/src/lib.rs
···
//! use jacquard::client::credential_session::{CredentialSession, SessionKey};
//! use jacquard::client::{AtpSession, FileAuthStore, MemorySessionStore};
//! use jacquard::identity::PublicResolver as JacquardResolver;
+
//! use jacquard::types::xrpc::XrpcClient;
//! # use miette::IntoDiagnostic;
//!
//! # #[derive(Parser, Debug)]
···
//! .into_diagnostic()?;
//! // Fetch timeline
//! let timeline = session
-
//! .clone()
-
//! .send(GetTimeline::new().limit(5).build())
+
//! .send(&GetTimeline::new().limit(5).build())
//! .await
//! .into_diagnostic()?
//! .into_output()
···
//! let resp = http
//! .xrpc(base)
//! .send(
-
//! GetAuthorFeed::new()
+
//! &GetAuthorFeed::new()
//! .actor(AtIdentifier::new_static("pattern.atproto.systems").unwrap())
//! .limit(5)
//! .build(),
···
//! .accept_labelers(vec![CowStr::from("did:plc:labelerid")])
//! .header(http::header::USER_AGENT, http::HeaderValue::from_static("jacquard-example"))
//! .send(
-
//! GetAuthorFeed::new()
+
//! &GetAuthorFeed::new()
//! .actor(AtIdentifier::new_static("pattern.atproto.systems").unwrap())
//! .limit(5)
//! .build(),
+147
crates/jacquard/tests/agent.rs
···
+
use std::collections::VecDeque;
+
use std::sync::Arc;
+
+
use http::{HeaderValue, Response as HttpResponse, StatusCode};
+
use jacquard::client::credential_session::{CredentialSession, SessionKey};
+
use jacquard::client::{Agent, AtpSession};
+
use jacquard::identity::resolver::{DidDocResponse, IdentityResolver, ResolverOptions};
+
use jacquard::types::did::Did;
+
use jacquard::types::string::Handle;
+
use jacquard_common::http_client::HttpClient;
+
use jacquard_common::session::MemorySessionStore;
+
use tokio::sync::Mutex;
+
+
#[derive(Clone, Default)]
+
struct MockClient {
+
queue: Arc<Mutex<VecDeque<http::Response<Vec<u8>>>>>,
+
log: Arc<Mutex<Vec<http::Request<Vec<u8>>>>>,
+
}
+
+
impl MockClient {
+
async fn push(&self, resp: http::Response<Vec<u8>>) {
+
self.queue.lock().await.push_back(resp);
+
}
+
}
+
+
impl HttpClient for MockClient {
+
type Error = std::convert::Infallible;
+
fn send_http(
+
&self,
+
request: http::Request<Vec<u8>>,
+
) -> impl core::future::Future<
+
Output = core::result::Result<http::Response<Vec<u8>>, Self::Error>,
+
> + Send {
+
let log = self.log.clone();
+
let queue = self.queue.clone();
+
async move {
+
log.lock().await.push(request);
+
Ok(queue.lock().await.pop_front().expect("no queued response"))
+
}
+
}
+
}
+
+
#[async_trait::async_trait]
+
impl IdentityResolver for MockClient {
+
fn options(&self) -> &ResolverOptions {
+
use std::sync::LazyLock;
+
static OPTS: LazyLock<ResolverOptions> = LazyLock::new(ResolverOptions::default);
+
&OPTS
+
}
+
async fn resolve_handle(
+
&self,
+
_handle: &Handle<'_>,
+
) -> std::result::Result<Did<'static>, jacquard::identity::resolver::IdentityError> {
+
Ok(Did::new_static("did:plc:alice").unwrap())
+
}
+
async fn resolve_did_doc(
+
&self,
+
_did: &Did<'_>,
+
) -> std::result::Result<DidDocResponse, jacquard::identity::resolver::IdentityError> {
+
let doc = serde_json::json!({
+
"id": "did:plc:alice",
+
"service": [{
+
"id": "#pds",
+
"type": "AtprotoPersonalDataServer",
+
"serviceEndpoint": "https://pds"
+
}]
+
});
+
Ok(DidDocResponse {
+
buffer: bytes::Bytes::from(serde_json::to_vec(&doc).unwrap()),
+
status: StatusCode::OK,
+
requested: None,
+
})
+
}
+
}
+
+
// XrpcExt blanket impl applies via HttpClient
+
+
fn refresh_session_body(access: &str, refresh: &str) -> Vec<u8> {
+
serde_json::to_vec(&serde_json::json!({
+
"accessJwt": access,
+
"refreshJwt": refresh,
+
"did": "did:plc:alice",
+
"handle": "alice.bsky.social"
+
}))
+
.unwrap()
+
}
+
+
#[tokio::test]
+
async fn agent_delegates_to_session_and_refreshes() {
+
let client = Arc::new(MockClient::default());
+
let store: Arc<MemorySessionStore<SessionKey, AtpSession>> = Arc::new(Default::default());
+
let session = CredentialSession::new(store.clone(), client.clone());
+
+
// Seed a session in the store and activate it via restore (sets endpoint to PDS)
+
let atp = AtpSession {
+
access_jwt: "acc1".into(),
+
refresh_jwt: "ref1".into(),
+
did: Did::new_static("did:plc:alice").unwrap(),
+
handle: Handle::new_static("alice.bsky.social").unwrap(),
+
};
+
let key: SessionKey = (atp.did.clone(), "session".into());
+
jacquard_common::session::SessionStore::set(store.as_ref(), key.clone(), atp)
+
.await
+
.unwrap();
+
session
+
.restore(Did::new_static("did:plc:alice").unwrap(), "session".into())
+
.await
+
.unwrap();
+
+
let agent: Agent<_> = Agent::from(session);
+
assert_eq!(agent.kind(), jacquard::client::AgentKind::AppPassword);
+
let info = agent.info().await.expect("session info");
+
assert_eq!(info.0.as_str(), "did:plc:alice");
+
assert_eq!(info.1.as_ref().unwrap().as_str(), "session");
+
assert_eq!(agent.endpoint().await.as_str(), "https://pds/");
+
+
// Queue a refresh response and call agent.refresh(); Authorization header must use refresh token
+
client
+
.push(
+
HttpResponse::builder()
+
.status(StatusCode::OK)
+
.header(http::header::CONTENT_TYPE, "application/json")
+
.body(refresh_session_body("acc2", "ref2"))
+
.unwrap(),
+
)
+
.await;
+
+
let token = agent.refresh().await.expect("refresh ok");
+
match token {
+
jacquard::AuthorizationToken::Bearer(s) => assert_eq!(s.as_ref(), "acc2"),
+
_ => panic!("expected Bearer token"),
+
}
+
+
// Validate the refreshSession call used the refresh token header
+
let log = client.log.lock().await;
+
assert_eq!(log.len(), 1);
+
assert!(
+
log[0]
+
.uri()
+
.to_string()
+
.ends_with("/xrpc/com.atproto.server.refreshSession")
+
);
+
assert_eq!(
+
log[0].headers().get(http::header::AUTHORIZATION),
+
Some(&HeaderValue::from_static("Bearer ref1"))
+
);
+
}
+261
crates/jacquard/tests/credential_session.rs
···
+
use std::collections::VecDeque;
+
use std::sync::Arc;
+
+
use bytes::Bytes;
+
use http::{HeaderValue, Method, Response as HttpResponse, StatusCode};
+
use jacquard::client::AtpSession;
+
use jacquard::client::credential_session::{CredentialSession, SessionKey};
+
use jacquard::identity::resolver::{DidDocResponse, IdentityResolver, ResolverOptions};
+
use jacquard::types::did::Did;
+
use jacquard::types::string::Handle;
+
use jacquard::types::xrpc::XrpcClient;
+
use jacquard_common::http_client::HttpClient;
+
use jacquard_common::session::{MemorySessionStore, SessionStore};
+
use tokio::sync::{Mutex, RwLock};
+
+
#[derive(Clone, Default)]
+
struct MockClient {
+
// Queue of HTTP responses to pop for each send_http call
+
queue: Arc<Mutex<VecDeque<HttpResponse<Vec<u8>>>>>,
+
// Capture requests for assertions
+
log: Arc<Mutex<Vec<http::Request<Vec<u8>>>>>,
+
// Count calls to identity resolver helpers
+
did_doc_calls: Arc<RwLock<usize>>,
+
}
+
+
impl MockClient {
+
async fn push(&self, resp: HttpResponse<Vec<u8>>) {
+
self.queue.lock().await.push_back(resp);
+
}
+
async fn take_log(&self) -> Vec<http::Request<Vec<u8>>> {
+
let mut log = self.log.lock().await;
+
let out = log.clone();
+
log.clear();
+
out
+
}
+
}
+
+
impl HttpClient for MockClient {
+
type Error = std::convert::Infallible;
+
+
fn send_http(
+
&self,
+
request: http::Request<Vec<u8>>,
+
) -> impl core::future::Future<
+
Output = core::result::Result<http::Response<Vec<u8>>, Self::Error>,
+
> + Send {
+
let log = self.log.clone();
+
let queue = self.queue.clone();
+
async move {
+
log.lock().await.push(request);
+
Ok(queue.lock().await.pop_front().expect("no queued response"))
+
}
+
}
+
}
+
+
#[async_trait::async_trait]
+
impl IdentityResolver for MockClient {
+
fn options(&self) -> &ResolverOptions {
+
use std::sync::LazyLock;
+
static OPTS: LazyLock<ResolverOptions> = LazyLock::new(ResolverOptions::default);
+
&OPTS
+
}
+
+
async fn resolve_handle(
+
&self,
+
handle: &Handle<'_>,
+
) -> std::result::Result<Did<'static>, jacquard::identity::resolver::IdentityError> {
+
// Return a fixed DID for any handle
+
assert!(handle.as_str().contains('.'));
+
Ok(Did::new_static("did:plc:alice").unwrap())
+
}
+
+
async fn resolve_did_doc(
+
&self,
+
did: &Did<'_>,
+
) -> std::result::Result<DidDocResponse, jacquard::identity::resolver::IdentityError> {
+
// Track calls and return a minimal DID doc with a PDS endpoint
+
*self.did_doc_calls.write().await += 1;
+
assert_eq!(did.as_str(), "did:plc:alice");
+
let doc = serde_json::json!({
+
"id": "did:plc:alice",
+
"service": [{
+
"id": "#pds",
+
"type": "AtprotoPersonalDataServer",
+
"serviceEndpoint": "https://pds"
+
}]
+
});
+
Ok(DidDocResponse {
+
buffer: Bytes::from(serde_json::to_vec(&doc).unwrap()),
+
status: StatusCode::OK,
+
requested: None,
+
})
+
}
+
}
+
+
// XrpcExt blanket impl applies via HttpClient
+
+
fn create_session_body() -> Vec<u8> {
+
serde_json::to_vec(&serde_json::json!({
+
"accessJwt": "acc1",
+
"refreshJwt": "ref1",
+
"did": "did:plc:alice",
+
"handle": "alice.bsky.social"
+
}))
+
.unwrap()
+
}
+
+
fn refresh_session_body(access: &str, refresh: &str) -> Vec<u8> {
+
serde_json::to_vec(&serde_json::json!({
+
"accessJwt": access,
+
"refreshJwt": refresh,
+
"did": "did:plc:alice",
+
"handle": "alice.bsky.social"
+
}))
+
.unwrap()
+
}
+
+
fn get_session_ok_body() -> Vec<u8> {
+
serde_json::to_vec(&serde_json::json!({
+
"did": "did:plc:alice",
+
"handle": "alice.bsky.social",
+
"active": true
+
}))
+
.unwrap()
+
}
+
+
#[tokio::test(flavor = "multi_thread")]
+
async fn credential_login_and_auto_refresh() {
+
let client = Arc::new(MockClient::default());
+
+
// Queue responses in order: createSession 200 → getSession 401 → refreshSession 200 → getSession 200
+
client
+
.push(
+
HttpResponse::builder()
+
.status(StatusCode::OK)
+
.header(http::header::CONTENT_TYPE, "application/json")
+
.body(create_session_body())
+
.unwrap(),
+
)
+
.await;
+
client
+
.push(
+
HttpResponse::builder()
+
.status(StatusCode::UNAUTHORIZED)
+
.header(http::header::CONTENT_TYPE, "application/json")
+
.body(serde_json::to_vec(&serde_json::json!({"error":"ExpiredToken"})).unwrap())
+
.unwrap(),
+
)
+
.await;
+
client
+
.push(
+
HttpResponse::builder()
+
.status(StatusCode::OK)
+
.header(http::header::CONTENT_TYPE, "application/json")
+
.body(refresh_session_body("acc2", "ref2"))
+
.unwrap(),
+
)
+
.await;
+
client
+
.push(
+
HttpResponse::builder()
+
.status(StatusCode::OK)
+
.header(http::header::CONTENT_TYPE, "application/json")
+
.body(get_session_ok_body())
+
.unwrap(),
+
)
+
.await;
+
+
let store: Arc<MemorySessionStore<SessionKey, AtpSession>> = Arc::new(Default::default());
+
let session = CredentialSession::new(store.clone(), client.clone());
+
+
// Before login, default endpoint should be public appview
+
assert_eq!(
+
session.endpoint().await.as_str(),
+
"https://public.bsky.app/"
+
);
+
+
// Login using handle; resolves to PDS and persists session
+
session
+
.login(
+
jacquard::CowStr::from("alice.bsky.social"),
+
jacquard::CowStr::from("apppass"),
+
Some(jacquard::CowStr::from("session")),
+
None,
+
None,
+
)
+
.await
+
.expect("login ok");
+
+
// Endpoint switches to PDS
+
assert_eq!(session.endpoint().await.as_str(), "https://pds/");
+
+
// Send a request that will first 401 (ExpiredToken), then refresh, then succeed
+
let resp = session
+
.send(&jacquard::api::com_atproto::server::get_session::GetSession)
+
.await
+
.expect("xrpc send ok");
+
assert_eq!(resp.status(), StatusCode::OK);
+
let out = resp
+
.parse()
+
.expect("parse ok after refresh (GetSession output)");
+
assert_eq!(out.handle.as_str(), "alice.bsky.social");
+
+
// Verify request sequence and Authorization headers used
+
let log = client.take_log().await;
+
assert_eq!(log.len(), 4, "expected four HTTP calls");
+
// 0: createSession (no auth)
+
assert_eq!(log[0].method(), Method::POST);
+
assert!(
+
log[0]
+
.uri()
+
.to_string()
+
.ends_with("/xrpc/com.atproto.server.createSession")
+
);
+
assert!(log[0].headers().get(http::header::AUTHORIZATION).is_none());
+
// 1: getSession (uses access token acc1)
+
assert_eq!(log[1].method(), Method::GET);
+
assert!(
+
log[1]
+
.uri()
+
.to_string()
+
.ends_with("/xrpc/com.atproto.server.getSession")
+
);
+
assert_eq!(
+
log[1].headers().get(http::header::AUTHORIZATION),
+
Some(&HeaderValue::from_static("Bearer acc1"))
+
);
+
// 2: refreshSession (uses refresh token ref1)
+
assert_eq!(log[2].method(), Method::POST);
+
assert!(
+
log[2]
+
.uri()
+
.to_string()
+
.ends_with("/xrpc/com.atproto.server.refreshSession")
+
);
+
assert_eq!(
+
log[2].headers().get(http::header::AUTHORIZATION),
+
Some(&HeaderValue::from_static("Bearer ref1"))
+
);
+
// 3: getSession (re-sent with new access token acc2)
+
assert_eq!(log[3].method(), Method::GET);
+
assert!(
+
log[3]
+
.uri()
+
.to_string()
+
.ends_with("/xrpc/com.atproto.server.getSession")
+
);
+
assert_eq!(
+
log[3].headers().get(http::header::AUTHORIZATION),
+
Some(&HeaderValue::from_static("Bearer acc2"))
+
);
+
+
// Verify store updated with refreshed tokens
+
let key: SessionKey = (
+
Did::new_static("did:plc:alice").unwrap(),
+
jacquard::CowStr::from("session"),
+
);
+
let updated = store.get(&key).await.expect("session present");
+
assert_eq!(updated.access_jwt.as_ref(), "acc2");
+
assert_eq!(updated.refresh_jwt.as_ref(), "ref2");
+
}
+374
crates/jacquard/tests/oauth_auto_refresh.rs
···
+
use std::collections::VecDeque;
+
use std::sync::Arc;
+
+
use bytes::Bytes;
+
use http::{HeaderValue, Method, Response as HttpResponse, StatusCode};
+
use jacquard::client::Agent;
+
use jacquard::IntoStatic;
+
use jacquard::types::did::Did;
+
use jacquard::types::xrpc::XrpcClient;
+
use jacquard_common::http_client::HttpClient;
+
use jacquard_oauth::atproto::AtprotoClientMetadata;
+
use jacquard_oauth::client::OAuthSession;
+
use jacquard_oauth::session::SessionRegistry;
+
use jacquard_oauth::resolver::OAuthResolver;
+
use jacquard_oauth::scopes::Scope;
+
use jacquard_oauth::session::{ClientData, ClientSessionData, DpopClientData};
+
use jacquard_oauth::types::{OAuthAuthorizationServerMetadata, OAuthTokenType, TokenSet};
+
use tokio::sync::Mutex;
+
+
#[derive(Clone, Default)]
+
struct MockClient {
+
queue: Arc<Mutex<VecDeque<http::Response<Vec<u8>>>>>,
+
log: Arc<Mutex<Vec<http::Request<Vec<u8>>>>>,
+
}
+
+
impl MockClient {
+
async fn push(&self, resp: http::Response<Vec<u8>>) {
+
self.queue.lock().await.push_back(resp);
+
}
+
}
+
+
impl HttpClient for MockClient {
+
type Error = std::convert::Infallible;
+
fn send_http(
+
&self,
+
request: http::Request<Vec<u8>>,
+
) -> impl core::future::Future<
+
Output = core::result::Result<http::Response<Vec<u8>>, Self::Error>,
+
> + Send {
+
let log = self.log.clone();
+
let queue = self.queue.clone();
+
async move {
+
log.lock().await.push(request);
+
Ok(queue
+
.lock()
+
.await
+
.pop_front()
+
.expect("no queued response"))
+
}
+
}
+
}
+
+
#[async_trait::async_trait]
+
impl jacquard::identity::resolver::IdentityResolver for MockClient {
+
fn options(&self) -> &jacquard::identity::resolver::ResolverOptions {
+
use std::sync::LazyLock;
+
static OPTS: LazyLock<jacquard::identity::resolver::ResolverOptions> =
+
LazyLock::new(jacquard::identity::resolver::ResolverOptions::default);
+
&OPTS
+
}
+
async fn resolve_handle(
+
&self,
+
_handle: &jacquard::types::string::Handle<'_>,
+
) -> std::result::Result<Did<'static>, jacquard::identity::resolver::IdentityError> {
+
Ok(Did::new_static("did:plc:alice").unwrap())
+
}
+
async fn resolve_did_doc(
+
&self,
+
_did: &Did<'_>,
+
) -> std::result::Result<jacquard::identity::resolver::DidDocResponse, jacquard::identity::resolver::IdentityError> {
+
let doc = serde_json::json!({
+
"id": "did:plc:alice",
+
"service": [{
+
"id": "#pds",
+
"type": "AtprotoPersonalDataServer",
+
"serviceEndpoint": "https://pds"
+
}]
+
});
+
Ok(jacquard::identity::resolver::DidDocResponse {
+
buffer: Bytes::from(serde_json::to_vec(&doc).unwrap()),
+
status: StatusCode::OK,
+
requested: None,
+
})
+
}
+
}
+
+
#[async_trait::async_trait]
+
impl OAuthResolver for MockClient {
+
async fn get_authorization_server_metadata(
+
&self,
+
issuer: &url::Url,
+
) -> Result<OAuthAuthorizationServerMetadata<'static>, jacquard_oauth::resolver::ResolverError> {
+
// Return minimal metadata with supported auth method "none" and DPoP support
+
let mut md = OAuthAuthorizationServerMetadata::default();
+
md.issuer = jacquard::CowStr::from(issuer.as_str());
+
md.token_endpoint = jacquard::CowStr::from(format!("{}/token", issuer));
+
md.authorization_endpoint = jacquard::CowStr::from(format!("{}/authorize", issuer));
+
md.require_pushed_authorization_requests = Some(true);
+
md.pushed_authorization_request_endpoint =
+
Some(jacquard::CowStr::from(format!("{}/par", issuer)));
+
md.token_endpoint_auth_methods_supported = Some(vec![jacquard::CowStr::from("none")]);
+
md.dpop_signing_alg_values_supported = Some(vec![jacquard::CowStr::from("ES256")]);
+
use jacquard::IntoStatic;
+
Ok(md.into_static())
+
}
+
+
async fn get_resource_server_metadata(
+
&self,
+
_pds: &url::Url,
+
) -> Result<OAuthAuthorizationServerMetadata<'static>, jacquard_oauth::resolver::ResolverError> {
+
// Return metadata pointing to the same issuer as above
+
let mut md = OAuthAuthorizationServerMetadata::default();
+
md.issuer = jacquard::CowStr::from("https://issuer");
+
md.token_endpoint = jacquard::CowStr::from("https://issuer/token");
+
md.authorization_endpoint = jacquard::CowStr::from("https://issuer/authorize");
+
md.require_pushed_authorization_requests = Some(true);
+
md.pushed_authorization_request_endpoint = Some(jacquard::CowStr::from("https://issuer/par"));
+
md.token_endpoint_auth_methods_supported = Some(vec![jacquard::CowStr::from("none")]);
+
md.dpop_signing_alg_values_supported = Some(vec![jacquard::CowStr::from("ES256")]);
+
Ok(md.into_static())
+
}
+
+
async fn verify_issuer(
+
&self,
+
_server_metadata: &OAuthAuthorizationServerMetadata<'_>,
+
_sub: &Did<'_>,
+
) -> Result<url::Url, jacquard_oauth::resolver::ResolverError> {
+
Ok(url::Url::parse("https://pds").unwrap())
+
}
+
}
+
+
fn get_session_unauthorized() -> http::Response<Vec<u8>> {
+
HttpResponse::builder()
+
.status(StatusCode::UNAUTHORIZED)
+
.header(
+
http::header::WWW_AUTHENTICATE,
+
HeaderValue::from_static("DPoP realm=\"pds\", error=\"invalid_token\""),
+
)
+
.body(Vec::new())
+
.unwrap()
+
}
+
+
fn get_session_unauthorized_body() -> http::Response<Vec<u8>> {
+
HttpResponse::builder()
+
.status(StatusCode::UNAUTHORIZED)
+
.header(http::header::CONTENT_TYPE, "application/json")
+
.body(
+
serde_json::to_vec(&serde_json::json!({
+
"error":"InvalidToken"
+
}))
+
.unwrap(),
+
)
+
.unwrap()
+
}
+
+
fn token_use_dpop_nonce() -> http::Response<Vec<u8>> {
+
HttpResponse::builder()
+
.status(StatusCode::BAD_REQUEST)
+
.header(http::header::CONTENT_TYPE, "application/json")
+
.header("DPoP-Nonce", HeaderValue::from_static("n1"))
+
.body(serde_json::to_vec(&serde_json::json!({"error":"use_dpop_nonce"})).unwrap())
+
.unwrap()
+
}
+
+
fn token_refresh_ok() -> http::Response<Vec<u8>> {
+
HttpResponse::builder()
+
.status(StatusCode::OK)
+
.header(http::header::CONTENT_TYPE, "application/json")
+
.body(
+
serde_json::to_vec(&serde_json::json!({
+
"access_token":"newacc",
+
"token_type":"DPoP",
+
"refresh_token":"newref",
+
"expires_in": 3600
+
}))
+
.unwrap(),
+
)
+
.unwrap()
+
}
+
+
fn get_session_ok() -> http::Response<Vec<u8>> {
+
HttpResponse::builder()
+
.status(StatusCode::OK)
+
.header(http::header::CONTENT_TYPE, "application/json")
+
.body(
+
serde_json::to_vec(&serde_json::json!({
+
"did":"did:plc:alice",
+
"handle":"alice.bsky.social",
+
"active":true
+
}))
+
.unwrap(),
+
)
+
.unwrap()
+
}
+
+
impl jacquard_oauth::dpop::DpopExt for MockClient {}
+
+
#[tokio::test(flavor = "multi_thread")]
+
async fn oauth_xrpc_invalid_token_triggers_refresh_and_retries() {
+
// (reopen test body since we inserted a trait impl)
+
let client = Arc::new(MockClient::default());
+
+
client.push(get_session_unauthorized()).await;
+
client.push(token_use_dpop_nonce()).await;
+
client.push(token_refresh_ok()).await;
+
client.push(get_session_ok()).await;
+
+
let mut path = std::env::temp_dir();
+
path.push(format!("jacquard-oauth-test-{}.json", std::process::id()));
+
std::fs::write(&path, "{}").unwrap();
+
let store = jacquard::client::FileAuthStore::new(&path);
+
+
let client_data = ClientData {
+
keyset: None,
+
config: AtprotoClientMetadata::new_localhost(None, Some(vec![Scope::Atproto])),
+
};
+
use jacquard::IntoStatic;
+
let session_data = ClientSessionData {
+
account_did: Did::new_static("did:plc:alice").unwrap(),
+
session_id: jacquard::CowStr::from("state"),
+
host_url: url::Url::parse("https://pds").unwrap(),
+
authserver_url: url::Url::parse("https://issuer").unwrap(),
+
authserver_token_endpoint: jacquard::CowStr::from("https://issuer/token"),
+
authserver_revocation_endpoint: None,
+
scopes: vec![Scope::Atproto],
+
dpop_data: DpopClientData {
+
dpop_key: jacquard_oauth::utils::generate_key(&[jacquard::CowStr::from("ES256")])
+
.unwrap(),
+
dpop_authserver_nonce: jacquard::CowStr::from(""),
+
dpop_host_nonce: jacquard::CowStr::from(""),
+
},
+
token_set: TokenSet {
+
iss: jacquard::CowStr::from("https://issuer"),
+
sub: Did::new_static("did:plc:alice").unwrap(),
+
aud: jacquard::CowStr::from("https://pds"),
+
scope: None,
+
refresh_token: Some(jacquard::CowStr::from("rt1")),
+
access_token: jacquard::CowStr::from("atk1"),
+
token_type: OAuthTokenType::DPoP,
+
expires_at: None,
+
},
+
}
+
.into_static();
+
let client_arc = client.clone();
+
let registry = Arc::new(SessionRegistry::new(store, client_arc.clone(), client_data));
+
// Seed the store so refresh can load the session
+
let data_store = ClientSessionData {
+
account_did: Did::new_static("did:plc:alice").unwrap(),
+
session_id: jacquard::CowStr::from("state"),
+
host_url: url::Url::parse("https://pds").unwrap(),
+
authserver_url: url::Url::parse("https://issuer").unwrap(),
+
authserver_token_endpoint: jacquard::CowStr::from("https://issuer/token"),
+
authserver_revocation_endpoint: None,
+
scopes: vec![Scope::Atproto],
+
dpop_data: DpopClientData {
+
dpop_key: jacquard_oauth::utils::generate_key(&[jacquard::CowStr::from("ES256")])
+
.unwrap(),
+
dpop_authserver_nonce: jacquard::CowStr::from(""),
+
dpop_host_nonce: jacquard::CowStr::from(""),
+
},
+
token_set: TokenSet {
+
iss: jacquard::CowStr::from("https://issuer"),
+
sub: Did::new_static("did:plc:alice").unwrap(),
+
aud: jacquard::CowStr::from("https://pds"),
+
scope: None,
+
refresh_token: Some(jacquard::CowStr::from("rt1")),
+
access_token: jacquard::CowStr::from("atk1"),
+
token_type: OAuthTokenType::DPoP,
+
expires_at: None,
+
},
+
}
+
.into_static();
+
registry.set(data_store).await.unwrap();
+
let session = OAuthSession::new(registry, client_arc, session_data);
+
+
let agent: Agent<_> = Agent::from(session);
+
let resp = agent
+
.send(&jacquard::api::com_atproto::server::get_session::GetSession)
+
.await
+
.expect("xrpc send ok after auto-refresh");
+
assert_eq!(resp.status(), StatusCode::OK);
+
+
// Inspect the request log
+
let log = client.log.lock().await;
+
assert_eq!(log.len(), 4, "expected 4 HTTP calls");
+
// 0: getSession with old token
+
assert_eq!(log[0].method(), Method::GET);
+
assert!(log[0].headers().get(http::header::AUTHORIZATION).unwrap().to_str().unwrap().starts_with("DPoP "));
+
assert!(log[0]
+
.uri()
+
.to_string()
+
.ends_with("/xrpc/com.atproto.server.getSession"));
+
// 1 and 2: token refresh attempts
+
assert_eq!(log[1].method(), Method::POST);
+
assert!(log[1].uri().to_string().ends_with("/token"));
+
assert!(log[1].headers().contains_key("DPoP"));
+
assert_eq!(log[2].method(), Method::POST);
+
assert!(log[2].uri().to_string().ends_with("/token"));
+
assert!(log[2].headers().contains_key("DPoP"));
+
// 3: retried getSession with new access token
+
assert_eq!(log[3].method(), Method::GET);
+
assert!(log[3]
+
.headers()
+
.get(http::header::AUTHORIZATION)
+
.unwrap()
+
.to_str()
+
.unwrap()
+
.starts_with("DPoP newacc"));
+
+
// Cleanup temp file
+
let _ = std::fs::remove_file(&path);
+
}
+
+
#[tokio::test(flavor = "multi_thread")]
+
async fn oauth_xrpc_invalid_token_body_triggers_refresh_and_retries() {
+
let client = Arc::new(MockClient::default());
+
+
// Queue responses: initial 401 with JSON body; token refresh 400(use_dpop_nonce); token refresh 200; retry getSession 200
+
client.push(get_session_unauthorized_body()).await;
+
client.push(token_use_dpop_nonce()).await;
+
client.push(token_refresh_ok()).await;
+
client.push(get_session_ok()).await;
+
+
let mut path = std::env::temp_dir();
+
path.push(format!("jacquard-oauth-test-body-{}.json", std::process::id()));
+
std::fs::write(&path, "{}").unwrap();
+
let store = jacquard::client::FileAuthStore::new(&path);
+
+
let client_data = ClientData {
+
keyset: None,
+
config: AtprotoClientMetadata::new_localhost(None, Some(vec![Scope::Atproto])),
+
};
+
use jacquard::IntoStatic;
+
let session_data = ClientSessionData {
+
account_did: Did::new_static("did:plc:alice").unwrap(),
+
session_id: jacquard::CowStr::from("state"),
+
host_url: url::Url::parse("https://pds").unwrap(),
+
authserver_url: url::Url::parse("https://issuer").unwrap(),
+
authserver_token_endpoint: jacquard::CowStr::from("https://issuer/token"),
+
authserver_revocation_endpoint: None,
+
scopes: vec![Scope::Atproto],
+
dpop_data: DpopClientData {
+
dpop_key: jacquard_oauth::utils::generate_key(&[jacquard::CowStr::from("ES256")])
+
.unwrap(),
+
dpop_authserver_nonce: jacquard::CowStr::from(""),
+
dpop_host_nonce: jacquard::CowStr::from(""),
+
},
+
token_set: TokenSet {
+
iss: jacquard::CowStr::from("https://issuer"),
+
sub: Did::new_static("did:plc:alice").unwrap(),
+
aud: jacquard::CowStr::from("https://pds"),
+
scope: None,
+
refresh_token: Some(jacquard::CowStr::from("rt1")),
+
access_token: jacquard::CowStr::from("atk1"),
+
token_type: OAuthTokenType::DPoP,
+
expires_at: None,
+
},
+
}
+
.into_static();
+
let client_arc = client.clone();
+
let registry = Arc::new(SessionRegistry::new(store, client_arc.clone(), client_data));
+
registry.set(session_data.clone()).await.unwrap();
+
let session = OAuthSession::new(registry, client_arc, session_data);
+
+
let agent: Agent<_> = Agent::from(session);
+
let resp = agent
+
.send(&jacquard::api::com_atproto::server::get_session::GetSession)
+
.await
+
.expect("xrpc send ok after auto-refresh");
+
assert_eq!(resp.status(), StatusCode::OK);
+
+
// Cleanup temp file
+
let _ = std::fs::remove_file(&path);
+
}
+293
crates/jacquard/tests/oauth_flow.rs
···
+
use std::collections::VecDeque;
+
use std::sync::Arc;
+
+
use bytes::Bytes;
+
use http::{Response as HttpResponse, StatusCode};
+
use jacquard::IntoStatic;
+
use jacquard::client::Agent;
+
use jacquard::types::xrpc::XrpcClient;
+
use jacquard_common::http_client::HttpClient;
+
use jacquard_oauth::atproto::AtprotoClientMetadata;
+
use jacquard_oauth::authstore::ClientAuthStore;
+
use jacquard_oauth::client::OAuthClient;
+
use jacquard_oauth::resolver::OAuthResolver;
+
use jacquard_oauth::scopes::Scope;
+
use jacquard_oauth::session::ClientData;
+
+
#[derive(Clone, Default)]
+
struct MockClient {
+
queue: Arc<tokio::sync::Mutex<VecDeque<http::Response<Vec<u8>>>>>,
+
}
+
+
impl MockClient {
+
async fn push(&self, resp: http::Response<Vec<u8>>) {
+
self.queue.lock().await.push_back(resp);
+
}
+
}
+
+
impl HttpClient for MockClient {
+
type Error = std::convert::Infallible;
+
fn send_http(
+
&self,
+
_request: http::Request<Vec<u8>>,
+
) -> impl core::future::Future<
+
Output = core::result::Result<http::Response<Vec<u8>>, Self::Error>,
+
> + Send {
+
let queue = self.queue.clone();
+
async move { Ok(queue.lock().await.pop_front().expect("no queued response")) }
+
}
+
}
+
+
#[async_trait::async_trait]
+
impl jacquard::identity::resolver::IdentityResolver for MockClient {
+
fn options(&self) -> &jacquard::identity::resolver::ResolverOptions {
+
use std::sync::LazyLock;
+
static OPTS: LazyLock<jacquard::identity::resolver::ResolverOptions> =
+
LazyLock::new(jacquard::identity::resolver::ResolverOptions::default);
+
&OPTS
+
}
+
async fn resolve_handle(
+
&self,
+
_handle: &jacquard::types::string::Handle<'_>,
+
) -> std::result::Result<
+
jacquard::types::did::Did<'static>,
+
jacquard::identity::resolver::IdentityError,
+
> {
+
Ok(jacquard::types::did::Did::new_static("did:plc:alice").unwrap())
+
}
+
async fn resolve_did_doc(
+
&self,
+
_did: &jacquard::types::did::Did<'_>,
+
) -> std::result::Result<
+
jacquard::identity::resolver::DidDocResponse,
+
jacquard::identity::resolver::IdentityError,
+
> {
+
let doc = serde_json::json!({
+
"id": "did:plc:alice",
+
"service": [{
+
"id": "#pds",
+
"type": "AtprotoPersonalDataServer",
+
"serviceEndpoint": "https://pds"
+
}]
+
});
+
Ok(jacquard::identity::resolver::DidDocResponse {
+
buffer: Bytes::from(serde_json::to_vec(&doc).unwrap()),
+
status: StatusCode::OK,
+
requested: None,
+
})
+
}
+
}
+
+
#[async_trait::async_trait]
+
impl OAuthResolver for MockClient {
+
async fn resolve_oauth(
+
&self,
+
_input: &str,
+
) -> Result<
+
(
+
jacquard_oauth::types::OAuthAuthorizationServerMetadata<'static>,
+
Option<jacquard_common::types::did_doc::DidDocument<'static>>,
+
),
+
jacquard_oauth::resolver::ResolverError,
+
> {
+
let mut md = jacquard_oauth::types::OAuthAuthorizationServerMetadata::default();
+
md.issuer = jacquard::CowStr::from("https://issuer");
+
md.authorization_endpoint = jacquard::CowStr::from("https://issuer/authorize");
+
md.token_endpoint = jacquard::CowStr::from("https://issuer/token");
+
md.require_pushed_authorization_requests = Some(true);
+
md.pushed_authorization_request_endpoint =
+
Some(jacquard::CowStr::from("https://issuer/par"));
+
md.token_endpoint_auth_methods_supported = Some(vec![jacquard::CowStr::from("none")]);
+
md.dpop_signing_alg_values_supported = Some(vec![jacquard::CowStr::from("ES256")]);
+
+
// Simple DID doc pointing to https://pds
+
let doc = serde_json::json!({
+
"id": "did:plc:alice",
+
"service": [{
+
"id": "#pds",
+
"type": "AtprotoPersonalDataServer",
+
"serviceEndpoint": "https://pds"
+
}]
+
});
+
let buf = Bytes::from(serde_json::to_vec(&doc).unwrap());
+
let did_doc_b: jacquard_common::types::did_doc::DidDocument<'_> =
+
serde_json::from_slice(&buf).unwrap();
+
let did_doc = did_doc_b.into_static();
+
Ok((md.into_static(), Some(did_doc)))
+
}
+
async fn get_authorization_server_metadata(
+
&self,
+
issuer: &url::Url,
+
) -> Result<
+
jacquard_oauth::types::OAuthAuthorizationServerMetadata<'static>,
+
jacquard_oauth::resolver::ResolverError,
+
> {
+
let mut md = jacquard_oauth::types::OAuthAuthorizationServerMetadata::default();
+
md.issuer = jacquard::CowStr::from(issuer.as_str());
+
md.authorization_endpoint = jacquard::CowStr::from(format!("{}/authorize", issuer));
+
md.token_endpoint = jacquard::CowStr::from(format!("{}/token", issuer));
+
md.require_pushed_authorization_requests = Some(true);
+
md.pushed_authorization_request_endpoint =
+
Some(jacquard::CowStr::from(format!("{}/par", issuer)));
+
md.token_endpoint_auth_methods_supported = Some(vec![jacquard::CowStr::from("none")]);
+
md.dpop_signing_alg_values_supported = Some(vec![jacquard::CowStr::from("ES256")]);
+
Ok(md.into_static())
+
}
+
+
async fn get_resource_server_metadata(
+
&self,
+
_pds: &url::Url,
+
) -> Result<
+
jacquard_oauth::types::OAuthAuthorizationServerMetadata<'static>,
+
jacquard_oauth::resolver::ResolverError,
+
> {
+
let mut md = jacquard_oauth::types::OAuthAuthorizationServerMetadata::default();
+
md.issuer = jacquard::CowStr::from("https://issuer/");
+
md.authorization_endpoint = jacquard::CowStr::from("https://issuer/authorize");
+
md.token_endpoint = jacquard::CowStr::from("https://issuer/token");
+
md.require_pushed_authorization_requests = Some(true);
+
md.pushed_authorization_request_endpoint =
+
Some(jacquard::CowStr::from("https://issuer/par"));
+
md.token_endpoint_auth_methods_supported = Some(vec![jacquard::CowStr::from("none")]);
+
md.dpop_signing_alg_values_supported = Some(vec![jacquard::CowStr::from("ES256")]);
+
Ok(md.into_static())
+
}
+
}
+
+
impl jacquard_oauth::dpop::DpopExt for MockClient {}
+
+
#[tokio::test(flavor = "multi_thread")]
+
async fn oauth_end_to_end_mock_flow() {
+
let client = Arc::new(MockClient::default());
+
// Queue responses: PAR 201, token 200, XRPC getSession 200
+
client
+
.push(
+
HttpResponse::builder()
+
.status(StatusCode::CREATED)
+
.header(http::header::CONTENT_TYPE, "application/json")
+
.body(
+
serde_json::to_vec(&serde_json::json!({
+
"request_uri": "urn:par:abc",
+
"expires_in": 60
+
}))
+
.unwrap(),
+
)
+
.unwrap(),
+
)
+
.await;
+
client
+
.push(
+
HttpResponse::builder()
+
.status(StatusCode::OK)
+
.header(http::header::CONTENT_TYPE, "application/json")
+
.header("DPoP-Nonce", http::HeaderValue::from_static("n1"))
+
.body(
+
serde_json::to_vec(&serde_json::json!({
+
"access_token": "atk1",
+
"token_type": "DPoP",
+
"refresh_token": "rt1",
+
"sub": "did:plc:alice",
+
"iss": "https://issuer",
+
"aud": "https://pds",
+
"expires_in": 3600
+
}))
+
.unwrap(),
+
)
+
.unwrap(),
+
)
+
.await;
+
client
+
.push(
+
HttpResponse::builder()
+
.status(StatusCode::OK)
+
.header(http::header::CONTENT_TYPE, "application/json")
+
.body(
+
serde_json::to_vec(&serde_json::json!({
+
"did": "did:plc:alice",
+
"handle": "alice.bsky.social",
+
"active": true
+
}))
+
.unwrap(),
+
)
+
.unwrap(),
+
)
+
.await;
+
+
// File-backed store for auth state/session
+
let mut path = std::env::temp_dir();
+
path.push(format!("jacquard-oauth-flow-{}.json", std::process::id()));
+
std::fs::write(&path, "{}").unwrap();
+
let store = jacquard::client::FileAuthStore::new(&path);
+
+
let client_data: ClientData<'static> = ClientData {
+
keyset: None,
+
config: AtprotoClientMetadata::new_localhost(None, Some(vec![Scope::Atproto])),
+
};
+
let client_arc = client.clone();
+
let oauth = OAuthClient::new_from_resolver(store, (*client_arc).clone(), client_data);
+
+
// Build metadata and call PAR to get an AuthRequestData, then save in store
+
let (server_metadata, identity) = client.resolve_oauth("alice.bsky.social").await.unwrap();
+
let metadata = jacquard_oauth::request::OAuthMetadata {
+
server_metadata,
+
client_metadata: jacquard_oauth::atproto::atproto_client_metadata(
+
AtprotoClientMetadata::new_localhost(None, Some(vec![Scope::Atproto])),
+
&None,
+
)
+
.unwrap()
+
.into_static(),
+
keyset: None,
+
};
+
let login_hint = identity.map(|_| jacquard::CowStr::from("alice.bsky.social"));
+
let mut auth_req = jacquard_oauth::request::par(client.as_ref(), login_hint, None, &metadata)
+
.await
+
.unwrap();
+
// Construct authorization URL as OAuthClient::start_auth would do
+
#[derive(serde::Serialize)]
+
struct Parameters<'s> {
+
client_id: url::Url,
+
request_uri: jacquard::CowStr<'s>,
+
}
+
let auth_url = format!(
+
"{}?{}",
+
metadata.server_metadata.authorization_endpoint,
+
serde_html_form::to_string(Parameters {
+
client_id: metadata.client_metadata.client_id.clone(),
+
request_uri: auth_req.request_uri.clone(),
+
})
+
.unwrap()
+
);
+
assert!(auth_url.contains("/authorize?"));
+
assert!(auth_url.contains("request_uri"));
+
// keep state for the callback
+
let state = auth_req.state.clone();
+
oauth
+
.registry
+
.store
+
.save_auth_req_info(&auth_req)
+
.await
+
.unwrap();
+
+
// callback: exchange code, create session
+
use jacquard_oauth::types::CallbackParams;
+
let session = oauth
+
.callback(CallbackParams {
+
code: jacquard::CowStr::from("code123"),
+
state: Some(state.clone()),
+
// Callback compares exact string with metadata.issuer (which is a URL string
+
// including trailing slash). Use normalized form to match.
+
iss: Some(jacquard::CowStr::from("https://issuer/")),
+
})
+
.await
+
.unwrap();
+
+
// Wrap in Agent and send a resource XRPC call to verify Authorization works
+
let agent: Agent<_> = Agent::from(session);
+
let resp = agent
+
.send(&jacquard::api::com_atproto::server::get_session::GetSession)
+
.await
+
.unwrap();
+
assert_eq!(resp.status(), StatusCode::OK);
+
+
let _ = std::fs::remove_file(&path);
+
}
+125
crates/jacquard/tests/restore_pds_cache.rs
···
+
use std::sync::Arc;
+
+
use bytes::Bytes;
+
use http::{Response as HttpResponse, StatusCode};
+
use jacquard::client::credential_session::{CredentialSession, SessionKey};
+
use jacquard::client::{AtpSession, FileAuthStore};
+
use jacquard::identity::resolver::{DidDocResponse, IdentityResolver, ResolverOptions};
+
use jacquard::types::did::Did;
+
use jacquard::types::string::Handle;
+
use jacquard_common::http_client::HttpClient;
+
use jacquard_common::session::SessionStore;
+
use std::fs;
+
use std::path::PathBuf;
+
use tokio::sync::RwLock;
+
use url::Url;
+
+
#[derive(Clone, Default)]
+
struct MockResolver {
+
// Count calls to DID doc resolution
+
did_doc_calls: Arc<RwLock<usize>>,
+
}
+
+
impl HttpClient for MockResolver {
+
type Error = std::convert::Infallible;
+
fn send_http(
+
&self,
+
_request: http::Request<Vec<u8>>,
+
) -> impl core::future::Future<
+
Output = core::result::Result<http::Response<Vec<u8>>, Self::Error>,
+
> + Send {
+
async {
+
// Not used in this test
+
Ok(HttpResponse::builder()
+
.status(StatusCode::OK)
+
.body(Vec::new())
+
.unwrap())
+
}
+
}
+
}
+
+
#[async_trait::async_trait]
+
impl IdentityResolver for MockResolver {
+
fn options(&self) -> &ResolverOptions {
+
use std::sync::LazyLock;
+
static OPTS: LazyLock<ResolverOptions> = LazyLock::new(ResolverOptions::default);
+
&OPTS
+
}
+
async fn resolve_handle(
+
&self,
+
_handle: &Handle<'_>,
+
) -> std::result::Result<Did<'static>, jacquard::identity::resolver::IdentityError> {
+
Ok(Did::new_static("did:plc:alice").unwrap())
+
}
+
async fn resolve_did_doc(
+
&self,
+
_did: &Did<'_>,
+
) -> std::result::Result<DidDocResponse, jacquard::identity::resolver::IdentityError> {
+
*self.did_doc_calls.write().await += 1;
+
let doc = serde_json::json!({
+
"id": "did:plc:alice",
+
"service": [{
+
"id": "#pds",
+
"type": "AtprotoPersonalDataServer",
+
"serviceEndpoint": "https://pds-resolved"
+
}]
+
});
+
Ok(DidDocResponse {
+
buffer: Bytes::from(serde_json::to_vec(&doc).unwrap()),
+
status: StatusCode::OK,
+
requested: None,
+
})
+
}
+
}
+
+
fn temp_file() -> PathBuf {
+
let mut p = std::env::temp_dir();
+
p.push(format!("jacquard-test-restore-{}.json", std::process::id()));
+
p
+
}
+
+
#[tokio::test]
+
async fn restore_uses_cached_pds_when_present() {
+
let path = temp_file();
+
fs::write(&path, "{}").unwrap();
+
let store = Arc::new(FileAuthStore::new(&path));
+
let resolver = Arc::new(MockResolver::default());
+
+
// Seed an app-password session in the file store
+
let session = AtpSession {
+
access_jwt: "acc".into(),
+
refresh_jwt: "ref".into(),
+
did: Did::new_static("did:plc:alice").unwrap(),
+
handle: Handle::new_static("alice.bsky.social").unwrap(),
+
};
+
let key: SessionKey = (session.did.clone(), "session".into());
+
jacquard_common::session::SessionStore::set(store.as_ref(), key.clone(), session)
+
.await
+
.unwrap();
+
// Verify it is persisted
+
assert!(SessionStore::get(store.as_ref(), &key).await.is_some());
+
// Persist PDS endpoint cache to avoid DID resolution on restore
+
store
+
.set_atp_pds(&key, &Url::parse("https://pds-cached").unwrap())
+
.unwrap();
+
assert_eq!(
+
store
+
.get_atp_pds(&key)
+
.ok()
+
.flatten()
+
.expect("pds cached")
+
.as_str(),
+
"https://pds-cached/"
+
);
+
+
let session = CredentialSession::new(store.clone(), resolver.clone());
+
// Restore should pick cached PDS and NOT call resolve_did_doc
+
session
+
.restore(Did::new_static("did:plc:alice").unwrap(), "session".into())
+
.await
+
.expect("restore ok");
+
assert_eq!(session.endpoint().await.as_str(), "https://pds-cached/");
+
+
// Cleanup
+
let _ = fs::remove_file(&path);
+
}