a go dns packet parser

Compare changes

Choose any two refs to compare.

+30 -31
domain_name.go
···
import (
"encoding/binary"
+
"fmt"
"strings"
)
-
// decode_domain decodes a domain name from a buffer starting at offset.
+
// decodeDomain decodes a domain name from a buffer starting at offset.
// It returns the domain name along with the offset and error.
-
func decode_domain(buf []byte, offset int) (string, int, error) {
+
func decodeDomain(buf []byte, offset int) (string, int, error) {
var builder strings.Builder
firstLabel := true
-
seen_offsets := make(map[int]struct{})
+
seenOffsets := make(map[int]struct{})
finalOffsetAfterJump := -1
currentOffset := offset
for {
-
if _, found := seen_offsets[currentOffset]; found {
+
if _, found := seenOffsets[currentOffset]; found {
return "", len(buf), &DomainCompressionError{}
}
-
seen_offsets[currentOffset] = struct{}{}
+
seenOffsets[currentOffset] = struct{}{}
length, nextOffsetAfterLen, err := getU8(buf, currentOffset)
if err != nil {
-
return "", len(buf), err
+
return "", len(buf), fmt.Errorf("failed to read domain label length: %w", err)
}
if length == 0 {
···
if (length & 0xC0) == 0xC0 {
sec, nextOffsetAfterPtr, err := getU8(buf, nextOffsetAfterLen)
if err != nil {
-
return "", len(buf), err
+
return "", len(buf), fmt.Errorf("failed to read domain compression pointer offset byte: %w", err)
}
jumpTargetOffset := int(length&0x3F)<<8 | int(sec)
···
if jumpTargetOffset >= len(buf) {
return "", len(buf), &BufferOverflowError{Length: len(buf), Offset: jumpTargetOffset}
}
-
if _, found := seen_offsets[jumpTargetOffset]; found {
+
if _, found := seenOffsets[jumpTargetOffset]; found {
return "", len(buf), &DomainCompressionError{}
}
···
labelBytes, nextOffsetAfterLabel, err := getSlice(buf, nextOffsetAfterLen, int(length))
if err != nil {
-
return "", len(buf), err
+
return "", len(buf), fmt.Errorf("failed to read domain label data: %w", err)
}
if !firstLabel {
···
return builder.String(), finalReadOffset, nil
}
-
// encode_domain returns the bytes of the input bytes appened with the encoded domain name.
-
func encode_domain(bytes []byte, domain_name string, offsets *map[string]uint16) []byte {
-
if domain_name == "." || domain_name == "" {
-
return append(bytes, 0)
+
// encodeDomain returns the bytes of the input bytes appened with the encoded domain name.
+
func encodeDomain(bytes []byte, domainName string, offsets *map[string]uint16) ([]byte, error) {
+
if domainName == "." || domainName == "" {
+
return append(bytes, 0), nil
}
-
clean_domain := strings.TrimSuffix(domain_name, ".")
-
if clean_domain == "" {
-
return append(bytes, 0)
+
cleanDomain := strings.TrimSuffix(domainName, ".")
+
if cleanDomain == "" {
+
return append(bytes, 0), nil
}
start := 0
-
for start < len(clean_domain) {
-
suffix := clean_domain[start:]
+
for start < len(cleanDomain) {
+
suffix := cleanDomain[start:]
if offset, found := (*offsets)[suffix]; found {
-
if offset > 0x3FFF {
-
end := strings.IndexByte(suffix, '.')
-
if end == -1 {
-
end = len(suffix)
-
}
-
} else {
-
pointer := 0xC000 | offset
-
return binary.BigEndian.AppendUint16(bytes, pointer)
-
}
+
pointer := 0xC000 | offset
+
bytes = binary.BigEndian.AppendUint16(bytes, pointer)
+
return bytes, nil
}
currentPos := uint16(len(bytes))
···
end := strings.IndexByte(suffix, '.')
var label string
-
nextStart := len(clean_domain)
+
nextStart := len(cleanDomain)
if end == -1 {
label = suffix
···
labelBytes := []byte(label)
if len(labelBytes) > 63 {
-
// XXX: maybe should return an error
-
labelBytes = labelBytes[:63]
+
return nil, &InvalidLabelError{Length: int(len(labelBytes))}
+
}
+
+
if len(labelBytes) == 0 && start < len(cleanDomain) {
+
return nil, &InvalidLabelError{Length: 0}
}
bytes = append(bytes, byte(len(labelBytes)))
bytes = append(bytes, labelBytes...)
}
-
return append(bytes, 0)
+
bytes = append(bytes, 0)
+
return bytes, nil
}
+143 -22
domain_test.go
···
package magna
import (
+
"errors"
"testing"
"github.com/stretchr/testify/assert"
···
func BenchmarkDecodeDomainSimple(b *testing.B) {
input := []byte{3, 'w', 'w', 'w', 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0}
for i := 0; i < b.N; i++ {
-
_, _, _ = decode_domain(input, 0)
+
_, _, _ = decodeDomain(input, 0)
}
}
···
offset := 13
b.ResetTimer()
for i := 0; i < b.N; i++ {
-
_, _, _ = decode_domain(input, offset)
+
_, _, _ = decodeDomain(input, offset)
}
}
···
out := make([]byte, 0, 64)
b.ResetTimer()
for i := 0; i < b.N; i++ {
-
_ = encode_domain(out[:0], domain, &offsets)
+
_, _ = encodeDomain(out[:0], domain, &offsets)
for k := range offsets {
delete(offsets, k)
}
···
out := make([]byte, 0, 128)
b.ResetTimer()
for i := 0; i < b.N; i++ {
-
tempOut := encode_domain(out[:0], domain1, &offsets)
-
_ = encode_domain(tempOut, domain2, &offsets)
+
tempOut, _ := encodeDomain(out[:0], domain1, &offsets)
+
_, _ = encodeDomain(tempOut, domain2, &offsets)
for k := range offsets {
delete(offsets, k)
}
···
expectedDomain string
expectedOffset int
expectedError error
+
errorCheck func(t *testing.T, err error)
}{
{
name: "Simple domain",
···
expectedDomain: "",
expectedOffset: 2,
expectedError: &InvalidLabelError{Length: 64},
+
errorCheck: func(t *testing.T, err error) {
+
var target *InvalidLabelError
+
assert.True(t, errors.As(err, &target))
+
assert.Equal(t, 64, target.Length)
+
},
},
{
name: "Compression loop",
···
expectedDomain: "",
expectedOffset: 4,
expectedError: &DomainCompressionError{},
+
errorCheck: func(t *testing.T, err error) {
+
assert.IsType(t, &DomainCompressionError{}, err)
+
},
},
{
name: "Truncated input",
···
expectedDomain: "",
expectedOffset: 3,
expectedError: &BufferOverflowError{Length: 3, Offset: 4},
+
errorCheck: func(t *testing.T, err error) {
+
var target *BufferOverflowError
+
assert.True(t, errors.As(err, &target), "Expected BufferOverflowError")
+
if target != nil {
+
assert.Equal(t, 3, target.Length)
+
assert.Equal(t, 1+3, target.Offset)
+
}
+
assert.Contains(t, err.Error(), "failed to read domain label data")
+
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
-
domain, offset, err := decode_domain(tt.input, tt.offset)
+
domain, offset, err := decodeDomain(tt.input, tt.offset)
-
t.Log(tt.name)
-
assert.Equal(t, tt.expectedError, err)
-
assert.Equal(t, tt.expectedDomain, domain)
-
assert.Equal(t, tt.expectedOffset, offset)
+
t.Logf("Test: %s, Input: %x, OffsetIn: %d => Domain: '%s', OffsetOut: %d, Err: %v", tt.name, tt.input, tt.offset, domain, offset, err)
+
+
if tt.expectedError != nil {
+
assert.Error(t, err, "Expected an error but got nil")
+
if tt.errorCheck != nil {
+
tt.errorCheck(t, err)
+
} else {
+
assert.IsType(t, tt.expectedError, err, "Error type mismatch")
+
}
+
} else {
+
assert.NoError(t, err, "Expected no error but got one")
+
}
+
+
assert.Equal(t, tt.expectedDomain, domain, "Domain mismatch")
+
if tt.expectedError == nil {
+
assert.Equal(t, tt.expectedOffset, offset, "Offset mismatch")
+
}
})
}
}
func TestEncodeDomain(t *testing.T) {
tests := []struct {
-
name string
-
input string
-
offsets map[string]uint16
-
expected []byte
-
newOffsets map[string]uint16
+
name string
+
input string
+
initialBuf []byte
+
offsets map[string]uint16
+
expected []byte
+
expectedErr error
+
newOffsets map[string]uint16
}{
{
name: "Simple domain",
input: "example.com",
+
initialBuf: []byte{},
offsets: make(map[string]uint16),
expected: []byte{7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0},
newOffsets: map[string]uint16{"example.com": 0, "com": 8},
},
{
-
name: "Domain with existing offset",
+
name: "Domain with existing offset for compression",
input: "test.example.com",
+
initialBuf: []byte{},
offsets: map[string]uint16{"example.com": 10},
expected: []byte{4, 't', 'e', 's', 't', 0xC0, 0x0A},
newOffsets: map[string]uint16{"test.example.com": 0, "example.com": 10},
···
{
name: "Multiple subdomains",
input: "a.b.c.d",
+
initialBuf: []byte{},
offsets: make(map[string]uint16),
expected: []byte{1, 'a', 1, 'b', 1, 'c', 1, 'd', 0},
newOffsets: map[string]uint16{"a.b.c.d": 0, "b.c.d": 2, "c.d": 4, "d": 6},
},
+
{
+
name: "Root domain",
+
input: ".",
+
initialBuf: []byte{},
+
offsets: make(map[string]uint16),
+
expected: []byte{0},
+
newOffsets: map[string]uint16{},
+
},
+
{
+
name: "Empty domain",
+
input: "",
+
initialBuf: []byte{},
+
offsets: make(map[string]uint16),
+
expected: []byte{0},
+
newOffsets: map[string]uint16{},
+
},
+
{
+
name: "Label too long",
+
input: "labeltoolonglabeltoolonglabeltoolonglabeltoolonglabeltoolonglabeltoolong.com",
+
initialBuf: []byte{},
+
offsets: make(map[string]uint16),
+
expected: nil,
+
expectedErr: &InvalidLabelError{Length: 72},
+
newOffsets: map[string]uint16{},
+
},
+
{
+
name: "Empty label inside domain",
+
input: "example..com",
+
initialBuf: []byte{},
+
offsets: make(map[string]uint16),
+
expected: nil,
+
expectedErr: &InvalidLabelError{Length: 0},
+
newOffsets: map[string]uint16{},
+
},
+
{
+
name: "Append to existing buffer",
+
input: "example.com",
+
initialBuf: []byte{0xAA, 0xBB},
+
offsets: make(map[string]uint16),
+
expected: []byte{0xAA, 0xBB, 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0},
+
newOffsets: map[string]uint16{"example.com": 2, "com": 10},
+
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
-
result := encode_domain([]byte{}, tt.input, &tt.offsets)
-
assert.Equal(t, tt.expected, result, "Encoded domain does not match expected output")
-
assert.Equal(t, tt.newOffsets, tt.offsets, "Offsets map does not match expected state")
+
currentOffsets := make(map[string]uint16)
+
for k, v := range tt.offsets {
+
currentOffsets[k] = v
+
}
+
+
result, err := encodeDomain(tt.initialBuf, tt.input, &currentOffsets)
+
+
if tt.expectedErr != nil {
+
assert.Error(t, err, "Expected an error but got nil")
+
assert.IsType(t, tt.expectedErr, err, "Error type mismatch")
+
if expectedILE, ok := tt.expectedErr.(*InvalidLabelError); ok {
+
actualILE := &InvalidLabelError{}
+
if assert.True(t, errors.As(err, &actualILE)) {
+
assert.Equal(t, expectedILE.Length, actualILE.Length)
+
}
+
}
+
} else {
+
assert.NoError(t, err, "Expected no error but got one")
+
assert.Equal(t, tt.expected, result, "Encoded domain does not match expected output")
+
assert.Equal(t, tt.newOffsets, currentOffsets, "Offsets map does not match expected state")
+
}
})
}
}
···
0x03, 0x63, 0x6f, 0x6d, 0x00,
},
{
-
0x03, 0x63, 0x6f, 0x6d, 0x00, 0x01, 0x63, 0xC0, 0x00,
+
0x03, 0x63, 0x6f, 0x6d, 0x00, 0x01, 0x63, 0xc0, 0x00,
+
},
+
{
+
0x03, 0x63, 0x6f, 0x6d, 0xc0, 0x00,
+
},
+
{
+
0xc0, 0x00,
},
{
-
0x03, 0x63, 0x6f, 0x6d, 0xC0, 0x00,
+
0xc0, 0xff,
+
},
+
{
+
0x40,
+
},
+
{
+
0x03, 0x63, 0x6f,
+
},
+
{
+
0xc0,
},
}
for _, tc := range testcases {
f.Add(tc)
}
f.Fuzz(func(t *testing.T, msg []byte) {
-
decode_domain(msg, 0)
+
_, _, err := decodeDomain(msg, 0)
+
if err != nil {
+
var bufErr *BufferOverflowError
+
var labelErr *InvalidLabelError
+
var compErr *DomainCompressionError
+
+
if !(errors.As(err, &bufErr) || errors.As(err, &labelErr) || errors.As(err, &compErr)) {
+
t.Errorf("Fuzzing decodeDomain: unexpected error type %T: %v for input %x", err, err, msg)
+
}
+
}
})
}
+24 -10
errors.go
···
}
func (e *BufferOverflowError) Error() string {
-
return fmt.Sprintf("magna: offset %d is past the buffer length %d", e.Offset, e.Length)
+
return fmt.Sprintf("buffer overflow: attempted to read past buffer length %d at offset %d", e.Length, e.Offset)
+
}
+
+
func (e *BufferOverflowError) Is(target error) bool {
+
_, ok := target.(*BufferOverflowError)
+
return ok
}
// InvalidLabelError represents an error when an invalid label length is encountered.
···
}
func (e *InvalidLabelError) Error() string {
-
return fmt.Sprintf("magna: received invalid label length %d", e.Length)
+
if e.Length > 63 {
+
return fmt.Sprintf("invalid domain label: length %d exceeds maximum 63", e.Length)
+
}
+
if e.Length == 0 {
+
return "invalid domain label: zero length label encountered"
+
}
+
+
// XXX: this should be unreachable
+
return fmt.Sprintf("invalid domain label: unexpected length %d", e.Length)
+
}
+
+
func (e *InvalidLabelError) Is(target error) bool {
+
_, ok := target.(*InvalidLabelError)
+
return ok
}
// DomainCompressionError represents an error related to domain compression.
type DomainCompressionError struct{}
func (e *DomainCompressionError) Error() string {
-
return "magna: loop detected in domain compression"
-
}
-
-
// MagnaError represents a generic error with a custom message.
-
type MagnaError struct {
-
Message string
+
return "invalid domain compression: pointer loop detected"
}
-
func (e *MagnaError) Error() string {
-
return fmt.Sprintf("magna: %s", e.Message)
+
func (e *DomainCompressionError) Is(target error) bool {
+
_, ok := target.(*DomainCompressionError)
+
return ok
}
+8 -28
errors_test.go
···
package magna
import (
-
"fmt"
"testing"
"github.com/stretchr/testify/assert"
···
offset int
expected string
}{
-
{"PositiveOffset", 10, 15, "magna: offset 15 is past the buffer length 10"},
-
{"ZeroLengthBuffer", 0, 5, "magna: offset 5 is past the buffer length 0"},
-
{"NegativeOffset", 10, -1, "magna: offset -1 is past the buffer length 10"},
-
{"EqualOffset", 10, 10, "magna: offset 10 is past the buffer length 10"},
+
{"PositiveOffset", 10, 15, "buffer overflow: attempted to read past buffer length 10 at offset 15"},
+
{"ZeroLengthBuffer", 0, 5, "buffer overflow: attempted to read past buffer length 0 at offset 5"},
+
{"NegativeOffset", 10, -1, "buffer overflow: attempted to read past buffer length 10 at offset -1"},
+
{"EqualOffset", 10, 10, "buffer overflow: attempted to read past buffer length 10 at offset 10"},
}
for _, tt := range tests {
···
length int
expected string
}{
-
{"LengthTooLarge", 64, "magna: received invalid label length 64"},
-
{"LengthZero", 0, "magna: received invalid label length 0"},
-
{"NegativeLength", -1, "magna: received invalid label length -1"},
+
{"LengthTooLarge", 64, "invalid domain label: length 64 exceeds maximum 63"},
+
{"LengthZero", 0, "invalid domain label: zero length label encountered"},
+
{"ValidLength", 30, "invalid domain label: unexpected length 30"},
}
for _, tt := range tests {
···
func TestDomainCompressionError(t *testing.T) {
t.Run("Standard", func(t *testing.T) {
err := &DomainCompressionError{}
-
expected := "magna: loop detected in domain compression"
+
expected := "invalid domain compression: pointer loop detected"
assert.Equal(t, expected, err.Error(), "Error() output mismatch")
})
}
-
-
func TestMagnaError(t *testing.T) {
-
tests := []struct {
-
name string
-
message string
-
expected string
-
}{
-
{"EmptyMessage", "", "magna: "},
-
{"SimpleMessage", "test error", "magna: test error"},
-
{"MessageWithPunctuation", "error: invalid input!", "magna: error: invalid input!"},
-
}
-
-
for _, tt := range tests {
-
t.Run(tt.name, func(t *testing.T) {
-
err := &MagnaError{Message: tt.message}
-
assert.Equal(t, tt.expected, err.Error())
-
})
-
}
-
}
+10 -7
header.go
···
package magna
-
import "encoding/binary"
+
import (
+
"encoding/binary"
+
"fmt"
+
)
// Decode decodes the header from the bytes.
func (h *Header) Decode(buf []byte, offset int) (int, error) {
var err error
h.ID, offset, err = getU16(buf, offset)
if err != nil {
-
return len(buf), err
+
return len(buf), fmt.Errorf("header decode: failed to read ID: %w", err)
}
flags, offset, err := getU16(buf, offset)
if err != nil {
-
return len(buf), err
+
return len(buf), fmt.Errorf("header decode: failed to read flags: %w", err)
}
h.QDCount, offset, err = getU16(buf, offset)
if err != nil {
-
return len(buf), err
+
return len(buf), fmt.Errorf("header decode: failed to read QDCount: %w", err)
}
h.ANCount, offset, err = getU16(buf, offset)
if err != nil {
-
return len(buf), err
+
return len(buf), fmt.Errorf("header decode: failed to read ANCount: %w", err)
}
h.NSCount, offset, err = getU16(buf, offset)
if err != nil {
-
return len(buf), err
+
return len(buf), fmt.Errorf("header decode: failed to read NSCount: %w", err)
}
h.ARCount, offset, err = getU16(buf, offset)
if err != nil {
-
return len(buf), err
+
return len(buf), fmt.Errorf("header decode: failed to read ARCount: %w", err)
}
h.QR = ((flags >> 15) & 0x01) == 1
+81 -60
header_test.go
···
package magna
import (
+
"bytes"
"encoding/binary"
+
"errors"
"testing"
"github.com/stretchr/testify/assert"
+
"github.com/stretchr/testify/require"
)
func TestHeaderDecode(t *testing.T) {
···
expectedHeader Header
expectedOffset int
expectedErr error
+
wantErrMsg string
}{
{
name: "Valid header",
···
expectedErr: nil,
},
{
-
name: "Insufficient buffer length",
+
name: "Insufficient buffer length for Flags",
input: []byte{0x12, 0x34, 0x81},
-
expectedHeader: Header{},
+
expectedHeader: Header{ID: 0x1234},
expectedOffset: 3,
-
expectedErr: &BufferOverflowError{Length: 3, Offset: 3},
+
expectedErr: &BufferOverflowError{},
+
wantErrMsg: "header decode: failed to read flags",
},
{
-
name: "Invalid ID",
+
name: "Insufficient buffer length for ID",
input: []byte{0x12},
expectedHeader: Header{},
expectedOffset: 1,
-
expectedErr: &BufferOverflowError{Length: 1, Offset: 1},
+
expectedErr: &BufferOverflowError{},
+
wantErrMsg: "header decode: failed to read ID",
},
{
name: "Missing QDCount",
input: []byte{0x12, 0x34, 0x81, 0x80, 0x00},
-
expectedHeader: Header{},
+
expectedHeader: Header{ID: 0x1234},
expectedOffset: 5,
expectedErr: &BufferOverflowError{},
+
wantErrMsg: "header decode: failed to read QDCount",
},
{
name: "Missing ANCount",
-
input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01},
-
expectedHeader: Header{},
-
expectedOffset: 6,
+
input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00},
+
expectedHeader: Header{ID: 0x1234, QDCount: 1},
+
expectedOffset: 7,
expectedErr: &BufferOverflowError{},
+
wantErrMsg: "header decode: failed to read ANCount",
},
{
name: "Missing NSCount",
-
input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02},
-
expectedHeader: Header{},
-
expectedOffset: 8,
+
input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00},
+
expectedHeader: Header{ID: 0x1234, QDCount: 1, ANCount: 2},
+
expectedOffset: 9,
expectedErr: &BufferOverflowError{},
+
wantErrMsg: "header decode: failed to read NSCount",
},
{
name: "Missing ARCount",
-
input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03},
-
expectedHeader: Header{},
-
expectedOffset: 10,
+
input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, 0x00},
+
expectedHeader: Header{ID: 0x1234, QDCount: 1, ANCount: 2, NSCount: 3},
+
expectedOffset: 11,
expectedErr: &BufferOverflowError{},
+
wantErrMsg: "header decode: failed to read ARCount",
},
}
···
offset, err := h.Decode(tt.input, 0)
if tt.expectedErr != nil {
-
assert.Error(t, err)
-
assert.IsType(t, tt.expectedErr, err)
+
assert.Error(t, err, "Expected an error but got nil")
+
+
assert.True(t, errors.Is(err, tt.expectedErr), "Error type mismatch. Got %T, expected %T", err, tt.expectedErr)
+
+
if tt.wantErrMsg != "" {
+
assert.ErrorContains(t, err, tt.wantErrMsg, "Wrapped error message mismatch")
+
}
+
+
assert.Equal(t, tt.expectedOffset, offset, "Offset mismatch on error")
} else {
-
assert.NoError(t, err)
-
assert.Equal(t, tt.expectedHeader, *h)
-
}
+
assert.NoError(t, err, "Expected no error but got one")
-
assert.Equal(t, tt.expectedOffset, offset)
+
assert.Equal(t, tt.expectedHeader, *h, "Header content mismatch")
+
assert.Equal(t, tt.expectedOffset, offset, "Offset mismatch on success")
+
}
})
}
}
···
}
h := &Header{}
-
_, err := h.Decode(input, 0)
+
offset, err := h.Decode(input, 0)
assert.NoError(t, err)
-
assert.Equal(t, tt.expected.QR, h.QR)
-
assert.Equal(t, tt.expected.OPCode, h.OPCode)
-
assert.Equal(t, tt.expected.AA, h.AA)
-
assert.Equal(t, tt.expected.TC, h.TC)
-
assert.Equal(t, tt.expected.RD, h.RD)
-
assert.Equal(t, tt.expected.RA, h.RA)
-
assert.Equal(t, tt.expected.Z, h.Z)
-
assert.Equal(t, tt.expected.RCode, h.RCode)
+
assert.Equal(t, 12, offset, "Offset should be 12 after decoding full header")
+
+
assert.Equal(t, tt.expected.QR, h.QR, "QR flag mismatch")
+
assert.Equal(t, tt.expected.OPCode, h.OPCode, "OPCode mismatch")
+
assert.Equal(t, tt.expected.AA, h.AA, "AA flag mismatch")
+
assert.Equal(t, tt.expected.TC, h.TC, "TC flag mismatch")
+
assert.Equal(t, tt.expected.RD, h.RD, "RD flag mismatch")
+
assert.Equal(t, tt.expected.RA, h.RA, "RA flag mismatch")
+
assert.Equal(t, tt.expected.Z, h.Z, "Z value mismatch")
+
assert.Equal(t, tt.expected.RCode, h.RCode, "RCode mismatch")
})
}
}
···
},
},
{
-
name: "No flags set",
+
name: "No flags set, different counts",
header: Header{
ID: 0x5678,
QR: false,
···
},
},
{
-
name: "Mixed flags",
+
name: "Mixed flags and counts",
header: Header{
ID: 0x9abc,
QR: true,
···
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
encoded := tt.header.Encode()
-
assert.Equal(t, tt.expected, encoded)
+
assert.Equal(t, tt.expected, encoded, "Encoded header mismatch")
})
}
}
···
}
encoded := originalHeader.Encode()
+
assert.Len(t, encoded, 12, "Encoded header should be 12 bytes")
decodedHeader := &Header{}
offset, err := decodedHeader.Decode(encoded, 0)
-
assert.NoError(t, err)
-
assert.Equal(t, len(encoded), offset)
-
assert.Equal(t, originalHeader, *decodedHeader)
+
assert.NoError(t, err, "Decoding failed unexpectedly")
+
assert.Equal(t, 12, offset, "Offset after decoding should be 12")
+
assert.Equal(t, originalHeader, *decodedHeader, "Decoded header does not match original")
}
func TestHeaderEncodeFlagCombinations(t *testing.T) {
···
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
encoded := tc.header.Encode()
+
require.Len(t, encoded, 12, "Encoded header length invariant")
+
flags := binary.BigEndian.Uint16(encoded[2:4])
-
assert.Equal(t, tc.expected, flags)
+
assert.Equal(t, tc.expected, flags, "Flags value mismatch")
})
}
}
···
{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, 0x00, 0x04},
{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+
{0x12, 0x34},
+
{0x12, 0x34, 0x81, 0x80, 0x00},
+
{},
}
for _, tc := range testcases {
···
}
f.Fuzz(func(t *testing.T, data []byte) {
-
// limit to only 12 bytes
-
if len(data) > 12 {
-
data = data[0:12]
-
}
-
h := &Header{}
offset, err := h.Decode(data, 0)
if err != nil {
-
switch err.(type) {
-
case *BufferOverflowError, *InvalidLabelError:
-
// these are expected error types
-
default:
-
t.Errorf("unexpected error type: %T", err)
+
var bofErr *BufferOverflowError
+
if !errors.As(err, &bofErr) {
+
t.Errorf("FuzzDecodeHeader: expected BufferOverflowError or wrapped BOF, got %T: %v", err, err)
+
}
+
if offset > len(data) {
+
t.Errorf("FuzzDecodeHeader: offset (%d) > data length (%d) on error", offset, len(data))
}
return
}
-
if offset != len(data) {
-
t.Errorf("offset (%d) does not match data length (%d)", offset, len(data))
+
if len(data) < 12 {
+
t.Errorf("FuzzDecodeHeader: decoded successfully but input length %d < 12", len(data))
+
return
+
}
+
if offset != 12 {
+
t.Errorf("FuzzDecodeHeader: successful decode offset (%d) != 12", offset)
}
if h.OPCode > 15 {
-
t.Errorf("invalid OPCode: %d", h.OPCode)
+
t.Errorf("FuzzDecodeHeader: invalid OPCode decoded: %d", h.OPCode)
}
-
if h.Z > 7 {
-
t.Errorf("invalid Z value: %d", h.Z)
+
t.Errorf("FuzzDecodeHeader: invalid Z value decoded: %d", h.Z)
}
-
if h.RCode > 15 {
-
t.Errorf("invalid RCode: %d", h.RCode)
+
t.Errorf("FuzzDecodeHeader: invalid RCode decoded: %d", h.RCode)
}
-
encoded := h.Encode()
-
if len(encoded) != len(data) {
-
t.Errorf("encoded length (%d) does not match input length (%d)", len(encoded), len(data))
-
-
for i := 0; i < len(data); i++ {
-
t.Errorf("mismatch at position: %d: encoded %02x, input: %02x", i, encoded[i], data[i])
+
if len(data) >= 12 {
+
encoded := h.Encode()
+
if !bytes.Equal(encoded, data[:12]) {
+
t.Errorf("FuzzDecodeHeader: encode/decode mismatch\nInput: %x\nEncoded: %x", data[:12], encoded)
}
}
})
+16 -8
question.go
···
package magna
-
import "encoding/binary"
+
import (
+
"encoding/binary"
+
"fmt"
+
)
// Decode decodes a question from buf at the offset
func (q *Question) Decode(buf []byte, offset int) (int, error) {
var err error
-
q.QName, offset, err = decode_domain(buf, offset)
+
q.QName, offset, err = decodeDomain(buf, offset)
if err != nil {
-
return offset, err
+
return offset, fmt.Errorf("question decode: failed to decode QName: %w", err)
}
qtype, offset, err := getU16(buf, offset)
if err != nil {
-
return offset, err
+
return offset, fmt.Errorf("question decode: failed to decode QType for %s: %w", q.QName, err)
}
qclass, offset, err := getU16(buf, offset)
if err != nil {
-
return offset, err
+
return offset, fmt.Errorf("question decode: failed to decode QClass for %s: %w", q.QName, err)
}
q.QType = DNSType(qtype)
···
}
// Encode serializes a Question into bytes, using a map to handle domain name compression offsets.
-
func (q *Question) Encode(bytes []byte, offsets *map[string]uint16) []byte {
-
bytes = encode_domain(bytes, q.QName, offsets)
+
func (q *Question) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) {
+
var err error
+
bytes, err = encodeDomain(bytes, q.QName, offsets)
+
if err != nil {
+
return nil, fmt.Errorf("question encode: failed to encode QName %s: %w", q.QName, err)
+
}
+
bytes = binary.BigEndian.AppendUint16(bytes, uint16(q.QType))
bytes = binary.BigEndian.AppendUint16(bytes, uint16(q.QClass))
-
return bytes
+
return bytes, nil
}
+103 -82
question_test.go
···
package magna
import (
+
"errors"
"testing"
"github.com/stretchr/testify/assert"
+
"github.com/stretchr/testify/require"
)
func TestQuestionDecode(t *testing.T) {
···
expectedOffset int
expected Question
expectedErr error
+
wantErrMsg string
}{
{
name: "Valid question - example.com A IN",
···
expectedErr: nil,
},
{
-
name: "Invalid domain name",
-
input: []byte{255, 'i', 'n', 'v', 'a', 'l', 'i', 'd', 0, 0, 1, 0, 1},
+
name: "Invalid domain name - label too long",
+
input: []byte{64, 'i', 'n', 'v', 'a', 'l', 'i', 'd', 0, 0, 1, 0, 1},
expectedOffset: 13,
expected: Question{},
-
expectedErr: &BufferOverflowError{},
+
expectedErr: &InvalidLabelError{},
+
wantErrMsg: "question decode: failed to decode QName: invalid domain label: length 64 exceeds maximum 63",
+
},
+
{
+
name: "Invalid domain name - compression loop",
+
input: []byte{0xC0, 0x00, 0, 1, 0, 1},
+
expectedOffset: 6,
+
expected: Question{},
+
expectedErr: &DomainCompressionError{},
+
wantErrMsg: "question decode: failed to decode QName: invalid domain compression: pointer loop detected",
},
{
name: "Insufficient buffer for QType",
···
expectedOffset: 14,
expected: Question{QName: "example.com"},
expectedErr: &BufferOverflowError{},
+
wantErrMsg: "question decode: failed to decode QType for example.com: buffer overflow",
},
{
name: "Insufficient buffer for QClass",
···
expectedOffset: 16,
expected: Question{QName: "example.com", QType: DNSType(1)},
expectedErr: &BufferOverflowError{},
+
wantErrMsg: "question decode: failed to decode QClass for example.com: buffer overflow",
},
}
···
q := &Question{}
offset, err := q.Decode(tt.input, 0)
-
assert.Equal(t, tt.expectedOffset, offset)
-
if tt.expectedErr != nil {
-
assert.Error(t, err)
-
assert.IsType(t, tt.expectedErr, err)
+
assert.Error(t, err, "Expected an error but got nil")
+
assert.True(t, errors.Is(err, tt.expectedErr), "Error type mismatch. Got %T, expected %T", err, tt.expectedErr)
+
if tt.wantErrMsg != "" {
+
assert.ErrorContains(t, err, tt.wantErrMsg, "Wrapped error message mismatch")
+
}
+
assert.Equal(t, tt.expectedOffset, offset, "Offset mismatch on error")
} else {
-
assert.NoError(t, err)
-
assert.Equal(t, tt.expected, *q)
+
assert.NoError(t, err, "Expected no error but got one")
+
assert.Equal(t, tt.expected, *q, "Decoded question mismatch")
+
assert.Equal(t, tt.expectedOffset, offset, "Offset mismatch on success")
}
})
}
···
func TestQuestionEncode(t *testing.T) {
tests := []struct {
-
name string
-
question Question
-
offsets map[string]uint16
-
expected []byte
+
name string
+
question Question
+
initialBuf []byte
+
offsets map[string]uint16
+
expected []byte
+
expectedErr error
+
wantErrMsg string
+
newOffsets map[string]uint16
}{
{
name: "Simple domain - example.com A IN",
···
QType: DNSType(1),
QClass: DNSClass(1),
},
-
offsets: make(map[string]uint16),
-
expected: []byte{7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 0, 1, 0, 1},
+
initialBuf: nil,
+
offsets: make(map[string]uint16),
+
expected: []byte{7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 0, 1, 0, 1},
+
newOffsets: map[string]uint16{"example.com": 0, "com": 8},
},
{
name: "Subdomain - subdomain.example.com AAAA IN",
···
QType: DNSType(28),
QClass: DNSClass(1),
},
-
offsets: make(map[string]uint16),
-
expected: []byte{9, 's', 'u', 'b', 'd', 'o', 'm', 'a', 'i', 'n', 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 0, 28, 0, 1},
+
initialBuf: nil,
+
offsets: make(map[string]uint16),
+
expected: []byte{9, 's', 'u', 'b', 'd', 'o', 'm', 'a', 'i', 'n', 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 0, 28, 0, 1},
+
newOffsets: map[string]uint16{"subdomain.example.com": 0, "example.com": 10, "com": 18},
},
{
name: "Different class - example.com MX CH",
···
QType: DNSType(15),
QClass: DNSClass(3),
},
-
offsets: make(map[string]uint16),
-
expected: []byte{7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 0, 15, 0, 3},
+
initialBuf: nil,
+
offsets: make(map[string]uint16),
+
expected: []byte{7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 0, 15, 0, 3},
+
newOffsets: map[string]uint16{"example.com": 0, "com": 8},
},
{
name: "Domain compression - example.com after subdomain.example.com",
···
QType: DNSType(1),
QClass: DNSClass(1),
},
+
initialBuf: nil,
offsets: map[string]uint16{
-
"com": 22,
-
"example.com": 19,
+
"subdomain.example.com": 0,
+
"example.com": 10,
+
"com": 18,
+
},
+
expected: []byte{0xC0, 0x0a, 0x00, 0x01, 0x00, 0x01},
+
newOffsets: map[string]uint16{
+
"subdomain.example.com": 0,
+
"example.com": 10,
+
"com": 18,
},
-
expected: []byte{0xC0, 0x13, 0x00, 0x01, 0x00, 0x01},
+
},
+
{
+
name: "Encode with initial buffer",
+
question: Question{
+
QName: "test.org",
+
QType: AType,
+
QClass: IN,
+
},
+
initialBuf: []byte{0xAA, 0xBB},
+
offsets: make(map[string]uint16),
+
expected: []byte{0xAA, 0xBB, 4, 't', 'e', 's', 't', 3, 'o', 'r', 'g', 0, 0, 1, 0, 1},
+
newOffsets: map[string]uint16{"test.org": 2, "org": 7},
+
},
+
{
+
name: "Encode invalid domain - label too long",
+
question: Question{
+
QName: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.com",
+
QType: AType,
+
QClass: IN,
+
},
+
initialBuf: nil,
+
offsets: make(map[string]uint16),
+
expected: nil,
+
expectedErr: &InvalidLabelError{},
+
wantErrMsg: "question encode: failed to encode QName",
+
newOffsets: map[string]uint16{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
-
result := tt.question.Encode(nil, &tt.offsets)
-
assert.Equal(t, tt.expected, result)
+
currentOffsets := make(map[string]uint16)
+
for k, v := range tt.offsets {
+
currentOffsets[k] = v
+
}
-
if len(tt.offsets) == 0 {
-
expectedOffsets := map[string]uint16{
-
tt.question.QName: 0,
-
}
-
for i := 0; i < len(tt.question.QName); i++ {
-
if tt.question.QName[i] == '.' {
-
expectedOffsets[tt.question.QName[i+1:]] = uint16(i + 1)
-
}
+
result, err := tt.question.Encode(tt.initialBuf, &currentOffsets)
+
+
if tt.expectedErr != nil {
+
assert.Error(t, err, "Expected an error but got nil")
+
assert.True(t, errors.Is(err, tt.expectedErr), "Error type mismatch. Got %T, expected %T", err, tt.expectedErr)
+
if tt.wantErrMsg != "" {
+
assert.ErrorContains(t, err, tt.wantErrMsg, "Wrapped error message mismatch")
}
-
assert.Equal(t, expectedOffsets, tt.offsets)
+
} else {
+
assert.NoError(t, err, "Expected no error but got one")
+
assert.Equal(t, tt.expected, result, "Encoded question mismatch")
+
assert.Equal(t, tt.newOffsets, currentOffsets, "Final offsets mismatch")
}
})
}
···
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
offsets := make(map[string]uint16)
-
encoded := tt.question.Encode(nil, &offsets)
+
encoded, err := tt.question.Encode(nil, &offsets)
+
require.NoError(t, err, "Encoding failed")
decodedQuestion := &Question{}
-
_, err := decodedQuestion.Decode(encoded, 0)
+
offset, err := decodedQuestion.Decode(encoded, 0)
-
assert.NoError(t, err)
-
assert.Equal(t, tt.question, *decodedQuestion)
+
assert.NoError(t, err, "Decoding failed")
+
assert.Equal(t, len(encoded), offset, "Offset after decoding should match encoded length")
+
assert.Equal(t, tt.question, *decodedQuestion, "Decoded question does not match original")
})
}
}
-
-
func TestQuestionEncodeWithExistingBuffer(t *testing.T) {
-
question := Question{
-
QName: "example.com",
-
QType: DNSType(1),
-
QClass: DNSClass(1),
-
}
-
-
existingBuffer := []byte{0xFF, 0xFF, 0xFF, 0xFF}
-
offsets := make(map[string]uint16)
-
-
result := question.Encode(existingBuffer, &offsets)
-
-
expected := append(
-
existingBuffer,
-
[]byte{7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 0, 1, 0, 1}...,
-
)
-
-
assert.Equal(t, expected, result)
-
}
-
-
func TestQuestionEncodeLongDomainName(t *testing.T) {
-
longLabel := make([]byte, 63)
-
for i := range longLabel {
-
longLabel[i] = 'a'
-
}
-
longDomainName := string(longLabel) + "." + string(longLabel) + "." + string(longLabel) + "." + string(longLabel[:61])
-
-
question := Question{
-
QName: longDomainName,
-
QType: DNSType(1),
-
QClass: DNSClass(1),
-
}
-
-
offsets := make(map[string]uint16)
-
encoded := question.Encode(nil, &offsets)
-
-
assert.Equal(t, 259, len(encoded))
-
-
decodedQuestion := &Question{}
-
_, err := decodedQuestion.Decode(encoded, 0)
-
-
assert.NoError(t, err)
-
assert.Equal(t, question, *decodedQuestion)
-
}
+16 -16
utils.go
···
// getU8 returns the first byte from a byte array at offset.
func getU8(buf []byte, offset int) (uint8, int, error) {
-
next_offset := offset + 1
-
if next_offset > len(buf) {
-
return 0, len(buf), &BufferOverflowError{Length: len(buf), Offset: next_offset}
+
nextOffset := offset + 1
+
if nextOffset > len(buf) {
+
return 0, len(buf), &BufferOverflowError{Length: len(buf), Offset: nextOffset}
}
-
return buf[offset], next_offset, nil
+
return buf[offset], nextOffset, nil
}
// getU16 returns the bigEndian uint16 from a byte array at offset.
func getU16(buf []byte, offset int) (uint16, int, error) {
-
next_offset := offset + 2
-
if next_offset > len(buf) {
-
return 0, len(buf), &BufferOverflowError{Length: len(buf), Offset: next_offset}
+
nextOffset := offset + 2
+
if nextOffset > len(buf) {
+
return 0, len(buf), &BufferOverflowError{Length: len(buf), Offset: nextOffset}
}
-
return binary.BigEndian.Uint16(buf[offset:]), next_offset, nil
+
return binary.BigEndian.Uint16(buf[offset:]), nextOffset, nil
}
// getU32 returns the bigEndian uint32 from a byte array at offset.
func getU32(buf []byte, offset int) (uint32, int, error) {
-
next_offset := offset + 4
-
if next_offset > len(buf) {
-
return 0, len(buf), &BufferOverflowError{Length: len(buf), Offset: next_offset}
+
nextOffset := offset + 4
+
if nextOffset > len(buf) {
+
return 0, len(buf), &BufferOverflowError{Length: len(buf), Offset: nextOffset}
}
-
return binary.BigEndian.Uint32(buf[offset:]), next_offset, nil
+
return binary.BigEndian.Uint32(buf[offset:]), nextOffset, nil
}
// getSlice returns a slice of bytes from a byte array at an offset and of length.
func getSlice(buf []byte, offset int, length int) ([]byte, int, error) {
-
next_offset := offset + length
-
if next_offset > len(buf) {
-
return nil, len(buf), &BufferOverflowError{Length: len(buf), Offset: next_offset}
+
nextOffset := offset + length
+
if nextOffset > len(buf) {
+
return nil, len(buf), &BufferOverflowError{Length: len(buf), Offset: nextOffset}
}
-
return buf[offset:next_offset], next_offset, nil
+
return buf[offset:nextOffset], nextOffset, nil
}
+11
.golangci.toml
···
+
version = "2"
+
[linters.settings.govet]
+
enable_all = true
+
[linters.settings.staticcheck]
+
checks = ["all"]
+
[linters.settings.exhaustive]
+
check = ["switch"]
+
[linters.settings.exhaustruct]
+
[linters.settings.goconst]
+
[linters.settings.testifylint]
+
enable-all = true
+7
Justfile
···
+
format:
+
go fmt ./...
+
gofumpt -l -w $(fd .go)
+
+
verify:
+
go vet ./...
+
golangci-lint run
+6
README.md
···
# Magna
+
## Contributing
+
Dependencies:
+
- go 1.24
+
- just
+
- golint-ci
+
this is a go package for packing/unpacking dns packets.
> which we expect to be so popular that it would be a waste of wire space
+7
types.go
···
X25Type = 19
ISDNType = 20
RTType = 21
+
AAAAType = 28
OPTType = 41
···
return "ISDN"
case RTType:
return "RT"
+
case AAAAType:
+
return "AAAA"
case OPTType:
return "OPT"
case AXFRType:
···
IntermediateHost string
}
+
type AAAA struct {
+
Address net.IP
+
}
+
type EDNSOption struct {
Code uint16
Data []byte
+7 -1
.tangled/workflows/lint.yml
···
when:
-
- event: ["push", "pull_request"]
+
- event: ["push", "pull_request", "manual"]
branch: ["main"]
+
engine: "nixery"
+
+
clone:
+
depth: 1
+
submodules: false
+
dependencies:
nixpkgs:
- go
+7 -2
.tangled/workflows/test.yml
···
when:
-
- event: ["push", "pull_request"]
+
- event: ["push", "pull_request", "manual"]
branch: ["main"]
-
- event: ["manual"]
+
+
engine: "nixery"
+
+
clone:
+
depth: 1
+
submodules: false
dependencies:
nixpkgs:
+18
.tangled/workflows/staticcheck.yml
···
+
when:
+
- event: ["push", "pull_request", "manual"]
+
branch: ["main"]
+
+
engine: "nixery"
+
+
clone:
+
depth: 1
+
submodules: false
+
+
dependencies:
+
nixpkgs:
+
- go
+
+
steps:
+
- name: "staticcheck"
+
command: |
+
go run honnef.co/go/tools/cmd/staticcheck@latest