package magna import ( "bytes" "encoding/binary" "errors" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestHeaderDecode(t *testing.T) { tests := []struct { name string input []byte expectedHeader Header expectedOffset int expectedErr error wantErrMsg string }{ { name: "Valid header", input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, 0x00, 0x04}, expectedHeader: Header{ ID: 0x1234, QR: true, OPCode: OPCode(0), AA: false, TC: false, RD: true, RA: true, Z: 0, RCode: RCode(0), QDCount: 1, ANCount: 2, NSCount: 3, ARCount: 4, }, expectedOffset: 12, expectedErr: nil, }, { name: "Insufficient buffer length for Flags", input: []byte{0x12, 0x34, 0x81}, expectedHeader: Header{ID: 0x1234}, expectedOffset: 3, expectedErr: &BufferOverflowError{}, wantErrMsg: "header decode: failed to read flags", }, { name: "Insufficient buffer length for ID", input: []byte{0x12}, expectedHeader: Header{}, expectedOffset: 1, expectedErr: &BufferOverflowError{}, wantErrMsg: "header decode: failed to read ID", }, { name: "Missing QDCount", input: []byte{0x12, 0x34, 0x81, 0x80, 0x00}, expectedHeader: Header{ID: 0x1234}, expectedOffset: 5, expectedErr: &BufferOverflowError{}, wantErrMsg: "header decode: failed to read QDCount", }, { name: "Missing ANCount", input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00}, expectedHeader: Header{ID: 0x1234, QDCount: 1}, expectedOffset: 7, expectedErr: &BufferOverflowError{}, wantErrMsg: "header decode: failed to read ANCount", }, { name: "Missing NSCount", input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00}, expectedHeader: Header{ID: 0x1234, QDCount: 1, ANCount: 2}, expectedOffset: 9, expectedErr: &BufferOverflowError{}, wantErrMsg: "header decode: failed to read NSCount", }, { name: "Missing ARCount", input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, 0x00}, expectedHeader: Header{ID: 0x1234, QDCount: 1, ANCount: 2, NSCount: 3}, expectedOffset: 11, expectedErr: &BufferOverflowError{}, wantErrMsg: "header decode: failed to read ARCount", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := &Header{} offset, err := h.Decode(tt.input, 0) if tt.expectedErr != nil { assert.Error(t, err, "Expected an error but got nil") assert.True(t, errors.Is(err, tt.expectedErr), "Error type mismatch. Got %T, expected %T", err, tt.expectedErr) if tt.wantErrMsg != "" { assert.ErrorContains(t, err, tt.wantErrMsg, "Wrapped error message mismatch") } assert.Equal(t, tt.expectedOffset, offset, "Offset mismatch on error") } else { assert.NoError(t, err, "Expected no error but got one") assert.Equal(t, tt.expectedHeader, *h, "Header content mismatch") assert.Equal(t, tt.expectedOffset, offset, "Offset mismatch on success") } }) } } func TestHeaderDecodeFlags(t *testing.T) { tests := []struct { name string flags uint16 expected Header }{ { name: "All flags set", flags: 0xFFFF, expected: Header{ QR: true, OPCode: OPCode(15), AA: true, TC: true, RD: true, RA: true, Z: 7, RCode: RCode(15), }, }, { name: "No flags set", flags: 0x0000, expected: Header{ QR: false, OPCode: OPCode(0), AA: false, TC: false, RD: false, RA: false, Z: 0, RCode: RCode(0), }, }, { name: "Mixed flags", flags: 0x8510, expected: Header{ QR: true, OPCode: OPCode(0), AA: true, TC: false, RD: true, RA: false, Z: 1, RCode: RCode(0), }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { input := []byte{ 0x00, 0x00, byte(tt.flags >> 8), byte(tt.flags), 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, } h := &Header{} offset, err := h.Decode(input, 0) assert.NoError(t, err) assert.Equal(t, 12, offset, "Offset should be 12 after decoding full header") assert.Equal(t, tt.expected.QR, h.QR, "QR flag mismatch") assert.Equal(t, tt.expected.OPCode, h.OPCode, "OPCode mismatch") assert.Equal(t, tt.expected.AA, h.AA, "AA flag mismatch") assert.Equal(t, tt.expected.TC, h.TC, "TC flag mismatch") assert.Equal(t, tt.expected.RD, h.RD, "RD flag mismatch") assert.Equal(t, tt.expected.RA, h.RA, "RA flag mismatch") assert.Equal(t, tt.expected.Z, h.Z, "Z value mismatch") assert.Equal(t, tt.expected.RCode, h.RCode, "RCode mismatch") }) } } func TestHeaderEncode(t *testing.T) { tests := []struct { name string header Header expected []byte }{ { name: "All fields set", header: Header{ ID: 0x1234, QR: true, OPCode: OPCode(1), AA: true, TC: true, RD: true, RA: true, Z: 5, RCode: RCode(3), QDCount: 1, ANCount: 2, NSCount: 3, ARCount: 4, }, expected: []byte{ 0x12, 0x34, 0x8f, 0xd3, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, 0x00, 0x04, }, }, { name: "No flags set, different counts", header: Header{ ID: 0x5678, QR: false, OPCode: OPCode(0), AA: false, TC: false, RD: false, RA: false, Z: 0, RCode: RCode(0), QDCount: 5, ANCount: 6, NSCount: 7, ARCount: 8, }, expected: []byte{ 0x56, 0x78, 0x00, 0x00, 0x00, 0x05, 0x00, 0x06, 0x00, 0x07, 0x00, 0x08, }, }, { name: "Mixed flags and counts", header: Header{ ID: 0x9abc, QR: true, OPCode: OPCode(2), AA: false, TC: true, RD: false, RA: true, Z: 3, RCode: RCode(4), QDCount: 9, ANCount: 10, NSCount: 11, ARCount: 12, }, expected: []byte{ 0x9a, 0xbc, 0x92, 0xb4, 0x00, 0x09, 0x00, 0x0a, 0x00, 0x0b, 0x00, 0x0c, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { encoded := tt.header.Encode() assert.Equal(t, tt.expected, encoded, "Encoded header mismatch") }) } } func TestHeaderEncodeDecodeRoundTrip(t *testing.T) { originalHeader := Header{ ID: 0xdcba, QR: true, OPCode: OPCode(3), AA: true, TC: false, RD: true, RA: false, Z: 6, RCode: RCode(2), QDCount: 13, ANCount: 14, NSCount: 15, ARCount: 16, } encoded := originalHeader.Encode() assert.Len(t, encoded, 12, "Encoded header should be 12 bytes") decodedHeader := &Header{} offset, err := decodedHeader.Decode(encoded, 0) assert.NoError(t, err, "Decoding failed unexpectedly") assert.Equal(t, 12, offset, "Offset after decoding should be 12") assert.Equal(t, originalHeader, *decodedHeader, "Decoded header does not match original") } func TestHeaderEncodeFlagCombinations(t *testing.T) { testCases := []struct { name string header Header expected uint16 }{ {"QR flag", Header{QR: true}, 0x8000}, {"OPCode", Header{OPCode: OPCode(5)}, 0x2800}, {"AA flag", Header{AA: true}, 0x0400}, {"TC flag", Header{TC: true}, 0x0200}, {"RD flag", Header{RD: true}, 0x0100}, {"RA flag", Header{RA: true}, 0x0080}, {"Z value", Header{Z: 5}, 0x0050}, {"RCode", Header{RCode: RCode(7)}, 0x0007}, {"All flags set", Header{QR: true, OPCode: OPCode(15), AA: true, TC: true, RD: true, RA: true, Z: 7, RCode: RCode(15)}, 0xffff}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { encoded := tc.header.Encode() require.Len(t, encoded, 12, "Encoded header length invariant") flags := binary.BigEndian.Uint16(encoded[2:4]) assert.Equal(t, tc.expected, flags, "Flags value mismatch") }) } } func FuzzDecodeHeader(f *testing.F) { testcases := [][]byte{ {0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, 0x00, 0x04}, {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, {0x12, 0x34}, {0x12, 0x34, 0x81, 0x80, 0x00}, {}, } for _, tc := range testcases { f.Add(tc) } f.Fuzz(func(t *testing.T, data []byte) { h := &Header{} offset, err := h.Decode(data, 0) if err != nil { var bofErr *BufferOverflowError if !errors.As(err, &bofErr) { t.Errorf("FuzzDecodeHeader: expected BufferOverflowError or wrapped BOF, got %T: %v", err, err) } if offset > len(data) { t.Errorf("FuzzDecodeHeader: offset (%d) > data length (%d) on error", offset, len(data)) } return } if len(data) < 12 { t.Errorf("FuzzDecodeHeader: decoded successfully but input length %d < 12", len(data)) return } if offset != 12 { t.Errorf("FuzzDecodeHeader: successful decode offset (%d) != 12", offset) } if h.OPCode > 15 { t.Errorf("FuzzDecodeHeader: invalid OPCode decoded: %d", h.OPCode) } if h.Z > 7 { t.Errorf("FuzzDecodeHeader: invalid Z value decoded: %d", h.Z) } if h.RCode > 15 { t.Errorf("FuzzDecodeHeader: invalid RCode decoded: %d", h.RCode) } if len(data) >= 12 { encoded := h.Encode() if !bytes.Equal(encoded, data[:12]) { t.Errorf("FuzzDecodeHeader: encode/decode mismatch\nInput: %x\nEncoded: %x", data[:12], encoded) } } }) }