1package magna
2
3import (
4 "bytes"
5 "encoding/binary"
6 "errors"
7 "testing"
8
9 "github.com/stretchr/testify/assert"
10 "github.com/stretchr/testify/require"
11)
12
13func TestHeaderDecode(t *testing.T) {
14 tests := []struct {
15 name string
16 input []byte
17 expectedHeader Header
18 expectedOffset int
19 expectedErr error
20 wantErrMsg string
21 }{
22 {
23 name: "Valid header",
24 input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, 0x00, 0x04},
25 expectedHeader: Header{
26 ID: 0x1234,
27 QR: true,
28 OPCode: OPCode(0),
29 AA: false,
30 TC: false,
31 RD: true,
32 RA: true,
33 Z: 0,
34 RCode: RCode(0),
35 QDCount: 1,
36 ANCount: 2,
37 NSCount: 3,
38 ARCount: 4,
39 },
40 expectedOffset: 12,
41 expectedErr: nil,
42 },
43 {
44 name: "Insufficient buffer length for Flags",
45 input: []byte{0x12, 0x34, 0x81},
46 expectedHeader: Header{ID: 0x1234},
47 expectedOffset: 3,
48 expectedErr: &BufferOverflowError{},
49 wantErrMsg: "header decode: failed to read flags",
50 },
51 {
52 name: "Insufficient buffer length for ID",
53 input: []byte{0x12},
54 expectedHeader: Header{},
55 expectedOffset: 1,
56 expectedErr: &BufferOverflowError{},
57 wantErrMsg: "header decode: failed to read ID",
58 },
59 {
60 name: "Missing QDCount",
61 input: []byte{0x12, 0x34, 0x81, 0x80, 0x00},
62 expectedHeader: Header{ID: 0x1234},
63 expectedOffset: 5,
64 expectedErr: &BufferOverflowError{},
65 wantErrMsg: "header decode: failed to read QDCount",
66 },
67 {
68 name: "Missing ANCount",
69 input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00},
70 expectedHeader: Header{ID: 0x1234, QDCount: 1},
71 expectedOffset: 7,
72 expectedErr: &BufferOverflowError{},
73 wantErrMsg: "header decode: failed to read ANCount",
74 },
75 {
76 name: "Missing NSCount",
77 input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00},
78 expectedHeader: Header{ID: 0x1234, QDCount: 1, ANCount: 2},
79 expectedOffset: 9,
80 expectedErr: &BufferOverflowError{},
81 wantErrMsg: "header decode: failed to read NSCount",
82 },
83 {
84 name: "Missing ARCount",
85 input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, 0x00},
86 expectedHeader: Header{ID: 0x1234, QDCount: 1, ANCount: 2, NSCount: 3},
87 expectedOffset: 11,
88 expectedErr: &BufferOverflowError{},
89 wantErrMsg: "header decode: failed to read ARCount",
90 },
91 }
92
93 for _, tt := range tests {
94 t.Run(tt.name, func(t *testing.T) {
95 h := &Header{}
96 offset, err := h.Decode(tt.input, 0)
97
98 if tt.expectedErr != nil {
99 assert.Error(t, err, "Expected an error but got nil")
100
101 assert.True(t, errors.Is(err, tt.expectedErr), "Error type mismatch. Got %T, expected %T", err, tt.expectedErr)
102
103 if tt.wantErrMsg != "" {
104 assert.ErrorContains(t, err, tt.wantErrMsg, "Wrapped error message mismatch")
105 }
106
107 assert.Equal(t, tt.expectedOffset, offset, "Offset mismatch on error")
108 } else {
109 assert.NoError(t, err, "Expected no error but got one")
110
111 assert.Equal(t, tt.expectedHeader, *h, "Header content mismatch")
112 assert.Equal(t, tt.expectedOffset, offset, "Offset mismatch on success")
113 }
114 })
115 }
116}
117
118func TestHeaderDecodeFlags(t *testing.T) {
119 tests := []struct {
120 name string
121 flags uint16
122 expected Header
123 }{
124 {
125 name: "All flags set",
126 flags: 0xFFFF,
127 expected: Header{
128 QR: true,
129 OPCode: OPCode(15),
130 AA: true,
131 TC: true,
132 RD: true,
133 RA: true,
134 Z: 7,
135 RCode: RCode(15),
136 },
137 },
138 {
139 name: "No flags set",
140 flags: 0x0000,
141 expected: Header{
142 QR: false,
143 OPCode: OPCode(0),
144 AA: false,
145 TC: false,
146 RD: false,
147 RA: false,
148 Z: 0,
149 RCode: RCode(0),
150 },
151 },
152 {
153 name: "Mixed flags",
154 flags: 0x8510,
155 expected: Header{
156 QR: true,
157 OPCode: OPCode(0),
158 AA: true,
159 TC: false,
160 RD: true,
161 RA: false,
162 Z: 1,
163 RCode: RCode(0),
164 },
165 },
166 }
167
168 for _, tt := range tests {
169 t.Run(tt.name, func(t *testing.T) {
170 input := []byte{
171 0x00, 0x00,
172 byte(tt.flags >> 8), byte(tt.flags),
173 0x00, 0x00,
174 0x00, 0x00,
175 0x00, 0x00,
176 0x00, 0x00,
177 }
178
179 h := &Header{}
180 offset, err := h.Decode(input, 0)
181
182 assert.NoError(t, err)
183 assert.Equal(t, 12, offset, "Offset should be 12 after decoding full header")
184
185 assert.Equal(t, tt.expected.QR, h.QR, "QR flag mismatch")
186 assert.Equal(t, tt.expected.OPCode, h.OPCode, "OPCode mismatch")
187 assert.Equal(t, tt.expected.AA, h.AA, "AA flag mismatch")
188 assert.Equal(t, tt.expected.TC, h.TC, "TC flag mismatch")
189 assert.Equal(t, tt.expected.RD, h.RD, "RD flag mismatch")
190 assert.Equal(t, tt.expected.RA, h.RA, "RA flag mismatch")
191 assert.Equal(t, tt.expected.Z, h.Z, "Z value mismatch")
192 assert.Equal(t, tt.expected.RCode, h.RCode, "RCode mismatch")
193 })
194 }
195}
196
197func TestHeaderEncode(t *testing.T) {
198 tests := []struct {
199 name string
200 header Header
201 expected []byte
202 }{
203 {
204 name: "All fields set",
205 header: Header{
206 ID: 0x1234,
207 QR: true,
208 OPCode: OPCode(1),
209 AA: true,
210 TC: true,
211 RD: true,
212 RA: true,
213 Z: 5,
214 RCode: RCode(3),
215 QDCount: 1,
216 ANCount: 2,
217 NSCount: 3,
218 ARCount: 4,
219 },
220 expected: []byte{
221 0x12, 0x34,
222 0x8f, 0xd3,
223 0x00, 0x01,
224 0x00, 0x02,
225 0x00, 0x03,
226 0x00, 0x04,
227 },
228 },
229 {
230 name: "No flags set, different counts",
231 header: Header{
232 ID: 0x5678,
233 QR: false,
234 OPCode: OPCode(0),
235 AA: false,
236 TC: false,
237 RD: false,
238 RA: false,
239 Z: 0,
240 RCode: RCode(0),
241 QDCount: 5,
242 ANCount: 6,
243 NSCount: 7,
244 ARCount: 8,
245 },
246 expected: []byte{
247 0x56, 0x78,
248 0x00, 0x00,
249 0x00, 0x05,
250 0x00, 0x06,
251 0x00, 0x07,
252 0x00, 0x08,
253 },
254 },
255 {
256 name: "Mixed flags and counts",
257 header: Header{
258 ID: 0x9abc,
259 QR: true,
260 OPCode: OPCode(2),
261 AA: false,
262 TC: true,
263 RD: false,
264 RA: true,
265 Z: 3,
266 RCode: RCode(4),
267 QDCount: 9,
268 ANCount: 10,
269 NSCount: 11,
270 ARCount: 12,
271 },
272 expected: []byte{
273 0x9a, 0xbc,
274 0x92, 0xb4,
275 0x00, 0x09,
276 0x00, 0x0a,
277 0x00, 0x0b,
278 0x00, 0x0c,
279 },
280 },
281 }
282
283 for _, tt := range tests {
284 t.Run(tt.name, func(t *testing.T) {
285 encoded := tt.header.Encode()
286 assert.Equal(t, tt.expected, encoded, "Encoded header mismatch")
287 })
288 }
289}
290
291func TestHeaderEncodeDecodeRoundTrip(t *testing.T) {
292 originalHeader := Header{
293 ID: 0xdcba,
294 QR: true,
295 OPCode: OPCode(3),
296 AA: true,
297 TC: false,
298 RD: true,
299 RA: false,
300 Z: 6,
301 RCode: RCode(2),
302 QDCount: 13,
303 ANCount: 14,
304 NSCount: 15,
305 ARCount: 16,
306 }
307
308 encoded := originalHeader.Encode()
309 assert.Len(t, encoded, 12, "Encoded header should be 12 bytes")
310
311 decodedHeader := &Header{}
312 offset, err := decodedHeader.Decode(encoded, 0)
313
314 assert.NoError(t, err, "Decoding failed unexpectedly")
315 assert.Equal(t, 12, offset, "Offset after decoding should be 12")
316 assert.Equal(t, originalHeader, *decodedHeader, "Decoded header does not match original")
317}
318
319func TestHeaderEncodeFlagCombinations(t *testing.T) {
320 testCases := []struct {
321 name string
322 header Header
323 expected uint16
324 }{
325 {"QR flag", Header{QR: true}, 0x8000},
326 {"OPCode", Header{OPCode: OPCode(5)}, 0x2800},
327 {"AA flag", Header{AA: true}, 0x0400},
328 {"TC flag", Header{TC: true}, 0x0200},
329 {"RD flag", Header{RD: true}, 0x0100},
330 {"RA flag", Header{RA: true}, 0x0080},
331 {"Z value", Header{Z: 5}, 0x0050},
332 {"RCode", Header{RCode: RCode(7)}, 0x0007},
333 {"All flags set", Header{QR: true, OPCode: OPCode(15), AA: true, TC: true, RD: true, RA: true, Z: 7, RCode: RCode(15)}, 0xffff},
334 }
335
336 for _, tc := range testCases {
337 t.Run(tc.name, func(t *testing.T) {
338 encoded := tc.header.Encode()
339 require.Len(t, encoded, 12, "Encoded header length invariant")
340
341 flags := binary.BigEndian.Uint16(encoded[2:4])
342 assert.Equal(t, tc.expected, flags, "Flags value mismatch")
343 })
344 }
345}
346
347func FuzzDecodeHeader(f *testing.F) {
348 testcases := [][]byte{
349 {0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, 0x00, 0x04},
350 {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
351 {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
352 {0x12, 0x34},
353 {0x12, 0x34, 0x81, 0x80, 0x00},
354 {},
355 }
356
357 for _, tc := range testcases {
358 f.Add(tc)
359 }
360
361 f.Fuzz(func(t *testing.T, data []byte) {
362 h := &Header{}
363 offset, err := h.Decode(data, 0)
364 if err != nil {
365 var bofErr *BufferOverflowError
366 if !errors.As(err, &bofErr) {
367 t.Errorf("FuzzDecodeHeader: expected BufferOverflowError or wrapped BOF, got %T: %v", err, err)
368 }
369 if offset > len(data) {
370 t.Errorf("FuzzDecodeHeader: offset (%d) > data length (%d) on error", offset, len(data))
371 }
372 return
373 }
374
375 if len(data) < 12 {
376 t.Errorf("FuzzDecodeHeader: decoded successfully but input length %d < 12", len(data))
377 return
378 }
379 if offset != 12 {
380 t.Errorf("FuzzDecodeHeader: successful decode offset (%d) != 12", offset)
381 }
382
383 if h.OPCode > 15 {
384 t.Errorf("FuzzDecodeHeader: invalid OPCode decoded: %d", h.OPCode)
385 }
386 if h.Z > 7 {
387 t.Errorf("FuzzDecodeHeader: invalid Z value decoded: %d", h.Z)
388 }
389 if h.RCode > 15 {
390 t.Errorf("FuzzDecodeHeader: invalid RCode decoded: %d", h.RCode)
391 }
392
393 if len(data) >= 12 {
394 encoded := h.Encode()
395 if !bytes.Equal(encoded, data[:12]) {
396 t.Errorf("FuzzDecodeHeader: encode/decode mismatch\nInput: %x\nEncoded: %x", data[:12], encoded)
397 }
398 }
399 })
400}