Repo of no-std crates for my personal embedded projects
1use core::{fmt, str};
2
3use winnow::{
4 ModalResult, Parser,
5 binary::be_u8,
6 error::{ContextError, ErrMode, FromExternalError},
7 stream::Offset,
8 token::take,
9};
10
11use crate::{
12 dns::traits::{DnsParse, DnsSerialize},
13 encoder::{DnsError, Encoder, MAX_STR_LEN, PTR_MASK},
14};
15
16#[derive(Clone, Copy)]
17pub struct Label<'a> {
18 repr: LabelRepr<'a>,
19}
20
21impl<'a> From<&'a str> for Label<'a> {
22 fn from(value: &'a str) -> Self {
23 Self {
24 repr: LabelRepr::Str(value),
25 }
26 }
27}
28
29impl<'a> DnsSerialize<'a> for Label<'a> {
30 type Error = DnsError;
31
32 fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> {
33 match self.repr {
34 LabelRepr::Bytes {
35 context,
36 start,
37 end,
38 } => {
39 encoder.write(&context[start..end]);
40
41 Ok(())
42 }
43 LabelRepr::Str(label) => encoder.write_label(label),
44 }
45 }
46
47 fn size(&self) -> usize {
48 match self.repr {
49 LabelRepr::Bytes {
50 context,
51 start,
52 end,
53 } => core::mem::size_of_val(&context[start..end]),
54 LabelRepr::Str(label) => core::mem::size_of_val(label) + 1,
55 }
56 }
57}
58
59impl<'a> DnsParse<'a> for Label<'a> {
60 fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self> {
61 let start = input.offset_from(&context);
62 let mut end = start;
63
64 loop {
65 match LabelSegment::parse(input)? {
66 LabelSegment::Empty => {
67 end += 1;
68 break;
69 }
70 LabelSegment::String(label) => {
71 end += 1 + label.len();
72 }
73 LabelSegment::Pointer(_) => {
74 end += 2;
75 break;
76 }
77 }
78 }
79
80 Ok(Self {
81 repr: LabelRepr::Bytes {
82 context,
83 start,
84 end,
85 },
86 })
87 }
88}
89
90impl Label<'_> {
91 pub fn segments(&self) -> impl Iterator<Item = LabelSegment<'_>> {
92 self.repr.iter()
93 }
94
95 pub fn names(&self) -> impl Iterator<Item = &'_ str> {
96 match self.repr {
97 LabelRepr::Str(view) => Either::A(view.split('.')),
98 LabelRepr::Bytes { context, start, .. } => Either::B(
99 LabelSegmentBytesIter::new(context, start).flat_map(|label| label.as_str()),
100 ),
101 }
102 }
103
104 pub fn is_empty(&self) -> bool {
105 self.repr.iter().next().is_none()
106 }
107}
108
109#[derive(Clone, Copy, PartialEq, Eq)]
110enum LabelRepr<'a> {
111 Bytes {
112 context: &'a [u8],
113 start: usize,
114 end: usize,
115 },
116 Str(&'a str),
117}
118
119/// A DNS-compatible label segment.
120#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
121pub enum LabelSegment<'a> {
122 /// The empty terminator.
123 Empty,
124
125 /// A string label.
126 String(&'a str),
127
128 /// A pointer to a previous name.
129 Pointer(u16),
130}
131
132impl<'a> LabelSegment<'a> {
133 fn parse(input: &mut &'a [u8]) -> ModalResult<Self> {
134 let b1 = be_u8(input)?;
135
136 match b1 {
137 0 => Ok(Self::Empty),
138 b1 if b1 & PTR_MASK == PTR_MASK => {
139 let b2 = be_u8(input)?;
140
141 let ptr = u16::from_be_bytes([b1 & !PTR_MASK, b2]);
142
143 Ok(Self::Pointer(ptr))
144 }
145 len => {
146 if len > MAX_STR_LEN {
147 return Err(ErrMode::Cut(ContextError::from_external_error(
148 input,
149 DnsError::LabelTooLong,
150 )));
151 }
152
153 let segment = take(len).try_map(core::str::from_utf8).parse_next(input)?;
154
155 Ok(Self::String(segment))
156 }
157 }
158 }
159
160 /// ## Safety
161 /// The caller upholds that this function is not called when parsing from newly received data. Data that
162 /// has yet to be determined to be a valid [`Label`] should be parsed and validated with [`Label::parse`],
163 /// and that the entire data/context has been validated, not just a portion of it.
164 #[inline]
165 unsafe fn parse_unchecked(input: &'a [u8]) -> Option<Self> {
166 input.split_first().map(|(b1, input)| match *b1 {
167 0 => Self::Empty,
168 b1 if b1 & PTR_MASK == PTR_MASK => {
169 // SAFETY: The caller has already validated that a second byte is available for a
170 // Pointer segment.
171 let b2 = unsafe { *input.get_unchecked(0) };
172
173 let ptr = u16::from_be_bytes([b1 & !PTR_MASK, b2]);
174
175 Self::Pointer(ptr)
176 }
177 len => {
178 // SAFETY: The caller has validated that this length value is correct and will only
179 // access within the bounds of the provided slice.
180 let segment = unsafe { input.get_unchecked(0..(len as usize)) };
181 // SAFETY: The caller has upheld the validity of the bytes as valid UTF-8 once before.
182 let segment = unsafe { core::str::from_utf8_unchecked(segment) };
183
184 Self::String(segment)
185 }
186 })
187 }
188
189 fn as_str(&self) -> Option<&'a str> {
190 match self {
191 Self::String(label) => Some(*label),
192 _ => None,
193 }
194 }
195}
196
197pub struct LabelSegmentBytesIter<'a> {
198 context: &'a [u8],
199 start: usize,
200}
201
202impl<'a> LabelSegmentBytesIter<'a> {
203 pub(crate) fn new(context: &'a [u8], start: usize) -> Self {
204 Self { context, start }
205 }
206}
207
208impl<'a> Iterator for LabelSegmentBytesIter<'a> {
209 type Item = LabelSegment<'a>;
210
211 fn next(&mut self) -> Option<Self::Item> {
212 loop {
213 let view = &self.context[self.start..];
214
215 // SAFETY: The segment has already been validated, so they should be all valid variants and UTF-8 bytes
216 let segment = unsafe { LabelSegment::parse_unchecked(view)? };
217
218 match segment {
219 LabelSegment::String(label) => {
220 self.start = self.start.saturating_add(label.len() + 1);
221 return Some(segment);
222 }
223 LabelSegment::Pointer(ptr) => {
224 self.start = ptr as usize;
225 }
226 LabelSegment::Empty => {
227 // Set the index offset to be len() so that the view is empty and terminates the loop
228 self.start = self.context.len();
229 return Some(LabelSegment::Empty);
230 }
231 }
232 }
233 }
234}
235
236impl<'a> LabelRepr<'a> {
237 fn iter(&self) -> impl Iterator<Item = LabelSegment<'a>> {
238 match *self {
239 LabelRepr::Bytes { context, start, .. } => {
240 Either::A(LabelSegmentBytesIter::new(context, start))
241 }
242 LabelRepr::Str(view) => Either::B(
243 view.split('.')
244 .map(LabelSegment::String)
245 .chain(Some(LabelSegment::Empty)),
246 ),
247 }
248 }
249}
250
251impl fmt::Debug for Label<'_> {
252 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
253 struct LabelFmt<'a>(&'a Label<'a>);
254
255 impl fmt::Debug for LabelFmt<'_> {
256 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
257 fmt::Display::fmt(self.0, f)
258 }
259 }
260
261 f.debug_tuple("Label").field(&LabelFmt(self)).finish()
262 }
263}
264
265impl fmt::Display for Label<'_> {
266 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
267 let mut names = self.names();
268
269 if let Some(name) = names.next() {
270 f.write_str(name)?;
271
272 names.try_for_each(|name| {
273 f.write_str(".")?;
274 f.write_str(name)
275 })
276 } else {
277 Ok(())
278 }
279 }
280}
281
282impl<'a, 'b> PartialEq<Label<'a>> for Label<'b> {
283 fn eq(&self, other: &Label<'a>) -> bool {
284 self.segments().eq(other.segments())
285 }
286}
287
288impl Eq for Label<'_> {}
289
290impl PartialEq<&str> for Label<'_> {
291 fn eq(&self, other: &&str) -> bool {
292 let mut self_iter = self.names();
293 let mut other_iter = other.split('.');
294
295 loop {
296 match (self_iter.next(), other_iter.next()) {
297 (Some(self_part), Some(other_part)) => {
298 if self_part != other_part {
299 return false;
300 }
301 }
302 (None, None) => return true,
303 _ => return false,
304 }
305 }
306 }
307}
308
309#[cfg(feature = "defmt")]
310impl defmt::Format for Label<'_> {
311 fn format(&self, fmt: defmt::Formatter) {
312 defmt::write!(fmt, "Label(");
313 let mut iter = self.names();
314 if let Some(first) = iter.next() {
315 defmt::write!(fmt, "{}", first);
316
317 iter.for_each(|part| defmt::write!(fmt, ".{}", part));
318 }
319 defmt::write!(fmt, ")");
320 }
321}
322
323/// One iterator or another.
324enum Either<A, B> {
325 A(A),
326 B(B),
327}
328
329impl<A: Iterator, Other: Iterator<Item = A::Item>> Iterator for Either<A, Other> {
330 type Item = A::Item;
331
332 fn next(&mut self) -> Option<Self::Item> {
333 match self {
334 Either::A(a) => a.next(),
335 Either::B(b) => b.next(),
336 }
337 }
338
339 fn size_hint(&self) -> (usize, Option<usize>) {
340 match self {
341 Either::A(a) => a.size_hint(),
342 Either::B(b) => b.size_hint(),
343 }
344 }
345
346 fn fold<B, F>(self, init: B, f: F) -> B
347 where
348 Self: Sized,
349 F: FnMut(B, Self::Item) -> B,
350 {
351 match self {
352 Either::A(a) => a.fold(init, f),
353 Either::B(b) => b.fold(init, f),
354 }
355 }
356}
357
358#[cfg(test)]
359mod test {
360 use super::*;
361
362 #[test]
363 fn segments_iter_test() {
364 let label: Label<'static> = Label::from("_service._udp.local");
365 let mut segments = label.segments();
366
367 assert_eq!(segments.next(), Some(LabelSegment::String("_service")));
368 assert_eq!(segments.next(), Some(LabelSegment::String("_udp")));
369 assert_eq!(segments.next(), Some(LabelSegment::String("local")));
370 assert_eq!(segments.next(), Some(LabelSegment::Empty));
371 assert_eq!(segments.next(), None);
372
373 // example.com with a pointer to the start
374 let data = b"\x07example\x03com\x00\xC0\x00";
375 let context = &data[..];
376 // The data here is entirely valid, even though we parse only a portion of it.
377 let label = Label::parse(&mut &data[13..], context).unwrap();
378
379 let mut segments = label.segments();
380 assert_eq!(segments.next(), Some(LabelSegment::String("example")));
381 assert_eq!(segments.next(), Some(LabelSegment::String("com")));
382 assert_eq!(segments.next(), Some(LabelSegment::Empty));
383 assert_eq!(segments.next(), None);
384 }
385
386 #[test]
387 fn names_iter_test() {
388 let label: Label<'static> = Label::from("_service._udp.local");
389 let mut names = label.names();
390
391 assert_eq!(names.next(), Some("_service"));
392 assert_eq!(names.next(), Some("_udp"));
393 assert_eq!(names.next(), Some("local"));
394 assert_eq!(names.next(), None);
395
396 let data = b"\x07example\x03com\x00\xC0\x00";
397 let context = &data[..];
398 // The data here is entirely valid, even though we parse only a portion of it.
399 let label = Label::parse(&mut &data[13..], context).unwrap();
400
401 let mut names = label.names();
402 assert_eq!(names.next(), Some("example"));
403 assert_eq!(names.next(), Some("com"));
404 assert_eq!(names.next(), None);
405 }
406
407 #[test]
408 fn serialize_str_label() {
409 let label: Label<'static> = Label::from("_service._udp.local");
410 let mut buffer = [0u8; 256];
411 let mut buffer = Encoder::new(&mut buffer);
412 label.serialize(&mut buffer).unwrap();
413 assert_eq!(buffer.finish(), b"\x08_service\x04_udp\x05local\x00");
414 }
415
416 #[test]
417 fn serialize_compressed_str_label() {
418 let label: Label<'static> = Label::from("_service._udp.local");
419 let label2: Label<'static> = Label::from("other._udp.local");
420 let mut buffer = [0u8; 256];
421 let mut buffer = Encoder::new(&mut buffer);
422 label.serialize(&mut buffer).unwrap();
423 label2.serialize(&mut buffer).unwrap();
424 assert_eq!(
425 buffer.finish(),
426 b"\x08_service\x04_udp\x05local\x00\x05other\xC0\x09"
427 );
428 }
429
430 #[test]
431 fn round_trip_compressed_str_label() {
432 let label: Label<'static> = Label::from("_service._udp.local");
433 let label2: Label<'static> = Label::from("other._udp.local");
434 let mut buffer = [0u8; 256];
435 let mut buffer = Encoder::new(&mut buffer);
436 label.serialize(&mut buffer).unwrap();
437 label2.serialize(&mut buffer).unwrap();
438 let context = buffer.finish();
439 assert_eq!(
440 context,
441 b"\x08_service\x04_udp\x05local\x00\x05other\xC0\x09"
442 );
443 let view = &mut &context[..];
444
445 let parsed_label = Label::parse(view, context).unwrap();
446 let parsed_label2 = Label::parse(view, context).unwrap();
447
448 let parsed_label_count = parsed_label.segments().count();
449 let parsed_label2_count = parsed_label2.segments().count();
450
451 // Both have same amount of segments.
452 assert_eq!(parsed_label_count, parsed_label2_count);
453
454 assert_eq!(parsed_label, label);
455 assert_eq!(parsed_label2, label2);
456 }
457
458 #[test]
459 fn label_byte_repr_serialization_quick_path() {
460 let data = b"\x07example\x03com\x00\xC0\x00";
461 let context = &data[..];
462 // The data here is entirely valid, even though we parse only a portion of it.
463 let label = Label::parse(&mut &data[13..], context).unwrap();
464
465 let mut buffer = [0u8; 256];
466 let mut buffer = Encoder::new(&mut buffer);
467 label.serialize(&mut buffer).unwrap();
468 // If the original Label is just a pointer, the new output will be a pointer, assuming
469 // the original data is also present in the output
470 assert_eq!(buffer.finish(), b"\xC0\x00");
471 }
472
473 #[test]
474 fn parse_and_eq_created_label() {
475 let data = b"\x07example\x03com\x00\xC0\x00";
476 let context = &data[..];
477
478 // The data here is entirely valid, even though we parse only a portion of it.
479 let parsed_label = Label::parse(&mut &data[13..], context).unwrap();
480
481 let created_label = Label::from("example.com");
482
483 assert_eq!(parsed_label, created_label);
484 }
485
486 #[test]
487 fn parse_and_eq_label_with_str() {
488 let data = b"\x07example\x03com\x00";
489 let context = &data[..];
490
491 let parsed_label = Label::parse(&mut &data[..], context).unwrap();
492
493 assert_eq!(parsed_label, "example.com");
494 }
495
496 #[test]
497 fn parse_ptr_label_and_eq_with_str() {
498 let data = b"\x07example\x03com\x00\xC0\x00";
499 let context = &data[..];
500
501 // The data here is entirely valid, even though we parse only a portion of it.
502 let parsed_label = Label::parse(&mut &data[13..], context).unwrap();
503
504 assert_eq!(parsed_label, "example.com");
505 }
506
507 #[test]
508 fn label_new_without_dot_is_not_empty() {
509 let label: Label = Label::from("example");
510 assert!(!label.is_empty());
511 }
512}