A better Rust ATProto crate
1use crate::corpus::LexiconCorpus; 2use crate::error::{CodegenError, Result}; 3use crate::lexicon::{LexArrayItem, LexUserType}; 4use proc_macro2::TokenStream; 5use quote::quote; 6 7mod utils; 8mod names; 9mod lifetime; 10mod types; 11mod structs; 12mod xrpc; 13mod output; 14 15/// Code generator for lexicon types 16pub struct CodeGenerator<'c> { 17 corpus: &'c LexiconCorpus, 18 root_module: String, 19 /// Track namespace dependencies (namespace -> set of namespaces it depends on) 20 namespace_deps: 21 std::cell::RefCell<std::collections::HashMap<String, std::collections::HashSet<String>>>, 22} 23 24impl<'c> CodeGenerator<'c> { 25 /// Create a new code generator 26 pub fn new(corpus: &'c LexiconCorpus, root_module: impl Into<String>) -> Self { 27 Self { 28 corpus, 29 root_module: root_module.into(), 30 namespace_deps: std::cell::RefCell::new(std::collections::HashMap::new()), 31 } 32 } 33 34 /// Generate doc comment from optional description (wrapper for utils function) 35 fn generate_doc_comment(&self, desc: Option<&jacquard_common::CowStr>) -> TokenStream { 36 utils::generate_doc_comment(desc) 37 } 38 39 /// Generate code for a lexicon def 40 pub fn generate_def( 41 &self, 42 nsid: &str, 43 def_name: &str, 44 def: &LexUserType<'static>, 45 ) -> Result<TokenStream> { 46 match def { 47 LexUserType::Record(record) => self.generate_record(nsid, def_name, record), 48 LexUserType::Object(obj) => self.generate_object(nsid, def_name, obj), 49 LexUserType::XrpcQuery(query) => self.generate_query(nsid, def_name, query), 50 LexUserType::XrpcProcedure(proc) => self.generate_procedure(nsid, def_name, proc), 51 LexUserType::Token(token) => { 52 // Token types are marker structs that can be used as union refs 53 let type_name = self.def_to_type_name(nsid, def_name); 54 let ident = syn::Ident::new(&type_name, proc_macro2::Span::call_site()); 55 let doc = self.generate_doc_comment(token.description.as_ref()); 56 57 // Token name for Display impl (just the def name, not the full ref) 58 let token_name = def_name; 59 60 Ok(quote! { 61 #doc 62 #[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq, Hash, jacquard_derive::IntoStatic)] 63 pub struct #ident; 64 65 impl std::fmt::Display for #ident { 66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 67 write!(f, #token_name) 68 } 69 } 70 }) 71 } 72 LexUserType::String(s) if s.known_values.is_some() => { 73 self.generate_known_values_enum(nsid, def_name, s) 74 } 75 LexUserType::String(s) => { 76 // Plain string type alias 77 let type_name = self.def_to_type_name(nsid, def_name); 78 let ident = syn::Ident::new(&type_name, proc_macro2::Span::call_site()); 79 let rust_type = self.string_to_rust_type(s); 80 let doc = self.generate_doc_comment(s.description.as_ref()); 81 Ok(quote! { 82 #doc 83 pub type #ident<'a> = #rust_type; 84 }) 85 } 86 LexUserType::Integer(i) if i.r#enum.is_some() => { 87 self.generate_integer_enum(nsid, def_name, i) 88 } 89 LexUserType::Array(array) => { 90 // Top-level array becomes type alias to Vec<ItemType> 91 let type_name = self.def_to_type_name(nsid, def_name); 92 let ident = syn::Ident::new(&type_name, proc_macro2::Span::call_site()); 93 let doc = self.generate_doc_comment(array.description.as_ref()); 94 let needs_lifetime = self.array_item_needs_lifetime(&array.items); 95 96 // Check if items are a union - if so, generate the union enum first 97 if let LexArrayItem::Union(union) = &array.items { 98 let union_name = format!("{}Item", type_name); 99 let refs: Vec<_> = union.refs.iter().cloned().collect(); 100 let union_def = self.generate_union(nsid, &union_name, &refs, None, union.closed)?; 101 102 let union_ident = syn::Ident::new(&union_name, proc_macro2::Span::call_site()); 103 if needs_lifetime { 104 Ok(quote! { 105 #union_def 106 107 #doc 108 pub type #ident<'a> = Vec<#union_ident<'a>>; 109 }) 110 } else { 111 Ok(quote! { 112 #union_def 113 114 #doc 115 pub type #ident = Vec<#union_ident>; 116 }) 117 } 118 } else { 119 // Regular array item type 120 let item_type = self.array_item_to_rust_type(nsid, &array.items)?; 121 if needs_lifetime { 122 Ok(quote! { 123 #doc 124 pub type #ident<'a> = Vec<#item_type>; 125 }) 126 } else { 127 Ok(quote! { 128 #doc 129 pub type #ident = Vec<#item_type>; 130 }) 131 } 132 } 133 } 134 LexUserType::Boolean(_) 135 | LexUserType::Integer(_) 136 | LexUserType::Bytes(_) 137 | LexUserType::CidLink(_) 138 | LexUserType::Unknown(_) => { 139 // These are rarely top-level defs, but if they are, make type aliases 140 let type_name = self.def_to_type_name(nsid, def_name); 141 let ident = syn::Ident::new(&type_name, proc_macro2::Span::call_site()); 142 let (rust_type, needs_lifetime) = match def { 143 LexUserType::Boolean(_) => (quote! { bool }, false), 144 LexUserType::Integer(_) => (quote! { i64 }, false), 145 LexUserType::Bytes(_) => (quote! { bytes::Bytes }, false), 146 LexUserType::CidLink(_) => { 147 (quote! { jacquard_common::types::cid::CidLink<'a> }, true) 148 } 149 LexUserType::Unknown(_) => { 150 (quote! { jacquard_common::types::value::Data<'a> }, true) 151 } 152 _ => unreachable!(), 153 }; 154 if needs_lifetime { 155 Ok(quote! { 156 pub type #ident<'a> = #rust_type; 157 }) 158 } else { 159 Ok(quote! { 160 pub type #ident = #rust_type; 161 }) 162 } 163 } 164 LexUserType::Blob(_) => Err(CodegenError::unsupported( 165 format!("top-level def type {:?}", def), 166 nsid, 167 None::<String>, 168 )), 169 LexUserType::XrpcSubscription(sub) => self.generate_subscription(nsid, def_name, sub), 170 } 171 } 172} 173 174#[cfg(test)] 175mod tests { 176 use super::*; 177 178 #[test] 179 fn test_generate_record() { 180 let corpus = 181 LexiconCorpus::load_from_dir("tests/fixtures/test_lexicons").expect("load corpus"); 182 let codegen = CodeGenerator::new(&corpus, "jacquard_api"); 183 184 let doc = corpus.get("app.bsky.feed.post").expect("get post"); 185 let def = doc.defs.get("main").expect("get main def"); 186 187 let tokens = codegen 188 .generate_def("app.bsky.feed.post", "main", def) 189 .expect("generate"); 190 191 // Format and print for inspection 192 let file: syn::File = syn::parse2(tokens).expect("parse tokens"); 193 let formatted = prettyplease::unparse(&file); 194 println!("\n{}\n", formatted); 195 196 // Check basic structure 197 assert!(formatted.contains("struct Post")); 198 assert!(formatted.contains("pub text")); 199 assert!(formatted.contains("CowStr<'a>")); 200 } 201 202 #[test] 203 fn test_generate_union() { 204 let corpus = 205 LexiconCorpus::load_from_dir("tests/fixtures/test_lexicons").expect("load corpus"); 206 let codegen = CodeGenerator::new(&corpus, "jacquard_api"); 207 208 // Create a union with embed types 209 let refs = vec![ 210 "app.bsky.embed.images".into(), 211 "app.bsky.embed.video".into(), 212 "app.bsky.embed.external".into(), 213 ]; 214 215 let tokens = codegen 216 .generate_union( 217 "app.bsky.feed.post", 218 "RecordEmbed", 219 &refs, 220 Some("Post embed union"), 221 None, 222 ) 223 .expect("generate union"); 224 225 let file: syn::File = syn::parse2(tokens).expect("parse tokens"); 226 let formatted = prettyplease::unparse(&file); 227 println!("\n{}\n", formatted); 228 229 // Check structure 230 assert!(formatted.contains("enum RecordEmbed")); 231 assert!(formatted.contains("Images")); 232 assert!(formatted.contains("Video")); 233 assert!(formatted.contains("External")); 234 assert!(formatted.contains("#[serde(tag = \"$type\")]")); 235 assert!(formatted.contains("#[jacquard_derive::open_union]")); 236 } 237 238 #[test] 239 fn test_generate_query() { 240 let corpus = 241 LexiconCorpus::load_from_dir("tests/fixtures/test_lexicons").expect("load corpus"); 242 let codegen = CodeGenerator::new(&corpus, "jacquard_api"); 243 244 let doc = corpus 245 .get("app.bsky.feed.getAuthorFeed") 246 .expect("get getAuthorFeed"); 247 let def = doc.defs.get("main").expect("get main def"); 248 249 let tokens = codegen 250 .generate_def("app.bsky.feed.getAuthorFeed", "main", def) 251 .expect("generate"); 252 253 let file: syn::File = syn::parse2(tokens).expect("parse tokens"); 254 let formatted = prettyplease::unparse(&file); 255 println!("\n{}\n", formatted); 256 257 // Check structure 258 assert!(formatted.contains("struct GetAuthorFeed")); 259 assert!(formatted.contains("struct GetAuthorFeedOutput")); 260 assert!(formatted.contains("enum GetAuthorFeedError")); 261 assert!(formatted.contains("pub actor")); 262 assert!(formatted.contains("pub limit")); 263 assert!(formatted.contains("pub cursor")); 264 assert!(formatted.contains("pub feed")); 265 assert!(formatted.contains("BlockedActor")); 266 assert!(formatted.contains("BlockedByActor")); 267 } 268 269 #[test] 270 fn test_generate_known_values_enum() { 271 let corpus = 272 LexiconCorpus::load_from_dir("tests/fixtures/test_lexicons").expect("load corpus"); 273 let codegen = CodeGenerator::new(&corpus, "jacquard_api"); 274 275 let doc = corpus 276 .get("com.atproto.label.defs") 277 .expect("get label defs"); 278 let def = doc.defs.get("labelValue").expect("get labelValue def"); 279 280 let tokens = codegen 281 .generate_def("com.atproto.label.defs", "labelValue", def) 282 .expect("generate"); 283 284 let file: syn::File = syn::parse2(tokens).expect("parse tokens"); 285 let formatted = prettyplease::unparse(&file); 286 println!("\n{}\n", formatted); 287 288 // Check structure 289 assert!(formatted.contains("enum LabelValue")); 290 assert!(formatted.contains("Hide")); 291 assert!(formatted.contains("NoPromote")); 292 assert!(formatted.contains("Warn")); 293 assert!(formatted.contains("DmcaViolation")); 294 assert!(formatted.contains("Other(jacquard_common::CowStr")); 295 assert!(formatted.contains("impl<'a> From<&'a str>")); 296 assert!(formatted.contains("fn as_str(&self)")); 297 } 298 299 #[test] 300 fn test_nsid_to_file_path() { 301 let corpus = 302 LexiconCorpus::load_from_dir("tests/fixtures/test_lexicons").expect("load corpus"); 303 let codegen = CodeGenerator::new(&corpus, "jacquard_api"); 304 305 // Regular paths 306 assert_eq!( 307 codegen.nsid_to_file_path("app.bsky.feed.post"), 308 std::path::PathBuf::from("app_bsky/feed/post.rs") 309 ); 310 311 assert_eq!( 312 codegen.nsid_to_file_path("app.bsky.feed.getAuthorFeed"), 313 std::path::PathBuf::from("app_bsky/feed/get_author_feed.rs") 314 ); 315 316 // Defs paths - should go in parent 317 assert_eq!( 318 codegen.nsid_to_file_path("com.atproto.label.defs"), 319 std::path::PathBuf::from("com_atproto/label.rs") 320 ); 321 } 322 323 #[test] 324 fn test_write_to_disk() { 325 let corpus = 326 LexiconCorpus::load_from_dir("tests/fixtures/test_lexicons").expect("load corpus"); 327 let codegen = CodeGenerator::new(&corpus, "test_generated"); 328 329 let tmp_dir = 330 tempfile::tempdir().expect("should be able to create temp directory for output"); 331 let output_dir = std::path::PathBuf::from(tmp_dir.path()); 332 333 // Clean up any previous test output 334 let _ = std::fs::remove_dir_all(&output_dir); 335 336 // Generate and write 337 codegen.write_to_disk(&output_dir).expect("write to disk"); 338 339 // Verify some files were created 340 assert!(output_dir.join("app_bsky/feed/post.rs").exists()); 341 assert!(output_dir.join("app_bsky/feed/get_author_feed.rs").exists()); 342 assert!(output_dir.join("com_atproto/label.rs").exists()); 343 344 // Verify module files were created 345 assert!(output_dir.join("lib.rs").exists()); 346 assert!(output_dir.join("app_bsky.rs").exists()); 347 348 // Read and verify post.rs contains expected content 349 let post_content = std::fs::read_to_string(output_dir.join("app_bsky/feed/post.rs")) 350 .expect("read post.rs"); 351 assert!(post_content.contains("pub struct Post")); 352 assert!(post_content.contains("jacquard_common")); 353 } 354}