Repo of no-std crates for my personal embedded projects
at main 11 kB view raw
1use core::{ 2 convert::Infallible, 3 net::{Ipv4Addr, Ipv6Addr}, 4 str, 5}; 6 7use alloc::vec::Vec; 8use winnow::token::take; 9use winnow::{ModalResult, Parser}; 10use winnow::{binary::be_u8, error::ContextError}; 11use winnow::{binary::be_u16, error::FromExternalError}; 12 13use super::label::Label; 14use crate::{ 15 dns::traits::{DnsParse, DnsParseKind, DnsSerialize}, 16 encoder::{DnsError, Encoder}, 17}; 18 19#[derive(Debug, Clone, Copy, PartialEq, Eq)] 20#[repr(u16)] 21#[allow(clippy::upper_case_acronyms)] 22pub enum QType { 23 A = 1, 24 AAAA = 28, 25 PTR = 12, 26 TXT = 16, 27 SRV = 33, 28 Any = 255, 29 Unknown(u16), 30} 31 32impl<'a> DnsParse<'a> for QType { 33 fn parse(input: &mut &'a [u8], _context: &'a [u8]) -> ModalResult<Self> { 34 be_u16.map(QType::from_u16).parse_next(input) 35 } 36} 37 38impl<'a> DnsSerialize<'a> for QType { 39 type Error = Infallible; 40 41 fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { 42 encoder.write(&self.to_u16().to_be_bytes()); 43 Ok(()) 44 } 45 46 fn size(&self) -> usize { 47 core::mem::size_of::<QType>() 48 } 49} 50 51impl<'a> DnsParseKind<'a> for QType { 52 type Output = Record<'a>; 53 54 fn parse_kind(&self, input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self::Output> { 55 match self { 56 QType::A => { 57 let record = A::parse(input, context)?; 58 Ok(Record::A(record)) 59 } 60 QType::AAAA => { 61 let record = AAAA::parse(input, context)?; 62 Ok(Record::AAAA(record)) 63 } 64 QType::PTR => { 65 let record = PTR::parse(input, context)?; 66 Ok(Record::PTR(record)) 67 } 68 QType::TXT => { 69 let record = TXT::parse(input, context)?; 70 Ok(Record::TXT(record)) 71 } 72 QType::SRV => { 73 let record = SRV::parse(input, context)?; 74 Ok(Record::SRV(record)) 75 } 76 QType::Any => Err(winnow::error::ErrMode::Backtrack( 77 ContextError::from_external_error(input, DnsError::Unsupported), 78 )), 79 QType::Unknown(_) => Err(winnow::error::ErrMode::Backtrack( 80 ContextError::from_external_error(input, DnsError::Unsupported), 81 )), 82 } 83 } 84} 85 86impl QType { 87 fn from_u16(value: u16) -> Self { 88 match value { 89 1 => QType::A, 90 28 => QType::AAAA, 91 12 => QType::PTR, 92 16 => QType::TXT, 93 33 => QType::SRV, 94 255 => QType::Any, 95 _ => QType::Unknown(value), 96 } 97 } 98 99 fn to_u16(self) -> u16 { 100 match self { 101 QType::A => 1, 102 QType::AAAA => 28, 103 QType::PTR => 12, 104 QType::TXT => 16, 105 QType::SRV => 33, 106 QType::Any => 255, 107 QType::Unknown(value) => value, 108 } 109 } 110} 111 112#[derive(Debug, PartialEq, Eq)] 113#[allow(clippy::upper_case_acronyms)] 114// Enum for DNS-SD records 115pub enum Record<'a> { 116 A(A), 117 AAAA(AAAA), 118 PTR(PTR<'a>), 119 TXT(TXT<'a>), 120 SRV(SRV<'a>), 121} 122 123impl<'a> DnsSerialize<'a> for Record<'a> { 124 type Error = DnsError; 125 126 fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { 127 match self { 128 Record::A(record) => { 129 record.serialize(encoder).ok(); 130 } 131 Record::AAAA(record) => { 132 record.serialize(encoder).ok(); 133 } 134 Record::PTR(record) => { 135 record.serialize(encoder)?; 136 } 137 Record::TXT(record) => { 138 record.serialize(encoder).ok(); 139 } 140 Record::SRV(record) => { 141 record.serialize(encoder)?; 142 } 143 }; 144 145 Ok(()) 146 } 147 148 fn size(&self) -> usize { 149 match self { 150 Self::A(a) => a.size(), 151 Self::AAAA(aaaa) => aaaa.size(), 152 Self::PTR(ptr) => ptr.size(), 153 Self::TXT(txt) => txt.size(), 154 Self::SRV(srv) => srv.size(), 155 } 156 } 157} 158 159// Struct for A record 160#[derive(Debug, PartialEq, Eq)] 161pub struct A { 162 pub address: Ipv4Addr, 163} 164 165impl<'a> DnsParse<'a> for A { 166 fn parse(input: &mut &'a [u8], _context: &'a [u8]) -> ModalResult<Self> { 167 let len = be_u16.parse_next(input)?; 168 let address = take(len) 169 .try_map(<[u8; 4]>::try_from) 170 .map(Ipv4Addr::from) 171 .parse_next(input)?; 172 173 Ok(A { address }) 174 } 175} 176 177impl<'a> DnsSerialize<'a> for A { 178 type Error = Infallible; 179 180 fn serialize(&self, writer: &mut Encoder<'_, '_>) -> Result<(), Self::Error> { 181 let len = 4u16.to_be_bytes(); 182 writer.write(&len); 183 writer.write(&self.address.octets()); 184 Ok(()) 185 } 186 187 fn size(&self) -> usize { 188 core::mem::size_of::<Ipv4Addr>() + core::mem::size_of::<u16>() 189 } 190} 191 192// Struct for AAAA record 193#[derive(Debug, PartialEq, Eq)] 194#[allow(clippy::upper_case_acronyms)] 195pub struct AAAA { 196 pub address: Ipv6Addr, 197} 198 199impl<'a> DnsParse<'a> for AAAA { 200 fn parse(input: &mut &'a [u8], _context: &'a [u8]) -> ModalResult<Self> { 201 let len = be_u16.parse_next(input)?; 202 let address = take(len) 203 .try_map(<[u8; 16]>::try_from) 204 .map(Ipv6Addr::from) 205 .parse_next(input)?; 206 207 Ok(AAAA { address }) 208 } 209} 210 211impl<'a> DnsSerialize<'a> for AAAA { 212 type Error = Infallible; 213 214 fn serialize(&self, writer: &mut Encoder<'_, '_>) -> Result<(), Self::Error> { 215 let len = 16u16.to_be_bytes(); 216 writer.write(&len); 217 writer.write(&self.address.octets()); 218 Ok(()) 219 } 220 221 fn size(&self) -> usize { 222 core::mem::size_of::<Ipv6Addr>() + core::mem::size_of::<u16>() 223 } 224} 225 226// Struct for PTR record 227#[derive(Debug, PartialEq, Eq)] 228#[allow(clippy::upper_case_acronyms)] 229pub struct PTR<'a> { 230 pub name: Label<'a>, 231} 232 233impl<'a> DnsParse<'a> for PTR<'a> { 234 fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self> { 235 let _ = be_u16.parse_next(input)?; 236 let name = Label::parse(input, context)?; 237 Ok(PTR { name }) 238 } 239} 240 241impl<'a> DnsSerialize<'a> for PTR<'a> { 242 type Error = DnsError; 243 244 fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { 245 encoder.with_record_length(|enc| self.name.serialize(enc)) 246 } 247 248 fn size(&self) -> usize { 249 self.name.size() + core::mem::size_of::<u16>() 250 } 251} 252 253// Struct for TXT record 254#[derive(Debug, PartialEq, Eq)] 255#[allow(clippy::upper_case_acronyms)] 256pub struct TXT<'a> { 257 pub text: Vec<&'a str>, 258} 259 260impl<'a> DnsParse<'a> for TXT<'a> { 261 fn parse(input: &mut &'a [u8], _context: &'a [u8]) -> ModalResult<Self> { 262 let text_len = be_u16.parse_next(input)?; 263 264 let mut total = 0u16; 265 let mut text = Vec::new(); 266 267 while total < text_len { 268 let len = be_u8(input)?; 269 270 total += 1 + len as u16; 271 272 if len > 0 { 273 let part = take(len).try_map(core::str::from_utf8).parse_next(input)?; 274 text.push(part); 275 } 276 } 277 278 Ok(TXT { text }) 279 } 280} 281 282impl<'a> DnsSerialize<'a> for TXT<'a> { 283 type Error = DnsError; 284 285 fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { 286 encoder.with_record_length(|enc| { 287 self.text.iter().try_for_each(|&part| { 288 let text_len = u8::try_from(part.len()) 289 .map_err(|_| DnsError::InvalidTxt) 290 .map(u8::to_be_bytes)?; 291 292 enc.write(&text_len); 293 enc.write(part.as_bytes()); 294 295 Ok(()) 296 }) 297 }) 298 } 299 300 fn size(&self) -> usize { 301 let len_size = core::mem::size_of::<u16>(); 302 303 let text_size = if self.text.is_empty() { 304 1 305 } else { 306 self.text.iter().map(|part| part.len() + 1).sum() 307 }; 308 309 len_size + text_size 310 } 311} 312 313// Struct for SRV record 314#[derive(Debug, PartialEq, Eq)] 315#[allow(clippy::upper_case_acronyms)] 316pub struct SRV<'a> { 317 pub priority: u16, 318 pub weight: u16, 319 pub port: u16, 320 pub target: Label<'a>, 321} 322 323impl<'a> DnsParse<'a> for SRV<'a> { 324 fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self> { 325 let _ = be_u16.parse_next(input)?; 326 let priority = be_u16.parse_next(input)?; 327 let weight = be_u16.parse_next(input)?; 328 let port = be_u16.parse_next(input)?; 329 let target = Label::parse(input, context)?; 330 331 Ok(SRV { 332 priority, 333 weight, 334 port, 335 target, 336 }) 337 } 338} 339 340impl<'a> DnsSerialize<'a> for SRV<'a> { 341 type Error = DnsError; 342 343 fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { 344 encoder.with_record_length(|enc| { 345 enc.write(&self.priority.to_be_bytes()); 346 enc.write(&self.weight.to_be_bytes()); 347 enc.write(&self.port.to_be_bytes()); 348 349 self.target.serialize(enc) 350 }) 351 } 352 353 fn size(&self) -> usize { 354 (core::mem::size_of::<u16>() * 4) + self.target.size() 355 } 356} 357 358#[cfg(feature = "defmt")] 359impl defmt::Format for A { 360 fn format(&self, fmt: defmt::Formatter) { 361 // use crate::format::FormatIpv4Addr; 362 defmt::write!(fmt, "A({})", self.address) 363 } 364} 365 366#[cfg(feature = "defmt")] 367impl defmt::Format for AAAA { 368 fn format(&self, fmt: defmt::Formatter) { 369 // use crate::format::FormatIpv6Addr; 370 defmt::write!(fmt, "AAAA({})", self.address) 371 } 372} 373 374#[cfg(feature = "defmt")] 375impl<'a> defmt::Format for Record<'a> { 376 fn format(&self, fmt: defmt::Formatter) { 377 match self { 378 Record::A(record) => defmt::write!(fmt, "Record::A({:?})", record), 379 Record::AAAA(record) => defmt::write!(fmt, "Record::AAAA({:?})", record), 380 Record::PTR(record) => defmt::write!(fmt, "Record::PTR({:?})", record), 381 Record::TXT(record) => defmt::write!(fmt, "Record::TXT({:?})", record), 382 Record::SRV(record) => defmt::write!(fmt, "Record::SRV({:?})", record), 383 } 384 } 385} 386 387#[cfg(feature = "defmt")] 388impl<'a> defmt::Format for PTR<'a> { 389 fn format(&self, fmt: defmt::Formatter) { 390 defmt::write!(fmt, "PTR {{ name: {:?} }}", self.name); 391 } 392} 393 394#[cfg(feature = "defmt")] 395impl<'a> defmt::Format for TXT<'a> { 396 fn format(&self, fmt: defmt::Formatter) { 397 defmt::write!(fmt, "TXT {{ text: {:?} }}", self.text); 398 } 399} 400 401#[cfg(feature = "defmt")] 402impl<'a> defmt::Format for SRV<'a> { 403 fn format(&self, fmt: defmt::Formatter) { 404 defmt::write!( 405 fmt, 406 "SRV {{ priority: {}, weight: {}, port: {}, target: {:?} }}", 407 self.priority, 408 self.weight, 409 self.port, 410 self.target 411 ); 412 } 413}