1use super::utils::sanitize_name;
2use super::CodeGenerator;
3use heck::{ToPascalCase, ToSnakeCase};
4
5impl<'c> CodeGenerator<'c> {
6 /// Check if a single-variant union is self-referential
7 pub(super) fn is_self_referential_union(
8 &self,
9 nsid: &str,
10 parent_type_name: &str,
11 union: &crate::lexicon::LexRefUnion,
12 ) -> bool {
13 if union.refs.len() != 1 {
14 return false;
15 }
16
17 let ref_str = if union.refs[0].starts_with('#') {
18 format!("{}{}", nsid, union.refs[0])
19 } else {
20 union.refs[0].to_string()
21 };
22
23 let (ref_nsid, ref_def) = if let Some((nsid_part, fragment)) = ref_str.split_once('#') {
24 (nsid_part, fragment)
25 } else {
26 (ref_str.as_str(), "main")
27 };
28
29 let ref_type_name = self.def_to_type_name(ref_nsid, ref_def);
30 ref_type_name == parent_type_name
31 }
32
33 /// Helper to generate field-based type name with collision detection
34 pub(super) fn generate_field_type_name(
35 &self,
36 nsid: &str,
37 parent_type_name: &str,
38 field_name: &str,
39 suffix: &str, // "" for union/object, "Item" for array unions
40 ) -> String {
41 let base_name = format!("{}{}{}", parent_type_name, field_name.to_pascal_case(), suffix);
42
43 // Check for collisions with lexicon defs
44 if let Some(doc) = self.corpus.get(nsid) {
45 let def_names: std::collections::HashSet<String> = doc
46 .defs
47 .keys()
48 .map(|name| self.def_to_type_name(nsid, name.as_ref()))
49 .collect();
50
51 if def_names.contains(&base_name) {
52 // Use "Union" suffix for union types, "Record" for objects
53 let disambiguator = if suffix.is_empty() && !parent_type_name.is_empty() {
54 "Union"
55 } else {
56 "Record"
57 };
58 return format!("{}{}{}{}", parent_type_name, disambiguator, field_name.to_pascal_case(), suffix);
59 }
60 }
61
62 base_name
63 }
64
65 /// Convert lexicon def name to Rust type name
66 pub(super) fn def_to_type_name(&self, nsid: &str, def_name: &str) -> String {
67 if def_name == "main" {
68 // Use last segment of NSID
69 let base_name = nsid.split('.').last().unwrap().to_pascal_case();
70
71 // Check if any other def would collide with this name
72 if let Some(doc) = self.corpus.get(nsid) {
73 let has_collision = doc.defs.keys().any(|other_def| {
74 let other_def_str: &str = other_def.as_ref();
75 other_def_str != "main" && other_def_str.to_pascal_case() == base_name
76 });
77
78 if has_collision {
79 return format!("{}Record", base_name);
80 }
81 }
82
83 base_name
84 } else {
85 def_name.to_pascal_case()
86 }
87 }
88
89 /// Convert NSID to file path relative to output directory
90 ///
91 /// - `app.bsky.feed.post` → `app_bsky/feed/post.rs`
92 /// - `com.atproto.label.defs` → `com_atproto/label.rs` (defs go in parent)
93 pub(super) fn nsid_to_file_path(&self, nsid: &str) -> std::path::PathBuf {
94 let parts: Vec<&str> = nsid.split('.').collect();
95
96 if parts.len() < 2 {
97 // Shouldn't happen with valid NSIDs, but handle gracefully
98 return format!("{}.rs", sanitize_name(parts[0])).into();
99 }
100
101 let last = parts.last().unwrap();
102
103 if *last == "defs" && parts.len() >= 3 {
104 // defs go in parent module: com.atproto.label.defs → com_atproto/label.rs
105 let first_two = format!("{}_{}", sanitize_name(parts[0]), sanitize_name(parts[1]));
106 if parts.len() == 3 {
107 // com.atproto.defs → com_atproto.rs
108 format!("{}.rs", first_two).into()
109 } else {
110 // com.atproto.label.defs → com_atproto/label.rs
111 let middle: Vec<&str> = parts[2..parts.len() - 1].iter().copied().collect();
112 let mut path = std::path::PathBuf::from(first_two);
113 for segment in &middle[..middle.len() - 1] {
114 path.push(sanitize_name(segment));
115 }
116 path.push(format!("{}.rs", sanitize_name(middle.last().unwrap())));
117 path
118 }
119 } else {
120 // Regular path: app.bsky.feed.post → app_bsky/feed/post.rs
121 let first_two = format!("{}_{}", sanitize_name(parts[0]), sanitize_name(parts[1]));
122 let mut path = std::path::PathBuf::from(first_two);
123
124 for segment in &parts[2..parts.len() - 1] {
125 path.push(sanitize_name(segment));
126 }
127
128 path.push(format!("{}.rs", sanitize_name(&last.to_snake_case())));
129 path
130 }
131 }
132}