package magna import ( "bytes" "encoding/binary" "errors" "net" "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestMessageDecode(t *testing.T) { buildQuery := func(id uint16, qname string, qtype DNSType, qclass DNSClass) []byte { buf := new(bytes.Buffer) binary.Write(buf, binary.BigEndian, id) binary.Write(buf, binary.BigEndian, uint16(0x0100)) binary.Write(buf, binary.BigEndian, uint16(1)) binary.Write(buf, binary.BigEndian, uint16(0)) binary.Write(buf, binary.BigEndian, uint16(0)) binary.Write(buf, binary.BigEndian, uint16(0)) offsets := make(map[string]uint16) qBytes, err := encodeDomain([]byte{}, qname, &offsets) require.NoError(t, err) buf.Write(qBytes) binary.Write(buf, binary.BigEndian, uint16(qtype)) binary.Write(buf, binary.BigEndian, uint16(qclass)) return buf.Bytes() } buildAnswer := func(id uint16, name string, rtype DNSType, rclass DNSClass, ttl uint32, rdata ResourceRecordData) []byte { buf := new(bytes.Buffer) binary.Write(buf, binary.BigEndian, id) binary.Write(buf, binary.BigEndian, uint16(0x8180)) binary.Write(buf, binary.BigEndian, uint16(0)) binary.Write(buf, binary.BigEndian, uint16(1)) binary.Write(buf, binary.BigEndian, uint16(0)) binary.Write(buf, binary.BigEndian, uint16(0)) rr := ResourceRecord{ Name: name, RType: rtype, RClass: rclass, TTL: ttl, RData: rdata, } offsets := make(map[string]uint16) rrBytes, err := rr.Encode([]byte{}, &offsets) require.NoError(t, err) buf.Write(rrBytes) return buf.Bytes() } tests := []struct { name string input []byte expected Message wantErr bool wantErrType error wantErrMsg string }{ { name: "Valid DNS query message with one question", input: buildQuery(1234, "www.example.com", AType, IN), expected: Message{ HasEDNS: false, Header: Header{ ID: 1234, QR: false, RD: true, OPCode: OPCode(0), QDCount: 1, RCode: NOERROR, }, Question: []Question{ { QName: "www.example.com", QType: AType, QClass: IN, }, }, Answer: []ResourceRecord{}, Additional: []ResourceRecord{}, Authority: []ResourceRecord{}, }, wantErr: false, }, { name: "Valid DNS answer message with one A record", input: buildAnswer(5678, "www.example.com", AType, IN, 3600, &A{Address: net.ParseIP("10.0.0.1").To4()}, ), expected: Message{ HasEDNS: false, Header: Header{ ID: 5678, QR: true, OPCode: 0, AA: false, RD: true, RA: true, RCode: 0, ANCount: 1, }, Question: []Question{}, Answer: []ResourceRecord{ { Name: "www.example.com", RType: AType, RClass: IN, TTL: 3600, RDLength: 4, RData: &A{Address: net.IP([]byte{10, 0, 0, 1})}, }, }, Additional: []ResourceRecord{}, Authority: []ResourceRecord{}, }, wantErr: false, }, { name: "Invalid input - empty buffer", input: []byte{}, wantErr: true, wantErrType: &BufferOverflowError{}, wantErrMsg: "failed to decode message header: header decode: failed to read ID", }, { name: "Invalid input - truncated header", input: []byte{0x12, 0x34}, wantErr: true, wantErrType: &BufferOverflowError{}, wantErrMsg: "failed to decode message header: header decode: failed to read flags", }, { name: "Invalid input - truncated question name", input: func() []byte { buf := new(bytes.Buffer) binary.Write(buf, binary.BigEndian, uint16(1235)) binary.Write(buf, binary.BigEndian, uint16(0x0100)) binary.Write(buf, binary.BigEndian, uint16(1)) binary.Write(buf, binary.BigEndian, uint16(0)) binary.Write(buf, binary.BigEndian, uint16(0)) binary.Write(buf, binary.BigEndian, uint16(0)) buf.Write([]byte{7, 'e', 'x', 'a'}) return buf.Bytes() }(), wantErr: true, wantErrType: &BufferOverflowError{}, wantErrMsg: "failed to decode question #1:", }, { name: "Invalid input - truncated answer record data", input: func() []byte { buf := new(bytes.Buffer) binary.Write(buf, binary.BigEndian, uint16(5679)) binary.Write(buf, binary.BigEndian, uint16(0x8180)) binary.Write(buf, binary.BigEndian, uint16(0)) binary.Write(buf, binary.BigEndian, uint16(1)) binary.Write(buf, binary.BigEndian, uint16(0)) binary.Write(buf, binary.BigEndian, uint16(0)) offsets := make(map[string]uint16) nameBytes, _ := encodeDomain([]byte{}, "example.com", &offsets) buf.Write(nameBytes) binary.Write(buf, binary.BigEndian, uint16(AType)) binary.Write(buf, binary.BigEndian, uint16(IN)) binary.Write(buf, binary.BigEndian, uint32(300)) binary.Write(buf, binary.BigEndian, uint16(4)) buf.Write([]byte{192, 168}) return buf.Bytes() }(), wantErr: true, wantErrType: &BufferOverflowError{}, wantErrMsg: "failed to decode answer record #1:", }, { name: "EDNS Record", input: []byte{0xea, 0x7c, 0x1, 0x20, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x6, 0x6c, 0x6f, 0x62, 0x73, 0x74, 0x65, 0x2, 0x72, 0x73, 0x0, 0x0, 0x1, 0x0, 0x1, 0x0, 0x0, 0x29, 0x4, 0xd0, 0x0, 0x64, 0x0, 0x0, 0x0, 0x8, 0x0, 0x64, 0x0, 0x4, 0x66, 0x6f, 0x6f, 0xa}, wantErr: false, expected: Message{ HasEDNS: true, Header: Header{ ID: 0xea7c, QR: false, OPCode: 0, RD: true, RCode: 0, QDCount: 1, ARCount: 1, }, Question: []Question{ { QName: "lobste.rs", QType: AType, QClass: IN, }, }, Answer: []ResourceRecord{}, Additional: []ResourceRecord{ { Name: "", RType: OPTType, RClass: 1232, TTL: 6553600, RDLength: 8, RData: &OPT{ []EDNSOption{ { Code: uint16(100), Data: []byte("foo\n"), }, }, }, }, }, Authority: []ResourceRecord{}, EDNSOptions: []EDNSOption{ { Code: 100, Data: []byte("foo\n"), }, }, EDNSVersion: 0x64, UDPSize: 0x4d0, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { m := &Message{} err := m.Decode(tt.input) if tt.wantErr { assert.Error(t, err, "Expected an error but got nil") if tt.wantErrType != nil { assert.True(t, errors.Is(err, tt.wantErrType), "Error type mismatch. Got %T, expected %T", err, tt.wantErrType) } if tt.wantErrMsg != "" { assert.ErrorContains(t, err, tt.wantErrMsg, "Error message mismatch") } } else { assert.NoError(t, err, "Expected no error but got one") assert.Equal(t, tt.expected.Header.ID, m.Header.ID, "Header ID mismatch") assert.Equal(t, tt.expected.Header.QR, m.Header.QR, "Header QR mismatch") assert.Equal(t, tt.expected.Header.OPCode, m.Header.OPCode, "Header OPCode mismatch") assert.Equal(t, tt.expected.Header.RCode, m.Header.RCode, "Header RCode mismatch") assert.Equal(t, tt.expected.Header.QDCount, m.Header.QDCount, "Header QDCount mismatch") assert.Equal(t, tt.expected.Header.ANCount, m.Header.ANCount, "Header ANCount mismatch") assert.Equal(t, tt.expected.Question, m.Question, "Question section mismatch") assert.Equal(t, tt.expected.Answer, m.Answer, "Answer section mismatch") assert.Equal(t, tt.expected.Authority, m.Authority, "Authority section mismatch") assert.Equal(t, tt.expected.Additional, m.Additional, "Additional section mismatch") assert.Equal(t, tt.expected.HasEDNS, m.HasEDNS, "HasEDNS mismatch") if m.HasEDNS { assert.Equal(t, tt.expected.EDNSOptions, m.EDNSOptions, "EDNS Options mismatch") assert.Equal(t, tt.expected.ExtendedRCode, m.ExtendedRCode, "ExtendedRCode mismatch") assert.Equal(t, tt.expected.EDNSVersion, m.EDNSVersion, "EDNSVersion mismatch") assert.Equal(t, tt.expected.EDNSFlags, m.EDNSFlags, "EDNSFlags mismatch") assert.Equal(t, tt.expected.UDPSize, m.UDPSize, "UDPSize mismatch") } b, err := m.Encode() assert.NoError(t, err, "Expected no error on round trip") assert.Equal(t, tt.input, b, "Expected equal inputs on round trip") } }) } } func TestMessageEncodeDecodeRoundTrip(t *testing.T) { tests := []struct { name string message *Message }{ { name: "Query with one question", message: CreateRequest(QUERY, true).AddQuestion(Question{ QName: "google.com", QType: AType, QClass: IN, }), }, { name: "Response with one A answer", message: &Message{ Header: Header{ ID: 12345, QR: true, OPCode: QUERY, RD: true, RA: true, RCode: NOERROR, ANCount: 1, }, Question: []Question{}, Answer: []ResourceRecord{ {Name: "test.local", RType: AType, RClass: IN, TTL: 60, RDLength: 4, RData: &A{net.ParseIP("192.0.2.1").To4()}}, }, Additional: []ResourceRecord{}, Authority: []ResourceRecord{}, }, }, { name: "Response with multiple answers and compression", message: &Message{ Header: Header{ID: 54321, QR: true, RCode: NOERROR, ANCount: 2}, Question: []Question{}, Answer: []ResourceRecord{ {Name: "www.example.com", RType: AType, RClass: IN, TTL: 300, RDLength: 4, RData: &A{net.ParseIP("192.0.2.2").To4()}}, {Name: "mail.example.com", RType: AType, RClass: IN, TTL: 300, RDLength: 4, RData: &A{net.ParseIP("192.0.2.3").To4()}}, }, Additional: []ResourceRecord{}, Authority: []ResourceRecord{}, }, }, { name: "Message with various record types", message: &Message{ Header: Header{ID: 1111, QR: true, RCode: NOERROR, ANCount: 3}, Question: []Question{}, Answer: []ResourceRecord{ {Name: "example.com", RType: MXType, RClass: IN, TTL: 3600, RDLength: 9, RData: &MX{Preference: 10, Exchange: "mail.example.com"}}, {Name: "mail.example.com", RType: AType, RClass: IN, TTL: 300, RDLength: 4, RData: &A{net.ParseIP("192.0.2.4").To4()}}, {Name: "example.com", RType: TXTType, RClass: IN, TTL: 600, RDLength: 36, RData: &TXT{TxtData: []string{"v=spf1 include:_spf.google.com ~all"}}}, }, Additional: []ResourceRecord{}, Authority: []ResourceRecord{}, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { encodedBytes, err := tt.message.Encode() require.NoError(t, err, "Encoding failed unexpectedly") require.NotEmpty(t, encodedBytes, "Encoded bytes should not be empty") decodedMsg := &Message{} err = decodedMsg.Decode(encodedBytes) require.NoError(t, err, "Decoding failed unexpectedly") assert.Equal(t, tt.message.Header.ID, decodedMsg.Header.ID, "Header ID mismatch") assert.Equal(t, tt.message.Header.QR, decodedMsg.Header.QR, "Header QR mismatch") assert.Equal(t, tt.message.Header.OPCode, decodedMsg.Header.OPCode, "Header OPCode mismatch") assert.Equal(t, tt.message.Header.RCode, decodedMsg.Header.RCode, "Header RCode mismatch") assert.Equal(t, tt.message.Question, decodedMsg.Question, "Question section mismatch") assert.Equal(t, tt.message.Answer, decodedMsg.Answer, "Answer section mismatch") assert.Equal(t, tt.message.Authority, decodedMsg.Authority, "Authority section mismatch") assert.Equal(t, tt.message.Additional, decodedMsg.Additional, "Additional section mismatch") }) } } func FuzzDecodeMessage(f *testing.F) { testcases := [][]byte{ { 0x8e, 0x19, 0x81, 0x80, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x04, 0x6e, 0x65, 0x77, 0x73, 0x0b, 0x79, 0x63, 0x6f, 0x6d, 0x62, 0x69, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x03, 0x63, 0x6f, 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01, 0xc0, 0x0c, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x04, 0xd1, 0xd8, 0xe6, 0xcf, }, } for _, tc := range testcases { f.Add(tc) } f.Fuzz(func(t *testing.T, msg []byte) { var m Message err := m.Decode(msg) if err != nil { var bufErr *BufferOverflowError var labelErr *InvalidLabelError var compErr *DomainCompressionError if !(errors.As(err, &bufErr) || errors.As(err, &labelErr) || errors.As(err, &compErr) || strings.Contains(err.Error(), "record:")) { t.Errorf("FuzzDecodeMessage: unexpected error type %T: %v for input %x", err, err, msg) } } }) }