a go dns packet parser
1package magna 2 3import ( 4 "bytes" 5 "encoding/binary" 6 "errors" 7 "testing" 8 9 "github.com/stretchr/testify/assert" 10 "github.com/stretchr/testify/require" 11) 12 13func TestHeaderDecode(t *testing.T) { 14 tests := []struct { 15 name string 16 input []byte 17 expectedHeader Header 18 expectedOffset int 19 expectedErr error 20 wantErrMsg string 21 }{ 22 { 23 name: "Valid header", 24 input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, 0x00, 0x04}, 25 expectedHeader: Header{ 26 ID: 0x1234, 27 QR: true, 28 OPCode: OPCode(0), 29 AA: false, 30 TC: false, 31 RD: true, 32 RA: true, 33 Z: 0, 34 RCode: RCode(0), 35 QDCount: 1, 36 ANCount: 2, 37 NSCount: 3, 38 ARCount: 4, 39 }, 40 expectedOffset: 12, 41 expectedErr: nil, 42 }, 43 { 44 name: "Insufficient buffer length for Flags", 45 input: []byte{0x12, 0x34, 0x81}, 46 expectedHeader: Header{ID: 0x1234}, 47 expectedOffset: 3, 48 expectedErr: &BufferOverflowError{}, 49 wantErrMsg: "header decode: failed to read flags", 50 }, 51 { 52 name: "Insufficient buffer length for ID", 53 input: []byte{0x12}, 54 expectedHeader: Header{}, 55 expectedOffset: 1, 56 expectedErr: &BufferOverflowError{}, 57 wantErrMsg: "header decode: failed to read ID", 58 }, 59 { 60 name: "Missing QDCount", 61 input: []byte{0x12, 0x34, 0x81, 0x80, 0x00}, 62 expectedHeader: Header{ID: 0x1234}, 63 expectedOffset: 5, 64 expectedErr: &BufferOverflowError{}, 65 wantErrMsg: "header decode: failed to read QDCount", 66 }, 67 { 68 name: "Missing ANCount", 69 input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00}, 70 expectedHeader: Header{ID: 0x1234, QDCount: 1}, 71 expectedOffset: 7, 72 expectedErr: &BufferOverflowError{}, 73 wantErrMsg: "header decode: failed to read ANCount", 74 }, 75 { 76 name: "Missing NSCount", 77 input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00}, 78 expectedHeader: Header{ID: 0x1234, QDCount: 1, ANCount: 2}, 79 expectedOffset: 9, 80 expectedErr: &BufferOverflowError{}, 81 wantErrMsg: "header decode: failed to read NSCount", 82 }, 83 { 84 name: "Missing ARCount", 85 input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, 0x00}, 86 expectedHeader: Header{ID: 0x1234, QDCount: 1, ANCount: 2, NSCount: 3}, 87 expectedOffset: 11, 88 expectedErr: &BufferOverflowError{}, 89 wantErrMsg: "header decode: failed to read ARCount", 90 }, 91 } 92 93 for _, tt := range tests { 94 t.Run(tt.name, func(t *testing.T) { 95 h := &Header{} 96 offset, err := h.Decode(tt.input, 0) 97 98 if tt.expectedErr != nil { 99 assert.Error(t, err, "Expected an error but got nil") 100 101 assert.True(t, errors.Is(err, tt.expectedErr), "Error type mismatch. Got %T, expected %T", err, tt.expectedErr) 102 103 if tt.wantErrMsg != "" { 104 assert.ErrorContains(t, err, tt.wantErrMsg, "Wrapped error message mismatch") 105 } 106 107 assert.Equal(t, tt.expectedOffset, offset, "Offset mismatch on error") 108 } else { 109 assert.NoError(t, err, "Expected no error but got one") 110 111 assert.Equal(t, tt.expectedHeader, *h, "Header content mismatch") 112 assert.Equal(t, tt.expectedOffset, offset, "Offset mismatch on success") 113 } 114 }) 115 } 116} 117 118func TestHeaderDecodeFlags(t *testing.T) { 119 tests := []struct { 120 name string 121 flags uint16 122 expected Header 123 }{ 124 { 125 name: "All flags set", 126 flags: 0xFFFF, 127 expected: Header{ 128 QR: true, 129 OPCode: OPCode(15), 130 AA: true, 131 TC: true, 132 RD: true, 133 RA: true, 134 Z: 7, 135 RCode: RCode(15), 136 }, 137 }, 138 { 139 name: "No flags set", 140 flags: 0x0000, 141 expected: Header{ 142 QR: false, 143 OPCode: OPCode(0), 144 AA: false, 145 TC: false, 146 RD: false, 147 RA: false, 148 Z: 0, 149 RCode: RCode(0), 150 }, 151 }, 152 { 153 name: "Mixed flags", 154 flags: 0x8510, 155 expected: Header{ 156 QR: true, 157 OPCode: OPCode(0), 158 AA: true, 159 TC: false, 160 RD: true, 161 RA: false, 162 Z: 1, 163 RCode: RCode(0), 164 }, 165 }, 166 } 167 168 for _, tt := range tests { 169 t.Run(tt.name, func(t *testing.T) { 170 input := []byte{ 171 0x00, 0x00, 172 byte(tt.flags >> 8), byte(tt.flags), 173 0x00, 0x00, 174 0x00, 0x00, 175 0x00, 0x00, 176 0x00, 0x00, 177 } 178 179 h := &Header{} 180 offset, err := h.Decode(input, 0) 181 182 assert.NoError(t, err) 183 assert.Equal(t, 12, offset, "Offset should be 12 after decoding full header") 184 185 assert.Equal(t, tt.expected.QR, h.QR, "QR flag mismatch") 186 assert.Equal(t, tt.expected.OPCode, h.OPCode, "OPCode mismatch") 187 assert.Equal(t, tt.expected.AA, h.AA, "AA flag mismatch") 188 assert.Equal(t, tt.expected.TC, h.TC, "TC flag mismatch") 189 assert.Equal(t, tt.expected.RD, h.RD, "RD flag mismatch") 190 assert.Equal(t, tt.expected.RA, h.RA, "RA flag mismatch") 191 assert.Equal(t, tt.expected.Z, h.Z, "Z value mismatch") 192 assert.Equal(t, tt.expected.RCode, h.RCode, "RCode mismatch") 193 }) 194 } 195} 196 197func TestHeaderEncode(t *testing.T) { 198 tests := []struct { 199 name string 200 header Header 201 expected []byte 202 }{ 203 { 204 name: "All fields set", 205 header: Header{ 206 ID: 0x1234, 207 QR: true, 208 OPCode: OPCode(1), 209 AA: true, 210 TC: true, 211 RD: true, 212 RA: true, 213 Z: 5, 214 RCode: RCode(3), 215 QDCount: 1, 216 ANCount: 2, 217 NSCount: 3, 218 ARCount: 4, 219 }, 220 expected: []byte{ 221 0x12, 0x34, 222 0x8f, 0xd3, 223 0x00, 0x01, 224 0x00, 0x02, 225 0x00, 0x03, 226 0x00, 0x04, 227 }, 228 }, 229 { 230 name: "No flags set, different counts", 231 header: Header{ 232 ID: 0x5678, 233 QR: false, 234 OPCode: OPCode(0), 235 AA: false, 236 TC: false, 237 RD: false, 238 RA: false, 239 Z: 0, 240 RCode: RCode(0), 241 QDCount: 5, 242 ANCount: 6, 243 NSCount: 7, 244 ARCount: 8, 245 }, 246 expected: []byte{ 247 0x56, 0x78, 248 0x00, 0x00, 249 0x00, 0x05, 250 0x00, 0x06, 251 0x00, 0x07, 252 0x00, 0x08, 253 }, 254 }, 255 { 256 name: "Mixed flags and counts", 257 header: Header{ 258 ID: 0x9abc, 259 QR: true, 260 OPCode: OPCode(2), 261 AA: false, 262 TC: true, 263 RD: false, 264 RA: true, 265 Z: 3, 266 RCode: RCode(4), 267 QDCount: 9, 268 ANCount: 10, 269 NSCount: 11, 270 ARCount: 12, 271 }, 272 expected: []byte{ 273 0x9a, 0xbc, 274 0x92, 0xb4, 275 0x00, 0x09, 276 0x00, 0x0a, 277 0x00, 0x0b, 278 0x00, 0x0c, 279 }, 280 }, 281 } 282 283 for _, tt := range tests { 284 t.Run(tt.name, func(t *testing.T) { 285 encoded := tt.header.Encode() 286 assert.Equal(t, tt.expected, encoded, "Encoded header mismatch") 287 }) 288 } 289} 290 291func TestHeaderEncodeDecodeRoundTrip(t *testing.T) { 292 originalHeader := Header{ 293 ID: 0xdcba, 294 QR: true, 295 OPCode: OPCode(3), 296 AA: true, 297 TC: false, 298 RD: true, 299 RA: false, 300 Z: 6, 301 RCode: RCode(2), 302 QDCount: 13, 303 ANCount: 14, 304 NSCount: 15, 305 ARCount: 16, 306 } 307 308 encoded := originalHeader.Encode() 309 assert.Len(t, encoded, 12, "Encoded header should be 12 bytes") 310 311 decodedHeader := &Header{} 312 offset, err := decodedHeader.Decode(encoded, 0) 313 314 assert.NoError(t, err, "Decoding failed unexpectedly") 315 assert.Equal(t, 12, offset, "Offset after decoding should be 12") 316 assert.Equal(t, originalHeader, *decodedHeader, "Decoded header does not match original") 317} 318 319func TestHeaderEncodeFlagCombinations(t *testing.T) { 320 testCases := []struct { 321 name string 322 header Header 323 expected uint16 324 }{ 325 {"QR flag", Header{QR: true}, 0x8000}, 326 {"OPCode", Header{OPCode: OPCode(5)}, 0x2800}, 327 {"AA flag", Header{AA: true}, 0x0400}, 328 {"TC flag", Header{TC: true}, 0x0200}, 329 {"RD flag", Header{RD: true}, 0x0100}, 330 {"RA flag", Header{RA: true}, 0x0080}, 331 {"Z value", Header{Z: 5}, 0x0050}, 332 {"RCode", Header{RCode: RCode(7)}, 0x0007}, 333 {"All flags set", Header{QR: true, OPCode: OPCode(15), AA: true, TC: true, RD: true, RA: true, Z: 7, RCode: RCode(15)}, 0xffff}, 334 } 335 336 for _, tc := range testCases { 337 t.Run(tc.name, func(t *testing.T) { 338 encoded := tc.header.Encode() 339 require.Len(t, encoded, 12, "Encoded header length invariant") 340 341 flags := binary.BigEndian.Uint16(encoded[2:4]) 342 assert.Equal(t, tc.expected, flags, "Flags value mismatch") 343 }) 344 } 345} 346 347func FuzzDecodeHeader(f *testing.F) { 348 testcases := [][]byte{ 349 {0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, 0x00, 0x04}, 350 {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, 351 {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 352 {0x12, 0x34}, 353 {0x12, 0x34, 0x81, 0x80, 0x00}, 354 {}, 355 } 356 357 for _, tc := range testcases { 358 f.Add(tc) 359 } 360 361 f.Fuzz(func(t *testing.T, data []byte) { 362 h := &Header{} 363 offset, err := h.Decode(data, 0) 364 if err != nil { 365 var bofErr *BufferOverflowError 366 if !errors.As(err, &bofErr) { 367 t.Errorf("FuzzDecodeHeader: expected BufferOverflowError or wrapped BOF, got %T: %v", err, err) 368 } 369 if offset > len(data) { 370 t.Errorf("FuzzDecodeHeader: offset (%d) > data length (%d) on error", offset, len(data)) 371 } 372 return 373 } 374 375 if len(data) < 12 { 376 t.Errorf("FuzzDecodeHeader: decoded successfully but input length %d < 12", len(data)) 377 return 378 } 379 if offset != 12 { 380 t.Errorf("FuzzDecodeHeader: successful decode offset (%d) != 12", offset) 381 } 382 383 if h.OPCode > 15 { 384 t.Errorf("FuzzDecodeHeader: invalid OPCode decoded: %d", h.OPCode) 385 } 386 if h.Z > 7 { 387 t.Errorf("FuzzDecodeHeader: invalid Z value decoded: %d", h.Z) 388 } 389 if h.RCode > 15 { 390 t.Errorf("FuzzDecodeHeader: invalid RCode decoded: %d", h.RCode) 391 } 392 393 if len(data) >= 12 { 394 encoded := h.Encode() 395 if !bytes.Equal(encoded, data[:12]) { 396 t.Errorf("FuzzDecodeHeader: encode/decode mismatch\nInput: %x\nEncoded: %x", data[:12], encoded) 397 } 398 } 399 }) 400}