Fetch User Keys - simple tool for fetching SSH keys from various sources
1// SPDX-FileCopyrightText: 2025 Łukasz Niemier <#@hauleth.dev>
2//
3// SPDX-License-Identifier: EUPL-1.2
4
5use std::fmt;
6use std::str::FromStr;
7
8use super::helpers;
9
10use serde::Deserialize;
11use ssh_key::PublicKey;
12
13#[derive(Debug)]
14pub struct DID {
15 method: String,
16 id: String,
17}
18
19impl FromStr for DID {
20 type Err = ();
21
22 fn from_str(input: &str) -> Result<Self, ()> {
23 if !input.starts_with("did:") {
24 return Err(());
25 }
26 if input.ends_with(":") {
27 return Err(());
28 }
29
30 let chunks: Box<[_]> = input.splitn(3, ":").collect();
31
32 if chunks.len() != 3 {
33 return Err(());
34 }
35
36 Ok(DID {
37 method: chunks[1].into(),
38 id: chunks[2].into(),
39 })
40 }
41}
42
43impl fmt::Display for DID {
44 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
45 write!(f, "did:{}:{}", self.method, self.id)
46 }
47}
48
49#[derive(Debug)]
50pub struct Handle(String);
51
52impl fmt::Display for Handle {
53 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
54 write!(f, "{}", self.0)
55 }
56}
57
58impl FromStr for Handle {
59 type Err = ();
60
61 fn from_str(input: &str) -> Result<Self, ()> {
62 if input.len() > 253 {
63 return Err(());
64 }
65
66 let input = {
67 let mut input = input.to_ascii_lowercase();
68 if input.starts_with("@") {
69 input.remove(0);
70 }
71
72 input
73 };
74
75 let segments: Box<[_]> = input.split('.').collect();
76
77 if segments.len() < 2 {
78 return Err(());
79 }
80
81 if !segments.iter().all(|&s| legal_segment(s)) {
82 return Err(());
83 }
84
85 if (b'0'..=b'9').contains(&segments.last().unwrap().as_bytes()[0]) {
86 return Err(());
87 }
88
89 Ok(Handle(input))
90 }
91}
92
93pub struct InvalidHandle(String);
94
95impl fmt::Display for InvalidHandle {
96 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
97 write!(f, "Invalid handle: {}", self.0)
98 }
99}
100
101#[derive(Debug)]
102pub enum Identifier {
103 DID(DID),
104 Handle(Handle),
105}
106
107impl FromStr for Identifier {
108 type Err = InvalidHandle;
109
110 fn from_str(input: &str) -> Result<Self, Self::Err> {
111 input
112 .parse()
113 .map(Identifier::DID)
114 .or_else(|_| input.parse().map(Identifier::Handle))
115 .map_err(|_| InvalidHandle(input.into()))
116 }
117}
118
119impl fmt::Display for Identifier {
120 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
121 match *self {
122 Identifier::DID(ref did) => write!(f, "{}", did),
123 Identifier::Handle(ref handle) => write!(f, "{}", handle),
124 }
125 }
126}
127
128#[derive(Debug, Deserialize)]
129pub struct ATProto {
130 #[serde(default = "default_atproto")]
131 pub host: String,
132 #[serde(deserialize_with = "helpers::from_str")]
133 pub handle: Identifier,
134}
135
136impl FromStr for ATProto {
137 type Err = InvalidHandle;
138
139 fn from_str(input: &str) -> Result<Self, InvalidHandle> {
140 Ok(ATProto {
141 host: default_atproto(),
142 handle: input.parse().map_err(|_| InvalidHandle(input.into()))?,
143 })
144 }
145}
146
147fn legal_segment(segment: &str) -> bool {
148 let bytes = segment.as_bytes();
149 segment != ""
150 && bytes.into_iter().all(|&b| allowed_byte(b))
151 && *bytes.first().unwrap() != b'-'
152 && *bytes.last().unwrap() != b'-'
153}
154
155fn allowed_byte(c: u8) -> bool {
156 (b'0'..=b'9').contains(&c) || (b'a'..=b'z').contains(&c) || c == b'-'
157}
158
159fn default_atproto() -> String {
160 "https://bsky.social/".into()
161}
162
163mod resp {
164 use serde::Deserialize;
165 use ssh_key::PublicKey;
166
167 #[derive(Debug, Deserialize)]
168 pub struct Resp {
169 pub records: Box<[Record]>,
170 }
171
172 #[derive(Debug, Deserialize)]
173 pub struct Record {
174 value: Value,
175 }
176
177 #[derive(Debug, Deserialize)]
178 pub struct Value {
179 key: String,
180 }
181
182 impl Into<PublicKey> for &Record {
183 fn into(self) -> PublicKey {
184 PublicKey::from_openssh(&self.value.key).unwrap()
185 }
186 }
187}
188
189impl super::Fetch for ATProto {
190 fn fetch(&self) -> Vec<PublicKey> {
191 let mut url = url::Url::parse(&self.host).unwrap();
192
193 url.query_pairs_mut()
194 .append_pair("repo", &self.handle.to_string())
195 .append_pair("collection", "sh.tangled.publicKey");
196
197 url.set_path("xrpc/com.atproto.repo.listRecords");
198
199 let data = ureq::get(&url.to_string())
200 .call()
201 .unwrap()
202 .body_mut()
203 .read_to_string()
204 .unwrap();
205
206 let decoded: resp::Resp = serde_json::from_str(&data).unwrap();
207
208 decoded.records.iter().map(Into::into).collect()
209 }
210}