1package dns
2
3import (
4 "context"
5 "fmt"
6 "log/slog"
7 "net"
8 "strings"
9 "time"
10
11 "tangled.sh/seiso.moe/magna"
12)
13
14var (
15 errNXDOMAIN = fmt.Errorf("alky: domain does not exist")
16 errOnlySOA = fmt.Errorf("alky: only an soa was reffered")
17 errNoServers = fmt.Errorf("alky: no servers responded")
18 errMaxDepth = fmt.Errorf("alky: maximum recursion depth exceeded")
19 errNonMatchingID = fmt.Errorf("alky: response ID mismatch")
20 errNonMatchingQuestion = fmt.Errorf("alky: response question mismatch")
21 errQueryTimeout = fmt.Errorf("alky: query timeout")
22)
23
24const (
25 depthKey contextKey = "dns_recursion_depth"
26)
27
28func withIncrementedDepth(ctx context.Context, maxDepth int) (context.Context, error) {
29 depth := getDepth(ctx)
30 if depth >= maxDepth {
31 return nil, errMaxDepth
32 }
33 return context.WithValue(ctx, depthKey, depth+1), nil
34}
35
36func getDepth(ctx context.Context) int {
37 if depth, ok := ctx.Value(depthKey).(int); ok {
38 return depth
39 }
40 return 0
41}
42
43type QueryHandler struct {
44 RootServers []string
45 Timeout time.Duration
46 Logger *slog.Logger
47}
48
49type queryResponse struct {
50 MSG magna.Message
51 Server string
52 Error error
53}
54
55func (h *QueryHandler) ServeDNS(w ResponseWriter, r *Request) {
56 if len(r.Message.Question) < 1 {
57 h.Logger.Debug("received query with no questions")
58 msg := r.Message.CreateReply(r.Message)
59 msg = msg.SetRCode(magna.FORMERR)
60 w.WriteMsg(msg)
61 return
62 }
63
64 question := r.Message.Question[0]
65 msg := r.Message.CreateReply(r.Message)
66 msg.Header.RA = true
67
68 records, err := h.resolveQuestion(r.Context, question, h.RootServers)
69
70 if err == errNXDOMAIN {
71 msg = msg.SetRCode(magna.NXDOMAIN)
72 msg.Authority = records
73 msg.Header.NSCount = uint16(len(records))
74 msg.Header.ANCount = 0
75 msg.Answer = nil
76 } else if err == errOnlySOA {
77 msg = msg.SetRCode(magna.NOERROR)
78 msg.Authority = records
79 msg.Header.NSCount = uint16(len(records))
80 msg.Header.ANCount = 0
81 msg.Header.ARCount = 0
82 msg.Answer = nil
83 } else if err != nil {
84 h.Logger.Warn("error", "error", err)
85 msg = msg.SetRCode(magna.SERVFAIL)
86 } else {
87 msg.Answer = records
88 msg.Header.ANCount = uint16(len(records))
89 msg = msg.SetRCode(magna.NOERROR)
90 }
91
92 w.WriteMsg(msg)
93}
94
95func (h *QueryHandler) resolveQuestion(ctx context.Context, question magna.Question, servers []string) ([]magna.ResourceRecord, error) {
96 const maxDepth = 16
97
98 newCtx, err := withIncrementedDepth(ctx, maxDepth)
99 if err != nil {
100 return nil, err
101 }
102 ctx = newCtx
103
104 for _, s := range servers {
105 msg, err := queryServer(ctx, question, s, h.Timeout)
106 if err != nil {
107 h.Logger.Warn("unable to resolve question", "server", s, "error", err)
108 continue
109 }
110
111 if msg.Header.RCode == magna.NXDOMAIN {
112 _, authority := ExtractSOA(msg)
113 return authority, errNXDOMAIN
114 }
115
116 if ok, answers := ExtractAnswer(question, msg); ok {
117 if msg.Answer[0].RType == magna.CNAMEType {
118 cnameQuestion := magna.Question{QName: msg.Answer[0].RData.String(), QType: question.QType, QClass: question.QClass}
119 ips, err := h.resolveQuestion(ctx, cnameQuestion, h.RootServers)
120 if err != nil {
121 h.Logger.Info("unable to resolve CNAME target", "question", cnameQuestion, "depth", getDepth(ctx), "error", err)
122 }
123
124 answers = append(answers, ips...)
125 }
126 return answers, nil
127 }
128
129 if ok, answers := HandleGlueRecords(question, msg); ok {
130 return h.resolveQuestion(ctx, question, answers)
131 }
132
133 if ok, answers := h.HandleReferral(ctx, question, msg); ok {
134 return h.resolveQuestion(ctx, question, answers)
135 }
136
137 if ok, answers := ExtractSOA(msg); ok {
138 return answers, errOnlySOA
139 }
140 }
141
142 return []magna.ResourceRecord{}, errNoServers
143}
144
145func queryServer(ctx context.Context, question magna.Question, server string, timeout time.Duration) (magna.Message, error) {
146 ctx, cancel := context.WithTimeout(ctx, timeout)
147 defer cancel()
148
149 var d net.Dialer
150 conn, err := d.DialContext(ctx, "udp", fmt.Sprintf("%s:53", server))
151 if err != nil {
152 return magna.Message{}, err
153 }
154 defer conn.Close()
155
156 conn.SetDeadline(time.Now().Add(timeout))
157
158 query := magna.CreateRequest(magna.QUERY, false)
159 query = query.AddQuestion(question)
160 msg, err := query.Encode()
161 if err != nil {
162 return magna.Message{}, err
163 }
164
165 if _, err := conn.Write(msg); err != nil {
166 return magna.Message{}, err
167 }
168
169 p := make([]byte, 512)
170 nn, err := conn.Read(p)
171
172 // TODO: retry request with TCP
173 if err != nil || nn > 512 {
174 if err == nil {
175 err = fmt.Errorf("truncated response")
176 }
177 return magna.Message{}, err
178 }
179
180 var response magna.Message
181 if err := response.Decode(p); err != nil {
182 return magna.Message{}, err
183 }
184
185 if err := validateResponse(*query, response, question); err != nil {
186 return magna.Message{}, err
187 }
188 return response, err
189}
190
191func validateResponse(query magna.Message, response magna.Message, question magna.Question) error {
192 if response.Header.ID != query.Header.ID {
193 return errNonMatchingID
194 }
195 if len(response.Question) < 1 || response.Question[0] != question {
196 return errNonMatchingQuestion
197 }
198 return nil
199}
200
201func ExtractAnswer(q magna.Question, r magna.Message) (bool, []magna.ResourceRecord) {
202 answers := make([]magna.ResourceRecord, 0, r.Header.ANCount)
203 for _, a := range r.Answer {
204 if a.RClass == q.QClass && strings.ToLower(a.Name) == strings.ToLower(q.QName) {
205 answers = append(answers, a)
206 }
207 }
208
209 if len(answers) <= 0 {
210 return false, []magna.ResourceRecord{}
211 }
212
213 return true, answers
214}
215
216func ExtractSOA(r magna.Message) (bool, []magna.ResourceRecord) {
217 answers := make([]magna.ResourceRecord, 0, r.Header.NSCount)
218 for _, a := range r.Authority {
219 if a.RType == magna.SOAType {
220 answers = append(answers, a)
221 }
222 }
223
224 if len(answers) <= 0 {
225 return false, []magna.ResourceRecord{}
226 }
227
228 return true, answers
229}
230
231func HandleGlueRecords(q magna.Question, r magna.Message) (bool, []string) {
232 answers := make([]string, 0, r.Header.ARCount)
233 for _, a := range r.Authority {
234 if a.RType != magna.NSType {
235 continue
236 }
237
238 ns, ok := a.RData.(*magna.NS)
239 if !ok {
240 // this should not happen but better safe than sorry
241 continue
242 }
243
244 for _, ad := range r.Additional {
245 // XXX: add AAAAType when magna supports it
246 if ad.RType == magna.AType && strings.ToLower(ad.Name) == strings.ToLower(ns.NSDName) {
247 answers = append(answers, ad.RData.String())
248 }
249 }
250 }
251
252 if len(answers) <= 0 {
253 return false, []string{}
254 }
255
256 return true, answers
257}
258
259func (h *QueryHandler) HandleReferral(ctx context.Context, q magna.Question, r magna.Message) (bool, []string) {
260 servers := make([]string, 0, r.Header.NSCount)
261
262 for _, auth := range r.Authority {
263 if auth.RType == magna.NSType {
264 nsQuestion := magna.Question{
265 QName: auth.RData.String(),
266 QType: magna.AType,
267 QClass: magna.IN,
268 }
269
270 answers, err := h.resolveQuestion(ctx, nsQuestion, h.RootServers)
271 if err != nil {
272 h.Logger.Warn("error handling referral",
273 "question", nsQuestion,
274 "depth", getDepth(ctx),
275 "error", err)
276 continue
277 }
278
279 for _, ans := range answers {
280 servers = append(servers, ans.RData.String())
281 }
282 }
283 }
284
285 if len(servers) <= 0 {
286 return false, []string{}
287 }
288
289 return true, servers
290}