Constellation, Spacedust, Slingshot, UFOs: atproto crates and services for microcosm

consolidate oauth handling

+1
Cargo.lock
···
"rand 0.9.1",
"serde",
"serde_json",
"tokio",
"tokio-util",
"url",
···
"rand 0.9.1",
"serde",
"serde_json",
+
"thiserror 2.0.12",
"tokio",
"tokio-util",
"url",
+1
who-am-i/Cargo.toml
···
rand = "0.9.1"
serde = { version = "1.0.219", features = ["derive"] }
serde_json = "1.0.140"
tokio = { version = "1.45.1", features = ["full", "macros"] }
tokio-util = "0.7.15"
url = "2.5.4"
···
rand = "0.9.1"
serde = { version = "1.0.219", features = ["derive"] }
serde_json = "1.0.140"
+
thiserror = "2.0.12"
tokio = { version = "1.45.1", features = ["full", "macros"] }
tokio-util = "0.7.15"
url = "2.5.4"
-34
who-am-i/src/dns_resolver.rs
···
-
// originally from weaver: https://github.com/rsform/weaver/blob/ee08213a85e09889b9bd66beceecee92ac025801/crates/weaver-common/src/resolver.rs
-
// MPL 2.0: https://github.com/rsform/weaver/blob/ee08213a85e09889b9bd66beceecee92ac025801/LICENSE
-
-
use atrium_identity::handle::DnsTxtResolver;
-
use hickory_resolver::TokioResolver;
-
-
pub struct HickoryDnsTxtResolver {
-
resolver: TokioResolver,
-
}
-
-
impl Default for HickoryDnsTxtResolver {
-
fn default() -> Self {
-
Self {
-
resolver: TokioResolver::builder_tokio()
-
.expect("failed to create resolver")
-
.build(),
-
}
-
}
-
}
-
-
impl DnsTxtResolver for HickoryDnsTxtResolver {
-
async fn resolve(
-
&self,
-
query: &str,
-
) -> core::result::Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>> {
-
Ok(self
-
.resolver
-
.txt_lookup(query)
-
.await?
-
.iter()
-
.map(|txt| txt.to_string())
-
.collect())
-
}
-
}
···
+17 -2
who-am-i/src/expiring_task_map.rs
···
use tokio::time::sleep;
use tokio_util::sync::{CancellationToken, DropGuard};
-
#[derive(Clone)]
pub struct ExpiringTaskMap<T>(TaskMap<T>);
impl<T: Send + 'static> ExpiringTaskMap<T> {
pub fn new(expiration: Duration) -> Self {
···
}
}
-
#[derive(Clone)]
struct TaskMap<T> {
map: Arc<DashMap<String, (DropGuard, JoinHandle<T>)>>,
expiration: Duration,
}
···
use tokio::time::sleep;
use tokio_util::sync::{CancellationToken, DropGuard};
pub struct ExpiringTaskMap<T>(TaskMap<T>);
+
+
/// need to manually implement clone because T is allowed to not be clone
+
impl<T> Clone for ExpiringTaskMap<T> {
+
fn clone(&self) -> Self {
+
Self(self.0.clone())
+
}
+
}
impl<T: Send + 'static> ExpiringTaskMap<T> {
pub fn new(expiration: Duration) -> Self {
···
}
}
struct TaskMap<T> {
map: Arc<DashMap<String, (DropGuard, JoinHandle<T>)>>,
expiration: Duration,
}
+
+
/// need to manually implement clone because T is allowed to not be clone
+
impl<T> Clone for TaskMap<T> {
+
fn clone(&self) -> Self {
+
Self {
+
map: self.map.clone(),
+
expiration: self.expiration,
+
}
+
}
+
}
-22
who-am-i/src/identity_resolver.rs
···
-
use atrium_api::types::string::Did;
-
use atrium_common::resolver::Resolver;
-
use atrium_identity::did::{CommonDidResolver, CommonDidResolverConfig, DEFAULT_PLC_DIRECTORY_URL};
-
use atrium_oauth::DefaultHttpClient;
-
use std::sync::Arc;
-
-
pub async fn resolve_identity(did: String) -> Option<String> {
-
let http_client = Arc::new(DefaultHttpClient::default());
-
let resolver = CommonDidResolver::new(CommonDidResolverConfig {
-
plc_directory_url: DEFAULT_PLC_DIRECTORY_URL.to_string(),
-
http_client: Arc::clone(&http_client),
-
});
-
let doc = resolver.resolve(&Did::new(did).unwrap()).await.unwrap(); // TODO: this is only half the resolution? or is atrium checking dns?
-
// tokio::time::sleep(std::time::Duration::from_secs(2)).await;
-
doc.also_known_as.and_then(|mut aka| {
-
if aka.is_empty() {
-
None
-
} else {
-
Some(aka.remove(0))
-
}
-
})
-
}
···
+1 -5
who-am-i/src/lib.rs
···
-
mod dns_resolver;
mod expiring_task_map;
-
mod identity_resolver;
mod oauth;
mod server;
-
pub use dns_resolver::HickoryDnsTxtResolver;
pub use expiring_task_map::ExpiringTaskMap;
-
pub use identity_resolver::resolve_identity;
-
pub use oauth::{Client, authorize, client};
pub use server::serve;
···
mod expiring_task_map;
mod oauth;
mod server;
pub use expiring_task_map::ExpiringTaskMap;
+
pub use oauth::{OAuth, OauthCallbackParams, ResolveHandleError};
pub use server::serve;
+197 -47
who-am-i/src/oauth.rs
···
-
use crate::HickoryDnsTxtResolver;
use atrium_identity::{
did::{CommonDidResolver, CommonDidResolverConfig, DEFAULT_PLC_DIRECTORY_URL},
-
handle::{AtprotoHandleResolver, AtprotoHandleResolverConfig},
};
use atrium_oauth::{
-
AtprotoLocalhostClientMetadata, AuthorizeOptions, DefaultHttpClient, KnownScope, OAuthClient,
-
OAuthClientConfig, OAuthResolverConfig, Scope,
store::{session::MemorySessionStore, state::MemoryStateStore},
};
use std::sync::Arc;
-
pub type Client = OAuthClient<
MemoryStateStore,
MemorySessionStore,
CommonDidResolver<DefaultHttpClient>,
AtprotoHandleResolver<HickoryDnsTxtResolver, DefaultHttpClient>,
>;
-
pub fn client() -> Client {
-
let http_client = Arc::new(DefaultHttpClient::default());
-
let config = OAuthClientConfig {
-
client_metadata: AtprotoLocalhostClientMetadata {
-
redirect_uris: Some(vec![String::from("http://127.0.0.1:9997/authorized")]),
-
scopes: Some(vec![Scope::Known(KnownScope::Atproto)]),
-
},
-
keys: None,
-
resolver: OAuthResolverConfig {
-
did_resolver: CommonDidResolver::new(CommonDidResolverConfig {
-
plc_directory_url: DEFAULT_PLC_DIRECTORY_URL.to_string(),
-
http_client: Arc::clone(&http_client),
-
}),
-
handle_resolver: AtprotoHandleResolver::new(AtprotoHandleResolverConfig {
-
dns_txt_resolver: HickoryDnsTxtResolver::default(),
-
http_client: Arc::clone(&http_client),
-
}),
-
authorization_server_metadata: Default::default(),
-
protected_resource_metadata: Default::default(),
-
},
-
// A store for saving state data while the user is being redirected to the authorization server.
-
state_store: MemoryStateStore::default(),
-
// A store for saving session data.
-
session_store: MemorySessionStore::default(),
-
};
-
let Ok(client) = OAuthClient::new(config) else {
-
panic!("failed to create oauth client");
-
};
-
client
}
-
pub async fn authorize(client: &Client, handle: &str) -> String {
-
let Ok(url) = client
-
.authorize(
-
handle,
-
AuthorizeOptions {
-
scopes: vec![Scope::Known(KnownScope::Atproto)],
-
..Default::default()
},
-
)
-
.await
-
else {
-
panic!("failed to authorize");
-
};
-
url
}
···
+
use atrium_api::{agent::SessionManager, types::string::Did};
+
use atrium_common::resolver::Resolver;
use atrium_identity::{
did::{CommonDidResolver, CommonDidResolverConfig, DEFAULT_PLC_DIRECTORY_URL},
+
handle::{AtprotoHandleResolver, AtprotoHandleResolverConfig, DnsTxtResolver},
};
use atrium_oauth::{
+
AtprotoLocalhostClientMetadata, AuthorizeOptions, CallbackParams, DefaultHttpClient,
+
KnownScope, OAuthClient, OAuthClientConfig, OAuthResolverConfig, Scope,
store::{session::MemorySessionStore, state::MemoryStateStore},
};
+
use hickory_resolver::TokioResolver;
+
use serde::Deserialize;
use std::sync::Arc;
+
use thiserror::Error;
+
const READONLY_SCOPE: [Scope; 1] = [Scope::Known(KnownScope::Atproto)];
+
+
#[derive(Debug, Deserialize)]
+
pub struct CallbackErrorParams {
+
error: String,
+
error_description: Option<String>,
+
#[allow(dead_code)]
+
state: Option<String>, // TODO: we _should_ use state to associate the auth request but how to do that with atrium is unclear
+
iss: Option<String>,
+
}
+
+
#[derive(Debug, Deserialize)]
+
#[serde(untagged)]
+
pub enum OauthCallbackParams {
+
Granted(CallbackParams),
+
Failed(CallbackErrorParams),
+
}
+
+
type Client = OAuthClient<
MemoryStateStore,
MemorySessionStore,
CommonDidResolver<DefaultHttpClient>,
AtprotoHandleResolver<HickoryDnsTxtResolver, DefaultHttpClient>,
>;
+
#[derive(Clone)]
+
pub struct OAuth {
+
client: Arc<Client>,
+
did_resolver: Arc<CommonDidResolver<DefaultHttpClient>>,
}
+
#[derive(Debug, Error)]
+
#[error(transparent)]
+
pub struct AuthSetupError(#[from] atrium_oauth::Error);
+
+
#[derive(Debug, Error)]
+
#[error(transparent)]
+
pub struct AuthStartError(#[from] atrium_oauth::Error);
+
+
#[derive(Debug, Error)]
+
pub enum AuthCompleteError {
+
#[error("the user denied request: {description:?} (from {issuer:?})")]
+
Denied {
+
description: Option<String>,
+
issuer: Option<String>,
+
},
+
#[error(
+
"the request was denied for another reason: {error}: {description:?} (from {issuer:?})"
+
)]
+
Failed {
+
error: String,
+
description: Option<String>,
+
issuer: Option<String>,
+
},
+
#[error("failed to complete oauth callback: {0}")]
+
CallbackFailed(atrium_oauth::Error),
+
#[error("the authorized session did not contain a DID")]
+
NoDid,
+
}
+
+
#[derive(Debug, Error)]
+
pub enum ResolveHandleError {
+
#[error("failed to resolve: {0}")]
+
ResolutionFailed(#[from] atrium_identity::Error),
+
#[error("identity resolved but no handle found for user")]
+
NoHandle,
+
#[error("found handle {0:?} but it appears invalid: {1}")]
+
InvalidHandle(String, &'static str),
+
}
+
+
impl OAuth {
+
pub fn new() -> Result<Self, AuthSetupError> {
+
let http_client = Arc::new(DefaultHttpClient::default());
+
let did_resolver = || {
+
CommonDidResolver::new(CommonDidResolverConfig {
+
plc_directory_url: DEFAULT_PLC_DIRECTORY_URL.to_string(),
+
http_client: http_client.clone(),
+
})
+
};
+
let client_config = OAuthClientConfig {
+
client_metadata: AtprotoLocalhostClientMetadata {
+
redirect_uris: Some(vec![String::from("http://127.0.0.1:9997/authorized")]),
+
scopes: Some(READONLY_SCOPE.to_vec()),
},
+
keys: None,
+
resolver: OAuthResolverConfig {
+
did_resolver: did_resolver(),
+
handle_resolver: AtprotoHandleResolver::new(AtprotoHandleResolverConfig {
+
dns_txt_resolver: HickoryDnsTxtResolver::default(),
+
http_client: Arc::clone(&http_client),
+
}),
+
authorization_server_metadata: Default::default(),
+
protected_resource_metadata: Default::default(),
+
},
+
state_store: MemoryStateStore::default(),
+
session_store: MemorySessionStore::default(),
+
};
+
+
let client = OAuthClient::new(client_config)?;
+
+
Ok(Self {
+
client: Arc::new(client),
+
did_resolver: Arc::new(did_resolver()),
+
})
+
}
+
+
pub async fn begin(&self, handle: &str) -> Result<String, AuthStartError> {
+
let auth_opts = AuthorizeOptions {
+
scopes: READONLY_SCOPE.to_vec(),
+
..Default::default()
+
};
+
Ok(self.client.authorize(handle, auth_opts).await?)
+
}
+
+
/// Finally, resolve the oauth flow to a verified DID
+
pub async fn complete(&self, params: OauthCallbackParams) -> Result<Did, AuthCompleteError> {
+
let params = match params {
+
OauthCallbackParams::Granted(params) => params,
+
OauthCallbackParams::Failed(p) if p.error == "access_denied" => {
+
return Err(AuthCompleteError::Denied {
+
description: p.error_description.clone(),
+
issuer: p.iss.clone(),
+
});
+
}
+
OauthCallbackParams::Failed(p) => {
+
return Err(AuthCompleteError::Failed {
+
error: p.error.clone(),
+
description: p.error_description.clone(),
+
issuer: p.iss.clone(),
+
});
+
}
+
};
+
let (session, _) = self
+
.client
+
.callback(params)
+
.await
+
.map_err(AuthCompleteError::CallbackFailed)?;
+
let Some(did) = session.did().await else {
+
return Err(AuthCompleteError::NoDid);
+
};
+
Ok(did)
+
}
+
+
pub async fn resolve_handle(&self, did: Did) -> Result<String, ResolveHandleError> {
+
// TODO: this is only half the resolution? or is atrium checking dns?
+
let doc = self.did_resolver.resolve(&did).await?;
+
let Some(aka) = doc.also_known_as else {
+
return Err(ResolveHandleError::NoHandle);
+
};
+
let Some(at_uri_handle) = aka.first() else {
+
return Err(ResolveHandleError::NoHandle);
+
};
+
if aka.len() > 1 {
+
eprintln!("more than one handle found for {did:?}");
+
}
+
let Some(bare_handle) = at_uri_handle.strip_prefix("at://") else {
+
return Err(ResolveHandleError::InvalidHandle(
+
at_uri_handle.to_string(),
+
"did not start with 'at://'",
+
));
+
};
+
if bare_handle.is_empty() {
+
return Err(ResolveHandleError::InvalidHandle(
+
bare_handle.to_string(),
+
"empty handle",
+
));
+
}
+
Ok(bare_handle.to_string())
+
}
+
}
+
+
pub struct HickoryDnsTxtResolver {
+
resolver: TokioResolver,
+
}
+
+
impl Default for HickoryDnsTxtResolver {
+
fn default() -> Self {
+
Self {
+
resolver: TokioResolver::builder_tokio()
+
.expect("failed to create resolver")
+
.build(),
+
}
+
}
+
}
+
+
impl DnsTxtResolver for HickoryDnsTxtResolver {
+
async fn resolve(
+
&self,
+
query: &str,
+
) -> core::result::Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>> {
+
Ok(self
+
.resolver
+
.txt_lookup(query)
+
.await?
+
.iter()
+
.map(|txt| txt.to_string())
+
.collect())
+
}
}
+41 -27
who-am-i/src/server.rs
···
-
use atrium_api::agent::SessionManager;
-
use atrium_oauth::CallbackParams;
use axum::{
Router,
extract::{FromRef, Query, State},
···
use tokio_util::sync::CancellationToken;
use url::Url;
-
use crate::{Client, ExpiringTaskMap, authorize, client, resolve_identity};
const FAVICON: &[u8] = include_bytes!("../static/favicon.ico");
const INDEX_HTML: &str = include_str!("../static/index.html");
···
pub key: Key,
pub one_clicks: Arc<HashSet<String>>,
pub engine: AppEngine,
-
pub client: Arc<Client>,
-
pub resolving: ExpiringTaskMap<Option<String>>,
pub shutdown: CancellationToken,
}
···
// clients have to pick up their identity-resolving tasks within this period
let task_pickup_expiration = Duration::from_secs(15);
let state = AppState {
engine: Engine::new(hbs),
key: Key::from(app_secret.as_bytes()), // TODO: via config
one_clicks: Arc::new(HashSet::from_iter(one_click)),
-
client: Arc::new(client()),
resolving: ExpiringTaskMap::new(task_pickup_expiration),
shutdown: shutdown.clone(),
};
···
async fn prompt(
State(AppState {
engine,
-
one_clicks,
resolving,
shutdown,
..
···
.into_response();
}
if let Some(did) = jar.get(DID_COOKIE_KEY) {
-
let did = did.value_trimmed().to_string();
-
let task_shutdown = shutdown.child_token();
-
let fetch_key = resolving.dispatch(resolve_identity(did.clone()), task_shutdown);
RenderHtml(
"prompt-known",
···
let Some(task_handle) = resolving.take(&params.fetch_key) else {
return "oops, task does not exist or is gone".into_response();
};
-
if let Some(handle) = task_handle.await.unwrap() {
-
// TODO: get active state etc.
-
// ...but also, that's a bsky thing?
-
let Some(handle) = handle.strip_prefix("at://") else {
-
return "hmm, handle did not start with at://".into_response();
-
};
Json(json!({ "handle": handle })).into_response()
} else {
"no handle?".into_response()
···
handle: String,
}
async fn start_oauth(
-
State(state): State<AppState>,
Query(params): Query<BeginOauthParams>,
jar: SignedCookieJar,
) -> (SignedCookieJar, Redirect) {
// if any existing session was active, clear it first
let jar = jar.remove(DID_COOKIE_KEY);
-
let auth_url = authorize(&state.client, &params.handle).await;
(jar, Redirect::to(&auth_url))
}
async fn complete_oauth(
-
State(state): State<AppState>,
-
Query(params): Query<CallbackParams>,
jar: SignedCookieJar,
) -> (SignedCookieJar, impl IntoResponse) {
-
let Ok((oauth_session, _)) = state.client.callback(params).await else {
panic!("failed to do client callback");
};
-
let did = oauth_session.did().await.expect("a did to be present");
let cookie = Cookie::build((DID_COOKIE_KEY, did.to_string()))
.http_only(true)
···
let jar = jar.add(cookie);
-
let task_shutdown = state.shutdown.child_token();
-
let fetch_key = state
-
.resolving
-
.dispatch(resolve_identity(did.to_string()), task_shutdown);
(
jar,
RenderHtml(
"authorized",
-
state.engine,
json!({
"did": did,
"fetch_key": fetch_key,
···
+
use atrium_api::types::string::Did;
use axum::{
Router,
extract::{FromRef, Query, State},
···
use tokio_util::sync::CancellationToken;
use url::Url;
+
use crate::{ExpiringTaskMap, OAuth, OauthCallbackParams, ResolveHandleError};
const FAVICON: &[u8] = include_bytes!("../static/favicon.ico");
const INDEX_HTML: &str = include_str!("../static/index.html");
···
pub key: Key,
pub one_clicks: Arc<HashSet<String>>,
pub engine: AppEngine,
+
pub oauth: Arc<OAuth>,
+
pub resolving: ExpiringTaskMap<Result<String, ResolveHandleError>>,
pub shutdown: CancellationToken,
}
···
// clients have to pick up their identity-resolving tasks within this period
let task_pickup_expiration = Duration::from_secs(15);
+
let oauth = OAuth::new().unwrap();
+
let state = AppState {
engine: Engine::new(hbs),
key: Key::from(app_secret.as_bytes()), // TODO: via config
one_clicks: Arc::new(HashSet::from_iter(one_click)),
+
oauth: Arc::new(oauth),
resolving: ExpiringTaskMap::new(task_pickup_expiration),
shutdown: shutdown.clone(),
};
···
async fn prompt(
State(AppState {
+
one_clicks,
engine,
+
oauth,
resolving,
shutdown,
..
···
.into_response();
}
if let Some(did) = jar.get(DID_COOKIE_KEY) {
+
let Ok(did) = Did::new(did.value_trimmed().to_string()) else {
+
return "did from cookie failed to parse".into_response();
+
};
+
let fetch_key = resolving.dispatch(
+
{
+
let oauth = oauth.clone();
+
let did = did.clone();
+
async move { oauth.resolve_handle(did.clone()).await }
+
},
+
shutdown.child_token(),
+
);
RenderHtml(
"prompt-known",
···
let Some(task_handle) = resolving.take(&params.fetch_key) else {
return "oops, task does not exist or is gone".into_response();
};
+
if let Ok(handle) = task_handle.await.unwrap() {
Json(json!({ "handle": handle })).into_response()
} else {
"no handle?".into_response()
···
handle: String,
}
async fn start_oauth(
+
State(AppState { oauth, .. }): State<AppState>,
Query(params): Query<BeginOauthParams>,
jar: SignedCookieJar,
) -> (SignedCookieJar, Redirect) {
// if any existing session was active, clear it first
let jar = jar.remove(DID_COOKIE_KEY);
+
let auth_url = oauth.begin(&params.handle).await.unwrap();
(jar, Redirect::to(&auth_url))
}
async fn complete_oauth(
+
State(AppState {
+
engine,
+
resolving,
+
oauth,
+
shutdown,
+
..
+
}): State<AppState>,
+
Query(params): Query<OauthCallbackParams>,
jar: SignedCookieJar,
) -> (SignedCookieJar, impl IntoResponse) {
+
let Ok(did) = oauth.complete(params).await else {
panic!("failed to do client callback");
};
let cookie = Cookie::build((DID_COOKIE_KEY, did.to_string()))
.http_only(true)
···
let jar = jar.add(cookie);
+
let fetch_key = resolving.dispatch(
+
{
+
let oauth = oauth.clone();
+
let did = did.clone();
+
async move { oauth.resolve_handle(did.clone()).await }
+
},
+
shutdown.child_token(),
+
);
(
jar,
RenderHtml(
"authorized",
+
engine,
json!({
"did": did,
"fetch_key": fetch_key,