A better Rust ATProto crate
at lifetimes 5.2 kB view raw
1#![cfg(feature = "loopback")] 2 3use crate::{ 4 atproto::AtprotoClientMetadata, 5 authstore::ClientAuthStore, 6 client::OAuthClient, 7 dpop::DpopExt, 8 error::{CallbackError, OAuthError}, 9 resolver::OAuthResolver, 10 scopes::Scope, 11 types::{AuthorizeOptions, CallbackParams}, 12}; 13use jacquard_common::{IntoStatic, cowstr::ToCowStr}; 14use rouille::Server; 15use std::net::SocketAddr; 16use tokio::sync::mpsc; 17use url::Url; 18 19#[derive(Clone, Debug)] 20pub enum LoopbackPort { 21 Fixed(u16), 22 Ephemeral, 23} 24 25#[derive(Clone, Debug)] 26pub struct LoopbackConfig { 27 pub host: String, 28 pub port: LoopbackPort, 29 pub open_browser: bool, 30 pub timeout_ms: u64, 31} 32 33impl Default for LoopbackConfig { 34 fn default() -> Self { 35 Self { 36 host: "127.0.0.1".into(), 37 port: LoopbackPort::Fixed(4000), 38 open_browser: true, 39 timeout_ms: 5 * 60 * 1000, 40 } 41 } 42} 43 44#[cfg(feature = "browser-open")] 45fn try_open_in_browser(url: &str) -> bool { 46 webbrowser::open(url).is_ok() 47} 48#[cfg(not(feature = "browser-open"))] 49fn try_open_in_browser(_url: &str) -> bool { 50 false 51} 52 53pub fn create_callback_router( 54 request: &rouille::Request, 55 tx: mpsc::Sender<CallbackParams>, 56) -> rouille::Response { 57 rouille::router!(request, 58 (GET) (/oauth/callback) => { 59 let state = request.get_param("state").unwrap(); 60 let code = request.get_param("code").unwrap(); 61 let iss = request.get_param("iss").unwrap(); 62 let callback_params = CallbackParams { 63 state: Some(state.to_cowstr().into_static()), 64 code: code.to_cowstr().into_static(), 65 iss: Some(iss.to_cowstr().into_static()), 66 }; 67 tx.try_send(callback_params).unwrap(); 68 rouille::Response::text("Logged in!") 69 }, 70 _ => rouille::Response::empty_404() 71 ) 72} 73 74struct CallbackHandle { 75 #[allow(dead_code)] 76 server_handle: std::thread::JoinHandle<()>, 77 server_stop: std::sync::mpsc::Sender<()>, 78 callback_rx: mpsc::Receiver<CallbackParams<'static>>, 79} 80 81fn one_shot_server(addr: SocketAddr) -> (SocketAddr, CallbackHandle) { 82 let (tx, callback_rx) = mpsc::channel(5); 83 let server = Server::new(addr, move |request| { 84 create_callback_router(request, tx.clone()) 85 }) 86 .expect("Could not start server"); 87 let (server_handle, server_stop) = server.stoppable(); 88 let handle = CallbackHandle { 89 server_handle, 90 server_stop, 91 callback_rx, 92 }; 93 (addr, handle) 94} 95 96impl<T, S> OAuthClient<T, S> 97where 98 T: OAuthResolver + DpopExt + Send + Sync + 'static, 99 S: ClientAuthStore + Send + Sync + 'static, 100{ 101 /// Drive the full OAuth flow using a local loopback server. 102 pub async fn login_with_local_server( 103 &self, 104 input: impl AsRef<str>, 105 opts: AuthorizeOptions<'_>, 106 cfg: LoopbackConfig, 107 ) -> crate::error::Result<super::client::OAuthSession<T, S>> { 108 let port = match cfg.port { 109 LoopbackPort::Fixed(p) => p, 110 LoopbackPort::Ephemeral => 0, 111 }; 112 // TODO: fix this to it also accepts ipv6 and properly finds a free port 113 let bind_addr: SocketAddr = format!("0.0.0.0:{}", port) 114 .parse() 115 .expect("invalid loopback host/port"); 116 let (local_addr, handle) = one_shot_server(bind_addr); 117 println!("Listening on {}", local_addr); 118 // build redirect uri 119 let redirect = Url::parse(&format!( 120 "http://{}:{}/oauth/callback", 121 cfg.host, 122 local_addr.port(), 123 )) 124 .unwrap(); 125 let client_data = crate::session::ClientData { 126 keyset: self.registry.client_data.keyset.clone(), 127 config: AtprotoClientMetadata::new_localhost( 128 Some(vec![redirect.clone()]), 129 Some(vec![ 130 Scope::Atproto, 131 Scope::Transition(crate::scopes::TransitionScope::Generic), 132 ]), 133 ), 134 }; 135 136 // Build client using store and resolver 137 let flow_client = OAuthClient::new_with_shared( 138 self.registry.store.clone(), 139 self.client.clone(), 140 client_data.clone(), 141 ); 142 143 // Start auth and get authorization URL 144 let auth_url = flow_client.start_auth(input.as_ref(), opts).await?; 145 // Print URL for copy/paste 146 println!("To authenticate with your PDS, visit:\n{}\n", auth_url); 147 // Optionally open browser 148 if cfg.open_browser { 149 let _ = try_open_in_browser(&auth_url); 150 } 151 152 // Await callback or timeout 153 let mut callback_rx = handle.callback_rx; 154 let cb = tokio::time::timeout( 155 std::time::Duration::from_millis(cfg.timeout_ms), 156 callback_rx.recv(), 157 ) 158 .await; 159 // trigger shutdown 160 let _ = handle.server_stop.send(()); 161 if let Ok(Some(cb)) = cb { 162 // Handle callback and create a session 163 Ok(flow_client.callback(cb).await?) 164 } else { 165 Err(OAuthError::Callback(CallbackError::Timeout)) 166 } 167 } 168}