A better Rust ATProto crate
1use chrono::{TimeDelta, Utc}; 2use http::{Method, Request, StatusCode}; 3use jacquard_common::{ 4 CowStr, IntoStatic, 5 cowstr::ToCowStr, 6 http_client::HttpClient, 7 session::SessionStoreError, 8 types::{ 9 did::Did, 10 string::{AtStrError, Datetime}, 11 }, 12}; 13use jacquard_identity::resolver::IdentityError; 14use serde::Serialize; 15use serde_json::Value; 16use smol_str::ToSmolStr; 17use thiserror::Error; 18 19use crate::{ 20 FALLBACK_ALG, 21 atproto::atproto_client_metadata, 22 dpop::DpopExt, 23 jose::jwt::{RegisteredClaims, RegisteredClaimsAud}, 24 keyset::Keyset, 25 resolver::OAuthResolver, 26 scopes::Scope, 27 session::{ 28 AuthRequestData, ClientData, ClientSessionData, DpopClientData, DpopDataSource, DpopReqData, 29 }, 30 types::{ 31 AuthorizationCodeChallengeMethod, AuthorizationResponseType, AuthorizeOptionPrompt, 32 OAuthAuthorizationServerMetadata, OAuthClientMetadata, OAuthParResponse, 33 OAuthTokenResponse, ParParameters, RefreshRequestParameters, RevocationRequestParameters, 34 TokenGrantType, TokenRequestParameters, TokenSet, 35 }, 36 utils::{compare_algos, generate_dpop_key, generate_nonce, generate_pkce}, 37}; 38 39// https://datatracker.ietf.org/doc/html/rfc7523#section-2.2 40const CLIENT_ASSERTION_TYPE_JWT_BEARER: &str = 41 "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"; 42 43#[derive(Error, Debug)] 44pub enum Error { 45 #[error("no {0} endpoint available")] 46 NoEndpoint(CowStr<'static>), 47 #[error("token response verification failed")] 48 Token(CowStr<'static>), 49 #[error("unsupported authentication method")] 50 UnsupportedAuthMethod, 51 #[error("no refresh token available")] 52 TokenRefresh, 53 #[error("failed to parse DID: {0}")] 54 InvalidDid(#[from] AtStrError), 55 #[error(transparent)] 56 DpopClient(#[from] crate::dpop::Error), 57 #[error(transparent)] 58 Storage(#[from] SessionStoreError), 59 60 #[error(transparent)] 61 ResolverError(#[from] crate::resolver::ResolverError), 62 // #[error(transparent)] 63 // OAuthSession(#[from] crate::oauth_session::Error), 64 #[error(transparent)] 65 Http(#[from] http::Error), 66 #[error("http client error: {0}")] 67 HttpClient(Box<dyn std::error::Error + Send + Sync + 'static>), 68 #[error("http status: {0}")] 69 HttpStatus(StatusCode), 70 #[error("http status: {0}, body: {1:?}")] 71 HttpStatusWithBody(StatusCode, Value), 72 #[error(transparent)] 73 Identity(#[from] IdentityError), 74 #[error(transparent)] 75 Keyset(#[from] crate::keyset::Error), 76 #[error(transparent)] 77 SerdeHtmlForm(#[from] serde_html_form::ser::Error), 78 #[error(transparent)] 79 SerdeJson(#[from] serde_json::Error), 80 #[error(transparent)] 81 Atproto(#[from] crate::atproto::Error), 82} 83 84pub type Result<T> = core::result::Result<T, Error>; 85 86#[allow(dead_code)] 87pub enum OAuthRequest<'a> { 88 Token(TokenRequestParameters<'a>), 89 Refresh(RefreshRequestParameters<'a>), 90 Revocation(RevocationRequestParameters<'a>), 91 Introspection, 92 PushedAuthorizationRequest(ParParameters<'a>), 93} 94 95impl OAuthRequest<'_> { 96 pub fn name(&self) -> CowStr<'static> { 97 CowStr::new_static(match self { 98 Self::Token(_) => "token", 99 Self::Refresh(_) => "refresh", 100 Self::Revocation(_) => "revocation", 101 Self::Introspection => "introspection", 102 Self::PushedAuthorizationRequest(_) => "pushed_authorization_request", 103 }) 104 } 105 pub fn expected_status(&self) -> StatusCode { 106 match self { 107 Self::Token(_) | Self::Refresh(_) => StatusCode::OK, 108 Self::PushedAuthorizationRequest(_) => StatusCode::CREATED, 109 // Unlike https://datatracker.ietf.org/doc/html/rfc7009#section-2.2, oauth-provider seems to return `204`. 110 Self::Revocation(_) => StatusCode::NO_CONTENT, 111 _ => unimplemented!(), 112 } 113 } 114} 115 116#[derive(Debug, Serialize)] 117pub struct RequestPayload<'a, T> 118where 119 T: Serialize, 120{ 121 client_id: CowStr<'a>, 122 #[serde(skip_serializing_if = "Option::is_none")] 123 client_assertion_type: Option<CowStr<'a>>, 124 #[serde(skip_serializing_if = "Option::is_none")] 125 client_assertion: Option<CowStr<'a>>, 126 #[serde(flatten)] 127 parameters: T, 128} 129 130#[derive(Debug, Clone)] 131pub struct OAuthMetadata { 132 pub server_metadata: OAuthAuthorizationServerMetadata<'static>, 133 pub client_metadata: OAuthClientMetadata<'static>, 134 pub keyset: Option<Keyset>, 135} 136 137impl OAuthMetadata { 138 pub async fn new<'r, T: HttpClient + OAuthResolver + Send + Sync>( 139 client: &T, 140 ClientData { keyset, config }: &ClientData<'r>, 141 session_data: &ClientSessionData<'r>, 142 ) -> Result<Self> { 143 Ok(OAuthMetadata { 144 server_metadata: client 145 .get_authorization_server_metadata(&session_data.authserver_url) 146 .await?, 147 client_metadata: atproto_client_metadata(config.clone(), &keyset) 148 .unwrap() 149 .into_static(), 150 keyset: keyset.clone(), 151 }) 152 } 153} 154 155pub async fn par<'r, T: OAuthResolver + DpopExt + Send + Sync + 'static>( 156 client: &T, 157 login_hint: Option<CowStr<'r>>, 158 prompt: Option<AuthorizeOptionPrompt>, 159 metadata: &OAuthMetadata, 160) -> crate::request::Result<AuthRequestData<'r>> { 161 let state = generate_nonce(); 162 let (code_challenge, verifier) = generate_pkce(); 163 164 let Some(dpop_key) = generate_dpop_key(&metadata.server_metadata) else { 165 return Err(Error::Token("none of the algorithms worked".into())); 166 }; 167 let mut dpop_data = DpopReqData { 168 dpop_key, 169 dpop_authserver_nonce: None, 170 }; 171 let parameters = ParParameters { 172 response_type: AuthorizationResponseType::Code, 173 redirect_uri: metadata.client_metadata.redirect_uris[0].to_cowstr(), 174 state: state.clone(), 175 scope: metadata.client_metadata.scope.clone(), 176 response_mode: None, 177 code_challenge, 178 code_challenge_method: AuthorizationCodeChallengeMethod::S256, 179 login_hint: login_hint, 180 prompt: prompt.map(CowStr::from), 181 }; 182 if metadata 183 .server_metadata 184 .pushed_authorization_request_endpoint 185 .is_some() 186 { 187 let par_response = oauth_request::<OAuthParResponse, T, DpopReqData>( 188 &client, 189 &mut dpop_data, 190 OAuthRequest::PushedAuthorizationRequest(parameters), 191 metadata, 192 ) 193 .await?; 194 195 let scopes = if let Some(scope) = &metadata.client_metadata.scope { 196 Scope::parse_multiple_reduced(&scope) 197 .expect("Failed to parse scopes") 198 .into_static() 199 } else { 200 vec![] 201 }; 202 let auth_req_data = AuthRequestData { 203 state, 204 authserver_url: url::Url::parse(&metadata.server_metadata.issuer) 205 .expect("Failed to parse issuer URL"), 206 account_did: None, 207 scopes, 208 request_uri: par_response.request_uri.to_cowstr().into_static(), 209 authserver_token_endpoint: metadata.server_metadata.token_endpoint.clone(), 210 authserver_revocation_endpoint: metadata.server_metadata.revocation_endpoint.clone(), 211 pkce_verifier: verifier, 212 dpop_data, 213 }; 214 215 Ok(auth_req_data) 216 } else if metadata 217 .server_metadata 218 .require_pushed_authorization_requests 219 == Some(true) 220 { 221 Err(Error::NoEndpoint(CowStr::new_static( 222 "server requires PAR but no endpoint is available", 223 ))) 224 } else { 225 todo!("use of PAR is mandatory") 226 } 227} 228 229pub async fn refresh<'r, T>( 230 client: &T, 231 mut session_data: ClientSessionData<'r>, 232 metadata: &OAuthMetadata, 233) -> Result<ClientSessionData<'r>> 234where 235 T: OAuthResolver + DpopExt + Send + Sync + 'static, 236{ 237 let Some(refresh_token) = session_data.token_set.refresh_token.as_ref() else { 238 return Err(Error::TokenRefresh); 239 }; 240 241 // /!\ IMPORTANT /!\ 242 // 243 // The "sub" MUST be a DID, whose issuer authority is indeed the server we 244 // are trying to obtain credentials from. Note that we are doing this 245 // *before* we actually try to refresh the token: 246 // 1) To avoid unnecessary refresh 247 // 2) So that the refresh is the last async operation, ensuring as few 248 // async operations happen before the result gets a chance to be stored. 249 let aud = client 250 .verify_issuer(&metadata.server_metadata, &session_data.token_set.sub) 251 .await?; 252 let iss = metadata.server_metadata.issuer.clone(); 253 254 let response = oauth_request::<OAuthTokenResponse, T, DpopClientData>( 255 client, 256 &mut session_data.dpop_data, 257 OAuthRequest::Refresh(RefreshRequestParameters { 258 grant_type: TokenGrantType::RefreshToken, 259 refresh_token: refresh_token.clone(), 260 scope: None, 261 }), 262 metadata, 263 ) 264 .await?; 265 266 let expires_at = response.expires_in.and_then(|expires_in| { 267 let now = Datetime::now(); 268 now.as_ref() 269 .checked_add_signed(TimeDelta::seconds(expires_in)) 270 .map(Datetime::new) 271 }); 272 273 session_data.update_with_tokens(TokenSet { 274 iss, 275 sub: session_data.token_set.sub.clone(), 276 aud: CowStr::Owned(aud.to_smolstr()), 277 scope: response.scope.map(CowStr::Owned), 278 access_token: CowStr::Owned(response.access_token), 279 refresh_token: response.refresh_token.map(CowStr::Owned), 280 token_type: response.token_type, 281 expires_at, 282 }); 283 284 Ok(session_data) 285} 286 287pub async fn exchange_code<'r, T, D>( 288 client: &T, 289 data_source: &'r mut D, 290 code: &str, 291 verifier: &str, 292 metadata: &OAuthMetadata, 293) -> Result<TokenSet<'r>> 294where 295 T: OAuthResolver + DpopExt + Send + Sync + 'static, 296 D: DpopDataSource, 297{ 298 let token_response = oauth_request::<OAuthTokenResponse, T, D>( 299 client, 300 data_source, 301 OAuthRequest::Token(TokenRequestParameters { 302 grant_type: TokenGrantType::AuthorizationCode, 303 code: code.into(), 304 redirect_uri: CowStr::Owned( 305 metadata.client_metadata.redirect_uris[0] 306 .clone() 307 .to_smolstr(), 308 ), // ? 309 code_verifier: verifier.into(), 310 }), 311 metadata, 312 ) 313 .await?; 314 let Some(sub) = token_response.sub else { 315 return Err(Error::Token("missing `sub` in token response".into())); 316 }; 317 let sub = Did::new_owned(sub)?; 318 let iss = metadata.server_metadata.issuer.clone(); 319 // /!\ IMPORTANT /!\ 320 // 321 // The token_response MUST always be valid before the "sub" it contains 322 // can be trusted (see Atproto's OAuth spec for details). 323 let aud = client 324 .verify_issuer(&metadata.server_metadata, &sub) 325 .await?; 326 327 let expires_at = token_response.expires_in.and_then(|expires_in| { 328 Datetime::now() 329 .as_ref() 330 .checked_add_signed(TimeDelta::seconds(expires_in)) 331 .map(Datetime::new) 332 }); 333 Ok(TokenSet { 334 iss, 335 sub, 336 aud: CowStr::Owned(aud.to_smolstr()), 337 scope: token_response.scope.map(CowStr::Owned), 338 access_token: CowStr::Owned(token_response.access_token), 339 refresh_token: token_response.refresh_token.map(CowStr::Owned), 340 token_type: token_response.token_type, 341 expires_at, 342 }) 343} 344 345pub async fn revoke<'r, T, D>( 346 client: &T, 347 data_source: &'r mut D, 348 token: &str, 349 metadata: &OAuthMetadata, 350) -> Result<()> 351where 352 T: OAuthResolver + DpopExt + Send + Sync + 'static, 353 D: DpopDataSource, 354{ 355 oauth_request::<(), T, D>( 356 client, 357 data_source, 358 OAuthRequest::Revocation(RevocationRequestParameters { 359 token: token.into(), 360 }), 361 metadata, 362 ) 363 .await?; 364 Ok(()) 365} 366 367pub async fn oauth_request<'de: 'r, 'r, O, T, D>( 368 client: &T, 369 data_source: &'r mut D, 370 request: OAuthRequest<'r>, 371 metadata: &OAuthMetadata, 372) -> Result<O> 373where 374 T: OAuthResolver + DpopExt + Send + Sync + 'static, 375 O: serde::de::DeserializeOwned, 376 D: DpopDataSource, 377{ 378 let Some(url) = endpoint_for_req(&metadata.server_metadata, &request) else { 379 return Err(Error::NoEndpoint(request.name())); 380 }; 381 let client_assertions = build_auth( 382 metadata.keyset.as_ref(), 383 &metadata.server_metadata, 384 &metadata.client_metadata, 385 )?; 386 let body = match &request { 387 OAuthRequest::Token(params) => build_oauth_req_body(client_assertions, params)?, 388 OAuthRequest::Refresh(params) => build_oauth_req_body(client_assertions, params)?, 389 OAuthRequest::Revocation(params) => build_oauth_req_body(client_assertions, params)?, 390 OAuthRequest::PushedAuthorizationRequest(params) => { 391 build_oauth_req_body(client_assertions, params)? 392 } 393 _ => unimplemented!(), 394 }; 395 let req = Request::builder() 396 .uri(url.to_string()) 397 .method(Method::POST) 398 .header("Content-Type", "application/x-www-form-urlencoded") 399 .body(body.into_bytes())?; 400 let res = client 401 .dpop_server_call(data_source) 402 .send(req) 403 .await 404 .map_err(Error::DpopClient)?; 405 if res.status() == request.expected_status() { 406 let body = res.body(); 407 if body.is_empty() { 408 // since an empty body cannot be deserialized, use “null” temporarily to allow deserialization to `()`. 409 Ok(serde_json::from_slice(b"null")?) 410 } else { 411 let output: O = serde_json::from_slice(body)?; 412 Ok(output) 413 } 414 } else if res.status().is_client_error() { 415 Err(Error::HttpStatusWithBody( 416 res.status(), 417 serde_json::from_slice(res.body())?, 418 )) 419 } else { 420 Err(Error::HttpStatus(res.status())) 421 } 422} 423 424#[inline] 425fn endpoint_for_req<'a, 'r>( 426 server_metadata: &'r OAuthAuthorizationServerMetadata<'a>, 427 request: &'r OAuthRequest, 428) -> Option<&'r CowStr<'a>> { 429 match request { 430 OAuthRequest::Token(_) | OAuthRequest::Refresh(_) => Some(&server_metadata.token_endpoint), 431 OAuthRequest::Revocation(_) => server_metadata.revocation_endpoint.as_ref(), 432 OAuthRequest::Introspection => server_metadata.introspection_endpoint.as_ref(), 433 OAuthRequest::PushedAuthorizationRequest(_) => server_metadata 434 .pushed_authorization_request_endpoint 435 .as_ref(), 436 } 437} 438 439#[inline] 440fn build_oauth_req_body<'a, S>(client_assertions: ClientAuth<'a>, parameters: S) -> Result<String> 441where 442 S: Serialize, 443{ 444 Ok(serde_html_form::to_string(RequestPayload { 445 client_id: client_assertions.client_id, 446 client_assertion_type: client_assertions.assertion_type, 447 client_assertion: client_assertions.assertion, 448 parameters, 449 })?) 450} 451 452#[derive(Debug, Clone, Default)] 453pub struct ClientAuth<'a> { 454 client_id: CowStr<'a>, 455 assertion_type: Option<CowStr<'a>>, // either none or `CLIENT_ASSERTION_TYPE_JWT_BEARER` 456 assertion: Option<CowStr<'a>>, 457} 458 459impl<'s> ClientAuth<'s> { 460 pub fn new_id(client_id: CowStr<'s>) -> Self { 461 Self { 462 client_id, 463 assertion_type: None, 464 assertion: None, 465 } 466 } 467} 468 469fn build_auth<'a>( 470 keyset: Option<&Keyset>, 471 server_metadata: &OAuthAuthorizationServerMetadata<'a>, 472 client_metadata: &OAuthClientMetadata<'a>, 473) -> Result<ClientAuth<'a>> { 474 let method_supported = server_metadata 475 .token_endpoint_auth_methods_supported 476 .as_ref(); 477 478 let client_id = client_metadata.client_id.to_cowstr().into_static(); 479 if let Some(method) = client_metadata.token_endpoint_auth_method.as_ref() { 480 match (*method).as_ref() { 481 "private_key_jwt" 482 if method_supported 483 .as_ref() 484 .is_some_and(|v| v.contains(&CowStr::new_static("private_key_jwt"))) => 485 { 486 if let Some(keyset) = &keyset { 487 let mut algs = server_metadata 488 .token_endpoint_auth_signing_alg_values_supported 489 .clone() 490 .unwrap_or(vec![FALLBACK_ALG.into()]); 491 algs.sort_by(compare_algos); 492 let iat = Utc::now().timestamp(); 493 return Ok(ClientAuth { 494 client_id: client_id.clone(), 495 assertion_type: Some(CowStr::new_static(CLIENT_ASSERTION_TYPE_JWT_BEARER)), 496 assertion: Some( 497 keyset.create_jwt( 498 &algs, 499 // https://datatracker.ietf.org/doc/html/rfc7523#section-3 500 RegisteredClaims { 501 iss: Some(client_id.clone()), 502 sub: Some(client_id), 503 aud: Some(RegisteredClaimsAud::Single( 504 server_metadata.issuer.clone(), 505 )), 506 exp: Some(iat + 60), 507 // "iat" is required and **MUST** be less than one minute 508 // https://datatracker.ietf.org/doc/html/rfc9101 509 iat: Some(iat), 510 // atproto oauth-provider requires "jti" to be present 511 jti: Some(generate_nonce()), 512 ..Default::default() 513 } 514 .into(), 515 )?, 516 ), 517 }); 518 } 519 } 520 "none" 521 if method_supported 522 .as_ref() 523 .is_some_and(|v| v.contains(&CowStr::new_static("none"))) => 524 { 525 return Ok(ClientAuth::new_id(client_id)); 526 } 527 _ => {} 528 } 529 } 530 531 Err(Error::UnsupportedAuthMethod) 532}