a go dns packet parser
1package magna 2 3import ( 4 "bytes" 5 "encoding/binary" 6 "errors" 7 "net" 8 "strings" 9 "testing" 10 11 "github.com/stretchr/testify/assert" 12 "github.com/stretchr/testify/require" 13) 14 15func TestMessageDecode(t *testing.T) { 16 buildQuery := func(id uint16, qname string, qtype DNSType, qclass DNSClass) []byte { 17 buf := new(bytes.Buffer) 18 binary.Write(buf, binary.BigEndian, id) 19 binary.Write(buf, binary.BigEndian, uint16(0x0100)) 20 binary.Write(buf, binary.BigEndian, uint16(1)) 21 binary.Write(buf, binary.BigEndian, uint16(0)) 22 binary.Write(buf, binary.BigEndian, uint16(0)) 23 binary.Write(buf, binary.BigEndian, uint16(0)) 24 offsets := make(map[string]uint16) 25 qBytes, err := encodeDomain([]byte{}, qname, &offsets) 26 require.NoError(t, err) 27 buf.Write(qBytes) 28 binary.Write(buf, binary.BigEndian, uint16(qtype)) 29 binary.Write(buf, binary.BigEndian, uint16(qclass)) 30 return buf.Bytes() 31 } 32 33 buildAnswer := func(id uint16, name string, rtype DNSType, rclass DNSClass, ttl uint32, rdata ResourceRecordData) []byte { 34 buf := new(bytes.Buffer) 35 36 binary.Write(buf, binary.BigEndian, id) 37 binary.Write(buf, binary.BigEndian, uint16(0x8180)) 38 binary.Write(buf, binary.BigEndian, uint16(0)) 39 binary.Write(buf, binary.BigEndian, uint16(1)) 40 binary.Write(buf, binary.BigEndian, uint16(0)) 41 binary.Write(buf, binary.BigEndian, uint16(0)) 42 rr := ResourceRecord{ 43 Name: name, 44 RType: rtype, 45 RClass: rclass, 46 TTL: ttl, 47 RData: rdata, 48 } 49 offsets := make(map[string]uint16) 50 rrBytes, err := rr.Encode([]byte{}, &offsets) 51 require.NoError(t, err) 52 buf.Write(rrBytes) 53 return buf.Bytes() 54 } 55 56 tests := []struct { 57 name string 58 input []byte 59 expected Message 60 wantErr bool 61 wantErrType error 62 wantErrMsg string 63 }{ 64 { 65 name: "Valid DNS query message with one question", 66 input: buildQuery(1234, "www.example.com", AType, IN), 67 expected: Message{ 68 Header: Header{ 69 ID: 1234, 70 QR: false, 71 RD: true, 72 OPCode: OPCode(0), 73 QDCount: 1, 74 RCode: NOERROR, 75 }, 76 Question: []Question{ 77 { 78 QName: "www.example.com", 79 QType: AType, 80 QClass: IN, 81 }, 82 }, 83 Answer: []ResourceRecord{}, 84 Additional: []ResourceRecord{}, 85 Authority: []ResourceRecord{}, 86 }, 87 wantErr: false, 88 }, 89 { 90 name: "Valid DNS answer message with one A record", 91 input: buildAnswer(5678, "www.example.com", AType, IN, 3600, 92 &A{Address: net.ParseIP("10.0.0.1").To4()}, 93 ), 94 expected: Message{ 95 Header: Header{ 96 ID: 5678, 97 QR: true, 98 OPCode: 0, 99 AA: false, 100 RD: true, 101 RA: true, 102 RCode: 0, 103 ANCount: 1, 104 }, 105 Question: []Question{}, 106 Answer: []ResourceRecord{ 107 { 108 Name: "www.example.com", 109 RType: AType, 110 RClass: IN, 111 TTL: 3600, 112 RDLength: 4, 113 RData: &A{Address: net.IP([]byte{10, 0, 0, 1})}, 114 }, 115 }, 116 Additional: []ResourceRecord{}, 117 Authority: []ResourceRecord{}, 118 }, 119 wantErr: false, 120 }, 121 { 122 name: "Invalid input - empty buffer", 123 input: []byte{}, 124 wantErr: true, 125 wantErrType: &BufferOverflowError{}, 126 wantErrMsg: "failed to decode message header: header decode: failed to read ID", 127 }, 128 { 129 name: "Invalid input - truncated header", 130 input: []byte{0x12, 0x34}, 131 wantErr: true, 132 wantErrType: &BufferOverflowError{}, 133 wantErrMsg: "failed to decode message header: header decode: failed to read flags", 134 }, 135 { 136 name: "Invalid input - truncated question name", 137 input: func() []byte { 138 buf := new(bytes.Buffer) 139 binary.Write(buf, binary.BigEndian, uint16(1235)) 140 binary.Write(buf, binary.BigEndian, uint16(0x0100)) 141 binary.Write(buf, binary.BigEndian, uint16(1)) 142 binary.Write(buf, binary.BigEndian, uint16(0)) 143 binary.Write(buf, binary.BigEndian, uint16(0)) 144 binary.Write(buf, binary.BigEndian, uint16(0)) 145 buf.Write([]byte{7, 'e', 'x', 'a'}) 146 return buf.Bytes() 147 }(), 148 wantErr: true, 149 wantErrType: &BufferOverflowError{}, 150 wantErrMsg: "failed to decode question #1:", 151 }, 152 { 153 name: "Invalid input - truncated answer record data", 154 input: func() []byte { 155 buf := new(bytes.Buffer) 156 157 binary.Write(buf, binary.BigEndian, uint16(5679)) 158 binary.Write(buf, binary.BigEndian, uint16(0x8180)) 159 binary.Write(buf, binary.BigEndian, uint16(0)) 160 binary.Write(buf, binary.BigEndian, uint16(1)) 161 binary.Write(buf, binary.BigEndian, uint16(0)) 162 binary.Write(buf, binary.BigEndian, uint16(0)) 163 164 offsets := make(map[string]uint16) 165 nameBytes, _ := encodeDomain([]byte{}, "example.com", &offsets) 166 buf.Write(nameBytes) 167 binary.Write(buf, binary.BigEndian, uint16(AType)) 168 binary.Write(buf, binary.BigEndian, uint16(IN)) 169 binary.Write(buf, binary.BigEndian, uint32(300)) 170 binary.Write(buf, binary.BigEndian, uint16(4)) 171 172 buf.Write([]byte{192, 168}) 173 return buf.Bytes() 174 }(), 175 wantErr: true, 176 wantErrType: &BufferOverflowError{}, 177 wantErrMsg: "failed to decode answer record #1:", 178 }, 179 } 180 181 for _, tt := range tests { 182 t.Run(tt.name, func(t *testing.T) { 183 m := &Message{} 184 err := m.Decode(tt.input) 185 186 if tt.wantErr { 187 assert.Error(t, err, "Expected an error but got nil") 188 if tt.wantErrType != nil { 189 assert.True(t, errors.Is(err, tt.wantErrType), "Error type mismatch. Got %T, expected %T", err, tt.wantErrType) 190 } 191 if tt.wantErrMsg != "" { 192 assert.ErrorContains(t, err, tt.wantErrMsg, "Error message mismatch") 193 } 194 } else { 195 assert.NoError(t, err, "Expected no error but got one") 196 197 assert.Equal(t, tt.expected.Header.ID, m.Header.ID, "Header ID mismatch") 198 assert.Equal(t, tt.expected.Header.QR, m.Header.QR, "Header QR mismatch") 199 assert.Equal(t, tt.expected.Header.OPCode, m.Header.OPCode, "Header OPCode mismatch") 200 assert.Equal(t, tt.expected.Header.RCode, m.Header.RCode, "Header RCode mismatch") 201 assert.Equal(t, tt.expected.Header.QDCount, m.Header.QDCount, "Header QDCount mismatch") 202 assert.Equal(t, tt.expected.Header.ANCount, m.Header.ANCount, "Header ANCount mismatch") 203 204 assert.Equal(t, tt.expected.Question, m.Question, "Question section mismatch") 205 assert.Equal(t, tt.expected.Answer, m.Answer, "Answer section mismatch") 206 assert.Equal(t, tt.expected.Authority, m.Authority, "Authority section mismatch") 207 assert.Equal(t, tt.expected.Additional, m.Additional, "Additional section mismatch") 208 } 209 }) 210 } 211} 212 213func TestMessageEncodeDecodeRoundTrip(t *testing.T) { 214 tests := []struct { 215 name string 216 message *Message 217 }{ 218 { 219 name: "Query with one question", 220 message: CreateRequest(QUERY, true).AddQuestion(Question{ 221 QName: "google.com", 222 QType: AType, 223 QClass: IN, 224 }), 225 }, 226 { 227 name: "Response with one A answer", 228 message: &Message{ 229 Header: Header{ 230 ID: 12345, QR: true, OPCode: QUERY, RD: true, RA: true, RCode: NOERROR, ANCount: 1, 231 }, 232 Question: []Question{}, 233 Answer: []ResourceRecord{ 234 {Name: "test.local", RType: AType, RClass: IN, TTL: 60, RDLength: 4, RData: &A{net.ParseIP("192.0.2.1").To4()}}, 235 }, 236 Additional: []ResourceRecord{}, 237 Authority: []ResourceRecord{}, 238 }, 239 }, 240 { 241 name: "Response with multiple answers and compression", 242 message: &Message{ 243 Header: Header{ID: 54321, QR: true, RCode: NOERROR, ANCount: 2}, 244 Question: []Question{}, 245 Answer: []ResourceRecord{ 246 {Name: "www.example.com", RType: AType, RClass: IN, TTL: 300, RDLength: 4, RData: &A{net.ParseIP("192.0.2.2").To4()}}, 247 {Name: "mail.example.com", RType: AType, RClass: IN, TTL: 300, RDLength: 4, RData: &A{net.ParseIP("192.0.2.3").To4()}}, 248 }, 249 Additional: []ResourceRecord{}, 250 Authority: []ResourceRecord{}, 251 }, 252 }, 253 { 254 name: "Message with various record types", 255 message: &Message{ 256 Header: Header{ID: 1111, QR: true, RCode: NOERROR, ANCount: 3}, 257 Question: []Question{}, 258 Answer: []ResourceRecord{ 259 {Name: "example.com", RType: MXType, RClass: IN, TTL: 3600, RDLength: 9, RData: &MX{Preference: 10, Exchange: "mail.example.com"}}, 260 {Name: "mail.example.com", RType: AType, RClass: IN, TTL: 300, RDLength: 4, RData: &A{net.ParseIP("192.0.2.4").To4()}}, 261 {Name: "example.com", RType: TXTType, RClass: IN, TTL: 600, RDLength: 36, RData: &TXT{TxtData: []string{"v=spf1 include:_spf.google.com ~all"}}}, 262 }, 263 Additional: []ResourceRecord{}, 264 Authority: []ResourceRecord{}, 265 }, 266 }, 267 } 268 269 for _, tt := range tests { 270 t.Run(tt.name, func(t *testing.T) { 271 encodedBytes, err := tt.message.Encode() 272 require.NoError(t, err, "Encoding failed unexpectedly") 273 require.NotEmpty(t, encodedBytes, "Encoded bytes should not be empty") 274 275 decodedMsg := &Message{} 276 err = decodedMsg.Decode(encodedBytes) 277 require.NoError(t, err, "Decoding failed unexpectedly") 278 279 assert.Equal(t, tt.message.Header.ID, decodedMsg.Header.ID, "Header ID mismatch") 280 assert.Equal(t, tt.message.Header.QR, decodedMsg.Header.QR, "Header QR mismatch") 281 assert.Equal(t, tt.message.Header.OPCode, decodedMsg.Header.OPCode, "Header OPCode mismatch") 282 assert.Equal(t, tt.message.Header.RCode, decodedMsg.Header.RCode, "Header RCode mismatch") 283 284 assert.Equal(t, tt.message.Question, decodedMsg.Question, "Question section mismatch") 285 assert.Equal(t, tt.message.Answer, decodedMsg.Answer, "Answer section mismatch") 286 assert.Equal(t, tt.message.Authority, decodedMsg.Authority, "Authority section mismatch") 287 assert.Equal(t, tt.message.Additional, decodedMsg.Additional, "Additional section mismatch") 288 }) 289 } 290} 291 292func FuzzDecodeMessage(f *testing.F) { 293 testcases := [][]byte{ 294 { 295 0x8e, 0x19, 0x81, 0x80, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x04, 0x6e, 0x65, 296 0x77, 0x73, 0x0b, 0x79, 0x63, 0x6f, 0x6d, 0x62, 0x69, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x03, 297 0x63, 0x6f, 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01, 0xc0, 0x0c, 0x00, 0x01, 0x00, 0x01, 0x00, 298 0x00, 0x00, 0x01, 0x00, 0x04, 0xd1, 0xd8, 0xe6, 0xcf, 299 }, 300 } 301 for _, tc := range testcases { 302 f.Add(tc) 303 } 304 f.Fuzz(func(t *testing.T, msg []byte) { 305 var m Message 306 err := m.Decode(msg) 307 if err != nil { 308 var bufErr *BufferOverflowError 309 var labelErr *InvalidLabelError 310 var compErr *DomainCompressionError 311 if !(errors.As(err, &bufErr) || errors.As(err, &labelErr) || errors.As(err, &compErr) || strings.Contains(err.Error(), "record:")) { 312 t.Errorf("FuzzDecodeMessage: unexpected error type %T: %v for input %x", err, err, msg) 313 } 314 } 315 }) 316}