a recursive dns resolver

general improvements

+4
go.mod
···
require (
github.com/BurntSushi/toml v1.5.0
github.com/ClickHouse/clickhouse-go/v2 v2.34.0
+
github.com/stretchr/testify v1.10.0
+
golang.org/x/time v0.11.0
tangled.sh/seiso.moe/magna v0.0.1
)
require (
github.com/ClickHouse/ch-go v0.65.1 // indirect
github.com/andybalholm/brotli v1.1.1 // indirect
+
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-faster/city v1.0.1 // indirect
github.com/go-faster/errors v0.7.1 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/paulmach/orb v0.11.1 // indirect
github.com/pierrec/lz4/v4 v4.1.22 // indirect
+
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/segmentio/asm v1.2.0 // indirect
github.com/shopspring/decimal v1.4.0 // indirect
go.opentelemetry.io/otel v1.35.0 // indirect
+2
go.sum
···
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
+
golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0=
+
golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
+1 -1
main.go
···
func main() {
cfg, err := config.LoadConfig(configFlag)
if err != nil {
-
log.Fatalf("failed to load config: ", err)
+
log.Fatalf("failed to load config: %v", err)
}
logger := setupLogger(&cfg)
+3 -8
pkg/dns/cache.go
···
Additional: make([]magna.ResourceRecord, len(original.Additional)),
}
-
remainingDuration := original.ExpireAt.Sub(now)
-
if remainingDuration < 0 {
-
remainingDuration = 0
-
}
+
remainingDuration := max(original.ExpireAt.Sub(now), 0)
remainingTTL := uint32(remainingDuration.Seconds())
copyAndAdjust := func(dest *[]magna.ResourceRecord, src []magna.ResourceRecord) {
for i, rr := range src {
(*dest)[i] = rr
rrExpireAt := original.CacheTime.Add(time.Duration(rr.TTL) * time.Second)
-
rrRemainingDuration := rrExpireAt.Sub(now)
-
if rrRemainingDuration < 0 {
-
rrRemainingDuration = 0
-
}
+
rrRemainingDuration := max(rrExpireAt.Sub(now), 0)
rrRemainingTTL := uint32(rrRemainingDuration.Seconds())
finalTTL := min(rrRemainingTTL, remainingTTL)
···
}
func (c *LRUCache) Set(key string, entry *CacheEntry) {
+
c.logger.Info("setting key", "key", key)
if entry == nil {
c.logger.Warn("attempted to set nil entry in cache", "key", key)
return
+57 -38
pkg/dns/ratelimit.go
···
"sync"
"time"
+
"golang.org/x/time/rate"
"tangled.sh/seiso.moe/magna"
)
···
type rateLimiter struct {
config RateLimitConfig
-
ipData map[string]*ipRateData
+
limiters map[string]*ipRateLimiterEntry
mu sync.RWMutex
+
stopCleanup chan struct{}
}
-
type ipRateData struct {
-
time time.Time
+
type ipRateLimiterEntry struct {
+
limiter *rate.Limiter
+
lastAccess time.Time
}
func NewDefaultRateLimitConfig() *RateLimitConfig {
return &RateLimitConfig{
Rate: 1,
Burst: 1,
-
WindowLength: time.Hour,
-
ExpirationTime: time.Hour,
+
WindowLength: time.Second,
+
ExpirationTime: 5 * time.Minute,
}
}
func newRateLimiter(config RateLimitConfig) *rateLimiter {
-
return &rateLimiter{
+
rl := &rateLimiter{
config: config,
-
ipData: make(map[string]*ipRateData),
+
limiters: make(map[string]*ipRateLimiterEntry),
+
stopCleanup: make(chan struct{}),
}
+
+
go rl.cleanupLoop()
+
return rl
}
func (rl *rateLimiter) allow(ip string) bool {
-
rl.mu.Lock()
+
rl.mu.Lock()
defer rl.mu.Unlock()
+
entry, exists := rl.limiters[ip]
now := time.Now()
-
cost := time.Duration(float64(time.Second) / rl.config.Rate)
-
data, exists := rl.ipData[ip]
if !exists {
-
data = &ipRateData{time: now.Add(-rl.config.WindowLength)}
-
rl.ipData[ip] = data
-
}
+
limiter := rate.NewLimiter(rate.Limit(rl.config.Rate), rl.config.Burst)
+
entry := &ipRateLimiterEntry{
+
limiter: limiter,
+
lastAccess: now,
+
}
-
if data.time.Before(now.Add(-rl.config.WindowLength)) {
-
data.time = now.Add(-rl.config.WindowLength)
+
rl.limiters[ip] = entry
+
return entry.limiter.Allow()
}
-
nextTime := data.time.Add(cost)
-
if now.Before(nextTime) {
-
return false
-
}
+
entry.lastAccess = now
+
return entry.limiter.Allow()
+
}
+
+
func (rl *rateLimiter) cleanupLoop() {
+
ticker := time.NewTicker(rl.config.ExpirationTime)
+
defer ticker.Stop()
-
if nextTime.Sub(now.Add(-rl.config.WindowLength)) > time.Duration(rl.config.Burst)*cost {
-
nextTime = now.Add(cost)
+
for {
+
select {
+
case <-ticker.C:
+
rl.performCleanup()
+
case <-rl.stopCleanup:
+
return
+
}
}
-
-
data.time = nextTime
-
return true
}
-
func (rl *rateLimiter) cleanup() {
+
func (rl *rateLimiter) performCleanup() {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
-
for ip, data := range rl.ipData {
-
if data.time.Before(now.Add(-rl.config.WindowLength)) {
-
delete(rl.ipData, ip)
+
expirationCutoff := now.Add(-rl.config.ExpirationTime)
+
+
for ip, entry := range rl.limiters {
+
if entry.lastAccess.Before(expirationCutoff) {
+
delete(rl.limiters, ip)
}
}
}
+
func (rl *rateLimiter) Stop() {
+
close(rl.stopCleanup)
+
}
+
func extractIP(addr net.Addr) string {
switch v := addr.(type) {
case *net.UDPAddr:
···
config = NewDefaultRateLimitConfig()
}
-
rl := newRateLimiter(*config)
-
-
go func() {
-
ticker := time.NewTicker(config.ExpirationTime)
-
for range ticker.C {
-
rl.cleanup()
+
if config.Rate <= 0 {
+
config.Rate = 1
+
if config.Burst <= 0 {
+
config.Burst = 1
}
-
}()
+
}
+
+
rl := newRateLimiter(*config)
return func(next Handler) Handler {
return HandlerFunc(func(w ResponseWriter, r *Request) {
-
if !rl.allow(extractIP(r.RemoteAddr)) {
-
r.Message.Header.RA = true
+
clientIP := extractIP(r.RemoteAddr)
+
if !rl.allow(clientIP) {
msg := r.Message.CreateReply(r.Message)
msg = r.Message.SetRCode(magna.REFUSED)
+
msg.Header.RA = true
-
// XXX: dont support edns yet and these get copied over on responses
msg.Header.ANCount = 0
msg.Header.NSCount = 0
msg.Header.ARCount = 0
+74 -21
pkg/dns/resolve.go
···
"log/slog"
"math/rand"
"net"
+
"strings"
"time"
"tangled.sh/seiso.moe/magna"
···
}
func (h *QueryHandler) resolveQuestion(ctx context.Context, question magna.Question, servers []string) ([]magna.ResourceRecord, error) {
+
depth := 0
+
if d, ok := ctx.Value("recursion_depth").(int); ok {
+
depth = d
+
}
+
+
// TODO move to configuration file
+
if depth > 15 {
+
h.Logger.Warn("Max recursion depth exceeded", "question", question.QName, "depth", depth)
+
return nil, fmt.Errorf("max recursion depth exceeded")
+
}
+
ctx = context.WithValue(ctx, "recursion_depth", depth+1)
+
if h.Cache != nil {
cacheKey := GenerateCacheKey(question)
if entry, found := h.Cache.Get(cacheKey); found {
···
resolveCtx, cancel := context.WithTimeout(ctx, h.Timeout)
defer cancel()
+
ch := make(chan queryResponse, len(servers))
for _, s := range servers {
go queryServer(resolveCtx, question, s, ch)
}
+
+
var lastError error
for range servers {
select {
case res := <-ch:
if res.Error != nil {
h.Logger.Warn("error", "question", question.QName, "server", res.Server, "error", res.Error)
+
lastError = res.Error
continue
}
···
Authority: msg.Authority,
Additional: msg.Additional,
}
+
fullMsg.Header.ANCount = uint16(len(msg.Answer))
+
fullMsg.Header.NSCount = uint16(len(msg.Authority))
+
fullMsg.Header.ARCount = uint16(len(msg.Additional))
if msg.Header.RCode == magna.NXDOMAIN {
if h.Cache != nil {
···
if msg.Header.NSCount > 0 {
var nextServers []string
var nsRecords []magna.ResourceRecord
+
nsRecordsWithoutGlue := 0
glueMap := make(map[string]string)
for _, rr := range msg.Additional {
if rr.RType == magna.AType {
-
for _, nsRR := range msg.Authority {
-
if nsRR.RType == magna.NSType && nsRR.RData.String() == rr.Name {
-
glueMap[rr.Name] = rr.RData.String()
-
break
-
}
-
}
+
glueMap[strings.ToLower(rr.Name)] = rr.RData.String()
}
}
···
}
for _, nsRR := range nsRecords {
-
nsName := nsRR.RData.String()
+
nsName := strings.ToLower(nsRR.RData.String())
+
if ip, found := glueMap[nsName]; found {
nextServers = append(nextServers, ip)
} else {
+
nsRecordsWithoutGlue++
nsQuestion := magna.Question{QName: nsName, QType: magna.AType, QClass: magna.IN}
nsAnswers, err := h.resolveQuestion(resolveCtx, nsQuestion, h.RootServers)
if err != nil {
h.Logger.Warn("error resolving NS A record", "ns", nsName, "error", err)
continue
+
} else {
+
foundAddr := false
+
for _, ans := range nsAnswers {
+
if ans.RType == magna.AType && strings.ToLower(ans.Name) == nsName {
+
ipAddr := ans.RData.String()
+
nextServers = append(nextServers, ipAddr)
+
foundAddr = true
+
}
+
}
+
if !foundAddr {
+
h.Logger.Warn("A record lookup for NS succeeded but yielded no matching A records", "ns", nsName)
+
}
}
-
for _, ans := range nsAnswers {
-
if ans.RType == magna.AType {
-
nextServers = append(nextServers, ans.RData.String())
-
}
-
}
}
}
if len(nextServers) > 0 {
-
return h.resolveQuestion(resolveCtx, question, nextServers)
+
return h.resolveQuestion(ctx, question, nextServers)
+
} else {
+
h.Logger.Warn("Processed delegation but failed to find any next server IPs", "question", question.QName, "server", res.Server, "ns_records_count", len(nsRecords), "ns_without_glue", nsRecordsWithoutGlue, "depth", depth)
+
lastError = fmt.Errorf("failed to resolve NS addresses for delegation from %s", res.Server)
}
h.Logger.Warn("could not resolve any NS records for delegation", "question", question.QName)
continue
}
-
if msg.Header.RCode == magna.NOERROR && msg.Header.ANCount == 0 {
+
if msg.Header.RCode == magna.NOERROR && msg.Header.ANCount == 0 && msg.Header.NSCount == 0{
+
h.Logger.Debug("Received NODATA response (NOERROR, ANCount=0)", "question", question.QName, "server", res.Server, "depth", depth)
+
if h.Cache != nil {
entry := CreateCacheEntry(question, &fullMsg)
if entry != nil {
h.Cache.Set(GenerateCacheKey(question), entry)
}
}
+
return []magna.ResourceRecord{}, nil
}
-
h.Logger.Warn("unexpected response state", "question", question.QName, "server", res.Server, "rcode", msg.Header.RCode)
-
continue
+
h.Logger.Warn("Unhandled response state", "question", question.QName, "server", res.Server, "rcode", msg.Header.RCode.String(), "ancount", msg.Header.ANCount, "nscount", msg.Header.NSCount, "depth", depth)
+
lastError = fmt.Errorf("unhandled response code %s from %s", msg.Header.RCode.String(), res.Server)
case <-resolveCtx.Done():
-
h.Logger.Debug("resolution cancelled or timed out", "question", question.QName)
-
return []magna.ResourceRecord{}, fmt.Errorf("resolution timed out or cancelled")
+
h.Logger.Warn("Resolution step timed out or cancelled", "question", question.QName, "elapsed", time.Since(time.Now().Add(-h.Timeout)), "depth", depth, "error", resolveCtx.Err()) // Approx elapsed
+
if lastError != nil {
+
return nil, fmt.Errorf("resolution timed out after error: %w", lastError)
+
}
+
return nil, fmt.Errorf("resolution step timed out or cancelled: %w", resolveCtx.Err())
}
}
h.Logger.Warn("all resolution paths failed", "question", question.QName)
+
if lastError != nil {
+
return nil, fmt.Errorf("all resolution paths failed: %w", lastError)
+
}
return []magna.ResourceRecord{}, fmt.Errorf("all resolution paths failed")
}
···
var response queryResponse
response.Server = server
+
bufPtr := resolverUDPBufferPool.Get().(*[]byte)
+
defer resolverUDPBufferPool.Put(bufPtr)
+
p := *bufPtr
+
defer func() {
+
if ctx.Err() != nil {
+
if response.Error == nil {
+
response.Error = ctx.Err()
+
}
+
}
+
select {
case ch <- response:
+
case <-ctx.Done():
+
slog.Debug("queryServer response channel send cancelled by context", "server", server, "question", question.QName)
default:
-
slog.Debug("queryServer response channel blocked or closed", "server", server)
+
slog.Debug("queryServer response channel blocked or closed on send", "server", server, "question", question.QName)
}
}()
···
dialer := net.Dialer{Timeout: 2 * time.Second}
conn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf("%s:53", server))
if err != nil {
-
response.Error = fmt.Errorf("dial error: %w", err)
+
if ctx.Err() != nil {
+
response.Error = ctx.Err()
+
} else {
+
response.Error = fmt.Errorf("dial error: %w", err)
+
}
+
return
}
defer conn.Close()
···
return
}
-
p := make([]byte, 512)
n, err := conn.Read(p)
if err != nil {
if ctx.Err() != nil {
+36 -4
pkg/dns/server.go
···
type contextKey string
const (
-
cacheHitKey contextKey = "cache_hit"
+
cacheHitKey contextKey = "cache_hit"
+
maxUDPBufferSize = 4096
+
maxResolverUDPBufferSize = 4096
+
)
+
+
var (
+
serverUDPBufferPool = sync.Pool {
+
New: func() any {
+
b := make([]byte, maxUDPBufferSize)
+
return &b
+
},
+
}
+
+
resolverUDPBufferPool = sync.Pool {
+
New: func() any {
+
b := make([]byte, maxResolverUDPBufferSize)
+
return &b
+
},
+
}
)
func setCacheHit(ctx context.Context) context.Context {
···
srv.Logger.Info("UDP listener started", "address", conn.LocalAddr())
for {
-
buf := make([]byte, srv.UDPSize)
+
bufPtr := serverUDPBufferPool.Get().(*[]byte)
+
buffer := *bufPtr
+
readDeadlineSet := false
if srv.ReadTimeout > 0 {
err := conn.SetReadDeadline(time.Now().Add(srv.ReadTimeout))
if err != nil {
+
serverUDPBufferPool.Put(bufPtr)
return fmt.Errorf("error setting UDP read deadline: %w", err)
}
+
readDeadlineSet = true
}
-
n, remoteAddr, err := conn.ReadFromUDP(buf)
+
n, remoteAddr, err := conn.ReadFromUDP(buffer)
if err != nil {
+
serverUDPBufferPool.Put(bufPtr)
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
+
if readDeadlineSet {
+
continue
+
}
+
+
srv.Logger.Warn("UDP read timeout occurred without explicit deadline set", "error", err)
continue
}
···
continue
}
-
go srv.handleUDPQuery(conn, buf[:n], remoteAddr)
+
queryData := make([]byte, n)
+
copy(queryData, buffer[:n])
+
serverUDPBufferPool.Put(bufPtr)
+
+
go srv.handleUDPQuery(conn, queryData, remoteAddr)
}
}