Repo of no-std crates for my personal embedded projects
1use core::{
2 convert::Infallible,
3 net::{Ipv4Addr, Ipv6Addr},
4 str,
5};
6
7use alloc::vec::Vec;
8use winnow::token::take;
9use winnow::{ModalResult, Parser};
10use winnow::{binary::be_u8, error::ContextError};
11use winnow::{binary::be_u16, error::FromExternalError};
12
13use super::label::Label;
14use crate::{
15 dns::traits::{DnsParse, DnsParseKind, DnsSerialize},
16 encoder::{DnsError, Encoder},
17};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20#[repr(u16)]
21#[allow(clippy::upper_case_acronyms)]
22pub enum QType {
23 A = 1,
24 AAAA = 28,
25 PTR = 12,
26 TXT = 16,
27 SRV = 33,
28 Any = 255,
29 Unknown(u16),
30}
31
32impl<'a> DnsParse<'a> for QType {
33 fn parse(input: &mut &'a [u8], _context: &'a [u8]) -> ModalResult<Self> {
34 be_u16.map(QType::from_u16).parse_next(input)
35 }
36}
37
38impl<'a> DnsSerialize<'a> for QType {
39 type Error = Infallible;
40
41 fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> {
42 encoder.write(&self.to_u16().to_be_bytes());
43 Ok(())
44 }
45
46 fn size(&self) -> usize {
47 core::mem::size_of::<QType>()
48 }
49}
50
51impl<'a> DnsParseKind<'a> for QType {
52 type Output = Record<'a>;
53
54 fn parse_kind(&self, input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self::Output> {
55 match self {
56 QType::A => {
57 let record = A::parse(input, context)?;
58 Ok(Record::A(record))
59 }
60 QType::AAAA => {
61 let record = AAAA::parse(input, context)?;
62 Ok(Record::AAAA(record))
63 }
64 QType::PTR => {
65 let record = PTR::parse(input, context)?;
66 Ok(Record::PTR(record))
67 }
68 QType::TXT => {
69 let record = TXT::parse(input, context)?;
70 Ok(Record::TXT(record))
71 }
72 QType::SRV => {
73 let record = SRV::parse(input, context)?;
74 Ok(Record::SRV(record))
75 }
76 QType::Any => Err(winnow::error::ErrMode::Backtrack(
77 ContextError::from_external_error(input, DnsError::Unsupported),
78 )),
79 QType::Unknown(_) => Err(winnow::error::ErrMode::Backtrack(
80 ContextError::from_external_error(input, DnsError::Unsupported),
81 )),
82 }
83 }
84}
85
86impl QType {
87 fn from_u16(value: u16) -> Self {
88 match value {
89 1 => QType::A,
90 28 => QType::AAAA,
91 12 => QType::PTR,
92 16 => QType::TXT,
93 33 => QType::SRV,
94 255 => QType::Any,
95 _ => QType::Unknown(value),
96 }
97 }
98
99 fn to_u16(self) -> u16 {
100 match self {
101 QType::A => 1,
102 QType::AAAA => 28,
103 QType::PTR => 12,
104 QType::TXT => 16,
105 QType::SRV => 33,
106 QType::Any => 255,
107 QType::Unknown(value) => value,
108 }
109 }
110}
111
112#[derive(Debug, PartialEq, Eq)]
113#[allow(clippy::upper_case_acronyms)]
114// Enum for DNS-SD records
115pub enum Record<'a> {
116 A(A),
117 AAAA(AAAA),
118 PTR(PTR<'a>),
119 TXT(TXT<'a>),
120 SRV(SRV<'a>),
121}
122
123impl<'a> DnsSerialize<'a> for Record<'a> {
124 type Error = DnsError;
125
126 fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> {
127 match self {
128 Record::A(record) => {
129 record.serialize(encoder).ok();
130 }
131 Record::AAAA(record) => {
132 record.serialize(encoder).ok();
133 }
134 Record::PTR(record) => {
135 record.serialize(encoder)?;
136 }
137 Record::TXT(record) => {
138 record.serialize(encoder).ok();
139 }
140 Record::SRV(record) => {
141 record.serialize(encoder)?;
142 }
143 };
144
145 Ok(())
146 }
147
148 fn size(&self) -> usize {
149 match self {
150 Self::A(a) => a.size(),
151 Self::AAAA(aaaa) => aaaa.size(),
152 Self::PTR(ptr) => ptr.size(),
153 Self::TXT(txt) => txt.size(),
154 Self::SRV(srv) => srv.size(),
155 }
156 }
157}
158
159// Struct for A record
160#[derive(Debug, PartialEq, Eq)]
161pub struct A {
162 pub address: Ipv4Addr,
163}
164
165impl<'a> DnsParse<'a> for A {
166 fn parse(input: &mut &'a [u8], _context: &'a [u8]) -> ModalResult<Self> {
167 let len = be_u16.parse_next(input)?;
168 let address = take(len)
169 .try_map(<[u8; 4]>::try_from)
170 .map(Ipv4Addr::from)
171 .parse_next(input)?;
172
173 Ok(A { address })
174 }
175}
176
177impl<'a> DnsSerialize<'a> for A {
178 type Error = Infallible;
179
180 fn serialize(&self, writer: &mut Encoder<'_, '_>) -> Result<(), Self::Error> {
181 let len = 4u16.to_be_bytes();
182 writer.write(&len);
183 writer.write(&self.address.octets());
184 Ok(())
185 }
186
187 fn size(&self) -> usize {
188 core::mem::size_of::<Ipv4Addr>() + core::mem::size_of::<u16>()
189 }
190}
191
192// Struct for AAAA record
193#[derive(Debug, PartialEq, Eq)]
194#[allow(clippy::upper_case_acronyms)]
195pub struct AAAA {
196 pub address: Ipv6Addr,
197}
198
199impl<'a> DnsParse<'a> for AAAA {
200 fn parse(input: &mut &'a [u8], _context: &'a [u8]) -> ModalResult<Self> {
201 let len = be_u16.parse_next(input)?;
202 let address = take(len)
203 .try_map(<[u8; 16]>::try_from)
204 .map(Ipv6Addr::from)
205 .parse_next(input)?;
206
207 Ok(AAAA { address })
208 }
209}
210
211impl<'a> DnsSerialize<'a> for AAAA {
212 type Error = Infallible;
213
214 fn serialize(&self, writer: &mut Encoder<'_, '_>) -> Result<(), Self::Error> {
215 let len = 16u16.to_be_bytes();
216 writer.write(&len);
217 writer.write(&self.address.octets());
218 Ok(())
219 }
220
221 fn size(&self) -> usize {
222 core::mem::size_of::<Ipv6Addr>() + core::mem::size_of::<u16>()
223 }
224}
225
226// Struct for PTR record
227#[derive(Debug, PartialEq, Eq)]
228#[allow(clippy::upper_case_acronyms)]
229pub struct PTR<'a> {
230 pub name: Label<'a>,
231}
232
233impl<'a> DnsParse<'a> for PTR<'a> {
234 fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self> {
235 let _ = be_u16.parse_next(input)?;
236 let name = Label::parse(input, context)?;
237 Ok(PTR { name })
238 }
239}
240
241impl<'a> DnsSerialize<'a> for PTR<'a> {
242 type Error = DnsError;
243
244 fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> {
245 encoder.with_record_length(|enc| self.name.serialize(enc))
246 }
247
248 fn size(&self) -> usize {
249 self.name.size() + core::mem::size_of::<u16>()
250 }
251}
252
253// Struct for TXT record
254#[derive(Debug, PartialEq, Eq)]
255#[allow(clippy::upper_case_acronyms)]
256pub struct TXT<'a> {
257 pub text: Vec<&'a str>,
258}
259
260impl<'a> DnsParse<'a> for TXT<'a> {
261 fn parse(input: &mut &'a [u8], _context: &'a [u8]) -> ModalResult<Self> {
262 let text_len = be_u16.parse_next(input)?;
263
264 let mut total = 0u16;
265 let mut text = Vec::new();
266
267 while total < text_len {
268 let len = be_u8(input)?;
269
270 total += 1 + len as u16;
271
272 if len > 0 {
273 let part = take(len).try_map(core::str::from_utf8).parse_next(input)?;
274 text.push(part);
275 }
276 }
277
278 Ok(TXT { text })
279 }
280}
281
282impl<'a> DnsSerialize<'a> for TXT<'a> {
283 type Error = DnsError;
284
285 fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> {
286 encoder.with_record_length(|enc| {
287 self.text.iter().try_for_each(|&part| {
288 let text_len = u8::try_from(part.len())
289 .map_err(|_| DnsError::InvalidTxt)
290 .map(u8::to_be_bytes)?;
291
292 enc.write(&text_len);
293 enc.write(part.as_bytes());
294
295 Ok(())
296 })
297 })
298 }
299
300 fn size(&self) -> usize {
301 let len_size = core::mem::size_of::<u16>();
302
303 let text_size = if self.text.is_empty() {
304 1
305 } else {
306 self.text.iter().map(|part| part.len() + 1).sum()
307 };
308
309 len_size + text_size
310 }
311}
312
313// Struct for SRV record
314#[derive(Debug, PartialEq, Eq)]
315#[allow(clippy::upper_case_acronyms)]
316pub struct SRV<'a> {
317 pub priority: u16,
318 pub weight: u16,
319 pub port: u16,
320 pub target: Label<'a>,
321}
322
323impl<'a> DnsParse<'a> for SRV<'a> {
324 fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self> {
325 let _ = be_u16.parse_next(input)?;
326 let priority = be_u16.parse_next(input)?;
327 let weight = be_u16.parse_next(input)?;
328 let port = be_u16.parse_next(input)?;
329 let target = Label::parse(input, context)?;
330
331 Ok(SRV {
332 priority,
333 weight,
334 port,
335 target,
336 })
337 }
338}
339
340impl<'a> DnsSerialize<'a> for SRV<'a> {
341 type Error = DnsError;
342
343 fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> {
344 encoder.with_record_length(|enc| {
345 enc.write(&self.priority.to_be_bytes());
346 enc.write(&self.weight.to_be_bytes());
347 enc.write(&self.port.to_be_bytes());
348
349 self.target.serialize(enc)
350 })
351 }
352
353 fn size(&self) -> usize {
354 (core::mem::size_of::<u16>() * 4) + self.target.size()
355 }
356}
357
358#[cfg(feature = "defmt")]
359impl defmt::Format for A {
360 fn format(&self, fmt: defmt::Formatter) {
361 // use crate::format::FormatIpv4Addr;
362 defmt::write!(fmt, "A({})", self.address)
363 }
364}
365
366#[cfg(feature = "defmt")]
367impl defmt::Format for AAAA {
368 fn format(&self, fmt: defmt::Formatter) {
369 // use crate::format::FormatIpv6Addr;
370 defmt::write!(fmt, "AAAA({})", self.address)
371 }
372}
373
374#[cfg(feature = "defmt")]
375impl<'a> defmt::Format for Record<'a> {
376 fn format(&self, fmt: defmt::Formatter) {
377 match self {
378 Record::A(record) => defmt::write!(fmt, "Record::A({:?})", record),
379 Record::AAAA(record) => defmt::write!(fmt, "Record::AAAA({:?})", record),
380 Record::PTR(record) => defmt::write!(fmt, "Record::PTR({:?})", record),
381 Record::TXT(record) => defmt::write!(fmt, "Record::TXT({:?})", record),
382 Record::SRV(record) => defmt::write!(fmt, "Record::SRV({:?})", record),
383 }
384 }
385}
386
387#[cfg(feature = "defmt")]
388impl<'a> defmt::Format for PTR<'a> {
389 fn format(&self, fmt: defmt::Formatter) {
390 defmt::write!(fmt, "PTR {{ name: {:?} }}", self.name);
391 }
392}
393
394#[cfg(feature = "defmt")]
395impl<'a> defmt::Format for TXT<'a> {
396 fn format(&self, fmt: defmt::Formatter) {
397 defmt::write!(fmt, "TXT {{ text: {:?} }}", self.text);
398 }
399}
400
401#[cfg(feature = "defmt")]
402impl<'a> defmt::Format for SRV<'a> {
403 fn format(&self, fmt: defmt::Formatter) {
404 defmt::write!(
405 fmt,
406 "SRV {{ priority: {}, weight: {}, port: {}, target: {:?} }}",
407 self.priority,
408 self.weight,
409 self.port,
410 self.target
411 );
412 }
413}