A better Rust ATProto crate
at oauth 15 kB view raw
1use std::str::FromStr; 2 3use crate::types::OAuthClientMetadata; 4use crate::{keyset::Keyset, scopes::Scope}; 5use jacquard_common::CowStr; 6use serde::{Deserialize, Serialize}; 7use thiserror::Error; 8use url::Url; 9 10#[derive(Error, Debug)] 11pub enum Error { 12 #[error("`client_id` must be a valid URL")] 13 InvalidClientId, 14 #[error("`grant_types` must include `authorization_code`")] 15 InvalidGrantTypes, 16 #[error("`scope` must not include `atproto`")] 17 InvalidScope, 18 #[error("`redirect_uris` must not be empty")] 19 EmptyRedirectUris, 20 #[error("`private_key_jwt` auth method requires `jwks` keys")] 21 EmptyJwks, 22 #[error( 23 "`private_key_jwt` auth method requires `token_endpoint_auth_signing_alg`, otherwise must not be provided" 24 )] 25 AuthSigningAlg, 26 #[error(transparent)] 27 SerdeHtmlForm(#[from] serde_html_form::ser::Error), 28 #[error(transparent)] 29 LocalhostClient(#[from] LocalhostClientError), 30} 31 32#[derive(Error, Debug)] 33pub enum LocalhostClientError { 34 #[error("invalid redirect_uri: {0}")] 35 Invalid(#[from] url::ParseError), 36 #[error("loopback client_id must use `http:` redirect_uri")] 37 NotHttpScheme, 38 #[error("loopback client_id must not use `localhost` as redirect_uri hostname")] 39 Localhost, 40 #[error("loopback client_id must not use loopback addresses as redirect_uri")] 41 NotLoopbackHost, 42} 43 44pub type Result<T> = core::result::Result<T, Error>; 45 46#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] 47#[serde(rename_all = "snake_case")] 48pub enum AuthMethod { 49 None, 50 // https://openid.net/specs/openid-connect-core-1_0.html#ClientAuthentication 51 PrivateKeyJwt, 52} 53 54impl From<AuthMethod> for CowStr<'static> { 55 fn from(value: AuthMethod) -> Self { 56 match value { 57 AuthMethod::None => CowStr::new_static("none"), 58 AuthMethod::PrivateKeyJwt => CowStr::new_static("private_key_jwt"), 59 } 60 } 61} 62 63#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] 64#[serde(rename_all = "snake_case")] 65pub enum GrantType { 66 AuthorizationCode, 67 RefreshToken, 68} 69 70impl From<GrantType> for CowStr<'static> { 71 fn from(value: GrantType) -> Self { 72 match value { 73 GrantType::AuthorizationCode => CowStr::new_static("authorization_code"), 74 GrantType::RefreshToken => CowStr::new_static("refresh_token"), 75 } 76 } 77} 78 79#[derive(Clone, Debug, PartialEq, Eq)] 80pub struct AtprotoClientMetadata<'m> { 81 pub client_id: Url, 82 pub client_uri: Option<Url>, 83 pub redirect_uris: Vec<Url>, 84 pub grant_types: Vec<GrantType>, 85 pub scopes: Vec<Scope<'m>>, 86 pub jwks_uri: Option<Url>, 87} 88 89impl<'m> AtprotoClientMetadata<'m> { 90 pub fn new( 91 client_id: Url, 92 client_uri: Option<Url>, 93 redirect_uris: Vec<Url>, 94 grant_types: Vec<GrantType>, 95 scopes: Vec<Scope<'m>>, 96 jwks_uri: Option<Url>, 97 ) -> Self { 98 Self { 99 client_id, 100 client_uri, 101 redirect_uris, 102 grant_types, 103 scopes, 104 jwks_uri, 105 } 106 } 107 108 pub fn new_localhost( 109 mut redirect_uris: Option<Vec<Url>>, 110 scopes: Option<Vec<Scope<'m>>>, 111 ) -> Self { 112 // Coerce provided redirect URIs to http://localhost while preserving path 113 if let Some(redirect_uris) = &mut redirect_uris { 114 for redirect_uri in redirect_uris { 115 let _ = redirect_uri.set_scheme("http"); 116 redirect_uri.set_host(Some("localhost")).unwrap(); 117 let _ = redirect_uri.set_port(None); 118 } 119 } 120 // determine client_id 121 #[derive(serde::Serialize)] 122 struct Parameters<'a> { 123 #[serde(skip_serializing_if = "Option::is_none")] 124 redirect_uri: Option<Vec<Url>>, 125 #[serde(skip_serializing_if = "Option::is_none")] 126 scope: Option<CowStr<'a>>, 127 } 128 let query = serde_html_form::to_string(Parameters { 129 redirect_uri: redirect_uris.clone(), 130 scope: scopes 131 .as_ref() 132 .map(|s| Scope::serialize_multiple(s.as_slice())), 133 }) 134 .ok(); 135 let mut client_id = String::from("http://localhost"); 136 if let Some(query) = query 137 && !query.is_empty() 138 { 139 client_id.push_str(&format!("?{query}")); 140 } 141 Self { 142 client_id: Url::parse(&client_id).unwrap(), 143 client_uri: None, 144 redirect_uris: redirect_uris.unwrap_or(vec![ 145 Url::from_str("http://127.0.0.1/").unwrap(), 146 Url::from_str("http://[::1]/").unwrap(), 147 ]), 148 grant_types: vec![GrantType::AuthorizationCode, GrantType::RefreshToken], 149 scopes: scopes.unwrap_or(vec![Scope::Atproto]), 150 jwks_uri: None, 151 } 152 } 153} 154 155pub fn atproto_client_metadata<'m>( 156 metadata: AtprotoClientMetadata<'m>, 157 keyset: &Option<Keyset>, 158) -> Result<OAuthClientMetadata<'m>> { 159 // For non-loopback clients, require a keyset/JWKs. 160 let is_loopback = metadata.client_id.scheme() == "http" 161 && metadata.client_id.host_str() == Some("localhost"); 162 if !is_loopback && keyset.is_none() { 163 return Err(Error::EmptyJwks); 164 } 165 if metadata.redirect_uris.is_empty() { 166 return Err(Error::EmptyRedirectUris); 167 } 168 if !metadata.grant_types.contains(&GrantType::AuthorizationCode) { 169 return Err(Error::InvalidGrantTypes); 170 } 171 if !metadata.scopes.contains(&Scope::Atproto) { 172 return Err(Error::InvalidScope); 173 } 174 let (auth_method, jwks_uri, jwks) = if let Some(keyset) = keyset { 175 let jwks = if metadata.jwks_uri.is_none() { 176 Some(keyset.public_jwks()) 177 } else { 178 None 179 }; 180 (AuthMethod::PrivateKeyJwt, metadata.jwks_uri, jwks) 181 } else { 182 (AuthMethod::None, None, None) 183 }; 184 185 Ok(OAuthClientMetadata { 186 client_id: metadata.client_id, 187 client_uri: metadata.client_uri, 188 redirect_uris: metadata.redirect_uris, 189 token_endpoint_auth_method: Some(auth_method.into()), 190 grant_types: if keyset.is_some() { 191 Some(metadata.grant_types.into_iter().map(|v| v.into()).collect()) 192 } else { 193 None 194 }, 195 scope: if keyset.is_some() { 196 Some(Scope::serialize_multiple(metadata.scopes.as_slice())) 197 } else { 198 None 199 }, 200 dpop_bound_access_tokens: if keyset.is_some() { Some(true) } else { None }, 201 jwks_uri, 202 jwks, 203 token_endpoint_auth_signing_alg: if keyset.is_some() { 204 Some(CowStr::new_static("ES256")) 205 } else { 206 None 207 }, 208 }) 209} 210 211#[cfg(test)] 212mod tests { 213 use std::str::FromStr; 214 215 use crate::scopes::TransitionScope; 216 217 use super::*; 218 use elliptic_curve::SecretKey; 219 use jose_jwk::{Jwk, Key, Parameters}; 220 use p256::pkcs8::DecodePrivateKey; 221 222 const PRIVATE_KEY: &str = r#"-----BEGIN PRIVATE KEY----- 223MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgED1AAgC7Fc9kPh5T 2244i4Tn+z+tc47W1zYgzXtyjJtD92hRANCAAT80DqC+Z/JpTO7/pkPBmWqIV1IGh1P 225gbGGr0pN+oSing7cZ0169JaRHTNh+0LNQXrFobInX6cj95FzEdRyT4T3 226-----END PRIVATE KEY-----"#; 227 228 #[test] 229 fn test_localhost_client_metadata_default() { 230 assert_eq!( 231 atproto_client_metadata(AtprotoClientMetadata::new_localhost(None, None), &None) 232 .unwrap(), 233 OAuthClientMetadata { 234 client_id: Url::from_str("http://localhost").unwrap(), 235 client_uri: None, 236 redirect_uris: vec![ 237 Url::from_str("http://127.0.0.1/").unwrap(), 238 Url::from_str("http://[::1]/").unwrap(), 239 ], 240 scope: None, 241 grant_types: None, 242 token_endpoint_auth_method: Some(AuthMethod::None.into()), 243 dpop_bound_access_tokens: None, 244 jwks_uri: None, 245 jwks: None, 246 token_endpoint_auth_signing_alg: None, 247 } 248 ); 249 } 250 251 #[test] 252 fn test_localhost_client_metadata_custom() { 253 assert_eq!( 254 atproto_client_metadata(AtprotoClientMetadata::new_localhost( 255 Some(vec![ 256 Url::from_str("http://127.0.0.1/callback").unwrap(), 257 Url::from_str("http://[::1]/callback").unwrap(), 258 ]), 259 Some( 260 vec![ 261 Scope::Atproto, 262 Scope::Transition(TransitionScope::Generic), 263 Scope::parse("account:email").unwrap() 264 ] 265 ) 266 ), &None) 267 .expect("failed to convert metadata"), 268 OAuthClientMetadata { 269 client_id: Url::from_str( 270 "http://localhost?redirect_uri=http%3A%2F%2Flocalhost%2Fcallback&redirect_uri=http%3A%2F%2Flocalhost%2Fcallback&scope=account%3Aemail+atproto+transition%3Ageneric" 271 ).unwrap(), 272 client_uri: None, 273 redirect_uris: vec![ 274 Url::from_str("http://localhost/callback").unwrap(), 275 Url::from_str("http://localhost/callback").unwrap(), 276 ], 277 scope: None, 278 grant_types: None, 279 token_endpoint_auth_method: Some(AuthMethod::None.into()), 280 dpop_bound_access_tokens: None, 281 jwks_uri: None, 282 jwks: None, 283 token_endpoint_auth_signing_alg: None, 284 } 285 ); 286 } 287 288 #[test] 289 fn test_localhost_client_metadata_invalid() { 290 // Invalid inputs are coerced to http://localhost rather than failing 291 { 292 let out = atproto_client_metadata( 293 AtprotoClientMetadata::new_localhost( 294 Some(vec![Url::from_str("https://127.0.0.1/").unwrap()]), 295 None, 296 ), 297 &None, 298 ) 299 .expect("should coerce to localhost"); 300 assert_eq!( 301 out, 302 OAuthClientMetadata { 303 client_id: Url::from_str("http://localhost?redirect_uri=http%3A%2F%2Flocalhost%2F").unwrap(), 304 client_uri: None, 305 redirect_uris: vec![Url::from_str("http://localhost/").unwrap()], 306 scope: None, 307 grant_types: None, 308 token_endpoint_auth_method: Some(AuthMethod::None.into()), 309 dpop_bound_access_tokens: None, 310 jwks_uri: None, 311 jwks: None, 312 token_endpoint_auth_signing_alg: None, 313 } 314 ); 315 } 316 { 317 let out = atproto_client_metadata( 318 AtprotoClientMetadata::new_localhost( 319 Some(vec![Url::from_str("http://localhost:8000/").unwrap()]), 320 None, 321 ), 322 &None, 323 ) 324 .expect("should coerce to localhost"); 325 assert_eq!( 326 out, 327 OAuthClientMetadata { 328 client_id: Url::from_str("http://localhost?redirect_uri=http%3A%2F%2Flocalhost%2F").unwrap(), 329 client_uri: None, 330 redirect_uris: vec![Url::from_str("http://localhost/").unwrap()], 331 scope: None, 332 grant_types: None, 333 token_endpoint_auth_method: Some(AuthMethod::None.into()), 334 dpop_bound_access_tokens: None, 335 jwks_uri: None, 336 jwks: None, 337 token_endpoint_auth_signing_alg: None, 338 } 339 ); 340 } 341 { 342 let out = atproto_client_metadata( 343 AtprotoClientMetadata::new_localhost( 344 Some(vec![Url::from_str("http://192.168.0.0/").unwrap()]), 345 None, 346 ), 347 &None, 348 ) 349 .expect("should coerce to localhost"); 350 assert_eq!( 351 out, 352 OAuthClientMetadata { 353 client_id: Url::from_str("http://localhost?redirect_uri=http%3A%2F%2Flocalhost%2F").unwrap(), 354 client_uri: None, 355 redirect_uris: vec![Url::from_str("http://localhost/").unwrap()], 356 scope: None, 357 grant_types: None, 358 token_endpoint_auth_method: Some(AuthMethod::None.into()), 359 dpop_bound_access_tokens: None, 360 jwks_uri: None, 361 jwks: None, 362 token_endpoint_auth_signing_alg: None, 363 } 364 ); 365 } 366 } 367 368 #[test] 369 fn test_client_metadata() { 370 let metadata = AtprotoClientMetadata { 371 client_id: Url::from_str("https://example.com/client_metadata.json").unwrap(), 372 client_uri: Some(Url::from_str("https://example.com").unwrap()), 373 redirect_uris: vec![Url::from_str("https://example.com/callback").unwrap()], 374 grant_types: vec![GrantType::AuthorizationCode], 375 scopes: vec![Scope::Atproto], 376 jwks_uri: None, 377 }; 378 { 379 // Non-loopback clients without a keyset should fail (must provide JWKS) 380 let metadata = metadata.clone(); 381 let err = atproto_client_metadata(metadata, &None).expect_err("expected to fail"); 382 assert!(matches!(err, Error::EmptyJwks)); 383 } 384 { 385 let metadata = metadata.clone(); 386 let secret_key = SecretKey::<p256::NistP256>::from_pkcs8_pem(PRIVATE_KEY) 387 .expect("failed to parse private key"); 388 let keys = vec![Jwk { 389 key: Key::from(&secret_key.into()), 390 prm: Parameters { 391 kid: Some(String::from("kid00")), 392 ..Default::default() 393 }, 394 }]; 395 let keyset = Keyset::try_from(keys.clone()).expect("failed to create keyset"); 396 assert_eq!( 397 atproto_client_metadata(metadata, &Some(keyset.clone())) 398 .expect("failed to convert metadata"), 399 OAuthClientMetadata { 400 client_id: Url::from_str("https://example.com/client_metadata.json").unwrap(), 401 client_uri: Some(Url::from_str("https://example.com").unwrap()), 402 redirect_uris: vec![Url::from_str("https://example.com/callback").unwrap()], 403 scope: Some(CowStr::new_static("atproto")), 404 grant_types: Some(vec![CowStr::new_static("authorization_code")]), 405 token_endpoint_auth_method: Some(AuthMethod::PrivateKeyJwt.into()), 406 dpop_bound_access_tokens: Some(true), 407 jwks_uri: None, 408 jwks: Some(keyset.public_jwks()), 409 token_endpoint_auth_signing_alg: Some(CowStr::new_static("ES256")), 410 } 411 ); 412 } 413 } 414}