use winnow::binary::{be_u16, be_u32}; use winnow::{ModalResult, Parser}; use super::label::Label; use super::records::Record; use crate::encoder::Encoder; use crate::{ dns::{ records::QType, traits::{DnsParse, DnsParseKind, DnsSerialize}, }, encoder::DnsError, }; #[derive(Debug, PartialEq, Eq)] pub struct Query<'a> { pub name: Label<'a>, pub qtype: QType, pub qclass: QClass, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u16)] pub enum QClass { IN = 1, Multicast = 32769, // (IN + Cache flush bit) Unknown(u16), } impl<'a> DnsParse<'a> for Query<'a> { fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult { let name = Label::parse(input, context)?; let qtype = QType::parse(input, context)?; let qclass = be_u16.map(QClass::from_u16).parse_next(input)?; Ok(Query { name, qtype, qclass, }) } } impl<'a> DnsSerialize<'a> for Query<'a> { type Error = DnsError; fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { self.name.serialize(encoder)?; self.qtype.serialize(encoder).ok(); encoder.write(&self.qclass.to_u16().to_be_bytes()); Ok(()) } fn size(&self) -> usize { self.name.size() + self.qtype.size() + core::mem::size_of::() } } #[derive(Debug, PartialEq, Eq)] pub struct Answer<'a> { pub name: Label<'a>, pub atype: QType, pub aclass: QClass, pub ttl: u32, pub record: Record<'a>, } impl QClass { fn from_u16(value: u16) -> Self { match value { 1 => QClass::IN, 32769 => QClass::Multicast, _ => QClass::Unknown(value), } } fn to_u16(self) -> u16 { match self { QClass::IN => 1, QClass::Multicast => 32769, QClass::Unknown(value) => value, } } } impl<'a> DnsParse<'a> for Answer<'a> { fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult { let name = Label::parse(input, context)?; let atype = QType::parse(input, context)?; let aclass = be_u16.map(QClass::from_u16).parse_next(input)?; let ttl = be_u32.parse_next(input)?; let record = atype.parse_kind(input, context)?; Ok(Answer { name, atype, aclass, ttl, record, }) } } impl<'a> DnsSerialize<'a> for Answer<'a> { type Error = DnsError; fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { self.name.serialize(encoder)?; self.atype.serialize(encoder).ok(); encoder.write(&self.aclass.to_u16().to_be_bytes()); encoder.write(&self.ttl.to_be_bytes()); self.record.serialize(encoder) } fn size(&self) -> usize { self.name.size() + self.atype.size() + core::mem::size_of::() + core::mem::size_of::() + self.record.size() } } #[cfg(feature = "defmt")] impl<'a> defmt::Format for Query<'a> { fn format(&self, fmt: defmt::Formatter) { defmt::write!( fmt, "Query {{ name: {:?}, qtype: {:?}, qclass: {:?} }}", self.name, self.qtype, self.qclass ); } } #[cfg(feature = "defmt")] impl defmt::Format for QType { fn format(&self, fmt: defmt::Formatter) { let qtype_str = match self { QType::A => "A", QType::AAAA => "AAAA", QType::PTR => "PTR", QType::TXT => "TXT", QType::SRV => "SRV", QType::Any => "Any", QType::Unknown(_) => "Unknown", }; defmt::write!(fmt, "QType({=str})", qtype_str); } } #[cfg(feature = "defmt")] impl defmt::Format for QClass { fn format(&self, fmt: defmt::Formatter) { let qclass_str = match self { QClass::IN => "IN", QClass::Multicast => "Multicast", QClass::Unknown(_) => "Unknown", }; defmt::write!(fmt, "QClass({=str})", qclass_str); } } #[cfg(feature = "defmt")] impl<'a> defmt::Format for Answer<'a> { fn format(&self, fmt: defmt::Formatter) { defmt::write!( fmt, "Answer {{ name: {:?}, atype: {:?}, aclass: {:?}, ttl: {}, record: {:?} }}", self.name, self.atype, self.aclass, self.ttl, self.record ); } } #[cfg(test)] mod tests { use super::*; use crate::dns::records::A; use core::net::Ipv4Addr; #[test] fn roundtrip_query() { let name = Label::from("example.local"); let query = Query { name, qtype: QType::A, qclass: QClass::IN, }; let mut buffer = [0u8; 256]; let mut buffer = Encoder::new(&mut buffer); query.serialize(&mut buffer).unwrap(); let buffer = buffer.finish(); let parsed_query = Query::parse(&mut &buffer[..], buffer).unwrap(); assert_eq!(query, parsed_query); } #[test] fn roundtrip_answer() { let name = Label::from("example.local"); let answer: Answer = Answer { name, atype: QType::A, aclass: QClass::IN, ttl: 120, record: Record::A(A { address: Ipv4Addr::new(192, 168, 1, 1), }), }; let mut buffer = [0u8; 256]; let mut buffer = Encoder::new(&mut buffer); answer.serialize(&mut buffer).unwrap(); let buffer = buffer.finish(); let parsed_answer = Answer::parse(&mut &buffer[..], buffer).unwrap(); assert_eq!(answer, parsed_answer); } }