a recursive dns resolver

add nx domain handling

Changed files
+65 -11
pkg
+65 -11
pkg/dns/resolve.go
···
"tangled.sh/seiso.moe/magna"
)
-
var errNXDOMAIN = fmt.Errorf("nxdomain")
+
var (
+
errNXDOMAIN = fmt.Errorf("alky: domain does not exist")
+
errOnlySOA = fmt.Errorf("alky: only an soa was reffered")
+
errNoServers = fmt.Errorf("alky: no servers responded")
+
errMaxDepth = fmt.Errorf("alky: maximum recursion depth exceeded")
+
errNonMatchingID = fmt.Errorf("alky: response ID mismatch")
+
errNonMatchingQuestion = fmt.Errorf("alky: response question mismatch")
+
errQueryTimeout = fmt.Errorf("alky: query timeout")
+
)
const (
depthKey contextKey = "dns_recursion_depth"
···
func withIncrementedDepth(ctx context.Context, maxDepth int) (context.Context, error) {
depth := getDepth(ctx)
if depth >= maxDepth {
-
return nil, fmt.Errorf("maximum recursion depth (%d) exceeded", maxDepth)
+
return nil, errMaxDepth
}
return context.WithValue(ctx, depthKey, depth+1), nil
}
···
msg.Authority = records
msg.Header.NSCount = uint16(len(records))
msg.Header.ANCount = 0
+
msg.Answer = nil
+
} else if err == errOnlySOA {
+
msg = msg.SetRCode(magna.NOERROR)
+
msg.Authority = records
+
msg.Header.NSCount = uint16(len(records))
+
msg.Header.ANCount = 0
+
msg.Header.ARCount = 0
msg.Answer = nil
} else if err != nil {
+
h.Logger.Warn("error", "error", err)
msg = msg.SetRCode(magna.SERVFAIL)
} else {
msg.Answer = records
···
for _, s := range servers {
msg, err := queryServer(ctx, question, s, h.Timeout)
if err != nil {
-
h.Logger.Warn("unable to resolve question", "server", s)
+
h.Logger.Warn("unable to resolve question", "server", s, "error", err)
continue
}
+
if msg.Header.RCode == magna.NXDOMAIN {
+
_, authority := ExtractSOA(msg)
+
return authority, errNXDOMAIN
+
}
+
if ok, answers := ExtractAnswer(question, msg); ok {
if msg.Answer[0].RType == magna.CNAMEType {
cnameQuestion := magna.Question{QName: msg.Answer[0].RData.String(), QType: question.QType, QClass: question.QClass}
···
if ok, answers := h.HandleReferral(ctx, question, msg); ok {
return h.resolveQuestion(ctx, question, answers)
}
+
+
if ok, answers := ExtractSOA(msg); ok {
+
return answers, errOnlySOA
+
}
}
-
return []magna.ResourceRecord{}, nil
+
return []magna.ResourceRecord{}, errNoServers
}
func queryServer(ctx context.Context, question magna.Question, server string, timeout time.Duration) (magna.Message, error) {
+
ctx, cancel := context.WithTimeout(ctx, timeout)
+
defer cancel()
+
var d net.Dialer
conn, err := d.DialContext(ctx, "udp", fmt.Sprintf("%s:53", server))
if err != nil {
···
}
defer conn.Close()
-
go func() {
-
<-ctx.Done()
-
conn.Close()
-
}()
-
conn.SetDeadline(time.Now().Add(timeout))
-
query := magna.CreateRequest(0, false)
+
query := magna.CreateRequest(magna.QUERY, false)
query = query.AddQuestion(question)
msg, err := query.Encode()
if err != nil {
···
}
var response magna.Message
-
err = response.Decode(p)
+
if err := response.Decode(p); err != nil {
+
return magna.Message{}, err
+
}
+
+
if err := validateResponse(*query, response, question); err != nil {
+
return magna.Message{}, err
+
}
return response, err
}
+
func validateResponse(query magna.Message, response magna.Message, question magna.Question) error {
+
if response.Header.ID != query.Header.ID {
+
return errNonMatchingID
+
}
+
if len(response.Question) < 1 || response.Question[0] != question {
+
return errNonMatchingQuestion
+
}
+
return nil
+
}
+
func ExtractAnswer(q magna.Question, r magna.Message) (bool, []magna.ResourceRecord) {
answers := make([]magna.ResourceRecord, 0, r.Header.ANCount)
for _, a := range r.Answer {
if a.RClass == q.QClass && strings.ToLower(a.Name) == strings.ToLower(q.QName) {
+
answers = append(answers, a)
+
}
+
}
+
+
if len(answers) <= 0 {
+
return false, []magna.ResourceRecord{}
+
}
+
+
return true, answers
+
}
+
+
func ExtractSOA(r magna.Message) (bool, []magna.ResourceRecord) {
+
answers := make([]magna.ResourceRecord, 0, r.Header.NSCount)
+
for _, a := range r.Authority {
+
if a.RType == magna.SOAType {
answers = append(answers, a)
}
}