a go dns packet parser

Merge pull request 'creator-funcs' (#6) from creator-funcs into main

Reviewed-on: https://code.kiri.systems/kiri/magna/pulls/6

blu 2c2a1c50 8d0a079d

+89 -23
domain_test.go
···
)
func TestDecodeDomain(t *testing.T) {
-
buf := []byte{
-
0x03, 0x63, 0x6f, 0x6d, 0x00,
+
tests := []struct {
+
name string
+
offset int
+
input []byte
+
expectedDomain string
+
expectedOffset int
+
expectedError error
+
}{
+
{
+
name: "Simple domain",
+
input: []byte{3, 'w', 'w', 'w', 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0},
+
expectedDomain: "www.example.com",
+
expectedOffset: 17,
+
expectedError: nil,
+
},
+
{
+
name: "Domain with compression",
+
offset: 17,
+
input: []byte{3, 'w', 'w', 'w', 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 3, 'f', 'o', 'o', 0xC0, 0},
+
expectedDomain: "foo.www.example.com",
+
expectedOffset: 23,
+
expectedError: nil,
+
},
+
{
+
name: "Invalid label length",
+
input: []byte{64, 'x'},
+
expectedDomain: "",
+
expectedOffset: 2,
+
expectedError: &InvalidLabelError{Length: 64},
+
},
+
{
+
name: "Compression loop",
+
input: []byte{0xC0, 0, 0xC0, 0},
+
expectedDomain: "",
+
expectedOffset: 4,
+
expectedError: &DomainCompressionError{},
+
},
+
{
+
name: "Truncated input",
+
input: []byte{3, 'w', 'w'},
+
expectedDomain: "",
+
expectedOffset: 3,
+
expectedError: &BufferOverflowError{Length: 3, Offset: 4},
+
},
}
-
domain, offset, err := decode_domain(buf, 0)
-
assert.Equal(t, "com", domain)
-
assert.Equal(t, 5, offset)
-
assert.NoError(t, err)
-
}
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
domain, offset, err := decode_domain(tt.input, tt.offset)
-
func TestDecodeDomainWithCompression(t *testing.T) {
-
buf := []byte{
-
0x03, 0x63, 0x6f, 0x6d, 0x00, 0x01, 0x63, 0xC0, 0x00,
+
t.Log(tt.name)
+
assert.Equal(t, tt.expectedError, err)
+
assert.Equal(t, tt.expectedDomain, domain)
+
assert.Equal(t, tt.expectedOffset, offset)
+
})
}
-
-
domain, offset, err := decode_domain(buf, 5)
-
assert.Equal(t, "c.com", domain)
-
assert.Equal(t, 9, offset)
-
assert.NoError(t, err)
}
-
func TestDecodeDomainWithCompressionLoop(t *testing.T) {
-
buf := []byte{
-
0x03, 0x63, 0x6f, 0x6d, 0xC0, 0x00,
+
func TestEncodeDomain(t *testing.T) {
+
tests := []struct {
+
name string
+
input string
+
offsets map[string]uint16
+
expected []byte
+
newOffsets map[string]uint16
+
}{
+
{
+
name: "Simple domain",
+
input: "example.com",
+
offsets: make(map[string]uint16),
+
expected: []byte{7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0},
+
newOffsets: map[string]uint16{"example.com": 0, "com": 8},
+
},
+
{
+
name: "Domain with existing offset",
+
input: "test.example.com",
+
offsets: map[string]uint16{"example.com": 10},
+
expected: []byte{4, 't', 'e', 's', 't', 0xC0, 0x0A},
+
newOffsets: map[string]uint16{"test.example.com": 0, "example.com": 10},
+
},
+
{
+
name: "Multiple subdomains",
+
input: "a.b.c.d",
+
offsets: make(map[string]uint16),
+
expected: []byte{1, 'a', 1, 'b', 1, 'c', 1, 'd', 0},
+
newOffsets: map[string]uint16{"a.b.c.d": 0, "b.c.d": 2, "c.d": 4, "d": 6},
+
},
}
-
domain, offset, err := decode_domain(buf, 0)
-
assert.Equal(t, "", domain)
-
assert.Equal(t, 6, offset)
-
assert.Error(t, err)
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
result := encode_domain([]byte{}, tt.input, &tt.offsets)
+
assert.Equal(t, tt.expected, result, "Encoded domain does not match expected output")
+
assert.Equal(t, tt.newOffsets, tt.offsets, "Offsets map does not match expected state")
+
})
+
}
}
func FuzzDecodeDomain(f *testing.F) {
···
},
}
for _, tc := range testcases {
-
f.Add(tc) // Use f.Add to provide a seed corpus
+
f.Add(tc)
}
f.Fuzz(func(t *testing.T, msg []byte) {
decode_domain(msg, 0)
+71
errors_test.go
···
+
package magna
+
+
import (
+
"testing"
+
+
"github.com/stretchr/testify/assert"
+
)
+
+
func TestBufferOverflowError(t *testing.T) {
+
tests := []struct {
+
name string
+
length int
+
offset int
+
expected string
+
}{
+
{"basic overflow", 10, 15, "magna: offset 15 is past the buffer length 10"},
+
{"zero length", 0, 5, "magna: offset 5 is past the buffer length 0"},
+
{"negative offset", 10, -1, "magna: offset -1 is past the buffer length 10"},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
err := &BufferOverflowError{Length: tt.length, Offset: tt.offset}
+
assert.Equal(t, tt.expected, err.Error())
+
})
+
}
+
}
+
+
func TestInvalidLabelError(t *testing.T) {
+
tests := []struct {
+
name string
+
length int
+
expected string
+
}{
+
{"zero length", 0, "magna: received invalid label length 0"},
+
{"negative length", -1, "magna: received invalid label length -1"},
+
{"large length", 1000, "magna: received invalid label length 1000"},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
err := &InvalidLabelError{Length: tt.length}
+
assert.Equal(t, tt.expected, err.Error())
+
})
+
}
+
}
+
+
func TestDomainCompressionError(t *testing.T) {
+
err := &DomainCompressionError{}
+
expected := "magna: loop detected in domain compression"
+
assert.Equal(t, expected, err.Error())
+
}
+
+
func TestMagnaError(t *testing.T) {
+
tests := []struct {
+
name string
+
message string
+
expected string
+
}{
+
{"empty message", "", "magna: "},
+
{"basic message", "test error", "magna: test error"},
+
{"message with punctuation", "error: invalid input!", "magna: error: invalid input!"},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
err := &MagnaError{Message: tt.message}
+
assert.Equal(t, tt.expected, err.Error())
+
})
+
}
+
}
+348 -49
header_test.go
···
package magna
import (
+
"encoding/binary"
"testing"
"github.com/stretchr/testify/assert"
)
func TestHeaderDecode(t *testing.T) {
-
bytes := []byte{
-
0x01, 0x02, // ID
-
0xaa, 0xaa, // QR, Opcode, AA, TC, RD, RA, Z, RCODE
-
0x00, 0x01, // QDCOUNT
-
0x00, 0x02, // ANCOUNT
-
0x00, 0x03, // NSCOUNT
-
0x00, 0x04, // ARCOUNT
+
tests := []struct {
+
name string
+
input []byte
+
expectedHeader Header
+
expectedOffset int
+
expectedErr error
+
}{
+
{
+
name: "Valid header",
+
input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, 0x00, 0x04},
+
expectedHeader: Header{
+
ID: 0x1234,
+
QR: true,
+
OPCode: OPCode(0),
+
AA: false,
+
TC: false,
+
RD: true,
+
RA: true,
+
Z: 0,
+
RCode: RCode(0),
+
QDCount: 1,
+
ANCount: 2,
+
NSCount: 3,
+
ARCount: 4,
+
},
+
expectedOffset: 12,
+
expectedErr: nil,
+
},
+
{
+
name: "Insufficient buffer length",
+
input: []byte{0x12, 0x34, 0x81},
+
expectedHeader: Header{},
+
expectedOffset: 3,
+
expectedErr: &BufferOverflowError{Length: 3, Offset: 3},
+
},
+
{
+
name: "Invalid ID",
+
input: []byte{0x12},
+
expectedHeader: Header{},
+
expectedOffset: 1,
+
expectedErr: &BufferOverflowError{Length: 1, Offset: 1},
+
},
+
{
+
name: "Missing QDCount",
+
input: []byte{0x12, 0x34, 0x81, 0x80, 0x00},
+
expectedHeader: Header{},
+
expectedOffset: 5,
+
expectedErr: &BufferOverflowError{},
+
},
+
{
+
name: "Missing ANCount",
+
input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01},
+
expectedHeader: Header{},
+
expectedOffset: 6,
+
expectedErr: &BufferOverflowError{},
+
},
+
{
+
name: "Missing NSCount",
+
input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02},
+
expectedHeader: Header{},
+
expectedOffset: 8,
+
expectedErr: &BufferOverflowError{},
+
},
+
{
+
name: "Missing ARCount",
+
input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03},
+
expectedHeader: Header{},
+
expectedOffset: 10,
+
expectedErr: &BufferOverflowError{},
+
},
}
-
var header Header
-
offset, err := header.Decode(bytes, 0)
-
if err != nil {
-
t.Errorf("error should be nil\n")
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
h := &Header{}
+
offset, err := h.Decode(tt.input, 0)
+
+
if tt.expectedErr != nil {
+
assert.Error(t, err)
+
assert.IsType(t, tt.expectedErr, err)
+
} else {
+
assert.NoError(t, err)
+
assert.Equal(t, tt.expectedHeader, *h)
+
}
+
+
assert.Equal(t, tt.expectedOffset, offset)
+
})
}
+
}
-
if offset != 12 {
-
t.Errorf("offset should be 12 not %v\n", offset)
+
func TestHeaderDecodeFlags(t *testing.T) {
+
tests := []struct {
+
name string
+
flags uint16
+
expected Header
+
}{
+
{
+
name: "All flags set",
+
flags: 0xFFFF,
+
expected: Header{
+
QR: true,
+
OPCode: OPCode(15),
+
AA: true,
+
TC: true,
+
RD: true,
+
RA: true,
+
Z: 7,
+
RCode: RCode(15),
+
},
+
},
+
{
+
name: "No flags set",
+
flags: 0x0000,
+
expected: Header{
+
QR: false,
+
OPCode: OPCode(0),
+
AA: false,
+
TC: false,
+
RD: false,
+
RA: false,
+
Z: 0,
+
RCode: RCode(0),
+
},
+
},
+
{
+
name: "Mixed flags",
+
flags: 0x8510,
+
expected: Header{
+
QR: true,
+
OPCode: OPCode(0),
+
AA: true,
+
TC: false,
+
RD: true,
+
RA: false,
+
Z: 1,
+
RCode: RCode(0),
+
},
+
},
}
-
assert.Equal(t, header.ID, uint16(258))
-
assert.Equal(t, header.QR, true)
-
assert.Equal(t, header.OPCode, OPCode(5))
-
assert.Equal(t, header.AA, false)
-
assert.Equal(t, header.TC, true)
-
assert.Equal(t, header.RD, false)
-
assert.Equal(t, header.RA, true)
-
assert.Equal(t, header.Z, uint8(0b010))
-
assert.Equal(t, header.RCode, RCode(0b1010))
-
assert.Equal(t, header.QDCount, uint16(1))
-
assert.Equal(t, header.ANCount, uint16(2))
-
assert.Equal(t, header.NSCount, uint16(3))
-
assert.Equal(t, header.ARCount, uint16(4))
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
input := []byte{
+
0x00, 0x00,
+
byte(tt.flags >> 8), byte(tt.flags),
+
0x00, 0x00,
+
0x00, 0x00,
+
0x00, 0x00,
+
0x00, 0x00,
+
}
+
+
h := &Header{}
+
_, err := h.Decode(input, 0)
+
+
assert.NoError(t, err)
+
assert.Equal(t, tt.expected.QR, h.QR)
+
assert.Equal(t, tt.expected.OPCode, h.OPCode)
+
assert.Equal(t, tt.expected.AA, h.AA)
+
assert.Equal(t, tt.expected.TC, h.TC)
+
assert.Equal(t, tt.expected.RD, h.RD)
+
assert.Equal(t, tt.expected.RA, h.RA)
+
assert.Equal(t, tt.expected.Z, h.Z)
+
assert.Equal(t, tt.expected.RCode, h.RCode)
+
})
+
}
}
func TestHeaderEncode(t *testing.T) {
-
bytes := []byte{
-
0x01, 0x02, // ID
-
0xaa, 0xaa, // QR, Opcode, AA, TC, RD, RA, Z, RCODE
-
0x00, 0x01, // QDCOUNT
-
0x00, 0x02, // ANCOUNT
-
0x00, 0x03, // NSCOUNT
-
0x00, 0x04, // ARCOUNT
+
tests := []struct {
+
name string
+
header Header
+
expected []byte
+
}{
+
{
+
name: "All fields set",
+
header: Header{
+
ID: 0x1234,
+
QR: true,
+
OPCode: OPCode(1),
+
AA: true,
+
TC: true,
+
RD: true,
+
RA: true,
+
Z: 5,
+
RCode: RCode(3),
+
QDCount: 1,
+
ANCount: 2,
+
NSCount: 3,
+
ARCount: 4,
+
},
+
expected: []byte{
+
0x12, 0x34,
+
0x8f, 0xd3,
+
0x00, 0x01,
+
0x00, 0x02,
+
0x00, 0x03,
+
0x00, 0x04,
+
},
+
},
+
{
+
name: "No flags set",
+
header: Header{
+
ID: 0x5678,
+
QR: false,
+
OPCode: OPCode(0),
+
AA: false,
+
TC: false,
+
RD: false,
+
RA: false,
+
Z: 0,
+
RCode: RCode(0),
+
QDCount: 5,
+
ANCount: 6,
+
NSCount: 7,
+
ARCount: 8,
+
},
+
expected: []byte{
+
0x56, 0x78,
+
0x00, 0x00,
+
0x00, 0x05,
+
0x00, 0x06,
+
0x00, 0x07,
+
0x00, 0x08,
+
},
+
},
+
{
+
name: "Mixed flags",
+
header: Header{
+
ID: 0x9abc,
+
QR: true,
+
OPCode: OPCode(2),
+
AA: false,
+
TC: true,
+
RD: false,
+
RA: true,
+
Z: 3,
+
RCode: RCode(4),
+
QDCount: 9,
+
ANCount: 10,
+
NSCount: 11,
+
ARCount: 12,
+
},
+
expected: []byte{
+
0x9a, 0xbc,
+
0x92, 0xb4,
+
0x00, 0x09,
+
0x00, 0x0a,
+
0x00, 0x0b,
+
0x00, 0x0c,
+
},
+
},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
encoded := tt.header.Encode()
+
assert.Equal(t, tt.expected, encoded)
+
})
}
+
}
-
var header Header
-
_, err := header.Decode(bytes, 0)
+
func TestHeaderEncodeDecodeRoundTrip(t *testing.T) {
+
originalHeader := Header{
+
ID: 0xdcba,
+
QR: true,
+
OPCode: OPCode(3),
+
AA: true,
+
TC: false,
+
RD: true,
+
RA: false,
+
Z: 6,
+
RCode: RCode(2),
+
QDCount: 13,
+
ANCount: 14,
+
NSCount: 15,
+
ARCount: 16,
+
}
+
+
encoded := originalHeader.Encode()
+
+
decodedHeader := &Header{}
+
offset, err := decodedHeader.Decode(encoded, 0)
+
assert.NoError(t, err)
+
assert.Equal(t, len(encoded), offset)
+
assert.Equal(t, originalHeader, *decodedHeader)
+
}
-
actual := header.Encode()
-
assert.Equal(t, bytes, actual)
+
func TestHeaderEncodeFlagCombinations(t *testing.T) {
+
testCases := []struct {
+
name string
+
header Header
+
expected uint16
+
}{
+
{"QR flag", Header{QR: true}, 0x8000},
+
{"OPCode", Header{OPCode: OPCode(5)}, 0x2800},
+
{"AA flag", Header{AA: true}, 0x0400},
+
{"TC flag", Header{TC: true}, 0x0200},
+
{"RD flag", Header{RD: true}, 0x0100},
+
{"RA flag", Header{RA: true}, 0x0080},
+
{"Z value", Header{Z: 5}, 0x0050},
+
{"RCode", Header{RCode: RCode(7)}, 0x0007},
+
{"All flags set", Header{QR: true, OPCode: OPCode(15), AA: true, TC: true, RD: true, RA: true, Z: 7, RCode: RCode(15)}, 0xffff},
+
}
+
+
for _, tc := range testCases {
+
t.Run(tc.name, func(t *testing.T) {
+
encoded := tc.header.Encode()
+
flags := binary.BigEndian.Uint16(encoded[2:4])
+
assert.Equal(t, tc.expected, flags)
+
})
+
}
}
func FuzzDecodeHeader(f *testing.F) {
testcases := [][]byte{
-
{
-
0x01, 0x02,
-
0xaa, 0xaa,
-
0x00, 0x01,
-
0x00, 0x02,
-
0x00, 0x03,
-
0x00, 0x04,
-
},
+
{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, 0x00, 0x04},
+
{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+
{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
}
+
for _, tc := range testcases {
-
f.Add(tc) // Use f.Add to provide a seed corpus
+
f.Add(tc)
}
-
var header Header
-
f.Fuzz(func(t *testing.T, msg []byte) {
-
header.Decode(msg, 0)
+
+
f.Fuzz(func(t *testing.T, data []byte) {
+
// limit to only 12 bytes
+
if len(data) > 12 {
+
data = data[0:12]
+
}
+
+
h := &Header{}
+
offset, err := h.Decode(data, 0)
+
if err != nil {
+
switch err.(type) {
+
case *BufferOverflowError, *InvalidLabelError:
+
// these are expected error types
+
default:
+
t.Errorf("unexpected error type: %T", err)
+
}
+
return
+
}
+
+
if offset != len(data) {
+
t.Errorf("offset (%d) does not match data length (%d)", offset, len(data))
+
}
+
+
if h.OPCode > 15 {
+
t.Errorf("invalid OPCode: %d", h.OPCode)
+
}
+
+
if h.Z > 7 {
+
t.Errorf("invalid Z value: %d", h.Z)
+
}
+
+
if h.RCode > 15 {
+
t.Errorf("invalid RCode: %d", h.RCode)
+
}
+
+
encoded := h.Encode()
+
if len(encoded) != len(data) {
+
t.Errorf("encoded length (%d) does not match input length (%d)", len(encoded), len(data))
+
+
for i := 0; i < len(data); i++ {
+
t.Errorf("mismatch at position: %d: encoded %02x, input: %02x", i, encoded[i], data[i])
+
}
+
}
})
}
+49
message.go
···
package magna
+
import (
+
"math/rand"
+
)
+
// Decode decodes a DNS packet.
func (m *Message) Decode(buf []byte) (err error) {
offset, err := m.Header.Decode(buf, 0)
···
return bytes
}
+
+
func CreateRequest(op OPCode, rd bool) *Message {
+
return &Message{
+
Header: Header{
+
ID: uint16(rand.Intn(65534) + 1),
+
QR: false,
+
OPCode: op,
+
AA: false,
+
TC: false,
+
RD: rd,
+
RA: false,
+
Z: 0,
+
RCode: NOERROR,
+
QDCount: 0,
+
ARCount: 0,
+
NSCount: 0,
+
ANCount: 0,
+
},
+
Question: make([]Question, 0),
+
Answer: make([]ResourceRecord, 0),
+
Additional: make([]ResourceRecord, 0),
+
Authority: make([]ResourceRecord, 0),
+
}
+
}
+
+
func (m *Message) CreateReply(req *Message) *Message {
+
m.Header.ID = req.Header.ID
+
m.Header.QR = true
+
m.Header.OPCode = req.Header.OPCode
+
+
return m
+
}
+
+
func (m *Message) AddQuestion(q Question) *Message {
+
m.Header.QDCount += 1
+
m.Question = append(m.Question, q)
+
+
return m
+
}
+
+
func (m *Message) SetRCode(rc RCode) *Message {
+
m.Header.RCode = rc
+
+
return m
+
}
+105 -176
message_test.go
···
package magna
import (
+
"bytes"
+
"encoding/binary"
+
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func TestMessageDecode(t *testing.T) {
-
bytes := []byte{
-
0x8e, 0x19, 0x81, 0x80, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x04, 0x6e, 0x65,
-
0x77, 0x73, 0x0b, 0x79, 0x63, 0x6f, 0x6d, 0x62, 0x69, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x03,
-
0x63, 0x6f, 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01, 0xc0, 0x0c, 0x00, 0x01, 0x00, 0x01, 0x00,
-
0x00, 0x00, 0x01, 0x00, 0x04, 0xd1, 0xd8, 0xe6, 0xcf,
+
tests := []struct {
+
name string
+
input []byte
+
expected Message
+
wantErr bool
+
}{
+
{
+
name: "Valid DNS message with one question",
+
input: func() []byte {
+
buf := new(bytes.Buffer)
+
binary.Write(buf, binary.BigEndian, uint16(1234))
+
binary.Write(buf, binary.BigEndian, uint16(0x0100))
+
binary.Write(buf, binary.BigEndian, uint16(1))
+
binary.Write(buf, binary.BigEndian, uint16(0))
+
binary.Write(buf, binary.BigEndian, uint16(0))
+
binary.Write(buf, binary.BigEndian, uint16(0))
+
buf.Write([]byte{3, 'w', 'w', 'w', 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0})
+
binary.Write(buf, binary.BigEndian, uint16(1))
+
binary.Write(buf, binary.BigEndian, uint16(1))
+
return buf.Bytes()
+
}(),
+
expected: Message{
+
Header: Header{
+
ID: 1234,
+
QR: false,
+
RD: true,
+
OPCode: 0,
+
QDCount: 1,
+
},
+
Question: []Question{
+
{
+
QName: "www.example.com",
+
QType: 1,
+
QClass: 1,
+
},
+
},
+
},
+
wantErr: false,
+
},
+
{
+
name: "Valid DNS message with one answer",
+
input: func() []byte {
+
buf := new(bytes.Buffer)
+
binary.Write(buf, binary.BigEndian, uint16(5678))
+
binary.Write(buf, binary.BigEndian, uint16(0x8180))
+
binary.Write(buf, binary.BigEndian, uint16(0))
+
binary.Write(buf, binary.BigEndian, uint16(1))
+
binary.Write(buf, binary.BigEndian, uint16(0))
+
binary.Write(buf, binary.BigEndian, uint16(0))
+
buf.Write([]byte{3, 'w', 'w', 'w', 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0})
+
binary.Write(buf, binary.BigEndian, uint16(1))
+
binary.Write(buf, binary.BigEndian, uint16(1))
+
binary.Write(buf, binary.BigEndian, uint32(3600))
+
binary.Write(buf, binary.BigEndian, uint16(4))
+
binary.Write(buf, binary.BigEndian, uint32(0x0A000001))
+
return buf.Bytes()
+
}(),
+
expected: Message{
+
Header: Header{
+
ID: 5678,
+
QR: true,
+
OPCode: 0,
+
AA: false,
+
RD: true,
+
RA: true,
+
RCode: 0,
+
ANCount: 1,
+
},
+
Answer: []ResourceRecord{
+
{
+
Name: "www.example.com",
+
RType: 1,
+
RClass: 1,
+
TTL: 3600,
+
RDLength: 4,
+
RData: &A{net.IP([]byte{10, 0, 0, 1})},
+
},
+
},
+
},
+
wantErr: false,
+
},
+
{
+
name: "Invalid input - empty buffer",
+
input: []byte{},
+
wantErr: true,
+
},
}
-
var msg Message
-
msg.Decode(bytes)
-
assert.Equal(t, uint16(0x8e19), msg.Header.ID)
-
assert.Equal(t, true, msg.Header.QR)
-
assert.Equal(t, OPCode(0), msg.Header.OPCode)
-
assert.Equal(t, false, msg.Header.AA)
-
assert.Equal(t, false, msg.Header.TC)
-
assert.Equal(t, true, msg.Header.RD)
-
assert.Equal(t, true, msg.Header.RA)
-
assert.Equal(t, uint8(0), msg.Header.Z)
-
assert.Equal(t, RCode(0), msg.Header.RCode)
-
assert.Equal(t, uint16(1), msg.Header.QDCount)
-
assert.Equal(t, uint16(1), msg.Header.ANCount)
-
assert.Equal(t, uint16(0), msg.Header.NSCount)
-
assert.Equal(t, uint16(0), msg.Header.ARCount)
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
m := &Message{}
+
err := m.Decode(tt.input)
-
assert.Equal(t, 1, len(msg.Question))
-
-
question := msg.Question[0]
-
assert.Equal(t, "news.ycombinator.com", question.QName)
-
assert.Equal(t, DNSType(1), question.QType)
-
assert.Equal(t, DNSClass(1), question.QClass)
-
-
assert.Equal(t, 1, len(msg.Answer))
-
answer := msg.Answer[0]
-
assert.Equal(t, answer.Name, "news.ycombinator.com")
-
assert.Equal(t, DNSType(1), answer.RType)
-
assert.Equal(t, DNSClass(1), answer.RClass)
-
assert.Equal(t, uint32(1), answer.TTL)
-
assert.Equal(t, uint16(4), answer.RDLength)
-
}
-
-
func TestMessageDecodeWithU14Offset(t *testing.T) {
-
bytes := []byte{
-
0x0e, 0xc3, 0x80, 0x00, 0x00, 0x01, 0x00, 0x00,
-
0x00, 0x0d, 0x00, 0x0e, 0x06, 0x6e, 0x73, 0x2d,
-
0x33, 0x37, 0x32, 0x09, 0x61, 0x77, 0x73, 0x64,
-
0x6e, 0x73, 0x2d, 0x34, 0x36, 0x03, 0x63, 0x6f,
-
0x6d, 0x00, 0x00, 0x01, 0x00, 0x01, 0xc0, 0x1d,
-
0x00, 0x02, 0x00, 0x01, 0x00, 0x02, 0xa3, 0x00,
-
0x00, 0x14, 0x01, 0x61, 0x0c, 0x67, 0x74, 0x6c,
-
0x64, 0x2d, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72,
-
0x73, 0x03, 0x6e, 0x65, 0x74, 0x00, 0xc0, 0x1d,
-
0x00, 0x02, 0x00, 0x01, 0x00, 0x02, 0xa3, 0x00,
-
0x00, 0x04, 0x01, 0x62, 0xc0, 0x34, 0xc0, 0x1d,
-
0x00, 0x02, 0x00, 0x01, 0x00, 0x02, 0xa3, 0x00,
-
0x00, 0x04, 0x01, 0x63, 0xc0, 0x34, 0xc0, 0x1d,
-
0x00, 0x02, 0x00, 0x01, 0x00, 0x02, 0xa3, 0x00,
-
0x00, 0x04, 0x01, 0x64, 0xc0, 0x34, 0xc0, 0x1d,
-
0x00, 0x02, 0x00, 0x01, 0x00, 0x02, 0xa3, 0x00,
-
0x00, 0x04, 0x01, 0x65, 0xc0, 0x34, 0xc0, 0x1d,
-
0x00, 0x02, 0x00, 0x01, 0x00, 0x02, 0xa3, 0x00,
-
0x00, 0x04, 0x01, 0x66, 0xc0, 0x34, 0xc0, 0x1d,
-
0x00, 0x02, 0x00, 0x01, 0x00, 0x02, 0xa3, 0x00,
-
0x00, 0x04, 0x01, 0x67, 0xc0, 0x34, 0xc0, 0x1d,
-
0x00, 0x02, 0x00, 0x01, 0x00, 0x02, 0xa3, 0x00,
-
0x00, 0x04, 0x01, 0x68, 0xc0, 0x34, 0xc0, 0x1d,
-
0x00, 0x02, 0x00, 0x01, 0x00, 0x02, 0xa3, 0x00,
-
0x00, 0x04, 0x01, 0x69, 0xc0, 0x34, 0xc0, 0x1d,
-
0x00, 0x02, 0x00, 0x01, 0x00, 0x02, 0xa3, 0x00,
-
0x00, 0x04, 0x01, 0x6a, 0xc0, 0x34, 0xc0, 0x1d,
-
0x00, 0x02, 0x00, 0x01, 0x00, 0x02, 0xa3, 0x00,
-
0x00, 0x04, 0x01, 0x6b, 0xc0, 0x34, 0xc0, 0x1d,
-
0x00, 0x02, 0x00, 0x01, 0x00, 0x02, 0xa3, 0x00,
-
0x00, 0x04, 0x01, 0x6c, 0xc0, 0x34, 0xc0, 0x1d,
-
0x00, 0x02, 0x00, 0x01, 0x00, 0x02, 0xa3, 0x00,
-
0x00, 0x04, 0x01, 0x6d, 0xc0, 0x34, 0xc0, 0x32,
-
0x00, 0x01, 0x00, 0x01, 0x00, 0x02, 0xa3, 0x00,
-
0x00, 0x04, 0xc0, 0x05, 0x06, 0x1e, 0xc0, 0x52,
-
0x00, 0x01, 0x00, 0x01, 0x00, 0x02, 0xa3, 0x00,
-
0x00, 0x04, 0xc0, 0x21, 0x0e, 0x1e, 0xc0, 0x62,
-
0x00, 0x01, 0x00, 0x01, 0x00, 0x02, 0xa3, 0x00,
-
0x00, 0x04, 0xc0, 0x1a, 0x5c, 0x1e, 0xc0, 0x72,
-
0x00, 0x01, 0x00, 0x01, 0x00, 0x02, 0xa3, 0x00,
-
0x00, 0x04, 0xc0, 0x1f, 0x50, 0x1e, 0xc0, 0x82,
-
0x00, 0x01, 0x00, 0x01, 0x00, 0x02, 0xa3, 0x00,
-
0x00, 0x04, 0xc0, 0x0c, 0x5e, 0x1e, 0xc0, 0x92,
-
0x00, 0x01, 0x00, 0x01, 0x00, 0x02, 0xa3, 0x00,
-
0x00, 0x04, 0xc0, 0x23, 0x33, 0x1e, 0xc0, 0xa2,
-
0x00, 0x01, 0x00, 0x01, 0x00, 0x02, 0xa3, 0x00,
-
0x00, 0x04, 0xc0, 0x2a, 0x5d, 0x1e, 0xc0, 0xb2,
-
0x00, 0x01, 0x00, 0x01, 0x00, 0x02, 0xa3, 0x00,
-
0x00, 0x04, 0xc0, 0x36, 0x70, 0x1e, 0xc0, 0xc2,
-
0x00, 0x01, 0x00, 0x01, 0x00, 0x02, 0xa3, 0x00,
-
0x00, 0x04, 0xc0, 0x2b, 0xac, 0x1e, 0xc0, 0xd2,
-
0x00, 0x01, 0x00, 0x01, 0x00, 0x02, 0xa3, 0x00,
-
0x00, 0x04, 0xc0, 0x30, 0x4f, 0x1e, 0xc0, 0xe2,
-
0x00, 0x01, 0x00, 0x01, 0x00, 0x02, 0xa3, 0x00,
-
0x00, 0x04, 0xc0, 0x34, 0xb2, 0x1e, 0xc0, 0xf2,
-
0x00, 0x01, 0x00, 0x01, 0x00, 0x02, 0xa3, 0x00,
-
0x00, 0x04, 0xc0, 0x29, 0xa2, 0x1e, 0xc1, 0x02,
-
0x00, 0x01, 0x00, 0x01, 0x00, 0x02, 0xa3, 0x00,
-
0x00, 0x04, 0xc0, 0x37, 0x53, 0x1e, 0xc0, 0x32,
-
0x00, 0x1c, 0x00, 0x01, 0x00, 0x02, 0xa3, 0x00,
-
0x00, 0x10, 0x20, 0x01, 0x05, 0x03, 0xa8, 0x3e,
-
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02,
-
0x00, 0x30,
+
if tt.wantErr {
+
assert.Error(t, err)
+
} else {
+
assert.NoError(t, err)
+
assert.Equal(t, tt.expected.Header, m.Header)
+
assert.Equal(t, tt.expected.Question, m.Question)
+
assert.Equal(t, tt.expected.Answer, m.Answer)
+
assert.Equal(t, tt.expected.Authority, m.Authority)
+
assert.Equal(t, tt.expected.Additional, m.Additional)
+
}
+
})
}
-
-
var msg Message
-
_ = msg.Decode(bytes)
-
// assert_no_error(t, err)
-
-
// Header Section
-
assert.Equal(t, uint16(0x0ec3), msg.Header.ID)
-
assert.Equal(t, true, msg.Header.QR)
-
assert.Equal(t, QUERY, msg.Header.OPCode)
-
assert.Equal(t, false, msg.Header.AA)
-
assert.Equal(t, false, msg.Header.TC)
-
assert.Equal(t, false, msg.Header.RD)
-
assert.Equal(t, false, msg.Header.RA)
-
assert.Equal(t, uint8(0), msg.Header.Z)
-
assert.Equal(t, NOERROR, msg.Header.RCode)
-
assert.Equal(t, uint16(1), msg.Header.QDCount)
-
assert.Equal(t, uint16(0), msg.Header.ANCount)
-
assert.Equal(t, uint16(13), msg.Header.NSCount)
-
assert.Equal(t, uint16(14), msg.Header.ARCount)
-
-
// Query Section
-
assert.Equal(t, 1, len(msg.Question))
-
question := msg.Question[0]
-
assert.Equal(t, "ns-372.awsdns-46.com", question.QName)
-
assert.Equal(t, AType, question.QType)
-
assert.Equal(t, IN, question.QClass)
-
-
assert.Equal(t, 13, len(msg.Authority))
-
assert.Equal(t, 14, len(msg.Additional))
-
}
-
-
func TestMessageEncode(t *testing.T) {
-
bytes := []byte{
-
0x8e, 0x19, 0x81, 0x80, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x04, 0x6e, 0x65,
-
0x77, 0x73, 0x0b, 0x79, 0x63, 0x6f, 0x6d, 0x62, 0x69, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x03,
-
0x63, 0x6f, 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01, 0xc0, 0x0c, 0x00, 0x01, 0x00, 0x01, 0x00,
-
0x00, 0x00, 0x01, 0x00, 0x04, 0xd1, 0xd8, 0xe6, 0xcf,
-
}
-
-
var msg Message
-
err := msg.Decode(bytes)
-
assert.NoError(t, err)
-
-
actual := msg.Encode()
-
assert.Equal(t, bytes, actual)
-
}
-
-
func TestMessageEncode2(t *testing.T) {
-
bytes := []byte{
-
0xfc, 0xa9, 0x81, 0x80, 0x00, 0x01, 0x00, 0x05,
-
0x00, 0x00, 0x00, 0x00, 0x03, 0x6f, 0x6c, 0x64,
-
0x06, 0x72, 0x65, 0x64, 0x64, 0x69, 0x74, 0x03,
-
0x63, 0x6f, 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01,
-
0xc0, 0x0c, 0x00, 0x05, 0x00, 0x01, 0x00, 0x00,
-
0x2a, 0x17, 0x00, 0x17, 0x06, 0x72, 0x65, 0x64,
-
0x64, 0x69, 0x74, 0x03, 0x6d, 0x61, 0x70, 0x06,
-
0x66, 0x61, 0x73, 0x74, 0x6c, 0x79, 0x03, 0x6e,
-
0x65, 0x74, 0x00, 0xc0, 0x2c, 0x00, 0x01, 0x00,
-
0x01, 0x00, 0x00, 0x00, 0x23, 0x00, 0x04, 0x97,
-
0x65, 0x01, 0x8c, 0xc0, 0x2c, 0x00, 0x01, 0x00,
-
0x01, 0x00, 0x00, 0x00, 0x23, 0x00, 0x04, 0x97,
-
0x65, 0xc1, 0x8c, 0xc0, 0x2c, 0x00, 0x01, 0x00,
-
0x01, 0x00, 0x00, 0x00, 0x23, 0x00, 0x04, 0x97,
-
0x65, 0x41, 0x8c, 0xc0, 0x2c, 0x00, 0x01, 0x00,
-
0x01, 0x00, 0x00, 0x00, 0x23, 0x00, 0x04, 0x97,
-
0x65, 0x81, 0x8c,
-
}
-
-
var msg Message
-
err := msg.Decode(bytes)
-
assert.NoError(t, err)
-
-
actual := msg.Encode()
-
assert.Equal(t, bytes, actual)
}
func FuzzDecodeMessage(f *testing.F) {
···
},
}
for _, tc := range testcases {
-
f.Add(tc) // Use f.Add to provide a seed corpus
+
f.Add(tc)
}
f.Fuzz(func(t *testing.T, msg []byte) {
var m Message
+240
question_test.go
···
+
package magna
+
+
import (
+
"testing"
+
+
"github.com/stretchr/testify/assert"
+
)
+
+
func TestQuestionDecode(t *testing.T) {
+
tests := []struct {
+
name string
+
input []byte
+
expectedOffset int
+
expected Question
+
expectedErr error
+
}{
+
{
+
name: "Valid question - example.com A IN",
+
input: []byte{7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 0, 1, 0, 1},
+
expectedOffset: 17,
+
expected: Question{
+
QName: "example.com",
+
QType: DNSType(1),
+
QClass: DNSClass(1),
+
},
+
expectedErr: nil,
+
},
+
{
+
name: "Valid question - example.com MX CH",
+
input: []byte{7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 0, 15, 0, 3},
+
expectedOffset: 17,
+
expected: Question{
+
QName: "example.com",
+
QType: DNSType(15),
+
QClass: DNSClass(3),
+
},
+
expectedErr: nil,
+
},
+
{
+
name: "Invalid domain name",
+
input: []byte{255, 'i', 'n', 'v', 'a', 'l', 'i', 'd', 0, 0, 1, 0, 1},
+
expectedOffset: 13,
+
expected: Question{},
+
expectedErr: &BufferOverflowError{},
+
},
+
{
+
name: "Insufficient buffer for QType",
+
input: []byte{7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 0},
+
expectedOffset: 14,
+
expected: Question{QName: "example.com"},
+
expectedErr: &BufferOverflowError{},
+
},
+
{
+
name: "Insufficient buffer for QClass",
+
input: []byte{7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 0, 1, 0},
+
expectedOffset: 16,
+
expected: Question{QName: "example.com", QType: DNSType(1)},
+
expectedErr: &BufferOverflowError{},
+
},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
q := &Question{}
+
offset, err := q.Decode(tt.input, 0)
+
+
assert.Equal(t, tt.expectedOffset, offset)
+
+
if tt.expectedErr != nil {
+
assert.Error(t, err)
+
assert.IsType(t, tt.expectedErr, err)
+
} else {
+
assert.NoError(t, err)
+
assert.Equal(t, tt.expected, *q)
+
}
+
})
+
}
+
}
+
+
func TestQuestionEncode(t *testing.T) {
+
tests := []struct {
+
name string
+
question Question
+
offsets map[string]uint16
+
expected []byte
+
}{
+
{
+
name: "Simple domain - example.com A IN",
+
question: Question{
+
QName: "example.com",
+
QType: DNSType(1),
+
QClass: DNSClass(1),
+
},
+
offsets: make(map[string]uint16),
+
expected: []byte{7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 0, 1, 0, 1},
+
},
+
{
+
name: "Subdomain - subdomain.example.com AAAA IN",
+
question: Question{
+
QName: "subdomain.example.com",
+
QType: DNSType(28),
+
QClass: DNSClass(1),
+
},
+
offsets: make(map[string]uint16),
+
expected: []byte{9, 's', 'u', 'b', 'd', 'o', 'm', 'a', 'i', 'n', 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 0, 28, 0, 1},
+
},
+
{
+
name: "Different class - example.com MX CH",
+
question: Question{
+
QName: "example.com",
+
QType: DNSType(15),
+
QClass: DNSClass(3),
+
},
+
offsets: make(map[string]uint16),
+
expected: []byte{7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 0, 15, 0, 3},
+
},
+
{
+
name: "Domain compression - example.com after subdomain.example.com",
+
question: Question{
+
QName: "example.com",
+
QType: DNSType(1),
+
QClass: DNSClass(1),
+
},
+
offsets: map[string]uint16{
+
"com": 22,
+
"example.com": 19,
+
},
+
expected: []byte{0xC0, 0x13, 0x00, 0x01, 0x00, 0x01},
+
},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
result := tt.question.Encode(nil, &tt.offsets)
+
assert.Equal(t, tt.expected, result)
+
+
if len(tt.offsets) == 0 {
+
expectedOffsets := map[string]uint16{
+
tt.question.QName: 0,
+
}
+
for i := 0; i < len(tt.question.QName); i++ {
+
if tt.question.QName[i] == '.' {
+
expectedOffsets[tt.question.QName[i+1:]] = uint16(i + 1)
+
}
+
}
+
assert.Equal(t, expectedOffsets, tt.offsets)
+
}
+
})
+
}
+
}
+
+
func TestQuestionEncodeDecodeRoundTrip(t *testing.T) {
+
tests := []struct {
+
name string
+
question Question
+
}{
+
{
+
name: "Simple domain - example.com A IN",
+
question: Question{
+
QName: "example.com",
+
QType: DNSType(1),
+
QClass: DNSClass(1),
+
},
+
},
+
{
+
name: "Subdomain - subdomain.example.com AAAA IN",
+
question: Question{
+
QName: "subdomain.example.com",
+
QType: DNSType(28),
+
QClass: DNSClass(1),
+
},
+
},
+
{
+
name: "Different class - example.com MX CH",
+
question: Question{
+
QName: "example.com",
+
QType: DNSType(15),
+
QClass: DNSClass(3),
+
},
+
},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
offsets := make(map[string]uint16)
+
encoded := tt.question.Encode(nil, &offsets)
+
+
decodedQuestion := &Question{}
+
_, err := decodedQuestion.Decode(encoded, 0)
+
+
assert.NoError(t, err)
+
assert.Equal(t, tt.question, *decodedQuestion)
+
})
+
}
+
}
+
+
func TestQuestionEncodeWithExistingBuffer(t *testing.T) {
+
question := Question{
+
QName: "example.com",
+
QType: DNSType(1),
+
QClass: DNSClass(1),
+
}
+
+
existingBuffer := []byte{0xFF, 0xFF, 0xFF, 0xFF}
+
offsets := make(map[string]uint16)
+
+
result := question.Encode(existingBuffer, &offsets)
+
+
expected := append(
+
existingBuffer,
+
[]byte{7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 0, 1, 0, 1}...,
+
)
+
+
assert.Equal(t, expected, result)
+
}
+
+
func TestQuestionEncodeLongDomainName(t *testing.T) {
+
longLabel := make([]byte, 63)
+
for i := range longLabel {
+
longLabel[i] = 'a'
+
}
+
longDomainName := string(longLabel) + "." + string(longLabel) + "." + string(longLabel) + "." + string(longLabel[:61])
+
+
question := Question{
+
QName: longDomainName,
+
QType: DNSType(1),
+
QClass: DNSClass(1),
+
}
+
+
offsets := make(map[string]uint16)
+
encoded := question.Encode(nil, &offsets)
+
+
assert.Equal(t, 259, len(encoded))
+
+
decodedQuestion := &Question{}
+
_, err := decodedQuestion.Decode(encoded, 0)
+
+
assert.NoError(t, err)
+
assert.Equal(t, question, *decodedQuestion)
+
}