A better Rust ATProto crate
1use std::str::FromStr; 2 3use crate::types::OAuthClientMetadata; 4use crate::{keyset::Keyset, scopes::Scope}; 5use jacquard_common::CowStr; 6use serde::{Deserialize, Serialize}; 7use thiserror::Error; 8use url::Url; 9 10#[derive(Error, Debug)] 11pub enum Error { 12 #[error("`client_id` must be a valid URL")] 13 InvalidClientId, 14 #[error("`grant_types` must include `authorization_code`")] 15 InvalidGrantTypes, 16 #[error("`scope` must not include `atproto`")] 17 InvalidScope, 18 #[error("`redirect_uris` must not be empty")] 19 EmptyRedirectUris, 20 #[error("`private_key_jwt` auth method requires `jwks` keys")] 21 EmptyJwks, 22 #[error( 23 "`private_key_jwt` auth method requires `token_endpoint_auth_signing_alg`, otherwise must not be provided" 24 )] 25 AuthSigningAlg, 26 #[error(transparent)] 27 SerdeHtmlForm(#[from] serde_html_form::ser::Error), 28 #[error(transparent)] 29 LocalhostClient(#[from] LocalhostClientError), 30} 31 32#[derive(Error, Debug)] 33pub enum LocalhostClientError { 34 #[error("invalid redirect_uri: {0}")] 35 Invalid(#[from] url::ParseError), 36 #[error("loopback client_id must use `http:` redirect_uri")] 37 NotHttpScheme, 38 #[error("loopback client_id must not use `localhost` as redirect_uri hostname")] 39 Localhost, 40 #[error("loopback client_id must not use loopback addresses as redirect_uri")] 41 NotLoopbackHost, 42} 43 44pub type Result<T> = core::result::Result<T, Error>; 45 46#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] 47#[serde(rename_all = "snake_case")] 48pub enum AuthMethod { 49 None, 50 // https://openid.net/specs/openid-connect-core-1_0.html#ClientAuthentication 51 PrivateKeyJwt, 52} 53 54impl From<AuthMethod> for CowStr<'static> { 55 fn from(value: AuthMethod) -> Self { 56 match value { 57 AuthMethod::None => CowStr::new_static("none"), 58 AuthMethod::PrivateKeyJwt => CowStr::new_static("private_key_jwt"), 59 } 60 } 61} 62 63#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] 64#[serde(rename_all = "snake_case")] 65pub enum GrantType { 66 AuthorizationCode, 67 RefreshToken, 68} 69 70impl From<GrantType> for CowStr<'static> { 71 fn from(value: GrantType) -> Self { 72 match value { 73 GrantType::AuthorizationCode => CowStr::new_static("authorization_code"), 74 GrantType::RefreshToken => CowStr::new_static("refresh_token"), 75 } 76 } 77} 78 79#[derive(Clone, Debug, PartialEq, Eq)] 80pub struct AtprotoClientMetadata<'m> { 81 pub client_id: Url, 82 pub client_uri: Option<Url>, 83 pub redirect_uris: Vec<Url>, 84 pub grant_types: Vec<GrantType>, 85 pub scopes: Vec<Scope<'m>>, 86 pub jwks_uri: Option<Url>, 87} 88 89impl<'m> AtprotoClientMetadata<'m> { 90 pub fn new( 91 client_id: Url, 92 client_uri: Option<Url>, 93 redirect_uris: Vec<Url>, 94 grant_types: Vec<GrantType>, 95 scopes: Vec<Scope<'m>>, 96 jwks_uri: Option<Url>, 97 ) -> Self { 98 Self { 99 client_id, 100 client_uri, 101 redirect_uris, 102 grant_types, 103 scopes, 104 jwks_uri, 105 } 106 } 107 108 pub fn default_localhost() -> Self { 109 Self::new_localhost( 110 None, 111 Some(Scope::parse_multiple("atproto transition:generic").unwrap()), 112 ) 113 } 114 115 pub fn new_localhost( 116 mut redirect_uris: Option<Vec<Url>>, 117 scopes: Option<Vec<Scope<'m>>>, 118 ) -> Self { 119 // Coerce provided redirect URIs to http://localhost while preserving path 120 if let Some(redirect_uris) = &mut redirect_uris { 121 for redirect_uri in redirect_uris { 122 let _ = redirect_uri.set_scheme("http"); 123 redirect_uri.set_host(Some("127.0.0.1")).unwrap(); 124 } 125 } 126 // determine client_id 127 #[derive(serde::Serialize)] 128 struct Parameters<'a> { 129 #[serde(skip_serializing_if = "Option::is_none")] 130 redirect_uri: Option<Vec<Url>>, 131 #[serde(skip_serializing_if = "Option::is_none")] 132 scope: Option<CowStr<'a>>, 133 } 134 let query = serde_html_form::to_string(Parameters { 135 redirect_uri: redirect_uris.clone(), 136 scope: scopes 137 .as_ref() 138 .map(|s| Scope::serialize_multiple(s.as_slice())), 139 }) 140 .ok(); 141 let mut client_id = String::from("http://localhost"); 142 if let Some(query) = query 143 && !query.is_empty() 144 { 145 client_id.push_str(&format!("?{query}")); 146 } 147 Self { 148 client_id: Url::parse(&client_id).unwrap(), 149 client_uri: None, 150 redirect_uris: redirect_uris.unwrap_or(vec![ 151 Url::from_str("http://127.0.0.1/").unwrap(), 152 Url::from_str("http://[::1]/").unwrap(), 153 ]), 154 grant_types: vec![GrantType::AuthorizationCode, GrantType::RefreshToken], 155 scopes: scopes.unwrap_or(vec![Scope::Atproto]), 156 jwks_uri: None, 157 } 158 } 159} 160 161pub fn atproto_client_metadata<'m>( 162 metadata: AtprotoClientMetadata<'m>, 163 keyset: &Option<Keyset>, 164) -> Result<OAuthClientMetadata<'m>> { 165 // For non-loopback clients, require a keyset/JWKs. 166 let is_loopback = 167 metadata.client_id.scheme() == "http" && metadata.client_id.host_str() == Some("localhost"); 168 if !is_loopback && keyset.is_none() { 169 return Err(Error::EmptyJwks); 170 } 171 if metadata.redirect_uris.is_empty() { 172 return Err(Error::EmptyRedirectUris); 173 } 174 if !metadata.grant_types.contains(&GrantType::AuthorizationCode) { 175 return Err(Error::InvalidGrantTypes); 176 } 177 if !metadata.scopes.contains(&Scope::Atproto) { 178 return Err(Error::InvalidScope); 179 } 180 let (auth_method, jwks_uri, jwks) = if let Some(keyset) = keyset { 181 let jwks = if metadata.jwks_uri.is_none() { 182 Some(keyset.public_jwks()) 183 } else { 184 None 185 }; 186 (AuthMethod::PrivateKeyJwt, metadata.jwks_uri, jwks) 187 } else { 188 (AuthMethod::None, None, None) 189 }; 190 191 Ok(OAuthClientMetadata { 192 client_id: metadata.client_id, 193 client_uri: metadata.client_uri, 194 redirect_uris: metadata.redirect_uris, 195 token_endpoint_auth_method: Some(auth_method.into()), 196 grant_types: if keyset.is_some() { 197 Some(metadata.grant_types.into_iter().map(|v| v.into()).collect()) 198 } else { 199 None 200 }, 201 scope: Some(Scope::serialize_multiple(metadata.scopes.as_slice())), 202 dpop_bound_access_tokens: if keyset.is_some() { Some(true) } else { None }, 203 jwks_uri, 204 jwks, 205 token_endpoint_auth_signing_alg: if keyset.is_some() { 206 Some(CowStr::new_static("ES256")) 207 } else { 208 None 209 }, 210 }) 211} 212 213#[cfg(test)] 214mod tests { 215 use std::str::FromStr; 216 217 use crate::scopes::TransitionScope; 218 219 use super::*; 220 use elliptic_curve::SecretKey; 221 use jose_jwk::{Jwk, Key, Parameters}; 222 use p256::pkcs8::DecodePrivateKey; 223 224 const PRIVATE_KEY: &str = r#"-----BEGIN PRIVATE KEY----- 225MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgED1AAgC7Fc9kPh5T 2264i4Tn+z+tc47W1zYgzXtyjJtD92hRANCAAT80DqC+Z/JpTO7/pkPBmWqIV1IGh1P 227gbGGr0pN+oSing7cZ0169JaRHTNh+0LNQXrFobInX6cj95FzEdRyT4T3 228-----END PRIVATE KEY-----"#; 229 230 #[test] 231 fn test_localhost_client_metadata_default() { 232 assert_eq!( 233 atproto_client_metadata(AtprotoClientMetadata::new_localhost(None, None), &None) 234 .unwrap(), 235 OAuthClientMetadata { 236 client_id: Url::from_str("http://localhost").unwrap(), 237 client_uri: None, 238 redirect_uris: vec![ 239 Url::from_str("http://127.0.0.1/").unwrap(), 240 Url::from_str("http://[::1]/").unwrap(), 241 ], 242 scope: Some(CowStr::new_static("atproto")), 243 grant_types: None, 244 token_endpoint_auth_method: Some(AuthMethod::None.into()), 245 dpop_bound_access_tokens: None, 246 jwks_uri: None, 247 jwks: None, 248 token_endpoint_auth_signing_alg: None, 249 } 250 ); 251 } 252 253 #[test] 254 fn test_localhost_client_metadata_custom() { 255 assert_eq!( 256 atproto_client_metadata(AtprotoClientMetadata::new_localhost( 257 Some(vec![ 258 Url::from_str("http://127.0.0.1/callback").unwrap(), 259 Url::from_str("http://[::1]/callback").unwrap(), 260 ]), 261 Some( 262 vec![ 263 Scope::Atproto, 264 Scope::Transition(TransitionScope::Generic), 265 Scope::parse("account:email").unwrap() 266 ] 267 ) 268 ), &None) 269 .expect("failed to convert metadata"), 270 OAuthClientMetadata { 271 client_id: Url::from_str( 272 "http://localhost?redirect_uri=http%3A%2F%2F127.0.0.1%2Fcallback&redirect_uri=http%3A%2F%2F127.0.0.1%2Fcallback&scope=account%3Aemail+atproto+transition%3Ageneric" 273 ).unwrap(), 274 client_uri: None, 275 redirect_uris: vec![ 276 Url::from_str("http://127.0.0.1/callback").unwrap(), 277 // TODO: fix this so that it respects IPv6 278 Url::from_str("http://127.0.0.1/callback").unwrap(), 279 ], 280 scope: Some(CowStr::new_static("account:email atproto transition:generic")), 281 grant_types: None, 282 token_endpoint_auth_method: Some(AuthMethod::None.into()), 283 dpop_bound_access_tokens: None, 284 jwks_uri: None, 285 jwks: None, 286 token_endpoint_auth_signing_alg: None, 287 } 288 ); 289 } 290 291 #[test] 292 fn test_localhost_client_metadata_invalid() { 293 // Invalid inputs are coerced to http://localhost rather than failing 294 { 295 let out = atproto_client_metadata( 296 AtprotoClientMetadata::new_localhost( 297 Some(vec![Url::from_str("https://127.0.0.1/").unwrap()]), 298 None, 299 ), 300 &None, 301 ) 302 .expect("should coerce to 127.0.0.1"); 303 assert_eq!( 304 out, 305 OAuthClientMetadata { 306 client_id: Url::from_str( 307 "http://localhost?redirect_uri=http%3A%2F%2F127.0.0.1%2F" 308 ) 309 .unwrap(), 310 client_uri: None, 311 redirect_uris: vec![Url::from_str("http://127.0.0.1/").unwrap()], 312 scope: Some(CowStr::new_static("atproto")), 313 grant_types: None, 314 token_endpoint_auth_method: Some(AuthMethod::None.into()), 315 dpop_bound_access_tokens: None, 316 jwks_uri: None, 317 jwks: None, 318 token_endpoint_auth_signing_alg: None, 319 } 320 ); 321 } 322 { 323 let out = atproto_client_metadata( 324 AtprotoClientMetadata::new_localhost( 325 Some(vec![Url::from_str("http://localhost:8000/").unwrap()]), 326 None, 327 ), 328 &None, 329 ) 330 .expect("should coerce to 127.0.0.1"); 331 assert_eq!( 332 out, 333 OAuthClientMetadata { 334 client_id: Url::from_str( 335 "http://localhost?redirect_uri=http%3A%2F%2F127.0.0.1%3A8000%2F" 336 ) 337 .unwrap(), 338 client_uri: None, 339 redirect_uris: vec![Url::from_str("http://127.0.0.1:8000/").unwrap()], 340 scope: Some(CowStr::new_static("atproto")), 341 grant_types: None, 342 token_endpoint_auth_method: Some(AuthMethod::None.into()), 343 dpop_bound_access_tokens: None, 344 jwks_uri: None, 345 jwks: None, 346 token_endpoint_auth_signing_alg: None, 347 } 348 ); 349 } 350 { 351 let out = atproto_client_metadata( 352 AtprotoClientMetadata::new_localhost( 353 Some(vec![Url::from_str("http://192.168.0.0/").unwrap()]), 354 None, 355 ), 356 &None, 357 ) 358 .expect("should coerce to 127.0.0.1"); 359 assert_eq!( 360 out, 361 OAuthClientMetadata { 362 client_id: Url::from_str( 363 "http://localhost?redirect_uri=http%3A%2F%2F127.0.0.1%2F" 364 ) 365 .unwrap(), 366 client_uri: None, 367 redirect_uris: vec![Url::from_str("http://127.0.0.1/").unwrap()], 368 scope: Some(CowStr::new_static("atproto")), 369 grant_types: None, 370 token_endpoint_auth_method: Some(AuthMethod::None.into()), 371 dpop_bound_access_tokens: None, 372 jwks_uri: None, 373 jwks: None, 374 token_endpoint_auth_signing_alg: None, 375 } 376 ); 377 } 378 } 379 380 #[test] 381 fn test_client_metadata() { 382 let metadata = AtprotoClientMetadata { 383 client_id: Url::from_str("https://example.com/client_metadata.json").unwrap(), 384 client_uri: Some(Url::from_str("https://example.com").unwrap()), 385 redirect_uris: vec![Url::from_str("https://example.com/callback").unwrap()], 386 grant_types: vec![GrantType::AuthorizationCode], 387 scopes: vec![Scope::Atproto], 388 jwks_uri: None, 389 }; 390 { 391 // Non-loopback clients without a keyset should fail (must provide JWKS) 392 let metadata = metadata.clone(); 393 let err = atproto_client_metadata(metadata, &None).expect_err("expected to fail"); 394 assert!(matches!(err, Error::EmptyJwks)); 395 } 396 { 397 let metadata = metadata.clone(); 398 let secret_key = SecretKey::<p256::NistP256>::from_pkcs8_pem(PRIVATE_KEY) 399 .expect("failed to parse private key"); 400 let keys = vec![Jwk { 401 key: Key::from(&secret_key.into()), 402 prm: Parameters { 403 kid: Some(String::from("kid00")), 404 ..Default::default() 405 }, 406 }]; 407 let keyset = Keyset::try_from(keys.clone()).expect("failed to create keyset"); 408 assert_eq!( 409 atproto_client_metadata(metadata, &Some(keyset.clone())) 410 .expect("failed to convert metadata"), 411 OAuthClientMetadata { 412 client_id: Url::from_str("https://example.com/client_metadata.json").unwrap(), 413 client_uri: Some(Url::from_str("https://example.com").unwrap()), 414 redirect_uris: vec![Url::from_str("https://example.com/callback").unwrap()], 415 scope: Some(CowStr::new_static("atproto")), 416 grant_types: Some(vec![CowStr::new_static("authorization_code")]), 417 token_endpoint_auth_method: Some(AuthMethod::PrivateKeyJwt.into()), 418 dpop_bound_access_tokens: Some(true), 419 jwks_uri: None, 420 jwks: Some(keyset.public_jwks()), 421 token_endpoint_auth_signing_alg: Some(CowStr::new_static("ES256")), 422 } 423 ); 424 } 425 } 426}