1package magna
2
3import (
4 "bytes"
5 "encoding/binary"
6 "errors"
7 "net"
8 "strings"
9 "testing"
10
11 "github.com/stretchr/testify/assert"
12 "github.com/stretchr/testify/require"
13)
14
15func TestMessageDecode(t *testing.T) {
16 buildQuery := func(id uint16, qname string, qtype DNSType, qclass DNSClass) []byte {
17 buf := new(bytes.Buffer)
18 binary.Write(buf, binary.BigEndian, id)
19 binary.Write(buf, binary.BigEndian, uint16(0x0100))
20 binary.Write(buf, binary.BigEndian, uint16(1))
21 binary.Write(buf, binary.BigEndian, uint16(0))
22 binary.Write(buf, binary.BigEndian, uint16(0))
23 binary.Write(buf, binary.BigEndian, uint16(0))
24 offsets := make(map[string]uint16)
25 qBytes, err := encodeDomain([]byte{}, qname, &offsets)
26 require.NoError(t, err)
27 buf.Write(qBytes)
28 binary.Write(buf, binary.BigEndian, uint16(qtype))
29 binary.Write(buf, binary.BigEndian, uint16(qclass))
30 return buf.Bytes()
31 }
32
33 buildAnswer := func(id uint16, name string, rtype DNSType, rclass DNSClass, ttl uint32, rdata ResourceRecordData) []byte {
34 buf := new(bytes.Buffer)
35
36 binary.Write(buf, binary.BigEndian, id)
37 binary.Write(buf, binary.BigEndian, uint16(0x8180))
38 binary.Write(buf, binary.BigEndian, uint16(0))
39 binary.Write(buf, binary.BigEndian, uint16(1))
40 binary.Write(buf, binary.BigEndian, uint16(0))
41 binary.Write(buf, binary.BigEndian, uint16(0))
42 rr := ResourceRecord{
43 Name: name,
44 RType: rtype,
45 RClass: rclass,
46 TTL: ttl,
47 RData: rdata,
48 }
49 offsets := make(map[string]uint16)
50 rrBytes, err := rr.Encode([]byte{}, &offsets)
51 require.NoError(t, err)
52 buf.Write(rrBytes)
53 return buf.Bytes()
54 }
55
56 tests := []struct {
57 name string
58 input []byte
59 expected Message
60 wantErr bool
61 wantErrType error
62 wantErrMsg string
63 }{
64 {
65 name: "Valid DNS query message with one question",
66 input: buildQuery(1234, "www.example.com", AType, IN),
67 expected: Message{
68 HasEDNS: false,
69 Header: Header{
70 ID: 1234,
71 QR: false,
72 RD: true,
73 OPCode: OPCode(0),
74 QDCount: 1,
75 RCode: NOERROR,
76 },
77 Question: []Question{
78 {
79 QName: "www.example.com",
80 QType: AType,
81 QClass: IN,
82 },
83 },
84 Answer: []ResourceRecord{},
85 Additional: []ResourceRecord{},
86 Authority: []ResourceRecord{},
87 },
88 wantErr: false,
89 },
90 {
91 name: "Valid DNS answer message with one A record",
92 input: buildAnswer(5678, "www.example.com", AType, IN, 3600,
93 &A{Address: net.ParseIP("10.0.0.1").To4()},
94 ),
95 expected: Message{
96 HasEDNS: false,
97 Header: Header{
98 ID: 5678,
99 QR: true,
100 OPCode: 0,
101 AA: false,
102 RD: true,
103 RA: true,
104 RCode: 0,
105 ANCount: 1,
106 },
107 Question: []Question{},
108 Answer: []ResourceRecord{
109 {
110 Name: "www.example.com",
111 RType: AType,
112 RClass: IN,
113 TTL: 3600,
114 RDLength: 4,
115 RData: &A{Address: net.IP([]byte{10, 0, 0, 1})},
116 },
117 },
118 Additional: []ResourceRecord{},
119 Authority: []ResourceRecord{},
120 },
121 wantErr: false,
122 },
123 {
124 name: "Invalid input - empty buffer",
125 input: []byte{},
126 wantErr: true,
127 wantErrType: &BufferOverflowError{},
128 wantErrMsg: "failed to decode message header: header decode: failed to read ID",
129 },
130 {
131 name: "Invalid input - truncated header",
132 input: []byte{0x12, 0x34},
133 wantErr: true,
134 wantErrType: &BufferOverflowError{},
135 wantErrMsg: "failed to decode message header: header decode: failed to read flags",
136 },
137 {
138 name: "Invalid input - truncated question name",
139 input: func() []byte {
140 buf := new(bytes.Buffer)
141 binary.Write(buf, binary.BigEndian, uint16(1235))
142 binary.Write(buf, binary.BigEndian, uint16(0x0100))
143 binary.Write(buf, binary.BigEndian, uint16(1))
144 binary.Write(buf, binary.BigEndian, uint16(0))
145 binary.Write(buf, binary.BigEndian, uint16(0))
146 binary.Write(buf, binary.BigEndian, uint16(0))
147 buf.Write([]byte{7, 'e', 'x', 'a'})
148 return buf.Bytes()
149 }(),
150 wantErr: true,
151 wantErrType: &BufferOverflowError{},
152 wantErrMsg: "failed to decode question #1:",
153 },
154 {
155 name: "Invalid input - truncated answer record data",
156 input: func() []byte {
157 buf := new(bytes.Buffer)
158
159 binary.Write(buf, binary.BigEndian, uint16(5679))
160 binary.Write(buf, binary.BigEndian, uint16(0x8180))
161 binary.Write(buf, binary.BigEndian, uint16(0))
162 binary.Write(buf, binary.BigEndian, uint16(1))
163 binary.Write(buf, binary.BigEndian, uint16(0))
164 binary.Write(buf, binary.BigEndian, uint16(0))
165
166 offsets := make(map[string]uint16)
167 nameBytes, _ := encodeDomain([]byte{}, "example.com", &offsets)
168 buf.Write(nameBytes)
169 binary.Write(buf, binary.BigEndian, uint16(AType))
170 binary.Write(buf, binary.BigEndian, uint16(IN))
171 binary.Write(buf, binary.BigEndian, uint32(300))
172 binary.Write(buf, binary.BigEndian, uint16(4))
173
174 buf.Write([]byte{192, 168})
175 return buf.Bytes()
176 }(),
177 wantErr: true,
178 wantErrType: &BufferOverflowError{},
179 wantErrMsg: "failed to decode answer record #1:",
180 },
181 {
182 name: "EDNS Record",
183 input: []byte{0xea, 0x7c, 0x1, 0x20, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x6, 0x6c, 0x6f, 0x62, 0x73, 0x74, 0x65, 0x2, 0x72, 0x73, 0x0, 0x0, 0x1, 0x0, 0x1, 0x0, 0x0, 0x29, 0x4, 0xd0, 0x0, 0x64, 0x0, 0x0, 0x0, 0x8, 0x0, 0x64, 0x0, 0x4, 0x66, 0x6f, 0x6f, 0xa},
184 wantErr: false,
185 expected: Message{
186 HasEDNS: true,
187 Header: Header{
188 ID: 0xea7c,
189 QR: false,
190 OPCode: 0,
191 RD: true,
192 RCode: 0,
193 QDCount: 1,
194 ARCount: 1,
195 },
196 Question: []Question{
197 {
198 QName: "lobste.rs",
199 QType: AType,
200 QClass: IN,
201 },
202 },
203 Answer: []ResourceRecord{},
204 Additional: []ResourceRecord{
205 {
206 Name: "",
207 RType: OPTType,
208 RClass: 1232,
209 TTL: 6553600,
210 RDLength: 8,
211 RData: &OPT{
212 []EDNSOption{
213 {
214 Code: uint16(100),
215 Data: []byte("foo\n"),
216 },
217 },
218 },
219 },
220 },
221 Authority: []ResourceRecord{},
222 EDNSOptions: []EDNSOption{
223 {
224 Code: 100,
225 Data: []byte("foo\n"),
226 },
227 },
228 EDNSVersion: 0x64,
229 UDPSize: 0x4d0,
230 },
231 },
232 }
233
234 for _, tt := range tests {
235 t.Run(tt.name, func(t *testing.T) {
236 m := &Message{}
237 err := m.Decode(tt.input)
238
239 if tt.wantErr {
240 assert.Error(t, err, "Expected an error but got nil")
241 if tt.wantErrType != nil {
242 assert.True(t, errors.Is(err, tt.wantErrType), "Error type mismatch. Got %T, expected %T", err, tt.wantErrType)
243 }
244 if tt.wantErrMsg != "" {
245 assert.ErrorContains(t, err, tt.wantErrMsg, "Error message mismatch")
246 }
247 } else {
248 assert.NoError(t, err, "Expected no error but got one")
249
250 assert.Equal(t, tt.expected.Header.ID, m.Header.ID, "Header ID mismatch")
251 assert.Equal(t, tt.expected.Header.QR, m.Header.QR, "Header QR mismatch")
252 assert.Equal(t, tt.expected.Header.OPCode, m.Header.OPCode, "Header OPCode mismatch")
253 assert.Equal(t, tt.expected.Header.RCode, m.Header.RCode, "Header RCode mismatch")
254 assert.Equal(t, tt.expected.Header.QDCount, m.Header.QDCount, "Header QDCount mismatch")
255 assert.Equal(t, tt.expected.Header.ANCount, m.Header.ANCount, "Header ANCount mismatch")
256
257 assert.Equal(t, tt.expected.Question, m.Question, "Question section mismatch")
258 assert.Equal(t, tt.expected.Answer, m.Answer, "Answer section mismatch")
259 assert.Equal(t, tt.expected.Authority, m.Authority, "Authority section mismatch")
260 assert.Equal(t, tt.expected.Additional, m.Additional, "Additional section mismatch")
261
262 assert.Equal(t, tt.expected.HasEDNS, m.HasEDNS, "HasEDNS mismatch")
263 if m.HasEDNS {
264 assert.Equal(t, tt.expected.EDNSOptions, m.EDNSOptions, "EDNS Options mismatch")
265 assert.Equal(t, tt.expected.ExtendedRCode, m.ExtendedRCode, "ExtendedRCode mismatch")
266 assert.Equal(t, tt.expected.EDNSVersion, m.EDNSVersion, "EDNSVersion mismatch")
267 assert.Equal(t, tt.expected.EDNSFlags, m.EDNSFlags, "EDNSFlags mismatch")
268 assert.Equal(t, tt.expected.UDPSize, m.UDPSize, "UDPSize mismatch")
269 }
270
271 b, err := m.Encode()
272 assert.NoError(t, err, "Expected no error on round trip")
273 assert.Equal(t, tt.input, b, "Expected equal inputs on round trip")
274 }
275 })
276 }
277}
278
279func TestMessageEncodeDecodeRoundTrip(t *testing.T) {
280 tests := []struct {
281 name string
282 message *Message
283 }{
284 {
285 name: "Query with one question",
286 message: CreateRequest(QUERY, true).AddQuestion(Question{
287 QName: "google.com",
288 QType: AType,
289 QClass: IN,
290 }),
291 },
292 {
293 name: "Response with one A answer",
294 message: &Message{
295 Header: Header{
296 ID: 12345, QR: true, OPCode: QUERY, RD: true, RA: true, RCode: NOERROR, ANCount: 1,
297 },
298 Question: []Question{},
299 Answer: []ResourceRecord{
300 {Name: "test.local", RType: AType, RClass: IN, TTL: 60, RDLength: 4, RData: &A{net.ParseIP("192.0.2.1").To4()}},
301 },
302 Additional: []ResourceRecord{},
303 Authority: []ResourceRecord{},
304 },
305 },
306 {
307 name: "Response with multiple answers and compression",
308 message: &Message{
309 Header: Header{ID: 54321, QR: true, RCode: NOERROR, ANCount: 2},
310 Question: []Question{},
311 Answer: []ResourceRecord{
312 {Name: "www.example.com", RType: AType, RClass: IN, TTL: 300, RDLength: 4, RData: &A{net.ParseIP("192.0.2.2").To4()}},
313 {Name: "mail.example.com", RType: AType, RClass: IN, TTL: 300, RDLength: 4, RData: &A{net.ParseIP("192.0.2.3").To4()}},
314 },
315 Additional: []ResourceRecord{},
316 Authority: []ResourceRecord{},
317 },
318 },
319 {
320 name: "Message with various record types",
321 message: &Message{
322 Header: Header{ID: 1111, QR: true, RCode: NOERROR, ANCount: 3},
323 Question: []Question{},
324 Answer: []ResourceRecord{
325 {Name: "example.com", RType: MXType, RClass: IN, TTL: 3600, RDLength: 9, RData: &MX{Preference: 10, Exchange: "mail.example.com"}},
326 {Name: "mail.example.com", RType: AType, RClass: IN, TTL: 300, RDLength: 4, RData: &A{net.ParseIP("192.0.2.4").To4()}},
327 {Name: "example.com", RType: TXTType, RClass: IN, TTL: 600, RDLength: 36, RData: &TXT{TxtData: []string{"v=spf1 include:_spf.google.com ~all"}}},
328 },
329 Additional: []ResourceRecord{},
330 Authority: []ResourceRecord{},
331 },
332 },
333 }
334
335 for _, tt := range tests {
336 t.Run(tt.name, func(t *testing.T) {
337 encodedBytes, err := tt.message.Encode()
338 require.NoError(t, err, "Encoding failed unexpectedly")
339 require.NotEmpty(t, encodedBytes, "Encoded bytes should not be empty")
340
341 decodedMsg := &Message{}
342 err = decodedMsg.Decode(encodedBytes)
343 require.NoError(t, err, "Decoding failed unexpectedly")
344
345 assert.Equal(t, tt.message.Header.ID, decodedMsg.Header.ID, "Header ID mismatch")
346 assert.Equal(t, tt.message.Header.QR, decodedMsg.Header.QR, "Header QR mismatch")
347 assert.Equal(t, tt.message.Header.OPCode, decodedMsg.Header.OPCode, "Header OPCode mismatch")
348 assert.Equal(t, tt.message.Header.RCode, decodedMsg.Header.RCode, "Header RCode mismatch")
349
350 assert.Equal(t, tt.message.Question, decodedMsg.Question, "Question section mismatch")
351 assert.Equal(t, tt.message.Answer, decodedMsg.Answer, "Answer section mismatch")
352 assert.Equal(t, tt.message.Authority, decodedMsg.Authority, "Authority section mismatch")
353 assert.Equal(t, tt.message.Additional, decodedMsg.Additional, "Additional section mismatch")
354 })
355 }
356}
357
358func FuzzDecodeMessage(f *testing.F) {
359 testcases := [][]byte{
360 {
361 0x8e, 0x19, 0x81, 0x80, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x04, 0x6e, 0x65,
362 0x77, 0x73, 0x0b, 0x79, 0x63, 0x6f, 0x6d, 0x62, 0x69, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x03,
363 0x63, 0x6f, 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01, 0xc0, 0x0c, 0x00, 0x01, 0x00, 0x01, 0x00,
364 0x00, 0x00, 0x01, 0x00, 0x04, 0xd1, 0xd8, 0xe6, 0xcf,
365 },
366 }
367 for _, tc := range testcases {
368 f.Add(tc)
369 }
370 f.Fuzz(func(t *testing.T, msg []byte) {
371 var m Message
372 err := m.Decode(msg)
373 if err != nil {
374 var bufErr *BufferOverflowError
375 var labelErr *InvalidLabelError
376 var compErr *DomainCompressionError
377 if !(errors.As(err, &bufErr) || errors.As(err, &labelErr) || errors.As(err, &compErr) || strings.Contains(err.Error(), "record:")) {
378 t.Errorf("FuzzDecodeMessage: unexpected error type %T: %v for input %x", err, err, msg)
379 }
380 }
381 })
382}