use alloc::vec::Vec; use winnow::ModalResult; use winnow::binary::be_u16; use super::flags::Flags; use super::query::{Answer, Query}; use crate::{ dns::traits::{DnsParse, DnsSerialize}, encoder::{DnsError, Encoder}, }; const ZERO_U16: [u8; 2] = 0u16.to_be_bytes(); #[derive(Debug, PartialEq, Eq)] pub struct Request<'a> { pub id: u16, pub flags: Flags, pub(crate) queries: Vec>, } impl<'a> DnsParse<'a> for Request<'a> { fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult { let id = be_u16(input)?; let flags = Flags::parse(input, context)?; let qdcount = be_u16(input)?; let _ancount = be_u16(input)?; let _nscount = be_u16(input)?; let _arcount = be_u16(input)?; let queries = (0..qdcount) .map(|_| Query::parse(input, context)) .collect::, _>>()?; Ok(Request { id, flags, queries }) } } impl<'a> DnsSerialize<'a> for Request<'a> { type Error = DnsError; fn serialize<'b>(&self, writer: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { writer.write(&self.id.to_be_bytes()); self.flags.serialize(writer).ok(); writer.write(&(self.queries.len() as u16).to_be_bytes()); writer.write(&ZERO_U16); writer.write(&ZERO_U16); writer.write(&ZERO_U16); self.queries .iter() .try_for_each(|query| query.serialize(writer)) } fn size(&self) -> usize { let total_query_size: usize = self.queries.iter().map(DnsSerialize::size).sum(); core::mem::size_of::() + self.flags.size() + (core::mem::size_of::() * 4) + total_query_size } } #[derive(Debug, PartialEq, Eq)] pub struct Response<'a> { pub id: u16, pub flags: Flags, pub queries: Vec>, pub answers: Vec>, } impl<'a> DnsParse<'a> for Response<'a> { fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult { let id = be_u16(input)?; let flags = Flags::parse(input, context)?; let qdcount = be_u16(input)?; let ancount = be_u16(input)?; let _nscount = be_u16(input)?; let _arcount = be_u16(input)?; let queries = (0..qdcount) .map(|_| Query::parse(input, context)) .collect::, _>>()?; let answers = (0..ancount) .map(|_| Answer::parse(input, context)) .collect::, _>>()?; Ok(Response { id, flags, queries, answers, }) } } impl<'a> DnsSerialize<'a> for Response<'a> { type Error = DnsError; fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { encoder.write(&self.id.to_be_bytes()); self.flags.serialize(encoder).ok(); encoder.write(&(self.queries.len() as u16).to_be_bytes()); encoder.write(&(self.answers.len() as u16).to_be_bytes()); encoder.write(&ZERO_U16); encoder.write(&ZERO_U16); self.queries .iter() .try_for_each(|query| query.serialize(encoder))?; self.answers .iter() .try_for_each(|answer| answer.serialize(encoder)) } fn size(&self) -> usize { let total_query_size: usize = self.queries.iter().map(DnsSerialize::size).sum(); let total_answer_size: usize = self.answers.iter().map(DnsSerialize::size).sum(); core::mem::size_of::() + self.flags.size() + (core::mem::size_of::() * 4) + total_query_size + total_answer_size } } #[cfg(feature = "defmt")] impl<'a> defmt::Format for Request<'a> { fn format(&self, fmt: defmt::Formatter) { defmt::write!( fmt, "Request {{ id: {}, flags: {:?}, queries: {:?} }}", self.id, self.flags, self.queries ); } } #[cfg(feature = "defmt")] impl<'a> defmt::Format for Response<'a> { fn format(&self, fmt: defmt::Formatter) { defmt::write!( fmt, "Response {{ id: {}, flags: {:?}, queries: {:?}, answers: {:?} }}", self.id, self.flags, self.queries, self.answers ); } } #[cfg(test)] mod tests { use alloc::vec; use super::*; use crate::dns::{ label::Label, query::QClass, records::{A, PTR, QType, Record, SRV, TXT}, }; use core::net::Ipv4Addr; #[test] fn parse_query() { let data = [ 0xAA, 0xAA, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x07, 0x65, // example . com in label format 0x78, 0x61, 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, // // 0x00, 0x01, 0x00, 0x01, ]; let request = Request::parse(&mut data.as_slice(), data.as_slice()).unwrap(); assert_eq!(request.id, 0xAAAA); assert_eq!(request.flags.0, 0x0100); assert_eq!(request.queries.len(), 1); assert_eq!(request.queries[0].name, "example.com"); assert_eq!(request.queries[0].qtype, QType::A); assert_eq!(request.queries[0].qclass, QClass::IN); } #[test] fn parse_response() { let data = [ 0xAA, 0xAA, // transaction ID 0x81, 0x80, // flags 0x00, 0x01, // 1 question 0x00, 0x01, // 1 A-answer 0x00, 0x00, // no authority 0x00, 0x00, // no additional answers // example . com in label format 0x07, 0x65, 0x78, 0x61, 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, // // 0x00, 0x01, 0x00, 0x01, // // 0xC0, 0x0C, // ptr to question section // 0x00, 0x01, 0x00, 0x01, // A and IN // 0x00, 0x00, 0x00, 0x3C, // TTL 60 seconds // 0x00, 0x04, // length of address // IP address: 192, 168, 1, 3, ]; let response = Response::parse(&mut data.as_slice(), data.as_slice()).unwrap(); assert_eq!(response.id, 0xAAAA); assert_eq!(response.flags.0, 0x8180); assert_eq!(response.answers.len(), 1); assert_eq!(response.answers[0].name, "example.com"); assert_eq!(response.answers[0].atype, QType::A); assert_eq!(response.answers[0].aclass, QClass::IN); assert_eq!(response.answers[0].ttl, 60); if let Record::A(a) = &response.answers[0].record { assert_eq!(a.address, Ipv4Addr::new(192, 168, 1, 3)); } else { panic!("Expected A record"); } } #[test] fn parse_response_two_records() { #[rustfmt::skip] let data = [ 0xAA, 0xAA, // 0x81, 0x80, // 0x00, 0x01, // 0x00, 0x02, // 0x00, 0x00, // 0x00, 0x00, // // example . com in label format 0x07, 0x65, 0x78, 0x61, 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, // // 0x00, 0x01, // query type 0x00, 0x01, // query class // 0xC0, 0x0C, // pointer 0x00, 0x01, // 0x00, 0x01, // 0x00, 0x00, 0x00, 0x3C, // ttl 60 seconds 0x00, 0x04, // length of A-record 0x5D, 0xB8, 0xD8, 0x22, // a-record // 0x07, 0x65, 0x78, 0x61, 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, // // 0x00, 0x10, // TXT 0x00, 0x01, // IN // 0x00, 0x00, 0x00, 0x3C, // ttl 60 seconds // 0x00, 0x10, // length of txt record // (len) "test txt record" 0x0F, 0x74, 0x65, 0x73, 0x74, 0x20, 0x74, 0x78, 0x74, 0x20, 0x72, 0x65, 0x63, 0x6F, 0x72, 0x64, ]; let response = Response::parse(&mut data.as_slice(), data.as_slice()).unwrap(); assert_eq!(response.id, 0xAAAA); assert_eq!(response.flags.0, 0x8180); assert_eq!(response.answers.len(), 2); // First answer assert_eq!(response.answers[0].name, "example.com"); assert_eq!(response.answers[0].atype, QType::A); assert_eq!(response.answers[0].aclass, QClass::IN); assert_eq!(response.answers[0].ttl, 60); if let Record::A(a) = &response.answers[0].record { assert_eq!(a.address, Ipv4Addr::new(93, 184, 216, 34)); } else { panic!("Expected A record"); } // Second answer assert_eq!(response.answers[1].name, "example.com"); assert_eq!(response.answers[1].atype, QType::TXT); assert_eq!(response.answers[1].aclass, QClass::IN); assert_eq!(response.answers[1].ttl, 60); if let Record::TXT(txt) = &response.answers[1].record && let Some(&text) = txt.text.first() { assert_eq!(text, "test txt record"); } else { panic!("Expected TXT record"); } } #[test] fn parse_response_srv() { let data = [ // 0xAA, 0xAA, // id 0x81, 0x80, // flags 0x00, 0x01, // one question 0x00, 0x01, // one answer 0x00, 0x00, // no authority 0x00, 0x00, // no extra // 0x04, 0x5f, 0x73, 0x69, 0x70, 0x04, 0x5f, 0x74, 0x63, 0x70, 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, // // 0x00, 0x21, // type SRV 0x00, 0x01, // IN // 0xc0, 0x0c, // // 0x00, 0x21, // SRV 0x00, 0x01, // IN 0x00, 0x00, 0x00, 0x3C, // ttl 60 // 0x00, 0x19, // data len 0x00, 0x0A, // prio 0x00, 0x05, // weight 0x13, 0xC4, // PORT // 0x09, 0x73, 0x69, 0x70, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, // ]; let response = Response::parse(&mut data.as_slice(), data.as_slice()).unwrap(); assert_eq!(response.id, 0xAAAA); assert_eq!(response.flags.0, 0x8180); assert_eq!(response.answers.len(), 1); // Answer assert_eq!(response.answers[0].name, "_sip._tcp.example.com"); assert_eq!(response.answers[0].atype, QType::SRV); assert_eq!(response.answers[0].aclass, QClass::IN); assert_eq!(response.answers[0].ttl, 60); let Record::SRV(srv) = &response.answers[0].record else { panic!("Expected SRV record"); }; assert_eq!(srv.priority, 10); assert_eq!(srv.weight, 5); assert_eq!(srv.port, 5060); assert_eq!(srv.target, "sipserver.example.com"); } #[test] fn parse_response_back_forth() { #[rustfmt::skip] let data = [ 0, 0, // Transaction ID 132, 0, // Response, Authoritative Answer, No Recursion 0, 0, // 0 questions 0, 4, // 4 answers 0, 0, // 0 authority RRs 0, 0, // 0 additional RRs // _midiriff 9, 95, 109, 105, 100, 105, 114, 105, 102, 102, // // _udp 4, 95, 117, 100, 112, // // local 5, 108, 111, 99, 97, 108, // 0, // // 0, 12, // PTR 0, 1, // Class IN 0, 0, 0, 120, // TTL 120 seconds 0, 10, // Data Length 10 // pi35291 7, 112, 105, 51, 53, 50, 57, 49, // // 192, 12, // Pointer to _midirif._udp._local. // 192, 44, // Pointer to instace name: pi35291._midirif._udp._local. 0, 33, // SRV 128, 1, // IN (Cache flush bit set) 0, 0, 0, 120, // TTL 120 seconds 0, 11, // Data Length 11 0, 0, // Priority 0 0, 0, // Weight 0 137, 219, // Port 35291 2, 112, 105, // _pi 192, 27, // Pointer to: .local. // TXT (Empty) 192, 44, 0, 16, 128, 1, 0, 0, 17, 148, 0, 1, 0, // A (10.1.1.9) 192, 72, 0, 1, 128, 1, 0, 0, 0, 120, 0, 4, 10, 1, 1, 9, ]; let response = Response::parse(&mut data.as_slice(), data.as_slice()).unwrap(); assert_eq!(response.answers[0].name, "_midiriff._udp.local"); assert_eq!(response.answers[0].ttl, 120); let Record::PTR(ptr) = &response.answers[0].record else { panic!() }; assert_eq!(ptr.name, "pi35291._midiriff._udp.local"); let mut buffer = [0u8; 256]; let mut buffer = Encoder::new(&mut buffer); response.serialize(&mut buffer).unwrap(); let buffer = buffer.finish(); let response2 = Response::parse(&mut &buffer[..], buffer).unwrap(); assert_eq!(response, response2); } #[test] fn mdns_service_response() { let mut response = Response { id: 0x1234, flags: Flags::standard_response(), queries: Vec::new(), answers: Vec::new(), }; let query = Query { name: Label::from("_test._udp.local"), qtype: QType::PTR, qclass: QClass::IN, }; response.queries.push(query); let ptr_answer = Answer { name: Label::from("_test._udp.local"), atype: QType::PTR, aclass: QClass::IN, ttl: 4500, record: Record::PTR(PTR { name: Label::from("test-service._test._udp.local"), }), }; response.answers.push(ptr_answer); let srv_answer = Answer { name: Label::from("test-service._test._udp.local"), atype: QType::SRV, aclass: QClass::IN, ttl: 120, record: Record::SRV(SRV { priority: 0, weight: 0, port: 8080, target: Label::from("host.local"), }), }; response.answers.push(srv_answer); let txt_answer = Answer { name: Label::from("test-service._test._udp.local"), atype: QType::TXT, aclass: QClass::IN, ttl: 120, record: Record::TXT(TXT { text: vec!["path=/test"], }), }; response.answers.push(txt_answer); let a_answer = Answer { name: Label::from("host.local"), atype: QType::A, aclass: QClass::IN, ttl: 120, record: Record::A(A { address: Ipv4Addr::new(192, 168, 1, 100), }), }; response.answers.push(a_answer); let mut buffer = [0u8; 256]; let mut buffer = Encoder::new(&mut buffer); response.serialize(&mut buffer).unwrap(); let buffer = buffer.finish(); let parsed_response = Response::parse(&mut &buffer[..], buffer).unwrap(); assert_eq!(response, parsed_response); } }