A better Rust ATProto crate
1use std::str::FromStr; 2 3use crate::types::OAuthClientMetadata; 4use crate::{keyset::Keyset, scopes::Scope}; 5use jacquard_common::CowStr; 6use serde::{Deserialize, Serialize}; 7use thiserror::Error; 8use url::Url; 9 10#[derive(Error, Debug)] 11pub enum Error { 12 #[error("`client_id` must be a valid URL")] 13 InvalidClientId, 14 #[error("`grant_types` must include `authorization_code`")] 15 InvalidGrantTypes, 16 #[error("`scope` must not include `atproto`")] 17 InvalidScope, 18 #[error("`redirect_uris` must not be empty")] 19 EmptyRedirectUris, 20 #[error("`private_key_jwt` auth method requires `jwks` keys")] 21 EmptyJwks, 22 #[error( 23 "`private_key_jwt` auth method requires `token_endpoint_auth_signing_alg`, otherwise must not be provided" 24 )] 25 AuthSigningAlg, 26 #[error(transparent)] 27 SerdeHtmlForm(#[from] serde_html_form::ser::Error), 28 #[error(transparent)] 29 LocalhostClient(#[from] LocalhostClientError), 30} 31 32#[derive(Error, Debug)] 33pub enum LocalhostClientError { 34 #[error("invalid redirect_uri: {0}")] 35 Invalid(#[from] url::ParseError), 36 #[error("loopback client_id must use `http:` redirect_uri")] 37 NotHttpScheme, 38 #[error("loopback client_id must not use `localhost` as redirect_uri hostname")] 39 Localhost, 40 #[error("loopback client_id must not use loopback addresses as redirect_uri")] 41 NotLoopbackHost, 42} 43 44pub type Result<T> = core::result::Result<T, Error>; 45 46#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] 47#[serde(rename_all = "snake_case")] 48pub enum AuthMethod { 49 None, 50 // https://openid.net/specs/openid-connect-core-1_0.html#ClientAuthentication 51 PrivateKeyJwt, 52} 53 54impl From<AuthMethod> for CowStr<'static> { 55 fn from(value: AuthMethod) -> Self { 56 match value { 57 AuthMethod::None => CowStr::new_static("none"), 58 AuthMethod::PrivateKeyJwt => CowStr::new_static("private_key_jwt"), 59 } 60 } 61} 62 63#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] 64#[serde(rename_all = "snake_case")] 65pub enum GrantType { 66 AuthorizationCode, 67 RefreshToken, 68} 69 70impl From<GrantType> for CowStr<'static> { 71 fn from(value: GrantType) -> Self { 72 match value { 73 GrantType::AuthorizationCode => CowStr::new_static("authorization_code"), 74 GrantType::RefreshToken => CowStr::new_static("refresh_token"), 75 } 76 } 77} 78 79#[derive(Clone, Debug, PartialEq, Eq)] 80pub struct AtprotoClientMetadata<'m> { 81 pub client_id: Url, 82 pub client_uri: Option<Url>, 83 pub redirect_uris: Vec<Url>, 84 pub grant_types: Vec<GrantType>, 85 pub scopes: Vec<Scope<'m>>, 86 pub jwks_uri: Option<Url>, 87} 88 89impl<'m> AtprotoClientMetadata<'m> { 90 pub fn new( 91 client_id: Url, 92 client_uri: Option<Url>, 93 redirect_uris: Vec<Url>, 94 grant_types: Vec<GrantType>, 95 scopes: Vec<Scope<'m>>, 96 jwks_uri: Option<Url>, 97 ) -> Self { 98 Self { 99 client_id, 100 client_uri, 101 redirect_uris, 102 grant_types, 103 scopes, 104 jwks_uri, 105 } 106 } 107 108 pub fn new_localhost( 109 mut redirect_uris: Option<Vec<Url>>, 110 scopes: Option<Vec<Scope<'m>>>, 111 ) -> Self { 112 // coerce redirect uris to localhost 113 if let Some(redirect_uris) = &mut redirect_uris { 114 for redirect_uri in redirect_uris { 115 redirect_uri.set_host(Some("http://localhost")).unwrap(); 116 } 117 } 118 // determine client_id 119 #[derive(serde::Serialize)] 120 struct Parameters<'a> { 121 #[serde(skip_serializing_if = "Option::is_none")] 122 redirect_uri: Option<Vec<Url>>, 123 #[serde(skip_serializing_if = "Option::is_none")] 124 scope: Option<CowStr<'a>>, 125 } 126 let query = serde_html_form::to_string(Parameters { 127 redirect_uri: redirect_uris.clone(), 128 scope: scopes 129 .as_ref() 130 .map(|s| Scope::serialize_multiple(s.as_slice())), 131 }) 132 .ok(); 133 let mut client_id = String::from("http://localhost"); 134 if let Some(query) = query 135 && !query.is_empty() 136 { 137 client_id.push_str(&format!("?{query}")); 138 } 139 Self { 140 client_id: Url::parse(&client_id).unwrap(), 141 client_uri: None, 142 redirect_uris: redirect_uris.unwrap_or(vec![ 143 Url::from_str("http://127.0.0.1/").unwrap(), 144 Url::from_str("http://[::1]/").unwrap(), 145 ]), 146 grant_types: vec![GrantType::AuthorizationCode, GrantType::RefreshToken], 147 scopes: scopes.unwrap_or(vec![Scope::Atproto]), 148 jwks_uri: None, 149 } 150 } 151} 152 153pub fn atproto_client_metadata<'m>( 154 metadata: AtprotoClientMetadata<'m>, 155 keyset: &Option<Keyset>, 156) -> Result<OAuthClientMetadata<'m>> { 157 if metadata.redirect_uris.is_empty() { 158 return Err(Error::EmptyRedirectUris); 159 } 160 if !metadata.grant_types.contains(&GrantType::AuthorizationCode) { 161 return Err(Error::InvalidGrantTypes); 162 } 163 if !metadata.scopes.contains(&Scope::Atproto) { 164 return Err(Error::InvalidScope); 165 } 166 let (auth_method, jwks_uri, jwks) = if let Some(keyset) = keyset { 167 let jwks = if metadata.jwks_uri.is_none() { 168 Some(keyset.public_jwks()) 169 } else { 170 None 171 }; 172 (AuthMethod::PrivateKeyJwt, metadata.jwks_uri, jwks) 173 } else { 174 (AuthMethod::None, None, None) 175 }; 176 177 Ok(OAuthClientMetadata { 178 client_id: metadata.client_id, 179 client_uri: metadata.client_uri, 180 redirect_uris: metadata.redirect_uris, 181 token_endpoint_auth_method: Some(auth_method.into()), 182 grant_types: Some(metadata.grant_types.into_iter().map(|v| v.into()).collect()), 183 scope: Some(Scope::serialize_multiple(metadata.scopes.as_slice())), 184 dpop_bound_access_tokens: Some(true), 185 jwks_uri, 186 jwks, 187 token_endpoint_auth_signing_alg: if keyset.is_some() { 188 Some(CowStr::new_static("ES256")) 189 } else { 190 None 191 }, 192 }) 193} 194 195#[cfg(test)] 196mod tests { 197 use std::str::FromStr; 198 199 use crate::scopes::TransitionScope; 200 201 use super::*; 202 use elliptic_curve::SecretKey; 203 use jose_jwk::{Jwk, Key, Parameters}; 204 use p256::pkcs8::DecodePrivateKey; 205 206 const PRIVATE_KEY: &str = r#"-----BEGIN PRIVATE KEY----- 207MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgED1AAgC7Fc9kPh5T 2084i4Tn+z+tc47W1zYgzXtyjJtD92hRANCAAT80DqC+Z/JpTO7/pkPBmWqIV1IGh1P 209gbGGr0pN+oSing7cZ0169JaRHTNh+0LNQXrFobInX6cj95FzEdRyT4T3 210-----END PRIVATE KEY-----"#; 211 212 #[test] 213 fn test_localhost_client_metadata_default() { 214 assert_eq!( 215 atproto_client_metadata(AtprotoClientMetadata::new_localhost(None, None), &None) 216 .unwrap(), 217 OAuthClientMetadata { 218 client_id: Url::from_str("http://localhost").unwrap(), 219 client_uri: None, 220 redirect_uris: vec![ 221 Url::from_str("http://127.0.0.1/").unwrap(), 222 Url::from_str("http://[::1]/").unwrap(), 223 ], 224 scope: None, 225 grant_types: None, 226 token_endpoint_auth_method: Some(AuthMethod::None.into()), 227 dpop_bound_access_tokens: None, 228 jwks_uri: None, 229 jwks: None, 230 token_endpoint_auth_signing_alg: None, 231 } 232 ); 233 } 234 235 #[test] 236 fn test_localhost_client_metadata_custom() { 237 assert_eq!( 238 atproto_client_metadata(AtprotoClientMetadata::new_localhost( 239 Some(vec![ 240 Url::from_str("http://127.0.0.1/callback").unwrap(), 241 Url::from_str("http://[::1]/callback").unwrap(), 242 ]), 243 Some( 244 vec![ 245 Scope::Atproto, 246 Scope::Transition(TransitionScope::Generic), 247 Scope::parse("account:email").unwrap() 248 ] 249 ) 250 ), &None) 251 .expect("failed to convert metadata"), 252 OAuthClientMetadata { 253 client_id: Url::from_str( 254 "http://localhost?redirect_uri=http%3A%2F%2F127.0.0.1%2Fcallback&redirect_uri=http%3A%2F%2F%5B%3A%3A1%5D%2Fcallback&scope=account%3Aemail+atproto+transition%3Ageneric" 255 ).unwrap(), 256 client_uri: None, 257 redirect_uris: vec![ 258 Url::from_str("http://127.0.0.1/callback").unwrap(), 259 Url::from_str("http://[::1]/callback").unwrap(), 260 ], 261 scope: None, 262 grant_types: None, 263 token_endpoint_auth_method: Some(AuthMethod::None.into()), 264 dpop_bound_access_tokens: None, 265 jwks_uri: None, 266 jwks: None, 267 token_endpoint_auth_signing_alg: None, 268 } 269 ); 270 } 271 272 #[test] 273 fn test_localhost_client_metadata_invalid() { 274 { 275 let err = atproto_client_metadata( 276 AtprotoClientMetadata::new_localhost( 277 Some(vec![Url::from_str("https://127.0.0.1/").unwrap()]), 278 None, 279 ), 280 &None, 281 ) 282 .expect_err("expected to fail"); 283 assert!(matches!( 284 err, 285 Error::LocalhostClient(LocalhostClientError::NotHttpScheme) 286 )); 287 } 288 { 289 let err = atproto_client_metadata( 290 AtprotoClientMetadata::new_localhost( 291 Some(vec![Url::from_str("http://localhost:8000/").unwrap()]), 292 None, 293 ), 294 &None, 295 ) 296 .expect_err("expected to fail"); 297 assert!(matches!( 298 err, 299 Error::LocalhostClient(LocalhostClientError::Localhost) 300 )); 301 } 302 { 303 let err = atproto_client_metadata( 304 AtprotoClientMetadata::new_localhost( 305 Some(vec![Url::from_str("http://192.168.0.0/").unwrap()]), 306 None, 307 ), 308 &None, 309 ) 310 .expect_err("expected to fail"); 311 assert!(matches!( 312 err, 313 Error::LocalhostClient(LocalhostClientError::NotLoopbackHost) 314 )); 315 } 316 } 317 318 #[test] 319 fn test_client_metadata() { 320 let metadata = AtprotoClientMetadata { 321 client_id: Url::from_str("https://example.com/client_metadata.json").unwrap(), 322 client_uri: Some(Url::from_str("https://example.com").unwrap()), 323 redirect_uris: vec![Url::from_str("https://example.com/callback").unwrap()], 324 grant_types: vec![GrantType::AuthorizationCode], 325 scopes: vec![Scope::Atproto], 326 jwks_uri: None, 327 }; 328 { 329 let metadata = metadata.clone(); 330 let err = atproto_client_metadata(metadata, &None).expect_err("expected to fail"); 331 assert!(matches!(err, Error::EmptyJwks)); 332 } 333 { 334 let metadata = metadata.clone(); 335 let secret_key = SecretKey::<p256::NistP256>::from_pkcs8_pem(PRIVATE_KEY) 336 .expect("failed to parse private key"); 337 let keys = vec![Jwk { 338 key: Key::from(&secret_key.into()), 339 prm: Parameters { 340 kid: Some(String::from("kid00")), 341 ..Default::default() 342 }, 343 }]; 344 let keyset = Keyset::try_from(keys.clone()).expect("failed to create keyset"); 345 assert_eq!( 346 atproto_client_metadata(metadata, &Some(keyset.clone())) 347 .expect("failed to convert metadata"), 348 OAuthClientMetadata { 349 client_id: Url::from_str("https://example.com/client_metadata.json").unwrap(), 350 client_uri: Some(Url::from_str("https://example.com").unwrap()), 351 redirect_uris: vec![Url::from_str("https://example.com/callback").unwrap()], 352 scope: Some(CowStr::new_static("atproto")), 353 grant_types: Some(vec![CowStr::new_static("authorization_code")]), 354 token_endpoint_auth_method: Some(AuthMethod::PrivateKeyJwt.into()), 355 dpop_bound_access_tokens: Some(true), 356 jwks_uri: None, 357 jwks: Some(keyset.public_jwks()), 358 token_endpoint_auth_signing_alg: Some(CowStr::new_static("ES256")), 359 } 360 ); 361 } 362 } 363}