a recursive dns resolver
at main 7.2 kB view raw
1package dns 2 3import ( 4 "context" 5 "fmt" 6 "log/slog" 7 "net" 8 "strings" 9 "time" 10 11 "tangled.sh/seiso.moe/magna" 12) 13 14var ( 15 errNXDOMAIN = fmt.Errorf("alky: domain does not exist") 16 errOnlySOA = fmt.Errorf("alky: only an soa was reffered") 17 errNoServers = fmt.Errorf("alky: no servers responded") 18 errMaxDepth = fmt.Errorf("alky: maximum recursion depth exceeded") 19 errNonMatchingID = fmt.Errorf("alky: response ID mismatch") 20 errNonMatchingQuestion = fmt.Errorf("alky: response question mismatch") 21 errQueryTimeout = fmt.Errorf("alky: query timeout") 22) 23 24const ( 25 depthKey contextKey = "dns_recursion_depth" 26) 27 28func withIncrementedDepth(ctx context.Context, maxDepth int) (context.Context, error) { 29 depth := getDepth(ctx) 30 if depth >= maxDepth { 31 return nil, errMaxDepth 32 } 33 return context.WithValue(ctx, depthKey, depth+1), nil 34} 35 36func getDepth(ctx context.Context) int { 37 if depth, ok := ctx.Value(depthKey).(int); ok { 38 return depth 39 } 40 return 0 41} 42 43type QueryHandler struct { 44 RootServers []string 45 Timeout time.Duration 46 Logger *slog.Logger 47} 48 49type queryResponse struct { 50 MSG magna.Message 51 Server string 52 Error error 53} 54 55func (h *QueryHandler) ServeDNS(w ResponseWriter, r *Request) { 56 if len(r.Message.Question) < 1 { 57 h.Logger.Debug("received query with no questions") 58 msg := r.Message.CreateReply(r.Message) 59 msg = msg.SetRCode(magna.FORMERR) 60 w.WriteMsg(msg) 61 return 62 } 63 64 question := r.Message.Question[0] 65 msg := r.Message.CreateReply(r.Message) 66 msg.Header.RA = true 67 68 records, err := h.resolveQuestion(r.Context, question, h.RootServers) 69 70 if err == errNXDOMAIN { 71 msg = msg.SetRCode(magna.NXDOMAIN) 72 msg.Authority = records 73 msg.Header.NSCount = uint16(len(records)) 74 msg.Header.ANCount = 0 75 msg.Answer = nil 76 } else if err == errOnlySOA { 77 msg = msg.SetRCode(magna.NOERROR) 78 msg.Authority = records 79 msg.Header.NSCount = uint16(len(records)) 80 msg.Header.ANCount = 0 81 msg.Header.ARCount = 0 82 msg.Answer = nil 83 } else if err != nil { 84 h.Logger.Warn("error", "error", err) 85 msg = msg.SetRCode(magna.SERVFAIL) 86 } else { 87 msg.Answer = records 88 msg.Header.ANCount = uint16(len(records)) 89 msg = msg.SetRCode(magna.NOERROR) 90 } 91 92 w.WriteMsg(msg) 93} 94 95func (h *QueryHandler) resolveQuestion(ctx context.Context, question magna.Question, servers []string) ([]magna.ResourceRecord, error) { 96 const maxDepth = 16 97 98 newCtx, err := withIncrementedDepth(ctx, maxDepth) 99 if err != nil { 100 return nil, err 101 } 102 ctx = newCtx 103 104 for _, s := range servers { 105 msg, err := queryServer(ctx, question, s, h.Timeout) 106 if err != nil { 107 h.Logger.Warn("unable to resolve question", "server", s, "error", err) 108 continue 109 } 110 111 if msg.Header.RCode == magna.NXDOMAIN { 112 _, authority := ExtractSOA(msg) 113 return authority, errNXDOMAIN 114 } 115 116 if ok, answers := ExtractAnswer(question, msg); ok { 117 if msg.Answer[0].RType == magna.CNAMEType { 118 cnameQuestion := magna.Question{QName: msg.Answer[0].RData.String(), QType: question.QType, QClass: question.QClass} 119 ips, err := h.resolveQuestion(ctx, cnameQuestion, h.RootServers) 120 if err != nil { 121 h.Logger.Info("unable to resolve CNAME target", "question", cnameQuestion, "depth", getDepth(ctx), "error", err) 122 } 123 124 answers = append(answers, ips...) 125 } 126 return answers, nil 127 } 128 129 if ok, answers := HandleGlueRecords(question, msg); ok { 130 return h.resolveQuestion(ctx, question, answers) 131 } 132 133 if ok, answers := h.HandleReferral(ctx, question, msg); ok { 134 return h.resolveQuestion(ctx, question, answers) 135 } 136 137 if ok, answers := ExtractSOA(msg); ok { 138 return answers, errOnlySOA 139 } 140 } 141 142 return []magna.ResourceRecord{}, errNoServers 143} 144 145func queryServer(ctx context.Context, question magna.Question, server string, timeout time.Duration) (magna.Message, error) { 146 ctx, cancel := context.WithTimeout(ctx, timeout) 147 defer cancel() 148 149 var d net.Dialer 150 conn, err := d.DialContext(ctx, "udp", fmt.Sprintf("%s:53", server)) 151 if err != nil { 152 return magna.Message{}, err 153 } 154 defer conn.Close() 155 156 conn.SetDeadline(time.Now().Add(timeout)) 157 158 query := magna.CreateRequest(magna.QUERY, false) 159 query = query.AddQuestion(question) 160 msg, err := query.Encode() 161 if err != nil { 162 return magna.Message{}, err 163 } 164 165 if _, err := conn.Write(msg); err != nil { 166 return magna.Message{}, err 167 } 168 169 p := make([]byte, 512) 170 nn, err := conn.Read(p) 171 172 // TODO: retry request with TCP 173 if err != nil || nn > 512 { 174 if err == nil { 175 err = fmt.Errorf("truncated response") 176 } 177 return magna.Message{}, err 178 } 179 180 var response magna.Message 181 if err := response.Decode(p); err != nil { 182 return magna.Message{}, err 183 } 184 185 if err := validateResponse(*query, response, question); err != nil { 186 return magna.Message{}, err 187 } 188 return response, err 189} 190 191func validateResponse(query magna.Message, response magna.Message, question magna.Question) error { 192 if response.Header.ID != query.Header.ID { 193 return errNonMatchingID 194 } 195 if len(response.Question) < 1 || response.Question[0] != question { 196 return errNonMatchingQuestion 197 } 198 return nil 199} 200 201func ExtractAnswer(q magna.Question, r magna.Message) (bool, []magna.ResourceRecord) { 202 answers := make([]magna.ResourceRecord, 0, r.Header.ANCount) 203 for _, a := range r.Answer { 204 if a.RClass == q.QClass && strings.ToLower(a.Name) == strings.ToLower(q.QName) { 205 answers = append(answers, a) 206 } 207 } 208 209 if len(answers) <= 0 { 210 return false, []magna.ResourceRecord{} 211 } 212 213 return true, answers 214} 215 216func ExtractSOA(r magna.Message) (bool, []magna.ResourceRecord) { 217 answers := make([]magna.ResourceRecord, 0, r.Header.NSCount) 218 for _, a := range r.Authority { 219 if a.RType == magna.SOAType { 220 answers = append(answers, a) 221 } 222 } 223 224 if len(answers) <= 0 { 225 return false, []magna.ResourceRecord{} 226 } 227 228 return true, answers 229} 230 231func HandleGlueRecords(q magna.Question, r magna.Message) (bool, []string) { 232 answers := make([]string, 0, r.Header.ARCount) 233 for _, a := range r.Authority { 234 if a.RType != magna.NSType { 235 continue 236 } 237 238 ns, ok := a.RData.(*magna.NS) 239 if !ok { 240 // this should not happen but better safe than sorry 241 continue 242 } 243 244 for _, ad := range r.Additional { 245 // XXX: add AAAAType when magna supports it 246 if ad.RType == magna.AType && strings.ToLower(ad.Name) == strings.ToLower(ns.NSDName) { 247 answers = append(answers, ad.RData.String()) 248 } 249 } 250 } 251 252 if len(answers) <= 0 { 253 return false, []string{} 254 } 255 256 return true, answers 257} 258 259func (h *QueryHandler) HandleReferral(ctx context.Context, q magna.Question, r magna.Message) (bool, []string) { 260 servers := make([]string, 0, r.Header.NSCount) 261 262 for _, auth := range r.Authority { 263 if auth.RType == magna.NSType { 264 nsQuestion := magna.Question{ 265 QName: auth.RData.String(), 266 QType: magna.AType, 267 QClass: magna.IN, 268 } 269 270 answers, err := h.resolveQuestion(ctx, nsQuestion, h.RootServers) 271 if err != nil { 272 h.Logger.Warn("error handling referral", 273 "question", nsQuestion, 274 "depth", getDepth(ctx), 275 "error", err) 276 continue 277 } 278 279 for _, ans := range answers { 280 servers = append(servers, ans.RData.String()) 281 } 282 } 283 } 284 285 if len(servers) <= 0 { 286 return false, []string{} 287 } 288 289 return true, servers 290}