use core::{fmt, str}; use winnow::{ ModalResult, Parser, binary::be_u8, error::{ContextError, ErrMode, FromExternalError}, stream::Offset, token::take, }; use crate::{ dns::traits::{DnsParse, DnsSerialize}, encoder::{DnsError, Encoder, MAX_STR_LEN, PTR_MASK}, }; #[derive(Clone, Copy)] pub struct Label<'a> { repr: LabelRepr<'a>, } impl<'a> From<&'a str> for Label<'a> { fn from(value: &'a str) -> Self { Self { repr: LabelRepr::Str(value), } } } impl<'a> DnsSerialize<'a> for Label<'a> { type Error = DnsError; fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { match self.repr { LabelRepr::Bytes { context, start, end, } => { encoder.write(&context[start..end]); Ok(()) } LabelRepr::Str(label) => encoder.write_label(label), } } fn size(&self) -> usize { match self.repr { LabelRepr::Bytes { context, start, end, } => core::mem::size_of_val(&context[start..end]), LabelRepr::Str(label) => core::mem::size_of_val(label) + 1, } } } impl<'a> DnsParse<'a> for Label<'a> { fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult { let start = input.offset_from(&context); let mut end = start; loop { match LabelSegment::parse(input)? { LabelSegment::Empty => { end += 1; break; } LabelSegment::String(label) => { end += 1 + label.len(); } LabelSegment::Pointer(_) => { end += 2; break; } } } Ok(Self { repr: LabelRepr::Bytes { context, start, end, }, }) } } impl Label<'_> { pub fn segments(&self) -> impl Iterator> { self.repr.iter() } pub fn names(&self) -> impl Iterator { match self.repr { LabelRepr::Str(view) => Either::A(view.split('.')), LabelRepr::Bytes { context, start, .. } => Either::B( LabelSegmentBytesIter::new(context, start).flat_map(|label| label.as_str()), ), } } pub fn is_empty(&self) -> bool { self.repr.iter().next().is_none() } } #[derive(Clone, Copy, PartialEq, Eq)] enum LabelRepr<'a> { Bytes { context: &'a [u8], start: usize, end: usize, }, Str(&'a str), } /// A DNS-compatible label segment. #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum LabelSegment<'a> { /// The empty terminator. Empty, /// A string label. String(&'a str), /// A pointer to a previous name. Pointer(u16), } impl<'a> LabelSegment<'a> { fn parse(input: &mut &'a [u8]) -> ModalResult { let b1 = be_u8(input)?; match b1 { 0 => Ok(Self::Empty), b1 if b1 & PTR_MASK == PTR_MASK => { let b2 = be_u8(input)?; let ptr = u16::from_be_bytes([b1 & !PTR_MASK, b2]); Ok(Self::Pointer(ptr)) } len => { if len > MAX_STR_LEN { return Err(ErrMode::Cut(ContextError::from_external_error( input, DnsError::LabelTooLong, ))); } let segment = take(len).try_map(core::str::from_utf8).parse_next(input)?; Ok(Self::String(segment)) } } } /// ## Safety /// The caller upholds that this function is not called when parsing from newly received data. Data that /// has yet to be determined to be a valid [`Label`] should be parsed and validated with [`Label::parse`], /// and that the entire data/context has been validated, not just a portion of it. #[inline] unsafe fn parse_unchecked(input: &'a [u8]) -> Option { input.split_first().map(|(b1, input)| match *b1 { 0 => Self::Empty, b1 if b1 & PTR_MASK == PTR_MASK => { // SAFETY: The caller has already validated that a second byte is available for a // Pointer segment. let b2 = unsafe { *input.get_unchecked(0) }; let ptr = u16::from_be_bytes([b1 & !PTR_MASK, b2]); Self::Pointer(ptr) } len => { // SAFETY: The caller has validated that this length value is correct and will only // access within the bounds of the provided slice. let segment = unsafe { input.get_unchecked(0..(len as usize)) }; // SAFETY: The caller has upheld the validity of the bytes as valid UTF-8 once before. let segment = unsafe { core::str::from_utf8_unchecked(segment) }; Self::String(segment) } }) } fn as_str(&self) -> Option<&'a str> { match self { Self::String(label) => Some(*label), _ => None, } } } pub struct LabelSegmentBytesIter<'a> { context: &'a [u8], start: usize, } impl<'a> LabelSegmentBytesIter<'a> { pub(crate) fn new(context: &'a [u8], start: usize) -> Self { Self { context, start } } } impl<'a> Iterator for LabelSegmentBytesIter<'a> { type Item = LabelSegment<'a>; fn next(&mut self) -> Option { loop { let view = &self.context[self.start..]; // SAFETY: The segment has already been validated, so they should be all valid variants and UTF-8 bytes let segment = unsafe { LabelSegment::parse_unchecked(view)? }; match segment { LabelSegment::String(label) => { self.start = self.start.saturating_add(label.len() + 1); return Some(segment); } LabelSegment::Pointer(ptr) => { self.start = ptr as usize; } LabelSegment::Empty => { // Set the index offset to be len() so that the view is empty and terminates the loop self.start = self.context.len(); return Some(LabelSegment::Empty); } } } } } impl<'a> LabelRepr<'a> { fn iter(&self) -> impl Iterator> { match *self { LabelRepr::Bytes { context, start, .. } => { Either::A(LabelSegmentBytesIter::new(context, start)) } LabelRepr::Str(view) => Either::B( view.split('.') .map(LabelSegment::String) .chain(Some(LabelSegment::Empty)), ), } } } impl fmt::Debug for Label<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { struct LabelFmt<'a>(&'a Label<'a>); impl fmt::Debug for LabelFmt<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Display::fmt(self.0, f) } } f.debug_tuple("Label").field(&LabelFmt(self)).finish() } } impl fmt::Display for Label<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let mut names = self.names(); if let Some(name) = names.next() { f.write_str(name)?; names.try_for_each(|name| { f.write_str(".")?; f.write_str(name) }) } else { Ok(()) } } } impl<'a, 'b> PartialEq> for Label<'b> { fn eq(&self, other: &Label<'a>) -> bool { self.segments().eq(other.segments()) } } impl Eq for Label<'_> {} impl PartialEq<&str> for Label<'_> { fn eq(&self, other: &&str) -> bool { let mut self_iter = self.names(); let mut other_iter = other.split('.'); loop { match (self_iter.next(), other_iter.next()) { (Some(self_part), Some(other_part)) => { if self_part != other_part { return false; } } (None, None) => return true, _ => return false, } } } } #[cfg(feature = "defmt")] impl defmt::Format for Label<'_> { fn format(&self, fmt: defmt::Formatter) { defmt::write!(fmt, "Label("); let mut iter = self.names(); if let Some(first) = iter.next() { defmt::write!(fmt, "{}", first); iter.for_each(|part| defmt::write!(fmt, ".{}", part)); } defmt::write!(fmt, ")"); } } /// One iterator or another. enum Either { A(A), B(B), } impl> Iterator for Either { type Item = A::Item; fn next(&mut self) -> Option { match self { Either::A(a) => a.next(), Either::B(b) => b.next(), } } fn size_hint(&self) -> (usize, Option) { match self { Either::A(a) => a.size_hint(), Either::B(b) => b.size_hint(), } } fn fold(self, init: B, f: F) -> B where Self: Sized, F: FnMut(B, Self::Item) -> B, { match self { Either::A(a) => a.fold(init, f), Either::B(b) => b.fold(init, f), } } } #[cfg(test)] mod test { use super::*; #[test] fn segments_iter_test() { let label: Label<'static> = Label::from("_service._udp.local"); let mut segments = label.segments(); assert_eq!(segments.next(), Some(LabelSegment::String("_service"))); assert_eq!(segments.next(), Some(LabelSegment::String("_udp"))); assert_eq!(segments.next(), Some(LabelSegment::String("local"))); assert_eq!(segments.next(), Some(LabelSegment::Empty)); assert_eq!(segments.next(), None); // example.com with a pointer to the start let data = b"\x07example\x03com\x00\xC0\x00"; let context = &data[..]; // The data here is entirely valid, even though we parse only a portion of it. let label = Label::parse(&mut &data[13..], context).unwrap(); let mut segments = label.segments(); assert_eq!(segments.next(), Some(LabelSegment::String("example"))); assert_eq!(segments.next(), Some(LabelSegment::String("com"))); assert_eq!(segments.next(), Some(LabelSegment::Empty)); assert_eq!(segments.next(), None); } #[test] fn names_iter_test() { let label: Label<'static> = Label::from("_service._udp.local"); let mut names = label.names(); assert_eq!(names.next(), Some("_service")); assert_eq!(names.next(), Some("_udp")); assert_eq!(names.next(), Some("local")); assert_eq!(names.next(), None); let data = b"\x07example\x03com\x00\xC0\x00"; let context = &data[..]; // The data here is entirely valid, even though we parse only a portion of it. let label = Label::parse(&mut &data[13..], context).unwrap(); let mut names = label.names(); assert_eq!(names.next(), Some("example")); assert_eq!(names.next(), Some("com")); assert_eq!(names.next(), None); } #[test] fn serialize_str_label() { let label: Label<'static> = Label::from("_service._udp.local"); let mut buffer = [0u8; 256]; let mut buffer = Encoder::new(&mut buffer); label.serialize(&mut buffer).unwrap(); assert_eq!(buffer.finish(), b"\x08_service\x04_udp\x05local\x00"); } #[test] fn serialize_compressed_str_label() { let label: Label<'static> = Label::from("_service._udp.local"); let label2: Label<'static> = Label::from("other._udp.local"); let mut buffer = [0u8; 256]; let mut buffer = Encoder::new(&mut buffer); label.serialize(&mut buffer).unwrap(); label2.serialize(&mut buffer).unwrap(); assert_eq!( buffer.finish(), b"\x08_service\x04_udp\x05local\x00\x05other\xC0\x09" ); } #[test] fn round_trip_compressed_str_label() { let label: Label<'static> = Label::from("_service._udp.local"); let label2: Label<'static> = Label::from("other._udp.local"); let mut buffer = [0u8; 256]; let mut buffer = Encoder::new(&mut buffer); label.serialize(&mut buffer).unwrap(); label2.serialize(&mut buffer).unwrap(); let context = buffer.finish(); assert_eq!( context, b"\x08_service\x04_udp\x05local\x00\x05other\xC0\x09" ); let view = &mut &context[..]; let parsed_label = Label::parse(view, context).unwrap(); let parsed_label2 = Label::parse(view, context).unwrap(); let parsed_label_count = parsed_label.segments().count(); let parsed_label2_count = parsed_label2.segments().count(); // Both have same amount of segments. assert_eq!(parsed_label_count, parsed_label2_count); assert_eq!(parsed_label, label); assert_eq!(parsed_label2, label2); } #[test] fn label_byte_repr_serialization_quick_path() { let data = b"\x07example\x03com\x00\xC0\x00"; let context = &data[..]; // The data here is entirely valid, even though we parse only a portion of it. let label = Label::parse(&mut &data[13..], context).unwrap(); let mut buffer = [0u8; 256]; let mut buffer = Encoder::new(&mut buffer); label.serialize(&mut buffer).unwrap(); // If the original Label is just a pointer, the new output will be a pointer, assuming // the original data is also present in the output assert_eq!(buffer.finish(), b"\xC0\x00"); } #[test] fn parse_and_eq_created_label() { let data = b"\x07example\x03com\x00\xC0\x00"; let context = &data[..]; // The data here is entirely valid, even though we parse only a portion of it. let parsed_label = Label::parse(&mut &data[13..], context).unwrap(); let created_label = Label::from("example.com"); assert_eq!(parsed_label, created_label); } #[test] fn parse_and_eq_label_with_str() { let data = b"\x07example\x03com\x00"; let context = &data[..]; let parsed_label = Label::parse(&mut &data[..], context).unwrap(); assert_eq!(parsed_label, "example.com"); } #[test] fn parse_ptr_label_and_eq_with_str() { let data = b"\x07example\x03com\x00\xC0\x00"; let context = &data[..]; // The data here is entirely valid, even though we parse only a portion of it. let parsed_label = Label::parse(&mut &data[13..], context).unwrap(); assert_eq!(parsed_label, "example.com"); } #[test] fn label_new_without_dot_is_not_empty() { let label: Label = Label::from("example"); assert!(!label.is_empty()); } }