package magna import ( "encoding/binary" "testing" "github.com/stretchr/testify/assert" ) func TestHeaderDecode(t *testing.T) { tests := []struct { name string input []byte expectedHeader Header expectedOffset int expectedErr error }{ { name: "Valid header", input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, 0x00, 0x04}, expectedHeader: Header{ ID: 0x1234, QR: true, OPCode: OPCode(0), AA: false, TC: false, RD: true, RA: true, Z: 0, RCode: RCode(0), QDCount: 1, ANCount: 2, NSCount: 3, ARCount: 4, }, expectedOffset: 12, expectedErr: nil, }, { name: "Insufficient buffer length", input: []byte{0x12, 0x34, 0x81}, expectedHeader: Header{}, expectedOffset: 3, expectedErr: &BufferOverflowError{Length: 3, Offset: 3}, }, { name: "Invalid ID", input: []byte{0x12}, expectedHeader: Header{}, expectedOffset: 1, expectedErr: &BufferOverflowError{Length: 1, Offset: 1}, }, { name: "Missing QDCount", input: []byte{0x12, 0x34, 0x81, 0x80, 0x00}, expectedHeader: Header{}, expectedOffset: 5, expectedErr: &BufferOverflowError{}, }, { name: "Missing ANCount", input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01}, expectedHeader: Header{}, expectedOffset: 6, expectedErr: &BufferOverflowError{}, }, { name: "Missing NSCount", input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02}, expectedHeader: Header{}, expectedOffset: 8, expectedErr: &BufferOverflowError{}, }, { name: "Missing ARCount", input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03}, expectedHeader: Header{}, expectedOffset: 10, expectedErr: &BufferOverflowError{}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := &Header{} offset, err := h.Decode(tt.input, 0) if tt.expectedErr != nil { assert.Error(t, err) assert.IsType(t, tt.expectedErr, err) } else { assert.NoError(t, err) assert.Equal(t, tt.expectedHeader, *h) } assert.Equal(t, tt.expectedOffset, offset) }) } } func TestHeaderDecodeFlags(t *testing.T) { tests := []struct { name string flags uint16 expected Header }{ { name: "All flags set", flags: 0xFFFF, expected: Header{ QR: true, OPCode: OPCode(15), AA: true, TC: true, RD: true, RA: true, Z: 7, RCode: RCode(15), }, }, { name: "No flags set", flags: 0x0000, expected: Header{ QR: false, OPCode: OPCode(0), AA: false, TC: false, RD: false, RA: false, Z: 0, RCode: RCode(0), }, }, { name: "Mixed flags", flags: 0x8510, expected: Header{ QR: true, OPCode: OPCode(0), AA: true, TC: false, RD: true, RA: false, Z: 1, RCode: RCode(0), }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { input := []byte{ 0x00, 0x00, byte(tt.flags >> 8), byte(tt.flags), 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, } h := &Header{} _, err := h.Decode(input, 0) assert.NoError(t, err) assert.Equal(t, tt.expected.QR, h.QR) assert.Equal(t, tt.expected.OPCode, h.OPCode) assert.Equal(t, tt.expected.AA, h.AA) assert.Equal(t, tt.expected.TC, h.TC) assert.Equal(t, tt.expected.RD, h.RD) assert.Equal(t, tt.expected.RA, h.RA) assert.Equal(t, tt.expected.Z, h.Z) assert.Equal(t, tt.expected.RCode, h.RCode) }) } } func TestHeaderEncode(t *testing.T) { tests := []struct { name string header Header expected []byte }{ { name: "All fields set", header: Header{ ID: 0x1234, QR: true, OPCode: OPCode(1), AA: true, TC: true, RD: true, RA: true, Z: 5, RCode: RCode(3), QDCount: 1, ANCount: 2, NSCount: 3, ARCount: 4, }, expected: []byte{ 0x12, 0x34, 0x8f, 0xd3, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, 0x00, 0x04, }, }, { name: "No flags set", header: Header{ ID: 0x5678, QR: false, OPCode: OPCode(0), AA: false, TC: false, RD: false, RA: false, Z: 0, RCode: RCode(0), QDCount: 5, ANCount: 6, NSCount: 7, ARCount: 8, }, expected: []byte{ 0x56, 0x78, 0x00, 0x00, 0x00, 0x05, 0x00, 0x06, 0x00, 0x07, 0x00, 0x08, }, }, { name: "Mixed flags", header: Header{ ID: 0x9abc, QR: true, OPCode: OPCode(2), AA: false, TC: true, RD: false, RA: true, Z: 3, RCode: RCode(4), QDCount: 9, ANCount: 10, NSCount: 11, ARCount: 12, }, expected: []byte{ 0x9a, 0xbc, 0x92, 0xb4, 0x00, 0x09, 0x00, 0x0a, 0x00, 0x0b, 0x00, 0x0c, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { encoded := tt.header.Encode() assert.Equal(t, tt.expected, encoded) }) } } func TestHeaderEncodeDecodeRoundTrip(t *testing.T) { originalHeader := Header{ ID: 0xdcba, QR: true, OPCode: OPCode(3), AA: true, TC: false, RD: true, RA: false, Z: 6, RCode: RCode(2), QDCount: 13, ANCount: 14, NSCount: 15, ARCount: 16, } encoded := originalHeader.Encode() decodedHeader := &Header{} offset, err := decodedHeader.Decode(encoded, 0) assert.NoError(t, err) assert.Equal(t, len(encoded), offset) assert.Equal(t, originalHeader, *decodedHeader) } func TestHeaderEncodeFlagCombinations(t *testing.T) { testCases := []struct { name string header Header expected uint16 }{ {"QR flag", Header{QR: true}, 0x8000}, {"OPCode", Header{OPCode: OPCode(5)}, 0x2800}, {"AA flag", Header{AA: true}, 0x0400}, {"TC flag", Header{TC: true}, 0x0200}, {"RD flag", Header{RD: true}, 0x0100}, {"RA flag", Header{RA: true}, 0x0080}, {"Z value", Header{Z: 5}, 0x0050}, {"RCode", Header{RCode: RCode(7)}, 0x0007}, {"All flags set", Header{QR: true, OPCode: OPCode(15), AA: true, TC: true, RD: true, RA: true, Z: 7, RCode: RCode(15)}, 0xffff}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { encoded := tc.header.Encode() flags := binary.BigEndian.Uint16(encoded[2:4]) assert.Equal(t, tc.expected, flags) }) } } func FuzzDecodeHeader(f *testing.F) { testcases := [][]byte{ {0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, 0x00, 0x04}, {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, } for _, tc := range testcases { f.Add(tc) } f.Fuzz(func(t *testing.T, data []byte) { // limit to only 12 bytes if len(data) > 12 { data = data[0:12] } h := &Header{} offset, err := h.Decode(data, 0) if err != nil { switch err.(type) { case *BufferOverflowError, *InvalidLabelError: // these are expected error types default: t.Errorf("unexpected error type: %T", err) } return } if offset != len(data) { t.Errorf("offset (%d) does not match data length (%d)", offset, len(data)) } if h.OPCode > 15 { t.Errorf("invalid OPCode: %d", h.OPCode) } if h.Z > 7 { t.Errorf("invalid Z value: %d", h.Z) } if h.RCode > 15 { t.Errorf("invalid RCode: %d", h.RCode) } encoded := h.Encode() if len(encoded) != len(data) { t.Errorf("encoded length (%d) does not match input length (%d)", len(encoded), len(data)) for i := 0; i < len(data); i++ { t.Errorf("mismatch at position: %d: encoded %02x, input: %02x", i, encoded[i], data[i]) } } }) }