Constellation, Spacedust, Slingshot, UFOs: atproto crates and services for microcosm
1use jose_jwk::Class; 2use jose_jwk::Jwk; 3use jose_jwk::Key; 4use jose_jwk::Parameters; 5use std::fs; 6use std::path::PathBuf; 7// use p256::SecretKey; 8use atrium_api::{agent::SessionManager, types::string::Did}; 9use atrium_common::resolver::Resolver; 10use atrium_identity::{ 11 did::{CommonDidResolver, CommonDidResolverConfig, DEFAULT_PLC_DIRECTORY_URL}, 12 handle::{AtprotoHandleResolver, AtprotoHandleResolverConfig, DnsTxtResolver}, 13}; 14use atrium_oauth::{ 15 AtprotoClientMetadata, AtprotoLocalhostClientMetadata, AuthMethod, AuthorizeOptions, 16 CallbackParams, DefaultHttpClient, GrantType, KnownScope, OAuthClient, OAuthClientConfig, 17 OAuthClientMetadata, OAuthResolverConfig, Scope, 18 store::{session::MemorySessionStore, state::MemoryStateStore}, 19}; 20use elliptic_curve::SecretKey; 21use hickory_resolver::{ResolveError, TokioResolver}; 22use jose_jwk::JwkSet; 23use pkcs8::DecodePrivateKey; 24use serde::Deserialize; 25use std::sync::Arc; 26use thiserror::Error; 27 28const READONLY_SCOPE: [Scope; 1] = [Scope::Known(KnownScope::Atproto)]; 29 30#[derive(Debug, Deserialize)] 31pub struct CallbackErrorParams { 32 error: String, 33 error_description: Option<String>, 34 #[allow(dead_code)] 35 state: Option<String>, // TODO: we _should_ use state to associate the auth request but how to do that with atrium is unclear 36 iss: Option<String>, 37} 38 39#[derive(Debug, Deserialize)] 40#[serde(untagged)] 41pub enum OAuthCallbackParams { 42 Granted(CallbackParams), 43 Failed(CallbackErrorParams), 44} 45 46type Client = OAuthClient< 47 MemoryStateStore, 48 MemorySessionStore, 49 CommonDidResolver<DefaultHttpClient>, 50 AtprotoHandleResolver<HickoryDnsTxtResolver, DefaultHttpClient>, 51>; 52 53#[derive(Clone)] 54pub struct OAuth { 55 client: Arc<Client>, 56 did_resolver: Arc<CommonDidResolver<DefaultHttpClient>>, 57} 58 59#[derive(Debug, Error)] 60pub enum AuthSetupError { 61 #[error("failed to intiialize atrium client: {0}")] 62 AtriumClientError(atrium_oauth::Error), 63 #[error("failed to initialize hickory dns resolver: {0}")] 64 HickoryResolverError(ResolveError), 65} 66 67#[derive(Debug, Error)] 68pub enum OAuthCompleteError { 69 #[error("the user denied request: {description:?} (from {issuer:?})")] 70 Denied { 71 description: Option<String>, 72 issuer: Option<String>, 73 }, 74 #[error("the request failed: {error}: {description:?} (from {issuer:?})")] 75 Failed { 76 error: String, 77 description: Option<String>, 78 issuer: Option<String>, 79 }, 80 #[error("failed to complete oauth callback: {0}")] 81 CallbackFailed(atrium_oauth::Error), 82 #[error("the authorized session did not contain a DID")] 83 NoDid, 84} 85 86#[derive(Debug, Error)] 87pub enum ResolveHandleError { 88 #[error("failed to resolve: {0}")] 89 ResolutionFailed(#[from] atrium_identity::Error), 90 #[error("identity resolved but no handle found for user")] 91 NoHandle, 92 #[error("found handle {0:?} but it appears invalid: {1}")] 93 InvalidHandle(String, &'static str), 94} 95 96impl OAuth { 97 pub fn new(oauth_private_key: Option<PathBuf>, base: String) -> Result<Self, AuthSetupError> { 98 let http_client = Arc::new(DefaultHttpClient::default()); 99 let did_resolver = || { 100 CommonDidResolver::new(CommonDidResolverConfig { 101 plc_directory_url: DEFAULT_PLC_DIRECTORY_URL.to_string(), 102 http_client: http_client.clone(), 103 }) 104 }; 105 let dns_txt_resolver = 106 HickoryDnsTxtResolver::new().map_err(AuthSetupError::HickoryResolverError)?; 107 108 let resolver = OAuthResolverConfig { 109 did_resolver: did_resolver(), 110 handle_resolver: AtprotoHandleResolver::new(AtprotoHandleResolverConfig { 111 dns_txt_resolver, 112 http_client: Arc::clone(&http_client), 113 }), 114 authorization_server_metadata: Default::default(), 115 protected_resource_metadata: Default::default(), 116 }; 117 118 let state_store = MemoryStateStore::default(); 119 let session_store = MemorySessionStore::default(); 120 121 let client = if let Some(path) = oauth_private_key { 122 let key_contents: Vec<u8> = fs::read(path).unwrap(); 123 let key_string = String::from_utf8(key_contents).unwrap(); 124 let key = SecretKey::<p256::NistP256>::from_pkcs8_pem(&key_string) 125 .map(|secret_key| Jwk { 126 key: Key::from(&secret_key.into()), 127 prm: Parameters { 128 kid: Some("at-oauth-00".to_string()), 129 cls: Some(Class::Signing), 130 ..Default::default() 131 }, 132 }) 133 .expect("to get private key"); 134 OAuthClient::new(OAuthClientConfig { 135 client_metadata: AtprotoClientMetadata { 136 client_id: format!("{base}/client-metadata.json"), 137 client_uri: Some(base.clone()), 138 redirect_uris: vec![format!("{base}/authorized")], 139 token_endpoint_auth_method: AuthMethod::PrivateKeyJwt, 140 grant_types: vec![GrantType::AuthorizationCode, GrantType::RefreshToken], 141 scopes: READONLY_SCOPE.to_vec(), 142 jwks_uri: Some(format!("{base}/.well-known/jwks.json")), 143 token_endpoint_auth_signing_alg: Some(String::from("ES256")), 144 }, 145 keys: Some(vec![key]), 146 resolver, 147 state_store, 148 session_store, 149 }) 150 .map_err(AuthSetupError::AtriumClientError)? 151 } else { 152 OAuthClient::new(OAuthClientConfig { 153 client_metadata: AtprotoLocalhostClientMetadata { 154 redirect_uris: Some(vec![String::from("http://127.0.0.1:9997/authorized")]), 155 scopes: Some(READONLY_SCOPE.to_vec()), 156 }, 157 keys: None, 158 resolver, 159 state_store, 160 session_store, 161 }) 162 .map_err(AuthSetupError::AtriumClientError)? 163 }; 164 165 Ok(Self { 166 client: Arc::new(client), 167 did_resolver: Arc::new(did_resolver()), 168 }) 169 } 170 171 pub fn client_metadata(&self) -> OAuthClientMetadata { 172 self.client.client_metadata.clone() 173 } 174 175 pub fn jwks(&self) -> JwkSet { 176 self.client.jwks() 177 } 178 179 pub async fn begin(&self, handle: &str) -> Result<String, atrium_oauth::Error> { 180 let auth_opts = AuthorizeOptions { 181 scopes: READONLY_SCOPE.to_vec(), 182 ..Default::default() 183 }; 184 self.client.authorize(handle, auth_opts).await 185 } 186 187 /// Finally, resolve the oauth flow to a verified DID 188 pub async fn complete(&self, params: OAuthCallbackParams) -> Result<Did, OAuthCompleteError> { 189 let params = match params { 190 OAuthCallbackParams::Granted(params) => params, 191 OAuthCallbackParams::Failed(p) if p.error == "access_denied" => { 192 return Err(OAuthCompleteError::Denied { 193 description: p.error_description.clone(), 194 issuer: p.iss.clone(), 195 }); 196 } 197 OAuthCallbackParams::Failed(p) => { 198 return Err(OAuthCompleteError::Failed { 199 error: p.error.clone(), 200 description: p.error_description.clone(), 201 issuer: p.iss.clone(), 202 }); 203 } 204 }; 205 let (session, _) = self 206 .client 207 .callback(params) 208 .await 209 .map_err(OAuthCompleteError::CallbackFailed)?; 210 let Some(did) = session.did().await else { 211 return Err(OAuthCompleteError::NoDid); 212 }; 213 Ok(did) 214 } 215 216 pub async fn resolve_handle(&self, did: Did) -> Result<String, ResolveHandleError> { 217 // TODO: this is only half the resolution? or is atrium checking dns? 218 let doc = self.did_resolver.resolve(&did).await?; 219 let Some(aka) = doc.also_known_as else { 220 return Err(ResolveHandleError::NoHandle); 221 }; 222 let Some(at_uri_handle) = aka.first() else { 223 return Err(ResolveHandleError::NoHandle); 224 }; 225 if aka.len() > 1 { 226 eprintln!("more than one handle found for {did:?}"); 227 } 228 let Some(bare_handle) = at_uri_handle.strip_prefix("at://") else { 229 return Err(ResolveHandleError::InvalidHandle( 230 at_uri_handle.to_string(), 231 "did not start with 'at://'", 232 )); 233 }; 234 if bare_handle.is_empty() { 235 return Err(ResolveHandleError::InvalidHandle( 236 at_uri_handle.to_string(), 237 "empty handle", 238 )); 239 } 240 Ok(bare_handle.to_string()) 241 } 242} 243 244pub struct HickoryDnsTxtResolver(TokioResolver); 245 246impl HickoryDnsTxtResolver { 247 fn new() -> Result<Self, ResolveError> { 248 Ok(Self(TokioResolver::builder_tokio()?.build())) 249 } 250} 251 252impl DnsTxtResolver for HickoryDnsTxtResolver { 253 async fn resolve( 254 &self, 255 query: &str, 256 ) -> core::result::Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>> { 257 match self.0.txt_lookup(query).await { 258 Ok(r) => { 259 metrics::counter!("whoami_resolve_dns_txt", "success" => "true").increment(1); 260 Ok(r.iter().map(|r| r.to_string()).collect()) 261 } 262 Err(e) => { 263 metrics::counter!("whoami_resolve_dns_txt", "success" => "false").increment(1); 264 Err(e.into()) 265 } 266 } 267 } 268}