Repo of no-std crates for my personal embedded projects

feat: WIP MDNS crate

Adds a mdns resolver/state-machine crate for providing MDNS-SD functionality for an embedded device.

+21
.tangled/workflows/miri.yml
···
+
when:
+
- event: ["push", "pull_request"]
+
branch: main
+
+
engine: nixery
+
+
dependencies:
+
nixpkgs:
+
- clang
+
- rustup
+
+
steps:
+
- name: Install Nightly
+
command: |
+
rustup toolchain install nightly --component miri
+
rustup override set nightly
+
cargo miri setup
+
- name: Miri Test
+
command: cargo miri test --locked -p sachy-mdns
+
environment:
+
RUSTFLAGS: -Zrandomize-layout
+19 -3
Cargo.lock
···
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
dependencies = [
"libc",
-
"windows-sys 0.52.0",
+
"windows-sys 0.61.2",
]
[[package]]
···
"errno",
"libc",
"linux-raw-sys 0.11.0",
-
"windows-sys 0.52.0",
+
"windows-sys 0.61.2",
[[package]]
···
version = "0.1.0"
[[package]]
+
name = "sachy-mdns"
+
version = "0.1.0"
+
dependencies = [
+
"defmt 1.0.1",
+
"embassy-time",
+
"sachy-fmt",
+
"winnow",
+
]
+
+
[[package]]
name = "sachy-shtc3"
version = "0.1.0"
dependencies = [
···
"getrandom",
"once_cell",
"rustix 1.1.2",
-
"windows-sys 0.52.0",
+
"windows-sys 0.61.2",
[[package]]
···
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
+
+
[[package]]
+
name = "winnow"
+
version = "0.7.14"
+
source = "registry+https://github.com/rust-lang/crates.io-index"
+
checksum = "5a5364e9d77fcdeeaa6062ced926ee3381faa2ee02d3eb83a5c27a8825540829"
[[package]]
name = "wit-bindgen"
+1
Cargo.toml
···
"sachy-esphome",
"sachy-fmt",
"sachy-fnv",
+
"sachy-mdns",
"sachy-shtc3",
"sachy-sntp",
]
+22
sachy-mdns/Cargo.toml
···
+
[package]
+
name = "sachy-mdns"
+
authors.workspace = true
+
edition.workspace = true
+
repository.workspace = true
+
license.workspace = true
+
version.workspace = true
+
rust-version.workspace = true
+
+
[dependencies]
+
defmt = { workspace = true, optional = true, features = ["alloc"] }
+
embassy-time = { workspace = true }
+
sachy-fmt = { path = "../sachy-fmt" }
+
winnow = { version = "0.7.12", default-features = false }
+
+
[features]
+
default = []
+
std = []
+
defmt = ["dep:defmt"]
+
+
[dev-dependencies]
+
winnow = { version = "0.7.12", default-features = false, features = ["alloc"] }
+6
sachy-mdns/src/dns.rs
···
+
pub(crate) mod flags;
+
pub(crate) mod label;
+
pub(crate) mod query;
+
pub(crate) mod records;
+
pub(crate) mod reqres;
+
pub mod traits;
+231
sachy-mdns/src/dns/flags.rs
···
+
#![allow(dead_code)]
+
+
use core::{convert::Infallible, fmt};
+
use winnow::{ModalResult, Parser, binary::be_u16};
+
+
use crate::{
+
dns::traits::{DnsParse, DnsSerialize},
+
encoder::Encoder,
+
};
+
+
#[derive(Default, Clone, Copy, PartialEq, Eq)]
+
pub struct Flags(pub u16);
+
+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+
#[repr(u8)]
+
pub enum Opcode {
+
Query = 0,
+
IQuery = 1,
+
Status = 2,
+
Reserved = 3,
+
Notify = 4,
+
Update = 5,
+
// Other values are reserved
+
}
+
+
impl Opcode {
+
const fn cast(value: u8) -> Self {
+
match value {
+
0 => Opcode::Query,
+
1 => Opcode::IQuery,
+
2 => Opcode::Status,
+
4 => Opcode::Notify,
+
5 => Opcode::Update,
+
_ => Opcode::Reserved,
+
}
+
}
+
}
+
+
impl From<u8> for Opcode {
+
fn from(value: u8) -> Self {
+
Self::cast(value)
+
}
+
}
+
+
impl From<Opcode> for u8 {
+
fn from(opcode: Opcode) -> Self {
+
opcode as u8
+
}
+
}
+
+
impl Flags {
+
const fn new() -> Self {
+
Flags(0)
+
}
+
+
pub const fn standard_request() -> Self {
+
let mut flags = Flags::new();
+
flags.set_query(true);
+
flags.set_opcode(Opcode::Query);
+
flags.set_recursion_desired(true);
+
flags
+
}
+
+
pub const fn standard_response() -> Self {
+
let mut flags = Flags::new();
+
flags.set_query(false);
+
flags.set_opcode(Opcode::Query);
+
flags.set_authoritative(true);
+
flags.set_recursion_available(false);
+
flags
+
}
+
+
// QR: Query/Response Flag
+
pub const fn is_query(&self) -> bool {
+
(self.0 & 0x8000) == 0
+
}
+
+
pub const fn set_query(&mut self, is_query: bool) {
+
if is_query {
+
self.0 &= !0x8000;
+
} else {
+
self.0 |= 0x8000;
+
}
+
}
+
+
// Opcode (bits 1-4)
+
pub const fn get_opcode(&self) -> Opcode {
+
Opcode::cast(((self.0 >> 11) & 0x0F) as u8)
+
}
+
+
pub const fn set_opcode(&mut self, opcode: Opcode) {
+
self.0 = (self.0 & !0x7800) | (((opcode as u8) as u16 & 0x0F) << 11);
+
}
+
+
// AA: Authoritative Answer
+
pub const fn is_authoritative(&self) -> bool {
+
(self.0 & 0x0400) != 0
+
}
+
+
pub const fn set_authoritative(&mut self, authoritative: bool) {
+
if authoritative {
+
self.0 |= 0x0400;
+
} else {
+
self.0 &= !0x0400;
+
}
+
}
+
+
// TC: Truncated
+
pub const fn is_truncated(&self) -> bool {
+
(self.0 & 0x0200) != 0
+
}
+
+
pub const fn set_truncated(&mut self, truncated: bool) {
+
if truncated {
+
self.0 |= 0x0200;
+
} else {
+
self.0 &= !0x0200;
+
}
+
}
+
+
// RD: Recursion Desired
+
pub const fn is_recursion_desired(&self) -> bool {
+
(self.0 & 0x0100) != 0
+
}
+
+
pub const fn set_recursion_desired(&mut self, recursion_desired: bool) {
+
if recursion_desired {
+
self.0 |= 0x0100;
+
} else {
+
self.0 &= !0x0100;
+
}
+
}
+
+
// RA: Recursion Available
+
pub const fn is_recursion_available(&self) -> bool {
+
(self.0 & 0x0080) != 0
+
}
+
+
pub const fn set_recursion_available(&mut self, recursion_available: bool) {
+
if recursion_available {
+
self.0 |= 0x0080;
+
} else {
+
self.0 &= !0x0080;
+
}
+
}
+
+
// Z: Reserved for future use (bits 9-11)
+
pub const fn get_reserved(&self) -> u8 {
+
((self.0 >> 4) & 0x07) as u8
+
}
+
+
pub const fn set_reserved(&mut self, reserved: u8) {
+
self.0 = (self.0 & !0x0070) | ((reserved as u16 & 0x07) << 4);
+
}
+
+
// RCODE: Response Code (bits 12-15)
+
pub const fn get_rcode(&self) -> u8 {
+
(self.0 & 0x000F) as u8
+
}
+
+
pub const fn set_rcode(&mut self, rcode: u8) {
+
self.0 = (self.0 & !0x000F) | (rcode as u16 & 0x0F);
+
}
+
}
+
+
impl<'a> DnsParse<'a> for Flags {
+
fn parse(input: &mut &'a [u8], _context: &'a [u8]) -> ModalResult<Self> {
+
be_u16.map(Flags).parse_next(input)
+
}
+
}
+
+
impl<'a> DnsSerialize<'a> for Flags {
+
type Error = Infallible;
+
+
fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> {
+
encoder.write(&self.0.to_be_bytes());
+
Ok(())
+
}
+
+
fn size(&self) -> usize {
+
core::mem::size_of::<u16>()
+
}
+
}
+
+
impl fmt::Debug for Flags {
+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+
f.debug_struct("Flags")
+
.field("query", &self.is_query())
+
.field("opcode", &self.get_opcode())
+
.field("authoritative", &self.is_authoritative())
+
.field("truncated", &self.is_truncated())
+
.field("recursion_desired", &self.is_recursion_desired())
+
.field("recursion_available", &self.is_recursion_available())
+
.field("reserved", &self.get_reserved())
+
.field("rcode", &self.get_rcode())
+
.finish()
+
}
+
}
+
+
#[cfg(feature = "defmt")]
+
impl defmt::Format for Flags {
+
fn format(&self, fmt: defmt::Formatter) {
+
defmt::write!(
+
fmt,
+
"Flags {{ query: {}, opcode: {:?}, authoritative: {}, truncated: {}, recursion_desired: {}, recursion_available: {}, reserved: {}, rcode: {} }}",
+
self.is_query(),
+
self.get_opcode(),
+
self.is_authoritative(),
+
self.is_truncated(),
+
self.is_recursion_desired(),
+
self.is_recursion_available(),
+
self.get_reserved(),
+
self.get_rcode()
+
);
+
}
+
}
+
+
#[cfg(feature = "defmt")]
+
impl defmt::Format for Opcode {
+
fn format(&self, fmt: defmt::Formatter) {
+
let opcode_str = match self {
+
Opcode::Query => "Query",
+
Opcode::IQuery => "IQuery",
+
Opcode::Status => "Status",
+
Opcode::Reserved => "Reserved",
+
Opcode::Notify => "Notify",
+
Opcode::Update => "Update",
+
};
+
defmt::write!(fmt, "Opcode({=str})", opcode_str);
+
}
+
}
+512
sachy-mdns/src/dns/label.rs
···
+
use core::{fmt, str};
+
+
use winnow::{
+
ModalResult, Parser,
+
binary::be_u8,
+
error::{ContextError, ErrMode, FromExternalError},
+
stream::Offset,
+
token::take,
+
};
+
+
use crate::{
+
dns::traits::{DnsParse, DnsSerialize},
+
encoder::{DnsError, Encoder, MAX_STR_LEN, PTR_MASK},
+
};
+
+
#[derive(Clone, Copy)]
+
pub struct Label<'a> {
+
repr: LabelRepr<'a>,
+
}
+
+
impl<'a> From<&'a str> for Label<'a> {
+
fn from(value: &'a str) -> Self {
+
Self {
+
repr: LabelRepr::Str(value),
+
}
+
}
+
}
+
+
impl<'a> DnsSerialize<'a> for Label<'a> {
+
type Error = DnsError;
+
+
fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> {
+
match self.repr {
+
LabelRepr::Bytes {
+
context,
+
start,
+
end,
+
} => {
+
encoder.write(&context[start..end]);
+
+
Ok(())
+
}
+
LabelRepr::Str(label) => encoder.write_label(label),
+
}
+
}
+
+
fn size(&self) -> usize {
+
match self.repr {
+
LabelRepr::Bytes {
+
context,
+
start,
+
end,
+
} => core::mem::size_of_val(&context[start..end]),
+
LabelRepr::Str(label) => core::mem::size_of_val(label) + 1,
+
}
+
}
+
}
+
+
impl<'a> DnsParse<'a> for Label<'a> {
+
fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self> {
+
let start = input.offset_from(&context);
+
let mut end = start;
+
+
loop {
+
match LabelSegment::parse(input)? {
+
LabelSegment::Empty => {
+
end += 1;
+
break;
+
}
+
LabelSegment::String(label) => {
+
end += 1 + label.len();
+
}
+
LabelSegment::Pointer(_) => {
+
end += 2;
+
break;
+
}
+
}
+
}
+
+
Ok(Self {
+
repr: LabelRepr::Bytes {
+
context,
+
start,
+
end,
+
},
+
})
+
}
+
}
+
+
impl Label<'_> {
+
pub fn segments(&self) -> impl Iterator<Item = LabelSegment<'_>> {
+
self.repr.iter()
+
}
+
+
pub fn names(&self) -> impl Iterator<Item = &'_ str> {
+
match self.repr {
+
LabelRepr::Str(view) => Either::A(view.split('.')),
+
LabelRepr::Bytes { context, start, .. } => Either::B(
+
LabelSegmentBytesIter::new(context, start).flat_map(|label| label.as_str()),
+
),
+
}
+
}
+
+
pub fn is_empty(&self) -> bool {
+
self.repr.iter().next().is_none()
+
}
+
}
+
+
#[derive(Clone, Copy, PartialEq, Eq)]
+
enum LabelRepr<'a> {
+
Bytes {
+
context: &'a [u8],
+
start: usize,
+
end: usize,
+
},
+
Str(&'a str),
+
}
+
+
/// A DNS-compatible label segment.
+
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
+
pub enum LabelSegment<'a> {
+
/// The empty terminator.
+
Empty,
+
+
/// A string label.
+
String(&'a str),
+
+
/// A pointer to a previous name.
+
Pointer(u16),
+
}
+
+
impl<'a> LabelSegment<'a> {
+
fn parse(input: &mut &'a [u8]) -> ModalResult<Self> {
+
let b1 = be_u8(input)?;
+
+
match b1 {
+
0 => Ok(Self::Empty),
+
b1 if b1 & PTR_MASK == PTR_MASK => {
+
let b2 = be_u8(input)?;
+
+
let ptr = u16::from_be_bytes([b1 & !PTR_MASK, b2]);
+
+
Ok(Self::Pointer(ptr))
+
}
+
len => {
+
if len > MAX_STR_LEN {
+
return Err(ErrMode::Cut(ContextError::from_external_error(
+
input,
+
DnsError::LabelTooLong,
+
)));
+
}
+
+
let segment = take(len).try_map(core::str::from_utf8).parse_next(input)?;
+
+
Ok(Self::String(segment))
+
}
+
}
+
}
+
+
/// ## Safety
+
/// The caller upholds that this function is not called when parsing from newly received data. Data that
+
/// has yet to be determined to be a valid [`Label`] should be parsed and validated with [`Label::parse`],
+
/// and that the entire data/context has been validated, not just a portion of it.
+
#[inline]
+
unsafe fn parse_unchecked(input: &'a [u8]) -> Option<Self> {
+
input.split_first().map(|(b1, input)| match *b1 {
+
0 => Self::Empty,
+
b1 if b1 & PTR_MASK == PTR_MASK => {
+
// SAFETY: The caller has already validated that a second byte is available for a
+
// Pointer segment.
+
let b2 = unsafe { *input.get_unchecked(0) };
+
+
let ptr = u16::from_be_bytes([b1 & !PTR_MASK, b2]);
+
+
Self::Pointer(ptr)
+
}
+
len => {
+
// SAFETY: The caller has validated that this length value is correct and will only
+
// access within the bounds of the provided slice.
+
let segment = unsafe { input.get_unchecked(0..(len as usize)) };
+
// SAFETY: The caller has upheld the validity of the bytes as valid UTF-8 once before.
+
let segment = unsafe { core::str::from_utf8_unchecked(segment) };
+
+
Self::String(segment)
+
}
+
})
+
}
+
+
fn as_str(&self) -> Option<&'a str> {
+
match self {
+
Self::String(label) => Some(*label),
+
_ => None,
+
}
+
}
+
}
+
+
pub struct LabelSegmentBytesIter<'a> {
+
context: &'a [u8],
+
start: usize,
+
}
+
+
impl<'a> LabelSegmentBytesIter<'a> {
+
pub(crate) fn new(context: &'a [u8], start: usize) -> Self {
+
Self { context, start }
+
}
+
}
+
+
impl<'a> Iterator for LabelSegmentBytesIter<'a> {
+
type Item = LabelSegment<'a>;
+
+
fn next(&mut self) -> Option<Self::Item> {
+
loop {
+
let view = &self.context[self.start..];
+
+
// SAFETY: The segment has already been validated, so they should be all valid variants and UTF-8 bytes
+
let segment = unsafe { LabelSegment::parse_unchecked(view)? };
+
+
match segment {
+
LabelSegment::String(label) => {
+
self.start = self.start.saturating_add(label.len() + 1);
+
return Some(segment);
+
}
+
LabelSegment::Pointer(ptr) => {
+
self.start = ptr as usize;
+
}
+
LabelSegment::Empty => {
+
// Set the index offset to be len() so that the view is empty and terminates the loop
+
self.start = self.context.len();
+
return Some(LabelSegment::Empty);
+
}
+
}
+
}
+
}
+
}
+
+
impl<'a> LabelRepr<'a> {
+
fn iter(&self) -> impl Iterator<Item = LabelSegment<'a>> {
+
match *self {
+
LabelRepr::Bytes { context, start, .. } => {
+
Either::A(LabelSegmentBytesIter::new(context, start))
+
}
+
LabelRepr::Str(view) => Either::B(
+
view.split('.')
+
.map(LabelSegment::String)
+
.chain(Some(LabelSegment::Empty)),
+
),
+
}
+
}
+
}
+
+
impl fmt::Debug for Label<'_> {
+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+
struct LabelFmt<'a>(&'a Label<'a>);
+
+
impl fmt::Debug for LabelFmt<'_> {
+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+
fmt::Display::fmt(self.0, f)
+
}
+
}
+
+
f.debug_tuple("Label").field(&LabelFmt(self)).finish()
+
}
+
}
+
+
impl fmt::Display for Label<'_> {
+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+
let mut names = self.names();
+
+
if let Some(name) = names.next() {
+
f.write_str(name)?;
+
+
names.try_for_each(|name| {
+
f.write_str(".")?;
+
f.write_str(name)
+
})
+
} else {
+
Ok(())
+
}
+
}
+
}
+
+
impl<'a, 'b> PartialEq<Label<'a>> for Label<'b> {
+
fn eq(&self, other: &Label<'a>) -> bool {
+
self.segments().eq(other.segments())
+
}
+
}
+
+
impl Eq for Label<'_> {}
+
+
impl PartialEq<&str> for Label<'_> {
+
fn eq(&self, other: &&str) -> bool {
+
let mut self_iter = self.names();
+
let mut other_iter = other.split('.');
+
+
loop {
+
match (self_iter.next(), other_iter.next()) {
+
(Some(self_part), Some(other_part)) => {
+
if self_part != other_part {
+
return false;
+
}
+
}
+
(None, None) => return true,
+
_ => return false,
+
}
+
}
+
}
+
}
+
+
#[cfg(feature = "defmt")]
+
impl defmt::Format for Label<'_> {
+
fn format(&self, fmt: defmt::Formatter) {
+
defmt::write!(fmt, "Label(");
+
let mut iter = self.names();
+
if let Some(first) = iter.next() {
+
defmt::write!(fmt, "{}", first);
+
+
iter.for_each(|part| defmt::write!(fmt, ".{}", part));
+
}
+
defmt::write!(fmt, ")");
+
}
+
}
+
+
/// One iterator or another.
+
enum Either<A, B> {
+
A(A),
+
B(B),
+
}
+
+
impl<A: Iterator, Other: Iterator<Item = A::Item>> Iterator for Either<A, Other> {
+
type Item = A::Item;
+
+
fn next(&mut self) -> Option<Self::Item> {
+
match self {
+
Either::A(a) => a.next(),
+
Either::B(b) => b.next(),
+
}
+
}
+
+
fn size_hint(&self) -> (usize, Option<usize>) {
+
match self {
+
Either::A(a) => a.size_hint(),
+
Either::B(b) => b.size_hint(),
+
}
+
}
+
+
fn fold<B, F>(self, init: B, f: F) -> B
+
where
+
Self: Sized,
+
F: FnMut(B, Self::Item) -> B,
+
{
+
match self {
+
Either::A(a) => a.fold(init, f),
+
Either::B(b) => b.fold(init, f),
+
}
+
}
+
}
+
+
#[cfg(test)]
+
mod test {
+
use super::*;
+
+
#[test]
+
fn segments_iter_test() {
+
let label: Label<'static> = Label::from("_service._udp.local");
+
let mut segments = label.segments();
+
+
assert_eq!(segments.next(), Some(LabelSegment::String("_service")));
+
assert_eq!(segments.next(), Some(LabelSegment::String("_udp")));
+
assert_eq!(segments.next(), Some(LabelSegment::String("local")));
+
assert_eq!(segments.next(), Some(LabelSegment::Empty));
+
assert_eq!(segments.next(), None);
+
+
// example.com with a pointer to the start
+
let data = b"\x07example\x03com\x00\xC0\x00";
+
let context = &data[..];
+
// The data here is entirely valid, even though we parse only a portion of it.
+
let label = Label::parse(&mut &data[13..], context).unwrap();
+
+
let mut segments = label.segments();
+
assert_eq!(segments.next(), Some(LabelSegment::String("example")));
+
assert_eq!(segments.next(), Some(LabelSegment::String("com")));
+
assert_eq!(segments.next(), Some(LabelSegment::Empty));
+
assert_eq!(segments.next(), None);
+
}
+
+
#[test]
+
fn names_iter_test() {
+
let label: Label<'static> = Label::from("_service._udp.local");
+
let mut names = label.names();
+
+
assert_eq!(names.next(), Some("_service"));
+
assert_eq!(names.next(), Some("_udp"));
+
assert_eq!(names.next(), Some("local"));
+
assert_eq!(names.next(), None);
+
+
let data = b"\x07example\x03com\x00\xC0\x00";
+
let context = &data[..];
+
// The data here is entirely valid, even though we parse only a portion of it.
+
let label = Label::parse(&mut &data[13..], context).unwrap();
+
+
let mut names = label.names();
+
assert_eq!(names.next(), Some("example"));
+
assert_eq!(names.next(), Some("com"));
+
assert_eq!(names.next(), None);
+
}
+
+
#[test]
+
fn serialize_str_label() {
+
let label: Label<'static> = Label::from("_service._udp.local");
+
let mut buffer = [0u8; 256];
+
let mut buffer = Encoder::new(&mut buffer);
+
label.serialize(&mut buffer).unwrap();
+
assert_eq!(buffer.finish(), b"\x08_service\x04_udp\x05local\x00");
+
}
+
+
#[test]
+
fn serialize_compressed_str_label() {
+
let label: Label<'static> = Label::from("_service._udp.local");
+
let label2: Label<'static> = Label::from("other._udp.local");
+
let mut buffer = [0u8; 256];
+
let mut buffer = Encoder::new(&mut buffer);
+
label.serialize(&mut buffer).unwrap();
+
label2.serialize(&mut buffer).unwrap();
+
assert_eq!(
+
buffer.finish(),
+
b"\x08_service\x04_udp\x05local\x00\x05other\xC0\x09"
+
);
+
}
+
+
#[test]
+
fn round_trip_compressed_str_label() {
+
let label: Label<'static> = Label::from("_service._udp.local");
+
let label2: Label<'static> = Label::from("other._udp.local");
+
let mut buffer = [0u8; 256];
+
let mut buffer = Encoder::new(&mut buffer);
+
label.serialize(&mut buffer).unwrap();
+
label2.serialize(&mut buffer).unwrap();
+
let context = buffer.finish();
+
assert_eq!(
+
context,
+
b"\x08_service\x04_udp\x05local\x00\x05other\xC0\x09"
+
);
+
let view = &mut &context[..];
+
+
let parsed_label = Label::parse(view, context).unwrap();
+
let parsed_label2 = Label::parse(view, context).unwrap();
+
+
let parsed_label_count = parsed_label.segments().count();
+
let parsed_label2_count = parsed_label2.segments().count();
+
+
// Both have same amount of segments.
+
assert_eq!(parsed_label_count, parsed_label2_count);
+
+
assert_eq!(parsed_label, label);
+
assert_eq!(parsed_label2, label2);
+
}
+
+
#[test]
+
fn label_byte_repr_serialization_quick_path() {
+
let data = b"\x07example\x03com\x00\xC0\x00";
+
let context = &data[..];
+
// The data here is entirely valid, even though we parse only a portion of it.
+
let label = Label::parse(&mut &data[13..], context).unwrap();
+
+
let mut buffer = [0u8; 256];
+
let mut buffer = Encoder::new(&mut buffer);
+
label.serialize(&mut buffer).unwrap();
+
// If the original Label is just a pointer, the new output will be a pointer, assuming
+
// the original data is also present in the output
+
assert_eq!(buffer.finish(), b"\xC0\x00");
+
}
+
+
#[test]
+
fn parse_and_eq_created_label() {
+
let data = b"\x07example\x03com\x00\xC0\x00";
+
let context = &data[..];
+
+
// The data here is entirely valid, even though we parse only a portion of it.
+
let parsed_label = Label::parse(&mut &data[13..], context).unwrap();
+
+
let created_label = Label::from("example.com");
+
+
assert_eq!(parsed_label, created_label);
+
}
+
+
#[test]
+
fn parse_and_eq_label_with_str() {
+
let data = b"\x07example\x03com\x00";
+
let context = &data[..];
+
+
let parsed_label = Label::parse(&mut &data[..], context).unwrap();
+
+
assert_eq!(parsed_label, "example.com");
+
}
+
+
#[test]
+
fn parse_ptr_label_and_eq_with_str() {
+
let data = b"\x07example\x03com\x00\xC0\x00";
+
let context = &data[..];
+
+
// The data here is entirely valid, even though we parse only a portion of it.
+
let parsed_label = Label::parse(&mut &data[13..], context).unwrap();
+
+
assert_eq!(parsed_label, "example.com");
+
}
+
+
#[test]
+
fn label_new_without_dot_is_not_empty() {
+
let label: Label = Label::from("example");
+
assert!(!label.is_empty());
+
}
+
}
+228
sachy-mdns/src/dns/query.rs
···
+
use winnow::binary::{be_u16, be_u32};
+
use winnow::{ModalResult, Parser};
+
+
use super::label::Label;
+
use super::records::Record;
+
use crate::encoder::Encoder;
+
use crate::{
+
dns::{
+
records::QType,
+
traits::{DnsParse, DnsParseKind, DnsSerialize},
+
},
+
encoder::DnsError,
+
};
+
+
#[derive(Debug, PartialEq, Eq)]
+
pub struct Query<'a> {
+
pub name: Label<'a>,
+
pub qtype: QType,
+
pub qclass: QClass,
+
}
+
+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+
#[repr(u16)]
+
pub enum QClass {
+
IN = 1,
+
Multicast = 32769, // (IN + Cache flush bit)
+
Unknown(u16),
+
}
+
+
impl<'a> DnsParse<'a> for Query<'a> {
+
fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self> {
+
let name = Label::parse(input, context)?;
+
let qtype = QType::parse(input, context)?;
+
let qclass = be_u16.map(QClass::from_u16).parse_next(input)?;
+
+
Ok(Query {
+
name,
+
qtype,
+
qclass,
+
})
+
}
+
}
+
+
impl<'a> DnsSerialize<'a> for Query<'a> {
+
type Error = DnsError;
+
+
fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> {
+
self.name.serialize(encoder)?;
+
self.qtype.serialize(encoder).ok();
+
encoder.write(&self.qclass.to_u16().to_be_bytes());
+
Ok(())
+
}
+
+
fn size(&self) -> usize {
+
self.name.size() + self.qtype.size() + core::mem::size_of::<QClass>()
+
}
+
}
+
+
#[derive(Debug, PartialEq, Eq)]
+
pub struct Answer<'a> {
+
pub name: Label<'a>,
+
pub atype: QType,
+
pub aclass: QClass,
+
pub ttl: u32,
+
pub record: Record<'a>,
+
}
+
+
impl QClass {
+
fn from_u16(value: u16) -> Self {
+
match value {
+
1 => QClass::IN,
+
32769 => QClass::Multicast,
+
_ => QClass::Unknown(value),
+
}
+
}
+
+
fn to_u16(self) -> u16 {
+
match self {
+
QClass::IN => 1,
+
QClass::Multicast => 32769,
+
QClass::Unknown(value) => value,
+
}
+
}
+
}
+
+
impl<'a> DnsParse<'a> for Answer<'a> {
+
fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self> {
+
let name = Label::parse(input, context)?;
+
let atype = QType::parse(input, context)?;
+
let aclass = be_u16.map(QClass::from_u16).parse_next(input)?;
+
+
let ttl = be_u32.parse_next(input)?;
+
let record = atype.parse_kind(input, context)?;
+
+
Ok(Answer {
+
name,
+
atype,
+
aclass,
+
ttl,
+
record,
+
})
+
}
+
}
+
+
impl<'a> DnsSerialize<'a> for Answer<'a> {
+
type Error = DnsError;
+
+
fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> {
+
self.name.serialize(encoder)?;
+
self.atype.serialize(encoder).ok();
+
encoder.write(&self.aclass.to_u16().to_be_bytes());
+
encoder.write(&self.ttl.to_be_bytes());
+
self.record.serialize(encoder)
+
}
+
+
fn size(&self) -> usize {
+
self.name.size()
+
+ self.atype.size()
+
+ core::mem::size_of::<QClass>()
+
+ core::mem::size_of::<u32>()
+
+ self.record.size()
+
}
+
}
+
+
#[cfg(feature = "defmt")]
+
impl<'a> defmt::Format for Query<'a> {
+
fn format(&self, fmt: defmt::Formatter) {
+
defmt::write!(
+
fmt,
+
"Query {{ name: {:?}, qtype: {:?}, qclass: {:?} }}",
+
self.name,
+
self.qtype,
+
self.qclass
+
);
+
}
+
}
+
+
#[cfg(feature = "defmt")]
+
impl defmt::Format for QType {
+
fn format(&self, fmt: defmt::Formatter) {
+
let qtype_str = match self {
+
QType::A => "A",
+
QType::AAAA => "AAAA",
+
QType::PTR => "PTR",
+
QType::TXT => "TXT",
+
QType::SRV => "SRV",
+
QType::Any => "Any",
+
QType::Unknown(_) => "Unknown",
+
};
+
defmt::write!(fmt, "QType({=str})", qtype_str);
+
}
+
}
+
+
#[cfg(feature = "defmt")]
+
impl defmt::Format for QClass {
+
fn format(&self, fmt: defmt::Formatter) {
+
let qclass_str = match self {
+
QClass::IN => "IN",
+
QClass::Multicast => "Multicast",
+
QClass::Unknown(_) => "Unknown",
+
};
+
defmt::write!(fmt, "QClass({=str})", qclass_str);
+
}
+
}
+
+
#[cfg(feature = "defmt")]
+
impl<'a> defmt::Format for Answer<'a> {
+
fn format(&self, fmt: defmt::Formatter) {
+
defmt::write!(
+
fmt,
+
"Answer {{ name: {:?}, atype: {:?}, aclass: {:?}, ttl: {}, record: {:?} }}",
+
self.name,
+
self.atype,
+
self.aclass,
+
self.ttl,
+
self.record
+
);
+
}
+
}
+
+
#[cfg(test)]
+
mod tests {
+
use super::*;
+
use crate::dns::records::A;
+
use core::net::Ipv4Addr;
+
+
#[test]
+
fn roundtrip_query() {
+
let name = Label::from("example.local");
+
+
let query = Query {
+
name,
+
qtype: QType::A,
+
qclass: QClass::IN,
+
};
+
+
let mut buffer = [0u8; 256];
+
let mut buffer = Encoder::new(&mut buffer);
+
query.serialize(&mut buffer).unwrap();
+
let buffer = buffer.finish();
+
let parsed_query = Query::parse(&mut &buffer[..], buffer).unwrap();
+
+
assert_eq!(query, parsed_query);
+
}
+
+
#[test]
+
fn roundtrip_answer() {
+
let name = Label::from("example.local");
+
+
let answer: Answer = Answer {
+
name,
+
atype: QType::A,
+
aclass: QClass::IN,
+
ttl: 120,
+
record: Record::A(A {
+
address: Ipv4Addr::new(192, 168, 1, 1),
+
}),
+
};
+
+
let mut buffer = [0u8; 256];
+
let mut buffer = Encoder::new(&mut buffer);
+
answer.serialize(&mut buffer).unwrap();
+
let buffer = buffer.finish();
+
let parsed_answer = Answer::parse(&mut &buffer[..], buffer).unwrap();
+
+
assert_eq!(answer, parsed_answer);
+
}
+
}
+413
sachy-mdns/src/dns/records.rs
···
+
use core::{
+
convert::Infallible,
+
net::{Ipv4Addr, Ipv6Addr},
+
str,
+
};
+
+
use alloc::vec::Vec;
+
use winnow::token::take;
+
use winnow::{ModalResult, Parser};
+
use winnow::{binary::be_u8, error::ContextError};
+
use winnow::{binary::be_u16, error::FromExternalError};
+
+
use super::label::Label;
+
use crate::{
+
dns::traits::{DnsParse, DnsParseKind, DnsSerialize},
+
encoder::{DnsError, Encoder},
+
};
+
+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+
#[repr(u16)]
+
#[allow(clippy::upper_case_acronyms)]
+
pub enum QType {
+
A = 1,
+
AAAA = 28,
+
PTR = 12,
+
TXT = 16,
+
SRV = 33,
+
Any = 255,
+
Unknown(u16),
+
}
+
+
impl<'a> DnsParse<'a> for QType {
+
fn parse(input: &mut &'a [u8], _context: &'a [u8]) -> ModalResult<Self> {
+
be_u16.map(QType::from_u16).parse_next(input)
+
}
+
}
+
+
impl<'a> DnsSerialize<'a> for QType {
+
type Error = Infallible;
+
+
fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> {
+
encoder.write(&self.to_u16().to_be_bytes());
+
Ok(())
+
}
+
+
fn size(&self) -> usize {
+
core::mem::size_of::<QType>()
+
}
+
}
+
+
impl<'a> DnsParseKind<'a> for QType {
+
type Output = Record<'a>;
+
+
fn parse_kind(&self, input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self::Output> {
+
match self {
+
QType::A => {
+
let record = A::parse(input, context)?;
+
Ok(Record::A(record))
+
}
+
QType::AAAA => {
+
let record = AAAA::parse(input, context)?;
+
Ok(Record::AAAA(record))
+
}
+
QType::PTR => {
+
let record = PTR::parse(input, context)?;
+
Ok(Record::PTR(record))
+
}
+
QType::TXT => {
+
let record = TXT::parse(input, context)?;
+
Ok(Record::TXT(record))
+
}
+
QType::SRV => {
+
let record = SRV::parse(input, context)?;
+
Ok(Record::SRV(record))
+
}
+
QType::Any => Err(winnow::error::ErrMode::Backtrack(
+
ContextError::from_external_error(input, DnsError::Unsupported),
+
)),
+
QType::Unknown(_) => Err(winnow::error::ErrMode::Backtrack(
+
ContextError::from_external_error(input, DnsError::Unsupported),
+
)),
+
}
+
}
+
}
+
+
impl QType {
+
fn from_u16(value: u16) -> Self {
+
match value {
+
1 => QType::A,
+
28 => QType::AAAA,
+
12 => QType::PTR,
+
16 => QType::TXT,
+
33 => QType::SRV,
+
255 => QType::Any,
+
_ => QType::Unknown(value),
+
}
+
}
+
+
fn to_u16(self) -> u16 {
+
match self {
+
QType::A => 1,
+
QType::AAAA => 28,
+
QType::PTR => 12,
+
QType::TXT => 16,
+
QType::SRV => 33,
+
QType::Any => 255,
+
QType::Unknown(value) => value,
+
}
+
}
+
}
+
+
#[derive(Debug, PartialEq, Eq)]
+
#[allow(clippy::upper_case_acronyms)]
+
// Enum for DNS-SD records
+
pub enum Record<'a> {
+
A(A),
+
AAAA(AAAA),
+
PTR(PTR<'a>),
+
TXT(TXT<'a>),
+
SRV(SRV<'a>),
+
}
+
+
impl<'a> DnsSerialize<'a> for Record<'a> {
+
type Error = DnsError;
+
+
fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> {
+
match self {
+
Record::A(record) => {
+
record.serialize(encoder).ok();
+
}
+
Record::AAAA(record) => {
+
record.serialize(encoder).ok();
+
}
+
Record::PTR(record) => {
+
record.serialize(encoder)?;
+
}
+
Record::TXT(record) => {
+
record.serialize(encoder).ok();
+
}
+
Record::SRV(record) => {
+
record.serialize(encoder)?;
+
}
+
};
+
+
Ok(())
+
}
+
+
fn size(&self) -> usize {
+
match self {
+
Self::A(a) => a.size(),
+
Self::AAAA(aaaa) => aaaa.size(),
+
Self::PTR(ptr) => ptr.size(),
+
Self::TXT(txt) => txt.size(),
+
Self::SRV(srv) => srv.size(),
+
}
+
}
+
}
+
+
// Struct for A record
+
#[derive(Debug, PartialEq, Eq)]
+
pub struct A {
+
pub address: Ipv4Addr,
+
}
+
+
impl<'a> DnsParse<'a> for A {
+
fn parse(input: &mut &'a [u8], _context: &'a [u8]) -> ModalResult<Self> {
+
let len = be_u16.parse_next(input)?;
+
let address = take(len)
+
.try_map(<[u8; 4]>::try_from)
+
.map(Ipv4Addr::from)
+
.parse_next(input)?;
+
+
Ok(A { address })
+
}
+
}
+
+
impl<'a> DnsSerialize<'a> for A {
+
type Error = Infallible;
+
+
fn serialize(&self, writer: &mut Encoder<'_, '_>) -> Result<(), Self::Error> {
+
let len = 4u16.to_be_bytes();
+
writer.write(&len);
+
writer.write(&self.address.octets());
+
Ok(())
+
}
+
+
fn size(&self) -> usize {
+
core::mem::size_of::<Ipv4Addr>() + core::mem::size_of::<u16>()
+
}
+
}
+
+
// Struct for AAAA record
+
#[derive(Debug, PartialEq, Eq)]
+
#[allow(clippy::upper_case_acronyms)]
+
pub struct AAAA {
+
pub address: Ipv6Addr,
+
}
+
+
impl<'a> DnsParse<'a> for AAAA {
+
fn parse(input: &mut &'a [u8], _context: &'a [u8]) -> ModalResult<Self> {
+
let len = be_u16.parse_next(input)?;
+
let address = take(len)
+
.try_map(<[u8; 16]>::try_from)
+
.map(Ipv6Addr::from)
+
.parse_next(input)?;
+
+
Ok(AAAA { address })
+
}
+
}
+
+
impl<'a> DnsSerialize<'a> for AAAA {
+
type Error = Infallible;
+
+
fn serialize(&self, writer: &mut Encoder<'_, '_>) -> Result<(), Self::Error> {
+
let len = 16u16.to_be_bytes();
+
writer.write(&len);
+
writer.write(&self.address.octets());
+
Ok(())
+
}
+
+
fn size(&self) -> usize {
+
core::mem::size_of::<Ipv6Addr>() + core::mem::size_of::<u16>()
+
}
+
}
+
+
// Struct for PTR record
+
#[derive(Debug, PartialEq, Eq)]
+
#[allow(clippy::upper_case_acronyms)]
+
pub struct PTR<'a> {
+
pub name: Label<'a>,
+
}
+
+
impl<'a> DnsParse<'a> for PTR<'a> {
+
fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self> {
+
let _ = be_u16.parse_next(input)?;
+
let name = Label::parse(input, context)?;
+
Ok(PTR { name })
+
}
+
}
+
+
impl<'a> DnsSerialize<'a> for PTR<'a> {
+
type Error = DnsError;
+
+
fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> {
+
encoder.with_record_length(|enc| self.name.serialize(enc))
+
}
+
+
fn size(&self) -> usize {
+
self.name.size() + core::mem::size_of::<u16>()
+
}
+
}
+
+
// Struct for TXT record
+
#[derive(Debug, PartialEq, Eq)]
+
#[allow(clippy::upper_case_acronyms)]
+
pub struct TXT<'a> {
+
pub text: Vec<&'a str>,
+
}
+
+
impl<'a> DnsParse<'a> for TXT<'a> {
+
fn parse(input: &mut &'a [u8], _context: &'a [u8]) -> ModalResult<Self> {
+
let text_len = be_u16.parse_next(input)?;
+
+
let mut total = 0u16;
+
let mut text = Vec::new();
+
+
while total < text_len {
+
let len = be_u8(input)?;
+
+
total += 1 + len as u16;
+
+
if len > 0 {
+
let part = take(len).try_map(core::str::from_utf8).parse_next(input)?;
+
text.push(part);
+
}
+
}
+
+
Ok(TXT { text })
+
}
+
}
+
+
impl<'a> DnsSerialize<'a> for TXT<'a> {
+
type Error = DnsError;
+
+
fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> {
+
encoder.with_record_length(|enc| {
+
self.text.iter().try_for_each(|&part| {
+
let text_len = u8::try_from(part.len())
+
.map_err(|_| DnsError::InvalidTxt)
+
.map(u8::to_be_bytes)?;
+
+
enc.write(&text_len);
+
enc.write(part.as_bytes());
+
+
Ok(())
+
})
+
})
+
}
+
+
fn size(&self) -> usize {
+
let len_size = core::mem::size_of::<u16>();
+
+
let text_size = if self.text.is_empty() {
+
1
+
} else {
+
self.text.iter().map(|part| part.len() + 1).sum()
+
};
+
+
len_size + text_size
+
}
+
}
+
+
// Struct for SRV record
+
#[derive(Debug, PartialEq, Eq)]
+
#[allow(clippy::upper_case_acronyms)]
+
pub struct SRV<'a> {
+
pub priority: u16,
+
pub weight: u16,
+
pub port: u16,
+
pub target: Label<'a>,
+
}
+
+
impl<'a> DnsParse<'a> for SRV<'a> {
+
fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self> {
+
let _ = be_u16.parse_next(input)?;
+
let priority = be_u16.parse_next(input)?;
+
let weight = be_u16.parse_next(input)?;
+
let port = be_u16.parse_next(input)?;
+
let target = Label::parse(input, context)?;
+
+
Ok(SRV {
+
priority,
+
weight,
+
port,
+
target,
+
})
+
}
+
}
+
+
impl<'a> DnsSerialize<'a> for SRV<'a> {
+
type Error = DnsError;
+
+
fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> {
+
encoder.with_record_length(|enc| {
+
enc.write(&self.priority.to_be_bytes());
+
enc.write(&self.weight.to_be_bytes());
+
enc.write(&self.port.to_be_bytes());
+
+
self.target.serialize(enc)
+
})
+
}
+
+
fn size(&self) -> usize {
+
(core::mem::size_of::<u16>() * 4) + self.target.size()
+
}
+
}
+
+
#[cfg(feature = "defmt")]
+
impl defmt::Format for A {
+
fn format(&self, fmt: defmt::Formatter) {
+
// use crate::format::FormatIpv4Addr;
+
defmt::write!(fmt, "A({})", self.address)
+
}
+
}
+
+
#[cfg(feature = "defmt")]
+
impl defmt::Format for AAAA {
+
fn format(&self, fmt: defmt::Formatter) {
+
// use crate::format::FormatIpv6Addr;
+
defmt::write!(fmt, "AAAA({})", self.address)
+
}
+
}
+
+
#[cfg(feature = "defmt")]
+
impl<'a> defmt::Format for Record<'a> {
+
fn format(&self, fmt: defmt::Formatter) {
+
match self {
+
Record::A(record) => defmt::write!(fmt, "Record::A({:?})", record),
+
Record::AAAA(record) => defmt::write!(fmt, "Record::AAAA({:?})", record),
+
Record::PTR(record) => defmt::write!(fmt, "Record::PTR({:?})", record),
+
Record::TXT(record) => defmt::write!(fmt, "Record::TXT({:?})", record),
+
Record::SRV(record) => defmt::write!(fmt, "Record::SRV({:?})", record),
+
}
+
}
+
}
+
+
#[cfg(feature = "defmt")]
+
impl<'a> defmt::Format for PTR<'a> {
+
fn format(&self, fmt: defmt::Formatter) {
+
defmt::write!(fmt, "PTR {{ name: {:?} }}", self.name);
+
}
+
}
+
+
#[cfg(feature = "defmt")]
+
impl<'a> defmt::Format for TXT<'a> {
+
fn format(&self, fmt: defmt::Formatter) {
+
defmt::write!(fmt, "TXT {{ text: {:?} }}", self.text);
+
}
+
}
+
+
#[cfg(feature = "defmt")]
+
impl<'a> defmt::Format for SRV<'a> {
+
fn format(&self, fmt: defmt::Formatter) {
+
defmt::write!(
+
fmt,
+
"SRV {{ priority: {}, weight: {}, port: {}, target: {:?} }}",
+
self.priority,
+
self.weight,
+
self.port,
+
self.target
+
);
+
}
+
}
+483
sachy-mdns/src/dns/reqres.rs
···
+
use alloc::vec::Vec;
+
use winnow::ModalResult;
+
use winnow::binary::be_u16;
+
+
use super::flags::Flags;
+
use super::query::{Answer, Query};
+
use crate::{
+
dns::traits::{DnsParse, DnsSerialize},
+
encoder::{DnsError, Encoder},
+
};
+
+
const ZERO_U16: [u8; 2] = 0u16.to_be_bytes();
+
+
#[derive(Debug, PartialEq, Eq)]
+
pub struct Request<'a> {
+
pub id: u16,
+
pub flags: Flags,
+
pub(crate) queries: Vec<Query<'a>>,
+
}
+
+
impl<'a> DnsParse<'a> for Request<'a> {
+
fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self> {
+
let id = be_u16(input)?;
+
let flags = Flags::parse(input, context)?;
+
let qdcount = be_u16(input)?;
+
let _ancount = be_u16(input)?;
+
let _nscount = be_u16(input)?;
+
let _arcount = be_u16(input)?;
+
let queries = (0..qdcount)
+
.map(|_| Query::parse(input, context))
+
.collect::<Result<Vec<_>, _>>()?;
+
Ok(Request { id, flags, queries })
+
}
+
}
+
+
impl<'a> DnsSerialize<'a> for Request<'a> {
+
type Error = DnsError;
+
+
fn serialize<'b>(&self, writer: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> {
+
writer.write(&self.id.to_be_bytes());
+
self.flags.serialize(writer).ok();
+
writer.write(&(self.queries.len() as u16).to_be_bytes());
+
writer.write(&ZERO_U16);
+
writer.write(&ZERO_U16);
+
writer.write(&ZERO_U16);
+
+
self.queries
+
.iter()
+
.try_for_each(|query| query.serialize(writer))
+
}
+
+
fn size(&self) -> usize {
+
let total_query_size: usize = self.queries.iter().map(DnsSerialize::size).sum();
+
+
core::mem::size_of::<u16>()
+
+ self.flags.size()
+
+ (core::mem::size_of::<u16>() * 4)
+
+ total_query_size
+
}
+
}
+
+
#[derive(Debug, PartialEq, Eq)]
+
pub struct Response<'a> {
+
pub id: u16,
+
pub flags: Flags,
+
pub queries: Vec<Query<'a>>,
+
pub answers: Vec<Answer<'a>>,
+
}
+
+
impl<'a> DnsParse<'a> for Response<'a> {
+
fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self> {
+
let id = be_u16(input)?;
+
let flags = Flags::parse(input, context)?;
+
let qdcount = be_u16(input)?;
+
let ancount = be_u16(input)?;
+
let _nscount = be_u16(input)?;
+
let _arcount = be_u16(input)?;
+
+
let queries = (0..qdcount)
+
.map(|_| Query::parse(input, context))
+
.collect::<Result<Vec<_>, _>>()?;
+
+
let answers = (0..ancount)
+
.map(|_| Answer::parse(input, context))
+
.collect::<Result<Vec<_>, _>>()?;
+
+
Ok(Response {
+
id,
+
flags,
+
queries,
+
answers,
+
})
+
}
+
}
+
+
impl<'a> DnsSerialize<'a> for Response<'a> {
+
type Error = DnsError;
+
+
fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> {
+
encoder.write(&self.id.to_be_bytes());
+
self.flags.serialize(encoder).ok();
+
encoder.write(&(self.queries.len() as u16).to_be_bytes());
+
encoder.write(&(self.answers.len() as u16).to_be_bytes());
+
encoder.write(&ZERO_U16);
+
encoder.write(&ZERO_U16);
+
+
self.queries
+
.iter()
+
.try_for_each(|query| query.serialize(encoder))?;
+
self.answers
+
.iter()
+
.try_for_each(|answer| answer.serialize(encoder))
+
}
+
+
fn size(&self) -> usize {
+
let total_query_size: usize = self.queries.iter().map(DnsSerialize::size).sum();
+
let total_answer_size: usize = self.answers.iter().map(DnsSerialize::size).sum();
+
+
core::mem::size_of::<u16>()
+
+ self.flags.size()
+
+ (core::mem::size_of::<u16>() * 4)
+
+ total_query_size
+
+ total_answer_size
+
}
+
}
+
+
#[cfg(feature = "defmt")]
+
impl<'a> defmt::Format for Request<'a> {
+
fn format(&self, fmt: defmt::Formatter) {
+
defmt::write!(
+
fmt,
+
"Request {{ id: {}, flags: {:?}, queries: {:?} }}",
+
self.id,
+
self.flags,
+
self.queries
+
);
+
}
+
}
+
+
#[cfg(feature = "defmt")]
+
impl<'a> defmt::Format for Response<'a> {
+
fn format(&self, fmt: defmt::Formatter) {
+
defmt::write!(
+
fmt,
+
"Response {{ id: {}, flags: {:?}, queries: {:?}, answers: {:?} }}",
+
self.id,
+
self.flags,
+
self.queries,
+
self.answers
+
);
+
}
+
}
+
+
#[cfg(test)]
+
mod tests {
+
use alloc::vec;
+
+
use super::*;
+
use crate::dns::{
+
label::Label,
+
query::QClass,
+
records::{A, PTR, QType, Record, SRV, TXT},
+
};
+
use core::net::Ipv4Addr;
+
+
#[test]
+
fn parse_query() {
+
let data = [
+
0xAA, 0xAA, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x07, 0x65,
+
// example . com in label format
+
0x78, 0x61, 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, //
+
//
+
0x00, 0x01, 0x00, 0x01,
+
];
+
+
let request = Request::parse(&mut data.as_slice(), data.as_slice()).unwrap();
+
+
assert_eq!(request.id, 0xAAAA);
+
assert_eq!(request.flags.0, 0x0100);
+
assert_eq!(request.queries.len(), 1);
+
assert_eq!(request.queries[0].name, "example.com");
+
assert_eq!(request.queries[0].qtype, QType::A);
+
assert_eq!(request.queries[0].qclass, QClass::IN);
+
}
+
+
#[test]
+
fn parse_response() {
+
let data = [
+
0xAA, 0xAA, // transaction ID
+
0x81, 0x80, // flags
+
0x00, 0x01, // 1 question
+
0x00, 0x01, // 1 A-answer
+
0x00, 0x00, // no authority
+
0x00, 0x00, // no additional answers
+
// example . com in label format
+
0x07, 0x65, 0x78, 0x61, 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, //
+
//
+
0x00, 0x01, 0x00, 0x01, //
+
//
+
0xC0, 0x0C, // ptr to question section
+
//
+
0x00, 0x01, 0x00, 0x01, // A and IN
+
//
+
0x00, 0x00, 0x00, 0x3C, // TTL 60 seconds
+
//
+
0x00, 0x04, // length of address
+
// IP address:
+
192, 168, 1, 3,
+
];
+
+
let response = Response::parse(&mut data.as_slice(), data.as_slice()).unwrap();
+
+
assert_eq!(response.id, 0xAAAA);
+
assert_eq!(response.flags.0, 0x8180);
+
assert_eq!(response.answers.len(), 1);
+
assert_eq!(response.answers[0].name, "example.com");
+
assert_eq!(response.answers[0].atype, QType::A);
+
assert_eq!(response.answers[0].aclass, QClass::IN);
+
assert_eq!(response.answers[0].ttl, 60);
+
if let Record::A(a) = &response.answers[0].record {
+
assert_eq!(a.address, Ipv4Addr::new(192, 168, 1, 3));
+
} else {
+
panic!("Expected A record");
+
}
+
}
+
+
#[test]
+
fn parse_response_two_records() {
+
#[rustfmt::skip]
+
let data = [
+
0xAA, 0xAA, //
+
0x81, 0x80, //
+
0x00, 0x01, //
+
0x00, 0x02, //
+
0x00, 0x00, //
+
0x00, 0x00, //
+
// example . com in label format
+
0x07, 0x65, 0x78, 0x61, 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, //
+
//
+
0x00, 0x01, // query type
+
0x00, 0x01, // query class
+
//
+
0xC0, 0x0C, // pointer
+
0x00, 0x01, //
+
0x00, 0x01, //
+
0x00, 0x00, 0x00, 0x3C, // ttl 60 seconds
+
0x00, 0x04, // length of A-record
+
0x5D, 0xB8, 0xD8, 0x22, // a-record
+
//
+
0x07, 0x65, 0x78, 0x61, 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, //
+
//
+
0x00, 0x10, // TXT
+
0x00, 0x01, // IN
+
//
+
0x00, 0x00, 0x00, 0x3C, // ttl 60 seconds
+
//
+
0x00, 0x10, // length of txt record
+
// (len) "test txt record"
+
0x0F, 0x74, 0x65, 0x73, 0x74, 0x20, 0x74, 0x78, 0x74, 0x20, 0x72, 0x65, 0x63, 0x6F, 0x72,
+
0x64,
+
];
+
+
let response = Response::parse(&mut data.as_slice(), data.as_slice()).unwrap();
+
+
assert_eq!(response.id, 0xAAAA);
+
assert_eq!(response.flags.0, 0x8180);
+
assert_eq!(response.answers.len(), 2);
+
+
// First answer
+
assert_eq!(response.answers[0].name, "example.com");
+
assert_eq!(response.answers[0].atype, QType::A);
+
assert_eq!(response.answers[0].aclass, QClass::IN);
+
assert_eq!(response.answers[0].ttl, 60);
+
if let Record::A(a) = &response.answers[0].record {
+
assert_eq!(a.address, Ipv4Addr::new(93, 184, 216, 34));
+
} else {
+
panic!("Expected A record");
+
}
+
+
// Second answer
+
assert_eq!(response.answers[1].name, "example.com");
+
assert_eq!(response.answers[1].atype, QType::TXT);
+
assert_eq!(response.answers[1].aclass, QClass::IN);
+
assert_eq!(response.answers[1].ttl, 60);
+
if let Record::TXT(txt) = &response.answers[1].record
+
&& let Some(&text) = txt.text.first()
+
{
+
assert_eq!(text, "test txt record");
+
} else {
+
panic!("Expected TXT record");
+
}
+
}
+
+
#[test]
+
fn parse_response_srv() {
+
let data = [
+
//
+
0xAA, 0xAA, // id
+
0x81, 0x80, // flags
+
0x00, 0x01, // one question
+
0x00, 0x01, // one answer
+
0x00, 0x00, // no authority
+
0x00, 0x00, // no extra
+
//
+
0x04, 0x5f, 0x73, 0x69, 0x70, 0x04, 0x5f, 0x74, 0x63, 0x70, 0x07, 0x65, 0x78, 0x61,
+
0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, //
+
//
+
0x00, 0x21, // type SRV
+
0x00, 0x01, // IN
+
//
+
0xc0, 0x0c, //
+
//
+
0x00, 0x21, // SRV
+
0x00, 0x01, // IN
+
0x00, 0x00, 0x00, 0x3C, // ttl 60
+
//
+
0x00, 0x19, // data len
+
0x00, 0x0A, // prio
+
0x00, 0x05, // weight
+
0x13, 0xC4, // PORT
+
//
+
0x09, 0x73, 0x69, 0x70, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x07, 0x65, 0x78, 0x61,
+
0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, //
+
];
+
+
let response = Response::parse(&mut data.as_slice(), data.as_slice()).unwrap();
+
+
assert_eq!(response.id, 0xAAAA);
+
assert_eq!(response.flags.0, 0x8180);
+
assert_eq!(response.answers.len(), 1);
+
+
// Answer
+
assert_eq!(response.answers[0].name, "_sip._tcp.example.com");
+
assert_eq!(response.answers[0].atype, QType::SRV);
+
assert_eq!(response.answers[0].aclass, QClass::IN);
+
assert_eq!(response.answers[0].ttl, 60);
+
let Record::SRV(srv) = &response.answers[0].record else {
+
panic!("Expected SRV record");
+
};
+
+
assert_eq!(srv.priority, 10);
+
assert_eq!(srv.weight, 5);
+
assert_eq!(srv.port, 5060);
+
assert_eq!(srv.target, "sipserver.example.com");
+
}
+
+
#[test]
+
fn parse_response_back_forth() {
+
#[rustfmt::skip]
+
let data = [
+
0, 0, // Transaction ID
+
132, 0, // Response, Authoritative Answer, No Recursion
+
0, 0, // 0 questions
+
0, 4, // 4 answers
+
0, 0, // 0 authority RRs
+
0, 0, // 0 additional RRs
+
// _midiriff
+
9, 95, 109, 105, 100, 105, 114, 105, 102, 102, //
+
// _udp
+
4, 95, 117, 100, 112, //
+
// local
+
5, 108, 111, 99, 97, 108, //
+
0, // <end>
+
//
+
0, 12, // PTR
+
0, 1, // Class IN
+
0, 0, 0, 120, // TTL 120 seconds
+
0, 10, // Data Length 10
+
// pi35291
+
7, 112, 105, 51, 53, 50, 57, 49, //
+
//
+
192, 12, // Pointer to _midirif._udp._local.
+
//
+
192, 44, // Pointer to instace name: pi35291._midirif._udp._local.
+
0, 33, // SRV
+
128, 1, // IN (Cache flush bit set)
+
0, 0, 0, 120, // TTL 120 seconds
+
0, 11, // Data Length 11
+
0, 0, // Priority 0
+
0, 0, // Weight 0
+
137, 219, // Port 35291
+
2, 112, 105, // _pi
+
192, 27, // Pointer to: .local.
+
// TXT (Empty)
+
192, 44, 0, 16, 128, 1, 0, 0, 17, 148, 0, 1, 0,
+
// A (10.1.1.9)
+
192, 72, 0, 1, 128, 1, 0, 0, 0, 120, 0, 4, 10, 1, 1, 9,
+
];
+
+
let response = Response::parse(&mut data.as_slice(), data.as_slice()).unwrap();
+
+
assert_eq!(response.answers[0].name, "_midiriff._udp.local");
+
assert_eq!(response.answers[0].ttl, 120);
+
let Record::PTR(ptr) = &response.answers[0].record else {
+
panic!()
+
};
+
assert_eq!(ptr.name, "pi35291._midiriff._udp.local");
+
+
let mut buffer = [0u8; 256];
+
let mut buffer = Encoder::new(&mut buffer);
+
response.serialize(&mut buffer).unwrap();
+
+
let buffer = buffer.finish();
+
+
let response2 = Response::parse(&mut &buffer[..], buffer).unwrap();
+
+
assert_eq!(response, response2);
+
}
+
+
#[test]
+
fn mdns_service_response() {
+
let mut response = Response {
+
id: 0x1234,
+
flags: Flags::standard_response(),
+
queries: Vec::new(),
+
answers: Vec::new(),
+
};
+
+
let query = Query {
+
name: Label::from("_test._udp.local"),
+
qtype: QType::PTR,
+
qclass: QClass::IN,
+
};
+
response.queries.push(query);
+
+
let ptr_answer = Answer {
+
name: Label::from("_test._udp.local"),
+
atype: QType::PTR,
+
aclass: QClass::IN,
+
ttl: 4500,
+
record: Record::PTR(PTR {
+
name: Label::from("test-service._test._udp.local"),
+
}),
+
};
+
response.answers.push(ptr_answer);
+
+
let srv_answer = Answer {
+
name: Label::from("test-service._test._udp.local"),
+
atype: QType::SRV,
+
aclass: QClass::IN,
+
ttl: 120,
+
record: Record::SRV(SRV {
+
priority: 0,
+
weight: 0,
+
port: 8080,
+
target: Label::from("host.local"),
+
}),
+
};
+
response.answers.push(srv_answer);
+
+
let txt_answer = Answer {
+
name: Label::from("test-service._test._udp.local"),
+
atype: QType::TXT,
+
aclass: QClass::IN,
+
ttl: 120,
+
record: Record::TXT(TXT {
+
text: vec!["path=/test"],
+
}),
+
};
+
response.answers.push(txt_answer);
+
+
let a_answer = Answer {
+
name: Label::from("host.local"),
+
atype: QType::A,
+
aclass: QClass::IN,
+
ttl: 120,
+
record: Record::A(A {
+
address: Ipv4Addr::new(192, 168, 1, 100),
+
}),
+
};
+
response.answers.push(a_answer);
+
+
let mut buffer = [0u8; 256];
+
let mut buffer = Encoder::new(&mut buffer);
+
response.serialize(&mut buffer).unwrap();
+
+
let buffer = buffer.finish();
+
+
let parsed_response = Response::parse(&mut &buffer[..], buffer).unwrap();
+
+
assert_eq!(response, parsed_response);
+
}
+
}
+21
sachy-mdns/src/dns/traits.rs
···
+
use winnow::ModalResult;
+
+
use crate::encoder::Encoder;
+
+
pub trait DnsParse<'a>: Sized {
+
fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self>;
+
}
+
+
pub trait DnsParseKind<'a> {
+
type Output;
+
+
fn parse_kind(&self, input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self::Output>;
+
}
+
+
pub trait DnsSerialize<'a> {
+
type Error;
+
+
fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error>;
+
#[allow(dead_code)]
+
fn size(&self) -> usize;
+
}
+145
sachy-mdns/src/encoder.rs
···
+
use core::ops::Range;
+
+
use alloc::collections::BTreeMap;
+
+
use crate::dns::traits::DnsSerialize;
+
+
pub(crate) const MAX_STR_LEN: u8 = !PTR_MASK;
+
pub(crate) const PTR_MASK: u8 = 0b1100_0000;
+
+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
+
pub enum DnsError {
+
LabelTooLong,
+
InvalidTxt,
+
Unsupported,
+
}
+
+
impl core::fmt::Display for DnsError {
+
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
+
match self {
+
Self::LabelTooLong => f.write_str("Encoding Error: Segment too long"),
+
Self::InvalidTxt => f.write_str("Encoding Error: TXT segment is invalid"),
+
Self::Unsupported => f.write_str("Encoding Error: Unsupported Record Type"),
+
}
+
}
+
}
+
+
impl core::error::Error for DnsError {}
+
+
#[derive(Debug)]
+
pub struct Encoder<'a, 'b> {
+
output: &'b mut [u8],
+
position: usize,
+
lookup: BTreeMap<&'a str, u16>,
+
reservation: Option<usize>,
+
}
+
+
impl<'a, 'b> Encoder<'a, 'b> {
+
pub const fn new(buffer: &'b mut [u8]) -> Self {
+
Self {
+
output: buffer,
+
position: 0,
+
lookup: BTreeMap::new(),
+
reservation: None,
+
}
+
}
+
+
/// Takes a payload and encodes it, consuming the encoder and yielding the resulting
+
/// slice.
+
pub fn encode<T, E>(mut self, payload: T) -> Result<&'b [u8], E>
+
where
+
E: core::error::Error,
+
T: DnsSerialize<'a, Error = E>,
+
{
+
payload.serialize(&mut self)?;
+
Ok(self.finish())
+
}
+
+
pub(crate) fn finish(self) -> &'b [u8] {
+
&self.output[..self.position]
+
}
+
+
fn increment(&mut self, amount: usize) {
+
self.position += amount;
+
}
+
+
pub(crate) fn write_label(&mut self, mut label: &'a str) -> Result<(), DnsError> {
+
loop {
+
if let Some(pos) = self.get_label_position(label) {
+
let [b1, b2] = u16::to_be_bytes(pos);
+
self.write(&[b1 | PTR_MASK, b2]);
+
return Ok(());
+
}
+
+
let dot = label.find('.');
+
+
let end = dot.unwrap_or(label.len());
+
let segment = &label[..end];
+
let len = u8::try_from(segment.len()).map_err(|_| DnsError::LabelTooLong)?;
+
+
if len > MAX_STR_LEN {
+
return Err(DnsError::LabelTooLong);
+
}
+
+
self.store_label_position(label);
+
self.write(&len.to_be_bytes());
+
self.write(segment.as_bytes());
+
+
match dot {
+
Some(end) => {
+
label = &label[end + 1..];
+
}
+
None => {
+
self.write(&[0]);
+
return Ok(());
+
}
+
}
+
}
+
}
+
+
pub(crate) fn write(&mut self, bytes: &[u8]) {
+
let len = bytes.len();
+
let end = self.position + len;
+
self.output[self.position..end].copy_from_slice(bytes);
+
self.increment(len);
+
}
+
+
fn get_label_position(&mut self, label: &str) -> Option<u16> {
+
self.lookup.get(label).copied()
+
}
+
+
fn store_label_position(&mut self, label: &'a str) {
+
self.lookup.insert(label, self.position as u16);
+
}
+
+
fn reserve_record_length(&mut self) {
+
if self.reservation.is_none() {
+
self.reservation = Some(self.position);
+
self.increment(2);
+
}
+
}
+
+
fn distance_from_reservation(&mut self) -> Option<(Range<usize>, u16)> {
+
self.reservation
+
.take()
+
.map(|start| (start..(start + 2), (self.position - start - 2) as u16))
+
}
+
+
fn write_record_length(&mut self) {
+
if let Some((reservation, len)) = self.distance_from_reservation() {
+
self.output[reservation].copy_from_slice(&len.to_be_bytes());
+
}
+
}
+
+
pub(crate) fn with_record_length<E, F>(&mut self, encoding_scope: F) -> Result<(), E>
+
where
+
E: core::error::Error,
+
F: FnOnce(&mut Self) -> Result<(), E>,
+
{
+
self.reserve_record_length();
+
encoding_scope(self)?;
+
self.write_record_length();
+
Ok(())
+
}
+
}
+60
sachy-mdns/src/lib.rs
···
+
#![no_std]
+
+
mod dns;
+
pub(crate) mod encoder;
+
pub mod server;
+
mod service;
+
mod state;
+
+
extern crate alloc;
+
+
use core::net::{Ipv4Addr, SocketAddrV4};
+
+
pub use crate::service::Service;
+
pub use crate::state::MdnsAction;
+
use crate::{dns::flags::Flags, server::Server, state::MdnsStateMachine};
+
+
/// Standard port for mDNS (5353).
+
pub const MDNS_PORT: u16 = 5353;
+
+
/// Standard IPv4 multicast address for mDNS (224.0.0.251).
+
pub const GROUP_ADDR_V4: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 251);
+
pub const GROUP_SOCK_V4: SocketAddrV4 = SocketAddrV4::new(GROUP_ADDR_V4, MDNS_PORT);
+
+
#[derive(Debug)]
+
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
+
pub struct MdnsService {
+
server: Server,
+
state: MdnsStateMachine,
+
}
+
+
impl MdnsService {
+
pub fn new(service: Service) -> Self {
+
Self {
+
server: Server::new(service),
+
state: Default::default(),
+
}
+
}
+
+
pub fn next_action(&mut self) -> MdnsAction {
+
self.state.drive_next_action()
+
}
+
+
pub fn send_announcement<'buffer>(&self, outgoing: &'buffer mut [u8]) -> Option<&'buffer [u8]> {
+
self.server.broadcast(
+
server::ResponseKind::Announcement,
+
Flags::standard_response(),
+
1,
+
alloc::vec::Vec::new(),
+
outgoing,
+
)
+
}
+
+
pub fn listen_for_queries<'buffer>(
+
&mut self,
+
incoming: &[u8],
+
outgoing: &'buffer mut [u8],
+
) -> Option<&'buffer [u8]> {
+
self.server.respond(incoming, outgoing)
+
}
+
}
+103
sachy-mdns/src/server.rs
···
+
use alloc::vec::Vec;
+
use sachy_fmt::{error, info};
+
+
use crate::{
+
dns::{
+
flags::Flags,
+
query::{QClass, Query},
+
records::QType,
+
reqres::{Request, Response},
+
traits::DnsParse,
+
},
+
encoder::Encoder,
+
service::Service,
+
};
+
+
pub(crate) enum ResponseKind {
+
Announcement,
+
QueryResponse(Vec<(QType, QClass)>),
+
}
+
+
#[derive(Debug)]
+
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
+
pub(crate) struct Server {
+
service: Service,
+
}
+
+
impl Server {
+
pub(crate) fn new(service: Service) -> Self {
+
Self { service }
+
}
+
+
pub(crate) fn broadcast<'a, 'b>(
+
&self,
+
response_kind: ResponseKind,
+
flags: Flags,
+
id: u16,
+
queries: Vec<Query<'b>>,
+
outgoing: &'a mut [u8],
+
) -> Option<&'a [u8]> {
+
let answers: Vec<_> = match response_kind {
+
ResponseKind::Announcement => self.service.as_answers(QClass::Multicast).collect(),
+
ResponseKind::QueryResponse(valid) => valid
+
.iter()
+
.flat_map(|&(qtype, qclass)| match qtype {
+
QType::A | QType::AAAA => self.service.ip_answer(qclass),
+
QType::PTR => self.service.ptr_answer(qclass),
+
QType::TXT => self.service.txt_answer(qclass),
+
QType::SRV => self.service.srv_answer(qclass),
+
QType::Any | QType::Unknown(_) => None,
+
})
+
.collect(),
+
};
+
+
if !answers.is_empty() {
+
let res = Response {
+
flags,
+
id,
+
queries,
+
answers,
+
};
+
+
info!("MDNS RESPONSE: {}", res);
+
+
return Encoder::new(outgoing)
+
.encode(res)
+
.inspect_err(|err| error!("Encoder errored: {}", err))
+
.ok();
+
}
+
+
None
+
}
+
+
pub(crate) fn respond<'a>(&self, incoming: &[u8], outgoing: &'a mut [u8]) -> Option<&'a [u8]> {
+
Request::parse(&mut &incoming[..], incoming)
+
.ok()
+
.and_then(|req| {
+
let valid_queries =
+
req.queries
+
.iter()
+
.filter_map(|q| match q.qtype {
+
QType::A | QType::AAAA | QType::TXT | QType::SRV => {
+
(q.name == self.service.hostname()).then_some((q.qtype, q.qclass))
+
}
+
QType::PTR => (q.name == self.service.service_type())
+
.then_some((q.qtype, q.qclass)),
+
QType::Any | QType::Unknown(_) => None,
+
})
+
.collect::<Vec<_>>();
+
+
if !valid_queries.is_empty() {
+
self.broadcast(
+
ResponseKind::QueryResponse(valid_queries),
+
req.flags,
+
req.id,
+
req.queries,
+
outgoing,
+
)
+
} else {
+
None
+
}
+
})
+
}
+
}
+193
sachy-mdns/src/service.rs
···
+
use core::net::IpAddr;
+
+
use alloc::{
+
string::{String, ToString},
+
vec::Vec,
+
};
+
+
use crate::dns::{
+
label::Label,
+
query::{Answer, QClass},
+
records::{A, AAAA, PTR, QType, Record, SRV, TXT},
+
};
+
+
#[derive(Debug, Default)]
+
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
+
pub struct Service {
+
service_type: String,
+
instance: String,
+
hostname: String,
+
ip: Option<IpAddr>,
+
port: u16,
+
}
+
+
impl Service {
+
pub fn new(
+
service_type: impl Into<String>,
+
instance: impl Into<String>,
+
hostname: impl Into<String>,
+
ip: Option<IpAddr>,
+
port: u16,
+
) -> Self {
+
let service_type = service_type.into();
+
let mut instance = instance.into();
+
let mut hostname = hostname.into();
+
+
instance.push('.');
+
instance.push_str(&service_type);
+
hostname.push_str(".local");
+
+
Self {
+
service_type,
+
instance,
+
hostname,
+
ip,
+
port,
+
}
+
}
+
+
pub fn service_type(&self) -> Label<'_> {
+
Label::from(self.service_type.as_ref())
+
}
+
+
pub fn instance(&self) -> Label<'_> {
+
Label::from(self.instance.as_ref())
+
}
+
+
pub fn hostname(&self) -> Label<'_> {
+
Label::from(self.hostname.as_ref())
+
}
+
+
pub fn ip(&self) -> Option<IpAddr> {
+
self.ip
+
}
+
+
pub fn port(&self) -> u16 {
+
self.port
+
}
+
+
pub(crate) fn ptr_answer(&self, aclass: QClass) -> Option<Answer<'_>> {
+
Some(Answer {
+
name: self.service_type(),
+
atype: QType::PTR,
+
aclass,
+
ttl: 4500,
+
record: Record::PTR(PTR {
+
name: self.instance(),
+
}),
+
})
+
}
+
+
pub(crate) fn srv_answer(&self, aclass: QClass) -> Option<Answer<'_>> {
+
Some(Answer {
+
name: self.instance(),
+
atype: QType::SRV,
+
aclass,
+
ttl: 120,
+
record: Record::SRV(SRV {
+
priority: 0,
+
weight: 0,
+
port: self.port,
+
target: self.hostname(),
+
}),
+
})
+
}
+
+
pub(crate) fn txt_answer(&self, aclass: QClass) -> Option<Answer<'_>> {
+
Some(Answer {
+
name: self.instance(),
+
atype: QType::TXT,
+
aclass,
+
ttl: 120,
+
record: Record::TXT(TXT { text: Vec::new() }),
+
})
+
}
+
+
pub(crate) fn ip_answer(&self, aclass: QClass) -> Option<Answer<'_>> {
+
self.ip().map(|address| match address {
+
IpAddr::V4(address) => Answer {
+
name: self.hostname(),
+
atype: QType::A,
+
aclass,
+
ttl: 120,
+
record: Record::A(A { address }),
+
},
+
IpAddr::V6(address) => Answer {
+
name: self.hostname(),
+
atype: QType::AAAA,
+
aclass,
+
ttl: 120,
+
record: Record::AAAA(AAAA { address }),
+
},
+
})
+
}
+
+
pub(crate) fn as_answers(&self, aclass: QClass) -> impl Iterator<Item = Answer<'_>> {
+
self.ptr_answer(aclass)
+
.into_iter()
+
.chain(self.srv_answer(aclass))
+
.chain(self.txt_answer(aclass))
+
.chain(self.ip_answer(aclass))
+
}
+
+
#[allow(dead_code)]
+
pub(crate) fn from_answers(answers: &[Answer<'_>]) -> Vec<Self> {
+
let mut output = Vec::new();
+
+
// Step 1: Process PTR records
+
for answer in answers {
+
if let Record::PTR(ptr) = &answer.record {
+
let instance = ptr.name.to_string();
+
let service_type = answer.name.to_string();
+
output.push(Self {
+
service_type,
+
instance,
+
..Default::default()
+
});
+
}
+
}
+
+
// Step 2: Process SRV records, A and AAAA records and merge data
+
for answer in answers {
+
match &answer.record {
+
Record::SRV(srv) => {
+
if let Some(stub) = output
+
.iter_mut()
+
.find(|stub| answer.name == stub.instance.as_ref())
+
{
+
stub.hostname = srv.target.to_string();
+
stub.port = srv.port;
+
}
+
}
+
Record::A(a) => {
+
if let Some(stub) = output
+
.iter_mut()
+
.find(|stub| answer.name == stub.hostname.as_ref())
+
{
+
stub.ip = Some(IpAddr::V4(a.address));
+
}
+
}
+
Record::AAAA(aaaa) => {
+
if let Some(stub) = output
+
.iter_mut()
+
.find(|stub| answer.name == stub.hostname.as_ref())
+
{
+
stub.ip = Some(IpAddr::V6(aaaa.address));
+
}
+
}
+
_ => {}
+
}
+
}
+
+
// Final step: Retain only complete services
+
output.retain(|stub| {
+
!stub.service_type.is_empty()
+
&& !stub.instance.is_empty()
+
&& !stub.hostname.is_empty()
+
&& stub.ip.is_some()
+
&& stub.port != 0
+
});
+
+
output
+
}
+
}
+100
sachy-mdns/src/state.rs
···
+
use embassy_time::{Duration, Instant};
+
use sachy_fmt::{debug, unwrap};
+
+
#[derive(Debug, Default)]
+
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
+
pub(crate) enum MdnsStateMachine {
+
#[default]
+
Start,
+
Announce {
+
last_sent: Instant,
+
},
+
ListenFor {
+
last_sent: Instant,
+
timeout: Duration,
+
},
+
WaitFor {
+
last_sent: Instant,
+
duration: Duration,
+
},
+
}
+
+
impl MdnsStateMachine {
+
/// Set the state to announced, if we have timed out the listening period and need
+
/// to announce, or if we received a query while listening and have sent a response.
+
pub(crate) fn announced(&mut self) {
+
*self = Self::Announce {
+
last_sent: Instant::now(),
+
};
+
}
+
+
fn next_state(&mut self) {
+
match self {
+
Self::Start => self.announced(),
+
&mut Self::Announce { last_sent } => {
+
let duration_since = last_sent.elapsed();
+
let duration = Duration::from_secs(1) - duration_since;
+
+
*self = Self::WaitFor {
+
last_sent,
+
duration,
+
};
+
}
+
&mut Self::ListenFor { last_sent, .. } => {
+
let duration_since = last_sent.elapsed();
+
let time_limit = Duration::from_secs(120);
+
+
if duration_since >= time_limit {
+
self.announced();
+
} else {
+
let timeout = time_limit - duration_since;
+
*self = Self::ListenFor { last_sent, timeout };
+
}
+
}
+
&mut Self::WaitFor { last_sent, .. } => {
+
let duration_since = last_sent.elapsed();
+
let time_limit = Duration::from_secs(120);
+
let timeout = time_limit - duration_since;
+
+
*self = Self::ListenFor { last_sent, timeout };
+
}
+
}
+
}
+
+
pub(crate) fn drive_next_action(&mut self) -> MdnsAction {
+
self.next_state();
+
unwrap!(MdnsAction::try_from(self))
+
}
+
}
+
+
#[derive(Debug)]
+
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
+
pub enum MdnsAction {
+
Announce,
+
ListenFor { timeout: Duration },
+
WaitFor { duration: Duration },
+
}
+
+
impl TryFrom<&mut MdnsStateMachine> for MdnsAction {
+
type Error = MdnsStateMachine;
+
+
fn try_from(value: &mut MdnsStateMachine) -> Result<Self, Self::Error> {
+
match value {
+
// We should start in this state, but never remain nor return to it when
+
// executing our state machine event loop.
+
MdnsStateMachine::Start => Err(MdnsStateMachine::Start),
+
MdnsStateMachine::Announce { .. } => {
+
debug!("ANNOUNCE");
+
Ok(Self::Announce)
+
}
+
&mut MdnsStateMachine::ListenFor { timeout, .. } => {
+
debug!("LISTEN FOR {}ms", timeout.as_millis());
+
Ok(Self::ListenFor { timeout })
+
}
+
&mut MdnsStateMachine::WaitFor { duration, .. } => {
+
debug!("WAIT FOR {}ms", duration.as_millis());
+
Ok(Self::WaitFor { duration })
+
}
+
}
+
}
+
}