1use std::str::FromStr;
2
3use crate::types::OAuthClientMetadata;
4use crate::{keyset::Keyset, scopes::Scope};
5use jacquard_common::CowStr;
6use serde::{Deserialize, Serialize};
7use thiserror::Error;
8use url::Url;
9
10#[derive(Error, Debug)]
11pub enum Error {
12 #[error("`client_id` must be a valid URL")]
13 InvalidClientId,
14 #[error("`grant_types` must include `authorization_code`")]
15 InvalidGrantTypes,
16 #[error("`scope` must not include `atproto`")]
17 InvalidScope,
18 #[error("`redirect_uris` must not be empty")]
19 EmptyRedirectUris,
20 #[error("`private_key_jwt` auth method requires `jwks` keys")]
21 EmptyJwks,
22 #[error(
23 "`private_key_jwt` auth method requires `token_endpoint_auth_signing_alg`, otherwise must not be provided"
24 )]
25 AuthSigningAlg,
26 #[error(transparent)]
27 SerdeHtmlForm(#[from] serde_html_form::ser::Error),
28 #[error(transparent)]
29 LocalhostClient(#[from] LocalhostClientError),
30}
31
32#[derive(Error, Debug)]
33pub enum LocalhostClientError {
34 #[error("invalid redirect_uri: {0}")]
35 Invalid(#[from] url::ParseError),
36 #[error("loopback client_id must use `http:` redirect_uri")]
37 NotHttpScheme,
38 #[error("loopback client_id must not use `localhost` as redirect_uri hostname")]
39 Localhost,
40 #[error("loopback client_id must not use loopback addresses as redirect_uri")]
41 NotLoopbackHost,
42}
43
44pub type Result<T> = core::result::Result<T, Error>;
45
46#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
47#[serde(rename_all = "snake_case")]
48pub enum AuthMethod {
49 None,
50 // https://openid.net/specs/openid-connect-core-1_0.html#ClientAuthentication
51 PrivateKeyJwt,
52}
53
54impl From<AuthMethod> for CowStr<'static> {
55 fn from(value: AuthMethod) -> Self {
56 match value {
57 AuthMethod::None => CowStr::new_static("none"),
58 AuthMethod::PrivateKeyJwt => CowStr::new_static("private_key_jwt"),
59 }
60 }
61}
62
63#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
64#[serde(rename_all = "snake_case")]
65pub enum GrantType {
66 AuthorizationCode,
67 RefreshToken,
68}
69
70impl From<GrantType> for CowStr<'static> {
71 fn from(value: GrantType) -> Self {
72 match value {
73 GrantType::AuthorizationCode => CowStr::new_static("authorization_code"),
74 GrantType::RefreshToken => CowStr::new_static("refresh_token"),
75 }
76 }
77}
78
79#[derive(Clone, Debug, PartialEq, Eq)]
80pub struct AtprotoClientMetadata<'m> {
81 pub client_id: Url,
82 pub client_uri: Option<Url>,
83 pub redirect_uris: Vec<Url>,
84 pub grant_types: Vec<GrantType>,
85 pub scopes: Vec<Scope<'m>>,
86 pub jwks_uri: Option<Url>,
87}
88
89impl<'m> AtprotoClientMetadata<'m> {
90 pub fn new(
91 client_id: Url,
92 client_uri: Option<Url>,
93 redirect_uris: Vec<Url>,
94 grant_types: Vec<GrantType>,
95 scopes: Vec<Scope<'m>>,
96 jwks_uri: Option<Url>,
97 ) -> Self {
98 Self {
99 client_id,
100 client_uri,
101 redirect_uris,
102 grant_types,
103 scopes,
104 jwks_uri,
105 }
106 }
107
108 pub fn new_localhost(
109 mut redirect_uris: Option<Vec<Url>>,
110 scopes: Option<Vec<Scope<'m>>>,
111 ) -> Self {
112 // coerce redirect uris to localhost
113 if let Some(redirect_uris) = &mut redirect_uris {
114 for redirect_uri in redirect_uris {
115 redirect_uri.set_host(Some("http://localhost")).unwrap();
116 }
117 }
118 // determine client_id
119 #[derive(serde::Serialize)]
120 struct Parameters<'a> {
121 #[serde(skip_serializing_if = "Option::is_none")]
122 redirect_uri: Option<Vec<Url>>,
123 #[serde(skip_serializing_if = "Option::is_none")]
124 scope: Option<CowStr<'a>>,
125 }
126 let query = serde_html_form::to_string(Parameters {
127 redirect_uri: redirect_uris.clone(),
128 scope: scopes
129 .as_ref()
130 .map(|s| Scope::serialize_multiple(s.as_slice())),
131 })
132 .ok();
133 let mut client_id = String::from("http://localhost");
134 if let Some(query) = query
135 && !query.is_empty()
136 {
137 client_id.push_str(&format!("?{query}"));
138 }
139 Self {
140 client_id: Url::parse(&client_id).unwrap(),
141 client_uri: None,
142 redirect_uris: redirect_uris.unwrap_or(vec![
143 Url::from_str("http://127.0.0.1/").unwrap(),
144 Url::from_str("http://[::1]/").unwrap(),
145 ]),
146 grant_types: vec![GrantType::AuthorizationCode, GrantType::RefreshToken],
147 scopes: scopes.unwrap_or(vec![Scope::Atproto]),
148 jwks_uri: None,
149 }
150 }
151}
152
153pub fn atproto_client_metadata<'m>(
154 metadata: AtprotoClientMetadata<'m>,
155 keyset: &Option<Keyset>,
156) -> Result<OAuthClientMetadata<'m>> {
157 if metadata.redirect_uris.is_empty() {
158 return Err(Error::EmptyRedirectUris);
159 }
160 if !metadata.grant_types.contains(&GrantType::AuthorizationCode) {
161 return Err(Error::InvalidGrantTypes);
162 }
163 if !metadata.scopes.contains(&Scope::Atproto) {
164 return Err(Error::InvalidScope);
165 }
166 let (auth_method, jwks_uri, jwks) = if let Some(keyset) = keyset {
167 let jwks = if metadata.jwks_uri.is_none() {
168 Some(keyset.public_jwks())
169 } else {
170 None
171 };
172 (AuthMethod::PrivateKeyJwt, metadata.jwks_uri, jwks)
173 } else {
174 (AuthMethod::None, None, None)
175 };
176
177 Ok(OAuthClientMetadata {
178 client_id: metadata.client_id,
179 client_uri: metadata.client_uri,
180 redirect_uris: metadata.redirect_uris,
181 token_endpoint_auth_method: Some(auth_method.into()),
182 grant_types: Some(metadata.grant_types.into_iter().map(|v| v.into()).collect()),
183 scope: Some(Scope::serialize_multiple(metadata.scopes.as_slice())),
184 dpop_bound_access_tokens: Some(true),
185 jwks_uri,
186 jwks,
187 token_endpoint_auth_signing_alg: if keyset.is_some() {
188 Some(CowStr::new_static("ES256"))
189 } else {
190 None
191 },
192 })
193}
194
195#[cfg(test)]
196mod tests {
197 use std::str::FromStr;
198
199 use crate::scopes::TransitionScope;
200
201 use super::*;
202 use elliptic_curve::SecretKey;
203 use jose_jwk::{Jwk, Key, Parameters};
204 use p256::pkcs8::DecodePrivateKey;
205
206 const PRIVATE_KEY: &str = r#"-----BEGIN PRIVATE KEY-----
207MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgED1AAgC7Fc9kPh5T
2084i4Tn+z+tc47W1zYgzXtyjJtD92hRANCAAT80DqC+Z/JpTO7/pkPBmWqIV1IGh1P
209gbGGr0pN+oSing7cZ0169JaRHTNh+0LNQXrFobInX6cj95FzEdRyT4T3
210-----END PRIVATE KEY-----"#;
211
212 #[test]
213 fn test_localhost_client_metadata_default() {
214 assert_eq!(
215 atproto_client_metadata(AtprotoClientMetadata::new_localhost(None, None), &None)
216 .unwrap(),
217 OAuthClientMetadata {
218 client_id: Url::from_str("http://localhost").unwrap(),
219 client_uri: None,
220 redirect_uris: vec![
221 Url::from_str("http://127.0.0.1/").unwrap(),
222 Url::from_str("http://[::1]/").unwrap(),
223 ],
224 scope: None,
225 grant_types: None,
226 token_endpoint_auth_method: Some(AuthMethod::None.into()),
227 dpop_bound_access_tokens: None,
228 jwks_uri: None,
229 jwks: None,
230 token_endpoint_auth_signing_alg: None,
231 }
232 );
233 }
234
235 #[test]
236 fn test_localhost_client_metadata_custom() {
237 assert_eq!(
238 atproto_client_metadata(AtprotoClientMetadata::new_localhost(
239 Some(vec![
240 Url::from_str("http://127.0.0.1/callback").unwrap(),
241 Url::from_str("http://[::1]/callback").unwrap(),
242 ]),
243 Some(
244 vec![
245 Scope::Atproto,
246 Scope::Transition(TransitionScope::Generic),
247 Scope::parse("account:email").unwrap()
248 ]
249 )
250 ), &None)
251 .expect("failed to convert metadata"),
252 OAuthClientMetadata {
253 client_id: Url::from_str(
254 "http://localhost?redirect_uri=http%3A%2F%2F127.0.0.1%2Fcallback&redirect_uri=http%3A%2F%2F%5B%3A%3A1%5D%2Fcallback&scope=account%3Aemail+atproto+transition%3Ageneric"
255 ).unwrap(),
256 client_uri: None,
257 redirect_uris: vec![
258 Url::from_str("http://127.0.0.1/callback").unwrap(),
259 Url::from_str("http://[::1]/callback").unwrap(),
260 ],
261 scope: None,
262 grant_types: None,
263 token_endpoint_auth_method: Some(AuthMethod::None.into()),
264 dpop_bound_access_tokens: None,
265 jwks_uri: None,
266 jwks: None,
267 token_endpoint_auth_signing_alg: None,
268 }
269 );
270 }
271
272 #[test]
273 fn test_localhost_client_metadata_invalid() {
274 {
275 let err = atproto_client_metadata(
276 AtprotoClientMetadata::new_localhost(
277 Some(vec![Url::from_str("https://127.0.0.1/").unwrap()]),
278 None,
279 ),
280 &None,
281 )
282 .expect_err("expected to fail");
283 assert!(matches!(
284 err,
285 Error::LocalhostClient(LocalhostClientError::NotHttpScheme)
286 ));
287 }
288 {
289 let err = atproto_client_metadata(
290 AtprotoClientMetadata::new_localhost(
291 Some(vec![Url::from_str("http://localhost:8000/").unwrap()]),
292 None,
293 ),
294 &None,
295 )
296 .expect_err("expected to fail");
297 assert!(matches!(
298 err,
299 Error::LocalhostClient(LocalhostClientError::Localhost)
300 ));
301 }
302 {
303 let err = atproto_client_metadata(
304 AtprotoClientMetadata::new_localhost(
305 Some(vec![Url::from_str("http://192.168.0.0/").unwrap()]),
306 None,
307 ),
308 &None,
309 )
310 .expect_err("expected to fail");
311 assert!(matches!(
312 err,
313 Error::LocalhostClient(LocalhostClientError::NotLoopbackHost)
314 ));
315 }
316 }
317
318 #[test]
319 fn test_client_metadata() {
320 let metadata = AtprotoClientMetadata {
321 client_id: Url::from_str("https://example.com/client_metadata.json").unwrap(),
322 client_uri: Some(Url::from_str("https://example.com").unwrap()),
323 redirect_uris: vec![Url::from_str("https://example.com/callback").unwrap()],
324 grant_types: vec![GrantType::AuthorizationCode],
325 scopes: vec![Scope::Atproto],
326 jwks_uri: None,
327 };
328 {
329 let metadata = metadata.clone();
330 let err = atproto_client_metadata(metadata, &None).expect_err("expected to fail");
331 assert!(matches!(err, Error::EmptyJwks));
332 }
333 {
334 let metadata = metadata.clone();
335 let secret_key = SecretKey::<p256::NistP256>::from_pkcs8_pem(PRIVATE_KEY)
336 .expect("failed to parse private key");
337 let keys = vec![Jwk {
338 key: Key::from(&secret_key.into()),
339 prm: Parameters {
340 kid: Some(String::from("kid00")),
341 ..Default::default()
342 },
343 }];
344 let keyset = Keyset::try_from(keys.clone()).expect("failed to create keyset");
345 assert_eq!(
346 atproto_client_metadata(metadata, &Some(keyset.clone()))
347 .expect("failed to convert metadata"),
348 OAuthClientMetadata {
349 client_id: Url::from_str("https://example.com/client_metadata.json").unwrap(),
350 client_uri: Some(Url::from_str("https://example.com").unwrap()),
351 redirect_uris: vec![Url::from_str("https://example.com/callback").unwrap()],
352 scope: Some(CowStr::new_static("atproto")),
353 grant_types: Some(vec![CowStr::new_static("authorization_code")]),
354 token_endpoint_auth_method: Some(AuthMethod::PrivateKeyJwt.into()),
355 dpop_bound_access_tokens: Some(true),
356 jwks_uri: None,
357 jwks: Some(keyset.public_jwks()),
358 token_endpoint_auth_signing_alg: Some(CowStr::new_static("ES256")),
359 }
360 );
361 }
362 }
363}