a go dns packet parser

refractor code to be more standard

+30 -31
domain_name.go
···
import (
"encoding/binary"
"strings"
)
-
// decode_domain decodes a domain name from a buffer starting at offset.
// It returns the domain name along with the offset and error.
-
func decode_domain(buf []byte, offset int) (string, int, error) {
var builder strings.Builder
firstLabel := true
-
seen_offsets := make(map[int]struct{})
finalOffsetAfterJump := -1
currentOffset := offset
for {
-
if _, found := seen_offsets[currentOffset]; found {
return "", len(buf), &DomainCompressionError{}
}
-
seen_offsets[currentOffset] = struct{}{}
length, nextOffsetAfterLen, err := getU8(buf, currentOffset)
if err != nil {
-
return "", len(buf), err
}
if length == 0 {
···
if (length & 0xC0) == 0xC0 {
sec, nextOffsetAfterPtr, err := getU8(buf, nextOffsetAfterLen)
if err != nil {
-
return "", len(buf), err
}
jumpTargetOffset := int(length&0x3F)<<8 | int(sec)
···
if jumpTargetOffset >= len(buf) {
return "", len(buf), &BufferOverflowError{Length: len(buf), Offset: jumpTargetOffset}
}
-
if _, found := seen_offsets[jumpTargetOffset]; found {
return "", len(buf), &DomainCompressionError{}
}
···
labelBytes, nextOffsetAfterLabel, err := getSlice(buf, nextOffsetAfterLen, int(length))
if err != nil {
-
return "", len(buf), err
}
if !firstLabel {
···
return builder.String(), finalReadOffset, nil
}
-
// encode_domain returns the bytes of the input bytes appened with the encoded domain name.
-
func encode_domain(bytes []byte, domain_name string, offsets *map[string]uint16) []byte {
-
if domain_name == "." || domain_name == "" {
-
return append(bytes, 0)
}
-
clean_domain := strings.TrimSuffix(domain_name, ".")
-
if clean_domain == "" {
-
return append(bytes, 0)
}
start := 0
-
for start < len(clean_domain) {
-
suffix := clean_domain[start:]
if offset, found := (*offsets)[suffix]; found {
-
if offset > 0x3FFF {
-
end := strings.IndexByte(suffix, '.')
-
if end == -1 {
-
end = len(suffix)
-
}
-
} else {
-
pointer := 0xC000 | offset
-
return binary.BigEndian.AppendUint16(bytes, pointer)
-
}
}
currentPos := uint16(len(bytes))
···
end := strings.IndexByte(suffix, '.')
var label string
-
nextStart := len(clean_domain)
if end == -1 {
label = suffix
···
labelBytes := []byte(label)
if len(labelBytes) > 63 {
-
// XXX: maybe should return an error
-
labelBytes = labelBytes[:63]
}
bytes = append(bytes, byte(len(labelBytes)))
bytes = append(bytes, labelBytes...)
}
-
return append(bytes, 0)
}
···
import (
"encoding/binary"
+
"fmt"
"strings"
)
+
// decodeDomain decodes a domain name from a buffer starting at offset.
// It returns the domain name along with the offset and error.
+
func decodeDomain(buf []byte, offset int) (string, int, error) {
var builder strings.Builder
firstLabel := true
+
seenOffsets := make(map[int]struct{})
finalOffsetAfterJump := -1
currentOffset := offset
for {
+
if _, found := seenOffsets[currentOffset]; found {
return "", len(buf), &DomainCompressionError{}
}
+
seenOffsets[currentOffset] = struct{}{}
length, nextOffsetAfterLen, err := getU8(buf, currentOffset)
if err != nil {
+
return "", len(buf), fmt.Errorf("failed to read domain label length: %w", err)
}
if length == 0 {
···
if (length & 0xC0) == 0xC0 {
sec, nextOffsetAfterPtr, err := getU8(buf, nextOffsetAfterLen)
if err != nil {
+
return "", len(buf), fmt.Errorf("failed to read domain compression pointer offset byte: %w", err)
}
jumpTargetOffset := int(length&0x3F)<<8 | int(sec)
···
if jumpTargetOffset >= len(buf) {
return "", len(buf), &BufferOverflowError{Length: len(buf), Offset: jumpTargetOffset}
}
+
if _, found := seenOffsets[jumpTargetOffset]; found {
return "", len(buf), &DomainCompressionError{}
}
···
labelBytes, nextOffsetAfterLabel, err := getSlice(buf, nextOffsetAfterLen, int(length))
if err != nil {
+
return "", len(buf), fmt.Errorf("failed to read domain label data: %w", err)
}
if !firstLabel {
···
return builder.String(), finalReadOffset, nil
}
+
// encodeDomain returns the bytes of the input bytes appened with the encoded domain name.
+
func encodeDomain(bytes []byte, domainName string, offsets *map[string]uint16) ([]byte, error) {
+
if domainName == "." || domainName == "" {
+
return append(bytes, 0), nil
}
+
cleanDomain := strings.TrimSuffix(domainName, ".")
+
if cleanDomain == "" {
+
return append(bytes, 0), nil
}
start := 0
+
for start < len(cleanDomain) {
+
suffix := cleanDomain[start:]
if offset, found := (*offsets)[suffix]; found {
+
pointer := 0xC000 | offset
+
bytes = binary.BigEndian.AppendUint16(bytes, pointer)
+
return bytes, nil
}
currentPos := uint16(len(bytes))
···
end := strings.IndexByte(suffix, '.')
var label string
+
nextStart := len(cleanDomain)
if end == -1 {
label = suffix
···
labelBytes := []byte(label)
if len(labelBytes) > 63 {
+
return nil, &InvalidLabelError{Length: int(len(labelBytes))}
+
}
+
+
if len(labelBytes) == 0 && start < len(cleanDomain) {
+
return nil, &InvalidLabelError{Length: 0}
}
bytes = append(bytes, byte(len(labelBytes)))
bytes = append(bytes, labelBytes...)
}
+
bytes = append(bytes, 0)
+
return bytes, nil
}
+143 -22
domain_test.go
···
package magna
import (
"testing"
"github.com/stretchr/testify/assert"
···
func BenchmarkDecodeDomainSimple(b *testing.B) {
input := []byte{3, 'w', 'w', 'w', 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0}
for i := 0; i < b.N; i++ {
-
_, _, _ = decode_domain(input, 0)
}
}
···
offset := 13
b.ResetTimer()
for i := 0; i < b.N; i++ {
-
_, _, _ = decode_domain(input, offset)
}
}
···
out := make([]byte, 0, 64)
b.ResetTimer()
for i := 0; i < b.N; i++ {
-
_ = encode_domain(out[:0], domain, &offsets)
for k := range offsets {
delete(offsets, k)
}
···
out := make([]byte, 0, 128)
b.ResetTimer()
for i := 0; i < b.N; i++ {
-
tempOut := encode_domain(out[:0], domain1, &offsets)
-
_ = encode_domain(tempOut, domain2, &offsets)
for k := range offsets {
delete(offsets, k)
}
···
expectedDomain string
expectedOffset int
expectedError error
}{
{
name: "Simple domain",
···
expectedDomain: "",
expectedOffset: 2,
expectedError: &InvalidLabelError{Length: 64},
},
{
name: "Compression loop",
···
expectedDomain: "",
expectedOffset: 4,
expectedError: &DomainCompressionError{},
},
{
name: "Truncated input",
···
expectedDomain: "",
expectedOffset: 3,
expectedError: &BufferOverflowError{Length: 3, Offset: 4},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
-
domain, offset, err := decode_domain(tt.input, tt.offset)
-
t.Log(tt.name)
-
assert.Equal(t, tt.expectedError, err)
-
assert.Equal(t, tt.expectedDomain, domain)
-
assert.Equal(t, tt.expectedOffset, offset)
})
}
}
func TestEncodeDomain(t *testing.T) {
tests := []struct {
-
name string
-
input string
-
offsets map[string]uint16
-
expected []byte
-
newOffsets map[string]uint16
}{
{
name: "Simple domain",
input: "example.com",
offsets: make(map[string]uint16),
expected: []byte{7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0},
newOffsets: map[string]uint16{"example.com": 0, "com": 8},
},
{
-
name: "Domain with existing offset",
input: "test.example.com",
offsets: map[string]uint16{"example.com": 10},
expected: []byte{4, 't', 'e', 's', 't', 0xC0, 0x0A},
newOffsets: map[string]uint16{"test.example.com": 0, "example.com": 10},
···
{
name: "Multiple subdomains",
input: "a.b.c.d",
offsets: make(map[string]uint16),
expected: []byte{1, 'a', 1, 'b', 1, 'c', 1, 'd', 0},
newOffsets: map[string]uint16{"a.b.c.d": 0, "b.c.d": 2, "c.d": 4, "d": 6},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
-
result := encode_domain([]byte{}, tt.input, &tt.offsets)
-
assert.Equal(t, tt.expected, result, "Encoded domain does not match expected output")
-
assert.Equal(t, tt.newOffsets, tt.offsets, "Offsets map does not match expected state")
})
}
}
···
0x03, 0x63, 0x6f, 0x6d, 0x00,
},
{
-
0x03, 0x63, 0x6f, 0x6d, 0x00, 0x01, 0x63, 0xC0, 0x00,
},
{
-
0x03, 0x63, 0x6f, 0x6d, 0xC0, 0x00,
},
}
for _, tc := range testcases {
f.Add(tc)
}
f.Fuzz(func(t *testing.T, msg []byte) {
-
decode_domain(msg, 0)
})
}
···
package magna
import (
+
"errors"
"testing"
"github.com/stretchr/testify/assert"
···
func BenchmarkDecodeDomainSimple(b *testing.B) {
input := []byte{3, 'w', 'w', 'w', 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0}
for i := 0; i < b.N; i++ {
+
_, _, _ = decodeDomain(input, 0)
}
}
···
offset := 13
b.ResetTimer()
for i := 0; i < b.N; i++ {
+
_, _, _ = decodeDomain(input, offset)
}
}
···
out := make([]byte, 0, 64)
b.ResetTimer()
for i := 0; i < b.N; i++ {
+
_, _ = encodeDomain(out[:0], domain, &offsets)
for k := range offsets {
delete(offsets, k)
}
···
out := make([]byte, 0, 128)
b.ResetTimer()
for i := 0; i < b.N; i++ {
+
tempOut, _ := encodeDomain(out[:0], domain1, &offsets)
+
_, _ = encodeDomain(tempOut, domain2, &offsets)
for k := range offsets {
delete(offsets, k)
}
···
expectedDomain string
expectedOffset int
expectedError error
+
errorCheck func(t *testing.T, err error)
}{
{
name: "Simple domain",
···
expectedDomain: "",
expectedOffset: 2,
expectedError: &InvalidLabelError{Length: 64},
+
errorCheck: func(t *testing.T, err error) {
+
var target *InvalidLabelError
+
assert.True(t, errors.As(err, &target))
+
assert.Equal(t, 64, target.Length)
+
},
},
{
name: "Compression loop",
···
expectedDomain: "",
expectedOffset: 4,
expectedError: &DomainCompressionError{},
+
errorCheck: func(t *testing.T, err error) {
+
assert.IsType(t, &DomainCompressionError{}, err)
+
},
},
{
name: "Truncated input",
···
expectedDomain: "",
expectedOffset: 3,
expectedError: &BufferOverflowError{Length: 3, Offset: 4},
+
errorCheck: func(t *testing.T, err error) {
+
var target *BufferOverflowError
+
assert.True(t, errors.As(err, &target), "Expected BufferOverflowError")
+
if target != nil {
+
assert.Equal(t, 3, target.Length)
+
assert.Equal(t, 1+3, target.Offset)
+
}
+
assert.Contains(t, err.Error(), "failed to read domain label data")
+
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
+
domain, offset, err := decodeDomain(tt.input, tt.offset)
+
+
t.Logf("Test: %s, Input: %x, OffsetIn: %d => Domain: '%s', OffsetOut: %d, Err: %v", tt.name, tt.input, tt.offset, domain, offset, err)
+
+
if tt.expectedError != nil {
+
assert.Error(t, err, "Expected an error but got nil")
+
if tt.errorCheck != nil {
+
tt.errorCheck(t, err)
+
} else {
+
assert.IsType(t, tt.expectedError, err, "Error type mismatch")
+
}
+
} else {
+
assert.NoError(t, err, "Expected no error but got one")
+
}
+
assert.Equal(t, tt.expectedDomain, domain, "Domain mismatch")
+
if tt.expectedError == nil {
+
assert.Equal(t, tt.expectedOffset, offset, "Offset mismatch")
+
}
})
}
}
func TestEncodeDomain(t *testing.T) {
tests := []struct {
+
name string
+
input string
+
initialBuf []byte
+
offsets map[string]uint16
+
expected []byte
+
expectedErr error
+
newOffsets map[string]uint16
}{
{
name: "Simple domain",
input: "example.com",
+
initialBuf: []byte{},
offsets: make(map[string]uint16),
expected: []byte{7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0},
newOffsets: map[string]uint16{"example.com": 0, "com": 8},
},
{
+
name: "Domain with existing offset for compression",
input: "test.example.com",
+
initialBuf: []byte{},
offsets: map[string]uint16{"example.com": 10},
expected: []byte{4, 't', 'e', 's', 't', 0xC0, 0x0A},
newOffsets: map[string]uint16{"test.example.com": 0, "example.com": 10},
···
{
name: "Multiple subdomains",
input: "a.b.c.d",
+
initialBuf: []byte{},
offsets: make(map[string]uint16),
expected: []byte{1, 'a', 1, 'b', 1, 'c', 1, 'd', 0},
newOffsets: map[string]uint16{"a.b.c.d": 0, "b.c.d": 2, "c.d": 4, "d": 6},
},
+
{
+
name: "Root domain",
+
input: ".",
+
initialBuf: []byte{},
+
offsets: make(map[string]uint16),
+
expected: []byte{0},
+
newOffsets: map[string]uint16{},
+
},
+
{
+
name: "Empty domain",
+
input: "",
+
initialBuf: []byte{},
+
offsets: make(map[string]uint16),
+
expected: []byte{0},
+
newOffsets: map[string]uint16{},
+
},
+
{
+
name: "Label too long",
+
input: "labeltoolonglabeltoolonglabeltoolonglabeltoolonglabeltoolonglabeltoolong.com",
+
initialBuf: []byte{},
+
offsets: make(map[string]uint16),
+
expected: nil,
+
expectedErr: &InvalidLabelError{Length: 72},
+
newOffsets: map[string]uint16{},
+
},
+
{
+
name: "Empty label inside domain",
+
input: "example..com",
+
initialBuf: []byte{},
+
offsets: make(map[string]uint16),
+
expected: nil,
+
expectedErr: &InvalidLabelError{Length: 0},
+
newOffsets: map[string]uint16{},
+
},
+
{
+
name: "Append to existing buffer",
+
input: "example.com",
+
initialBuf: []byte{0xAA, 0xBB},
+
offsets: make(map[string]uint16),
+
expected: []byte{0xAA, 0xBB, 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0},
+
newOffsets: map[string]uint16{"example.com": 2, "com": 10},
+
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
+
currentOffsets := make(map[string]uint16)
+
for k, v := range tt.offsets {
+
currentOffsets[k] = v
+
}
+
+
result, err := encodeDomain(tt.initialBuf, tt.input, &currentOffsets)
+
+
if tt.expectedErr != nil {
+
assert.Error(t, err, "Expected an error but got nil")
+
assert.IsType(t, tt.expectedErr, err, "Error type mismatch")
+
if expectedILE, ok := tt.expectedErr.(*InvalidLabelError); ok {
+
actualILE := &InvalidLabelError{}
+
if assert.True(t, errors.As(err, &actualILE)) {
+
assert.Equal(t, expectedILE.Length, actualILE.Length)
+
}
+
}
+
} else {
+
assert.NoError(t, err, "Expected no error but got one")
+
assert.Equal(t, tt.expected, result, "Encoded domain does not match expected output")
+
assert.Equal(t, tt.newOffsets, currentOffsets, "Offsets map does not match expected state")
+
}
})
}
}
···
0x03, 0x63, 0x6f, 0x6d, 0x00,
},
{
+
0x03, 0x63, 0x6f, 0x6d, 0x00, 0x01, 0x63, 0xc0, 0x00,
+
},
+
{
+
0x03, 0x63, 0x6f, 0x6d, 0xc0, 0x00,
+
},
+
{
+
0xc0, 0x00,
+
},
+
{
+
0xc0, 0xff,
},
{
+
0x40,
+
},
+
{
+
0x03, 0x63, 0x6f,
+
},
+
{
+
0xc0,
},
}
for _, tc := range testcases {
f.Add(tc)
}
f.Fuzz(func(t *testing.T, msg []byte) {
+
_, _, err := decodeDomain(msg, 0)
+
if err != nil {
+
var bufErr *BufferOverflowError
+
var labelErr *InvalidLabelError
+
var compErr *DomainCompressionError
+
+
if !(errors.As(err, &bufErr) || errors.As(err, &labelErr) || errors.As(err, &compErr)) {
+
t.Errorf("Fuzzing decodeDomain: unexpected error type %T: %v for input %x", err, err, msg)
+
}
+
}
})
}
+24 -10
errors.go
···
}
func (e *BufferOverflowError) Error() string {
-
return fmt.Sprintf("magna: offset %d is past the buffer length %d", e.Offset, e.Length)
}
// InvalidLabelError represents an error when an invalid label length is encountered.
···
}
func (e *InvalidLabelError) Error() string {
-
return fmt.Sprintf("magna: received invalid label length %d", e.Length)
}
// DomainCompressionError represents an error related to domain compression.
type DomainCompressionError struct{}
func (e *DomainCompressionError) Error() string {
-
return "magna: loop detected in domain compression"
}
-
// MagnaError represents a generic error with a custom message.
-
type MagnaError struct {
-
Message string
-
}
-
-
func (e *MagnaError) Error() string {
-
return fmt.Sprintf("magna: %s", e.Message)
}
···
}
func (e *BufferOverflowError) Error() string {
+
return fmt.Sprintf("buffer overflow: attempted to read past buffer length %d at offset %d", e.Length, e.Offset)
+
}
+
+
func (e *BufferOverflowError) Is(target error) bool {
+
_, ok := target.(*BufferOverflowError)
+
return ok
}
// InvalidLabelError represents an error when an invalid label length is encountered.
···
}
func (e *InvalidLabelError) Error() string {
+
if e.Length > 63 {
+
return fmt.Sprintf("invalid domain label: length %d exceeds maximum 63", e.Length)
+
}
+
if e.Length == 0 {
+
return "invalid domain label: zero length label encountered"
+
}
+
+
// XXX: this should be unreachable
+
return fmt.Sprintf("invalid domain label: unexpected length %d", e.Length)
+
}
+
+
func (e *InvalidLabelError) Is(target error) bool {
+
_, ok := target.(*InvalidLabelError)
+
return ok
}
// DomainCompressionError represents an error related to domain compression.
type DomainCompressionError struct{}
func (e *DomainCompressionError) Error() string {
+
return "invalid domain compression: pointer loop detected"
}
+
func (e *DomainCompressionError) Is(target error) bool {
+
_, ok := target.(*DomainCompressionError)
+
return ok
}
+8 -28
errors_test.go
···
package magna
import (
-
"fmt"
"testing"
"github.com/stretchr/testify/assert"
···
offset int
expected string
}{
-
{"PositiveOffset", 10, 15, "magna: offset 15 is past the buffer length 10"},
-
{"ZeroLengthBuffer", 0, 5, "magna: offset 5 is past the buffer length 0"},
-
{"NegativeOffset", 10, -1, "magna: offset -1 is past the buffer length 10"},
-
{"EqualOffset", 10, 10, "magna: offset 10 is past the buffer length 10"},
}
for _, tt := range tests {
···
length int
expected string
}{
-
{"LengthTooLarge", 64, "magna: received invalid label length 64"},
-
{"LengthZero", 0, "magna: received invalid label length 0"},
-
{"NegativeLength", -1, "magna: received invalid label length -1"},
}
for _, tt := range tests {
···
func TestDomainCompressionError(t *testing.T) {
t.Run("Standard", func(t *testing.T) {
err := &DomainCompressionError{}
-
expected := "magna: loop detected in domain compression"
assert.Equal(t, expected, err.Error(), "Error() output mismatch")
})
}
-
-
func TestMagnaError(t *testing.T) {
-
tests := []struct {
-
name string
-
message string
-
expected string
-
}{
-
{"EmptyMessage", "", "magna: "},
-
{"SimpleMessage", "test error", "magna: test error"},
-
{"MessageWithPunctuation", "error: invalid input!", "magna: error: invalid input!"},
-
}
-
-
for _, tt := range tests {
-
t.Run(tt.name, func(t *testing.T) {
-
err := &MagnaError{Message: tt.message}
-
assert.Equal(t, tt.expected, err.Error())
-
})
-
}
-
}
···
package magna
import (
"testing"
"github.com/stretchr/testify/assert"
···
offset int
expected string
}{
+
{"PositiveOffset", 10, 15, "buffer overflow: attempted to read past buffer length 10 at offset 15"},
+
{"ZeroLengthBuffer", 0, 5, "buffer overflow: attempted to read past buffer length 0 at offset 5"},
+
{"NegativeOffset", 10, -1, "buffer overflow: attempted to read past buffer length 10 at offset -1"},
+
{"EqualOffset", 10, 10, "buffer overflow: attempted to read past buffer length 10 at offset 10"},
}
for _, tt := range tests {
···
length int
expected string
}{
+
{"LengthTooLarge", 64, "invalid domain label: length 64 exceeds maximum 63"},
+
{"LengthZero", 0, "invalid domain label: zero length label encountered"},
+
{"ValidLength", 30, "invalid domain label: unexpected length 30"},
}
for _, tt := range tests {
···
func TestDomainCompressionError(t *testing.T) {
t.Run("Standard", func(t *testing.T) {
err := &DomainCompressionError{}
+
expected := "invalid domain compression: pointer loop detected"
assert.Equal(t, expected, err.Error(), "Error() output mismatch")
})
}
+10 -7
header.go
···
package magna
-
import "encoding/binary"
// Decode decodes the header from the bytes.
func (h *Header) Decode(buf []byte, offset int) (int, error) {
var err error
h.ID, offset, err = getU16(buf, offset)
if err != nil {
-
return len(buf), err
}
flags, offset, err := getU16(buf, offset)
if err != nil {
-
return len(buf), err
}
h.QDCount, offset, err = getU16(buf, offset)
if err != nil {
-
return len(buf), err
}
h.ANCount, offset, err = getU16(buf, offset)
if err != nil {
-
return len(buf), err
}
h.NSCount, offset, err = getU16(buf, offset)
if err != nil {
-
return len(buf), err
}
h.ARCount, offset, err = getU16(buf, offset)
if err != nil {
-
return len(buf), err
}
h.QR = ((flags >> 15) & 0x01) == 1
···
package magna
+
import (
+
"encoding/binary"
+
"fmt"
+
)
// Decode decodes the header from the bytes.
func (h *Header) Decode(buf []byte, offset int) (int, error) {
var err error
h.ID, offset, err = getU16(buf, offset)
if err != nil {
+
return len(buf), fmt.Errorf("header decode: failed to read ID: %w", err)
}
flags, offset, err := getU16(buf, offset)
if err != nil {
+
return len(buf), fmt.Errorf("header decode: failed to read flags: %w", err)
}
h.QDCount, offset, err = getU16(buf, offset)
if err != nil {
+
return len(buf), fmt.Errorf("header decode: failed to read QDCount: %w", err)
}
h.ANCount, offset, err = getU16(buf, offset)
if err != nil {
+
return len(buf), fmt.Errorf("header decode: failed to read ANCount: %w", err)
}
h.NSCount, offset, err = getU16(buf, offset)
if err != nil {
+
return len(buf), fmt.Errorf("header decode: failed to read NSCount: %w", err)
}
h.ARCount, offset, err = getU16(buf, offset)
if err != nil {
+
return len(buf), fmt.Errorf("header decode: failed to read ARCount: %w", err)
}
h.QR = ((flags >> 15) & 0x01) == 1
+81 -60
header_test.go
···
package magna
import (
"encoding/binary"
"testing"
"github.com/stretchr/testify/assert"
)
func TestHeaderDecode(t *testing.T) {
···
expectedHeader Header
expectedOffset int
expectedErr error
}{
{
name: "Valid header",
···
expectedErr: nil,
},
{
-
name: "Insufficient buffer length",
input: []byte{0x12, 0x34, 0x81},
-
expectedHeader: Header{},
expectedOffset: 3,
-
expectedErr: &BufferOverflowError{Length: 3, Offset: 3},
},
{
-
name: "Invalid ID",
input: []byte{0x12},
expectedHeader: Header{},
expectedOffset: 1,
-
expectedErr: &BufferOverflowError{Length: 1, Offset: 1},
},
{
name: "Missing QDCount",
input: []byte{0x12, 0x34, 0x81, 0x80, 0x00},
-
expectedHeader: Header{},
expectedOffset: 5,
expectedErr: &BufferOverflowError{},
},
{
name: "Missing ANCount",
-
input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01},
-
expectedHeader: Header{},
-
expectedOffset: 6,
expectedErr: &BufferOverflowError{},
},
{
name: "Missing NSCount",
-
input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02},
-
expectedHeader: Header{},
-
expectedOffset: 8,
expectedErr: &BufferOverflowError{},
},
{
name: "Missing ARCount",
-
input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03},
-
expectedHeader: Header{},
-
expectedOffset: 10,
expectedErr: &BufferOverflowError{},
},
}
···
offset, err := h.Decode(tt.input, 0)
if tt.expectedErr != nil {
-
assert.Error(t, err)
-
assert.IsType(t, tt.expectedErr, err)
} else {
-
assert.NoError(t, err)
-
assert.Equal(t, tt.expectedHeader, *h)
}
-
-
assert.Equal(t, tt.expectedOffset, offset)
})
}
}
···
}
h := &Header{}
-
_, err := h.Decode(input, 0)
assert.NoError(t, err)
-
assert.Equal(t, tt.expected.QR, h.QR)
-
assert.Equal(t, tt.expected.OPCode, h.OPCode)
-
assert.Equal(t, tt.expected.AA, h.AA)
-
assert.Equal(t, tt.expected.TC, h.TC)
-
assert.Equal(t, tt.expected.RD, h.RD)
-
assert.Equal(t, tt.expected.RA, h.RA)
-
assert.Equal(t, tt.expected.Z, h.Z)
-
assert.Equal(t, tt.expected.RCode, h.RCode)
})
}
}
···
},
},
{
-
name: "No flags set",
header: Header{
ID: 0x5678,
QR: false,
···
},
},
{
-
name: "Mixed flags",
header: Header{
ID: 0x9abc,
QR: true,
···
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
encoded := tt.header.Encode()
-
assert.Equal(t, tt.expected, encoded)
})
}
}
···
}
encoded := originalHeader.Encode()
decodedHeader := &Header{}
offset, err := decodedHeader.Decode(encoded, 0)
-
assert.NoError(t, err)
-
assert.Equal(t, len(encoded), offset)
-
assert.Equal(t, originalHeader, *decodedHeader)
}
func TestHeaderEncodeFlagCombinations(t *testing.T) {
···
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
encoded := tc.header.Encode()
flags := binary.BigEndian.Uint16(encoded[2:4])
-
assert.Equal(t, tc.expected, flags)
})
}
}
···
{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, 0x00, 0x04},
{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
}
for _, tc := range testcases {
···
}
f.Fuzz(func(t *testing.T, data []byte) {
-
// limit to only 12 bytes
-
if len(data) > 12 {
-
data = data[0:12]
-
}
-
h := &Header{}
offset, err := h.Decode(data, 0)
if err != nil {
-
switch err.(type) {
-
case *BufferOverflowError, *InvalidLabelError:
-
// these are expected error types
-
default:
-
t.Errorf("unexpected error type: %T", err)
}
return
}
-
if offset != len(data) {
-
t.Errorf("offset (%d) does not match data length (%d)", offset, len(data))
}
if h.OPCode > 15 {
-
t.Errorf("invalid OPCode: %d", h.OPCode)
}
-
if h.Z > 7 {
-
t.Errorf("invalid Z value: %d", h.Z)
}
-
if h.RCode > 15 {
-
t.Errorf("invalid RCode: %d", h.RCode)
}
-
encoded := h.Encode()
-
if len(encoded) != len(data) {
-
t.Errorf("encoded length (%d) does not match input length (%d)", len(encoded), len(data))
-
-
for i := 0; i < len(data); i++ {
-
t.Errorf("mismatch at position: %d: encoded %02x, input: %02x", i, encoded[i], data[i])
}
}
})
···
package magna
import (
+
"bytes"
"encoding/binary"
+
"errors"
"testing"
"github.com/stretchr/testify/assert"
+
"github.com/stretchr/testify/require"
)
func TestHeaderDecode(t *testing.T) {
···
expectedHeader Header
expectedOffset int
expectedErr error
+
wantErrMsg string
}{
{
name: "Valid header",
···
expectedErr: nil,
},
{
+
name: "Insufficient buffer length for Flags",
input: []byte{0x12, 0x34, 0x81},
+
expectedHeader: Header{ID: 0x1234},
expectedOffset: 3,
+
expectedErr: &BufferOverflowError{},
+
wantErrMsg: "header decode: failed to read flags",
},
{
+
name: "Insufficient buffer length for ID",
input: []byte{0x12},
expectedHeader: Header{},
expectedOffset: 1,
+
expectedErr: &BufferOverflowError{},
+
wantErrMsg: "header decode: failed to read ID",
},
{
name: "Missing QDCount",
input: []byte{0x12, 0x34, 0x81, 0x80, 0x00},
+
expectedHeader: Header{ID: 0x1234},
expectedOffset: 5,
expectedErr: &BufferOverflowError{},
+
wantErrMsg: "header decode: failed to read QDCount",
},
{
name: "Missing ANCount",
+
input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00},
+
expectedHeader: Header{ID: 0x1234, QDCount: 1},
+
expectedOffset: 7,
expectedErr: &BufferOverflowError{},
+
wantErrMsg: "header decode: failed to read ANCount",
},
{
name: "Missing NSCount",
+
input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00},
+
expectedHeader: Header{ID: 0x1234, QDCount: 1, ANCount: 2},
+
expectedOffset: 9,
expectedErr: &BufferOverflowError{},
+
wantErrMsg: "header decode: failed to read NSCount",
},
{
name: "Missing ARCount",
+
input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, 0x00},
+
expectedHeader: Header{ID: 0x1234, QDCount: 1, ANCount: 2, NSCount: 3},
+
expectedOffset: 11,
expectedErr: &BufferOverflowError{},
+
wantErrMsg: "header decode: failed to read ARCount",
},
}
···
offset, err := h.Decode(tt.input, 0)
if tt.expectedErr != nil {
+
assert.Error(t, err, "Expected an error but got nil")
+
+
assert.True(t, errors.Is(err, tt.expectedErr), "Error type mismatch. Got %T, expected %T", err, tt.expectedErr)
+
+
if tt.wantErrMsg != "" {
+
assert.ErrorContains(t, err, tt.wantErrMsg, "Wrapped error message mismatch")
+
}
+
+
assert.Equal(t, tt.expectedOffset, offset, "Offset mismatch on error")
} else {
+
assert.NoError(t, err, "Expected no error but got one")
+
+
assert.Equal(t, tt.expectedHeader, *h, "Header content mismatch")
+
assert.Equal(t, tt.expectedOffset, offset, "Offset mismatch on success")
}
})
}
}
···
}
h := &Header{}
+
offset, err := h.Decode(input, 0)
assert.NoError(t, err)
+
assert.Equal(t, 12, offset, "Offset should be 12 after decoding full header")
+
+
assert.Equal(t, tt.expected.QR, h.QR, "QR flag mismatch")
+
assert.Equal(t, tt.expected.OPCode, h.OPCode, "OPCode mismatch")
+
assert.Equal(t, tt.expected.AA, h.AA, "AA flag mismatch")
+
assert.Equal(t, tt.expected.TC, h.TC, "TC flag mismatch")
+
assert.Equal(t, tt.expected.RD, h.RD, "RD flag mismatch")
+
assert.Equal(t, tt.expected.RA, h.RA, "RA flag mismatch")
+
assert.Equal(t, tt.expected.Z, h.Z, "Z value mismatch")
+
assert.Equal(t, tt.expected.RCode, h.RCode, "RCode mismatch")
})
}
}
···
},
},
{
+
name: "No flags set, different counts",
header: Header{
ID: 0x5678,
QR: false,
···
},
},
{
+
name: "Mixed flags and counts",
header: Header{
ID: 0x9abc,
QR: true,
···
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
encoded := tt.header.Encode()
+
assert.Equal(t, tt.expected, encoded, "Encoded header mismatch")
})
}
}
···
}
encoded := originalHeader.Encode()
+
assert.Len(t, encoded, 12, "Encoded header should be 12 bytes")
decodedHeader := &Header{}
offset, err := decodedHeader.Decode(encoded, 0)
+
assert.NoError(t, err, "Decoding failed unexpectedly")
+
assert.Equal(t, 12, offset, "Offset after decoding should be 12")
+
assert.Equal(t, originalHeader, *decodedHeader, "Decoded header does not match original")
}
func TestHeaderEncodeFlagCombinations(t *testing.T) {
···
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
encoded := tc.header.Encode()
+
require.Len(t, encoded, 12, "Encoded header length invariant")
+
flags := binary.BigEndian.Uint16(encoded[2:4])
+
assert.Equal(t, tc.expected, flags, "Flags value mismatch")
})
}
}
···
{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, 0x00, 0x04},
{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+
{0x12, 0x34},
+
{0x12, 0x34, 0x81, 0x80, 0x00},
+
{},
}
for _, tc := range testcases {
···
}
f.Fuzz(func(t *testing.T, data []byte) {
h := &Header{}
offset, err := h.Decode(data, 0)
if err != nil {
+
var bofErr *BufferOverflowError
+
if !errors.As(err, &bofErr) {
+
t.Errorf("FuzzDecodeHeader: expected BufferOverflowError or wrapped BOF, got %T: %v", err, err)
+
}
+
if offset > len(data) {
+
t.Errorf("FuzzDecodeHeader: offset (%d) > data length (%d) on error", offset, len(data))
}
return
}
+
if len(data) < 12 {
+
t.Errorf("FuzzDecodeHeader: decoded successfully but input length %d < 12", len(data))
+
return
+
}
+
if offset != 12 {
+
t.Errorf("FuzzDecodeHeader: successful decode offset (%d) != 12", offset)
}
if h.OPCode > 15 {
+
t.Errorf("FuzzDecodeHeader: invalid OPCode decoded: %d", h.OPCode)
}
if h.Z > 7 {
+
t.Errorf("FuzzDecodeHeader: invalid Z value decoded: %d", h.Z)
}
if h.RCode > 15 {
+
t.Errorf("FuzzDecodeHeader: invalid RCode decoded: %d", h.RCode)
}
+
if len(data) >= 12 {
+
encoded := h.Encode()
+
if !bytes.Equal(encoded, data[:12]) {
+
t.Errorf("FuzzDecodeHeader: encode/decode mismatch\nInput: %x\nEncoded: %x", data[:12], encoded)
}
}
})
+41 -20
message.go
···
package magna
import (
"math/rand"
)
···
func (m *Message) Decode(buf []byte) (err error) {
offset, err := m.Header.Decode(buf, 0)
if err != nil {
-
return err
}
-
for x := 0; x < int(m.Header.QDCount); x++ {
var question Question
offset, err = question.Decode(buf, offset)
if err != nil {
-
return err
}
m.Question = append(m.Question, question)
}
-
for x := 0; x < int(m.Header.ANCount); x++ {
var rr ResourceRecord
offset, err = rr.Decode(buf, offset)
if err != nil {
-
return err
}
m.Answer = append(m.Answer, rr)
}
-
for x := 0; x < int(m.Header.NSCount); x++ {
var rr ResourceRecord
offset, err = rr.Decode(buf, offset)
if err != nil {
-
return err
}
m.Authority = append(m.Authority, rr)
}
-
for x := 0; x < int(m.Header.ARCount); x++ {
var rr ResourceRecord
offset, err = rr.Decode(buf, offset)
if err != nil {
-
return err
}
m.Additional = append(m.Additional, rr)
···
// Encode encodes a message to a DNS packet.
// TODO: set truncation bit if over 512 and udp is protocol
-
func (m *Message) Encode() []byte {
m.offsets = make(map[string]uint16)
bytes := make([]byte, 0, 512)
-
bytes = append(bytes, m.Header.Encode()...)
-
for _, question := range m.Question {
-
bytes = question.Encode(bytes, &m.offsets)
}
-
for _, answer := range m.Answer {
-
bytes = answer.Encode(bytes, &m.offsets)
}
-
for _, authority := range m.Authority {
-
bytes = authority.Encode(bytes, &m.offsets)
}
-
for _, additional := range m.Additional {
-
bytes = additional.Encode(bytes, &m.offsets)
}
-
return bytes
}
func CreateRequest(op OPCode, rd bool) *Message {
···
package magna
import (
+
"fmt"
"math/rand"
)
···
func (m *Message) Decode(buf []byte) (err error) {
offset, err := m.Header.Decode(buf, 0)
if err != nil {
+
return fmt.Errorf("failed to decode message header: %w", err)
}
+
m.Question = make([]Question, 0, m.Header.QDCount)
+
for i := range m.Header.QDCount {
var question Question
offset, err = question.Decode(buf, offset)
if err != nil {
+
return fmt.Errorf("failed to decode question #%d: %w", i+1, err)
}
m.Question = append(m.Question, question)
}
+
m.Answer = make([]ResourceRecord, 0, m.Header.ANCount)
+
for i := range m.Header.ANCount {
var rr ResourceRecord
offset, err = rr.Decode(buf, offset)
if err != nil {
+
return fmt.Errorf("failed to decode answer record #%d: %w", i+1, err)
}
m.Answer = append(m.Answer, rr)
}
+
m.Authority = make([]ResourceRecord, 0, m.Header.NSCount)
+
for i := range m.Header.NSCount {
var rr ResourceRecord
offset, err = rr.Decode(buf, offset)
if err != nil {
+
return fmt.Errorf("failed to decode authority record #%d: %w", i+1, err)
}
m.Authority = append(m.Authority, rr)
}
+
m.Additional = make([]ResourceRecord, 0, m.Header.ARCount)
+
for i := range m.Header.ARCount {
var rr ResourceRecord
offset, err = rr.Decode(buf, offset)
if err != nil {
+
return fmt.Errorf("failed to decode additional record #%d: %w", i+1, err)
}
m.Additional = append(m.Additional, rr)
···
// Encode encodes a message to a DNS packet.
// TODO: set truncation bit if over 512 and udp is protocol
+
func (m *Message) Encode() ([]byte, error) {
m.offsets = make(map[string]uint16)
bytes := make([]byte, 0, 512)
+
+
headerBytes := m.Header.Encode()
+
bytes = append(bytes, headerBytes...)
+
+
var err error
+
for i, question := range m.Question {
+
bytes, err = question.Encode(bytes, &m.offsets)
+
if err != nil {
+
return nil, fmt.Errorf("failed to encode question #%d (%s): %w", i+1, question.QName, err)
+
}
}
+
for i, answer := range m.Answer {
+
bytes, err = answer.Encode(bytes, &m.offsets)
+
if err != nil {
+
return nil, fmt.Errorf("failed to encode answer record #%d (%s): %w", i+1, answer.Name, err)
+
}
}
+
for i, authority := range m.Authority {
+
bytes, err = authority.Encode(bytes, &m.offsets)
+
if err != nil {
+
return nil, fmt.Errorf("failed to encode authority record #%d (%s): %w", i+1, authority.Name, err)
+
}
}
+
for i, additional := range m.Additional {
+
bytes, err = additional.Encode(bytes, &m.offsets)
+
if err != nil {
+
return nil, fmt.Errorf("failed to encode additional record #%d (%s): %w", i+1, additional.Name, err)
+
}
}
+
return bytes, nil
}
func CreateRequest(op OPCode, rd bool) *Message {
+232 -52
message_test.go
···
import (
"bytes"
"encoding/binary"
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func TestMessageDecode(t *testing.T) {
tests := []struct {
-
name string
-
input []byte
-
expected Message
-
wantErr bool
}{
{
-
name: "Valid DNS message with one question",
-
input: func() []byte {
-
buf := new(bytes.Buffer)
-
binary.Write(buf, binary.BigEndian, uint16(1234))
-
binary.Write(buf, binary.BigEndian, uint16(0x0100))
-
binary.Write(buf, binary.BigEndian, uint16(1))
-
binary.Write(buf, binary.BigEndian, uint16(0))
-
binary.Write(buf, binary.BigEndian, uint16(0))
-
binary.Write(buf, binary.BigEndian, uint16(0))
-
buf.Write([]byte{3, 'w', 'w', 'w', 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0})
-
binary.Write(buf, binary.BigEndian, uint16(1))
-
binary.Write(buf, binary.BigEndian, uint16(1))
-
return buf.Bytes()
-
}(),
expected: Message{
Header: Header{
ID: 1234,
QR: false,
RD: true,
-
OPCode: 0,
QDCount: 1,
},
Question: []Question{
{
QName: "www.example.com",
-
QType: 1,
-
QClass: 1,
},
},
},
wantErr: false,
},
{
-
name: "Valid DNS message with one answer",
-
input: func() []byte {
-
buf := new(bytes.Buffer)
-
binary.Write(buf, binary.BigEndian, uint16(5678))
-
binary.Write(buf, binary.BigEndian, uint16(0x8180))
-
binary.Write(buf, binary.BigEndian, uint16(0))
-
binary.Write(buf, binary.BigEndian, uint16(1))
-
binary.Write(buf, binary.BigEndian, uint16(0))
-
binary.Write(buf, binary.BigEndian, uint16(0))
-
buf.Write([]byte{3, 'w', 'w', 'w', 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0})
-
binary.Write(buf, binary.BigEndian, uint16(1))
-
binary.Write(buf, binary.BigEndian, uint16(1))
-
binary.Write(buf, binary.BigEndian, uint32(3600))
-
binary.Write(buf, binary.BigEndian, uint16(4))
-
binary.Write(buf, binary.BigEndian, uint32(0x0A000001))
-
return buf.Bytes()
-
}(),
expected: Message{
Header: Header{
ID: 5678,
···
RCode: 0,
ANCount: 1,
},
Answer: []ResourceRecord{
{
Name: "www.example.com",
-
RType: 1,
-
RClass: 1,
TTL: 3600,
RDLength: 4,
-
RData: &A{net.IP([]byte{10, 0, 0, 1})},
},
},
},
wantErr: false,
},
{
-
name: "Invalid input - empty buffer",
-
input: []byte{},
-
wantErr: true,
},
}
···
err := m.Decode(tt.input)
if tt.wantErr {
-
assert.Error(t, err)
} else {
-
assert.NoError(t, err)
-
assert.Equal(t, tt.expected.Header, m.Header)
-
assert.Equal(t, tt.expected.Question, m.Question)
-
assert.Equal(t, tt.expected.Answer, m.Answer)
-
assert.Equal(t, tt.expected.Authority, m.Authority)
-
assert.Equal(t, tt.expected.Additional, m.Additional)
}
})
}
}
func FuzzDecodeMessage(f *testing.F) {
testcases := [][]byte{
{
···
}
f.Fuzz(func(t *testing.T, msg []byte) {
var m Message
-
m.Decode(msg)
})
}
···
import (
"bytes"
"encoding/binary"
+
"errors"
"net"
+
"strings"
"testing"
"github.com/stretchr/testify/assert"
+
"github.com/stretchr/testify/require"
)
func TestMessageDecode(t *testing.T) {
+
buildQuery := func(id uint16, qname string, qtype DNSType, qclass DNSClass) []byte {
+
buf := new(bytes.Buffer)
+
binary.Write(buf, binary.BigEndian, id)
+
binary.Write(buf, binary.BigEndian, uint16(0x0100))
+
binary.Write(buf, binary.BigEndian, uint16(1))
+
binary.Write(buf, binary.BigEndian, uint16(0))
+
binary.Write(buf, binary.BigEndian, uint16(0))
+
binary.Write(buf, binary.BigEndian, uint16(0))
+
offsets := make(map[string]uint16)
+
qBytes, err := encodeDomain([]byte{}, qname, &offsets)
+
require.NoError(t, err)
+
buf.Write(qBytes)
+
binary.Write(buf, binary.BigEndian, uint16(qtype))
+
binary.Write(buf, binary.BigEndian, uint16(qclass))
+
return buf.Bytes()
+
}
+
+
buildAnswer := func(id uint16, name string, rtype DNSType, rclass DNSClass, ttl uint32, rdata ResourceRecordData) []byte {
+
buf := new(bytes.Buffer)
+
+
binary.Write(buf, binary.BigEndian, id)
+
binary.Write(buf, binary.BigEndian, uint16(0x8180))
+
binary.Write(buf, binary.BigEndian, uint16(0))
+
binary.Write(buf, binary.BigEndian, uint16(1))
+
binary.Write(buf, binary.BigEndian, uint16(0))
+
binary.Write(buf, binary.BigEndian, uint16(0))
+
rr := ResourceRecord{
+
Name: name,
+
RType: rtype,
+
RClass: rclass,
+
TTL: ttl,
+
RData: rdata,
+
}
+
offsets := make(map[string]uint16)
+
rrBytes, err := rr.Encode([]byte{}, &offsets)
+
require.NoError(t, err)
+
buf.Write(rrBytes)
+
return buf.Bytes()
+
}
+
tests := []struct {
+
name string
+
input []byte
+
expected Message
+
wantErr bool
+
wantErrType error
+
wantErrMsg string
}{
{
+
name: "Valid DNS query message with one question",
+
input: buildQuery(1234, "www.example.com", AType, IN),
expected: Message{
Header: Header{
ID: 1234,
QR: false,
RD: true,
+
OPCode: OPCode(0),
QDCount: 1,
+
RCode: NOERROR,
},
Question: []Question{
{
QName: "www.example.com",
+
QType: AType,
+
QClass: IN,
},
},
+
Answer: []ResourceRecord{},
+
Additional: []ResourceRecord{},
+
Authority: []ResourceRecord{},
},
wantErr: false,
},
{
+
name: "Valid DNS answer message with one A record",
+
input: buildAnswer(5678, "www.example.com", AType, IN, 3600,
+
&A{Address: net.ParseIP("10.0.0.1").To4()},
+
),
expected: Message{
Header: Header{
ID: 5678,
···
RCode: 0,
ANCount: 1,
},
+
Question: []Question{},
Answer: []ResourceRecord{
{
Name: "www.example.com",
+
RType: AType,
+
RClass: IN,
TTL: 3600,
RDLength: 4,
+
RData: &A{Address: net.IP([]byte{10, 0, 0, 1})},
},
},
+
Additional: []ResourceRecord{},
+
Authority: []ResourceRecord{},
},
wantErr: false,
},
{
+
name: "Invalid input - empty buffer",
+
input: []byte{},
+
wantErr: true,
+
wantErrType: &BufferOverflowError{},
+
wantErrMsg: "failed to decode message header: header decode: failed to read ID",
+
},
+
{
+
name: "Invalid input - truncated header",
+
input: []byte{0x12, 0x34},
+
wantErr: true,
+
wantErrType: &BufferOverflowError{},
+
wantErrMsg: "failed to decode message header: header decode: failed to read flags",
+
},
+
{
+
name: "Invalid input - truncated question name",
+
input: func() []byte {
+
buf := new(bytes.Buffer)
+
binary.Write(buf, binary.BigEndian, uint16(1235))
+
binary.Write(buf, binary.BigEndian, uint16(0x0100))
+
binary.Write(buf, binary.BigEndian, uint16(1))
+
binary.Write(buf, binary.BigEndian, uint16(0))
+
binary.Write(buf, binary.BigEndian, uint16(0))
+
binary.Write(buf, binary.BigEndian, uint16(0))
+
buf.Write([]byte{7, 'e', 'x', 'a'})
+
return buf.Bytes()
+
}(),
+
wantErr: true,
+
wantErrType: &BufferOverflowError{},
+
wantErrMsg: "failed to decode question #1:",
+
},
+
{
+
name: "Invalid input - truncated answer record data",
+
input: func() []byte {
+
buf := new(bytes.Buffer)
+
+
binary.Write(buf, binary.BigEndian, uint16(5679))
+
binary.Write(buf, binary.BigEndian, uint16(0x8180))
+
binary.Write(buf, binary.BigEndian, uint16(0))
+
binary.Write(buf, binary.BigEndian, uint16(1))
+
binary.Write(buf, binary.BigEndian, uint16(0))
+
binary.Write(buf, binary.BigEndian, uint16(0))
+
+
offsets := make(map[string]uint16)
+
nameBytes, _ := encodeDomain([]byte{}, "example.com", &offsets)
+
buf.Write(nameBytes)
+
binary.Write(buf, binary.BigEndian, uint16(AType))
+
binary.Write(buf, binary.BigEndian, uint16(IN))
+
binary.Write(buf, binary.BigEndian, uint32(300))
+
binary.Write(buf, binary.BigEndian, uint16(4))
+
+
buf.Write([]byte{192, 168})
+
return buf.Bytes()
+
}(),
+
wantErr: true,
+
wantErrType: &BufferOverflowError{},
+
wantErrMsg: "failed to decode answer record #1:",
},
}
···
err := m.Decode(tt.input)
if tt.wantErr {
+
assert.Error(t, err, "Expected an error but got nil")
+
if tt.wantErrType != nil {
+
assert.True(t, errors.Is(err, tt.wantErrType), "Error type mismatch. Got %T, expected %T", err, tt.wantErrType)
+
}
+
if tt.wantErrMsg != "" {
+
assert.ErrorContains(t, err, tt.wantErrMsg, "Error message mismatch")
+
}
} else {
+
assert.NoError(t, err, "Expected no error but got one")
+
+
assert.Equal(t, tt.expected.Header.ID, m.Header.ID, "Header ID mismatch")
+
assert.Equal(t, tt.expected.Header.QR, m.Header.QR, "Header QR mismatch")
+
assert.Equal(t, tt.expected.Header.OPCode, m.Header.OPCode, "Header OPCode mismatch")
+
assert.Equal(t, tt.expected.Header.RCode, m.Header.RCode, "Header RCode mismatch")
+
assert.Equal(t, tt.expected.Header.QDCount, m.Header.QDCount, "Header QDCount mismatch")
+
assert.Equal(t, tt.expected.Header.ANCount, m.Header.ANCount, "Header ANCount mismatch")
+
+
assert.Equal(t, tt.expected.Question, m.Question, "Question section mismatch")
+
assert.Equal(t, tt.expected.Answer, m.Answer, "Answer section mismatch")
+
assert.Equal(t, tt.expected.Authority, m.Authority, "Authority section mismatch")
+
assert.Equal(t, tt.expected.Additional, m.Additional, "Additional section mismatch")
}
})
}
}
+
func TestMessageEncodeDecodeRoundTrip(t *testing.T) {
+
tests := []struct {
+
name string
+
message *Message
+
}{
+
{
+
name: "Query with one question",
+
message: CreateRequest(QUERY, true).AddQuestion(Question{
+
QName: "google.com",
+
QType: AType,
+
QClass: IN,
+
}),
+
},
+
{
+
name: "Response with one A answer",
+
message: &Message{
+
Header: Header{
+
ID: 12345, QR: true, OPCode: QUERY, RD: true, RA: true, RCode: NOERROR, ANCount: 1,
+
},
+
Question: []Question{},
+
Answer: []ResourceRecord{
+
{Name: "test.local", RType: AType, RClass: IN, TTL: 60, RDLength: 4, RData: &A{net.ParseIP("192.0.2.1").To4()}},
+
},
+
Additional: []ResourceRecord{},
+
Authority: []ResourceRecord{},
+
},
+
},
+
{
+
name: "Response with multiple answers and compression",
+
message: &Message{
+
Header: Header{ID: 54321, QR: true, RCode: NOERROR, ANCount: 2},
+
Question: []Question{},
+
Answer: []ResourceRecord{
+
{Name: "www.example.com", RType: AType, RClass: IN, TTL: 300, RDLength: 4, RData: &A{net.ParseIP("192.0.2.2").To4()}},
+
{Name: "mail.example.com", RType: AType, RClass: IN, TTL: 300, RDLength: 4, RData: &A{net.ParseIP("192.0.2.3").To4()}},
+
},
+
Additional: []ResourceRecord{},
+
Authority: []ResourceRecord{},
+
},
+
},
+
{
+
name: "Message with various record types",
+
message: &Message{
+
Header: Header{ID: 1111, QR: true, RCode: NOERROR, ANCount: 3},
+
Question: []Question{},
+
Answer: []ResourceRecord{
+
{Name: "example.com", RType: MXType, RClass: IN, TTL: 3600, RDLength: 9, RData: &MX{Preference: 10, Exchange: "mail.example.com"}},
+
{Name: "mail.example.com", RType: AType, RClass: IN, TTL: 300, RDLength: 4, RData: &A{net.ParseIP("192.0.2.4").To4()}},
+
{Name: "example.com", RType: TXTType, RClass: IN, TTL: 600, RDLength: 36, RData: &TXT{TxtData: []string{"v=spf1 include:_spf.google.com ~all"}}},
+
},
+
Additional: []ResourceRecord{},
+
Authority: []ResourceRecord{},
+
},
+
},
+
}
+
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
encodedBytes, err := tt.message.Encode()
+
require.NoError(t, err, "Encoding failed unexpectedly")
+
require.NotEmpty(t, encodedBytes, "Encoded bytes should not be empty")
+
+
decodedMsg := &Message{}
+
err = decodedMsg.Decode(encodedBytes)
+
require.NoError(t, err, "Decoding failed unexpectedly")
+
+
assert.Equal(t, tt.message.Header.ID, decodedMsg.Header.ID, "Header ID mismatch")
+
assert.Equal(t, tt.message.Header.QR, decodedMsg.Header.QR, "Header QR mismatch")
+
assert.Equal(t, tt.message.Header.OPCode, decodedMsg.Header.OPCode, "Header OPCode mismatch")
+
assert.Equal(t, tt.message.Header.RCode, decodedMsg.Header.RCode, "Header RCode mismatch")
+
+
assert.Equal(t, tt.message.Question, decodedMsg.Question, "Question section mismatch")
+
assert.Equal(t, tt.message.Answer, decodedMsg.Answer, "Answer section mismatch")
+
assert.Equal(t, tt.message.Authority, decodedMsg.Authority, "Authority section mismatch")
+
assert.Equal(t, tt.message.Additional, decodedMsg.Additional, "Additional section mismatch")
+
})
+
}
+
}
+
func FuzzDecodeMessage(f *testing.F) {
testcases := [][]byte{
{
···
}
f.Fuzz(func(t *testing.T, msg []byte) {
var m Message
+
err := m.Decode(msg)
+
if err != nil {
+
var bufErr *BufferOverflowError
+
var labelErr *InvalidLabelError
+
var compErr *DomainCompressionError
+
if !(errors.As(err, &bufErr) || errors.As(err, &labelErr) || errors.As(err, &compErr) || strings.Contains(err.Error(), "record:")) {
+
t.Errorf("FuzzDecodeMessage: unexpected error type %T: %v for input %x", err, err, msg)
+
}
+
}
})
}
+16 -8
question.go
···
package magna
-
import "encoding/binary"
// Decode decodes a question from buf at the offset
func (q *Question) Decode(buf []byte, offset int) (int, error) {
var err error
-
q.QName, offset, err = decode_domain(buf, offset)
if err != nil {
-
return offset, err
}
qtype, offset, err := getU16(buf, offset)
if err != nil {
-
return offset, err
}
qclass, offset, err := getU16(buf, offset)
if err != nil {
-
return offset, err
}
q.QType = DNSType(qtype)
···
}
// Encode serializes a Question into bytes, using a map to handle domain name compression offsets.
-
func (q *Question) Encode(bytes []byte, offsets *map[string]uint16) []byte {
-
bytes = encode_domain(bytes, q.QName, offsets)
bytes = binary.BigEndian.AppendUint16(bytes, uint16(q.QType))
bytes = binary.BigEndian.AppendUint16(bytes, uint16(q.QClass))
-
return bytes
}
···
package magna
+
import (
+
"encoding/binary"
+
"fmt"
+
)
// Decode decodes a question from buf at the offset
func (q *Question) Decode(buf []byte, offset int) (int, error) {
var err error
+
q.QName, offset, err = decodeDomain(buf, offset)
if err != nil {
+
return offset, fmt.Errorf("question decode: failed to decode QName: %w", err)
}
qtype, offset, err := getU16(buf, offset)
if err != nil {
+
return offset, fmt.Errorf("question decode: failed to decode QType for %s: %w", q.QName, err)
}
qclass, offset, err := getU16(buf, offset)
if err != nil {
+
return offset, fmt.Errorf("question decode: failed to decode QClass for %s: %w", q.QName, err)
}
q.QType = DNSType(qtype)
···
}
// Encode serializes a Question into bytes, using a map to handle domain name compression offsets.
+
func (q *Question) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) {
+
var err error
+
bytes, err = encodeDomain(bytes, q.QName, offsets)
+
if err != nil {
+
return nil, fmt.Errorf("question encode: failed to encode QName %s: %w", q.QName, err)
+
}
+
bytes = binary.BigEndian.AppendUint16(bytes, uint16(q.QType))
bytes = binary.BigEndian.AppendUint16(bytes, uint16(q.QClass))
+
return bytes, nil
}
+103 -82
question_test.go
···
package magna
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestQuestionDecode(t *testing.T) {
···
expectedOffset int
expected Question
expectedErr error
}{
{
name: "Valid question - example.com A IN",
···
expectedErr: nil,
},
{
-
name: "Invalid domain name",
-
input: []byte{255, 'i', 'n', 'v', 'a', 'l', 'i', 'd', 0, 0, 1, 0, 1},
expectedOffset: 13,
expected: Question{},
-
expectedErr: &BufferOverflowError{},
},
{
name: "Insufficient buffer for QType",
···
expectedOffset: 14,
expected: Question{QName: "example.com"},
expectedErr: &BufferOverflowError{},
},
{
name: "Insufficient buffer for QClass",
···
expectedOffset: 16,
expected: Question{QName: "example.com", QType: DNSType(1)},
expectedErr: &BufferOverflowError{},
},
}
···
q := &Question{}
offset, err := q.Decode(tt.input, 0)
-
assert.Equal(t, tt.expectedOffset, offset)
-
if tt.expectedErr != nil {
-
assert.Error(t, err)
-
assert.IsType(t, tt.expectedErr, err)
} else {
-
assert.NoError(t, err)
-
assert.Equal(t, tt.expected, *q)
}
})
}
···
func TestQuestionEncode(t *testing.T) {
tests := []struct {
-
name string
-
question Question
-
offsets map[string]uint16
-
expected []byte
}{
{
name: "Simple domain - example.com A IN",
···
QType: DNSType(1),
QClass: DNSClass(1),
},
-
offsets: make(map[string]uint16),
-
expected: []byte{7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 0, 1, 0, 1},
},
{
name: "Subdomain - subdomain.example.com AAAA IN",
···
QType: DNSType(28),
QClass: DNSClass(1),
},
-
offsets: make(map[string]uint16),
-
expected: []byte{9, 's', 'u', 'b', 'd', 'o', 'm', 'a', 'i', 'n', 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 0, 28, 0, 1},
},
{
name: "Different class - example.com MX CH",
···
QType: DNSType(15),
QClass: DNSClass(3),
},
-
offsets: make(map[string]uint16),
-
expected: []byte{7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 0, 15, 0, 3},
},
{
name: "Domain compression - example.com after subdomain.example.com",
···
QType: DNSType(1),
QClass: DNSClass(1),
},
offsets: map[string]uint16{
-
"com": 22,
-
"example.com": 19,
},
-
expected: []byte{0xC0, 0x13, 0x00, 0x01, 0x00, 0x01},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
-
result := tt.question.Encode(nil, &tt.offsets)
-
assert.Equal(t, tt.expected, result)
-
if len(tt.offsets) == 0 {
-
expectedOffsets := map[string]uint16{
-
tt.question.QName: 0,
-
}
-
for i := 0; i < len(tt.question.QName); i++ {
-
if tt.question.QName[i] == '.' {
-
expectedOffsets[tt.question.QName[i+1:]] = uint16(i + 1)
-
}
}
-
assert.Equal(t, expectedOffsets, tt.offsets)
}
})
}
···
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
offsets := make(map[string]uint16)
-
encoded := tt.question.Encode(nil, &offsets)
decodedQuestion := &Question{}
-
_, err := decodedQuestion.Decode(encoded, 0)
-
assert.NoError(t, err)
-
assert.Equal(t, tt.question, *decodedQuestion)
})
}
}
-
-
func TestQuestionEncodeWithExistingBuffer(t *testing.T) {
-
question := Question{
-
QName: "example.com",
-
QType: DNSType(1),
-
QClass: DNSClass(1),
-
}
-
-
existingBuffer := []byte{0xFF, 0xFF, 0xFF, 0xFF}
-
offsets := make(map[string]uint16)
-
-
result := question.Encode(existingBuffer, &offsets)
-
-
expected := append(
-
existingBuffer,
-
[]byte{7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 0, 1, 0, 1}...,
-
)
-
-
assert.Equal(t, expected, result)
-
}
-
-
func TestQuestionEncodeLongDomainName(t *testing.T) {
-
longLabel := make([]byte, 63)
-
for i := range longLabel {
-
longLabel[i] = 'a'
-
}
-
longDomainName := string(longLabel) + "." + string(longLabel) + "." + string(longLabel) + "." + string(longLabel[:61])
-
-
question := Question{
-
QName: longDomainName,
-
QType: DNSType(1),
-
QClass: DNSClass(1),
-
}
-
-
offsets := make(map[string]uint16)
-
encoded := question.Encode(nil, &offsets)
-
-
assert.Equal(t, 259, len(encoded))
-
-
decodedQuestion := &Question{}
-
_, err := decodedQuestion.Decode(encoded, 0)
-
-
assert.NoError(t, err)
-
assert.Equal(t, question, *decodedQuestion)
-
}
···
package magna
import (
+
"errors"
"testing"
"github.com/stretchr/testify/assert"
+
"github.com/stretchr/testify/require"
)
func TestQuestionDecode(t *testing.T) {
···
expectedOffset int
expected Question
expectedErr error
+
wantErrMsg string
}{
{
name: "Valid question - example.com A IN",
···
expectedErr: nil,
},
{
+
name: "Invalid domain name - label too long",
+
input: []byte{64, 'i', 'n', 'v', 'a', 'l', 'i', 'd', 0, 0, 1, 0, 1},
expectedOffset: 13,
expected: Question{},
+
expectedErr: &InvalidLabelError{},
+
wantErrMsg: "question decode: failed to decode QName: invalid domain label: length 64 exceeds maximum 63",
+
},
+
{
+
name: "Invalid domain name - compression loop",
+
input: []byte{0xC0, 0x00, 0, 1, 0, 1},
+
expectedOffset: 6,
+
expected: Question{},
+
expectedErr: &DomainCompressionError{},
+
wantErrMsg: "question decode: failed to decode QName: invalid domain compression: pointer loop detected",
},
{
name: "Insufficient buffer for QType",
···
expectedOffset: 14,
expected: Question{QName: "example.com"},
expectedErr: &BufferOverflowError{},
+
wantErrMsg: "question decode: failed to decode QType for example.com: buffer overflow",
},
{
name: "Insufficient buffer for QClass",
···
expectedOffset: 16,
expected: Question{QName: "example.com", QType: DNSType(1)},
expectedErr: &BufferOverflowError{},
+
wantErrMsg: "question decode: failed to decode QClass for example.com: buffer overflow",
},
}
···
q := &Question{}
offset, err := q.Decode(tt.input, 0)
if tt.expectedErr != nil {
+
assert.Error(t, err, "Expected an error but got nil")
+
assert.True(t, errors.Is(err, tt.expectedErr), "Error type mismatch. Got %T, expected %T", err, tt.expectedErr)
+
if tt.wantErrMsg != "" {
+
assert.ErrorContains(t, err, tt.wantErrMsg, "Wrapped error message mismatch")
+
}
+
assert.Equal(t, tt.expectedOffset, offset, "Offset mismatch on error")
} else {
+
assert.NoError(t, err, "Expected no error but got one")
+
assert.Equal(t, tt.expected, *q, "Decoded question mismatch")
+
assert.Equal(t, tt.expectedOffset, offset, "Offset mismatch on success")
}
})
}
···
func TestQuestionEncode(t *testing.T) {
tests := []struct {
+
name string
+
question Question
+
initialBuf []byte
+
offsets map[string]uint16
+
expected []byte
+
expectedErr error
+
wantErrMsg string
+
newOffsets map[string]uint16
}{
{
name: "Simple domain - example.com A IN",
···
QType: DNSType(1),
QClass: DNSClass(1),
},
+
initialBuf: nil,
+
offsets: make(map[string]uint16),
+
expected: []byte{7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 0, 1, 0, 1},
+
newOffsets: map[string]uint16{"example.com": 0, "com": 8},
},
{
name: "Subdomain - subdomain.example.com AAAA IN",
···
QType: DNSType(28),
QClass: DNSClass(1),
},
+
initialBuf: nil,
+
offsets: make(map[string]uint16),
+
expected: []byte{9, 's', 'u', 'b', 'd', 'o', 'm', 'a', 'i', 'n', 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 0, 28, 0, 1},
+
newOffsets: map[string]uint16{"subdomain.example.com": 0, "example.com": 10, "com": 18},
},
{
name: "Different class - example.com MX CH",
···
QType: DNSType(15),
QClass: DNSClass(3),
},
+
initialBuf: nil,
+
offsets: make(map[string]uint16),
+
expected: []byte{7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 0, 15, 0, 3},
+
newOffsets: map[string]uint16{"example.com": 0, "com": 8},
},
{
name: "Domain compression - example.com after subdomain.example.com",
···
QType: DNSType(1),
QClass: DNSClass(1),
},
+
initialBuf: nil,
offsets: map[string]uint16{
+
"subdomain.example.com": 0,
+
"example.com": 10,
+
"com": 18,
+
},
+
expected: []byte{0xC0, 0x0a, 0x00, 0x01, 0x00, 0x01},
+
newOffsets: map[string]uint16{
+
"subdomain.example.com": 0,
+
"example.com": 10,
+
"com": 18,
+
},
+
},
+
{
+
name: "Encode with initial buffer",
+
question: Question{
+
QName: "test.org",
+
QType: AType,
+
QClass: IN,
+
},
+
initialBuf: []byte{0xAA, 0xBB},
+
offsets: make(map[string]uint16),
+
expected: []byte{0xAA, 0xBB, 4, 't', 'e', 's', 't', 3, 'o', 'r', 'g', 0, 0, 1, 0, 1},
+
newOffsets: map[string]uint16{"test.org": 2, "org": 7},
+
},
+
{
+
name: "Encode invalid domain - label too long",
+
question: Question{
+
QName: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.com",
+
QType: AType,
+
QClass: IN,
},
+
initialBuf: nil,
+
offsets: make(map[string]uint16),
+
expected: nil,
+
expectedErr: &InvalidLabelError{},
+
wantErrMsg: "question encode: failed to encode QName",
+
newOffsets: map[string]uint16{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
+
currentOffsets := make(map[string]uint16)
+
for k, v := range tt.offsets {
+
currentOffsets[k] = v
+
}
+
+
result, err := tt.question.Encode(tt.initialBuf, &currentOffsets)
+
if tt.expectedErr != nil {
+
assert.Error(t, err, "Expected an error but got nil")
+
assert.True(t, errors.Is(err, tt.expectedErr), "Error type mismatch. Got %T, expected %T", err, tt.expectedErr)
+
if tt.wantErrMsg != "" {
+
assert.ErrorContains(t, err, tt.wantErrMsg, "Wrapped error message mismatch")
}
+
} else {
+
assert.NoError(t, err, "Expected no error but got one")
+
assert.Equal(t, tt.expected, result, "Encoded question mismatch")
+
assert.Equal(t, tt.newOffsets, currentOffsets, "Final offsets mismatch")
}
})
}
···
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
offsets := make(map[string]uint16)
+
encoded, err := tt.question.Encode(nil, &offsets)
+
require.NoError(t, err, "Encoding failed")
decodedQuestion := &Question{}
+
offset, err := decodedQuestion.Decode(encoded, 0)
+
assert.NoError(t, err, "Decoding failed")
+
assert.Equal(t, len(encoded), offset, "Offset after decoding should match encoded length")
+
assert.Equal(t, tt.question, *decodedQuestion, "Decoded question does not match original")
})
}
}
+208 -126
resource_record.go
···
func (a *A) Decode(buf []byte, offset int, rdlength int) (int, error) {
bytes, offset, err := getSlice(buf, offset, rdlength)
if err != nil {
-
return offset, err
}
a.Address = net.IP(bytes)
-
return offset, err
}
-
func (a *A) Encode(bytes []byte, offsets *map[string]uint16) []byte {
-
return append(bytes, a.Address.To4()...)
}
func (a A) String() string {
···
func (ns *NS) Decode(buf []byte, offset int, rdlength int) (int, error) {
var err error
-
ns.NSDName, offset, err = decode_domain(buf, offset)
if err != nil {
-
return offset, err
}
-
return offset, err
}
-
func (ns *NS) Encode(bytes []byte, offsets *map[string]uint16) []byte {
-
return encode_domain(bytes, ns.NSDName, offsets)
}
func (ns NS) String() string {
···
func (md *MD) Decode(buf []byte, offset int, rdlength int) (int, error) {
var err error
-
md.MADName, offset, err = decode_domain(buf, offset)
if err != nil {
-
return offset, err
}
-
return offset, err
}
-
func (md *MD) Encode(bytes []byte, offsets *map[string]uint16) []byte {
-
return encode_domain(bytes, md.MADName, offsets)
}
func (md MD) String() string {
···
func (mf *MF) Decode(buf []byte, offset int, rdlength int) (int, error) {
var err error
-
mf.MADName, offset, err = decode_domain(buf, offset)
if err != nil {
-
return offset, err
}
-
return offset, err
}
-
func (mf *MF) Encode(bytes []byte, offsets *map[string]uint16) []byte {
-
return encode_domain(bytes, mf.MADName, offsets)
}
func (mf MF) String() string {
···
func (c *CNAME) Decode(buf []byte, offset int, rdlength int) (int, error) {
var err error
-
c.CName, offset, err = decode_domain(buf, offset)
if err != nil {
-
return offset, err
}
-
return offset, err
}
-
func (c *CNAME) Encode(bytes []byte, offsets *map[string]uint16) []byte {
-
return encode_domain(bytes, c.CName, offsets)
}
func (c CNAME) String() string {
···
func (soa *SOA) Decode(buf []byte, offset int, rdlength int) (int, error) {
var err error
-
soa.MName, offset, err = decode_domain(buf, offset)
if err != nil {
-
return offset, err
}
-
soa.RName, offset, err = decode_domain(buf, offset)
if err != nil {
-
return offset, err
}
soa.Serial, offset, err = getU32(buf, offset)
if err != nil {
-
return offset, err
}
soa.Refresh, offset, err = getU32(buf, offset)
if err != nil {
-
return offset, err
}
soa.Retry, offset, err = getU32(buf, offset)
if err != nil {
-
return offset, err
}
soa.Expire, offset, err = getU32(buf, offset)
if err != nil {
-
return offset, err
}
soa.Minimum, offset, err = getU32(buf, offset)
if err != nil {
-
return offset, err
}
-
return offset, err
}
-
func (soa *SOA) Encode(bytes []byte, offsets *map[string]uint16) []byte {
-
bytes = encode_domain(bytes, soa.MName, offsets)
-
bytes = encode_domain(bytes, soa.RName, offsets)
bytes = binary.BigEndian.AppendUint32(bytes, soa.Serial)
bytes = binary.BigEndian.AppendUint32(bytes, soa.Refresh)
bytes = binary.BigEndian.AppendUint32(bytes, soa.Retry)
bytes = binary.BigEndian.AppendUint32(bytes, soa.Expire)
bytes = binary.BigEndian.AppendUint32(bytes, soa.Minimum)
-
return bytes
}
func (soa SOA) String() string {
···
}
func (mb *MB) Decode(buf []byte, offset int, rdlength int) (int, error) {
-
madname, offset, err := decode_domain(buf, offset)
if err != nil {
-
return offset, err
}
mb.MADName = string(madname)
-
return offset, err
}
-
func (mb *MB) Encode(bytes []byte, offsets *map[string]uint16) []byte {
-
return encode_domain(bytes, mb.MADName, offsets)
}
func (mb MB) String() string {
···
func (mg *MG) Decode(buf []byte, offset int, rdlength int) (int, error) {
var err error
-
mg.MGMName, offset, err = decode_domain(buf, offset)
if err != nil {
-
return offset, err
}
-
return offset, err
}
-
func (mg *MG) Encode(bytes []byte, offsets *map[string]uint16) []byte {
-
return encode_domain(bytes, mg.MGMName, offsets)
}
func (mg MG) String() string {
···
func (mr *MR) Decode(buf []byte, offset int, rdlength int) (int, error) {
var err error
-
mr.NEWName, offset, err = decode_domain(buf, offset)
if err != nil {
-
return offset, err
}
-
return offset, err
}
-
func (mr *MR) Encode(bytes []byte, offsets *map[string]uint16) []byte {
-
return encode_domain(bytes, mr.NEWName, offsets)
}
func (mr MR) String() string {
···
var err error
null.Anything, offset, err = getSlice(buf, offset, int(rdlength))
if err != nil {
-
return offset, err
}
-
return offset, err
}
-
func (null *NULL) Encode(bytes []byte, offsets *map[string]uint16) []byte {
-
return append(bytes, null.Anything...)
}
func (null NULL) String() string {
···
func (wks *WKS) Decode(buf []byte, offset int, rdlength int) (int, error) {
if rdlength < 5 {
-
return len(buf), &MagnaError{Message: fmt.Sprintf("magna: WKS RDLENGTH too short: %d", rdlength)}
}
addressBytes, nextOffset, err := getSlice(buf, offset, 4)
if err != nil {
-
return len(buf), fmt.Errorf("magna: WKS error reading address: %w", err)
}
offset = nextOffset
wks.Address = net.IP(addressBytes)
protocol, nextOffset, err := getU8(buf, offset)
if err != nil {
-
return len(buf), fmt.Errorf("magna: WKS error reading protocol: %w", err)
}
offset = nextOffset
wks.Protocol = protocol
···
bitmapLength := rdlength - 5
wks.BitMap, nextOffset, err = getSlice(buf, offset, bitmapLength)
if err != nil {
-
return len(buf), fmt.Errorf("magna: WKS error reading bitmap: %w", err)
}
offset = nextOffset
return offset, nil
}
-
func (wks *WKS) Encode(bytes []byte, offsets *map[string]uint16) []byte {
bytes = append(bytes, wks.Address.To4()...)
bytes = append(bytes, wks.Protocol)
bytes = append(bytes, wks.BitMap...)
-
return bytes
}
func (wks WKS) String() string {
-
return fmt.Sprintf("%s %d %s", wks.Address.String(), wks.Protocol, wks.BitMap)
}
func (ptr *PTR) Decode(buf []byte, offset int, rdlength int) (int, error) {
var err error
-
ptr.PTRDName, offset, err = decode_domain(buf, offset)
if err != nil {
-
return offset, err
}
-
return offset, err
}
-
func (ptr *PTR) Encode(bytes []byte, offsets *map[string]uint16) []byte {
-
return encode_domain(bytes, ptr.PTRDName, offsets)
}
func (ptr PTR) String() string {
···
}
func (hinfo *HINFO) Decode(buf []byte, offset int, rdlength int) (int, error) {
endOffset := offset + rdlength
if endOffset > len(buf) {
return len(buf), &BufferOverflowError{Length: len(buf), Offset: endOffset}
···
cpuLen, nextOffset, err := getU8(buf, currentOffset)
if err != nil {
-
return len(buf), fmt.Errorf("magna: HINFO error reading CPU length: %w", err)
}
currentOffset = nextOffset
if currentOffset+int(cpuLen) > endOffset {
···
}
cpuBytes, nextOffset, err := getSlice(buf, currentOffset, int(cpuLen))
if err != nil {
-
return len(buf), fmt.Errorf("magna: HINFO error reading CPU data: %w", err)
}
currentOffset = nextOffset
hinfo.CPU = string(cpuBytes)
···
osLen, nextOffset, err := getU8(buf, currentOffset)
if err != nil {
if currentOffset == endOffset {
-
return len(buf), &MagnaError{Message: "magna: HINFO missing OS string"}
}
-
return len(buf), fmt.Errorf("magna: HINFO error reading OS length: %w", err)
}
currentOffset = nextOffset
if currentOffset+int(osLen) > endOffset {
···
}
osBytes, nextOffset, err := getSlice(buf, currentOffset, int(osLen))
if err != nil {
-
return len(buf), fmt.Errorf("magna: HINFO error reading OS data: %w", err)
}
currentOffset = nextOffset
hinfo.OS = string(osBytes)
if currentOffset != endOffset {
-
return len(buf), &MagnaError{Message: fmt.Sprintf("magna: HINFO RDATA length mismatch, expected end at %d, ended at %d", endOffset, currentOffset)}
}
return currentOffset, nil
}
-
func (hinfo *HINFO) Encode(bytes []byte, offsets *map[string]uint16) []byte {
-
// XXX: should probally return an error
if len(hinfo.CPU) > 255 {
-
hinfo.CPU = hinfo.CPU[:255]
}
if len(hinfo.OS) > 255 {
-
hinfo.OS = hinfo.OS[:255]
}
bytes = append(bytes, byte(len(hinfo.CPU)))
bytes = append(bytes, []byte(hinfo.CPU)...)
bytes = append(bytes, byte(len(hinfo.OS)))
bytes = append(bytes, []byte(hinfo.OS)...)
-
return bytes
}
func (hinfo HINFO) String() string {
···
func (minfo *MINFO) Decode(buf []byte, offset int, rdlength int) (int, error) {
var err error
-
minfo.RMailBx, offset, err = decode_domain(buf, offset)
if err != nil {
-
return offset, err
}
-
minfo.EMailBx, offset, err = decode_domain(buf, offset)
if err != nil {
-
return offset, err
}
-
return offset, err
}
-
func (minfo *MINFO) Encode(bytes []byte, offsets *map[string]uint16) []byte {
-
bytes = encode_domain(bytes, minfo.RMailBx, offsets)
-
bytes = encode_domain(bytes, minfo.EMailBx, offsets)
-
return bytes
}
func (minfo MINFO) String() string {
···
var err error
mx.Preference, offset, err = getU16(buf, offset)
if err != nil {
-
return offset, err
}
-
mx.Exchange, offset, err = decode_domain(buf, offset)
if err != nil {
-
return offset, err
}
-
return offset, err
}
-
func (mx *MX) Encode(bytes []byte, offsets *map[string]uint16) []byte {
bytes = binary.BigEndian.AppendUint16(bytes, mx.Preference)
-
bytes = encode_domain(bytes, mx.Exchange, offsets)
-
return bytes
}
func (mx MX) String() string {
···
for currentOffset < endOffset {
strLen, nextOffsetAfterLen, err := getU8(buf, currentOffset)
if err != nil {
-
return len(buf), fmt.Errorf("magna: error reading TXT string length byte: %w", err)
}
nextOffsetAfterData := nextOffsetAfterLen + int(strLen)
if nextOffsetAfterData > endOffset {
-
return len(buf), &MagnaError{
-
Message: fmt.Sprintf("magna: TXT string segment length %d at offset %d exceeds RDLENGTH boundary %d", strLen, nextOffsetAfterLen, endOffset),
-
}
}
strBytes, actualNextOffsetAfterData, err := getSlice(buf, nextOffsetAfterLen, int(strLen))
if err != nil {
-
return len(buf), fmt.Errorf("magna: error reading TXT string data: %w", err)
}
txt.TxtData = append(txt.TxtData, string(strBytes))
···
}
if currentOffset != endOffset {
-
return len(buf), &MagnaError{
-
Message: fmt.Sprintf("magna: TXT RDATA parsing finished at offset %d, but expected end at %d based on RDLENGTH", currentOffset, endOffset),
-
}
}
return currentOffset, nil
}
-
func (txt *TXT) Encode(bytes []byte, offsets *map[string]uint16) []byte {
for _, s := range txt.TxtData {
if len(s) > 255 {
-
// XXX: should return probably an error
-
s = s[:255]
}
bytes = append(bytes, byte(len(s)))
bytes = append(bytes, []byte(s)...)
}
-
return bytes
}
func (txt TXT) String() string {
···
var err error
r.Bytes, offset, err = getSlice(buf, offset, int(rdlength))
if err != nil {
-
return offset, err
}
return offset, err
}
-
func (r *Reserved) Encode(bytes []byte, offsets *map[string]uint16) []byte {
-
return append(bytes, r.Bytes...)
}
func (r Reserved) String() string {
-
return string(r.Bytes)
}
// Decode decodes a resource record from buf at the offset.
func (r *ResourceRecord) Decode(buf []byte, offset int) (int, error) {
-
name, offset, err := decode_domain(buf, offset)
if err != nil {
-
return offset, err
}
-
r.Name = name
-
rtype, offset, err := getU16(buf, offset)
if err != nil {
-
return offset, err
}
r.RType = DNSType(rtype)
-
rclass, offset, err := getU16(buf, offset)
if err != nil {
-
return offset, err
}
r.RClass = DNSClass(rclass)
r.TTL, offset, err = getU32(buf, offset)
if err != nil {
-
return offset, err
}
r.RDLength, offset, err = getU16(buf, offset)
if err != nil {
-
return offset, err
}
switch r.RType {
···
if r.RData != nil {
offset, err = r.RData.Decode(buf, offset, int(r.RDLength))
if err != nil {
-
return offset, err
}
}
···
}
// Encode encdoes a resource record and returns the input bytes appened.
-
func (r *ResourceRecord) Encode(bytes []byte, offsets *map[string]uint16) []byte {
-
bytes = encode_domain(bytes, r.Name, offsets)
bytes = binary.BigEndian.AppendUint16(bytes, uint16(r.RType))
bytes = binary.BigEndian.AppendUint16(bytes, uint16(r.RClass))
bytes = binary.BigEndian.AppendUint32(bytes, r.TTL)
rdata_start := len(bytes)
bytes = binary.BigEndian.AppendUint16(bytes, 0)
-
bytes = r.RData.Encode(bytes, offsets)
rdata_length := uint16(len(bytes) - rdata_start - 2)
binary.BigEndian.PutUint16(bytes[rdata_start:rdata_start+2], rdata_length)
-
return bytes
}
···
func (a *A) Decode(buf []byte, offset int, rdlength int) (int, error) {
bytes, offset, err := getSlice(buf, offset, rdlength)
if err != nil {
+
return offset, fmt.Errorf("A record: failed to read address data: %w", err)
}
a.Address = net.IP(bytes)
+
if a.Address.To4() == nil {
+
return offset, fmt.Errorf("A record: decoded data is not a valid IPv4 address: %v", bytes)
+
}
+
return offset, nil
}
+
func (a *A) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) {
+
ipv4 := a.Address.To4()
+
if ipv4 == nil {
+
return nil, fmt.Errorf("A record: cannot encode non-IPv4 address %s", a.Address.String())
+
}
+
+
return append(bytes, a.Address.To4()...), nil
}
func (a A) String() string {
···
func (ns *NS) Decode(buf []byte, offset int, rdlength int) (int, error) {
var err error
+
ns.NSDName, offset, err = decodeDomain(buf, offset)
if err != nil {
+
return offset, fmt.Errorf("NS record: failed to decode NSDName: %w", err)
}
+
return offset, nil
}
+
func (ns *NS) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) {
+
var err error
+
bytes, err = encodeDomain(bytes, ns.NSDName, offsets)
+
if err != nil {
+
return nil, fmt.Errorf("NS record: failed to encode NSDName %s: %w", ns.NSDName, err)
+
}
+
+
return bytes, nil
}
func (ns NS) String() string {
···
func (md *MD) Decode(buf []byte, offset int, rdlength int) (int, error) {
var err error
+
md.MADName, offset, err = decodeDomain(buf, offset)
if err != nil {
+
return offset, fmt.Errorf("MD record: failed to decode MADName %s: %w", md.MADName, err)
}
+
return offset, nil
}
+
func (md *MD) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) {
+
var err error
+
bytes, err = encodeDomain(bytes, md.MADName, offsets)
+
if err != nil {
+
return nil, fmt.Errorf("MD record: failed to encode MADName %s: %w", md.MADName, err)
+
}
+
+
return bytes, nil
}
func (md MD) String() string {
···
func (mf *MF) Decode(buf []byte, offset int, rdlength int) (int, error) {
var err error
+
mf.MADName, offset, err = decodeDomain(buf, offset)
if err != nil {
+
return offset, fmt.Errorf("MF record: failed to decode MADName: %w", err)
}
+
return offset, nil
}
+
func (mf *MF) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) {
+
var err error
+
bytes, err = encodeDomain(bytes, mf.MADName, offsets)
+
if err != nil {
+
return nil, fmt.Errorf("MF record: failed to encode MADName %s: %w", mf.MADName, err)
+
}
+
+
return bytes, nil
}
func (mf MF) String() string {
···
func (c *CNAME) Decode(buf []byte, offset int, rdlength int) (int, error) {
var err error
+
c.CName, offset, err = decodeDomain(buf, offset)
if err != nil {
+
return offset, fmt.Errorf("CNAME record: failed to decode CNAME: %w", err)
}
+
return offset, nil
}
+
func (c *CNAME) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) {
+
var err error
+
bytes, err = encodeDomain(bytes, c.CName, offsets)
+
if err != nil {
+
return nil, fmt.Errorf("CNAME record: failed to encode CNAME %s: %w", c.CName, err)
+
}
+
+
return bytes, nil
}
func (c CNAME) String() string {
···
func (soa *SOA) Decode(buf []byte, offset int, rdlength int) (int, error) {
var err error
+
soa.MName, offset, err = decodeDomain(buf, offset)
if err != nil {
+
return offset, fmt.Errorf("SOA record: failed to decode MName: %w", err)
}
+
soa.RName, offset, err = decodeDomain(buf, offset)
if err != nil {
+
return offset, fmt.Errorf("SOA record: failed to decode RName: %w", err)
}
soa.Serial, offset, err = getU32(buf, offset)
if err != nil {
+
return offset, fmt.Errorf("SOA record: failed to decode Serial: %w", err)
}
soa.Refresh, offset, err = getU32(buf, offset)
if err != nil {
+
return offset, fmt.Errorf("SOA record: failed to decode Refresh: %w", err)
}
soa.Retry, offset, err = getU32(buf, offset)
if err != nil {
+
return offset, fmt.Errorf("SOA record: failed to decode Retry: %w", err)
}
soa.Expire, offset, err = getU32(buf, offset)
if err != nil {
+
return offset, fmt.Errorf("SOA record: failed to decode Expire: %w", err)
}
soa.Minimum, offset, err = getU32(buf, offset)
if err != nil {
+
return offset, fmt.Errorf("SOA record: failed to decode Minimum: %w", err)
}
+
return offset, nil
}
+
func (soa *SOA) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) {
+
var err error
+
bytes, err = encodeDomain(bytes, soa.MName, offsets)
+
if err != nil {
+
return nil, fmt.Errorf("SOA record: failed to encode MName %s: %w", soa.MName, err)
+
}
+
+
bytes, err = encodeDomain(bytes, soa.RName, offsets)
+
if err != nil {
+
return nil, fmt.Errorf("SOA record: failed to encode RName %s: %w", soa.RName, err)
+
}
+
bytes = binary.BigEndian.AppendUint32(bytes, soa.Serial)
bytes = binary.BigEndian.AppendUint32(bytes, soa.Refresh)
bytes = binary.BigEndian.AppendUint32(bytes, soa.Retry)
bytes = binary.BigEndian.AppendUint32(bytes, soa.Expire)
bytes = binary.BigEndian.AppendUint32(bytes, soa.Minimum)
+
return bytes, nil
}
func (soa SOA) String() string {
···
}
func (mb *MB) Decode(buf []byte, offset int, rdlength int) (int, error) {
+
madname, offset, err := decodeDomain(buf, offset)
if err != nil {
+
return offset, fmt.Errorf("MB record: failed to decode MADName: %w", err)
}
mb.MADName = string(madname)
+
return offset, nil
}
+
func (mb *MB) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) {
+
var err error
+
bytes, err = encodeDomain(bytes, mb.MADName, offsets)
+
if err != nil {
+
return nil, fmt.Errorf("MB record: failed to encode MADName %s: %w", mb.MADName, err)
+
}
+
+
return bytes, nil
}
func (mb MB) String() string {
···
func (mg *MG) Decode(buf []byte, offset int, rdlength int) (int, error) {
var err error
+
mg.MGMName, offset, err = decodeDomain(buf, offset)
if err != nil {
+
return offset, fmt.Errorf("MG record: failed to decode MGMName: %w", err)
}
+
return offset, nil
}
+
func (mg *MG) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) {
+
var err error
+
bytes, err = encodeDomain(bytes, mg.MGMName, offsets)
+
if err != nil {
+
return nil, fmt.Errorf("MG record: failed to encode MGMName %s: %w", mg.MGMName, err)
+
}
+
return bytes, nil
}
func (mg MG) String() string {
···
func (mr *MR) Decode(buf []byte, offset int, rdlength int) (int, error) {
var err error
+
mr.NEWName, offset, err = decodeDomain(buf, offset)
if err != nil {
+
return offset, fmt.Errorf("MR record: failed to decode NEWName: %w", err)
}
+
return offset, nil
}
+
func (mr *MR) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) {
+
var err error
+
bytes, err = encodeDomain(bytes, mr.NEWName, offsets)
+
if err != nil {
+
return nil, fmt.Errorf("MR record: failed to encode NEWName: %w", err)
+
}
+
+
return bytes, nil
}
func (mr MR) String() string {
···
var err error
null.Anything, offset, err = getSlice(buf, offset, int(rdlength))
if err != nil {
+
return offset, fmt.Errorf("NULL record: failed to read data: %w", err)
}
+
return offset, nil
}
+
func (null *NULL) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) {
+
return append(bytes, null.Anything...), nil
}
func (null NULL) String() string {
···
func (wks *WKS) Decode(buf []byte, offset int, rdlength int) (int, error) {
if rdlength < 5 {
+
return len(buf), fmt.Errorf("WKS record: RDLENGTH %d is too short, minimum 5 required", rdlength)
}
addressBytes, nextOffset, err := getSlice(buf, offset, 4)
if err != nil {
+
return len(buf), fmt.Errorf("WKS record: failed to read address: %w", err)
}
offset = nextOffset
wks.Address = net.IP(addressBytes)
protocol, nextOffset, err := getU8(buf, offset)
if err != nil {
+
return len(buf), fmt.Errorf("WKS record: failed to read protocol: %w", err)
}
offset = nextOffset
wks.Protocol = protocol
···
bitmapLength := rdlength - 5
wks.BitMap, nextOffset, err = getSlice(buf, offset, bitmapLength)
if err != nil {
+
return len(buf), fmt.Errorf("WKS record: failed to read bitmap: %w", err)
}
offset = nextOffset
return offset, nil
}
+
func (wks *WKS) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) {
bytes = append(bytes, wks.Address.To4()...)
bytes = append(bytes, wks.Protocol)
bytes = append(bytes, wks.BitMap...)
+
return bytes, nil
}
func (wks WKS) String() string {
+
return fmt.Sprintf("%s %d %x", wks.Address.String(), wks.Protocol, wks.BitMap)
}
func (ptr *PTR) Decode(buf []byte, offset int, rdlength int) (int, error) {
var err error
+
ptr.PTRDName, offset, err = decodeDomain(buf, offset)
if err != nil {
+
return offset, fmt.Errorf("PTR record: failed to decode PTRDName: %w", err)
}
+
return offset, nil
}
+
func (ptr *PTR) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) {
+
var err error
+
bytes, err = encodeDomain(bytes, ptr.PTRDName, offsets)
+
if err != nil {
+
return nil, fmt.Errorf("PTR record: failed to encode PTRD %s: %w", ptr.PTRDName, err)
+
}
+
+
return bytes, nil
}
func (ptr PTR) String() string {
···
}
func (hinfo *HINFO) Decode(buf []byte, offset int, rdlength int) (int, error) {
+
startOffset := offset
endOffset := offset + rdlength
if endOffset > len(buf) {
return len(buf), &BufferOverflowError{Length: len(buf), Offset: endOffset}
···
cpuLen, nextOffset, err := getU8(buf, currentOffset)
if err != nil {
+
return len(buf), fmt.Errorf("HINFO record: failed to read CPU length: %w", err)
}
currentOffset = nextOffset
if currentOffset+int(cpuLen) > endOffset {
···
}
cpuBytes, nextOffset, err := getSlice(buf, currentOffset, int(cpuLen))
if err != nil {
+
return len(buf), fmt.Errorf("HINFO record: failed to read CPU data: %w", err)
}
currentOffset = nextOffset
hinfo.CPU = string(cpuBytes)
···
osLen, nextOffset, err := getU8(buf, currentOffset)
if err != nil {
if currentOffset == endOffset {
+
return len(buf), fmt.Errorf("HINFO record: missing OS length byte at offset %d (expected end: %d)", currentOffset, endOffset)
}
+
return len(buf), fmt.Errorf("HINFO record: failed to read OS length: %w", err)
}
currentOffset = nextOffset
if currentOffset+int(osLen) > endOffset {
···
}
osBytes, nextOffset, err := getSlice(buf, currentOffset, int(osLen))
if err != nil {
+
return len(buf), fmt.Errorf("HINFO record: failed to read OS data: %w", err)
}
currentOffset = nextOffset
hinfo.OS = string(osBytes)
if currentOffset != endOffset {
+
return len(buf), fmt.Errorf("HINFO record: RDATA length mismatch, consumed %d bytes, expected %d", currentOffset-startOffset, rdlength)
}
return currentOffset, nil
}
+
func (hinfo *HINFO) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) {
if len(hinfo.CPU) > 255 {
+
return nil, fmt.Errorf("HINFO record: CPU string length %d exceeds maximum 255", len(hinfo.CPU))
}
if len(hinfo.OS) > 255 {
+
return nil, fmt.Errorf("HINFO record: OS string length %d exceeds maximum 255", len(hinfo.OS))
}
bytes = append(bytes, byte(len(hinfo.CPU)))
bytes = append(bytes, []byte(hinfo.CPU)...)
bytes = append(bytes, byte(len(hinfo.OS)))
bytes = append(bytes, []byte(hinfo.OS)...)
+
return bytes, nil
}
func (hinfo HINFO) String() string {
···
func (minfo *MINFO) Decode(buf []byte, offset int, rdlength int) (int, error) {
var err error
+
minfo.RMailBx, offset, err = decodeDomain(buf, offset)
if err != nil {
+
return offset, fmt.Errorf("MINFO record: failed to decode RMailBx: %w", err)
}
+
minfo.EMailBx, offset, err = decodeDomain(buf, offset)
if err != nil {
+
return offset, fmt.Errorf("MINFO record: failed to decode EMailBx: %w", err)
}
+
return offset, nil
}
+
func (minfo *MINFO) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) {
+
var err error
+
bytes, err = encodeDomain(bytes, minfo.RMailBx, offsets)
+
if err != nil {
+
return nil, fmt.Errorf("MINFO record: failed to encode RMailBx %s: %w", minfo.RMailBx, err)
+
}
+
+
bytes, err = encodeDomain(bytes, minfo.EMailBx, offsets)
+
if err != nil {
+
return nil, fmt.Errorf("MINFO record: failed to encode EMailBx %s: %w", minfo.EMailBx, err)
+
}
+
return bytes, nil
}
func (minfo MINFO) String() string {
···
var err error
mx.Preference, offset, err = getU16(buf, offset)
if err != nil {
+
return offset, fmt.Errorf("MX record: failed to decode Preference: %w", err)
}
+
mx.Exchange, offset, err = decodeDomain(buf, offset)
if err != nil {
+
return offset, fmt.Errorf("MX record: failed to decode Exchange: %w", err)
}
+
return offset, nil
}
+
func (mx *MX) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) {
+
var err error
bytes = binary.BigEndian.AppendUint16(bytes, mx.Preference)
+
bytes, err = encodeDomain(bytes, mx.Exchange, offsets)
+
if err != nil {
+
return nil, fmt.Errorf("MX record: failed to encode Exchange %s: %w", mx.Exchange, err)
+
}
+
return bytes, nil
}
func (mx MX) String() string {
···
for currentOffset < endOffset {
strLen, nextOffsetAfterLen, err := getU8(buf, currentOffset)
if err != nil {
+
return len(buf), fmt.Errorf("TXT record: failed to read string length byte: %w", err)
}
nextOffsetAfterData := nextOffsetAfterLen + int(strLen)
if nextOffsetAfterData > endOffset {
+
return len(buf), fmt.Errorf("TXT record: string segment length %d exceeds RDLENGTH boundary %d", strLen, endOffset)
}
strBytes, actualNextOffsetAfterData, err := getSlice(buf, nextOffsetAfterLen, int(strLen))
if err != nil {
+
return len(buf), fmt.Errorf("TXT record: failed to read string data (length %d): %w", strLen, err)
}
txt.TxtData = append(txt.TxtData, string(strBytes))
···
}
if currentOffset != endOffset {
+
return len(buf), fmt.Errorf("TXT record: RDATA parsing finished at offset %d, but expected end at %d", currentOffset, endOffset)
}
return currentOffset, nil
}
+
func (txt *TXT) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) {
for _, s := range txt.TxtData {
if len(s) > 255 {
+
return nil, fmt.Errorf("TXT record: string segment length %d exceeds maximum 255", len(s))
}
bytes = append(bytes, byte(len(s)))
bytes = append(bytes, []byte(s)...)
}
+
return bytes, nil
}
func (txt TXT) String() string {
···
var err error
r.Bytes, offset, err = getSlice(buf, offset, int(rdlength))
if err != nil {
+
return offset, fmt.Errorf("reserved record: failed to read data: %w", err)
}
return offset, err
}
+
func (r *Reserved) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) {
+
return append(bytes, r.Bytes...), nil
}
func (r Reserved) String() string {
+
return fmt.Sprintf("[Reserved Data: %x]", r.Bytes)
}
// Decode decodes a resource record from buf at the offset.
func (r *ResourceRecord) Decode(buf []byte, offset int) (int, error) {
+
var err error
+
r.Name, offset, err = decodeDomain(buf, offset)
if err != nil {
+
return offset, fmt.Errorf("rr decode: failed to decode record name: %w", err)
}
+
var rtype uint16
+
rtype, offset, err = getU16(buf, offset)
if err != nil {
+
return offset, fmt.Errorf("rr decode: failed to decode RType for %s: %w", r.Name, err)
}
r.RType = DNSType(rtype)
+
var rclass uint16
+
rclass, offset, err = getU16(buf, offset)
if err != nil {
+
return offset, fmt.Errorf("rr decode: failed to decode RClass for %s: %w", r.Name, err)
}
r.RClass = DNSClass(rclass)
r.TTL, offset, err = getU32(buf, offset)
if err != nil {
+
return offset, fmt.Errorf("rr decode: failed to decode TTL for %s: %w", r.Name, err)
}
r.RDLength, offset, err = getU16(buf, offset)
if err != nil {
+
return offset, fmt.Errorf("rr decode: failed to decode RDLength for %s: %w", r.Name, err)
}
switch r.RType {
···
if r.RData != nil {
offset, err = r.RData.Decode(buf, offset, int(r.RDLength))
if err != nil {
+
return offset, fmt.Errorf("rr decode: failed to decode RData for %s (%s): %w", r.Name, r.RType.String(), err)
}
}
···
}
// Encode encdoes a resource record and returns the input bytes appened.
+
func (r *ResourceRecord) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) {
+
var err error
+
bytes, err = encodeDomain(bytes, r.Name, offsets)
+
if err != nil {
+
return nil, fmt.Errorf("rr encode: failed to encode record name %s: %w", r.Name, err)
+
}
+
bytes = binary.BigEndian.AppendUint16(bytes, uint16(r.RType))
bytes = binary.BigEndian.AppendUint16(bytes, uint16(r.RClass))
bytes = binary.BigEndian.AppendUint32(bytes, r.TTL)
rdata_start := len(bytes)
bytes = binary.BigEndian.AppendUint16(bytes, 0)
+
bytes, err = r.RData.Encode(bytes, offsets)
+
if err != nil {
+
return nil, fmt.Errorf("rr encode: failed to encode RData for %s (%s): %w", r.Name, r.RType.String(), err)
+
}
+
rdata_length := uint16(len(bytes) - rdata_start - 2)
binary.BigEndian.PutUint16(bytes[rdata_start:rdata_start+2], rdata_length)
+
return bytes, nil
}
+472 -83
resource_record_test.go
···
import (
"encoding/binary"
"net"
"testing"
···
"github.com/stretchr/testify/require"
)
-
func TestTXTRecord(t *testing.T) {
-
rdataBytes := []byte{0x03, 'a', 'b', 'c', 0x03, 'd', 'e', 'f'}
-
rdlength := uint16(len(rdataBytes))
-
buf := []byte{0x00}
-
buf = binary.BigEndian.AppendUint16(buf, uint16(TXTType))
-
buf = binary.BigEndian.AppendUint16(buf, uint16(IN))
-
buf = binary.BigEndian.AppendUint32(buf, 3600)
-
buf = binary.BigEndian.AppendUint16(buf, rdlength)
buf = append(buf, rdataBytes...)
-
rr := &ResourceRecord{}
-
offset, err := rr.Decode(buf, 0)
-
require.NoError(t, err)
-
assert.Equal(t, len(buf), offset)
-
require.IsType(t, &TXT{}, rr.RData)
-
txtData := rr.RData.(*TXT)
-
expectedDecodedData := []string{"abc", "def"}
-
assert.Equal(t, expectedDecodedData, txtData.TxtData, "Decoded TXT data does not match expected concatenation")
-
txtToEncode := &TXT{TxtData: []string{"test"}}
-
expectedEncodedRdata := []byte{0x04, 't', 'e', 's', 't'}
-
encodeBuf := []byte{}
-
encodedRdata := txtToEncode.Encode(encodeBuf, nil)
-
assert.Equal(t, expectedEncodedRdata, encodedRdata, "Encoded TXT RDATA is incorrect")
}
-
func TestHINFORecordRFCCompliance(t *testing.T) {
-
rdataBytes := []byte{0x03, 'C', 'P', 'U', 0x02, 'O', 'S'}
-
rdlength := uint16(len(rdataBytes))
-
buf := []byte{0x00}
-
buf = binary.BigEndian.AppendUint16(buf, uint16(HINFOType))
-
buf = binary.BigEndian.AppendUint16(buf, uint16(IN))
-
buf = binary.BigEndian.AppendUint32(buf, 3600)
-
buf = binary.BigEndian.AppendUint16(buf, rdlength)
-
buf = append(buf, rdataBytes...)
-
rr := &ResourceRecord{}
-
offset, err := rr.Decode(buf, 0)
-
require.NoError(t, err)
-
assert.Equal(t, len(buf), offset)
-
require.IsType(t, &HINFO{}, rr.RData)
-
hinfoData := rr.RData.(*HINFO)
-
assert.Equal(t, "CPU", hinfoData.CPU, "Decoded HINFO CPU does not match")
-
assert.Equal(t, "OS", hinfoData.OS, "Decoded HINFO OS does not match")
-
hinfoToEncode := &HINFO{CPU: "Intel", OS: "Linux"}
-
expectedEncodedRdata := []byte{0x05, 'I', 'n', 't', 'e', 'l', 0x05, 'L', 'i', 'n', 'u', 'x'}
-
encodeBuf := []byte{}
-
encodedRdata := hinfoToEncode.Encode(encodeBuf, nil)
-
assert.Equal(t, expectedEncodedRdata, encodedRdata, "Encoded HINFO RDATA is incorrect")
}
-
func TestWKSRecordDecoding(t *testing.T) {
addr := net.ParseIP("192.168.1.1").To4()
proto := byte(6)
bitmap := []byte{0x01, 0x80}
-
rdataBytes := append(addr, proto)
-
rdataBytes = append(rdataBytes, bitmap...)
-
rdlength := uint16(len(rdataBytes))
-
buf := []byte{0x00}
-
buf = binary.BigEndian.AppendUint16(buf, uint16(WKSType))
-
buf = binary.BigEndian.AppendUint16(buf, uint16(IN))
-
buf = binary.BigEndian.AppendUint32(buf, 3600)
-
buf = binary.BigEndian.AppendUint16(buf, rdlength)
-
buf = append(buf, rdataBytes...)
-
rr := &ResourceRecord{}
-
offset, err := rr.Decode(buf, 0)
-
require.NoError(t, err)
-
assert.Equal(t, len(buf), offset)
-
require.IsType(t, &WKS{}, rr.RData)
-
wksData := rr.RData.(*WKS)
-
assert.Equal(t, addr, wksData.Address.To4())
-
assert.Equal(t, proto, wksData.Protocol)
-
assert.Equal(t, bitmap, wksData.BitMap)
}
-
func TestSOADecodeWithCompression(t *testing.T) {
-
input := []byte{0x69, 0x7b, 0x81, 0x83, 0x0, 0x1, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0xf, 0x6e, 0x6f, 0x77, 0x61, 0x79, 0x74, 0x68, 0x69, 0x73, 0x65, 0x78, 0x69, 0x73, 0x74, 0x73, 0x3, 0x63, 0x6f, 0x6d, 0x0, 0x0, 0x1, 0x0, 0x1, 0xc0, 0x1c, 0x0, 0x6, 0x0, 0x1, 0x0, 0x0, 0x3, 0x84, 0x0, 0x3d, 0x1, 0x61, 0xc, 0x67, 0x74, 0x6c, 0x64, 0x2d, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x3, 0x6e, 0x65, 0x74, 0x0, 0x5, 0x6e, 0x73, 0x74, 0x6c, 0x64, 0xc, 0x76, 0x65, 0x72, 0x69, 0x73, 0x69, 0x67, 0x6e, 0x2d, 0x67, 0x72, 0x73, 0xc0, 0x1c, 0x67, 0xaa, 0xc5, 0x6b, 0x0, 0x0, 0x7, 0x8, 0x0, 0x0, 0x3, 0x84, 0x0, 0x9, 0x3a, 0x80, 0x0, 0x0, 0x3, 0x84}
-
msg := &Message{}
-
err := msg.Decode(input)
assert.NoError(t, err)
-
assert.Equal(t, 1, len(msg.Authority))
-
rr := msg.Authority[0]
-
assert.Equal(t, DNSType(6), rr.RType)
-
assert.Equal(t, DNSClass(1), rr.RClass)
-
assert.Equal(t, uint32(900), rr.TTL)
-
assert.Equal(t, uint16(61), rr.RDLength)
-
soa, ok := msg.Authority[0].RData.(*SOA)
-
assert.True(t, ok)
-
assert.Equal(t, "a.gtld-servers.net", soa.MName)
-
assert.Equal(t, "nstld.verisign-grs.com", soa.RName)
-
assert.Equal(t, uint32(1739244907), soa.Serial)
-
assert.Equal(t, uint32(1800), soa.Refresh)
-
assert.Equal(t, uint32(900), soa.Retry)
-
assert.Equal(t, uint32(604800), soa.Expire)
-
assert.Equal(t, uint32(900), soa.Minimum)
-
encoded := msg.Encode()
-
assert.Equal(t, input, encoded)
}
···
import (
"encoding/binary"
+
"errors"
"net"
"testing"
···
"github.com/stretchr/testify/require"
)
+
func buildRRBytes(t *testing.T, name string, rtype DNSType, rclass DNSClass, ttl uint32, rdataBytes []byte) []byte {
+
t.Helper()
+
buf := []byte{}
+
offsets := make(map[string]uint16)
+
encodedName, err := encodeDomain(buf, name, &offsets)
+
require.NoError(t, err, "Failed to encode name in test helper")
+
buf = encodedName
+
+
buf = binary.BigEndian.AppendUint16(buf, uint16(rtype))
+
buf = binary.BigEndian.AppendUint16(buf, uint16(rclass))
+
buf = binary.BigEndian.AppendUint32(buf, ttl)
+
+
buf = binary.BigEndian.AppendUint16(buf, uint16(len(rdataBytes)))
buf = append(buf, rdataBytes...)
+
return buf
+
}
+
func encodeRData(t *testing.T, rdata ResourceRecordData) []byte {
+
t.Helper()
+
buf := []byte{}
+
offsets := make(map[string]uint16)
+
encodedRData, err := rdata.Encode(buf, &offsets)
+
require.NoError(t, err, "Failed to encode RDATA in test helper")
+
return encodedRData
+
}
+
func TestARecord(t *testing.T) {
+
addr := net.ParseIP("192.168.1.1").To4()
+
rdataBytes := []byte(addr)
+
a := &A{}
+
+
offset, err := a.Decode([]byte{}, 0, 4)
+
assert.Error(t, err, "Decode should fail with empty buffer")
+
assert.True(t, errors.Is(err, &BufferOverflowError{}))
+
+
offset, err = a.Decode(rdataBytes, 0, 4)
+
assert.NoError(t, err)
+
assert.Equal(t, 4, offset)
+
assert.Equal(t, addr, a.Address)
+
+
_, err = a.Decode([]byte{1, 2, 3}, 0, 3)
+
assert.Error(t, err)
+
assert.Contains(t, err.Error(), "A record:")
+
+
_, err = a.Decode([]byte{1, 2, 3, 4, 5}, 0, 5)
+
assert.Error(t, err)
+
assert.Contains(t, err.Error(), "A record:")
+
+
addr = net.ParseIP("192.168.1.1").To4()
+
err = nil
+
aEncode := &A{Address: addr}
+
encoded := encodeRData(t, aEncode)
+
assert.NoError(t, err)
+
assert.Equal(t, rdataBytes, encoded)
+
}
+
+
func TestNSRecord(t *testing.T) {
+
nsName := "ns1.example.com"
+
offsets := make(map[string]uint16)
+
rdataBytes, _ := encodeDomain([]byte{}, nsName, &offsets)
+
ns := &NS{}
+
offset, err := ns.Decode(rdataBytes, 0, len(rdataBytes))
+
assert.NoError(t, err)
+
assert.Equal(t, len(rdataBytes), offset)
+
assert.Equal(t, nsName, ns.NSDName)
+
_, err = ns.Decode(rdataBytes[:len(rdataBytes)-2], 0, len(rdataBytes)-2)
+
assert.Error(t, err)
+
assert.True(t, errors.Is(err, &BufferOverflowError{}))
+
assert.ErrorContains(t, err, "NS record: failed to decode NSDName")
+
nsEncode := &NS{NSDName: nsName}
+
encoded := encodeRData(t, nsEncode)
+
assert.Equal(t, rdataBytes, encoded)
}
+
func TestCNAMERecord(t *testing.T) {
+
cname := "target.example.com"
+
offsets := make(map[string]uint16)
+
rdataBytes, _ := encodeDomain([]byte{}, cname, &offsets)
+
c := &CNAME{}
+
+
offset, err := c.Decode(rdataBytes, 0, len(rdataBytes))
+
assert.NoError(t, err)
+
assert.Equal(t, len(rdataBytes), offset)
+
assert.Equal(t, cname, c.CName)
+
_, err = c.Decode(rdataBytes[:5], 0, 5)
+
assert.Error(t, err)
+
assert.ErrorContains(t, err, "CNAME record")
+
cEncode := &CNAME{CName: cname}
+
encoded := encodeRData(t, cEncode)
+
assert.Equal(t, rdataBytes, encoded)
+
}
+
func TestSOARecord(t *testing.T) {
+
mname := "ns.example.com"
+
rname := "admin.example.com"
+
serial := uint32(2023010101)
+
refresh := uint32(7200)
+
retry := uint32(3600)
+
expire := uint32(1209600)
+
minimum := uint32(3600)
+
soaEncode := &SOA{MName: mname, RName: rname, Serial: serial, Refresh: refresh, Retry: retry, Expire: expire, Minimum: minimum}
+
rdataBytes := encodeRData(t, soaEncode)
+
soa := &SOA{}
+
offset, err := soa.Decode(rdataBytes, 0, len(rdataBytes))
+
assert.NoError(t, err)
+
assert.Equal(t, len(rdataBytes), offset)
+
assert.Equal(t, *soaEncode, *soa)
+
+
_, err = soa.Decode(rdataBytes[:len(rdataBytes)-5], 0, len(rdataBytes)-5)
+
assert.Error(t, err)
+
assert.ErrorContains(t, err, "SOA record:")
+
+
nameOffset := make(map[string]uint16)
+
mnameBytes, _ := encodeDomain([]byte{}, mname, &nameOffset)
+
rnameBytes, _ := encodeDomain([]byte{}, rname, &nameOffset)
+
shortRdataBytes := append(mnameBytes, rnameBytes...)
+
_, err = soa.Decode(shortRdataBytes, 0, len(shortRdataBytes))
+
assert.Error(t, err)
+
assert.ErrorContains(t, err, "SOA record")
+
}
+
+
func TestPTRRecord(t *testing.T) {
+
ptrName := "host.example.com"
+
offsets := make(map[string]uint16)
+
rdataBytes, _ := encodeDomain([]byte{}, ptrName, &offsets)
+
ptr := &PTR{}
+
+
offset, err := ptr.Decode(rdataBytes, 0, len(rdataBytes))
+
assert.NoError(t, err)
+
assert.Equal(t, len(rdataBytes), offset)
+
assert.Equal(t, ptrName, ptr.PTRDName)
+
+
_, err = ptr.Decode(rdataBytes[:3], 0, 3)
+
assert.Error(t, err)
+
assert.ErrorContains(t, err, "PTR record: failed to decode PTRDName")
+
+
ptrEncode := &PTR{PTRDName: ptrName}
+
encoded := encodeRData(t, ptrEncode)
+
assert.Equal(t, rdataBytes, encoded)
+
}
+
+
func TestMXRecord(t *testing.T) {
+
preference := uint16(10)
+
exchange := "mail.example.com"
+
+
mxEncode := &MX{Preference: preference, Exchange: exchange}
+
rdataBytes := encodeRData(t, mxEncode)
+
mx := &MX{}
+
+
offset, err := mx.Decode(rdataBytes, 0, len(rdataBytes))
+
assert.NoError(t, err)
+
assert.Equal(t, len(rdataBytes), offset)
+
assert.Equal(t, *mxEncode, *mx)
+
+
_, err = mx.Decode([]byte{0}, 0, 1)
+
assert.Error(t, err)
+
assert.ErrorContains(t, err, "MX record")
+
+
buf := make([]byte, 2)
+
binary.BigEndian.PutUint16(buf, preference)
+
buf = append(buf, []byte{4, 'm', 'a'}...)
+
_, err = mx.Decode(buf, 0, len(buf))
+
assert.Error(t, err)
+
assert.ErrorContains(t, err, "MX record: failed to decode Exchange")
+
}
+
+
func TestTXTRecord(t *testing.T) {
+
txtData := []string{"abc", "def"}
+
txtEncode := &TXT{TxtData: txtData}
+
rdataBytes := encodeRData(t, txtEncode)
+
txt := &TXT{}
+
+
offset, err := txt.Decode(rdataBytes, 0, len(rdataBytes))
+
require.NoError(t, err, "TXT Decode failed")
+
assert.Equal(t, len(rdataBytes), offset)
+
assert.Equal(t, txtData, txt.TxtData, "Decoded TXT data mismatch")
+
+
txtDataEmpty := []string{""}
+
txtEncodeEmpty := &TXT{TxtData: txtDataEmpty}
+
rdataBytesEmpty := encodeRData(t, txtEncodeEmpty)
+
offset, err = txt.Decode(rdataBytesEmpty, 0, len(rdataBytesEmpty))
+
require.NoError(t, err, "TXT Decode with empty string failed")
+
assert.Equal(t, len(rdataBytesEmpty), offset)
+
assert.Equal(t, txtDataEmpty, txt.TxtData)
+
+
txtDataMulti := []string{"v=spf1", "include:_spf.google.com", "~all"}
+
txtEncodeMulti := &TXT{TxtData: txtDataMulti}
+
rdataBytesMulti := encodeRData(t, txtEncodeMulti)
+
offset, err = txt.Decode(rdataBytesMulti, 0, len(rdataBytesMulti))
+
require.NoError(t, err, "TXT Decode with multiple strings failed")
+
assert.Equal(t, len(rdataBytesMulti), offset)
+
assert.Equal(t, txtDataMulti, txt.TxtData)
+
+
_, err = txt.Decode([]byte{}, 0, 0)
+
assert.NoError(t, err)
+
+
_, err = txt.Decode([]byte{5, 'd', 'a', 't'}, 0, 4)
+
assert.Error(t, err)
+
assert.ErrorContains(t, err, "TXT record: string segment length 5 exceeds RDLENGTH boundary 4")
+
+
encoded := encodeRData(t, txtEncode)
+
assert.Equal(t, rdataBytes, encoded)
+
}
+
+
func TestHINFORecord(t *testing.T) {
+
cpu := "Intel"
+
os := "Linux"
+
hinfoEncode := &HINFO{CPU: cpu, OS: os}
+
rdataBytes := encodeRData(t, hinfoEncode)
+
hinfo := &HINFO{}
+
+
offset, err := hinfo.Decode(rdataBytes, 0, len(rdataBytes))
+
require.NoError(t, err, "HINFO Decode failed")
+
assert.Equal(t, len(rdataBytes), offset)
+
assert.Equal(t, cpu, hinfo.CPU)
+
assert.Equal(t, os, hinfo.OS)
+
+
hinfoEncodeEmpty := &HINFO{CPU: "", OS: ""}
+
rdataBytesEmpty := encodeRData(t, hinfoEncodeEmpty)
+
offset, err = hinfo.Decode(rdataBytesEmpty, 0, len(rdataBytesEmpty))
+
require.NoError(t, err, "HINFO Decode with empty strings failed")
+
assert.Equal(t, len(rdataBytesEmpty), offset)
+
assert.Equal(t, "", hinfo.CPU)
+
assert.Equal(t, "", hinfo.OS)
+
+
_, err = hinfo.Decode([]byte{}, 0, 0)
+
assert.Error(t, err)
+
assert.ErrorContains(t, err, "HINFO record:")
+
+
_, err = hinfo.Decode([]byte{5, 'I', 'n'}, 0, 3)
+
assert.Error(t, err)
+
assert.ErrorContains(t, err, "buffer overflow:")
+
+
_, err = hinfo.Decode([]byte{5, 'I', 'n', 't', 'e', 'l'}, 0, 6)
+
assert.Error(t, err)
+
assert.ErrorContains(t, err, "HINFO record:")
+
+
_, err = hinfo.Decode([]byte{5, 'I', 'n', 't', 'e', 'l', 5, 'L', 'i'}, 0, 9)
+
assert.Error(t, err)
+
assert.ErrorContains(t, err, "buffer overflow:")
+
+
extraData := append(rdataBytes, 0xFF)
+
_, err = hinfo.Decode(extraData, 0, len(extraData))
+
assert.Error(t, err)
+
assert.ErrorContains(t, err, "HINFO record:")
+
_, err = hinfo.Decode([]byte{10, 'a', 'b', 'c'}, 0, 4)
+
assert.Error(t, err)
+
assert.ErrorContains(t, err, "buffer overflow:")
}
+
func TestWKSRecord(t *testing.T) {
addr := net.ParseIP("192.168.1.1").To4()
proto := byte(6)
bitmap := []byte{0x01, 0x80}
+
wksEncode := &WKS{Address: addr, Protocol: proto, BitMap: bitmap}
+
rdataBytes := encodeRData(t, wksEncode)
+
wks := &WKS{}
+
offset, err := wks.Decode(rdataBytes, 0, len(rdataBytes))
+
require.NoError(t, err, "WKS Decode failed")
+
assert.Equal(t, len(rdataBytes), offset)
+
assert.Equal(t, addr, wks.Address.To4())
+
assert.Equal(t, proto, wks.Protocol)
+
assert.Equal(t, bitmap, wks.BitMap)
+
wksEncodeNoBitmap := &WKS{Address: addr, Protocol: proto, BitMap: []byte{}}
+
rdataBytesNoBitmap := encodeRData(t, wksEncodeNoBitmap)
+
wks = &WKS{}
+
offset, err = wks.Decode(rdataBytesNoBitmap, 0, len(rdataBytesNoBitmap))
+
require.NoError(t, err, "WKS Decode without bitmap failed")
+
assert.Equal(t, len(rdataBytesNoBitmap), offset)
+
assert.Equal(t, addr, wks.Address.To4())
+
assert.Equal(t, proto, wks.Protocol)
+
assert.Empty(t, wks.BitMap)
+
_, err = wks.Decode([]byte{1, 2, 3, 4}, 0, 4)
+
assert.Error(t, err)
+
assert.ErrorContains(t, err, "WKS record: RDLENGTH 4 is too short")
+
+
_, err = wks.Decode([]byte{1, 2, 3}, 0, 5)
+
assert.Error(t, err)
+
assert.ErrorContains(t, err, "WKS record: failed to read address")
+
+
_, err = wks.Decode([]byte{1, 2, 3, 4}, 0, 5)
+
assert.Error(t, err)
+
assert.ErrorContains(t, err, "WKS record: failed to read protocol")
+
+
_, err = wks.Decode([]byte{1, 2, 3, 4, 6, 0x01}, 0, 7)
+
assert.Error(t, err)
+
assert.ErrorContains(t, err, "WKS record: failed to read bitmap")
}
+
func TestReservedRecord(t *testing.T) {
+
rdataBytes := []byte{0xDE, 0xAD, 0xBE, 0xEF}
+
r := &Reserved{}
+
offset, err := r.Decode(rdataBytes, 0, len(rdataBytes))
assert.NoError(t, err)
+
assert.Equal(t, len(rdataBytes), offset)
+
assert.Equal(t, rdataBytes, r.Bytes)
+
_, err = r.Decode(rdataBytes[:2], 0, 4)
+
assert.Error(t, err)
+
assert.ErrorContains(t, err, "reserved record: failed to read data")
+
+
rEncode := &Reserved{Bytes: rdataBytes}
+
encoded := encodeRData(t, rEncode)
+
assert.Equal(t, rdataBytes, encoded)
+
+
rEncodeNil := &Reserved{Bytes: nil}
+
encodedNil := encodeRData(t, rEncodeNil)
+
assert.Empty(t, encodedNil)
+
}
+
+
func TestResourceRecordDecode(t *testing.T) {
+
tests := []struct {
+
name string
+
input []byte
+
expectedRR *ResourceRecord
+
wantErr bool
+
wantErrType error
+
wantErrMsg string
+
}{
+
{
+
name: "Valid A record",
+
input: buildRRBytes(t, "a.com", AType, IN, 60, []byte{1, 1, 1, 1}),
+
expectedRR: &ResourceRecord{
+
Name: "a.com", RType: AType, RClass: IN, TTL: 60, RDLength: 4, RData: &A{net.IP{1, 1, 1, 1}},
+
},
+
},
+
{
+
name: "Valid TXT record",
+
input: buildRRBytes(t, "b.org", TXTType, IN, 300, encodeRData(t, &TXT{[]string{"hello", "world"}})),
+
expectedRR: &ResourceRecord{
+
Name: "b.org", RType: TXTType, RClass: IN, TTL: 300, RDLength: 12, RData: &TXT{[]string{"hello", "world"}},
+
},
+
},
+
{
+
name: "Unknown record type",
+
input: buildRRBytes(t, "c.net", DNSType(9999), IN, 10, []byte{0xca, 0xfe}),
+
expectedRR: &ResourceRecord{
+
Name: "c.net", RType: DNSType(9999), RClass: IN, TTL: 10, RDLength: 2, RData: &Reserved{[]byte{0xca, 0xfe}},
+
},
+
},
+
{
+
name: "Truncated name",
+
input: []byte{3, 'a', 'b'},
+
wantErr: true,
+
wantErrType: &BufferOverflowError{},
+
wantErrMsg: "rr decode:",
+
},
+
{
+
name: "Truncated type",
+
input: buildRRBytes(t, "d.com", AType, IN, 60, []byte{1})[:5],
+
wantErr: true,
+
wantErrType: &BufferOverflowError{},
+
wantErrMsg: "rr decode:",
+
},
+
{
+
name: "Truncated RDATA section",
+
input: buildRRBytes(t, "e.com", AType, IN, 60, []byte{1, 2, 3, 4})[:15],
+
wantErr: true,
+
wantErrType: &BufferOverflowError{},
+
wantErrMsg: "rr decode:",
+
},
+
{
+
name: "RDLENGTH mismatch (claims longer than buffer)",
+
input: func() []byte {
+
buf := buildRRBytes(t, "f.com", AType, IN, 60, []byte{1, 2, 3, 4})
+
binary.BigEndian.PutUint16(buf[10:12], 10)
+
return buf[:14]
+
}(),
+
wantErr: true,
+
wantErrType: &BufferOverflowError{},
+
wantErrMsg: "rr decode:",
+
},
+
{
+
name: "RDLENGTH mismatch (RData decoder consumes less)",
+
input: func() []byte {
+
rdataBytes := encodeRData(t, &TXT{[]string{"short"}})
+
buf := buildRRBytes(t, "g.com", TXTType, IN, 60, rdataBytes)
+
nameLen := len(buf) - 10 - len(rdataBytes)
+
rdlenPos := nameLen + 8
+
binary.BigEndian.PutUint16(buf[rdlenPos:rdlenPos+2], uint16(len(rdataBytes)+5))
+
return buf
+
}(),
+
wantErr: true,
+
wantErrMsg: "rr decode:",
+
},
+
}
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
rr := &ResourceRecord{}
+
offset, err := rr.Decode(tt.input, 0)
+
if tt.wantErr {
+
assert.Error(t, err)
+
if tt.wantErrType != nil {
+
assert.True(t, errors.Is(err, tt.wantErrType), "Error type mismatch. Got %T", err)
+
}
+
if tt.wantErrMsg != "" {
+
assert.ErrorContains(t, err, tt.wantErrMsg)
+
}
+
} else {
+
assert.NoError(t, err)
+
assert.Equal(t, len(tt.input), offset, "Offset should match input length")
+
assert.Equal(t, tt.expectedRR.Name, rr.Name)
+
assert.Equal(t, tt.expectedRR.RType, rr.RType)
+
assert.Equal(t, tt.expectedRR.RClass, rr.RClass)
+
assert.Equal(t, tt.expectedRR.TTL, rr.TTL)
+
assert.Equal(t, tt.expectedRR.RDLength, rr.RDLength)
+
assert.Equal(t, tt.expectedRR.RData, rr.RData)
+
}
+
})
+
}
+
}
+
func TestResourceRecordEncode(t *testing.T) {
+
tests := []struct {
+
name string
+
rr *ResourceRecord
+
expectedLen int
+
wantErr bool
+
wantErrType error
+
wantErrMsg string
+
}{
+
{
+
name: "Valid A record",
+
rr: &ResourceRecord{Name: "a.com", RType: AType, RClass: IN, TTL: 60, RData: &A{net.IP{1, 1, 1, 1}}},
+
},
+
{
+
name: "Valid TXT record",
+
rr: &ResourceRecord{Name: "b.org", RType: TXTType, RClass: IN, TTL: 300, RData: &TXT{[]string{"hello", "world"}}},
+
},
+
{
+
name: "Encode fail - Invalid Name",
+
rr: &ResourceRecord{Name: "a..b", RType: AType, RClass: IN, TTL: 60, RData: &A{net.IP{1, 1, 1, 1}}},
+
wantErr: true,
+
wantErrType: &InvalidLabelError{},
+
wantErrMsg: "rr encode: failed to encode record name a..b",
+
},
+
{
+
name: "Encode fail - Invalid RData (A record)",
+
rr: &ResourceRecord{Name: "a.com", RType: AType, RClass: IN, TTL: 60, RData: &A{net.ParseIP("::1")}},
+
wantErr: true,
+
wantErrMsg: "rr encode: failed to encode RData for a.com (A): A record: cannot encode non-IPv4 address",
+
},
+
{
+
name: "Encode fail - Invalid RData (TXT record)",
+
rr: &ResourceRecord{Name: "b.org", RType: TXTType, RClass: IN, TTL: 300, RData: &TXT{[]string{string(make([]byte, 256))}}},
+
wantErr: true,
+
wantErrMsg: "rr encode: failed to encode RData for b.org (TXT): TXT record: string segment length 256 exceeds maximum 255",
+
},
+
}
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
offsets := make(map[string]uint16)
+
encodedBytes, err := tt.rr.Encode([]byte{}, &offsets)
+
+
if tt.wantErr {
+
assert.Error(t, err)
+
if tt.wantErrType != nil {
+
assert.True(t, errors.Is(err, tt.wantErrType), "Error type mismatch. Got %T", err)
+
}
+
if tt.wantErrMsg != "" {
+
assert.ErrorContains(t, err, tt.wantErrMsg)
+
}
+
} else {
+
assert.NoError(t, err)
+
assert.NotEmpty(t, encodedBytes)
+
+
decodedRR := &ResourceRecord{}
+
offset, decodeErr := decodedRR.Decode(encodedBytes, 0)
+
assert.NoError(t, decodeErr, "Failed to decode back encoded RR")
+
if decodeErr == nil {
+
assert.Equal(t, len(encodedBytes), offset, "Decoded offset mismatch")
+
assert.Equal(t, tt.rr.Name, decodedRR.Name)
+
assert.Equal(t, tt.rr.RType, decodedRR.RType)
+
assert.Equal(t, tt.rr.RClass, decodedRR.RClass)
+
assert.Equal(t, tt.rr.TTL, decodedRR.TTL)
+
if tt.rr.RData == nil {
+
assert.IsType(t, &Reserved{}, decodedRR.RData, "Nil RData should decode as Reserved")
+
assert.Empty(t, decodedRR.RData.(*Reserved).Bytes, "Nil RData should decode as empty Reserved")
+
assert.Equal(t, uint16(0), decodedRR.RDLength, "Nil RData should have RDLength 0")
+
} else {
+
assert.Equal(t, tt.rr.RData, decodedRR.RData, "RData mismatch after round trip")
+
assert.NotEqual(t, uint16(0), decodedRR.RDLength, "Non-nil RData should have non-zero RDLength")
+
}
+
}
+
}
+
})
+
}
}
+1 -1
types.go
···
// *map[string]uint8 - map containing labels and offsets for domain name compression
type ResourceRecordData interface {
Decode([]byte, int, int) (int, error)
-
Encode([]byte, *map[string]uint16) []byte
String() string
}
···
// *map[string]uint8 - map containing labels and offsets for domain name compression
type ResourceRecordData interface {
Decode([]byte, int, int) (int, error)
+
Encode([]byte, *map[string]uint16) ([]byte, error)
String() string
}
+16 -16
utils.go
···
// getU8 returns the first byte from a byte array at offset.
func getU8(buf []byte, offset int) (uint8, int, error) {
-
next_offset := offset + 1
-
if next_offset > len(buf) {
-
return 0, len(buf), &BufferOverflowError{Length: len(buf), Offset: next_offset}
}
-
return buf[offset], next_offset, nil
}
// getU16 returns the bigEndian uint16 from a byte array at offset.
func getU16(buf []byte, offset int) (uint16, int, error) {
-
next_offset := offset + 2
-
if next_offset > len(buf) {
-
return 0, len(buf), &BufferOverflowError{Length: len(buf), Offset: next_offset}
}
-
return binary.BigEndian.Uint16(buf[offset:]), next_offset, nil
}
// getU32 returns the bigEndian uint32 from a byte array at offset.
func getU32(buf []byte, offset int) (uint32, int, error) {
-
next_offset := offset + 4
-
if next_offset > len(buf) {
-
return 0, len(buf), &BufferOverflowError{Length: len(buf), Offset: next_offset}
}
-
return binary.BigEndian.Uint32(buf[offset:]), next_offset, nil
}
// getSlice returns a slice of bytes from a byte array at an offset and of length.
func getSlice(buf []byte, offset int, length int) ([]byte, int, error) {
-
next_offset := offset + length
-
if next_offset > len(buf) {
-
return nil, len(buf), &BufferOverflowError{Length: len(buf), Offset: next_offset}
}
-
return buf[offset:next_offset], next_offset, nil
}
···
// getU8 returns the first byte from a byte array at offset.
func getU8(buf []byte, offset int) (uint8, int, error) {
+
nextOffset := offset + 1
+
if nextOffset > len(buf) {
+
return 0, len(buf), &BufferOverflowError{Length: len(buf), Offset: nextOffset}
}
+
return buf[offset], nextOffset, nil
}
// getU16 returns the bigEndian uint16 from a byte array at offset.
func getU16(buf []byte, offset int) (uint16, int, error) {
+
nextOffset := offset + 2
+
if nextOffset > len(buf) {
+
return 0, len(buf), &BufferOverflowError{Length: len(buf), Offset: nextOffset}
}
+
return binary.BigEndian.Uint16(buf[offset:]), nextOffset, nil
}
// getU32 returns the bigEndian uint32 from a byte array at offset.
func getU32(buf []byte, offset int) (uint32, int, error) {
+
nextOffset := offset + 4
+
if nextOffset > len(buf) {
+
return 0, len(buf), &BufferOverflowError{Length: len(buf), Offset: nextOffset}
}
+
return binary.BigEndian.Uint32(buf[offset:]), nextOffset, nil
}
// getSlice returns a slice of bytes from a byte array at an offset and of length.
func getSlice(buf []byte, offset int, length int) ([]byte, int, error) {
+
nextOffset := offset + length
+
if nextOffset > len(buf) {
+
return nil, len(buf), &BufferOverflowError{Length: len(buf), Offset: nextOffset}
}
+
return buf[offset:nextOffset], nextOffset, nil
}