···
"tangled.sh/seiso.moe/magna"
···
var errNXDOMAIN = fmt.Errorf("nxdomain")
17
+
depthKey contextKey = "dns_recursion_depth"
20
+
func withIncrementedDepth(ctx context.Context, maxDepth int) (context.Context, error) {
21
+
depth := getDepth(ctx)
22
+
if depth >= maxDepth {
23
+
return nil, fmt.Errorf("maximum recursion depth (%d) exceeded", maxDepth)
25
+
return context.WithValue(ctx, depthKey, depth+1), nil
28
+
func getDepth(ctx context.Context) int {
29
+
if depth, ok := ctx.Value(depthKey).(int); ok {
type QueryHandler struct {
···
question := r.Message.Question[0]
msg := r.Message.CreateReply(r.Message)
···
func (h *QueryHandler) resolveQuestion(ctx context.Context, question magna.Question, servers []string) ([]magna.ResourceRecord, error) {
61
-
ch := make(chan queryResponse, len(servers))
63
-
for _, s := range servers {
64
-
go queryServer(ctx, question, s, ch, h.Timeout)
82
+
newCtx, err := withIncrementedDepth(ctx, maxDepth)
70
-
if res.Error != nil {
75
-
if msg.Header.ANCount > 0 {
76
-
if msg.Answer[0].RType == magna.CNAMEType {
77
-
cname_answers, err := h.resolveQuestion(ctx, magna.Question{QName: msg.Answer[0].RData.String(), QType: question.QType, QClass: question.QClass}, h.RootServers)
81
-
msg.Answer = append(msg.Answer, cname_answers...)
84
-
return msg.Answer, nil
88
+
for _, s := range servers {
89
+
msg, err := queryServer(ctx, question, s, h.Timeout)
91
+
h.Logger.Warn("unable to resolve question", "server", s)
87
-
if msg.Header.ARCount > 0 {
88
-
var nextZone []string
89
-
for _, ans := range msg.Additional {
90
-
if ans.RType == magna.AType {
91
-
nextZone = append(nextZone, ans.RData.String())
95
+
if ok, answers := ExtractAnswer(question, msg); ok {
96
+
if msg.Answer[0].RType == magna.CNAMEType {
97
+
cnameQuestion := magna.Question{QName: msg.Answer[0].RData.String(), QType: question.QType, QClass: question.QClass}
98
+
ips, err := h.resolveQuestion(ctx, cnameQuestion, h.RootServers)
100
+
h.Logger.Info("unable to resolve CNAME target", "question", cnameQuestion, "depth", getDepth(ctx), "error", err)
95
-
return h.resolveQuestion(ctx, question, nextZone)
103
+
answers = append(answers, ips...)
105
+
return answers, nil
98
-
if msg.Header.NSCount > 0 {
100
-
for _, a := range msg.Authority {
101
-
if a.RType == magna.NSType {
102
-
ans, err := h.resolveQuestion(ctx, magna.Question{QName: a.RData.String(), QType: magna.AType, QClass: magna.IN}, h.RootServers)
106
-
for _, x := range ans {
107
-
ns = append(ns, x.RData.String())
108
+
if ok, answers := HandleGlueRecords(question, msg); ok {
109
+
return h.resolveQuestion(ctx, question, answers)
112
-
return h.resolveQuestion(ctx, question, ns)
115
-
return []magna.ResourceRecord{}, nil
112
+
if ok, answers := h.HandleReferral(ctx, question, msg); ok {
113
+
return h.resolveQuestion(ctx, question, answers)
return []magna.ResourceRecord{}, nil
122
-
func queryServer(ctx context.Context, question magna.Question, server string, ch chan<- queryResponse, timeout time.Duration) {
123
-
done := make(chan struct{}, 1)
120
+
func queryServer(ctx context.Context, question magna.Question, server string, timeout time.Duration) (magna.Message, error) {
122
+
conn, err := d.DialContext(ctx, "udp", fmt.Sprintf("%s:53", server))
124
+
return magna.Message{}, err
126
-
conn, err := net.Dial("udp", fmt.Sprintf("%s:53", server))
128
-
ch <- queryResponse{Error: err}
133
+
conn.SetDeadline(time.Now().Add(timeout))
135
+
query := magna.CreateRequest(0, false)
136
+
query = query.AddQuestion(question)
137
+
msg, err := query.Encode()
139
+
return magna.Message{}, err
142
+
if _, err := conn.Write(msg); err != nil {
143
+
return magna.Message{}, err
146
+
p := make([]byte, 512)
147
+
nn, err := conn.Read(p)
149
+
// TODO: retry request with TCP
150
+
if err != nil || nn > 512 {
152
+
err = fmt.Errorf("truncated response")
154
+
return magna.Message{}, err
133
-
query := magna.CreateRequest(0, false)
134
-
query = query.AddQuestion(question)
135
-
msg, err := query.Encode()
137
-
ch <-queryResponse{Server: server, Error: err}
157
+
var response magna.Message
158
+
err = response.Decode(p)
159
+
return response, err
162
+
func ExtractAnswer(q magna.Question, r magna.Message) (bool, []magna.ResourceRecord) {
163
+
answers := make([]magna.ResourceRecord, 0, r.Header.ANCount)
164
+
for _, a := range r.Answer {
165
+
if a.RClass == q.QClass && strings.ToLower(a.Name) == strings.ToLower(q.QName) {
166
+
answers = append(answers, a)
140
-
if _, err := conn.Write(msg); err != nil {
141
-
ch <- queryResponse{Server: server, Error: err}
170
+
if len(answers) <= 0 {
171
+
return false, []magna.ResourceRecord{}
174
+
return true, answers
177
+
func HandleGlueRecords(q magna.Question, r magna.Message) (bool, []string) {
178
+
answers := make([]string, 0, r.Header.ARCount)
179
+
for _, a := range r.Authority {
180
+
if a.RType != magna.NSType {
145
-
p := make([]byte, 512)
146
-
nn, err := conn.Read(p)
184
+
ns, ok := a.RData.(*magna.NS)
186
+
// this should not happen but better safe than sorry
148
-
// TODO: retry request with TCP
149
-
if err != nil || nn > 512 {
151
-
err = fmt.Errorf("truncated response")
190
+
for _, ad := range r.Additional {
191
+
// XXX: add AAAAType when magna supports it
192
+
if ad.RType == magna.AType && strings.ToLower(ad.Name) == strings.ToLower(ns.NSDName) {
193
+
answers = append(answers, ad.RData.String())
153
-
ch <- queryResponse{Server: server, Error: err}
157
-
var response magna.Message
158
-
err = response.Decode(p)
159
-
ch <- queryResponse{MSG: response, Server: server, Error: err}
198
+
if len(answers) <= 0 {
199
+
return false, []string{}
202
+
return true, answers
205
+
func (h *QueryHandler) HandleReferral(ctx context.Context, q magna.Question, r magna.Message) (bool, []string) {
206
+
servers := make([]string, 0, r.Header.NSCount)
208
+
for _, auth := range r.Authority {
209
+
if auth.RType == magna.NSType {
210
+
nsQuestion := magna.Question{
211
+
QName: auth.RData.String(),
212
+
QType: magna.AType,
216
+
answers, err := h.resolveQuestion(ctx, nsQuestion, h.RootServers)
218
+
h.Logger.Warn("error handling referral",
219
+
"question", nsQuestion,
220
+
"depth", getDepth(ctx),
164
-
ch <- queryResponse{Server: server, Error: ctx.Err()}
166
-
// goroutine finished with no cancellation
167
-
case <-time.After(timeout):
168
-
ch <- queryResponse{Server: server, Error: fmt.Errorf("timeout")}
225
+
for _, ans := range answers {
226
+
servers = append(servers, ans.RData.String())
231
+
if len(servers) <= 0 {
232
+
return false, []string{}
235
+
return true, servers