Repo of no-std crates for my personal embedded projects
at main 4.1 kB view raw
1use core::ops::Range; 2 3use alloc::collections::BTreeMap; 4 5use crate::dns::traits::DnsSerialize; 6 7pub(crate) const MAX_STR_LEN: u8 = !PTR_MASK; 8pub(crate) const PTR_MASK: u8 = 0b1100_0000; 9 10#[derive(Debug, Clone, Copy, PartialEq, Eq)] 11#[cfg_attr(feature = "defmt", derive(defmt::Format))] 12pub enum DnsError { 13 LabelTooLong, 14 InvalidTxt, 15 Unsupported, 16} 17 18impl core::fmt::Display for DnsError { 19 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { 20 match self { 21 Self::LabelTooLong => f.write_str("Encoding Error: Segment too long"), 22 Self::InvalidTxt => f.write_str("Encoding Error: TXT segment is invalid"), 23 Self::Unsupported => f.write_str("Encoding Error: Unsupported Record Type"), 24 } 25 } 26} 27 28impl core::error::Error for DnsError {} 29 30#[derive(Debug)] 31pub struct Encoder<'a, 'b> { 32 output: &'b mut [u8], 33 position: usize, 34 lookup: BTreeMap<&'a str, u16>, 35 reservation: Option<usize>, 36} 37 38impl<'a, 'b> Encoder<'a, 'b> { 39 pub const fn new(buffer: &'b mut [u8]) -> Self { 40 Self { 41 output: buffer, 42 position: 0, 43 lookup: BTreeMap::new(), 44 reservation: None, 45 } 46 } 47 48 /// Takes a payload and encodes it, consuming the encoder and yielding the resulting 49 /// slice. 50 pub fn encode<T, E>(mut self, payload: T) -> Result<&'b [u8], E> 51 where 52 E: core::error::Error, 53 T: DnsSerialize<'a, Error = E>, 54 { 55 payload.serialize(&mut self)?; 56 Ok(self.finish()) 57 } 58 59 pub(crate) fn finish(self) -> &'b [u8] { 60 &self.output[..self.position] 61 } 62 63 fn increment(&mut self, amount: usize) { 64 self.position += amount; 65 } 66 67 pub(crate) fn write_label(&mut self, mut label: &'a str) -> Result<(), DnsError> { 68 loop { 69 if let Some(pos) = self.get_label_position(label) { 70 let [b1, b2] = u16::to_be_bytes(pos); 71 self.write(&[b1 | PTR_MASK, b2]); 72 return Ok(()); 73 } 74 75 let dot = label.find('.'); 76 77 let end = dot.unwrap_or(label.len()); 78 let segment = &label[..end]; 79 let len = u8::try_from(segment.len()).map_err(|_| DnsError::LabelTooLong)?; 80 81 if len > MAX_STR_LEN { 82 return Err(DnsError::LabelTooLong); 83 } 84 85 self.store_label_position(label); 86 self.write(&len.to_be_bytes()); 87 self.write(segment.as_bytes()); 88 89 match dot { 90 Some(end) => { 91 label = &label[end + 1..]; 92 } 93 None => { 94 self.write(&[0]); 95 return Ok(()); 96 } 97 } 98 } 99 } 100 101 pub(crate) fn write(&mut self, bytes: &[u8]) { 102 let len = bytes.len(); 103 let end = self.position + len; 104 self.output[self.position..end].copy_from_slice(bytes); 105 self.increment(len); 106 } 107 108 fn get_label_position(&mut self, label: &str) -> Option<u16> { 109 self.lookup.get(label).copied() 110 } 111 112 fn store_label_position(&mut self, label: &'a str) { 113 self.lookup.insert(label, self.position as u16); 114 } 115 116 fn reserve_record_length(&mut self) { 117 if self.reservation.is_none() { 118 self.reservation = Some(self.position); 119 self.increment(2); 120 } 121 } 122 123 fn distance_from_reservation(&mut self) -> Option<(Range<usize>, u16)> { 124 self.reservation 125 .take() 126 .map(|start| (start..(start + 2), (self.position - start - 2) as u16)) 127 } 128 129 fn write_record_length(&mut self) { 130 if let Some((reservation, len)) = self.distance_from_reservation() { 131 self.output[reservation].copy_from_slice(&len.to_be_bytes()); 132 } 133 } 134 135 pub(crate) fn with_record_length<E, F>(&mut self, encoding_scope: F) -> Result<(), E> 136 where 137 E: core::error::Error, 138 F: FnOnce(&mut Self) -> Result<(), E>, 139 { 140 self.reserve_record_length(); 141 encoding_scope(self)?; 142 self.write_record_length(); 143 Ok(()) 144 } 145}