Repo of no-std crates for my personal embedded projects
1use alloc::vec::Vec;
2use winnow::ModalResult;
3use winnow::binary::be_u16;
4
5use super::flags::Flags;
6use super::query::{Answer, Query};
7use crate::{
8 dns::traits::{DnsParse, DnsSerialize},
9 encoder::{DnsError, Encoder},
10};
11
12const ZERO_U16: [u8; 2] = 0u16.to_be_bytes();
13
14#[derive(Debug, PartialEq, Eq)]
15pub struct Request<'a> {
16 pub id: u16,
17 pub flags: Flags,
18 pub(crate) queries: Vec<Query<'a>>,
19}
20
21impl<'a> DnsParse<'a> for Request<'a> {
22 fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self> {
23 let id = be_u16(input)?;
24 let flags = Flags::parse(input, context)?;
25 let qdcount = be_u16(input)?;
26 let _ancount = be_u16(input)?;
27 let _nscount = be_u16(input)?;
28 let _arcount = be_u16(input)?;
29 let queries = (0..qdcount)
30 .map(|_| Query::parse(input, context))
31 .collect::<Result<Vec<_>, _>>()?;
32 Ok(Request { id, flags, queries })
33 }
34}
35
36impl<'a> DnsSerialize<'a> for Request<'a> {
37 type Error = DnsError;
38
39 fn serialize<'b>(&self, writer: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> {
40 writer.write(&self.id.to_be_bytes());
41 self.flags.serialize(writer).ok();
42 writer.write(&(self.queries.len() as u16).to_be_bytes());
43 writer.write(&ZERO_U16);
44 writer.write(&ZERO_U16);
45 writer.write(&ZERO_U16);
46
47 self.queries
48 .iter()
49 .try_for_each(|query| query.serialize(writer))
50 }
51
52 fn size(&self) -> usize {
53 let total_query_size: usize = self.queries.iter().map(DnsSerialize::size).sum();
54
55 core::mem::size_of::<u16>()
56 + self.flags.size()
57 + (core::mem::size_of::<u16>() * 4)
58 + total_query_size
59 }
60}
61
62#[derive(Debug, PartialEq, Eq)]
63pub struct Response<'a> {
64 pub id: u16,
65 pub flags: Flags,
66 pub queries: Vec<Query<'a>>,
67 pub answers: Vec<Answer<'a>>,
68}
69
70impl<'a> DnsParse<'a> for Response<'a> {
71 fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self> {
72 let id = be_u16(input)?;
73 let flags = Flags::parse(input, context)?;
74 let qdcount = be_u16(input)?;
75 let ancount = be_u16(input)?;
76 let _nscount = be_u16(input)?;
77 let _arcount = be_u16(input)?;
78
79 let queries = (0..qdcount)
80 .map(|_| Query::parse(input, context))
81 .collect::<Result<Vec<_>, _>>()?;
82
83 let answers = (0..ancount)
84 .map(|_| Answer::parse(input, context))
85 .collect::<Result<Vec<_>, _>>()?;
86
87 Ok(Response {
88 id,
89 flags,
90 queries,
91 answers,
92 })
93 }
94}
95
96impl<'a> DnsSerialize<'a> for Response<'a> {
97 type Error = DnsError;
98
99 fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> {
100 encoder.write(&self.id.to_be_bytes());
101 self.flags.serialize(encoder).ok();
102 encoder.write(&(self.queries.len() as u16).to_be_bytes());
103 encoder.write(&(self.answers.len() as u16).to_be_bytes());
104 encoder.write(&ZERO_U16);
105 encoder.write(&ZERO_U16);
106
107 self.queries
108 .iter()
109 .try_for_each(|query| query.serialize(encoder))?;
110 self.answers
111 .iter()
112 .try_for_each(|answer| answer.serialize(encoder))
113 }
114
115 fn size(&self) -> usize {
116 let total_query_size: usize = self.queries.iter().map(DnsSerialize::size).sum();
117 let total_answer_size: usize = self.answers.iter().map(DnsSerialize::size).sum();
118
119 core::mem::size_of::<u16>()
120 + self.flags.size()
121 + (core::mem::size_of::<u16>() * 4)
122 + total_query_size
123 + total_answer_size
124 }
125}
126
127#[cfg(feature = "defmt")]
128impl<'a> defmt::Format for Request<'a> {
129 fn format(&self, fmt: defmt::Formatter) {
130 defmt::write!(
131 fmt,
132 "Request {{ id: {}, flags: {:?}, queries: {:?} }}",
133 self.id,
134 self.flags,
135 self.queries
136 );
137 }
138}
139
140#[cfg(feature = "defmt")]
141impl<'a> defmt::Format for Response<'a> {
142 fn format(&self, fmt: defmt::Formatter) {
143 defmt::write!(
144 fmt,
145 "Response {{ id: {}, flags: {:?}, queries: {:?}, answers: {:?} }}",
146 self.id,
147 self.flags,
148 self.queries,
149 self.answers
150 );
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use alloc::vec;
157
158 use super::*;
159 use crate::dns::{
160 label::Label,
161 query::QClass,
162 records::{A, PTR, QType, Record, SRV, TXT},
163 };
164 use core::net::Ipv4Addr;
165
166 #[test]
167 fn parse_query() {
168 let data = [
169 0xAA, 0xAA, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x07, 0x65,
170 // example . com in label format
171 0x78, 0x61, 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, //
172 //
173 0x00, 0x01, 0x00, 0x01,
174 ];
175
176 let request = Request::parse(&mut data.as_slice(), data.as_slice()).unwrap();
177
178 assert_eq!(request.id, 0xAAAA);
179 assert_eq!(request.flags.0, 0x0100);
180 assert_eq!(request.queries.len(), 1);
181 assert_eq!(request.queries[0].name, "example.com");
182 assert_eq!(request.queries[0].qtype, QType::A);
183 assert_eq!(request.queries[0].qclass, QClass::IN);
184 }
185
186 #[test]
187 fn parse_response() {
188 let data = [
189 0xAA, 0xAA, // transaction ID
190 0x81, 0x80, // flags
191 0x00, 0x01, // 1 question
192 0x00, 0x01, // 1 A-answer
193 0x00, 0x00, // no authority
194 0x00, 0x00, // no additional answers
195 // example . com in label format
196 0x07, 0x65, 0x78, 0x61, 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, //
197 //
198 0x00, 0x01, 0x00, 0x01, //
199 //
200 0xC0, 0x0C, // ptr to question section
201 //
202 0x00, 0x01, 0x00, 0x01, // A and IN
203 //
204 0x00, 0x00, 0x00, 0x3C, // TTL 60 seconds
205 //
206 0x00, 0x04, // length of address
207 // IP address:
208 192, 168, 1, 3,
209 ];
210
211 let response = Response::parse(&mut data.as_slice(), data.as_slice()).unwrap();
212
213 assert_eq!(response.id, 0xAAAA);
214 assert_eq!(response.flags.0, 0x8180);
215 assert_eq!(response.answers.len(), 1);
216 assert_eq!(response.answers[0].name, "example.com");
217 assert_eq!(response.answers[0].atype, QType::A);
218 assert_eq!(response.answers[0].aclass, QClass::IN);
219 assert_eq!(response.answers[0].ttl, 60);
220 if let Record::A(a) = &response.answers[0].record {
221 assert_eq!(a.address, Ipv4Addr::new(192, 168, 1, 3));
222 } else {
223 panic!("Expected A record");
224 }
225 }
226
227 #[test]
228 fn parse_response_two_records() {
229 #[rustfmt::skip]
230 let data = [
231 0xAA, 0xAA, //
232 0x81, 0x80, //
233 0x00, 0x01, //
234 0x00, 0x02, //
235 0x00, 0x00, //
236 0x00, 0x00, //
237 // example . com in label format
238 0x07, 0x65, 0x78, 0x61, 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, //
239 //
240 0x00, 0x01, // query type
241 0x00, 0x01, // query class
242 //
243 0xC0, 0x0C, // pointer
244 0x00, 0x01, //
245 0x00, 0x01, //
246 0x00, 0x00, 0x00, 0x3C, // ttl 60 seconds
247 0x00, 0x04, // length of A-record
248 0x5D, 0xB8, 0xD8, 0x22, // a-record
249 //
250 0x07, 0x65, 0x78, 0x61, 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, //
251 //
252 0x00, 0x10, // TXT
253 0x00, 0x01, // IN
254 //
255 0x00, 0x00, 0x00, 0x3C, // ttl 60 seconds
256 //
257 0x00, 0x10, // length of txt record
258 // (len) "test txt record"
259 0x0F, 0x74, 0x65, 0x73, 0x74, 0x20, 0x74, 0x78, 0x74, 0x20, 0x72, 0x65, 0x63, 0x6F, 0x72,
260 0x64,
261 ];
262
263 let response = Response::parse(&mut data.as_slice(), data.as_slice()).unwrap();
264
265 assert_eq!(response.id, 0xAAAA);
266 assert_eq!(response.flags.0, 0x8180);
267 assert_eq!(response.answers.len(), 2);
268
269 // First answer
270 assert_eq!(response.answers[0].name, "example.com");
271 assert_eq!(response.answers[0].atype, QType::A);
272 assert_eq!(response.answers[0].aclass, QClass::IN);
273 assert_eq!(response.answers[0].ttl, 60);
274 if let Record::A(a) = &response.answers[0].record {
275 assert_eq!(a.address, Ipv4Addr::new(93, 184, 216, 34));
276 } else {
277 panic!("Expected A record");
278 }
279
280 // Second answer
281 assert_eq!(response.answers[1].name, "example.com");
282 assert_eq!(response.answers[1].atype, QType::TXT);
283 assert_eq!(response.answers[1].aclass, QClass::IN);
284 assert_eq!(response.answers[1].ttl, 60);
285 if let Record::TXT(txt) = &response.answers[1].record
286 && let Some(&text) = txt.text.first()
287 {
288 assert_eq!(text, "test txt record");
289 } else {
290 panic!("Expected TXT record");
291 }
292 }
293
294 #[test]
295 fn parse_response_srv() {
296 let data = [
297 //
298 0xAA, 0xAA, // id
299 0x81, 0x80, // flags
300 0x00, 0x01, // one question
301 0x00, 0x01, // one answer
302 0x00, 0x00, // no authority
303 0x00, 0x00, // no extra
304 //
305 0x04, 0x5f, 0x73, 0x69, 0x70, 0x04, 0x5f, 0x74, 0x63, 0x70, 0x07, 0x65, 0x78, 0x61,
306 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, //
307 //
308 0x00, 0x21, // type SRV
309 0x00, 0x01, // IN
310 //
311 0xc0, 0x0c, //
312 //
313 0x00, 0x21, // SRV
314 0x00, 0x01, // IN
315 0x00, 0x00, 0x00, 0x3C, // ttl 60
316 //
317 0x00, 0x19, // data len
318 0x00, 0x0A, // prio
319 0x00, 0x05, // weight
320 0x13, 0xC4, // PORT
321 //
322 0x09, 0x73, 0x69, 0x70, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x07, 0x65, 0x78, 0x61,
323 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, //
324 ];
325
326 let response = Response::parse(&mut data.as_slice(), data.as_slice()).unwrap();
327
328 assert_eq!(response.id, 0xAAAA);
329 assert_eq!(response.flags.0, 0x8180);
330 assert_eq!(response.answers.len(), 1);
331
332 // Answer
333 assert_eq!(response.answers[0].name, "_sip._tcp.example.com");
334 assert_eq!(response.answers[0].atype, QType::SRV);
335 assert_eq!(response.answers[0].aclass, QClass::IN);
336 assert_eq!(response.answers[0].ttl, 60);
337 let Record::SRV(srv) = &response.answers[0].record else {
338 panic!("Expected SRV record");
339 };
340
341 assert_eq!(srv.priority, 10);
342 assert_eq!(srv.weight, 5);
343 assert_eq!(srv.port, 5060);
344 assert_eq!(srv.target, "sipserver.example.com");
345 }
346
347 #[test]
348 fn parse_response_back_forth() {
349 #[rustfmt::skip]
350 let data = [
351 0, 0, // Transaction ID
352 132, 0, // Response, Authoritative Answer, No Recursion
353 0, 0, // 0 questions
354 0, 4, // 4 answers
355 0, 0, // 0 authority RRs
356 0, 0, // 0 additional RRs
357 // _midiriff
358 9, 95, 109, 105, 100, 105, 114, 105, 102, 102, //
359 // _udp
360 4, 95, 117, 100, 112, //
361 // local
362 5, 108, 111, 99, 97, 108, //
363 0, // <end>
364 //
365 0, 12, // PTR
366 0, 1, // Class IN
367 0, 0, 0, 120, // TTL 120 seconds
368 0, 10, // Data Length 10
369 // pi35291
370 7, 112, 105, 51, 53, 50, 57, 49, //
371 //
372 192, 12, // Pointer to _midirif._udp._local.
373 //
374 192, 44, // Pointer to instace name: pi35291._midirif._udp._local.
375 0, 33, // SRV
376 128, 1, // IN (Cache flush bit set)
377 0, 0, 0, 120, // TTL 120 seconds
378 0, 11, // Data Length 11
379 0, 0, // Priority 0
380 0, 0, // Weight 0
381 137, 219, // Port 35291
382 2, 112, 105, // _pi
383 192, 27, // Pointer to: .local.
384 // TXT (Empty)
385 192, 44, 0, 16, 128, 1, 0, 0, 17, 148, 0, 1, 0,
386 // A (10.1.1.9)
387 192, 72, 0, 1, 128, 1, 0, 0, 0, 120, 0, 4, 10, 1, 1, 9,
388 ];
389
390 let response = Response::parse(&mut data.as_slice(), data.as_slice()).unwrap();
391
392 assert_eq!(response.answers[0].name, "_midiriff._udp.local");
393 assert_eq!(response.answers[0].ttl, 120);
394 let Record::PTR(ptr) = &response.answers[0].record else {
395 panic!()
396 };
397 assert_eq!(ptr.name, "pi35291._midiriff._udp.local");
398
399 let mut buffer = [0u8; 256];
400 let mut buffer = Encoder::new(&mut buffer);
401 response.serialize(&mut buffer).unwrap();
402
403 let buffer = buffer.finish();
404
405 let response2 = Response::parse(&mut &buffer[..], buffer).unwrap();
406
407 assert_eq!(response, response2);
408 }
409
410 #[test]
411 fn mdns_service_response() {
412 let mut response = Response {
413 id: 0x1234,
414 flags: Flags::standard_response(),
415 queries: Vec::new(),
416 answers: Vec::new(),
417 };
418
419 let query = Query {
420 name: Label::from("_test._udp.local"),
421 qtype: QType::PTR,
422 qclass: QClass::IN,
423 };
424 response.queries.push(query);
425
426 let ptr_answer = Answer {
427 name: Label::from("_test._udp.local"),
428 atype: QType::PTR,
429 aclass: QClass::IN,
430 ttl: 4500,
431 record: Record::PTR(PTR {
432 name: Label::from("test-service._test._udp.local"),
433 }),
434 };
435 response.answers.push(ptr_answer);
436
437 let srv_answer = Answer {
438 name: Label::from("test-service._test._udp.local"),
439 atype: QType::SRV,
440 aclass: QClass::IN,
441 ttl: 120,
442 record: Record::SRV(SRV {
443 priority: 0,
444 weight: 0,
445 port: 8080,
446 target: Label::from("host.local"),
447 }),
448 };
449 response.answers.push(srv_answer);
450
451 let txt_answer = Answer {
452 name: Label::from("test-service._test._udp.local"),
453 atype: QType::TXT,
454 aclass: QClass::IN,
455 ttl: 120,
456 record: Record::TXT(TXT {
457 text: vec!["path=/test"],
458 }),
459 };
460 response.answers.push(txt_answer);
461
462 let a_answer = Answer {
463 name: Label::from("host.local"),
464 atype: QType::A,
465 aclass: QClass::IN,
466 ttl: 120,
467 record: Record::A(A {
468 address: Ipv4Addr::new(192, 168, 1, 100),
469 }),
470 };
471 response.answers.push(a_answer);
472
473 let mut buffer = [0u8; 256];
474 let mut buffer = Encoder::new(&mut buffer);
475 response.serialize(&mut buffer).unwrap();
476
477 let buffer = buffer.finish();
478
479 let parsed_response = Response::parse(&mut &buffer[..], buffer).unwrap();
480
481 assert_eq!(response, parsed_response);
482 }
483}