axum_service_auth_middleware.rs
1use crate::AppState;
2use atproto_identity::key::{identify_key, validate};
3use axum::Json;
4use axum::body::Body;
5use axum::extract::State;
6use axum::http::{Request, StatusCode, header};
7use axum::middleware::Next;
8use axum::response::{IntoResponse, Response};
9use jacquard_common::types::did::Did;
10use jacquard_common::url::Url;
11use jacquard_identity::PublicResolver;
12use jacquard_identity::resolver::IdentityResolver;
13use jwt_compact::{Claims, UntrustedToken};
14use serde::{Deserialize, Serialize};
15use std::sync::Arc;
16
17pub mod com_atproto_pdsmoover_backup_handlers;
18
19pub struct XrpcErrorResponse {
20 error: XrpcError,
21 pub status: StatusCode,
22}
23
24impl XrpcErrorResponse {
25 pub fn internal_server_error() -> Self {
26 Self {
27 error: XrpcError {
28 error: "InternalServerError".to_string(),
29 message: None,
30 },
31 status: StatusCode::INTERNAL_SERVER_ERROR,
32 }
33 }
34
35 pub fn auth_missing() -> Self {
36 Self {
37 error: XrpcError {
38 error: "AuthMissing".to_string(),
39 message: Some("Authentication Required".to_string()),
40 },
41 status: StatusCode::UNAUTHORIZED,
42 }
43 }
44
45 pub fn invalid_token(message: Option<&str>) -> Self {
46 Self {
47 error: XrpcError {
48 error: "InvalidToken".to_string(),
49 message: None,
50 },
51 status: StatusCode::UNAUTHORIZED,
52 }
53 }
54}
55
56#[derive(serde::Deserialize, serde::Serialize)]
57pub struct XrpcError {
58 pub error: String,
59 #[serde(skip_serializing_if = "std::option::Option::is_none")]
60 pub message: Option<String>,
61}
62
63impl IntoResponse for XrpcErrorResponse {
64 fn into_response(self) -> Response {
65 (self.status, Json(self.error)).into_response()
66 }
67}
68
69/// Subset of data returned that has been validated of the user making the call
70#[derive(Clone, Serialize, Deserialize, Debug)]
71pub struct VerifiedServiceAuthResults {
72 /// The user's did
73 pub did: String,
74 /// The user's pds url
75 pub pds_url: Url,
76 /// The user's atproto multikey used to verify requests and signing
77 pub multi_key: String,
78}
79
80/// A subset of the claims that are in the service auth token.
81#[derive(Serialize, Deserialize)]
82struct ServiceAuthClaims {
83 /// User's did
84 pub iss: String,
85 /// Audience (did:web in this case that was proxied)
86 pub aud: String,
87 /// Lexicon XRPC endpoint requested. example com.atproto.sync.getRecord
88 pub lxm: String,
89}
90
91/// Verifies the service auth token that is appended to an XRPC proxy request
92async fn verify_service_auth(
93 jwt: String,
94 lxm: &str,
95 public_resolver: Arc<PublicResolver>,
96 did_web: String,
97) -> anyhow::Result<VerifiedServiceAuthResults> {
98 let token = UntrustedToken::new(&jwt)?;
99
100 let claims: Claims<ServiceAuthClaims> = token.deserialize_claims_unchecked()?;
101 let did = Did::new(claims.custom.iss.as_str())?;
102 //TODO change to shared one later
103 let doc_response = public_resolver.resolve_did_doc(&did).await?;
104 let doc = doc_response.parse()?;
105
106 let multi_key = match doc.atproto_multikey() {
107 Some(key) => key,
108 None => {
109 return Err(anyhow::anyhow!("No atproto_multikey in did doc"));
110 }
111 };
112 let identified_key = identify_key(&multi_key)?;
113
114 // If no error is throw it's valid. Should check expiry time here as well (I think)
115 let _ = validate(
116 &identified_key,
117 &token.signature_bytes(),
118 &token.signed_data,
119 )?;
120
121 if claims.custom.aud != did_web {
122 return Err(anyhow::anyhow!("Invalid audience (did:web)"));
123 }
124
125 if claims.custom.lxm != lxm {
126 return Err(anyhow::anyhow!("Invalid XRPC endpoint requested"));
127 }
128 let pds_url = match doc.pds_endpoint() {
129 None => {
130 return Err(anyhow::anyhow!("No pds_endpoint in did doc"));
131 }
132 Some(endpoint) => endpoint,
133 };
134
135 Ok(VerifiedServiceAuthResults {
136 did: did.to_string(),
137 pds_url,
138 multi_key: multi_key.to_string(),
139 })
140}
141
142async fn service_auth_middleware(
143 State(state): State<AppState>,
144 mut req: Request<Body>,
145 next: Next,
146) -> Result<Response, XrpcErrorResponse> {
147 // Expect Authorization: Bearer <jwt>
148 let auth_header = req
149 .headers()
150 .get(header::AUTHORIZATION)
151 .and_then(|v| v.to_str().ok())
152 .map(str::to_string);
153
154 let Some(value) = auth_header else {
155 return Err(XrpcErrorResponse::auth_missing());
156 };
157
158 // Ensure Bearer prefix
159 let token = value.strip_prefix("Bearer ").unwrap_or("").trim();
160 if token.is_empty() {
161 return Err(XrpcErrorResponse::auth_missing());
162 }
163
164 // Build lxm from request path by removing /xrpc/ prefix
165 let path = req.uri().path();
166 let lxm = path.strip_prefix("/xrpc/").unwrap_or(path);
167
168 // Verify token
169 let verified = verify_service_auth(
170 token.to_string(),
171 lxm,
172 state.public_resolver.clone(),
173 state.did_web.0.clone(),
174 )
175 .await;
176
177 match verified {
178 Ok(results) => {
179 req.extensions_mut().insert(results);
180 Ok(next.run(req).await)
181 }
182 Err(err) => {
183 tracing::warn!(error = %err, "Invalid service auth token");
184 Err(XrpcErrorResponse::invalid_token(None))
185 }
186 }
187}