1#![cfg(feature = "loopback")]
2
3use crate::{
4 atproto::AtprotoClientMetadata,
5 authstore::ClientAuthStore,
6 client::OAuthClient,
7 dpop::DpopExt,
8 error::{CallbackError, OAuthError},
9 resolver::OAuthResolver,
10 scopes::Scope,
11 types::{AuthorizeOptions, CallbackParams},
12};
13use jacquard_common::{IntoStatic, cowstr::ToCowStr};
14use rouille::Server;
15use std::net::SocketAddr;
16use tokio::sync::mpsc;
17use url::Url;
18
19#[derive(Clone, Debug)]
20pub enum LoopbackPort {
21 Fixed(u16),
22 Ephemeral,
23}
24
25#[derive(Clone, Debug)]
26pub struct LoopbackConfig {
27 pub host: String,
28 pub port: LoopbackPort,
29 pub open_browser: bool,
30 pub timeout_ms: u64,
31}
32
33impl Default for LoopbackConfig {
34 fn default() -> Self {
35 Self {
36 host: "127.0.0.1".into(),
37 port: LoopbackPort::Fixed(4000),
38 open_browser: true,
39 timeout_ms: 5 * 60 * 1000,
40 }
41 }
42}
43
44#[cfg(feature = "browser-open")]
45fn try_open_in_browser(url: &str) -> bool {
46 webbrowser::open(url).is_ok()
47}
48#[cfg(not(feature = "browser-open"))]
49fn try_open_in_browser(_url: &str) -> bool {
50 false
51}
52
53pub fn create_callback_router(
54 request: &rouille::Request,
55 tx: mpsc::Sender<CallbackParams>,
56) -> rouille::Response {
57 rouille::router!(request,
58 (GET) (/oauth/callback) => {
59 let state = request.get_param("state").unwrap();
60 let code = request.get_param("code").unwrap();
61 let iss = request.get_param("iss").unwrap();
62 let callback_params = CallbackParams {
63 state: Some(state.to_cowstr().into_static()),
64 code: code.to_cowstr().into_static(),
65 iss: Some(iss.to_cowstr().into_static()),
66 };
67 tx.try_send(callback_params).unwrap();
68 rouille::Response::text("Logged in!")
69 },
70 _ => rouille::Response::empty_404()
71 )
72}
73
74struct CallbackHandle {
75 #[allow(dead_code)]
76 server_handle: std::thread::JoinHandle<()>,
77 server_stop: std::sync::mpsc::Sender<()>,
78 callback_rx: mpsc::Receiver<CallbackParams<'static>>,
79}
80
81fn one_shot_server(addr: SocketAddr) -> (SocketAddr, CallbackHandle) {
82 let (tx, callback_rx) = mpsc::channel(5);
83 let server = Server::new(addr, move |request| {
84 create_callback_router(request, tx.clone())
85 })
86 .expect("Could not start server");
87 let (server_handle, server_stop) = server.stoppable();
88 let handle = CallbackHandle {
89 server_handle,
90 server_stop,
91 callback_rx,
92 };
93 (addr, handle)
94}
95
96impl<T, S> OAuthClient<T, S>
97where
98 T: OAuthResolver + DpopExt + Send + Sync + 'static,
99 S: ClientAuthStore + Send + Sync + 'static,
100{
101 /// Drive the full OAuth flow using a local loopback server.
102 pub async fn login_with_local_server(
103 &self,
104 input: impl AsRef<str>,
105 opts: AuthorizeOptions<'_>,
106 cfg: LoopbackConfig,
107 ) -> crate::error::Result<super::client::OAuthSession<T, S>> {
108 let port = match cfg.port {
109 LoopbackPort::Fixed(p) => p,
110 LoopbackPort::Ephemeral => 0,
111 };
112 // TODO: fix this to it also accepts ipv6 and properly finds a free port
113 let bind_addr: SocketAddr = format!("0.0.0.0:{}", port)
114 .parse()
115 .expect("invalid loopback host/port");
116 let (local_addr, handle) = one_shot_server(bind_addr);
117 println!("Listening on {}", local_addr);
118 // build redirect uri
119 let redirect = Url::parse(&format!(
120 "http://{}:{}/oauth/callback",
121 cfg.host,
122 local_addr.port(),
123 ))
124 .unwrap();
125 let client_data = crate::session::ClientData {
126 keyset: self.registry.client_data.keyset.clone(),
127 config: AtprotoClientMetadata::new_localhost(
128 Some(vec![redirect.clone()]),
129 Some(vec![
130 Scope::Atproto,
131 Scope::Transition(crate::scopes::TransitionScope::Generic),
132 ]),
133 ),
134 };
135
136 // Build client using store and resolver
137 let flow_client = OAuthClient::new_with_shared(
138 self.registry.store.clone(),
139 self.client.clone(),
140 client_data.clone(),
141 );
142
143 // Start auth and get authorization URL
144 let auth_url = flow_client.start_auth(input.as_ref(), opts).await?;
145 // Print URL for copy/paste
146 println!("To authenticate with your PDS, visit:\n{}\n", auth_url);
147 // Optionally open browser
148 if cfg.open_browser {
149 let _ = try_open_in_browser(&auth_url);
150 }
151
152 // Await callback or timeout
153 let mut callback_rx = handle.callback_rx;
154 let cb = tokio::time::timeout(
155 std::time::Duration::from_millis(cfg.timeout_ms),
156 callback_rx.recv(),
157 )
158 .await;
159 // trigger shutdown
160 let _ = handle.server_stop.send(());
161 if let Ok(Some(cb)) = cb {
162 // Handle callback and create a session
163 Ok(flow_client.callback(cb).await?)
164 } else {
165 Err(OAuthError::Callback(CallbackError::Timeout))
166 }
167 }
168}