Fetch User Keys - simple tool for fetching SSH keys from various sources
1// SPDX-FileCopyrightText: 2025 Łukasz Niemier <#@hauleth.dev> 2// 3// SPDX-License-Identifier: EUPL-1.2 4 5use std::fmt; 6use std::str::FromStr; 7 8use super::helpers; 9 10use serde::Deserialize; 11use ssh_key::PublicKey; 12 13#[derive(Debug)] 14pub struct DID { 15 method: String, 16 id: String, 17} 18 19impl FromStr for DID { 20 type Err = (); 21 22 fn from_str(input: &str) -> Result<Self, ()> { 23 if !input.starts_with("did:") { 24 return Err(()); 25 } 26 if input.ends_with(":") { 27 return Err(()); 28 } 29 30 let chunks: Box<[_]> = input.splitn(3, ":").collect(); 31 32 if chunks.len() != 3 { 33 return Err(()); 34 } 35 36 Ok(DID { 37 method: chunks[1].into(), 38 id: chunks[2].into(), 39 }) 40 } 41} 42 43impl fmt::Display for DID { 44 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { 45 write!(f, "did:{}:{}", self.method, self.id) 46 } 47} 48 49#[derive(Debug)] 50pub struct Handle(String); 51 52impl fmt::Display for Handle { 53 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { 54 write!(f, "{}", self.0) 55 } 56} 57 58impl FromStr for Handle { 59 type Err = (); 60 61 fn from_str(input: &str) -> Result<Self, ()> { 62 if input.len() > 253 { 63 return Err(()); 64 } 65 66 let input = { 67 let mut input = input.to_ascii_lowercase(); 68 if input.starts_with("@") { 69 input.remove(0); 70 } 71 72 input 73 }; 74 75 let segments: Box<[_]> = input.split('.').collect(); 76 77 if segments.len() < 2 { 78 return Err(()); 79 } 80 81 if !segments.iter().all(|&s| legal_segment(s)) { 82 return Err(()); 83 } 84 85 if (b'0'..=b'9').contains(&segments.last().unwrap().as_bytes()[0]) { 86 return Err(()); 87 } 88 89 Ok(Handle(input)) 90 } 91} 92 93pub struct InvalidHandle(String); 94 95impl fmt::Display for InvalidHandle { 96 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { 97 write!(f, "Invalid handle: {}", self.0) 98 } 99} 100 101#[derive(Debug)] 102pub enum Identifier { 103 DID(DID), 104 Handle(Handle), 105} 106 107impl FromStr for Identifier { 108 type Err = InvalidHandle; 109 110 fn from_str(input: &str) -> Result<Self, Self::Err> { 111 input 112 .parse() 113 .map(Identifier::DID) 114 .or_else(|_| input.parse().map(Identifier::Handle)) 115 .map_err(|_| InvalidHandle(input.into())) 116 } 117} 118 119impl fmt::Display for Identifier { 120 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { 121 match *self { 122 Identifier::DID(ref did) => write!(f, "{}", did), 123 Identifier::Handle(ref handle) => write!(f, "{}", handle), 124 } 125 } 126} 127 128#[derive(Debug, Deserialize)] 129pub struct ATProto { 130 #[serde(default = "default_atproto")] 131 pub host: String, 132 #[serde(deserialize_with = "helpers::from_str")] 133 pub handle: Identifier, 134} 135 136impl FromStr for ATProto { 137 type Err = InvalidHandle; 138 139 fn from_str(input: &str) -> Result<Self, InvalidHandle> { 140 Ok(ATProto { 141 host: default_atproto(), 142 handle: input.parse().map_err(|_| InvalidHandle(input.into()))?, 143 }) 144 } 145} 146 147fn legal_segment(segment: &str) -> bool { 148 let bytes = segment.as_bytes(); 149 segment != "" 150 && bytes.into_iter().all(|&b| allowed_byte(b)) 151 && *bytes.first().unwrap() != b'-' 152 && *bytes.last().unwrap() != b'-' 153} 154 155fn allowed_byte(c: u8) -> bool { 156 (b'0'..=b'9').contains(&c) || (b'a'..=b'z').contains(&c) || c == b'-' 157} 158 159fn default_atproto() -> String { 160 "https://bsky.social/".into() 161} 162 163mod resp { 164 use serde::Deserialize; 165 use ssh_key::PublicKey; 166 167 #[derive(Debug, Deserialize)] 168 pub struct Resp { 169 pub records: Box<[Record]>, 170 } 171 172 #[derive(Debug, Deserialize)] 173 pub struct Record { 174 value: Value, 175 } 176 177 #[derive(Debug, Deserialize)] 178 pub struct Value { 179 key: String, 180 } 181 182 impl Into<PublicKey> for &Record { 183 fn into(self) -> PublicKey { 184 PublicKey::from_openssh(&self.value.key).unwrap() 185 } 186 } 187} 188 189impl super::Fetch for ATProto { 190 fn fetch(&self) -> Vec<PublicKey> { 191 let mut url = url::Url::parse(&self.host).unwrap(); 192 193 url.query_pairs_mut() 194 .append_pair("repo", &self.handle.to_string()) 195 .append_pair("collection", "sh.tangled.publicKey"); 196 197 url.set_path("xrpc/com.atproto.repo.listRecords"); 198 199 let data = ureq::get(&url.to_string()) 200 .call() 201 .unwrap() 202 .body_mut() 203 .read_to_string() 204 .unwrap(); 205 206 let decoded: resp::Resp = serde_json::from_str(&data).unwrap(); 207 208 decoded.records.iter().map(Into::into).collect() 209 } 210}