Repo of no-std crates for my personal embedded projects
1use core::ops::Range;
2
3use alloc::collections::BTreeMap;
4
5use crate::dns::traits::DnsSerialize;
6
7pub(crate) const MAX_STR_LEN: u8 = !PTR_MASK;
8pub(crate) const PTR_MASK: u8 = 0b1100_0000;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11#[cfg_attr(feature = "defmt", derive(defmt::Format))]
12pub enum DnsError {
13 LabelTooLong,
14 InvalidTxt,
15 Unsupported,
16}
17
18impl core::fmt::Display for DnsError {
19 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
20 match self {
21 Self::LabelTooLong => f.write_str("Encoding Error: Segment too long"),
22 Self::InvalidTxt => f.write_str("Encoding Error: TXT segment is invalid"),
23 Self::Unsupported => f.write_str("Encoding Error: Unsupported Record Type"),
24 }
25 }
26}
27
28impl core::error::Error for DnsError {}
29
30#[derive(Debug)]
31pub struct Encoder<'a, 'b> {
32 output: &'b mut [u8],
33 position: usize,
34 lookup: BTreeMap<&'a str, u16>,
35 reservation: Option<usize>,
36}
37
38impl<'a, 'b> Encoder<'a, 'b> {
39 pub const fn new(buffer: &'b mut [u8]) -> Self {
40 Self {
41 output: buffer,
42 position: 0,
43 lookup: BTreeMap::new(),
44 reservation: None,
45 }
46 }
47
48 /// Takes a payload and encodes it, consuming the encoder and yielding the resulting
49 /// slice.
50 pub fn encode<T, E>(mut self, payload: T) -> Result<&'b [u8], E>
51 where
52 E: core::error::Error,
53 T: DnsSerialize<'a, Error = E>,
54 {
55 payload.serialize(&mut self)?;
56 Ok(self.finish())
57 }
58
59 pub(crate) fn finish(self) -> &'b [u8] {
60 &self.output[..self.position]
61 }
62
63 fn increment(&mut self, amount: usize) {
64 self.position += amount;
65 }
66
67 pub(crate) fn write_label(&mut self, mut label: &'a str) -> Result<(), DnsError> {
68 loop {
69 if let Some(pos) = self.get_label_position(label) {
70 let [b1, b2] = u16::to_be_bytes(pos);
71 self.write(&[b1 | PTR_MASK, b2]);
72 return Ok(());
73 }
74
75 let dot = label.find('.');
76
77 let end = dot.unwrap_or(label.len());
78 let segment = &label[..end];
79 let len = u8::try_from(segment.len()).map_err(|_| DnsError::LabelTooLong)?;
80
81 if len > MAX_STR_LEN {
82 return Err(DnsError::LabelTooLong);
83 }
84
85 self.store_label_position(label);
86 self.write(&len.to_be_bytes());
87 self.write(segment.as_bytes());
88
89 match dot {
90 Some(end) => {
91 label = &label[end + 1..];
92 }
93 None => {
94 self.write(&[0]);
95 return Ok(());
96 }
97 }
98 }
99 }
100
101 pub(crate) fn write(&mut self, bytes: &[u8]) {
102 let len = bytes.len();
103 let end = self.position + len;
104 self.output[self.position..end].copy_from_slice(bytes);
105 self.increment(len);
106 }
107
108 fn get_label_position(&mut self, label: &str) -> Option<u16> {
109 self.lookup.get(label).copied()
110 }
111
112 fn store_label_position(&mut self, label: &'a str) {
113 self.lookup.insert(label, self.position as u16);
114 }
115
116 fn reserve_record_length(&mut self) {
117 if self.reservation.is_none() {
118 self.reservation = Some(self.position);
119 self.increment(2);
120 }
121 }
122
123 fn distance_from_reservation(&mut self) -> Option<(Range<usize>, u16)> {
124 self.reservation
125 .take()
126 .map(|start| (start..(start + 2), (self.position - start - 2) as u16))
127 }
128
129 fn write_record_length(&mut self) {
130 if let Some((reservation, len)) = self.distance_from_reservation() {
131 self.output[reservation].copy_from_slice(&len.to_be_bytes());
132 }
133 }
134
135 pub(crate) fn with_record_length<E, F>(&mut self, encoding_scope: F) -> Result<(), E>
136 where
137 E: core::error::Error,
138 F: FnOnce(&mut Self) -> Result<(), E>,
139 {
140 self.reserve_record_length();
141 encoding_scope(self)?;
142 self.write_record_length();
143 Ok(())
144 }
145}