1package magna
2
3import (
4 "encoding/binary"
5 "errors"
6 "net"
7 "testing"
8
9 "github.com/stretchr/testify/assert"
10 "github.com/stretchr/testify/require"
11)
12
13func buildRRBytes(t *testing.T, name string, rtype DNSType, rclass DNSClass, ttl uint32, rdataBytes []byte) []byte {
14 t.Helper()
15 buf := []byte{}
16 offsets := make(map[string]uint16)
17
18 encodedName, err := encodeDomain(buf, name, &offsets)
19 require.NoError(t, err, "Failed to encode name in test helper")
20 buf = encodedName
21
22 buf = binary.BigEndian.AppendUint16(buf, uint16(rtype))
23 buf = binary.BigEndian.AppendUint16(buf, uint16(rclass))
24 buf = binary.BigEndian.AppendUint32(buf, ttl)
25
26 buf = binary.BigEndian.AppendUint16(buf, uint16(len(rdataBytes)))
27 buf = append(buf, rdataBytes...)
28
29 return buf
30}
31
32func encodeRData(t *testing.T, rdata ResourceRecordData) []byte {
33 t.Helper()
34 buf := []byte{}
35 offsets := make(map[string]uint16)
36 encodedRData, err := rdata.Encode(buf, &offsets)
37 require.NoError(t, err, "Failed to encode RDATA in test helper")
38 return encodedRData
39}
40
41func TestARecord(t *testing.T) {
42 addr := net.ParseIP("192.168.1.1").To4()
43 rdataBytes := []byte(addr)
44 a := &A{}
45
46 offset, err := a.Decode([]byte{}, 0, 4)
47 assert.Error(t, err, "Decode should fail with empty buffer")
48 assert.True(t, errors.Is(err, &BufferOverflowError{}))
49
50 offset, err = a.Decode(rdataBytes, 0, 4)
51 assert.NoError(t, err)
52 assert.Equal(t, 4, offset)
53 assert.Equal(t, addr, a.Address)
54
55 _, err = a.Decode([]byte{1, 2, 3}, 0, 3)
56 assert.Error(t, err)
57 assert.Contains(t, err.Error(), "A record:")
58
59 _, err = a.Decode([]byte{1, 2, 3, 4, 5}, 0, 5)
60 assert.Error(t, err)
61 assert.Contains(t, err.Error(), "A record:")
62
63 addr = net.ParseIP("192.168.1.1").To4()
64 err = nil
65 aEncode := &A{Address: addr}
66 encoded := encodeRData(t, aEncode)
67 assert.NoError(t, err)
68 assert.Equal(t, rdataBytes, encoded)
69}
70
71func TestNSRecord(t *testing.T) {
72 nsName := "ns1.example.com"
73 offsets := make(map[string]uint16)
74 rdataBytes, _ := encodeDomain([]byte{}, nsName, &offsets)
75 ns := &NS{}
76
77 offset, err := ns.Decode(rdataBytes, 0, len(rdataBytes))
78 assert.NoError(t, err)
79 assert.Equal(t, len(rdataBytes), offset)
80 assert.Equal(t, nsName, ns.NSDName)
81
82 _, err = ns.Decode(rdataBytes[:len(rdataBytes)-2], 0, len(rdataBytes)-2)
83 assert.Error(t, err)
84 assert.True(t, errors.Is(err, &BufferOverflowError{}))
85 assert.ErrorContains(t, err, "NS record: failed to decode NSDName")
86
87 nsEncode := &NS{NSDName: nsName}
88 encoded := encodeRData(t, nsEncode)
89 assert.Equal(t, rdataBytes, encoded)
90}
91
92func TestCNAMERecord(t *testing.T) {
93 cname := "target.example.com"
94 offsets := make(map[string]uint16)
95 rdataBytes, _ := encodeDomain([]byte{}, cname, &offsets)
96 c := &CNAME{}
97
98 offset, err := c.Decode(rdataBytes, 0, len(rdataBytes))
99 assert.NoError(t, err)
100 assert.Equal(t, len(rdataBytes), offset)
101 assert.Equal(t, cname, c.CName)
102
103 _, err = c.Decode(rdataBytes[:5], 0, 5)
104 assert.Error(t, err)
105 assert.ErrorContains(t, err, "CNAME record")
106
107 cEncode := &CNAME{CName: cname}
108 encoded := encodeRData(t, cEncode)
109 assert.Equal(t, rdataBytes, encoded)
110}
111
112func TestSOARecord(t *testing.T) {
113 mname := "ns.example.com"
114 rname := "admin.example.com"
115 serial := uint32(2023010101)
116 refresh := uint32(7200)
117 retry := uint32(3600)
118 expire := uint32(1209600)
119 minimum := uint32(3600)
120
121 soaEncode := &SOA{MName: mname, RName: rname, Serial: serial, Refresh: refresh, Retry: retry, Expire: expire, Minimum: minimum}
122 rdataBytes := encodeRData(t, soaEncode)
123 soa := &SOA{}
124
125 offset, err := soa.Decode(rdataBytes, 0, len(rdataBytes))
126 assert.NoError(t, err)
127 assert.Equal(t, len(rdataBytes), offset)
128 assert.Equal(t, *soaEncode, *soa)
129
130 _, err = soa.Decode(rdataBytes[:len(rdataBytes)-5], 0, len(rdataBytes)-5)
131 assert.Error(t, err)
132 assert.ErrorContains(t, err, "SOA record:")
133
134 nameOffset := make(map[string]uint16)
135 mnameBytes, _ := encodeDomain([]byte{}, mname, &nameOffset)
136 rnameBytes, _ := encodeDomain([]byte{}, rname, &nameOffset)
137 shortRdataBytes := append(mnameBytes, rnameBytes...)
138 _, err = soa.Decode(shortRdataBytes, 0, len(shortRdataBytes))
139 assert.Error(t, err)
140 assert.ErrorContains(t, err, "SOA record")
141}
142
143func TestPTRRecord(t *testing.T) {
144 ptrName := "host.example.com"
145 offsets := make(map[string]uint16)
146 rdataBytes, _ := encodeDomain([]byte{}, ptrName, &offsets)
147 ptr := &PTR{}
148
149 offset, err := ptr.Decode(rdataBytes, 0, len(rdataBytes))
150 assert.NoError(t, err)
151 assert.Equal(t, len(rdataBytes), offset)
152 assert.Equal(t, ptrName, ptr.PTRDName)
153
154 _, err = ptr.Decode(rdataBytes[:3], 0, 3)
155 assert.Error(t, err)
156 assert.ErrorContains(t, err, "PTR record: failed to decode PTRDName")
157
158 ptrEncode := &PTR{PTRDName: ptrName}
159 encoded := encodeRData(t, ptrEncode)
160 assert.Equal(t, rdataBytes, encoded)
161}
162
163func TestMXRecord(t *testing.T) {
164 preference := uint16(10)
165 exchange := "mail.example.com"
166
167 mxEncode := &MX{Preference: preference, Exchange: exchange}
168 rdataBytes := encodeRData(t, mxEncode)
169 mx := &MX{}
170
171 offset, err := mx.Decode(rdataBytes, 0, len(rdataBytes))
172 assert.NoError(t, err)
173 assert.Equal(t, len(rdataBytes), offset)
174 assert.Equal(t, *mxEncode, *mx)
175
176 _, err = mx.Decode([]byte{0}, 0, 1)
177 assert.Error(t, err)
178 assert.ErrorContains(t, err, "MX record")
179
180 buf := make([]byte, 2)
181 binary.BigEndian.PutUint16(buf, preference)
182 buf = append(buf, []byte{4, 'm', 'a'}...)
183 _, err = mx.Decode(buf, 0, len(buf))
184 assert.Error(t, err)
185 assert.ErrorContains(t, err, "MX record: failed to decode Exchange")
186}
187
188func TestTXTRecord(t *testing.T) {
189 txtData := []string{"abc", "def"}
190 txtEncode := &TXT{TxtData: txtData}
191 rdataBytes := encodeRData(t, txtEncode)
192 txt := &TXT{}
193
194 offset, err := txt.Decode(rdataBytes, 0, len(rdataBytes))
195 require.NoError(t, err, "TXT Decode failed")
196 assert.Equal(t, len(rdataBytes), offset)
197 assert.Equal(t, txtData, txt.TxtData, "Decoded TXT data mismatch")
198
199 txtDataEmpty := []string{""}
200 txtEncodeEmpty := &TXT{TxtData: txtDataEmpty}
201 rdataBytesEmpty := encodeRData(t, txtEncodeEmpty)
202 offset, err = txt.Decode(rdataBytesEmpty, 0, len(rdataBytesEmpty))
203 require.NoError(t, err, "TXT Decode with empty string failed")
204 assert.Equal(t, len(rdataBytesEmpty), offset)
205 assert.Equal(t, txtDataEmpty, txt.TxtData)
206
207 txtDataMulti := []string{"v=spf1", "include:_spf.google.com", "~all"}
208 txtEncodeMulti := &TXT{TxtData: txtDataMulti}
209 rdataBytesMulti := encodeRData(t, txtEncodeMulti)
210 offset, err = txt.Decode(rdataBytesMulti, 0, len(rdataBytesMulti))
211 require.NoError(t, err, "TXT Decode with multiple strings failed")
212 assert.Equal(t, len(rdataBytesMulti), offset)
213 assert.Equal(t, txtDataMulti, txt.TxtData)
214
215 _, err = txt.Decode([]byte{}, 0, 0)
216 assert.NoError(t, err)
217
218 _, err = txt.Decode([]byte{5, 'd', 'a', 't'}, 0, 4)
219 assert.Error(t, err)
220 assert.ErrorContains(t, err, "TXT record: string segment length 5 exceeds RDLENGTH boundary 4")
221
222 encoded := encodeRData(t, txtEncode)
223 assert.Equal(t, rdataBytes, encoded)
224}
225
226func TestHINFORecord(t *testing.T) {
227 cpu := "Intel"
228 os := "Linux"
229 hinfoEncode := &HINFO{CPU: cpu, OS: os}
230 rdataBytes := encodeRData(t, hinfoEncode)
231 hinfo := &HINFO{}
232
233 offset, err := hinfo.Decode(rdataBytes, 0, len(rdataBytes))
234 require.NoError(t, err, "HINFO Decode failed")
235 assert.Equal(t, len(rdataBytes), offset)
236 assert.Equal(t, cpu, hinfo.CPU)
237 assert.Equal(t, os, hinfo.OS)
238
239 hinfoEncodeEmpty := &HINFO{CPU: "", OS: ""}
240 rdataBytesEmpty := encodeRData(t, hinfoEncodeEmpty)
241 offset, err = hinfo.Decode(rdataBytesEmpty, 0, len(rdataBytesEmpty))
242 require.NoError(t, err, "HINFO Decode with empty strings failed")
243 assert.Equal(t, len(rdataBytesEmpty), offset)
244 assert.Equal(t, "", hinfo.CPU)
245 assert.Equal(t, "", hinfo.OS)
246
247 _, err = hinfo.Decode([]byte{}, 0, 0)
248 assert.Error(t, err)
249 assert.ErrorContains(t, err, "HINFO record:")
250
251 _, err = hinfo.Decode([]byte{5, 'I', 'n'}, 0, 3)
252 assert.Error(t, err)
253 assert.ErrorContains(t, err, "buffer overflow:")
254
255 _, err = hinfo.Decode([]byte{5, 'I', 'n', 't', 'e', 'l'}, 0, 6)
256 assert.Error(t, err)
257 assert.ErrorContains(t, err, "HINFO record:")
258
259 _, err = hinfo.Decode([]byte{5, 'I', 'n', 't', 'e', 'l', 5, 'L', 'i'}, 0, 9)
260 assert.Error(t, err)
261 assert.ErrorContains(t, err, "buffer overflow:")
262
263 extraData := append(rdataBytes, 0xFF)
264 _, err = hinfo.Decode(extraData, 0, len(extraData))
265 assert.Error(t, err)
266 assert.ErrorContains(t, err, "HINFO record:")
267
268 _, err = hinfo.Decode([]byte{10, 'a', 'b', 'c'}, 0, 4)
269 assert.Error(t, err)
270 assert.ErrorContains(t, err, "buffer overflow:")
271}
272
273func TestWKSRecord(t *testing.T) {
274 addr := net.ParseIP("192.168.1.1").To4()
275 proto := byte(6)
276 bitmap := []byte{0x01, 0x80}
277
278 wksEncode := &WKS{Address: addr, Protocol: proto, BitMap: bitmap}
279 rdataBytes := encodeRData(t, wksEncode)
280 wks := &WKS{}
281
282 offset, err := wks.Decode(rdataBytes, 0, len(rdataBytes))
283 require.NoError(t, err, "WKS Decode failed")
284 assert.Equal(t, len(rdataBytes), offset)
285 assert.Equal(t, addr, wks.Address.To4())
286 assert.Equal(t, proto, wks.Protocol)
287 assert.Equal(t, bitmap, wks.BitMap)
288
289 wksEncodeNoBitmap := &WKS{Address: addr, Protocol: proto, BitMap: []byte{}}
290 rdataBytesNoBitmap := encodeRData(t, wksEncodeNoBitmap)
291 wks = &WKS{}
292 offset, err = wks.Decode(rdataBytesNoBitmap, 0, len(rdataBytesNoBitmap))
293 require.NoError(t, err, "WKS Decode without bitmap failed")
294 assert.Equal(t, len(rdataBytesNoBitmap), offset)
295 assert.Equal(t, addr, wks.Address.To4())
296 assert.Equal(t, proto, wks.Protocol)
297 assert.Empty(t, wks.BitMap)
298
299 _, err = wks.Decode([]byte{1, 2, 3, 4}, 0, 4)
300 assert.Error(t, err)
301 assert.ErrorContains(t, err, "WKS record: RDLENGTH 4 is too short")
302
303 _, err = wks.Decode([]byte{1, 2, 3}, 0, 5)
304 assert.Error(t, err)
305 assert.ErrorContains(t, err, "WKS record: failed to read address")
306
307 _, err = wks.Decode([]byte{1, 2, 3, 4}, 0, 5)
308 assert.Error(t, err)
309 assert.ErrorContains(t, err, "WKS record: failed to read protocol")
310
311 _, err = wks.Decode([]byte{1, 2, 3, 4, 6, 0x01}, 0, 7)
312 assert.Error(t, err)
313 assert.ErrorContains(t, err, "WKS record: failed to read bitmap")
314}
315
316func TestReservedRecord(t *testing.T) {
317 rdataBytes := []byte{0xDE, 0xAD, 0xBE, 0xEF}
318 r := &Reserved{}
319
320 offset, err := r.Decode(rdataBytes, 0, len(rdataBytes))
321 assert.NoError(t, err)
322 assert.Equal(t, len(rdataBytes), offset)
323 assert.Equal(t, rdataBytes, r.Bytes)
324
325 _, err = r.Decode(rdataBytes[:2], 0, 4)
326 assert.Error(t, err)
327 assert.ErrorContains(t, err, "reserved record: failed to read data")
328
329 rEncode := &Reserved{Bytes: rdataBytes}
330 encoded := encodeRData(t, rEncode)
331 assert.Equal(t, rdataBytes, encoded)
332
333 rEncodeNil := &Reserved{Bytes: nil}
334 encodedNil := encodeRData(t, rEncodeNil)
335 assert.Empty(t, encodedNil)
336}
337
338func TestResourceRecordDecode(t *testing.T) {
339 tests := []struct {
340 name string
341 input []byte
342 expectedRR *ResourceRecord
343 wantErr bool
344 wantErrType error
345 wantErrMsg string
346 }{
347 {
348 name: "Valid A record",
349 input: buildRRBytes(t, "a.com", AType, IN, 60, []byte{1, 1, 1, 1}),
350 expectedRR: &ResourceRecord{
351 Name: "a.com", RType: AType, RClass: IN, TTL: 60, RDLength: 4, RData: &A{net.IP{1, 1, 1, 1}},
352 },
353 },
354 {
355 name: "Valid TXT record",
356 input: buildRRBytes(t, "b.org", TXTType, IN, 300, encodeRData(t, &TXT{[]string{"hello", "world"}})),
357 expectedRR: &ResourceRecord{
358 Name: "b.org", RType: TXTType, RClass: IN, TTL: 300, RDLength: 12, RData: &TXT{[]string{"hello", "world"}},
359 },
360 },
361 {
362 name: "Unknown record type",
363 input: buildRRBytes(t, "c.net", DNSType(9999), IN, 10, []byte{0xca, 0xfe}),
364 expectedRR: &ResourceRecord{
365 Name: "c.net", RType: DNSType(9999), RClass: IN, TTL: 10, RDLength: 2, RData: &Reserved{[]byte{0xca, 0xfe}},
366 },
367 },
368 {
369 name: "Truncated name",
370 input: []byte{3, 'a', 'b'},
371 wantErr: true,
372 wantErrType: &BufferOverflowError{},
373 wantErrMsg: "rr decode:",
374 },
375 {
376 name: "Truncated type",
377 input: buildRRBytes(t, "d.com", AType, IN, 60, []byte{1})[:5],
378 wantErr: true,
379 wantErrType: &BufferOverflowError{},
380 wantErrMsg: "rr decode:",
381 },
382 {
383 name: "Truncated RDATA section",
384 input: buildRRBytes(t, "e.com", AType, IN, 60, []byte{1, 2, 3, 4})[:15],
385 wantErr: true,
386 wantErrType: &BufferOverflowError{},
387 wantErrMsg: "rr decode:",
388 },
389 {
390 name: "RDLENGTH mismatch (claims longer than buffer)",
391 input: func() []byte {
392 buf := buildRRBytes(t, "f.com", AType, IN, 60, []byte{1, 2, 3, 4})
393 binary.BigEndian.PutUint16(buf[10:12], 10)
394 return buf[:14]
395 }(),
396 wantErr: true,
397 wantErrType: &BufferOverflowError{},
398 wantErrMsg: "rr decode:",
399 },
400 {
401 name: "RDLENGTH mismatch (RData decoder consumes less)",
402 input: func() []byte {
403 rdataBytes := encodeRData(t, &TXT{[]string{"short"}})
404 buf := buildRRBytes(t, "g.com", TXTType, IN, 60, rdataBytes)
405 nameLen := len(buf) - 10 - len(rdataBytes)
406 rdlenPos := nameLen + 8
407 binary.BigEndian.PutUint16(buf[rdlenPos:rdlenPos+2], uint16(len(rdataBytes)+5))
408 return buf
409 }(),
410 wantErr: true,
411 wantErrMsg: "rr decode:",
412 },
413 }
414
415 for _, tt := range tests {
416 t.Run(tt.name, func(t *testing.T) {
417 rr := &ResourceRecord{}
418 offset, err := rr.Decode(tt.input, 0)
419
420 if tt.wantErr {
421 assert.Error(t, err)
422 if tt.wantErrType != nil {
423 assert.True(t, errors.Is(err, tt.wantErrType), "Error type mismatch. Got %T", err)
424 }
425 if tt.wantErrMsg != "" {
426 assert.ErrorContains(t, err, tt.wantErrMsg)
427 }
428 } else {
429 assert.NoError(t, err)
430 assert.Equal(t, len(tt.input), offset, "Offset should match input length")
431 assert.Equal(t, tt.expectedRR.Name, rr.Name)
432 assert.Equal(t, tt.expectedRR.RType, rr.RType)
433 assert.Equal(t, tt.expectedRR.RClass, rr.RClass)
434 assert.Equal(t, tt.expectedRR.TTL, rr.TTL)
435 assert.Equal(t, tt.expectedRR.RDLength, rr.RDLength)
436 assert.Equal(t, tt.expectedRR.RData, rr.RData)
437 }
438 })
439 }
440}
441
442func TestResourceRecordEncode(t *testing.T) {
443 tests := []struct {
444 name string
445 rr *ResourceRecord
446 expectedLen int
447 wantErr bool
448 wantErrType error
449 wantErrMsg string
450 }{
451 {
452 name: "Valid A record",
453 rr: &ResourceRecord{Name: "a.com", RType: AType, RClass: IN, TTL: 60, RData: &A{net.IP{1, 1, 1, 1}}},
454 },
455 {
456 name: "Valid TXT record",
457 rr: &ResourceRecord{Name: "b.org", RType: TXTType, RClass: IN, TTL: 300, RData: &TXT{[]string{"hello", "world"}}},
458 },
459 {
460 name: "Encode fail - Invalid Name",
461 rr: &ResourceRecord{Name: "a..b", RType: AType, RClass: IN, TTL: 60, RData: &A{net.IP{1, 1, 1, 1}}},
462 wantErr: true,
463 wantErrType: &InvalidLabelError{},
464 wantErrMsg: "rr encode: failed to encode record name a..b",
465 },
466 {
467 name: "Encode fail - Invalid RData (A record)",
468 rr: &ResourceRecord{Name: "a.com", RType: AType, RClass: IN, TTL: 60, RData: &A{net.ParseIP("::1")}},
469 wantErr: true,
470 wantErrMsg: "rr encode: failed to encode RData for a.com (A): A record: cannot encode non-IPv4 address",
471 },
472 {
473 name: "Encode fail - Invalid RData (TXT record)",
474 rr: &ResourceRecord{Name: "b.org", RType: TXTType, RClass: IN, TTL: 300, RData: &TXT{[]string{string(make([]byte, 256))}}},
475 wantErr: true,
476 wantErrMsg: "rr encode: failed to encode RData for b.org (TXT): TXT record: string segment length 256 exceeds maximum 255",
477 },
478 }
479
480 for _, tt := range tests {
481 t.Run(tt.name, func(t *testing.T) {
482 offsets := make(map[string]uint16)
483 encodedBytes, err := tt.rr.Encode([]byte{}, &offsets)
484
485 if tt.wantErr {
486 assert.Error(t, err)
487 if tt.wantErrType != nil {
488 assert.True(t, errors.Is(err, tt.wantErrType), "Error type mismatch. Got %T", err)
489 }
490 if tt.wantErrMsg != "" {
491 assert.ErrorContains(t, err, tt.wantErrMsg)
492 }
493 } else {
494 assert.NoError(t, err)
495 assert.NotEmpty(t, encodedBytes)
496
497 decodedRR := &ResourceRecord{}
498 offset, decodeErr := decodedRR.Decode(encodedBytes, 0)
499 assert.NoError(t, decodeErr, "Failed to decode back encoded RR")
500 if decodeErr == nil {
501 assert.Equal(t, len(encodedBytes), offset, "Decoded offset mismatch")
502 assert.Equal(t, tt.rr.Name, decodedRR.Name)
503 assert.Equal(t, tt.rr.RType, decodedRR.RType)
504 assert.Equal(t, tt.rr.RClass, decodedRR.RClass)
505 assert.Equal(t, tt.rr.TTL, decodedRR.TTL)
506 if tt.rr.RData == nil {
507 assert.IsType(t, &Reserved{}, decodedRR.RData, "Nil RData should decode as Reserved")
508 assert.Empty(t, decodedRR.RData.(*Reserved).Bytes, "Nil RData should decode as empty Reserved")
509 assert.Equal(t, uint16(0), decodedRR.RDLength, "Nil RData should have RDLength 0")
510 } else {
511 assert.Equal(t, tt.rr.RData, decodedRR.RData, "RData mismatch after round trip")
512 assert.NotEqual(t, uint16(0), decodedRR.RDLength, "Non-nil RData should have non-zero RDLength")
513 }
514 }
515 }
516 })
517 }
518}
519
520func TestAAAARecord(t *testing.T) {
521 addr := net.ParseIP("2001:db8::1")
522 rdataBytes := []byte(addr)
523 aaaa := &AAAA{}
524
525 offset, err := aaaa.Decode([]byte{}, 0, 16)
526 assert.Error(t, err, "Decode should fail with empty buffer")
527 assert.True(t, errors.Is(err, &BufferOverflowError{}))
528
529 offset, err = aaaa.Decode(rdataBytes, 0, 16)
530 assert.NoError(t, err)
531 assert.Equal(t, 16, offset)
532 assert.Equal(t, addr, aaaa.Address)
533
534 _, err = aaaa.Decode([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, 0, 15)
535 assert.Error(t, err)
536 assert.Contains(t, err.Error(), "AAAA record:")
537
538 aaaaEncode := &AAAA{Address: addr}
539 encoded := encodeRData(t, aaaaEncode)
540 assert.Equal(t, rdataBytes, encoded)
541
542 ipv4 := net.ParseIP("192.168.1.1")
543 aaaaEncodeInvalid := &AAAA{Address: ipv4}
544 _, err = aaaaEncodeInvalid.Encode([]byte{}, &map[string]uint16{})
545
546 assert.Error(t, err)
547 assert.Contains(t, err.Error(), "cannot encode non-IPv6 address")
548}