a go dns packet parser

use assert to test and fix offset issue in resource record encode

+7 -11
domain_name.go
···
}
// encode_domain returns the bytes of the input bytes appened with the encoded domain name.
-
func encode_domain(bytes []byte, domain_name string, offsets *map[string]uint8) []byte {
-
pos := uint8(len(bytes))
+
func encode_domain(bytes []byte, domain_name string, offsets *map[string]uint16) []byte {
+
pos := uint16(len(bytes))
labels := strings.Split(domain_name, ".")
for i, label := range labels {
remaining_labels := strings.Join(labels[i:], ".")
if offset, found := (*offsets)[remaining_labels]; found {
-
pointer := 0xC000 | uint16(offset)
-
bytes = binary.BigEndian.AppendUint16(bytes, pointer)
-
-
return bytes
+
pointer := 0xC000 | offset
+
return binary.BigEndian.AppendUint16(bytes, pointer)
}
+
(*offsets)[remaining_labels] = pos
bytes = append(bytes, uint8(len(label)))
bytes = append(bytes, []byte(label)...)
-
-
(*offsets)[remaining_labels] = pos
-
pos += 1 + uint8(len(label))
+
pos += 1 + uint16(len(label))
}
-
bytes = append(bytes, 0)
-
return bytes
+
return append(bytes, 0)
}
+11 -9
domain_test.go
···
import (
"testing"
+
+
"github.com/stretchr/testify/assert"
)
func TestDecodeDomain(t *testing.T) {
···
}
domain, offset, err := decode_domain(buf, 0)
-
assert_eq(t, "com", domain)
-
assert_eq(t, 5, offset)
-
assert_no_error(t, err)
+
assert.Equal(t, "com", domain)
+
assert.Equal(t, 5, offset)
+
assert.NoError(t, err)
}
func TestDecodeDomainWithCompression(t *testing.T) {
···
}
domain, offset, err := decode_domain(buf, 5)
-
assert_eq(t, "c.com", domain)
-
assert_eq(t, 9, offset)
-
assert_no_error(t, err)
+
assert.Equal(t, "c.com", domain)
+
assert.Equal(t, 9, offset)
+
assert.NoError(t, err)
}
func TestDecodeDomainWithCompressionLoop(t *testing.T) {
···
}
domain, offset, err := decode_domain(buf, 0)
-
assert_eq(t, "", domain)
-
assert_eq(t, 6, offset)
-
assert_error(t, err)
+
assert.Equal(t, "", domain)
+
assert.Equal(t, 6, offset)
+
assert.Error(t, err)
}
func FuzzDecodeDomain(f *testing.F) {
+8
go.mod
···
module code.kiri.systems/kiri/magna
go 1.22.3
+
+
require github.com/stretchr/testify v1.9.0
+
+
require (
+
github.com/davecgh/go-spew v1.1.1 // indirect
+
github.com/pmezard/go-difflib v1.0.0 // indirect
+
gopkg.in/yaml.v3 v3.0.1 // indirect
+
)
+17 -42
header_test.go
···
package magna
import (
-
"reflect"
"testing"
-
)
-
func assert_eq(t *testing.T, expected any, actual any) {
-
if reflect.TypeOf(expected) != reflect.TypeOf(actual) {
-
t.Errorf("expected type: %T\t actual type: %T\n", expected, actual)
-
}
-
-
if expected != actual {
-
t.Fatalf("expected: %#v\t actual: %#v\n", expected, actual)
-
}
-
}
-
-
func assert_no_error(t *testing.T, err error) {
-
if err != nil {
-
t.Fatalf("err is not nil: %v\n", err)
-
}
-
}
-
-
func assert_error(t *testing.T, err error) {
-
if err == nil {
-
t.Fatalf("err is not nil: %v\n", err)
-
}
-
}
+
"github.com/stretchr/testify/assert"
+
)
func TestHeaderDecode(t *testing.T) {
bytes := []byte{
···
t.Errorf("offset should be 12 not %v\n", offset)
}
-
assert_eq(t, header.ID, uint16(258))
-
assert_eq(t, header.QR, true)
-
assert_eq(t, header.OPCode, OPCode(5))
-
assert_eq(t, header.AA, false)
-
assert_eq(t, header.TC, true)
-
assert_eq(t, header.RD, false)
-
assert_eq(t, header.RA, true)
-
assert_eq(t, header.Z, uint8(0b010))
-
assert_eq(t, header.RCode, RCode(0b1010))
-
assert_eq(t, header.QDCount, uint16(1))
-
assert_eq(t, header.ANCount, uint16(2))
-
assert_eq(t, header.NSCount, uint16(3))
-
assert_eq(t, header.ARCount, uint16(4))
+
assert.Equal(t, header.ID, uint16(258))
+
assert.Equal(t, header.QR, true)
+
assert.Equal(t, header.OPCode, OPCode(5))
+
assert.Equal(t, header.AA, false)
+
assert.Equal(t, header.TC, true)
+
assert.Equal(t, header.RD, false)
+
assert.Equal(t, header.RA, true)
+
assert.Equal(t, header.Z, uint8(0b010))
+
assert.Equal(t, header.RCode, RCode(0b1010))
+
assert.Equal(t, header.QDCount, uint16(1))
+
assert.Equal(t, header.ANCount, uint16(2))
+
assert.Equal(t, header.NSCount, uint16(3))
+
assert.Equal(t, header.ARCount, uint16(4))
}
func TestHeaderEncode(t *testing.T) {
···
var header Header
_, err := header.Decode(bytes, 0)
-
assert_no_error(t, err)
+
assert.NoError(t, err)
actual := header.Encode()
-
for i := 0; i < len(bytes); i++ {
-
if bytes[i] != actual[i] {
-
t.Fatal(bytes, actual)
-
}
-
}
+
assert.Equal(t, bytes, actual)
}
func FuzzDecodeHeader(f *testing.F) {
+1 -1
message.go
···
// Encode encodes a message to a DNS packet.
// TODO: set truncation bit if over 512 and udp is protocol
func (m *Message) Encode() []byte {
-
m.offsets = make(map[string]uint8)
+
m.offsets = make(map[string]uint16)
bytes := m.Header.Encode()
for _, question := range m.Question {
+74 -47
message_test.go
···
import (
"testing"
+
+
"github.com/stretchr/testify/assert"
)
func TestMessageDecode(t *testing.T) {
···
var msg Message
msg.Decode(bytes)
-
assert_eq(t, uint16(0x8e19), msg.Header.ID)
-
assert_eq(t, true, msg.Header.QR)
-
assert_eq(t, OPCode(0), msg.Header.OPCode)
-
assert_eq(t, false, msg.Header.AA)
-
assert_eq(t, false, msg.Header.TC)
-
assert_eq(t, true, msg.Header.RD)
-
assert_eq(t, true, msg.Header.RA)
-
assert_eq(t, uint8(0), msg.Header.Z)
-
assert_eq(t, RCode(0), msg.Header.RCode)
-
assert_eq(t, uint16(1), msg.Header.QDCount)
-
assert_eq(t, uint16(1), msg.Header.ANCount)
-
assert_eq(t, uint16(0), msg.Header.NSCount)
-
assert_eq(t, uint16(0), msg.Header.ARCount)
+
assert.Equal(t, uint16(0x8e19), msg.Header.ID)
+
assert.Equal(t, true, msg.Header.QR)
+
assert.Equal(t, OPCode(0), msg.Header.OPCode)
+
assert.Equal(t, false, msg.Header.AA)
+
assert.Equal(t, false, msg.Header.TC)
+
assert.Equal(t, true, msg.Header.RD)
+
assert.Equal(t, true, msg.Header.RA)
+
assert.Equal(t, uint8(0), msg.Header.Z)
+
assert.Equal(t, RCode(0), msg.Header.RCode)
+
assert.Equal(t, uint16(1), msg.Header.QDCount)
+
assert.Equal(t, uint16(1), msg.Header.ANCount)
+
assert.Equal(t, uint16(0), msg.Header.NSCount)
+
assert.Equal(t, uint16(0), msg.Header.ARCount)
-
assert_eq(t, 1, len(msg.Question))
+
assert.Equal(t, 1, len(msg.Question))
question := msg.Question[0]
-
assert_eq(t, "news.ycombinator.com", question.QName)
-
assert_eq(t, DNSType(1), question.QType)
-
assert_eq(t, DNSClass(1), question.QClass)
+
assert.Equal(t, "news.ycombinator.com", question.QName)
+
assert.Equal(t, DNSType(1), question.QType)
+
assert.Equal(t, DNSClass(1), question.QClass)
-
assert_eq(t, 1, len(msg.Answer))
+
assert.Equal(t, 1, len(msg.Answer))
answer := msg.Answer[0]
-
assert_eq(t, answer.Name, "news.ycombinator.com")
-
assert_eq(t, DNSType(1), answer.RType)
-
assert_eq(t, DNSClass(1), answer.RClass)
-
assert_eq(t, uint32(1), answer.TTL)
-
assert_eq(t, uint16(4), answer.RDLength)
+
assert.Equal(t, answer.Name, "news.ycombinator.com")
+
assert.Equal(t, DNSType(1), answer.RType)
+
assert.Equal(t, DNSClass(1), answer.RClass)
+
assert.Equal(t, uint32(1), answer.TTL)
+
assert.Equal(t, uint16(4), answer.RDLength)
}
func TestMessageDecodeWithU14Offset(t *testing.T) {
···
// assert_no_error(t, err)
// Header Section
-
assert_eq(t, uint16(0x0ec3), msg.Header.ID)
-
assert_eq(t, true, msg.Header.QR)
-
assert_eq(t, QUERY, msg.Header.OPCode)
-
assert_eq(t, false, msg.Header.AA)
-
assert_eq(t, false, msg.Header.TC)
-
assert_eq(t, false, msg.Header.RD)
-
assert_eq(t, false, msg.Header.RA)
-
assert_eq(t, uint8(0), msg.Header.Z)
-
assert_eq(t, NOERROR, msg.Header.RCode)
-
assert_eq(t, uint16(1), msg.Header.QDCount)
-
assert_eq(t, uint16(0), msg.Header.ANCount)
-
assert_eq(t, uint16(13), msg.Header.NSCount)
-
assert_eq(t, uint16(14), msg.Header.ARCount)
+
assert.Equal(t, uint16(0x0ec3), msg.Header.ID)
+
assert.Equal(t, true, msg.Header.QR)
+
assert.Equal(t, QUERY, msg.Header.OPCode)
+
assert.Equal(t, false, msg.Header.AA)
+
assert.Equal(t, false, msg.Header.TC)
+
assert.Equal(t, false, msg.Header.RD)
+
assert.Equal(t, false, msg.Header.RA)
+
assert.Equal(t, uint8(0), msg.Header.Z)
+
assert.Equal(t, NOERROR, msg.Header.RCode)
+
assert.Equal(t, uint16(1), msg.Header.QDCount)
+
assert.Equal(t, uint16(0), msg.Header.ANCount)
+
assert.Equal(t, uint16(13), msg.Header.NSCount)
+
assert.Equal(t, uint16(14), msg.Header.ARCount)
// Query Section
-
assert_eq(t, 1, len(msg.Question))
+
assert.Equal(t, 1, len(msg.Question))
question := msg.Question[0]
-
assert_eq(t, "ns-372.awsdns-46.com", question.QName)
-
assert_eq(t, AType, question.QType)
-
assert_eq(t, IN, question.QClass)
+
assert.Equal(t, "ns-372.awsdns-46.com", question.QName)
+
assert.Equal(t, AType, question.QType)
+
assert.Equal(t, IN, question.QClass)
-
assert_eq(t, 13, len(msg.Authority))
-
assert_eq(t, 14, len(msg.Additional))
+
assert.Equal(t, 13, len(msg.Authority))
+
assert.Equal(t, 14, len(msg.Additional))
}
func TestMessageEncode(t *testing.T) {
···
var msg Message
err := msg.Decode(bytes)
-
assert_no_error(t, err)
+
assert.NoError(t, err)
actual := msg.Encode()
-
for i := 0; i < len(bytes); i++ {
-
if bytes[i] != actual[i] {
-
t.Fatal(bytes, actual)
-
}
+
assert.Equal(t, bytes, actual)
+
}
+
+
func TestMessageEncode2(t *testing.T) {
+
bytes := []byte{
+
0xfc, 0xa9, 0x81, 0x80, 0x00, 0x01, 0x00, 0x05,
+
0x00, 0x00, 0x00, 0x00, 0x03, 0x6f, 0x6c, 0x64,
+
0x06, 0x72, 0x65, 0x64, 0x64, 0x69, 0x74, 0x03,
+
0x63, 0x6f, 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01,
+
0xc0, 0x0c, 0x00, 0x05, 0x00, 0x01, 0x00, 0x00,
+
0x2a, 0x17, 0x00, 0x17, 0x06, 0x72, 0x65, 0x64,
+
0x64, 0x69, 0x74, 0x03, 0x6d, 0x61, 0x70, 0x06,
+
0x66, 0x61, 0x73, 0x74, 0x6c, 0x79, 0x03, 0x6e,
+
0x65, 0x74, 0x00, 0xc0, 0x2c, 0x00, 0x01, 0x00,
+
0x01, 0x00, 0x00, 0x00, 0x23, 0x00, 0x04, 0x97,
+
0x65, 0x01, 0x8c, 0xc0, 0x2c, 0x00, 0x01, 0x00,
+
0x01, 0x00, 0x00, 0x00, 0x23, 0x00, 0x04, 0x97,
+
0x65, 0xc1, 0x8c, 0xc0, 0x2c, 0x00, 0x01, 0x00,
+
0x01, 0x00, 0x00, 0x00, 0x23, 0x00, 0x04, 0x97,
+
0x65, 0x41, 0x8c, 0xc0, 0x2c, 0x00, 0x01, 0x00,
+
0x01, 0x00, 0x00, 0x00, 0x23, 0x00, 0x04, 0x97,
+
0x65, 0x81, 0x8c,
}
+
+
var msg Message
+
err := msg.Decode(bytes)
+
assert.NoError(t, err)
+
+
actual := msg.Encode()
+
assert.Equal(t, bytes, actual)
}
func FuzzDecodeMessage(f *testing.F) {
+1 -1
question.go
···
}
// Encode serializes a Question into bytes, using a map to handle domain name compression offsets.
-
func (q *Question) Encode(bytes []byte, offsets *map[string]uint8) []byte {
+
func (q *Question) Encode(bytes []byte, offsets *map[string]uint16) []byte {
bytes = encode_domain(bytes, q.QName, offsets)
bytes = binary.BigEndian.AppendUint16(bytes, uint16(q.QType))
bytes = binary.BigEndian.AppendUint16(bytes, uint16(q.QClass))
+23 -27
resource_record.go
···
return offset, err
}
-
func (a *A) Encode(bytes []byte, offsets *map[string]uint8) []byte {
+
func (a *A) Encode(bytes []byte, offsets *map[string]uint16) []byte {
return append(bytes, a.Address.To4()...)
}
···
return offset, err
}
-
func (ns *NS) Encode(bytes []byte, offsets *map[string]uint8) []byte {
+
func (ns *NS) Encode(bytes []byte, offsets *map[string]uint16) []byte {
return encode_domain(bytes, ns.NSDName, offsets)
}
···
return offset, err
}
-
func (md *MD) Encode(bytes []byte, offsets *map[string]uint8) []byte {
+
func (md *MD) Encode(bytes []byte, offsets *map[string]uint16) []byte {
return encode_domain(bytes, md.MADName, offsets)
}
···
return offset, err
}
-
func (mf *MF) Encode(bytes []byte, offsets *map[string]uint8) []byte {
+
func (mf *MF) Encode(bytes []byte, offsets *map[string]uint16) []byte {
return encode_domain(bytes, mf.MADName, offsets)
}
···
return offset, err
}
-
func (c *CNAME) Encode(bytes []byte, offsets *map[string]uint8) []byte {
+
func (c *CNAME) Encode(bytes []byte, offsets *map[string]uint16) []byte {
return encode_domain(bytes, c.CName, offsets)
}
···
return offset, err
}
-
func (soa *SOA) Encode(bytes []byte, offsets *map[string]uint8) []byte {
+
func (soa *SOA) Encode(bytes []byte, offsets *map[string]uint16) []byte {
bytes = append(bytes, encode_domain(bytes, soa.MName, offsets)...)
bytes = append(bytes, encode_domain(bytes, soa.RName, offsets)...)
bytes = binary.BigEndian.AppendUint32(bytes, soa.Serial)
···
return offset, err
}
-
func (mb *MB) Encode(bytes []byte, offsets *map[string]uint8) []byte {
+
func (mb *MB) Encode(bytes []byte, offsets *map[string]uint16) []byte {
return encode_domain(bytes, mb.MADName, offsets)
}
···
return offset, err
}
-
func (mg *MG) Encode(bytes []byte, offsets *map[string]uint8) []byte {
+
func (mg *MG) Encode(bytes []byte, offsets *map[string]uint16) []byte {
return encode_domain(bytes, mg.MGMName, offsets)
}
···
return offset, err
}
-
func (mr *MR) Encode(bytes []byte, offsets *map[string]uint8) []byte {
+
func (mr *MR) Encode(bytes []byte, offsets *map[string]uint16) []byte {
return encode_domain(bytes, mr.NEWName, offsets)
}
···
return offset, err
}
-
func (null *NULL) Encode(bytes []byte, offsets *map[string]uint8) []byte {
+
func (null *NULL) Encode(bytes []byte, offsets *map[string]uint16) []byte {
return append(bytes, null.Anything...)
}
···
return offset, err
}
-
func (wks *WKS) Encode(bytes []byte, offsets *map[string]uint8) []byte {
+
func (wks *WKS) Encode(bytes []byte, offsets *map[string]uint16) []byte {
bytes = append(bytes, wks.Address.To4()...)
bytes = append(bytes, wks.Protocol)
bytes = append(bytes, wks.BitMap...)
···
return offset, err
}
-
func (ptr *PTR) Encode(bytes []byte, offsets *map[string]uint8) []byte {
+
func (ptr *PTR) Encode(bytes []byte, offsets *map[string]uint16) []byte {
return encode_domain(bytes, ptr.PTRDName, offsets)
}
···
return offset, err
}
-
func (hinfo *HINFO) Encode(bytes []byte, offsets *map[string]uint8) []byte {
+
func (hinfo *HINFO) Encode(bytes []byte, offsets *map[string]uint16) []byte {
bytes = append(bytes, []byte(hinfo.CPU)...)
bytes = append(bytes, ' ')
bytes = append(bytes, []byte(hinfo.OS)...)
···
return offset, err
}
-
func (minfo *MINFO) Encode(bytes []byte, offsets *map[string]uint8) []byte {
+
func (minfo *MINFO) Encode(bytes []byte, offsets *map[string]uint16) []byte {
bytes = encode_domain(bytes, minfo.RMailBx, offsets)
bytes = encode_domain(bytes, minfo.EMailBx, offsets)
···
return offset, err
}
-
func (mx *MX) Encode(bytes []byte, offsets *map[string]uint8) []byte {
+
func (mx *MX) Encode(bytes []byte, offsets *map[string]uint16) []byte {
bytes = binary.BigEndian.AppendUint16(bytes, mx.Preference)
bytes = encode_domain(bytes, mx.Exchange, offsets)
···
return offset, err
}
-
func (txt *TXT) Encode(bytes []byte, offsets *map[string]uint8) []byte {
+
func (txt *TXT) Encode(bytes []byte, offsets *map[string]uint16) []byte {
return append(bytes, []byte(txt.TxtData)...)
}
···
return offset, err
}
-
func (r *Reserved) Encode(bytes []byte, offsets *map[string]uint8) []byte {
+
func (r *Reserved) Encode(bytes []byte, offsets *map[string]uint16) []byte {
return append(bytes, r.Bytes...)
}
···
}
// Encode encdoes a resource record and returns the input bytes appened.
-
func (r *ResourceRecord) Encode(bytes []byte, offsets *map[string]uint8) []byte {
+
func (r *ResourceRecord) Encode(bytes []byte, offsets *map[string]uint16) []byte {
bytes = encode_domain(bytes, r.Name, offsets)
bytes = binary.BigEndian.AppendUint16(bytes, uint16(r.RType))
bytes = binary.BigEndian.AppendUint16(bytes, uint16(r.RClass))
bytes = binary.BigEndian.AppendUint32(bytes, r.TTL)
-
second_dinner := make([]byte, len(bytes))
-
copy(second_dinner, bytes)
-
start := len(second_dinner)
-
second_dinner = r.RData.Encode(second_dinner, offsets)
-
end := len(second_dinner)
-
data := second_dinner[start:end]
-
-
bytes = binary.BigEndian.AppendUint16(bytes, uint16(end-start))
-
bytes = append(bytes, data...)
+
rdata_start := len(bytes)
+
bytes = binary.BigEndian.AppendUint16(bytes, 0)
+
bytes = r.RData.Encode(bytes, offsets)
+
rdata_length := uint16(len(bytes) - rdata_start - 2)
+
binary.BigEndian.PutUint16(bytes[rdata_start:rdata_start+2], rdata_length)
return bytes
}
+2 -2
types.go
···
// offsets is a map of domains pointing to the seen offset used
// in domain compression.
-
offsets map[string]uint8
+
offsets map[string]uint16
}
// A Header represents the metadata information of a DNS packet.
···
// *map[string]uint8 - map containing labels and offsets for domain name compression
type ResourceRecordData interface {
Decode([]byte, int, int) (int, error)
-
Encode([]byte, *map[string]uint8) []byte
+
Encode([]byte, *map[string]uint16) []byte
String() string
}
+26 -24
utils_test.go
···
import (
"testing"
+
+
"github.com/stretchr/testify/assert"
)
func TestU8(t *testing.T) {
···
}
actual, offset, err := getU8(buf, 0)
-
assert_eq(t, uint8(1), actual)
-
assert_eq(t, 1, offset)
-
assert_no_error(t, err)
+
assert.Equal(t, uint8(1), actual)
+
assert.Equal(t, 1, offset)
+
assert.NoError(t, err)
actual, offset, err = getU8(buf, 1)
-
assert_eq(t, uint8(0), actual)
-
assert_eq(t, 1, offset)
-
assert_error(t, err)
+
assert.Equal(t, uint8(0), actual)
+
assert.Equal(t, 1, offset)
+
assert.Error(t, err)
}
func TestU16(t *testing.T) {
···
}
actual, offset, err := getU16(buf, 0)
-
assert_eq(t, uint16(1), actual)
-
assert_eq(t, 2, offset)
-
assert_no_error(t, err)
+
assert.Equal(t, uint16(1), actual)
+
assert.Equal(t, 2, offset)
+
assert.NoError(t, err)
actual, offset, err = getU16(buf, 1)
-
assert_eq(t, uint16(0), actual)
-
assert_eq(t, 2, offset)
-
assert_error(t, err)
+
assert.Equal(t, uint16(0), actual)
+
assert.Equal(t, 2, offset)
+
assert.Error(t, err)
}
func TestU32(t *testing.T) {
···
}
actual, offset, err := getU32(buf, 0)
-
assert_eq(t, uint32(1), actual)
-
assert_eq(t, 4, offset)
-
assert_no_error(t, err)
+
assert.Equal(t, uint32(1), actual)
+
assert.Equal(t, 4, offset)
+
assert.NoError(t, err)
actual, offset, err = getU32(buf, 1)
-
assert_eq(t, uint32(0), actual)
-
assert_eq(t, 4, offset)
-
assert_error(t, err)
+
assert.Equal(t, uint32(0), actual)
+
assert.Equal(t, 4, offset)
+
assert.Error(t, err)
}
func TestSlice(t *testing.T) {
···
}
actual, offset, err := getSlice(buf, 0, 3)
-
assert_eq(t, "blu", string(actual))
-
assert_eq(t, 3, offset)
-
assert_no_error(t, err)
+
assert.Equal(t, "blu", string(actual))
+
assert.Equal(t, 3, offset)
+
assert.NoError(t, err)
actual, offset, err = getSlice(buf, 0, 4)
-
assert_eq(t, "", string(actual))
-
assert_eq(t, 3, offset)
-
assert_error(t, err)
+
assert.Equal(t, "", string(actual))
+
assert.Equal(t, 3, offset)
+
assert.Error(t, err)
}