package magna import ( "encoding/binary" "errors" "net" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func buildRRBytes(t *testing.T, name string, rtype DNSType, rclass DNSClass, ttl uint32, rdataBytes []byte) []byte { t.Helper() buf := []byte{} offsets := make(map[string]uint16) encodedName, err := encodeDomain(buf, name, &offsets) require.NoError(t, err, "Failed to encode name in test helper") buf = encodedName buf = binary.BigEndian.AppendUint16(buf, uint16(rtype)) buf = binary.BigEndian.AppendUint16(buf, uint16(rclass)) buf = binary.BigEndian.AppendUint32(buf, ttl) buf = binary.BigEndian.AppendUint16(buf, uint16(len(rdataBytes))) buf = append(buf, rdataBytes...) return buf } func encodeRData(t *testing.T, rdata ResourceRecordData) []byte { t.Helper() buf := []byte{} offsets := make(map[string]uint16) encodedRData, err := rdata.Encode(buf, &offsets) require.NoError(t, err, "Failed to encode RDATA in test helper") return encodedRData } func TestARecord(t *testing.T) { addr := net.ParseIP("192.168.1.1").To4() rdataBytes := []byte(addr) a := &A{} offset, err := a.Decode([]byte{}, 0, 4) assert.Error(t, err, "Decode should fail with empty buffer") assert.True(t, errors.Is(err, &BufferOverflowError{})) offset, err = a.Decode(rdataBytes, 0, 4) assert.NoError(t, err) assert.Equal(t, 4, offset) assert.Equal(t, addr, a.Address) _, err = a.Decode([]byte{1, 2, 3}, 0, 3) assert.Error(t, err) assert.Contains(t, err.Error(), "A record:") _, err = a.Decode([]byte{1, 2, 3, 4, 5}, 0, 5) assert.Error(t, err) assert.Contains(t, err.Error(), "A record:") addr = net.ParseIP("192.168.1.1").To4() err = nil aEncode := &A{Address: addr} encoded := encodeRData(t, aEncode) assert.NoError(t, err) assert.Equal(t, rdataBytes, encoded) } func TestNSRecord(t *testing.T) { nsName := "ns1.example.com" offsets := make(map[string]uint16) rdataBytes, _ := encodeDomain([]byte{}, nsName, &offsets) ns := &NS{} offset, err := ns.Decode(rdataBytes, 0, len(rdataBytes)) assert.NoError(t, err) assert.Equal(t, len(rdataBytes), offset) assert.Equal(t, nsName, ns.NSDName) _, err = ns.Decode(rdataBytes[:len(rdataBytes)-2], 0, len(rdataBytes)-2) assert.Error(t, err) assert.True(t, errors.Is(err, &BufferOverflowError{})) assert.ErrorContains(t, err, "NS record: failed to decode NSDName") nsEncode := &NS{NSDName: nsName} encoded := encodeRData(t, nsEncode) assert.Equal(t, rdataBytes, encoded) } func TestCNAMERecord(t *testing.T) { cname := "target.example.com" offsets := make(map[string]uint16) rdataBytes, _ := encodeDomain([]byte{}, cname, &offsets) c := &CNAME{} offset, err := c.Decode(rdataBytes, 0, len(rdataBytes)) assert.NoError(t, err) assert.Equal(t, len(rdataBytes), offset) assert.Equal(t, cname, c.CName) _, err = c.Decode(rdataBytes[:5], 0, 5) assert.Error(t, err) assert.ErrorContains(t, err, "CNAME record") cEncode := &CNAME{CName: cname} encoded := encodeRData(t, cEncode) assert.Equal(t, rdataBytes, encoded) } func TestSOARecord(t *testing.T) { mname := "ns.example.com" rname := "admin.example.com" serial := uint32(2023010101) refresh := uint32(7200) retry := uint32(3600) expire := uint32(1209600) minimum := uint32(3600) soaEncode := &SOA{MName: mname, RName: rname, Serial: serial, Refresh: refresh, Retry: retry, Expire: expire, Minimum: minimum} rdataBytes := encodeRData(t, soaEncode) soa := &SOA{} offset, err := soa.Decode(rdataBytes, 0, len(rdataBytes)) assert.NoError(t, err) assert.Equal(t, len(rdataBytes), offset) assert.Equal(t, *soaEncode, *soa) _, err = soa.Decode(rdataBytes[:len(rdataBytes)-5], 0, len(rdataBytes)-5) assert.Error(t, err) assert.ErrorContains(t, err, "SOA record:") nameOffset := make(map[string]uint16) mnameBytes, _ := encodeDomain([]byte{}, mname, &nameOffset) rnameBytes, _ := encodeDomain([]byte{}, rname, &nameOffset) shortRdataBytes := append(mnameBytes, rnameBytes...) _, err = soa.Decode(shortRdataBytes, 0, len(shortRdataBytes)) assert.Error(t, err) assert.ErrorContains(t, err, "SOA record") } func TestPTRRecord(t *testing.T) { ptrName := "host.example.com" offsets := make(map[string]uint16) rdataBytes, _ := encodeDomain([]byte{}, ptrName, &offsets) ptr := &PTR{} offset, err := ptr.Decode(rdataBytes, 0, len(rdataBytes)) assert.NoError(t, err) assert.Equal(t, len(rdataBytes), offset) assert.Equal(t, ptrName, ptr.PTRDName) _, err = ptr.Decode(rdataBytes[:3], 0, 3) assert.Error(t, err) assert.ErrorContains(t, err, "PTR record: failed to decode PTRDName") ptrEncode := &PTR{PTRDName: ptrName} encoded := encodeRData(t, ptrEncode) assert.Equal(t, rdataBytes, encoded) } func TestMXRecord(t *testing.T) { preference := uint16(10) exchange := "mail.example.com" mxEncode := &MX{Preference: preference, Exchange: exchange} rdataBytes := encodeRData(t, mxEncode) mx := &MX{} offset, err := mx.Decode(rdataBytes, 0, len(rdataBytes)) assert.NoError(t, err) assert.Equal(t, len(rdataBytes), offset) assert.Equal(t, *mxEncode, *mx) _, err = mx.Decode([]byte{0}, 0, 1) assert.Error(t, err) assert.ErrorContains(t, err, "MX record") buf := make([]byte, 2) binary.BigEndian.PutUint16(buf, preference) buf = append(buf, []byte{4, 'm', 'a'}...) _, err = mx.Decode(buf, 0, len(buf)) assert.Error(t, err) assert.ErrorContains(t, err, "MX record: failed to decode Exchange") } func TestTXTRecord(t *testing.T) { txtData := []string{"abc", "def"} txtEncode := &TXT{TxtData: txtData} rdataBytes := encodeRData(t, txtEncode) txt := &TXT{} offset, err := txt.Decode(rdataBytes, 0, len(rdataBytes)) require.NoError(t, err, "TXT Decode failed") assert.Equal(t, len(rdataBytes), offset) assert.Equal(t, txtData, txt.TxtData, "Decoded TXT data mismatch") txtDataEmpty := []string{""} txtEncodeEmpty := &TXT{TxtData: txtDataEmpty} rdataBytesEmpty := encodeRData(t, txtEncodeEmpty) offset, err = txt.Decode(rdataBytesEmpty, 0, len(rdataBytesEmpty)) require.NoError(t, err, "TXT Decode with empty string failed") assert.Equal(t, len(rdataBytesEmpty), offset) assert.Equal(t, txtDataEmpty, txt.TxtData) txtDataMulti := []string{"v=spf1", "include:_spf.google.com", "~all"} txtEncodeMulti := &TXT{TxtData: txtDataMulti} rdataBytesMulti := encodeRData(t, txtEncodeMulti) offset, err = txt.Decode(rdataBytesMulti, 0, len(rdataBytesMulti)) require.NoError(t, err, "TXT Decode with multiple strings failed") assert.Equal(t, len(rdataBytesMulti), offset) assert.Equal(t, txtDataMulti, txt.TxtData) _, err = txt.Decode([]byte{}, 0, 0) assert.NoError(t, err) _, err = txt.Decode([]byte{5, 'd', 'a', 't'}, 0, 4) assert.Error(t, err) assert.ErrorContains(t, err, "TXT record: string segment length 5 exceeds RDLENGTH boundary 4") encoded := encodeRData(t, txtEncode) assert.Equal(t, rdataBytes, encoded) } func TestHINFORecord(t *testing.T) { cpu := "Intel" os := "Linux" hinfoEncode := &HINFO{CPU: cpu, OS: os} rdataBytes := encodeRData(t, hinfoEncode) hinfo := &HINFO{} offset, err := hinfo.Decode(rdataBytes, 0, len(rdataBytes)) require.NoError(t, err, "HINFO Decode failed") assert.Equal(t, len(rdataBytes), offset) assert.Equal(t, cpu, hinfo.CPU) assert.Equal(t, os, hinfo.OS) hinfoEncodeEmpty := &HINFO{CPU: "", OS: ""} rdataBytesEmpty := encodeRData(t, hinfoEncodeEmpty) offset, err = hinfo.Decode(rdataBytesEmpty, 0, len(rdataBytesEmpty)) require.NoError(t, err, "HINFO Decode with empty strings failed") assert.Equal(t, len(rdataBytesEmpty), offset) assert.Equal(t, "", hinfo.CPU) assert.Equal(t, "", hinfo.OS) _, err = hinfo.Decode([]byte{}, 0, 0) assert.Error(t, err) assert.ErrorContains(t, err, "HINFO record:") _, err = hinfo.Decode([]byte{5, 'I', 'n'}, 0, 3) assert.Error(t, err) assert.ErrorContains(t, err, "buffer overflow:") _, err = hinfo.Decode([]byte{5, 'I', 'n', 't', 'e', 'l'}, 0, 6) assert.Error(t, err) assert.ErrorContains(t, err, "HINFO record:") _, err = hinfo.Decode([]byte{5, 'I', 'n', 't', 'e', 'l', 5, 'L', 'i'}, 0, 9) assert.Error(t, err) assert.ErrorContains(t, err, "buffer overflow:") extraData := append(rdataBytes, 0xFF) _, err = hinfo.Decode(extraData, 0, len(extraData)) assert.Error(t, err) assert.ErrorContains(t, err, "HINFO record:") _, err = hinfo.Decode([]byte{10, 'a', 'b', 'c'}, 0, 4) assert.Error(t, err) assert.ErrorContains(t, err, "buffer overflow:") } func TestWKSRecord(t *testing.T) { addr := net.ParseIP("192.168.1.1").To4() proto := byte(6) bitmap := []byte{0x01, 0x80} wksEncode := &WKS{Address: addr, Protocol: proto, BitMap: bitmap} rdataBytes := encodeRData(t, wksEncode) wks := &WKS{} offset, err := wks.Decode(rdataBytes, 0, len(rdataBytes)) require.NoError(t, err, "WKS Decode failed") assert.Equal(t, len(rdataBytes), offset) assert.Equal(t, addr, wks.Address.To4()) assert.Equal(t, proto, wks.Protocol) assert.Equal(t, bitmap, wks.BitMap) wksEncodeNoBitmap := &WKS{Address: addr, Protocol: proto, BitMap: []byte{}} rdataBytesNoBitmap := encodeRData(t, wksEncodeNoBitmap) wks = &WKS{} offset, err = wks.Decode(rdataBytesNoBitmap, 0, len(rdataBytesNoBitmap)) require.NoError(t, err, "WKS Decode without bitmap failed") assert.Equal(t, len(rdataBytesNoBitmap), offset) assert.Equal(t, addr, wks.Address.To4()) assert.Equal(t, proto, wks.Protocol) assert.Empty(t, wks.BitMap) _, err = wks.Decode([]byte{1, 2, 3, 4}, 0, 4) assert.Error(t, err) assert.ErrorContains(t, err, "WKS record: RDLENGTH 4 is too short") _, err = wks.Decode([]byte{1, 2, 3}, 0, 5) assert.Error(t, err) assert.ErrorContains(t, err, "WKS record: failed to read address") _, err = wks.Decode([]byte{1, 2, 3, 4}, 0, 5) assert.Error(t, err) assert.ErrorContains(t, err, "WKS record: failed to read protocol") _, err = wks.Decode([]byte{1, 2, 3, 4, 6, 0x01}, 0, 7) assert.Error(t, err) assert.ErrorContains(t, err, "WKS record: failed to read bitmap") } func TestReservedRecord(t *testing.T) { rdataBytes := []byte{0xDE, 0xAD, 0xBE, 0xEF} r := &Reserved{} offset, err := r.Decode(rdataBytes, 0, len(rdataBytes)) assert.NoError(t, err) assert.Equal(t, len(rdataBytes), offset) assert.Equal(t, rdataBytes, r.Bytes) _, err = r.Decode(rdataBytes[:2], 0, 4) assert.Error(t, err) assert.ErrorContains(t, err, "reserved record: failed to read data") rEncode := &Reserved{Bytes: rdataBytes} encoded := encodeRData(t, rEncode) assert.Equal(t, rdataBytes, encoded) rEncodeNil := &Reserved{Bytes: nil} encodedNil := encodeRData(t, rEncodeNil) assert.Empty(t, encodedNil) } func TestResourceRecordDecode(t *testing.T) { tests := []struct { name string input []byte expectedRR *ResourceRecord wantErr bool wantErrType error wantErrMsg string }{ { name: "Valid A record", input: buildRRBytes(t, "a.com", AType, IN, 60, []byte{1, 1, 1, 1}), expectedRR: &ResourceRecord{ Name: "a.com", RType: AType, RClass: IN, TTL: 60, RDLength: 4, RData: &A{net.IP{1, 1, 1, 1}}, }, }, { name: "Valid TXT record", input: buildRRBytes(t, "b.org", TXTType, IN, 300, encodeRData(t, &TXT{[]string{"hello", "world"}})), expectedRR: &ResourceRecord{ Name: "b.org", RType: TXTType, RClass: IN, TTL: 300, RDLength: 12, RData: &TXT{[]string{"hello", "world"}}, }, }, { name: "Unknown record type", input: buildRRBytes(t, "c.net", DNSType(9999), IN, 10, []byte{0xca, 0xfe}), expectedRR: &ResourceRecord{ Name: "c.net", RType: DNSType(9999), RClass: IN, TTL: 10, RDLength: 2, RData: &Reserved{[]byte{0xca, 0xfe}}, }, }, { name: "Truncated name", input: []byte{3, 'a', 'b'}, wantErr: true, wantErrType: &BufferOverflowError{}, wantErrMsg: "rr decode:", }, { name: "Truncated type", input: buildRRBytes(t, "d.com", AType, IN, 60, []byte{1})[:5], wantErr: true, wantErrType: &BufferOverflowError{}, wantErrMsg: "rr decode:", }, { name: "Truncated RDATA section", input: buildRRBytes(t, "e.com", AType, IN, 60, []byte{1, 2, 3, 4})[:15], wantErr: true, wantErrType: &BufferOverflowError{}, wantErrMsg: "rr decode:", }, { name: "RDLENGTH mismatch (claims longer than buffer)", input: func() []byte { buf := buildRRBytes(t, "f.com", AType, IN, 60, []byte{1, 2, 3, 4}) binary.BigEndian.PutUint16(buf[10:12], 10) return buf[:14] }(), wantErr: true, wantErrType: &BufferOverflowError{}, wantErrMsg: "rr decode:", }, { name: "RDLENGTH mismatch (RData decoder consumes less)", input: func() []byte { rdataBytes := encodeRData(t, &TXT{[]string{"short"}}) buf := buildRRBytes(t, "g.com", TXTType, IN, 60, rdataBytes) nameLen := len(buf) - 10 - len(rdataBytes) rdlenPos := nameLen + 8 binary.BigEndian.PutUint16(buf[rdlenPos:rdlenPos+2], uint16(len(rdataBytes)+5)) return buf }(), wantErr: true, wantErrMsg: "rr decode:", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { rr := &ResourceRecord{} offset, err := rr.Decode(tt.input, 0) if tt.wantErr { assert.Error(t, err) if tt.wantErrType != nil { assert.True(t, errors.Is(err, tt.wantErrType), "Error type mismatch. Got %T", err) } if tt.wantErrMsg != "" { assert.ErrorContains(t, err, tt.wantErrMsg) } } else { assert.NoError(t, err) assert.Equal(t, len(tt.input), offset, "Offset should match input length") assert.Equal(t, tt.expectedRR.Name, rr.Name) assert.Equal(t, tt.expectedRR.RType, rr.RType) assert.Equal(t, tt.expectedRR.RClass, rr.RClass) assert.Equal(t, tt.expectedRR.TTL, rr.TTL) assert.Equal(t, tt.expectedRR.RDLength, rr.RDLength) assert.Equal(t, tt.expectedRR.RData, rr.RData) } }) } } func TestResourceRecordEncode(t *testing.T) { tests := []struct { name string rr *ResourceRecord expectedLen int wantErr bool wantErrType error wantErrMsg string }{ { name: "Valid A record", rr: &ResourceRecord{Name: "a.com", RType: AType, RClass: IN, TTL: 60, RData: &A{net.IP{1, 1, 1, 1}}}, }, { name: "Valid TXT record", rr: &ResourceRecord{Name: "b.org", RType: TXTType, RClass: IN, TTL: 300, RData: &TXT{[]string{"hello", "world"}}}, }, { name: "Encode fail - Invalid Name", rr: &ResourceRecord{Name: "a..b", RType: AType, RClass: IN, TTL: 60, RData: &A{net.IP{1, 1, 1, 1}}}, wantErr: true, wantErrType: &InvalidLabelError{}, wantErrMsg: "rr encode: failed to encode record name a..b", }, { name: "Encode fail - Invalid RData (A record)", rr: &ResourceRecord{Name: "a.com", RType: AType, RClass: IN, TTL: 60, RData: &A{net.ParseIP("::1")}}, wantErr: true, wantErrMsg: "rr encode: failed to encode RData for a.com (A): A record: cannot encode non-IPv4 address", }, { name: "Encode fail - Invalid RData (TXT record)", rr: &ResourceRecord{Name: "b.org", RType: TXTType, RClass: IN, TTL: 300, RData: &TXT{[]string{string(make([]byte, 256))}}}, wantErr: true, wantErrMsg: "rr encode: failed to encode RData for b.org (TXT): TXT record: string segment length 256 exceeds maximum 255", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { offsets := make(map[string]uint16) encodedBytes, err := tt.rr.Encode([]byte{}, &offsets) if tt.wantErr { assert.Error(t, err) if tt.wantErrType != nil { assert.True(t, errors.Is(err, tt.wantErrType), "Error type mismatch. Got %T", err) } if tt.wantErrMsg != "" { assert.ErrorContains(t, err, tt.wantErrMsg) } } else { assert.NoError(t, err) assert.NotEmpty(t, encodedBytes) decodedRR := &ResourceRecord{} offset, decodeErr := decodedRR.Decode(encodedBytes, 0) assert.NoError(t, decodeErr, "Failed to decode back encoded RR") if decodeErr == nil { assert.Equal(t, len(encodedBytes), offset, "Decoded offset mismatch") assert.Equal(t, tt.rr.Name, decodedRR.Name) assert.Equal(t, tt.rr.RType, decodedRR.RType) assert.Equal(t, tt.rr.RClass, decodedRR.RClass) assert.Equal(t, tt.rr.TTL, decodedRR.TTL) if tt.rr.RData == nil { assert.IsType(t, &Reserved{}, decodedRR.RData, "Nil RData should decode as Reserved") assert.Empty(t, decodedRR.RData.(*Reserved).Bytes, "Nil RData should decode as empty Reserved") assert.Equal(t, uint16(0), decodedRR.RDLength, "Nil RData should have RDLength 0") } else { assert.Equal(t, tt.rr.RData, decodedRR.RData, "RData mismatch after round trip") assert.NotEqual(t, uint16(0), decodedRR.RDLength, "Non-nil RData should have non-zero RDLength") } } } }) } }