Repo of no-std crates for my personal embedded projects
at main 16 kB view raw
1use core::{fmt, str}; 2 3use winnow::{ 4 ModalResult, Parser, 5 binary::be_u8, 6 error::{ContextError, ErrMode, FromExternalError}, 7 stream::Offset, 8 token::take, 9}; 10 11use crate::{ 12 dns::traits::{DnsParse, DnsSerialize}, 13 encoder::{DnsError, Encoder, MAX_STR_LEN, PTR_MASK}, 14}; 15 16#[derive(Clone, Copy)] 17pub struct Label<'a> { 18 repr: LabelRepr<'a>, 19} 20 21impl<'a> From<&'a str> for Label<'a> { 22 fn from(value: &'a str) -> Self { 23 Self { 24 repr: LabelRepr::Str(value), 25 } 26 } 27} 28 29impl<'a> DnsSerialize<'a> for Label<'a> { 30 type Error = DnsError; 31 32 fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { 33 match self.repr { 34 LabelRepr::Bytes { 35 context, 36 start, 37 end, 38 } => { 39 encoder.write(&context[start..end]); 40 41 Ok(()) 42 } 43 LabelRepr::Str(label) => encoder.write_label(label), 44 } 45 } 46 47 fn size(&self) -> usize { 48 match self.repr { 49 LabelRepr::Bytes { 50 context, 51 start, 52 end, 53 } => core::mem::size_of_val(&context[start..end]), 54 LabelRepr::Str(label) => core::mem::size_of_val(label) + 1, 55 } 56 } 57} 58 59impl<'a> DnsParse<'a> for Label<'a> { 60 fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self> { 61 let start = input.offset_from(&context); 62 let mut end = start; 63 64 loop { 65 match LabelSegment::parse(input)? { 66 LabelSegment::Empty => { 67 end += 1; 68 break; 69 } 70 LabelSegment::String(label) => { 71 end += 1 + label.len(); 72 } 73 LabelSegment::Pointer(_) => { 74 end += 2; 75 break; 76 } 77 } 78 } 79 80 Ok(Self { 81 repr: LabelRepr::Bytes { 82 context, 83 start, 84 end, 85 }, 86 }) 87 } 88} 89 90impl Label<'_> { 91 pub fn segments(&self) -> impl Iterator<Item = LabelSegment<'_>> { 92 self.repr.iter() 93 } 94 95 pub fn names(&self) -> impl Iterator<Item = &'_ str> { 96 match self.repr { 97 LabelRepr::Str(view) => Either::A(view.split('.')), 98 LabelRepr::Bytes { context, start, .. } => Either::B( 99 LabelSegmentBytesIter::new(context, start).flat_map(|label| label.as_str()), 100 ), 101 } 102 } 103 104 pub fn is_empty(&self) -> bool { 105 self.repr.iter().next().is_none() 106 } 107} 108 109#[derive(Clone, Copy, PartialEq, Eq)] 110enum LabelRepr<'a> { 111 Bytes { 112 context: &'a [u8], 113 start: usize, 114 end: usize, 115 }, 116 Str(&'a str), 117} 118 119/// A DNS-compatible label segment. 120#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] 121pub enum LabelSegment<'a> { 122 /// The empty terminator. 123 Empty, 124 125 /// A string label. 126 String(&'a str), 127 128 /// A pointer to a previous name. 129 Pointer(u16), 130} 131 132impl<'a> LabelSegment<'a> { 133 fn parse(input: &mut &'a [u8]) -> ModalResult<Self> { 134 let b1 = be_u8(input)?; 135 136 match b1 { 137 0 => Ok(Self::Empty), 138 b1 if b1 & PTR_MASK == PTR_MASK => { 139 let b2 = be_u8(input)?; 140 141 let ptr = u16::from_be_bytes([b1 & !PTR_MASK, b2]); 142 143 Ok(Self::Pointer(ptr)) 144 } 145 len => { 146 if len > MAX_STR_LEN { 147 return Err(ErrMode::Cut(ContextError::from_external_error( 148 input, 149 DnsError::LabelTooLong, 150 ))); 151 } 152 153 let segment = take(len).try_map(core::str::from_utf8).parse_next(input)?; 154 155 Ok(Self::String(segment)) 156 } 157 } 158 } 159 160 /// ## Safety 161 /// The caller upholds that this function is not called when parsing from newly received data. Data that 162 /// has yet to be determined to be a valid [`Label`] should be parsed and validated with [`Label::parse`], 163 /// and that the entire data/context has been validated, not just a portion of it. 164 #[inline] 165 unsafe fn parse_unchecked(input: &'a [u8]) -> Option<Self> { 166 input.split_first().map(|(b1, input)| match *b1 { 167 0 => Self::Empty, 168 b1 if b1 & PTR_MASK == PTR_MASK => { 169 // SAFETY: The caller has already validated that a second byte is available for a 170 // Pointer segment. 171 let b2 = unsafe { *input.get_unchecked(0) }; 172 173 let ptr = u16::from_be_bytes([b1 & !PTR_MASK, b2]); 174 175 Self::Pointer(ptr) 176 } 177 len => { 178 // SAFETY: The caller has validated that this length value is correct and will only 179 // access within the bounds of the provided slice. 180 let segment = unsafe { input.get_unchecked(0..(len as usize)) }; 181 // SAFETY: The caller has upheld the validity of the bytes as valid UTF-8 once before. 182 let segment = unsafe { core::str::from_utf8_unchecked(segment) }; 183 184 Self::String(segment) 185 } 186 }) 187 } 188 189 fn as_str(&self) -> Option<&'a str> { 190 match self { 191 Self::String(label) => Some(*label), 192 _ => None, 193 } 194 } 195} 196 197pub struct LabelSegmentBytesIter<'a> { 198 context: &'a [u8], 199 start: usize, 200} 201 202impl<'a> LabelSegmentBytesIter<'a> { 203 pub(crate) fn new(context: &'a [u8], start: usize) -> Self { 204 Self { context, start } 205 } 206} 207 208impl<'a> Iterator for LabelSegmentBytesIter<'a> { 209 type Item = LabelSegment<'a>; 210 211 fn next(&mut self) -> Option<Self::Item> { 212 loop { 213 let view = &self.context[self.start..]; 214 215 // SAFETY: The segment has already been validated, so they should be all valid variants and UTF-8 bytes 216 let segment = unsafe { LabelSegment::parse_unchecked(view)? }; 217 218 match segment { 219 LabelSegment::String(label) => { 220 self.start = self.start.saturating_add(label.len() + 1); 221 return Some(segment); 222 } 223 LabelSegment::Pointer(ptr) => { 224 self.start = ptr as usize; 225 } 226 LabelSegment::Empty => { 227 // Set the index offset to be len() so that the view is empty and terminates the loop 228 self.start = self.context.len(); 229 return Some(LabelSegment::Empty); 230 } 231 } 232 } 233 } 234} 235 236impl<'a> LabelRepr<'a> { 237 fn iter(&self) -> impl Iterator<Item = LabelSegment<'a>> { 238 match *self { 239 LabelRepr::Bytes { context, start, .. } => { 240 Either::A(LabelSegmentBytesIter::new(context, start)) 241 } 242 LabelRepr::Str(view) => Either::B( 243 view.split('.') 244 .map(LabelSegment::String) 245 .chain(Some(LabelSegment::Empty)), 246 ), 247 } 248 } 249} 250 251impl fmt::Debug for Label<'_> { 252 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 253 struct LabelFmt<'a>(&'a Label<'a>); 254 255 impl fmt::Debug for LabelFmt<'_> { 256 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 257 fmt::Display::fmt(self.0, f) 258 } 259 } 260 261 f.debug_tuple("Label").field(&LabelFmt(self)).finish() 262 } 263} 264 265impl fmt::Display for Label<'_> { 266 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 267 let mut names = self.names(); 268 269 if let Some(name) = names.next() { 270 f.write_str(name)?; 271 272 names.try_for_each(|name| { 273 f.write_str(".")?; 274 f.write_str(name) 275 }) 276 } else { 277 Ok(()) 278 } 279 } 280} 281 282impl<'a, 'b> PartialEq<Label<'a>> for Label<'b> { 283 fn eq(&self, other: &Label<'a>) -> bool { 284 self.segments().eq(other.segments()) 285 } 286} 287 288impl Eq for Label<'_> {} 289 290impl PartialEq<&str> for Label<'_> { 291 fn eq(&self, other: &&str) -> bool { 292 let mut self_iter = self.names(); 293 let mut other_iter = other.split('.'); 294 295 loop { 296 match (self_iter.next(), other_iter.next()) { 297 (Some(self_part), Some(other_part)) => { 298 if self_part != other_part { 299 return false; 300 } 301 } 302 (None, None) => return true, 303 _ => return false, 304 } 305 } 306 } 307} 308 309#[cfg(feature = "defmt")] 310impl defmt::Format for Label<'_> { 311 fn format(&self, fmt: defmt::Formatter) { 312 defmt::write!(fmt, "Label("); 313 let mut iter = self.names(); 314 if let Some(first) = iter.next() { 315 defmt::write!(fmt, "{}", first); 316 317 iter.for_each(|part| defmt::write!(fmt, ".{}", part)); 318 } 319 defmt::write!(fmt, ")"); 320 } 321} 322 323/// One iterator or another. 324enum Either<A, B> { 325 A(A), 326 B(B), 327} 328 329impl<A: Iterator, Other: Iterator<Item = A::Item>> Iterator for Either<A, Other> { 330 type Item = A::Item; 331 332 fn next(&mut self) -> Option<Self::Item> { 333 match self { 334 Either::A(a) => a.next(), 335 Either::B(b) => b.next(), 336 } 337 } 338 339 fn size_hint(&self) -> (usize, Option<usize>) { 340 match self { 341 Either::A(a) => a.size_hint(), 342 Either::B(b) => b.size_hint(), 343 } 344 } 345 346 fn fold<B, F>(self, init: B, f: F) -> B 347 where 348 Self: Sized, 349 F: FnMut(B, Self::Item) -> B, 350 { 351 match self { 352 Either::A(a) => a.fold(init, f), 353 Either::B(b) => b.fold(init, f), 354 } 355 } 356} 357 358#[cfg(test)] 359mod test { 360 use super::*; 361 362 #[test] 363 fn segments_iter_test() { 364 let label: Label<'static> = Label::from("_service._udp.local"); 365 let mut segments = label.segments(); 366 367 assert_eq!(segments.next(), Some(LabelSegment::String("_service"))); 368 assert_eq!(segments.next(), Some(LabelSegment::String("_udp"))); 369 assert_eq!(segments.next(), Some(LabelSegment::String("local"))); 370 assert_eq!(segments.next(), Some(LabelSegment::Empty)); 371 assert_eq!(segments.next(), None); 372 373 // example.com with a pointer to the start 374 let data = b"\x07example\x03com\x00\xC0\x00"; 375 let context = &data[..]; 376 // The data here is entirely valid, even though we parse only a portion of it. 377 let label = Label::parse(&mut &data[13..], context).unwrap(); 378 379 let mut segments = label.segments(); 380 assert_eq!(segments.next(), Some(LabelSegment::String("example"))); 381 assert_eq!(segments.next(), Some(LabelSegment::String("com"))); 382 assert_eq!(segments.next(), Some(LabelSegment::Empty)); 383 assert_eq!(segments.next(), None); 384 } 385 386 #[test] 387 fn names_iter_test() { 388 let label: Label<'static> = Label::from("_service._udp.local"); 389 let mut names = label.names(); 390 391 assert_eq!(names.next(), Some("_service")); 392 assert_eq!(names.next(), Some("_udp")); 393 assert_eq!(names.next(), Some("local")); 394 assert_eq!(names.next(), None); 395 396 let data = b"\x07example\x03com\x00\xC0\x00"; 397 let context = &data[..]; 398 // The data here is entirely valid, even though we parse only a portion of it. 399 let label = Label::parse(&mut &data[13..], context).unwrap(); 400 401 let mut names = label.names(); 402 assert_eq!(names.next(), Some("example")); 403 assert_eq!(names.next(), Some("com")); 404 assert_eq!(names.next(), None); 405 } 406 407 #[test] 408 fn serialize_str_label() { 409 let label: Label<'static> = Label::from("_service._udp.local"); 410 let mut buffer = [0u8; 256]; 411 let mut buffer = Encoder::new(&mut buffer); 412 label.serialize(&mut buffer).unwrap(); 413 assert_eq!(buffer.finish(), b"\x08_service\x04_udp\x05local\x00"); 414 } 415 416 #[test] 417 fn serialize_compressed_str_label() { 418 let label: Label<'static> = Label::from("_service._udp.local"); 419 let label2: Label<'static> = Label::from("other._udp.local"); 420 let mut buffer = [0u8; 256]; 421 let mut buffer = Encoder::new(&mut buffer); 422 label.serialize(&mut buffer).unwrap(); 423 label2.serialize(&mut buffer).unwrap(); 424 assert_eq!( 425 buffer.finish(), 426 b"\x08_service\x04_udp\x05local\x00\x05other\xC0\x09" 427 ); 428 } 429 430 #[test] 431 fn round_trip_compressed_str_label() { 432 let label: Label<'static> = Label::from("_service._udp.local"); 433 let label2: Label<'static> = Label::from("other._udp.local"); 434 let mut buffer = [0u8; 256]; 435 let mut buffer = Encoder::new(&mut buffer); 436 label.serialize(&mut buffer).unwrap(); 437 label2.serialize(&mut buffer).unwrap(); 438 let context = buffer.finish(); 439 assert_eq!( 440 context, 441 b"\x08_service\x04_udp\x05local\x00\x05other\xC0\x09" 442 ); 443 let view = &mut &context[..]; 444 445 let parsed_label = Label::parse(view, context).unwrap(); 446 let parsed_label2 = Label::parse(view, context).unwrap(); 447 448 let parsed_label_count = parsed_label.segments().count(); 449 let parsed_label2_count = parsed_label2.segments().count(); 450 451 // Both have same amount of segments. 452 assert_eq!(parsed_label_count, parsed_label2_count); 453 454 assert_eq!(parsed_label, label); 455 assert_eq!(parsed_label2, label2); 456 } 457 458 #[test] 459 fn label_byte_repr_serialization_quick_path() { 460 let data = b"\x07example\x03com\x00\xC0\x00"; 461 let context = &data[..]; 462 // The data here is entirely valid, even though we parse only a portion of it. 463 let label = Label::parse(&mut &data[13..], context).unwrap(); 464 465 let mut buffer = [0u8; 256]; 466 let mut buffer = Encoder::new(&mut buffer); 467 label.serialize(&mut buffer).unwrap(); 468 // If the original Label is just a pointer, the new output will be a pointer, assuming 469 // the original data is also present in the output 470 assert_eq!(buffer.finish(), b"\xC0\x00"); 471 } 472 473 #[test] 474 fn parse_and_eq_created_label() { 475 let data = b"\x07example\x03com\x00\xC0\x00"; 476 let context = &data[..]; 477 478 // The data here is entirely valid, even though we parse only a portion of it. 479 let parsed_label = Label::parse(&mut &data[13..], context).unwrap(); 480 481 let created_label = Label::from("example.com"); 482 483 assert_eq!(parsed_label, created_label); 484 } 485 486 #[test] 487 fn parse_and_eq_label_with_str() { 488 let data = b"\x07example\x03com\x00"; 489 let context = &data[..]; 490 491 let parsed_label = Label::parse(&mut &data[..], context).unwrap(); 492 493 assert_eq!(parsed_label, "example.com"); 494 } 495 496 #[test] 497 fn parse_ptr_label_and_eq_with_str() { 498 let data = b"\x07example\x03com\x00\xC0\x00"; 499 let context = &data[..]; 500 501 // The data here is entirely valid, even though we parse only a portion of it. 502 let parsed_label = Label::parse(&mut &data[13..], context).unwrap(); 503 504 assert_eq!(parsed_label, "example.com"); 505 } 506 507 #[test] 508 fn label_new_without_dot_is_not_empty() { 509 let label: Label = Label::from("example"); 510 assert!(!label.is_empty()); 511 } 512}