package dns import ( "context" "fmt" "log/slog" "net" "strings" "time" "tangled.sh/seiso.moe/magna" ) var ( errNXDOMAIN = fmt.Errorf("alky: domain does not exist") errOnlySOA = fmt.Errorf("alky: only an soa was reffered") errNoServers = fmt.Errorf("alky: no servers responded") errMaxDepth = fmt.Errorf("alky: maximum recursion depth exceeded") errNonMatchingID = fmt.Errorf("alky: response ID mismatch") errNonMatchingQuestion = fmt.Errorf("alky: response question mismatch") errQueryTimeout = fmt.Errorf("alky: query timeout") ) const ( depthKey contextKey = "dns_recursion_depth" ) func withIncrementedDepth(ctx context.Context, maxDepth int) (context.Context, error) { depth := getDepth(ctx) if depth >= maxDepth { return nil, errMaxDepth } return context.WithValue(ctx, depthKey, depth+1), nil } func getDepth(ctx context.Context) int { if depth, ok := ctx.Value(depthKey).(int); ok { return depth } return 0 } type QueryHandler struct { RootServers []string Timeout time.Duration Logger *slog.Logger } type queryResponse struct { MSG magna.Message Server string Error error } func (h *QueryHandler) ServeDNS(w ResponseWriter, r *Request) { if len(r.Message.Question) < 1 { h.Logger.Debug("received query with no questions") msg := r.Message.CreateReply(r.Message) msg = msg.SetRCode(magna.FORMERR) w.WriteMsg(msg) return } question := r.Message.Question[0] msg := r.Message.CreateReply(r.Message) msg.Header.RA = true records, err := h.resolveQuestion(r.Context, question, h.RootServers) if err == errNXDOMAIN { msg = msg.SetRCode(magna.NXDOMAIN) msg.Authority = records msg.Header.NSCount = uint16(len(records)) msg.Header.ANCount = 0 msg.Answer = nil } else if err == errOnlySOA { msg = msg.SetRCode(magna.NOERROR) msg.Authority = records msg.Header.NSCount = uint16(len(records)) msg.Header.ANCount = 0 msg.Header.ARCount = 0 msg.Answer = nil } else if err != nil { h.Logger.Warn("error", "error", err) msg = msg.SetRCode(magna.SERVFAIL) } else { msg.Answer = records msg.Header.ANCount = uint16(len(records)) msg = msg.SetRCode(magna.NOERROR) } w.WriteMsg(msg) } func (h *QueryHandler) resolveQuestion(ctx context.Context, question magna.Question, servers []string) ([]magna.ResourceRecord, error) { const maxDepth = 16 newCtx, err := withIncrementedDepth(ctx, maxDepth) if err != nil { return nil, err } ctx = newCtx for _, s := range servers { msg, err := queryServer(ctx, question, s, h.Timeout) if err != nil { h.Logger.Warn("unable to resolve question", "server", s, "error", err) continue } if msg.Header.RCode == magna.NXDOMAIN { _, authority := ExtractSOA(msg) return authority, errNXDOMAIN } if ok, answers := ExtractAnswer(question, msg); ok { if msg.Answer[0].RType == magna.CNAMEType { cnameQuestion := magna.Question{QName: msg.Answer[0].RData.String(), QType: question.QType, QClass: question.QClass} ips, err := h.resolveQuestion(ctx, cnameQuestion, h.RootServers) if err != nil { h.Logger.Info("unable to resolve CNAME target", "question", cnameQuestion, "depth", getDepth(ctx), "error", err) } answers = append(answers, ips...) } return answers, nil } if ok, answers := HandleGlueRecords(question, msg); ok { return h.resolveQuestion(ctx, question, answers) } if ok, answers := h.HandleReferral(ctx, question, msg); ok { return h.resolveQuestion(ctx, question, answers) } if ok, answers := ExtractSOA(msg); ok { return answers, errOnlySOA } } return []magna.ResourceRecord{}, errNoServers } func queryServer(ctx context.Context, question magna.Question, server string, timeout time.Duration) (magna.Message, error) { ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() var d net.Dialer conn, err := d.DialContext(ctx, "udp", fmt.Sprintf("%s:53", server)) if err != nil { return magna.Message{}, err } defer conn.Close() conn.SetDeadline(time.Now().Add(timeout)) query := magna.CreateRequest(magna.QUERY, false) query = query.AddQuestion(question) msg, err := query.Encode() if err != nil { return magna.Message{}, err } if _, err := conn.Write(msg); err != nil { return magna.Message{}, err } p := make([]byte, 512) nn, err := conn.Read(p) // TODO: retry request with TCP if err != nil || nn > 512 { if err == nil { err = fmt.Errorf("truncated response") } return magna.Message{}, err } var response magna.Message if err := response.Decode(p); err != nil { return magna.Message{}, err } if err := validateResponse(*query, response, question); err != nil { return magna.Message{}, err } return response, err } func validateResponse(query magna.Message, response magna.Message, question magna.Question) error { if response.Header.ID != query.Header.ID { return errNonMatchingID } if len(response.Question) < 1 || response.Question[0] != question { return errNonMatchingQuestion } return nil } func ExtractAnswer(q magna.Question, r magna.Message) (bool, []magna.ResourceRecord) { answers := make([]magna.ResourceRecord, 0, r.Header.ANCount) for _, a := range r.Answer { if a.RClass == q.QClass && strings.ToLower(a.Name) == strings.ToLower(q.QName) { answers = append(answers, a) } } if len(answers) <= 0 { return false, []magna.ResourceRecord{} } return true, answers } func ExtractSOA(r magna.Message) (bool, []magna.ResourceRecord) { answers := make([]magna.ResourceRecord, 0, r.Header.NSCount) for _, a := range r.Authority { if a.RType == magna.SOAType { answers = append(answers, a) } } if len(answers) <= 0 { return false, []magna.ResourceRecord{} } return true, answers } func HandleGlueRecords(q magna.Question, r magna.Message) (bool, []string) { answers := make([]string, 0, r.Header.ARCount) for _, a := range r.Authority { if a.RType != magna.NSType { continue } ns, ok := a.RData.(*magna.NS) if !ok { // this should not happen but better safe than sorry continue } for _, ad := range r.Additional { // XXX: add AAAAType when magna supports it if ad.RType == magna.AType && strings.ToLower(ad.Name) == strings.ToLower(ns.NSDName) { answers = append(answers, ad.RData.String()) } } } if len(answers) <= 0 { return false, []string{} } return true, answers } func (h *QueryHandler) HandleReferral(ctx context.Context, q magna.Question, r magna.Message) (bool, []string) { servers := make([]string, 0, r.Header.NSCount) for _, auth := range r.Authority { if auth.RType == magna.NSType { nsQuestion := magna.Question{ QName: auth.RData.String(), QType: magna.AType, QClass: magna.IN, } answers, err := h.resolveQuestion(ctx, nsQuestion, h.RootServers) if err != nil { h.Logger.Warn("error handling referral", "question", nsQuestion, "depth", getDepth(ctx), "error", err) continue } for _, ans := range answers { servers = append(servers, ans.RData.String()) } } } if len(servers) <= 0 { return false, []string{} } return true, servers }