A better Rust ATProto crate
1use crate::types::{OAuthAuthorizationServerMetadata, OAuthProtectedResourceMetadata}; 2use http::{Request, StatusCode}; 3use jacquard_common::CowStr; 4use jacquard_common::types::did_doc::DidDocument; 5use jacquard_common::types::ident::AtIdentifier; 6use jacquard_common::{IntoStatic, error::TransportError}; 7use jacquard_common::{http_client::HttpClient, types::did::Did}; 8use jacquard_identity::resolver::{IdentityError, IdentityResolver}; 9use url::Url; 10 11/// Compare two issuer strings strictly but without spuriously failing on trivial differences. 12/// 13/// Rules: 14/// - Schemes must match exactly. 15/// - Hostnames and effective ports must match (treat missing port the same as default port). 16/// - Path must match, except that an empty path and `/` are equivalent. 17/// - Query/fragment are not considered; if present on either side, the comparison fails. 18pub(crate) fn issuer_equivalent(a: &str, b: &str) -> bool { 19 fn normalize(url: &Url) -> Option<(String, String, u16, String)> { 20 if url.query().is_some() || url.fragment().is_some() { 21 return None; 22 } 23 let scheme = url.scheme().to_string(); 24 let host = url.host_str()?.to_string(); 25 let port = url.port_or_known_default()?; 26 let path = match url.path() { 27 "" => "/".to_string(), 28 "/" => "/".to_string(), 29 other => other.to_string(), 30 }; 31 Some((scheme, host, port, path)) 32 } 33 34 match (Url::parse(a), Url::parse(b)) { 35 (Ok(ua), Ok(ub)) => match (normalize(&ua), normalize(&ub)) { 36 (Some((sa, ha, pa, pa_path)), Some((sb, hb, pb, pb_path))) => { 37 if sa != sb || ha != hb || pa != pb { 38 return false; 39 } 40 if pa_path == "/" && pb_path == "/" { 41 return true; 42 } 43 pa_path == pb_path 44 } 45 _ => false, 46 }, 47 _ => a == b, 48 } 49} 50 51#[derive(thiserror::Error, Debug, miette::Diagnostic)] 52pub enum ResolverError { 53 #[error("resource not found")] 54 #[diagnostic( 55 code(jacquard_oauth::resolver::not_found), 56 help("check the base URL or identifier") 57 )] 58 NotFound, 59 #[error("invalid at identifier: {0}")] 60 #[diagnostic( 61 code(jacquard_oauth::resolver::at_identifier), 62 help("ensure a valid handle or DID was provided") 63 )] 64 AtIdentifier(String), 65 #[error("invalid did: {0}")] 66 #[diagnostic( 67 code(jacquard_oauth::resolver::did), 68 help("ensure DID is correctly formed (did:plc or did:web)") 69 )] 70 Did(String), 71 #[error("invalid did document: {0}")] 72 #[diagnostic( 73 code(jacquard_oauth::resolver::did_document), 74 help("verify the DID document structure and service entries") 75 )] 76 DidDocument(String), 77 #[error("protected resource metadata is invalid: {0}")] 78 #[diagnostic( 79 code(jacquard_oauth::resolver::protected_resource_metadata), 80 help("PDS must advertise an authorization server in its protected resource metadata") 81 )] 82 ProtectedResourceMetadata(String), 83 #[error("authorization server metadata is invalid: {0}")] 84 #[diagnostic( 85 code(jacquard_oauth::resolver::authorization_server_metadata), 86 help("issuer must match and include the PDS resource") 87 )] 88 AuthorizationServerMetadata(String), 89 #[error("error resolving identity: {0}")] 90 #[diagnostic(code(jacquard_oauth::resolver::identity))] 91 IdentityResolverError(#[from] IdentityError), 92 #[error("unsupported did method: {0:?}")] 93 #[diagnostic( 94 code(jacquard_oauth::resolver::unsupported_did_method), 95 help("supported DID methods: did:web, did:plc") 96 )] 97 UnsupportedDidMethod(Did<'static>), 98 #[error(transparent)] 99 #[diagnostic(code(jacquard_oauth::resolver::transport))] 100 Transport(#[from] TransportError), 101 #[error("http status: {0:?}")] 102 #[diagnostic( 103 code(jacquard_oauth::resolver::http_status), 104 help("check well-known paths and server configuration") 105 )] 106 HttpStatus(StatusCode), 107 #[error(transparent)] 108 #[diagnostic(code(jacquard_oauth::resolver::serde_json))] 109 SerdeJson(#[from] serde_json::Error), 110 #[error(transparent)] 111 #[diagnostic(code(jacquard_oauth::resolver::serde_form))] 112 SerdeHtmlForm(#[from] serde_html_form::ser::Error), 113 #[error(transparent)] 114 #[diagnostic(code(jacquard_oauth::resolver::url))] 115 Uri(#[from] url::ParseError), 116} 117 118pub trait OAuthResolver: IdentityResolver + HttpClient { 119 fn verify_issuer( 120 &self, 121 server_metadata: &OAuthAuthorizationServerMetadata<'_>, 122 sub: &Did<'_>, 123 ) -> impl std::future::Future<Output = Result<Url, ResolverError>> + Send 124 where 125 Self: Sync, 126 { 127 async { 128 let (metadata, identity) = self.resolve_from_identity(sub).await?; 129 if !issuer_equivalent(&metadata.issuer, &server_metadata.issuer) { 130 return Err(ResolverError::AuthorizationServerMetadata( 131 "issuer mismatch".to_string(), 132 )); 133 } 134 Ok(identity 135 .pds_endpoint() 136 .ok_or(ResolverError::DidDocument(format!("{:?}", identity).into()))?) 137 } 138 } 139 fn resolve_oauth( 140 &self, 141 input: &str, 142 ) -> impl Future< 143 Output = Result< 144 ( 145 OAuthAuthorizationServerMetadata<'static>, 146 Option<DidDocument<'static>>, 147 ), 148 ResolverError, 149 >, 150 > + Send 151 where 152 Self: Sync, 153 { 154 // Allow using an entryway, or PDS url, directly as login input (e.g. 155 // when the user forgot their handle, or when the handle does not 156 // resolve to a DID) 157 async { 158 Ok(if input.starts_with("https://") { 159 let url = Url::parse(input).map_err(|_| ResolverError::NotFound)?; 160 (self.resolve_from_service(&url).await?, None) 161 } else { 162 let (metadata, identity) = self.resolve_from_identity(input).await?; 163 (metadata, Some(identity)) 164 }) 165 } 166 } 167 fn resolve_from_service( 168 &self, 169 input: &Url, 170 ) -> impl Future<Output = Result<OAuthAuthorizationServerMetadata<'static>, ResolverError>> + Send 171 where 172 Self: Sync, 173 { 174 async { 175 // Assume first that input is a PDS URL (as required by ATPROTO) 176 if let Ok(metadata) = self.get_resource_server_metadata(input).await { 177 return Ok(metadata); 178 } 179 // Fallback to trying to fetch as an issuer (Entryway) 180 self.get_authorization_server_metadata(input).await 181 } 182 } 183 fn resolve_from_identity( 184 &self, 185 input: &str, 186 ) -> impl Future< 187 Output = Result< 188 ( 189 OAuthAuthorizationServerMetadata<'static>, 190 DidDocument<'static>, 191 ), 192 ResolverError, 193 >, 194 > + Send 195 where 196 Self: Sync, 197 { 198 async { 199 let actor = AtIdentifier::new(input) 200 .map_err(|e| ResolverError::AtIdentifier(format!("{:?}", e)))?; 201 let identity = self.resolve_ident_owned(&actor).await?; 202 if let Some(pds) = &identity.pds_endpoint() { 203 let metadata = self.get_resource_server_metadata(pds).await?; 204 Ok((metadata, identity)) 205 } else { 206 Err(ResolverError::DidDocument(format!("Did doc lacking pds"))) 207 } 208 } 209 } 210 fn get_authorization_server_metadata( 211 &self, 212 issuer: &Url, 213 ) -> impl Future<Output = Result<OAuthAuthorizationServerMetadata<'static>, ResolverError>> + Send 214 where 215 Self: Sync, 216 { 217 async { 218 let mut md = resolve_authorization_server(self, issuer).await?; 219 // Normalize issuer string to the input URL representation to avoid slash quirks 220 md.issuer = jacquard_common::CowStr::from(issuer.as_str()).into_static(); 221 Ok(md) 222 } 223 } 224 fn get_resource_server_metadata( 225 &self, 226 pds: &Url, 227 ) -> impl Future<Output = Result<OAuthAuthorizationServerMetadata<'static>, ResolverError>> + Send 228 where 229 Self: Sync, 230 { 231 async move { 232 let rs_metadata = resolve_protected_resource_info(self, pds).await?; 233 // ATPROTO requires one, and only one, authorization server entry 234 // > That document MUST contain a single item in the authorization_servers array. 235 // https://github.com/bluesky-social/proposals/tree/main/0004-oauth#server-metadata 236 let issuer = match &rs_metadata.authorization_servers { 237 Some(servers) if !servers.is_empty() => { 238 if servers.len() > 1 { 239 return Err(ResolverError::ProtectedResourceMetadata(format!( 240 "unable to determine authorization server for PDS: {pds}" 241 ))); 242 } 243 &servers[0] 244 } 245 _ => { 246 return Err(ResolverError::ProtectedResourceMetadata(format!( 247 "no authorization server found for PDS: {pds}" 248 ))); 249 } 250 }; 251 let as_metadata = self.get_authorization_server_metadata(issuer).await?; 252 // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-resource-metadata-08#name-authorization-server-metada 253 if let Some(protected_resources) = &as_metadata.protected_resources { 254 let resource_url = rs_metadata 255 .resource 256 .strip_suffix('/') 257 .unwrap_or(rs_metadata.resource.as_str()); 258 if !protected_resources.contains(&CowStr::Borrowed(resource_url)) { 259 return Err(ResolverError::AuthorizationServerMetadata(format!( 260 "pds {pds}, resource {0} not protected by issuer: {issuer}, protected resources: {1:?}", 261 rs_metadata.resource, protected_resources 262 ))); 263 } 264 } 265 266 // TODO: atproot specific validation? 267 // https://github.com/bluesky-social/proposals/tree/main/0004-oauth#server-metadata 268 // 269 // eg. 270 // https://drafts.aaronpk.com/draft-parecki-oauth-client-id-metadata-document/draft-parecki-oauth-client-id-metadata-document.html 271 // if as_metadata.client_id_metadata_document_supported != Some(true) { 272 // return Err(Error::AuthorizationServerMetadata(format!( 273 // "authorization server does not support client_id_metadata_document: {issuer}" 274 // ))); 275 // } 276 277 Ok(as_metadata) 278 } 279 } 280} 281 282pub async fn resolve_authorization_server<T: HttpClient + ?Sized>( 283 client: &T, 284 server: &Url, 285) -> Result<OAuthAuthorizationServerMetadata<'static>, ResolverError> { 286 let url = server 287 .join("/.well-known/oauth-authorization-server") 288 .map_err(|e| ResolverError::Transport(TransportError::Other(Box::new(e))))?; 289 290 let req = Request::builder() 291 .uri(url.to_string()) 292 .body(Vec::new()) 293 .map_err(|e| ResolverError::Transport(TransportError::InvalidRequest(e.to_string())))?; 294 let res = client 295 .send_http(req) 296 .await 297 .map_err(|e| ResolverError::Transport(TransportError::Other(Box::new(e))))?; 298 if res.status() == StatusCode::OK { 299 let mut metadata = serde_json::from_slice::<OAuthAuthorizationServerMetadata>(res.body()) 300 .map_err(ResolverError::SerdeJson)?; 301 // https://datatracker.ietf.org/doc/html/rfc8414#section-3.3 302 // Accept semantically equivalent issuer (normalize to the requested URL form) 303 if issuer_equivalent(&metadata.issuer, server.as_str()) { 304 metadata.issuer = server.as_str().into(); 305 Ok(metadata.into_static()) 306 } else { 307 Err(ResolverError::AuthorizationServerMetadata(format!( 308 "invalid issuer: {}", 309 metadata.issuer 310 ))) 311 } 312 } else { 313 Err(ResolverError::HttpStatus(res.status())) 314 } 315} 316 317pub async fn resolve_protected_resource_info<T: HttpClient + ?Sized>( 318 client: &T, 319 server: &Url, 320) -> Result<OAuthProtectedResourceMetadata<'static>, ResolverError> { 321 let url = server 322 .join("/.well-known/oauth-protected-resource") 323 .map_err(|e| ResolverError::Transport(TransportError::Other(Box::new(e))))?; 324 325 let req = Request::builder() 326 .uri(url.to_string()) 327 .body(Vec::new()) 328 .map_err(|e| ResolverError::Transport(TransportError::InvalidRequest(e.to_string())))?; 329 let res = client 330 .send_http(req) 331 .await 332 .map_err(|e| ResolverError::Transport(TransportError::Other(Box::new(e))))?; 333 if res.status() == StatusCode::OK { 334 let mut metadata = serde_json::from_slice::<OAuthProtectedResourceMetadata>(res.body()) 335 .map_err(ResolverError::SerdeJson)?; 336 // https://datatracker.ietf.org/doc/html/rfc8414#section-3.3 337 // Accept semantically equivalent resource URL (normalize to the requested URL form) 338 if issuer_equivalent(&metadata.resource, server.as_str()) { 339 metadata.resource = server.as_str().into(); 340 Ok(metadata.into_static()) 341 } else { 342 Err(ResolverError::AuthorizationServerMetadata(format!( 343 "invalid resource: {}", 344 metadata.resource 345 ))) 346 } 347 } else { 348 Err(ResolverError::HttpStatus(res.status())) 349 } 350} 351 352impl OAuthResolver for jacquard_identity::JacquardResolver {} 353 354#[cfg(test)] 355mod tests { 356 use super::*; 357 use http::{Request as HttpRequest, Response as HttpResponse, StatusCode}; 358 use jacquard_common::http_client::HttpClient; 359 360 #[derive(Default, Clone)] 361 struct MockHttp { 362 next: std::sync::Arc<tokio::sync::Mutex<Option<HttpResponse<Vec<u8>>>>>, 363 } 364 365 impl HttpClient for MockHttp { 366 type Error = std::convert::Infallible; 367 fn send_http( 368 &self, 369 _request: HttpRequest<Vec<u8>>, 370 ) -> impl core::future::Future< 371 Output = core::result::Result<HttpResponse<Vec<u8>>, Self::Error>, 372 > + Send { 373 let next = self.next.clone(); 374 async move { Ok(next.lock().await.take().unwrap()) } 375 } 376 } 377 378 #[tokio::test] 379 async fn authorization_server_http_status() { 380 let client = MockHttp::default(); 381 *client.next.lock().await = Some( 382 HttpResponse::builder() 383 .status(StatusCode::NOT_FOUND) 384 .body(Vec::new()) 385 .unwrap(), 386 ); 387 let issuer = url::Url::parse("https://issuer").unwrap(); 388 let err = super::resolve_authorization_server(&client, &issuer) 389 .await 390 .unwrap_err(); 391 matches!(err, ResolverError::HttpStatus(StatusCode::NOT_FOUND)); 392 } 393 394 #[tokio::test] 395 async fn authorization_server_bad_json() { 396 let client = MockHttp::default(); 397 *client.next.lock().await = Some( 398 HttpResponse::builder() 399 .status(StatusCode::OK) 400 .body(b"{not json}".to_vec()) 401 .unwrap(), 402 ); 403 let issuer = url::Url::parse("https://issuer").unwrap(); 404 let err = super::resolve_authorization_server(&client, &issuer) 405 .await 406 .unwrap_err(); 407 matches!(err, ResolverError::SerdeJson(_)); 408 } 409 410 #[test] 411 fn issuer_equivalence_rules() { 412 assert!(super::issuer_equivalent( 413 "https://issuer", 414 "https://issuer/" 415 )); 416 assert!(super::issuer_equivalent( 417 "https://issuer:443/", 418 "https://issuer/" 419 )); 420 assert!(!super::issuer_equivalent( 421 "http://issuer/", 422 "https://issuer/" 423 )); 424 assert!(!super::issuer_equivalent( 425 "https://issuer/foo", 426 "https://issuer/" 427 )); 428 assert!(!super::issuer_equivalent( 429 "https://issuer/?q=1", 430 "https://issuer/" 431 )); 432 } 433}