forked from
microcosm.blue/microcosm-rs
Constellation, Spacedust, Slingshot, UFOs: atproto crates and services for microcosm
1use jose_jwk::Class;
2use jose_jwk::Jwk;
3use jose_jwk::Key;
4use jose_jwk::Parameters;
5use std::fs;
6use std::path::PathBuf;
7// use p256::SecretKey;
8use atrium_api::{agent::SessionManager, types::string::Did};
9use atrium_common::resolver::Resolver;
10use atrium_identity::{
11 did::{CommonDidResolver, CommonDidResolverConfig, DEFAULT_PLC_DIRECTORY_URL},
12 handle::{AtprotoHandleResolver, AtprotoHandleResolverConfig, DnsTxtResolver},
13};
14use atrium_oauth::{
15 AtprotoClientMetadata, AtprotoLocalhostClientMetadata, AuthMethod, AuthorizeOptions,
16 CallbackParams, DefaultHttpClient, GrantType, KnownScope, OAuthClient, OAuthClientConfig,
17 OAuthClientMetadata, OAuthResolverConfig, Scope,
18 store::{session::MemorySessionStore, state::MemoryStateStore},
19};
20use elliptic_curve::SecretKey;
21use hickory_resolver::{ResolveError, TokioResolver};
22use jose_jwk::JwkSet;
23use pkcs8::DecodePrivateKey;
24use serde::Deserialize;
25use std::sync::Arc;
26use thiserror::Error;
27
28const READONLY_SCOPE: [Scope; 1] = [Scope::Known(KnownScope::Atproto)];
29
30#[derive(Debug, Deserialize)]
31pub struct CallbackErrorParams {
32 error: String,
33 error_description: Option<String>,
34 #[allow(dead_code)]
35 state: Option<String>, // TODO: we _should_ use state to associate the auth request but how to do that with atrium is unclear
36 iss: Option<String>,
37}
38
39#[derive(Debug, Deserialize)]
40#[serde(untagged)]
41pub enum OAuthCallbackParams {
42 Granted(CallbackParams),
43 Failed(CallbackErrorParams),
44}
45
46type Client = OAuthClient<
47 MemoryStateStore,
48 MemorySessionStore,
49 CommonDidResolver<DefaultHttpClient>,
50 AtprotoHandleResolver<HickoryDnsTxtResolver, DefaultHttpClient>,
51>;
52
53#[derive(Clone)]
54pub struct OAuth {
55 client: Arc<Client>,
56 did_resolver: Arc<CommonDidResolver<DefaultHttpClient>>,
57}
58
59#[derive(Debug, Error)]
60pub enum AuthSetupError {
61 #[error("failed to intiialize atrium client: {0}")]
62 AtriumClientError(atrium_oauth::Error),
63 #[error("failed to initialize hickory dns resolver: {0}")]
64 HickoryResolverError(ResolveError),
65}
66
67#[derive(Debug, Error)]
68pub enum OAuthCompleteError {
69 #[error("the user denied request: {description:?} (from {issuer:?})")]
70 Denied {
71 description: Option<String>,
72 issuer: Option<String>,
73 },
74 #[error("the request failed: {error}: {description:?} (from {issuer:?})")]
75 Failed {
76 error: String,
77 description: Option<String>,
78 issuer: Option<String>,
79 },
80 #[error("failed to complete oauth callback: {0}")]
81 CallbackFailed(atrium_oauth::Error),
82 #[error("the authorized session did not contain a DID")]
83 NoDid,
84}
85
86#[derive(Debug, Error)]
87pub enum ResolveHandleError {
88 #[error("failed to resolve: {0}")]
89 ResolutionFailed(#[from] atrium_identity::Error),
90 #[error("identity resolved but no handle found for user")]
91 NoHandle,
92 #[error("found handle {0:?} but it appears invalid: {1}")]
93 InvalidHandle(String, &'static str),
94}
95
96impl OAuth {
97 pub fn new(oauth_private_key: Option<PathBuf>, base: String) -> Result<Self, AuthSetupError> {
98 let http_client = Arc::new(DefaultHttpClient::default());
99 let did_resolver = || {
100 CommonDidResolver::new(CommonDidResolverConfig {
101 plc_directory_url: DEFAULT_PLC_DIRECTORY_URL.to_string(),
102 http_client: http_client.clone(),
103 })
104 };
105 let dns_txt_resolver =
106 HickoryDnsTxtResolver::new().map_err(AuthSetupError::HickoryResolverError)?;
107
108 let resolver = OAuthResolverConfig {
109 did_resolver: did_resolver(),
110 handle_resolver: AtprotoHandleResolver::new(AtprotoHandleResolverConfig {
111 dns_txt_resolver,
112 http_client: Arc::clone(&http_client),
113 }),
114 authorization_server_metadata: Default::default(),
115 protected_resource_metadata: Default::default(),
116 };
117
118 let state_store = MemoryStateStore::default();
119 let session_store = MemorySessionStore::default();
120
121 let client = if let Some(path) = oauth_private_key {
122 let key_contents: Vec<u8> = fs::read(path).unwrap();
123 let key_string = String::from_utf8(key_contents).unwrap();
124 let key = SecretKey::<p256::NistP256>::from_pkcs8_pem(&key_string)
125 .map(|secret_key| Jwk {
126 key: Key::from(&secret_key.into()),
127 prm: Parameters {
128 kid: Some("at-oauth-00".to_string()),
129 cls: Some(Class::Signing),
130 ..Default::default()
131 },
132 })
133 .expect("to get private key");
134 OAuthClient::new(OAuthClientConfig {
135 client_metadata: AtprotoClientMetadata {
136 client_id: format!("{base}/client-metadata.json"),
137 client_uri: Some(base.clone()),
138 redirect_uris: vec![format!("{base}/authorized")],
139 token_endpoint_auth_method: AuthMethod::PrivateKeyJwt,
140 grant_types: vec![GrantType::AuthorizationCode, GrantType::RefreshToken],
141 scopes: READONLY_SCOPE.to_vec(),
142 jwks_uri: Some(format!("{base}/.well-known/jwks.json")),
143 token_endpoint_auth_signing_alg: Some(String::from("ES256")),
144 },
145 keys: Some(vec![key]),
146 resolver,
147 state_store,
148 session_store,
149 })
150 .map_err(AuthSetupError::AtriumClientError)?
151 } else {
152 OAuthClient::new(OAuthClientConfig {
153 client_metadata: AtprotoLocalhostClientMetadata {
154 redirect_uris: Some(vec![String::from("http://127.0.0.1:9997/authorized")]),
155 scopes: Some(READONLY_SCOPE.to_vec()),
156 },
157 keys: None,
158 resolver,
159 state_store,
160 session_store,
161 })
162 .map_err(AuthSetupError::AtriumClientError)?
163 };
164
165 Ok(Self {
166 client: Arc::new(client),
167 did_resolver: Arc::new(did_resolver()),
168 })
169 }
170
171 pub fn client_metadata(&self) -> OAuthClientMetadata {
172 self.client.client_metadata.clone()
173 }
174
175 pub fn jwks(&self) -> JwkSet {
176 self.client.jwks()
177 }
178
179 pub async fn begin(&self, handle: &str) -> Result<String, atrium_oauth::Error> {
180 let auth_opts = AuthorizeOptions {
181 scopes: READONLY_SCOPE.to_vec(),
182 ..Default::default()
183 };
184 self.client.authorize(handle, auth_opts).await
185 }
186
187 /// Finally, resolve the oauth flow to a verified DID
188 pub async fn complete(&self, params: OAuthCallbackParams) -> Result<Did, OAuthCompleteError> {
189 let params = match params {
190 OAuthCallbackParams::Granted(params) => params,
191 OAuthCallbackParams::Failed(p) if p.error == "access_denied" => {
192 return Err(OAuthCompleteError::Denied {
193 description: p.error_description.clone(),
194 issuer: p.iss.clone(),
195 });
196 }
197 OAuthCallbackParams::Failed(p) => {
198 return Err(OAuthCompleteError::Failed {
199 error: p.error.clone(),
200 description: p.error_description.clone(),
201 issuer: p.iss.clone(),
202 });
203 }
204 };
205 let (session, _) = self
206 .client
207 .callback(params)
208 .await
209 .map_err(OAuthCompleteError::CallbackFailed)?;
210 let Some(did) = session.did().await else {
211 return Err(OAuthCompleteError::NoDid);
212 };
213 Ok(did)
214 }
215
216 pub async fn resolve_handle(&self, did: Did) -> Result<String, ResolveHandleError> {
217 // TODO: this is only half the resolution? or is atrium checking dns?
218 let doc = self.did_resolver.resolve(&did).await?;
219 let Some(aka) = doc.also_known_as else {
220 return Err(ResolveHandleError::NoHandle);
221 };
222 let Some(at_uri_handle) = aka.first() else {
223 return Err(ResolveHandleError::NoHandle);
224 };
225 if aka.len() > 1 {
226 eprintln!("more than one handle found for {did:?}");
227 }
228 let Some(bare_handle) = at_uri_handle.strip_prefix("at://") else {
229 return Err(ResolveHandleError::InvalidHandle(
230 at_uri_handle.to_string(),
231 "did not start with 'at://'",
232 ));
233 };
234 if bare_handle.is_empty() {
235 return Err(ResolveHandleError::InvalidHandle(
236 at_uri_handle.to_string(),
237 "empty handle",
238 ));
239 }
240 Ok(bare_handle.to_string())
241 }
242}
243
244pub struct HickoryDnsTxtResolver(TokioResolver);
245
246impl HickoryDnsTxtResolver {
247 fn new() -> Result<Self, ResolveError> {
248 Ok(Self(TokioResolver::builder_tokio()?.build()))
249 }
250}
251
252impl DnsTxtResolver for HickoryDnsTxtResolver {
253 async fn resolve(
254 &self,
255 query: &str,
256 ) -> core::result::Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>> {
257 match self.0.txt_lookup(query).await {
258 Ok(r) => {
259 metrics::counter!("whoami_resolve_dns_txt", "success" => "true").increment(1);
260 Ok(r.iter().map(|r| r.to_string()).collect())
261 }
262 Err(e) => {
263 metrics::counter!("whoami_resolve_dns_txt", "success" => "false").increment(1);
264 Err(e.into())
265 }
266 }
267 }
268}