Repo of no-std crates for my personal embedded projects
at main 15 kB view raw
1use alloc::vec::Vec; 2use winnow::ModalResult; 3use winnow::binary::be_u16; 4 5use super::flags::Flags; 6use super::query::{Answer, Query}; 7use crate::{ 8 dns::traits::{DnsParse, DnsSerialize}, 9 encoder::{DnsError, Encoder}, 10}; 11 12const ZERO_U16: [u8; 2] = 0u16.to_be_bytes(); 13 14#[derive(Debug, PartialEq, Eq)] 15pub struct Request<'a> { 16 pub id: u16, 17 pub flags: Flags, 18 pub(crate) queries: Vec<Query<'a>>, 19} 20 21impl<'a> DnsParse<'a> for Request<'a> { 22 fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self> { 23 let id = be_u16(input)?; 24 let flags = Flags::parse(input, context)?; 25 let qdcount = be_u16(input)?; 26 let _ancount = be_u16(input)?; 27 let _nscount = be_u16(input)?; 28 let _arcount = be_u16(input)?; 29 let queries = (0..qdcount) 30 .map(|_| Query::parse(input, context)) 31 .collect::<Result<Vec<_>, _>>()?; 32 Ok(Request { id, flags, queries }) 33 } 34} 35 36impl<'a> DnsSerialize<'a> for Request<'a> { 37 type Error = DnsError; 38 39 fn serialize<'b>(&self, writer: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { 40 writer.write(&self.id.to_be_bytes()); 41 self.flags.serialize(writer).ok(); 42 writer.write(&(self.queries.len() as u16).to_be_bytes()); 43 writer.write(&ZERO_U16); 44 writer.write(&ZERO_U16); 45 writer.write(&ZERO_U16); 46 47 self.queries 48 .iter() 49 .try_for_each(|query| query.serialize(writer)) 50 } 51 52 fn size(&self) -> usize { 53 let total_query_size: usize = self.queries.iter().map(DnsSerialize::size).sum(); 54 55 core::mem::size_of::<u16>() 56 + self.flags.size() 57 + (core::mem::size_of::<u16>() * 4) 58 + total_query_size 59 } 60} 61 62#[derive(Debug, PartialEq, Eq)] 63pub struct Response<'a> { 64 pub id: u16, 65 pub flags: Flags, 66 pub queries: Vec<Query<'a>>, 67 pub answers: Vec<Answer<'a>>, 68} 69 70impl<'a> DnsParse<'a> for Response<'a> { 71 fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self> { 72 let id = be_u16(input)?; 73 let flags = Flags::parse(input, context)?; 74 let qdcount = be_u16(input)?; 75 let ancount = be_u16(input)?; 76 let _nscount = be_u16(input)?; 77 let _arcount = be_u16(input)?; 78 79 let queries = (0..qdcount) 80 .map(|_| Query::parse(input, context)) 81 .collect::<Result<Vec<_>, _>>()?; 82 83 let answers = (0..ancount) 84 .map(|_| Answer::parse(input, context)) 85 .collect::<Result<Vec<_>, _>>()?; 86 87 Ok(Response { 88 id, 89 flags, 90 queries, 91 answers, 92 }) 93 } 94} 95 96impl<'a> DnsSerialize<'a> for Response<'a> { 97 type Error = DnsError; 98 99 fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { 100 encoder.write(&self.id.to_be_bytes()); 101 self.flags.serialize(encoder).ok(); 102 encoder.write(&(self.queries.len() as u16).to_be_bytes()); 103 encoder.write(&(self.answers.len() as u16).to_be_bytes()); 104 encoder.write(&ZERO_U16); 105 encoder.write(&ZERO_U16); 106 107 self.queries 108 .iter() 109 .try_for_each(|query| query.serialize(encoder))?; 110 self.answers 111 .iter() 112 .try_for_each(|answer| answer.serialize(encoder)) 113 } 114 115 fn size(&self) -> usize { 116 let total_query_size: usize = self.queries.iter().map(DnsSerialize::size).sum(); 117 let total_answer_size: usize = self.answers.iter().map(DnsSerialize::size).sum(); 118 119 core::mem::size_of::<u16>() 120 + self.flags.size() 121 + (core::mem::size_of::<u16>() * 4) 122 + total_query_size 123 + total_answer_size 124 } 125} 126 127#[cfg(feature = "defmt")] 128impl<'a> defmt::Format for Request<'a> { 129 fn format(&self, fmt: defmt::Formatter) { 130 defmt::write!( 131 fmt, 132 "Request {{ id: {}, flags: {:?}, queries: {:?} }}", 133 self.id, 134 self.flags, 135 self.queries 136 ); 137 } 138} 139 140#[cfg(feature = "defmt")] 141impl<'a> defmt::Format for Response<'a> { 142 fn format(&self, fmt: defmt::Formatter) { 143 defmt::write!( 144 fmt, 145 "Response {{ id: {}, flags: {:?}, queries: {:?}, answers: {:?} }}", 146 self.id, 147 self.flags, 148 self.queries, 149 self.answers 150 ); 151 } 152} 153 154#[cfg(test)] 155mod tests { 156 use alloc::vec; 157 158 use super::*; 159 use crate::dns::{ 160 label::Label, 161 query::QClass, 162 records::{A, PTR, QType, Record, SRV, TXT}, 163 }; 164 use core::net::Ipv4Addr; 165 166 #[test] 167 fn parse_query() { 168 let data = [ 169 0xAA, 0xAA, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x07, 0x65, 170 // example . com in label format 171 0x78, 0x61, 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, // 172 // 173 0x00, 0x01, 0x00, 0x01, 174 ]; 175 176 let request = Request::parse(&mut data.as_slice(), data.as_slice()).unwrap(); 177 178 assert_eq!(request.id, 0xAAAA); 179 assert_eq!(request.flags.0, 0x0100); 180 assert_eq!(request.queries.len(), 1); 181 assert_eq!(request.queries[0].name, "example.com"); 182 assert_eq!(request.queries[0].qtype, QType::A); 183 assert_eq!(request.queries[0].qclass, QClass::IN); 184 } 185 186 #[test] 187 fn parse_response() { 188 let data = [ 189 0xAA, 0xAA, // transaction ID 190 0x81, 0x80, // flags 191 0x00, 0x01, // 1 question 192 0x00, 0x01, // 1 A-answer 193 0x00, 0x00, // no authority 194 0x00, 0x00, // no additional answers 195 // example . com in label format 196 0x07, 0x65, 0x78, 0x61, 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, // 197 // 198 0x00, 0x01, 0x00, 0x01, // 199 // 200 0xC0, 0x0C, // ptr to question section 201 // 202 0x00, 0x01, 0x00, 0x01, // A and IN 203 // 204 0x00, 0x00, 0x00, 0x3C, // TTL 60 seconds 205 // 206 0x00, 0x04, // length of address 207 // IP address: 208 192, 168, 1, 3, 209 ]; 210 211 let response = Response::parse(&mut data.as_slice(), data.as_slice()).unwrap(); 212 213 assert_eq!(response.id, 0xAAAA); 214 assert_eq!(response.flags.0, 0x8180); 215 assert_eq!(response.answers.len(), 1); 216 assert_eq!(response.answers[0].name, "example.com"); 217 assert_eq!(response.answers[0].atype, QType::A); 218 assert_eq!(response.answers[0].aclass, QClass::IN); 219 assert_eq!(response.answers[0].ttl, 60); 220 if let Record::A(a) = &response.answers[0].record { 221 assert_eq!(a.address, Ipv4Addr::new(192, 168, 1, 3)); 222 } else { 223 panic!("Expected A record"); 224 } 225 } 226 227 #[test] 228 fn parse_response_two_records() { 229 #[rustfmt::skip] 230 let data = [ 231 0xAA, 0xAA, // 232 0x81, 0x80, // 233 0x00, 0x01, // 234 0x00, 0x02, // 235 0x00, 0x00, // 236 0x00, 0x00, // 237 // example . com in label format 238 0x07, 0x65, 0x78, 0x61, 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, // 239 // 240 0x00, 0x01, // query type 241 0x00, 0x01, // query class 242 // 243 0xC0, 0x0C, // pointer 244 0x00, 0x01, // 245 0x00, 0x01, // 246 0x00, 0x00, 0x00, 0x3C, // ttl 60 seconds 247 0x00, 0x04, // length of A-record 248 0x5D, 0xB8, 0xD8, 0x22, // a-record 249 // 250 0x07, 0x65, 0x78, 0x61, 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, // 251 // 252 0x00, 0x10, // TXT 253 0x00, 0x01, // IN 254 // 255 0x00, 0x00, 0x00, 0x3C, // ttl 60 seconds 256 // 257 0x00, 0x10, // length of txt record 258 // (len) "test txt record" 259 0x0F, 0x74, 0x65, 0x73, 0x74, 0x20, 0x74, 0x78, 0x74, 0x20, 0x72, 0x65, 0x63, 0x6F, 0x72, 260 0x64, 261 ]; 262 263 let response = Response::parse(&mut data.as_slice(), data.as_slice()).unwrap(); 264 265 assert_eq!(response.id, 0xAAAA); 266 assert_eq!(response.flags.0, 0x8180); 267 assert_eq!(response.answers.len(), 2); 268 269 // First answer 270 assert_eq!(response.answers[0].name, "example.com"); 271 assert_eq!(response.answers[0].atype, QType::A); 272 assert_eq!(response.answers[0].aclass, QClass::IN); 273 assert_eq!(response.answers[0].ttl, 60); 274 if let Record::A(a) = &response.answers[0].record { 275 assert_eq!(a.address, Ipv4Addr::new(93, 184, 216, 34)); 276 } else { 277 panic!("Expected A record"); 278 } 279 280 // Second answer 281 assert_eq!(response.answers[1].name, "example.com"); 282 assert_eq!(response.answers[1].atype, QType::TXT); 283 assert_eq!(response.answers[1].aclass, QClass::IN); 284 assert_eq!(response.answers[1].ttl, 60); 285 if let Record::TXT(txt) = &response.answers[1].record 286 && let Some(&text) = txt.text.first() 287 { 288 assert_eq!(text, "test txt record"); 289 } else { 290 panic!("Expected TXT record"); 291 } 292 } 293 294 #[test] 295 fn parse_response_srv() { 296 let data = [ 297 // 298 0xAA, 0xAA, // id 299 0x81, 0x80, // flags 300 0x00, 0x01, // one question 301 0x00, 0x01, // one answer 302 0x00, 0x00, // no authority 303 0x00, 0x00, // no extra 304 // 305 0x04, 0x5f, 0x73, 0x69, 0x70, 0x04, 0x5f, 0x74, 0x63, 0x70, 0x07, 0x65, 0x78, 0x61, 306 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, // 307 // 308 0x00, 0x21, // type SRV 309 0x00, 0x01, // IN 310 // 311 0xc0, 0x0c, // 312 // 313 0x00, 0x21, // SRV 314 0x00, 0x01, // IN 315 0x00, 0x00, 0x00, 0x3C, // ttl 60 316 // 317 0x00, 0x19, // data len 318 0x00, 0x0A, // prio 319 0x00, 0x05, // weight 320 0x13, 0xC4, // PORT 321 // 322 0x09, 0x73, 0x69, 0x70, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x07, 0x65, 0x78, 0x61, 323 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, // 324 ]; 325 326 let response = Response::parse(&mut data.as_slice(), data.as_slice()).unwrap(); 327 328 assert_eq!(response.id, 0xAAAA); 329 assert_eq!(response.flags.0, 0x8180); 330 assert_eq!(response.answers.len(), 1); 331 332 // Answer 333 assert_eq!(response.answers[0].name, "_sip._tcp.example.com"); 334 assert_eq!(response.answers[0].atype, QType::SRV); 335 assert_eq!(response.answers[0].aclass, QClass::IN); 336 assert_eq!(response.answers[0].ttl, 60); 337 let Record::SRV(srv) = &response.answers[0].record else { 338 panic!("Expected SRV record"); 339 }; 340 341 assert_eq!(srv.priority, 10); 342 assert_eq!(srv.weight, 5); 343 assert_eq!(srv.port, 5060); 344 assert_eq!(srv.target, "sipserver.example.com"); 345 } 346 347 #[test] 348 fn parse_response_back_forth() { 349 #[rustfmt::skip] 350 let data = [ 351 0, 0, // Transaction ID 352 132, 0, // Response, Authoritative Answer, No Recursion 353 0, 0, // 0 questions 354 0, 4, // 4 answers 355 0, 0, // 0 authority RRs 356 0, 0, // 0 additional RRs 357 // _midiriff 358 9, 95, 109, 105, 100, 105, 114, 105, 102, 102, // 359 // _udp 360 4, 95, 117, 100, 112, // 361 // local 362 5, 108, 111, 99, 97, 108, // 363 0, // <end> 364 // 365 0, 12, // PTR 366 0, 1, // Class IN 367 0, 0, 0, 120, // TTL 120 seconds 368 0, 10, // Data Length 10 369 // pi35291 370 7, 112, 105, 51, 53, 50, 57, 49, // 371 // 372 192, 12, // Pointer to _midirif._udp._local. 373 // 374 192, 44, // Pointer to instace name: pi35291._midirif._udp._local. 375 0, 33, // SRV 376 128, 1, // IN (Cache flush bit set) 377 0, 0, 0, 120, // TTL 120 seconds 378 0, 11, // Data Length 11 379 0, 0, // Priority 0 380 0, 0, // Weight 0 381 137, 219, // Port 35291 382 2, 112, 105, // _pi 383 192, 27, // Pointer to: .local. 384 // TXT (Empty) 385 192, 44, 0, 16, 128, 1, 0, 0, 17, 148, 0, 1, 0, 386 // A (10.1.1.9) 387 192, 72, 0, 1, 128, 1, 0, 0, 0, 120, 0, 4, 10, 1, 1, 9, 388 ]; 389 390 let response = Response::parse(&mut data.as_slice(), data.as_slice()).unwrap(); 391 392 assert_eq!(response.answers[0].name, "_midiriff._udp.local"); 393 assert_eq!(response.answers[0].ttl, 120); 394 let Record::PTR(ptr) = &response.answers[0].record else { 395 panic!() 396 }; 397 assert_eq!(ptr.name, "pi35291._midiriff._udp.local"); 398 399 let mut buffer = [0u8; 256]; 400 let mut buffer = Encoder::new(&mut buffer); 401 response.serialize(&mut buffer).unwrap(); 402 403 let buffer = buffer.finish(); 404 405 let response2 = Response::parse(&mut &buffer[..], buffer).unwrap(); 406 407 assert_eq!(response, response2); 408 } 409 410 #[test] 411 fn mdns_service_response() { 412 let mut response = Response { 413 id: 0x1234, 414 flags: Flags::standard_response(), 415 queries: Vec::new(), 416 answers: Vec::new(), 417 }; 418 419 let query = Query { 420 name: Label::from("_test._udp.local"), 421 qtype: QType::PTR, 422 qclass: QClass::IN, 423 }; 424 response.queries.push(query); 425 426 let ptr_answer = Answer { 427 name: Label::from("_test._udp.local"), 428 atype: QType::PTR, 429 aclass: QClass::IN, 430 ttl: 4500, 431 record: Record::PTR(PTR { 432 name: Label::from("test-service._test._udp.local"), 433 }), 434 }; 435 response.answers.push(ptr_answer); 436 437 let srv_answer = Answer { 438 name: Label::from("test-service._test._udp.local"), 439 atype: QType::SRV, 440 aclass: QClass::IN, 441 ttl: 120, 442 record: Record::SRV(SRV { 443 priority: 0, 444 weight: 0, 445 port: 8080, 446 target: Label::from("host.local"), 447 }), 448 }; 449 response.answers.push(srv_answer); 450 451 let txt_answer = Answer { 452 name: Label::from("test-service._test._udp.local"), 453 atype: QType::TXT, 454 aclass: QClass::IN, 455 ttl: 120, 456 record: Record::TXT(TXT { 457 text: vec!["path=/test"], 458 }), 459 }; 460 response.answers.push(txt_answer); 461 462 let a_answer = Answer { 463 name: Label::from("host.local"), 464 atype: QType::A, 465 aclass: QClass::IN, 466 ttl: 120, 467 record: Record::A(A { 468 address: Ipv4Addr::new(192, 168, 1, 100), 469 }), 470 }; 471 response.answers.push(a_answer); 472 473 let mut buffer = [0u8; 256]; 474 let mut buffer = Encoder::new(&mut buffer); 475 response.serialize(&mut buffer).unwrap(); 476 477 let buffer = buffer.finish(); 478 479 let parsed_response = Response::parse(&mut &buffer[..], buffer).unwrap(); 480 481 assert_eq!(response, parsed_response); 482 } 483}