a go dns packet parser
1package magna 2 3import ( 4 "encoding/binary" 5 "testing" 6 7 "github.com/stretchr/testify/assert" 8) 9 10func TestHeaderDecode(t *testing.T) { 11 tests := []struct { 12 name string 13 input []byte 14 expectedHeader Header 15 expectedOffset int 16 expectedErr error 17 }{ 18 { 19 name: "Valid header", 20 input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, 0x00, 0x04}, 21 expectedHeader: Header{ 22 ID: 0x1234, 23 QR: true, 24 OPCode: OPCode(0), 25 AA: false, 26 TC: false, 27 RD: true, 28 RA: true, 29 Z: 0, 30 RCode: RCode(0), 31 QDCount: 1, 32 ANCount: 2, 33 NSCount: 3, 34 ARCount: 4, 35 }, 36 expectedOffset: 12, 37 expectedErr: nil, 38 }, 39 { 40 name: "Insufficient buffer length", 41 input: []byte{0x12, 0x34, 0x81}, 42 expectedHeader: Header{}, 43 expectedOffset: 3, 44 expectedErr: &BufferOverflowError{Length: 3, Offset: 3}, 45 }, 46 { 47 name: "Invalid ID", 48 input: []byte{0x12}, 49 expectedHeader: Header{}, 50 expectedOffset: 1, 51 expectedErr: &BufferOverflowError{Length: 1, Offset: 1}, 52 }, 53 { 54 name: "Missing QDCount", 55 input: []byte{0x12, 0x34, 0x81, 0x80, 0x00}, 56 expectedHeader: Header{}, 57 expectedOffset: 5, 58 expectedErr: &BufferOverflowError{}, 59 }, 60 { 61 name: "Missing ANCount", 62 input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01}, 63 expectedHeader: Header{}, 64 expectedOffset: 6, 65 expectedErr: &BufferOverflowError{}, 66 }, 67 { 68 name: "Missing NSCount", 69 input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02}, 70 expectedHeader: Header{}, 71 expectedOffset: 8, 72 expectedErr: &BufferOverflowError{}, 73 }, 74 { 75 name: "Missing ARCount", 76 input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03}, 77 expectedHeader: Header{}, 78 expectedOffset: 10, 79 expectedErr: &BufferOverflowError{}, 80 }, 81 } 82 83 for _, tt := range tests { 84 t.Run(tt.name, func(t *testing.T) { 85 h := &Header{} 86 offset, err := h.Decode(tt.input, 0) 87 88 if tt.expectedErr != nil { 89 assert.Error(t, err) 90 assert.IsType(t, tt.expectedErr, err) 91 } else { 92 assert.NoError(t, err) 93 assert.Equal(t, tt.expectedHeader, *h) 94 } 95 96 assert.Equal(t, tt.expectedOffset, offset) 97 }) 98 } 99} 100 101func TestHeaderDecodeFlags(t *testing.T) { 102 tests := []struct { 103 name string 104 flags uint16 105 expected Header 106 }{ 107 { 108 name: "All flags set", 109 flags: 0xFFFF, 110 expected: Header{ 111 QR: true, 112 OPCode: OPCode(15), 113 AA: true, 114 TC: true, 115 RD: true, 116 RA: true, 117 Z: 7, 118 RCode: RCode(15), 119 }, 120 }, 121 { 122 name: "No flags set", 123 flags: 0x0000, 124 expected: Header{ 125 QR: false, 126 OPCode: OPCode(0), 127 AA: false, 128 TC: false, 129 RD: false, 130 RA: false, 131 Z: 0, 132 RCode: RCode(0), 133 }, 134 }, 135 { 136 name: "Mixed flags", 137 flags: 0x8510, 138 expected: Header{ 139 QR: true, 140 OPCode: OPCode(0), 141 AA: true, 142 TC: false, 143 RD: true, 144 RA: false, 145 Z: 1, 146 RCode: RCode(0), 147 }, 148 }, 149 } 150 151 for _, tt := range tests { 152 t.Run(tt.name, func(t *testing.T) { 153 input := []byte{ 154 0x00, 0x00, 155 byte(tt.flags >> 8), byte(tt.flags), 156 0x00, 0x00, 157 0x00, 0x00, 158 0x00, 0x00, 159 0x00, 0x00, 160 } 161 162 h := &Header{} 163 _, err := h.Decode(input, 0) 164 165 assert.NoError(t, err) 166 assert.Equal(t, tt.expected.QR, h.QR) 167 assert.Equal(t, tt.expected.OPCode, h.OPCode) 168 assert.Equal(t, tt.expected.AA, h.AA) 169 assert.Equal(t, tt.expected.TC, h.TC) 170 assert.Equal(t, tt.expected.RD, h.RD) 171 assert.Equal(t, tt.expected.RA, h.RA) 172 assert.Equal(t, tt.expected.Z, h.Z) 173 assert.Equal(t, tt.expected.RCode, h.RCode) 174 }) 175 } 176} 177 178func TestHeaderEncode(t *testing.T) { 179 tests := []struct { 180 name string 181 header Header 182 expected []byte 183 }{ 184 { 185 name: "All fields set", 186 header: Header{ 187 ID: 0x1234, 188 QR: true, 189 OPCode: OPCode(1), 190 AA: true, 191 TC: true, 192 RD: true, 193 RA: true, 194 Z: 5, 195 RCode: RCode(3), 196 QDCount: 1, 197 ANCount: 2, 198 NSCount: 3, 199 ARCount: 4, 200 }, 201 expected: []byte{ 202 0x12, 0x34, 203 0x8f, 0xd3, 204 0x00, 0x01, 205 0x00, 0x02, 206 0x00, 0x03, 207 0x00, 0x04, 208 }, 209 }, 210 { 211 name: "No flags set", 212 header: Header{ 213 ID: 0x5678, 214 QR: false, 215 OPCode: OPCode(0), 216 AA: false, 217 TC: false, 218 RD: false, 219 RA: false, 220 Z: 0, 221 RCode: RCode(0), 222 QDCount: 5, 223 ANCount: 6, 224 NSCount: 7, 225 ARCount: 8, 226 }, 227 expected: []byte{ 228 0x56, 0x78, 229 0x00, 0x00, 230 0x00, 0x05, 231 0x00, 0x06, 232 0x00, 0x07, 233 0x00, 0x08, 234 }, 235 }, 236 { 237 name: "Mixed flags", 238 header: Header{ 239 ID: 0x9abc, 240 QR: true, 241 OPCode: OPCode(2), 242 AA: false, 243 TC: true, 244 RD: false, 245 RA: true, 246 Z: 3, 247 RCode: RCode(4), 248 QDCount: 9, 249 ANCount: 10, 250 NSCount: 11, 251 ARCount: 12, 252 }, 253 expected: []byte{ 254 0x9a, 0xbc, 255 0x92, 0xb4, 256 0x00, 0x09, 257 0x00, 0x0a, 258 0x00, 0x0b, 259 0x00, 0x0c, 260 }, 261 }, 262 } 263 264 for _, tt := range tests { 265 t.Run(tt.name, func(t *testing.T) { 266 encoded := tt.header.Encode() 267 assert.Equal(t, tt.expected, encoded) 268 }) 269 } 270} 271 272func TestHeaderEncodeDecodeRoundTrip(t *testing.T) { 273 originalHeader := Header{ 274 ID: 0xdcba, 275 QR: true, 276 OPCode: OPCode(3), 277 AA: true, 278 TC: false, 279 RD: true, 280 RA: false, 281 Z: 6, 282 RCode: RCode(2), 283 QDCount: 13, 284 ANCount: 14, 285 NSCount: 15, 286 ARCount: 16, 287 } 288 289 encoded := originalHeader.Encode() 290 291 decodedHeader := &Header{} 292 offset, err := decodedHeader.Decode(encoded, 0) 293 294 assert.NoError(t, err) 295 assert.Equal(t, len(encoded), offset) 296 assert.Equal(t, originalHeader, *decodedHeader) 297} 298 299func TestHeaderEncodeFlagCombinations(t *testing.T) { 300 testCases := []struct { 301 name string 302 header Header 303 expected uint16 304 }{ 305 {"QR flag", Header{QR: true}, 0x8000}, 306 {"OPCode", Header{OPCode: OPCode(5)}, 0x2800}, 307 {"AA flag", Header{AA: true}, 0x0400}, 308 {"TC flag", Header{TC: true}, 0x0200}, 309 {"RD flag", Header{RD: true}, 0x0100}, 310 {"RA flag", Header{RA: true}, 0x0080}, 311 {"Z value", Header{Z: 5}, 0x0050}, 312 {"RCode", Header{RCode: RCode(7)}, 0x0007}, 313 {"All flags set", Header{QR: true, OPCode: OPCode(15), AA: true, TC: true, RD: true, RA: true, Z: 7, RCode: RCode(15)}, 0xffff}, 314 } 315 316 for _, tc := range testCases { 317 t.Run(tc.name, func(t *testing.T) { 318 encoded := tc.header.Encode() 319 flags := binary.BigEndian.Uint16(encoded[2:4]) 320 assert.Equal(t, tc.expected, flags) 321 }) 322 } 323} 324 325func FuzzDecodeHeader(f *testing.F) { 326 testcases := [][]byte{ 327 {0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, 0x00, 0x04}, 328 {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, 329 {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 330 } 331 332 for _, tc := range testcases { 333 f.Add(tc) 334 } 335 336 f.Fuzz(func(t *testing.T, data []byte) { 337 // limit to only 12 bytes 338 if len(data) > 12 { 339 data = data[0:12] 340 } 341 342 h := &Header{} 343 offset, err := h.Decode(data, 0) 344 if err != nil { 345 switch err.(type) { 346 case *BufferOverflowError, *InvalidLabelError: 347 // these are expected error types 348 default: 349 t.Errorf("unexpected error type: %T", err) 350 } 351 return 352 } 353 354 if offset != len(data) { 355 t.Errorf("offset (%d) does not match data length (%d)", offset, len(data)) 356 } 357 358 if h.OPCode > 15 { 359 t.Errorf("invalid OPCode: %d", h.OPCode) 360 } 361 362 if h.Z > 7 { 363 t.Errorf("invalid Z value: %d", h.Z) 364 } 365 366 if h.RCode > 15 { 367 t.Errorf("invalid RCode: %d", h.RCode) 368 } 369 370 encoded := h.Encode() 371 if len(encoded) != len(data) { 372 t.Errorf("encoded length (%d) does not match input length (%d)", len(encoded), len(data)) 373 374 for i := 0; i < len(data); i++ { 375 t.Errorf("mismatch at position: %d: encoded %02x, input: %02x", i, encoded[i], data[i]) 376 } 377 } 378 }) 379}