a recursive dns resolver

feat: rewrite resolution logic to be more modular

Changed files
+159 -91
pkg
+1 -1
go.mod
···
github.com/ClickHouse/clickhouse-go/v2 v2.34.0
github.com/stretchr/testify v1.10.0
golang.org/x/time v0.11.0
-
tangled.sh/seiso.moe/magna v0.0.1
+
tangled.sh/seiso.moe/magna v0.0.2
)
require (
+2
go.sum
···
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
tangled.sh/seiso.moe/magna v0.0.1 h1:v8GM2y3xEinc0jGVxYf/33xtWJ74ES9EuTaMxXL8zxo=
tangled.sh/seiso.moe/magna v0.0.1/go.mod h1:bqm+DTo2Pv4ITT0EnR079l++BJgoChBswSB/3KeijUk=
+
tangled.sh/seiso.moe/magna v0.0.2 h1:4VGPlqv/7tVyTtsR4Qkk8ZNypuNbmaeLogWzkpbHrRs=
+
tangled.sh/seiso.moe/magna v0.0.2/go.mod h1:bqm+DTo2Pv4ITT0EnR079l++BJgoChBswSB/3KeijUk=
+7 -7
pkg/dns/ratelimit.go
···
}
type rateLimiter struct {
-
config RateLimitConfig
-
limiters map[string]*ipRateLimiterEntry
-
mu sync.RWMutex
+
config RateLimitConfig
+
limiters map[string]*ipRateLimiterEntry
+
mu sync.RWMutex
stopCleanup chan struct{}
}
···
func newRateLimiter(config RateLimitConfig) *rateLimiter {
rl := &rateLimiter{
-
config: config,
-
limiters: make(map[string]*ipRateLimiterEntry),
+
config: config,
+
limiters: make(map[string]*ipRateLimiterEntry),
stopCleanup: make(chan struct{}),
}
···
}
func (rl *rateLimiter) allow(ip string) bool {
-
rl.mu.Lock()
+
rl.mu.Lock()
defer rl.mu.Unlock()
entry, exists := rl.limiters[ip]
···
if !exists {
limiter := rate.NewLimiter(rate.Limit(rl.config.Rate), rl.config.Burst)
entry := &ipRateLimiterEntry{
-
limiter: limiter,
+
limiter: limiter,
lastAccess: now,
}
+147 -81
pkg/dns/resolve.go
···
"fmt"
"log/slog"
"net"
+
"strings"
"time"
"tangled.sh/seiso.moe/magna"
···
var errNXDOMAIN = fmt.Errorf("nxdomain")
+
const (
+
depthKey contextKey = "dns_recursion_depth"
+
)
+
+
func withIncrementedDepth(ctx context.Context, maxDepth int) (context.Context, error) {
+
depth := getDepth(ctx)
+
if depth >= maxDepth {
+
return nil, fmt.Errorf("maximum recursion depth (%d) exceeded", maxDepth)
+
}
+
return context.WithValue(ctx, depthKey, depth+1), nil
+
}
+
+
func getDepth(ctx context.Context) int {
+
if depth, ok := ctx.Value(depthKey).(int); ok {
+
return depth
+
}
+
return 0
+
}
+
type QueryHandler struct {
RootServers []string
Timeout time.Duration
···
}
question := r.Message.Question[0]
-
msg := r.Message.CreateReply(r.Message)
msg.Header.RA = true
···
}
func (h *QueryHandler) resolveQuestion(ctx context.Context, question magna.Question, servers []string) ([]magna.ResourceRecord, error) {
-
ch := make(chan queryResponse, len(servers))
+
const maxDepth = 16
-
for _, s := range servers {
-
go queryServer(ctx, question, s, ch, h.Timeout)
+
newCtx, err := withIncrementedDepth(ctx, maxDepth)
+
if err != nil {
+
return nil, err
}
+
ctx = newCtx
-
for range servers {
-
select {
-
case res := <-ch:
-
if res.Error != nil {
-
break
-
}
-
-
msg := res.MSG
-
if msg.Header.ANCount > 0 {
-
if msg.Answer[0].RType == magna.CNAMEType {
-
cname_answers, err := h.resolveQuestion(ctx, magna.Question{QName: msg.Answer[0].RData.String(), QType: question.QType, QClass: question.QClass}, h.RootServers)
-
if err != nil {
-
continue
-
}
-
msg.Answer = append(msg.Answer, cname_answers...)
-
}
-
-
return msg.Answer, nil
-
}
+
for _, s := range servers {
+
msg, err := queryServer(ctx, question, s, h.Timeout)
+
if err != nil {
+
h.Logger.Warn("unable to resolve question", "server", s)
+
continue
+
}
-
if msg.Header.ARCount > 0 {
-
var nextZone []string
-
for _, ans := range msg.Additional {
-
if ans.RType == magna.AType {
-
nextZone = append(nextZone, ans.RData.String())
-
}
+
if ok, answers := ExtractAnswer(question, msg); ok {
+
if msg.Answer[0].RType == magna.CNAMEType {
+
cnameQuestion := magna.Question{QName: msg.Answer[0].RData.String(), QType: question.QType, QClass: question.QClass}
+
ips, err := h.resolveQuestion(ctx, cnameQuestion, h.RootServers)
+
if err != nil {
+
h.Logger.Info("unable to resolve CNAME target", "question", cnameQuestion, "depth", getDepth(ctx), "error", err)
}
-
return h.resolveQuestion(ctx, question, nextZone)
+
answers = append(answers, ips...)
}
+
return answers, nil
+
}
-
if msg.Header.NSCount > 0 {
-
var ns []string
-
for _, a := range msg.Authority {
-
if a.RType == magna.NSType {
-
ans, err := h.resolveQuestion(ctx, magna.Question{QName: a.RData.String(), QType: magna.AType, QClass: magna.IN}, h.RootServers)
-
if err != nil {
-
break
-
}
-
for _, x := range ans {
-
ns = append(ns, x.RData.String())
-
}
-
}
-
}
+
if ok, answers := HandleGlueRecords(question, msg); ok {
+
return h.resolveQuestion(ctx, question, answers)
+
}
-
return h.resolveQuestion(ctx, question, ns)
-
}
-
-
return []magna.ResourceRecord{}, nil
+
if ok, answers := h.HandleReferral(ctx, question, msg); ok {
+
return h.resolveQuestion(ctx, question, answers)
}
}
return []magna.ResourceRecord{}, nil
}
-
func queryServer(ctx context.Context, question magna.Question, server string, ch chan<- queryResponse, timeout time.Duration) {
-
done := make(chan struct{}, 1)
+
func queryServer(ctx context.Context, question magna.Question, server string, timeout time.Duration) (magna.Message, error) {
+
var d net.Dialer
+
conn, err := d.DialContext(ctx, "udp", fmt.Sprintf("%s:53", server))
+
if err != nil {
+
return magna.Message{}, err
+
}
+
defer conn.Close()
go func() {
-
conn, err := net.Dial("udp", fmt.Sprintf("%s:53", server))
-
if err != nil {
-
ch <- queryResponse{Error: err}
-
return
+
<-ctx.Done()
+
conn.Close()
+
}()
+
+
conn.SetDeadline(time.Now().Add(timeout))
+
+
query := magna.CreateRequest(0, false)
+
query = query.AddQuestion(question)
+
msg, err := query.Encode()
+
if err != nil {
+
return magna.Message{}, err
+
}
+
+
if _, err := conn.Write(msg); err != nil {
+
return magna.Message{}, err
+
}
+
+
p := make([]byte, 512)
+
nn, err := conn.Read(p)
+
+
// TODO: retry request with TCP
+
if err != nil || nn > 512 {
+
if err == nil {
+
err = fmt.Errorf("truncated response")
}
-
defer conn.Close()
+
return magna.Message{}, err
+
}
-
query := magna.CreateRequest(0, false)
-
query = query.AddQuestion(question)
-
msg, err := query.Encode()
-
if err != nil {
-
ch <-queryResponse{Server: server, Error: err}
+
var response magna.Message
+
err = response.Decode(p)
+
return response, err
+
}
+
+
func ExtractAnswer(q magna.Question, r magna.Message) (bool, []magna.ResourceRecord) {
+
answers := make([]magna.ResourceRecord, 0, r.Header.ANCount)
+
for _, a := range r.Answer {
+
if a.RClass == q.QClass && strings.ToLower(a.Name) == strings.ToLower(q.QName) {
+
answers = append(answers, a)
}
+
}
-
if _, err := conn.Write(msg); err != nil {
-
ch <- queryResponse{Server: server, Error: err}
-
return
+
if len(answers) <= 0 {
+
return false, []magna.ResourceRecord{}
+
}
+
+
return true, answers
+
}
+
+
func HandleGlueRecords(q magna.Question, r magna.Message) (bool, []string) {
+
answers := make([]string, 0, r.Header.ARCount)
+
for _, a := range r.Authority {
+
if a.RType != magna.NSType {
+
continue
}
-
p := make([]byte, 512)
-
nn, err := conn.Read(p)
+
ns, ok := a.RData.(*magna.NS)
+
if !ok {
+
// this should not happen but better safe than sorry
+
continue
+
}
-
// TODO: retry request with TCP
-
if err != nil || nn > 512 {
-
if err == nil {
-
err = fmt.Errorf("truncated response")
+
for _, ad := range r.Additional {
+
// XXX: add AAAAType when magna supports it
+
if ad.RType == magna.AType && strings.ToLower(ad.Name) == strings.ToLower(ns.NSDName) {
+
answers = append(answers, ad.RData.String())
}
-
ch <- queryResponse{Server: server, Error: err}
-
return
}
+
}
-
var response magna.Message
-
err = response.Decode(p)
-
ch <- queryResponse{MSG: response, Server: server, Error: err}
-
}()
+
if len(answers) <= 0 {
+
return false, []string{}
+
}
+
+
return true, answers
+
}
+
+
func (h *QueryHandler) HandleReferral(ctx context.Context, q magna.Question, r magna.Message) (bool, []string) {
+
servers := make([]string, 0, r.Header.NSCount)
+
+
for _, auth := range r.Authority {
+
if auth.RType == magna.NSType {
+
nsQuestion := magna.Question{
+
QName: auth.RData.String(),
+
QType: magna.AType,
+
QClass: magna.IN,
+
}
+
+
answers, err := h.resolveQuestion(ctx, nsQuestion, h.RootServers)
+
if err != nil {
+
h.Logger.Warn("error handling referral",
+
"question", nsQuestion,
+
"depth", getDepth(ctx),
+
"error", err)
+
continue
+
}
-
select {
-
case <-ctx.Done():
-
ch <- queryResponse{Server: server, Error: ctx.Err()}
-
case <-done:
-
// goroutine finished with no cancellation
-
case <-time.After(timeout):
-
ch <- queryResponse{Server: server, Error: fmt.Errorf("timeout")}
+
for _, ans := range answers {
+
servers = append(servers, ans.RData.String())
+
}
+
}
}
+
+
if len(servers) <= 0 {
+
return false, []string{}
+
}
+
+
return true, servers
}
+2 -2
pkg/dns/server.go
···
)
var (
-
serverUDPBufferPool = sync.Pool {
+
serverUDPBufferPool = sync.Pool{
New: func() any {
b := make([]byte, maxUDPBufferSize)
return &b
},
}
-
resolverUDPBufferPool = sync.Pool {
+
resolverUDPBufferPool = sync.Pool{
New: func() any {
b := make([]byte, maxResolverUDPBufferSize)
return &b