a recursive dns resolver

initial code review/remove complicated cache

+1 -1
docker-compose.yml
···
soft: 262144
hard: 262144
healthcheck:
-
test: wget --no-verbose --tries=1 --spider http://localhost:8123/ping || exit 1
+
test: curl http://localhost:8123/ping || exit 1
interval: 30s
timeout: 5s
retries: 3
+11 -10
go.mod
···
module tangled.sh/seiso.moe/alky
-
go 1.22.5
+
go 1.23.0
+
+
toolchain go1.24.2
require (
-
github.com/BurntSushi/toml v1.4.0
-
github.com/ClickHouse/clickhouse-go/v2 v2.31.0
-
tangled.sh/seiso.moe/magna v0.0.0-20250326021922-01ca5bbcb720
+
github.com/BurntSushi/toml v1.5.0
+
github.com/ClickHouse/clickhouse-go/v2 v2.34.0
+
tangled.sh/seiso.moe/magna v0.0.1
)
require (
-
github.com/ClickHouse/ch-go v0.64.1 // indirect
+
github.com/ClickHouse/ch-go v0.65.1 // indirect
github.com/andybalholm/brotli v1.1.1 // indirect
github.com/go-faster/city v1.0.1 // indirect
github.com/go-faster/errors v0.7.1 // indirect
github.com/google/uuid v1.6.0 // indirect
-
github.com/klauspost/compress v1.17.11 // indirect
+
github.com/klauspost/compress v1.18.0 // indirect
github.com/paulmach/orb v0.11.1 // indirect
github.com/pierrec/lz4/v4 v4.1.22 // indirect
-
github.com/pkg/errors v0.9.1 // indirect
github.com/segmentio/asm v1.2.0 // indirect
github.com/shopspring/decimal v1.4.0 // indirect
-
go.opentelemetry.io/otel v1.34.0 // indirect
-
go.opentelemetry.io/otel/trace v1.34.0 // indirect
-
golang.org/x/sys v0.30.0 // indirect
+
go.opentelemetry.io/otel v1.35.0 // indirect
+
go.opentelemetry.io/otel/trace v1.35.0 // indirect
+
golang.org/x/sys v0.32.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
+18 -19
go.sum
···
-
github.com/BurntSushi/toml v1.4.0 h1:kuoIxZQy2WRRk1pttg9asf+WVv6tWQuBNVmK8+nqPr0=
-
github.com/BurntSushi/toml v1.4.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
-
github.com/ClickHouse/ch-go v0.64.1 h1:FWpP+QU4KchgzpEekuv8YoI/fUc4H2r6Bwc5WwrzvcI=
-
github.com/ClickHouse/ch-go v0.64.1/go.mod h1:RBUynvczWwVzhS6Up9lPKlH1mrk4UAmle6uzCiW4Pkc=
-
github.com/ClickHouse/clickhouse-go/v2 v2.31.0 h1:9MNHRDYXjFTJizGEJM1DfYAqdra/ohprPoZ+LPiuHXQ=
-
github.com/ClickHouse/clickhouse-go/v2 v2.31.0/go.mod h1:V1aZaG0ctMbd8KVi+D4loXi97duWYtHiQHMCgipKJcI=
+
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
+
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
+
github.com/ClickHouse/ch-go v0.65.1 h1:SLuxmLl5Mjj44/XbINsK2HFvzqup0s6rwKLFH347ZhU=
+
github.com/ClickHouse/ch-go v0.65.1/go.mod h1:bsodgURwmrkvkBe5jw1qnGDgyITsYErfONKAHn05nv4=
+
github.com/ClickHouse/clickhouse-go/v2 v2.34.0 h1:Y4rqkdrRHgExvC4o/NTbLdY5LFQ3LHS77/RNFxFX3Co=
+
github.com/ClickHouse/clickhouse-go/v2 v2.34.0/go.mod h1:yioSINoRLVZkLyDzdMXPLRIqhDvel8iLBlwh6Iefso8=
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
···
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
-
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
-
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
+
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
+
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=
-
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
-
github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
+
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
+
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
···
github.com/paulmach/protoscan v0.2.1/go.mod h1:SpcSwydNLrxUGSDvXvO0P7g7AuhJ7lcKfDlhJCDw2gY=
github.com/pierrec/lz4/v4 v4.1.22 h1:cKFw6uJDK+/gfw5BcDL0JL5aBsAFdsIT18eRtLj7VIU=
github.com/pierrec/lz4/v4 v4.1.22/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
-
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
···
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
go.mongodb.org/mongo-driver v1.11.4/go.mod h1:PTSz5yu21bkT/wXpkS7WR5f0ddqw5quethTUn9WM+2g=
-
go.opentelemetry.io/otel v1.34.0 h1:zRLXxLCgL1WyKsPVrgbSdMN4c0FMkDAskSTQP+0hdUY=
-
go.opentelemetry.io/otel v1.34.0/go.mod h1:OWFPOQ+h4G8xpyjgqo4SxJYdDQ/qmRH+wivy7zzx9oI=
-
go.opentelemetry.io/otel/trace v1.34.0 h1:+ouXS2V8Rd4hp4580a8q23bg0azF2nI8cqLYnC8mh/k=
-
go.opentelemetry.io/otel/trace v1.34.0/go.mod h1:Svm7lSjQD7kG7KJ/MUHPVXSDGz2OX4h0M2jHBhmSfRE=
+
go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ=
+
go.opentelemetry.io/otel v1.35.0/go.mod h1:UEqy8Zp11hpkUrL73gSlELM0DupHoiq72dR+Zqel/+Y=
+
go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt/xgMs=
+
go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
···
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
-
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
+
golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20=
+
golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
···
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
-
tangled.sh/seiso.moe/magna v0.0.0-20250326021922-01ca5bbcb720 h1:e19STzd/6HFUBcdm5JnBE7txDVGdLqoyVUV+CMPVuuQ=
-
tangled.sh/seiso.moe/magna v0.0.0-20250326021922-01ca5bbcb720/go.mod h1:bqm+DTo2Pv4ITT0EnR079l++BJgoChBswSB/3KeijUk=
+
tangled.sh/seiso.moe/magna v0.0.1 h1:v8GM2y3xEinc0jGVxYf/33xtWJ74ES9EuTaMxXL8zxo=
+
tangled.sh/seiso.moe/magna v0.0.1/go.mod h1:bqm+DTo2Pv4ITT0EnR079l++BJgoChBswSB/3KeijUk=
+17 -2
main.go
···
"log"
"log/slog"
"os"
+
"os/signal"
+
"syscall"
"time"
"tangled.sh/seiso.moe/alky/pkg/config"
···
log.Fatal(err)
}
-
cache := dns.NewMemoryCache(cfg.Cache.MaxItems, cfg.Cache.CleanupInterval.Duration)
+
cache := dns.NewLRUCache(cfg.Cache.MaxItems, cfg.Cache.CleanupInterval.Duration)
+
defer cache.Stop()
handler := &dns.QueryHandler{
RootServers: rootServers,
···
Level: slog.LevelInfo,
})(metricsHandler)
+
go func() {
+
sigChan := make(chan os.Signal, 1)
+
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
+
<-sigChan
+
slog.Info("Shutting down...")
+
cache.Stop()
+
metricsClient.Close()
+
os.Exit(0)
+
}()
+
s := dns.Server{
Address: cfg.Server.Address,
Port: cfg.Server.Port,
···
if err := s.ListenAndServe(); err != nil {
slog.Error("Failed to start server", "error", err)
+
cache.Stop()
+
metricsClient.Close()
}
}
-
func monitorCacheMetrics(cache *dns.MemoryCache, metricsClient *metrics.ClickHouseMetrics, logger *slog.Logger) {
+
func monitorCacheMetrics(cache dns.Cache, metricsClient *metrics.ClickHouseMetrics, logger *slog.Logger) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
+212 -98
pkg/dns/cache.go
···
package dns
import (
+
"container/list"
"strings"
"sync"
+
"sync/atomic"
"time"
"tangled.sh/seiso.moe/magna"
···
Additional []CachedResourceRecord
NegativeTTL time.Duration
ExpireAt time.Time
-
CreateTime time.Time
IsNegative bool
}
···
BailiwickRule BailiwickRule
}
-
type MemoryCache struct {
-
entries map[string]*CacheEntry
-
mu sync.RWMutex
-
stats CacheStats
-
maxSize int
-
cleaupInterval time.Duration
+
type CacheStats struct {
+
TotalQueries atomic.Int64
+
CacheHits atomic.Int64
+
CacheMisses atomic.Int64
+
NegativeHits atomic.Int64
+
PositiveHits atomic.Int64
+
Evictions atomic.Int64
+
Size atomic.Int64
+
}
+
+
type LRUCache struct {
+
mu sync.RWMutex
+
maxSize int
+
cleanupInterval time.Duration
+
lruList *list.List
+
cacheMap map[string]*list.Element
+
stats CacheStats
+
stopCleanup chan struct{}
}
-
type CacheStats struct {
-
TotalQueries int64
-
CacheHits int64
-
CacheMisses int64
-
NegativeHits int64
-
PositiveHits int64
-
Evictions int64
-
Size int64
+
type lruItem struct {
+
key string
+
value *CacheEntry
}
type Cache interface {
Get(key string) (*CacheEntry, bool)
Set(key string, entry *CacheEntry)
GetStats() *CacheStats
+
Stop()
}
-
func NewMemoryCache(maxSize int, cleanupInterval time.Duration) *MemoryCache {
-
cache := &MemoryCache{
-
entries: make(map[string]*CacheEntry),
-
maxSize: maxSize,
-
cleaupInterval: cleanupInterval,
+
func NewLRUCache(maxSize int, cleanupInterval time.Duration) *LRUCache {
+
if maxSize <= 0 {
+
maxSize = 5000
+
}
+
+
if cleanupInterval <= 0 {
+
cleanupInterval = 5 * time.Minute
+
}
+
+
cache := &LRUCache{
+
maxSize: maxSize,
+
cleanupInterval: cleanupInterval,
+
lruList: list.New(),
+
cacheMap: make(map[string]*list.Element),
+
stopCleanup: make(chan struct{}),
}
go cache.periodicCleanup()
return cache
}
-
func (c *MemoryCache) Get(key string) (*CacheEntry, bool) {
+
func (c *LRUCache) Get(key string) (*CacheEntry, bool) {
+
c.stats.TotalQueries.Add(1)
+
c.mu.RLock()
+
element, exists := c.cacheMap[key]
c.mu.RUnlock()
-
c.stats.TotalQueries++
+
if !exists {
+
c.stats.CacheMisses.Add(1)
+
return nil, false
+
}
+
+
c.mu.Lock()
+
defer c.mu.Unlock()
-
entry, exists := c.entries[key]
+
element, exists = c.cacheMap[key]
+
// element might have been evicted or cleaned between RUnlock and Lock
if !exists {
-
c.stats.CacheMisses++
return nil, false
}
+
item := element.Value.(*lruItem)
+
entry := item.value
+
if time.Now().After(entry.ExpireAt) {
-
c.stats.CacheMisses++
+
c.removeItem(element)
+
c.stats.CacheMisses.Add(1)
return nil, false
}
if entry.IsNegative {
-
c.stats.NegativeHits++
+
c.stats.NegativeHits.Add(1)
} else {
-
c.stats.PositiveHits++
+
c.stats.PositiveHits.Add(1)
}
-
c.stats.CacheHits++
+
c.stats.CacheHits.Add(1)
return entry, true
}
-
func (c *MemoryCache) Set(key string, entry *CacheEntry) {
+
func (c *LRUCache) Set(key string, entry *CacheEntry) {
+
if time.Now().After(entry.ExpireAt) {
+
return
+
}
+
c.mu.Lock()
defer c.mu.Unlock()
-
if len(c.entries) >= c.maxSize {
-
c.evictOldest()
+
if element, exists := c.cacheMap[key]; exists {
+
element.Value.(*lruItem).value = entry
+
c.lruList.MoveToFront(element)
+
return
}
-
c.entries[key] = entry
-
c.stats.Size = int64(len(c.entries))
-
}
+
newItem := &lruItem{
+
key: key,
+
value: entry,
+
}
+
element := c.lruList.PushFront(newItem)
+
c.cacheMap[key] = element
+
currentSize := int64(c.lruList.Len())
+
c.stats.Size.Store(currentSize)
-
func (c *MemoryCache) GetStats() *CacheStats {
-
c.mu.RLock()
-
defer c.mu.RUnlock()
-
-
stats := c.stats
-
return &stats
+
for int64(c.lruList.Len()) > int64(c.maxSize) {
+
c.evictLRU()
+
}
}
-
func (c *MemoryCache) evictOldest() {
-
var oldestKey string
-
var oldestTime time.Time
-
-
first := true
-
for k, e := range c.entries {
-
if first || e.CreateTime.Before(oldestTime) {
-
oldestKey = k
-
oldestTime = e.CreateTime
-
first = false
-
}
+
func (c *LRUCache) evictLRU() {
+
element := c.lruList.Back()
+
if element != nil {
+
c.removeItem(element)
+
c.stats.Evictions.Add(1)
}
+
}
-
if oldestKey != "" {
-
delete(c.entries, oldestKey)
-
c.stats.Evictions++
-
}
+
func (c *LRUCache) removeItem(element *list.Element) {
+
item := element.Value.(*lruItem)
+
delete(c.cacheMap, item.key)
+
c.lruList.Remove(element)
+
c.stats.Size.Store(int64(c.lruList.Len()))
}
-
func (c *MemoryCache) periodicCleanup() {
-
ticker := time.NewTicker(c.cleaupInterval)
+
func (c *LRUCache) GetStats() *CacheStats {
+
statsSnapshot := &CacheStats{}
+
statsSnapshot.TotalQueries.Store(c.stats.TotalQueries.Load())
+
statsSnapshot.CacheHits.Store(c.stats.CacheHits.Load())
+
statsSnapshot.CacheMisses.Store(c.stats.CacheMisses.Load())
+
statsSnapshot.NegativeHits.Store(c.stats.NegativeHits.Load())
+
statsSnapshot.PositiveHits.Store(c.stats.PositiveHits.Load())
+
statsSnapshot.Evictions.Store(c.stats.Evictions.Load())
+
statsSnapshot.Size.Store(c.stats.Size.Load())
+
return statsSnapshot
+
}
+
+
func (c *LRUCache) periodicCleanup() {
+
ticker := time.NewTicker(c.cleanupInterval)
defer ticker.Stop()
-
for range ticker.C {
-
c.cleanup()
+
for {
+
select {
+
case <-ticker.C:
+
c.cleanupExpired()
+
case <-c.stopCleanup:
+
return
+
}
}
}
-
func (c *MemoryCache) cleanup() {
+
func (c *LRUCache) cleanupExpired() {
c.mu.Lock()
defer c.mu.Unlock()
now := time.Now()
-
for k, e := range c.entries {
-
if now.After(e.ExpireAt) {
-
delete(c.entries, k)
+
element := c.lruList.Back()
+
for element != nil {
+
item := element.Value.(*lruItem)
+
nextElement := element.Prev()
+
+
if now.After(item.value.ExpireAt) {
+
c.removeItem(element)
}
+
+
element = nextElement
}
+
}
-
c.stats.Size = int64(len(c.entries))
+
func (c *LRUCache) Stop() {
+
close(c.stopCleanup)
}
func getMinTTL(records []magna.ResourceRecord) uint32 {
if len(records) == 0 {
-
return 0
+
return 60
}
-
minTTL := records[0].TTL
for _, rr := range records[1:] {
+
if rr.TTL == 0 {
+
return 60
+
}
if rr.TTL < minTTL {
minTTL = rr.TTL
}
+
}
+
if minTTL < 60 {
+
return 60
}
return minTTL
}
···
func CreateCacheEntry(msg *magna.Message, zone string) *CacheEntry {
now := time.Now()
entry := &CacheEntry{
-
CreateTime: now,
IsNegative: msg.Header.RCode == magna.NXDOMAIN,
}
···
var soaTTL uint32
for _, auth := range msg.Authority {
if auth.RType == magna.SOAType {
-
soaTTL = auth.TTL
+
soa, ok := auth.RData.(*magna.SOA)
+
if ok {
+
soaTTL = soa.Minimum
+
if auth.TTL < soaTTL {
+
soaTTL = auth.TTL
+
}
+
} else {
+
soaTTL = auth.TTL
+
}
break
}
}
···
soaTTL = 900
}
+
if soaTTL < 60 {
+
soaTTL = 60
+
}
+
entry.NegativeTTL = time.Duration(soaTTL) * time.Second
entry.ExpireAt = now.Add(entry.NegativeTTL)
+
entry.Authority = make([]CachedResourceRecord, len(msg.Authority))
for i, rr := range msg.Authority {
+
recordTTL := rr.TTL
+
if recordTTL < 60 {
+
recordTTL = 60
+
}
rule := determineBailiwickRule(zone, rr.Name)
entry.Authority[i] = CachedResourceRecord{
Record: rr,
-
ExpireAt: now.Add(time.Duration(rr.TTL) * time.Second),
+
ExpireAt: now.Add(time.Duration(recordTTL) * time.Second),
BailiwickRule: rule,
}
}
-
return entry
-
}
+
} else {
+
minAnswerTTL := getMinTTL(msg.Answer)
+
+
if len(msg.Answer) == 0 && len(msg.Authority) > 0 && msg.Authority[0].RType == magna.SOAType {
+
+
soa, ok := msg.Authority[0].RData.(*magna.SOA)
+
soaTTL := uint32(900)
+
if ok {
+
soaTTL = soa.Minimum
+
if msg.Authority[0].TTL < soaTTL {
+
soaTTL = msg.Authority[0].TTL
+
}
+
} else {
+
soaTTL = msg.Authority[0].TTL
+
}
+
if soaTTL < 60 {
+
soaTTL = 60
+
}
+
+
entry.IsNegative = true
+
entry.NegativeTTL = time.Duration(soaTTL) * time.Second
+
entry.ExpireAt = now.Add(entry.NegativeTTL)
+
+
entry.Authority = make([]CachedResourceRecord, len(msg.Authority))
+
for i, rr := range msg.Authority {
+
recordTTL := rr.TTL
+
if recordTTL < 60 {
+
recordTTL = 60
+
}
+
rule := determineBailiwickRule(zone, rr.Name)
+
entry.Authority[i] = CachedResourceRecord{Record: rr, ExpireAt: now.Add(time.Duration(recordTTL) * time.Second), BailiwickRule: rule}
+
}
-
minTTL := getMinTTL(msg.Answer)
-
entry.ExpireAt = now.Add(time.Duration(minTTL) * time.Second)
+
} else {
+
entry.ExpireAt = now.Add(time.Duration(minAnswerTTL) * time.Second)
-
entry.Answer = make([]CachedResourceRecord, len(msg.Answer))
-
for i, rr := range msg.Answer {
-
rule := determineBailiwickRule(zone, rr.Name)
-
entry.Answer[i] = CachedResourceRecord{
-
Record: rr,
-
ExpireAt: now.Add(time.Duration(rr.TTL) * time.Second),
-
BailiwickRule: rule,
-
}
-
}
+
entry.Answer = make([]CachedResourceRecord, len(msg.Answer))
+
for i, rr := range msg.Answer {
+
recordTTL := rr.TTL
+
if recordTTL < 60 {
+
recordTTL = 60
+
}
+
rule := determineBailiwickRule(zone, rr.Name)
+
entry.Answer[i] = CachedResourceRecord{Record: rr, ExpireAt: now.Add(time.Duration(recordTTL) * time.Second), BailiwickRule: rule}
+
}
-
entry.Authority = make([]CachedResourceRecord, len(msg.Authority))
-
for i, rr := range msg.Authority {
-
rule := determineBailiwickRule(zone, rr.Name)
-
entry.Authority[i] = CachedResourceRecord{
-
Record: rr,
-
ExpireAt: now.Add(time.Duration(rr.TTL) * time.Second),
-
BailiwickRule: rule,
-
}
-
}
+
entry.Authority = make([]CachedResourceRecord, len(msg.Authority))
+
for i, rr := range msg.Authority {
+
recordTTL := rr.TTL
+
if recordTTL < 60 {
+
recordTTL = 60
+
}
+
rule := determineBailiwickRule(zone, rr.Name)
+
entry.Authority[i] = CachedResourceRecord{Record: rr, ExpireAt: now.Add(time.Duration(recordTTL) * time.Second), BailiwickRule: rule}
+
}
-
entry.Additional = make([]CachedResourceRecord, len(msg.Additional))
-
for i, rr := range msg.Additional {
-
rule := determineBailiwickRule(zone, rr.Name)
-
entry.Additional[i] = CachedResourceRecord{
-
Record: rr,
-
ExpireAt: now.Add(time.Duration(rr.TTL) * time.Second),
-
BailiwickRule: rule,
+
entry.Additional = make([]CachedResourceRecord, len(msg.Additional))
+
for i, rr := range msg.Additional {
+
recordTTL := rr.TTL
+
if recordTTL < 60 {
+
recordTTL = 60
+
}
+
rule := determineBailiwickRule(zone, rr.Name)
+
entry.Additional[i] = CachedResourceRecord{Record: rr, ExpireAt: now.Add(time.Duration(recordTTL) * time.Second), BailiwickRule: rule}
+
}
}
}
+151 -193
pkg/dns/resolve.go
···
"context"
"fmt"
"log/slog"
+
"math/rand"
"net"
"time"
"tangled.sh/seiso.moe/magna"
)
+
+
var errNXDOMAIN = fmt.Errorf("nxdomain")
type QueryHandler struct {
RootServers []string
Timeout time.Duration
-
Cache *MemoryCache
+
Cache Cache
Logger *slog.Logger
}
···
}
question := r.Message.Question[0]
-
answers, authority, err := h.resolveQuestion(r.Context, question, h.RootServers)
+
records, err := h.resolveQuestion(r.Context, question, h.RootServers)
msg := r.Message.CreateReply(r.Message)
msg.Header.RA = true
-
if err != nil {
-
if err.Error() == "nxdomain" {
-
msg = msg.SetRCode(magna.NXDOMAIN)
-
msg.Authority = authority
-
msg.Header.NSCount = uint16(len(authority))
-
} else {
-
msg = msg.SetRCode(magna.SERVFAIL)
-
}
+
if err == errNXDOMAIN {
+
msg = msg.SetRCode(magna.NXDOMAIN)
+
msg.Authority = records
+
msg.Header.NSCount = uint16(len(records))
+
msg.Header.ANCount = 0
+
msg.Answer = nil
+
} else if err != nil {
+
msg.SetRCode(magna.SERVFAIL)
} else {
-
msg.Answer = answers
-
msg.Header.ANCount = uint16(len(answers))
+
msg.Answer = records
+
msg.Header.ANCount = uint16(len(records))
msg = msg.SetRCode(magna.NOERROR)
}
w.WriteMsg(msg)
}
-
func (h *QueryHandler) resolveQuestion(ctx context.Context, question magna.Question, servers []string) ([]magna.ResourceRecord, []magna.ResourceRecord, error) {
-
return h.resolveQuestionWithZone(ctx, question, servers, ".")
-
}
-
-
func (h *QueryHandler) resolveQuestionWithZone(ctx context.Context, question magna.Question, servers []string, currentZone string) ([]magna.ResourceRecord, []magna.ResourceRecord, error) {
-
cacheKey := fmt.Sprintf("%s:%s:%s", question.QName, question.QType.String(), question.QClass.String())
-
-
if h.Cache != nil {
-
if entry, hit := h.Cache.Get(cacheKey); hit {
-
now := time.Now()
-
-
if r := getCurrentRequest(ctx); r != nil {
-
r.Context = setCacheHit(r.Context)
-
}
-
-
if entry.IsNegative {
-
return nil, convertCachedToMagna(entry.Authority, now), fmt.Errorf("nxdomain")
-
}
-
-
validAnswers := convertCachedToMagna(entry.Answer, now)
-
if len(validAnswers) > 0 {
-
return validAnswers, nil, nil
-
}
-
} else {
-
if r := getCurrentRequest(ctx); r != nil {
-
r.Context = setCacheMiss(r.Context)
-
}
-
}
-
}
-
-
ctx, cancel := context.WithTimeout(ctx, h.Timeout)
+
func (h *QueryHandler) resolveQuestion(ctx context.Context, question magna.Question, servers []string) ([]magna.ResourceRecord, error) {
+
resolveCtx, cancel := context.WithTimeout(ctx, h.Timeout)
defer cancel()
-
ch := make(chan queryResponse, len(servers))
+
for _, s := range servers {
-
go queryServer(ctx, question, s, ch, h.Timeout)
+
go queryServer(resolveCtx, question, s, ch)
}
-
for i := 0; i < len(servers); i++ {
+
for range servers {
select {
case res := <-ch:
if res.Error != nil {
-
h.Logger.Debug("server query failed",
-
"server", res.Server,
-
"error", res.Error)
-
continue
+
slog.Warn("error", "question", question, "server", res.Server, "error", res.Error)
+
break
}
msg := res.MSG
-
zone := extractZone(&msg)
-
if zone == "" {
-
zone = currentZone
-
}
-
if msg.Header.RCode == magna.NXDOMAIN {
-
entry := CreateCacheEntry(&msg, zone)
-
h.Cache.Set(cacheKey, entry)
-
now := time.Now()
-
return nil, convertCachedToMagna(entry.Authority, now), fmt.Errorf("nxdomain")
+
return msg.Authority, errNXDOMAIN
}
if msg.Header.ANCount > 0 {
if msg.Answer[0].RType == magna.CNAMEType {
-
h.Logger.Debug("following CNAME",
-
"cname", msg.Answer[0].RData.String())
-
-
entry := CreateCacheEntry(&msg, zone)
-
h.Cache.Set(cacheKey, entry)
+
cname_answers, err := h.resolveQuestion(resolveCtx, magna.Question{QName: msg.Answer[0].RData.String(), QType: question.QType, QClass: question.QClass}, h.RootServers)
+
if err != nil {
+
if err == errNXDOMAIN {
+
return cname_answers, errNXDOMAIN
+
}
+
slog.Warn("error resolving CNAME", "cname", msg.Answer[0].RData.String(), "error", err)
-
answers, auth, err := h.resolveQuestionWithZone(ctx, magna.Question{QName: msg.Answer[0].RData.String(), QType: question.QType, QClass: question.QClass}, h.RootServers, zone)
-
if err == nil {
-
return append(msg.Answer, answers...), auth, nil
+
continue
}
-
return nil, auth, err
+
msg.Answer = append(msg.Answer, cname_answers...)
}
-
entry := CreateCacheEntry(&msg, zone)
-
h.Cache.Set(cacheKey, entry)
-
return msg.Answer, nil, nil
-
}
-
-
if msg.Header.ARCount > 0 {
-
var nextZone []string
-
for _, ans := range msg.Additional {
-
if ans.RType == magna.AType {
-
nextZone = append(nextZone, ans.RData.String())
-
}
-
}
-
-
if len(nextZone) > 0 {
-
return h.resolveQuestion(ctx, question, nextZone)
-
}
+
return msg.Answer, nil
}
if msg.Header.NSCount > 0 {
-
nsRecords := make(map[string]string)
-
glueRecords := make(map[string]string)
+
var nextServers []string
+
var nsRecords []magna.ResourceRecord
-
if msg.Header.ARCount > 0 {
-
for _, additional := range msg.Additional {
-
if additional.RType == magna.AType {
-
rule := determineBailiwickRule(zone, additional.Name)
-
if rule != BailiwickOutside {
-
glueRecords[additional.Name] = additional.RData.String()
+
glueMap := make(map[string]string)
+
for _, rr := range msg.Additional {
+
if rr.RType == magna.AType {
+
for _, nsRR := range msg.Authority {
+
if nsRR.RType == magna.NSType && nsRR.RData.String() == rr.Name {
+
glueMap[rr.Name] = rr.RData.String()
+
break
}
}
}
}
-
var nextServers []string
-
var needResolution []string
+
for _, rr := range msg.Authority {
+
if rr.RType == magna.NSType {
+
nsRecords = append(nsRecords, rr)
+
}
+
}
-
for _, auth := range msg.Authority {
-
if auth.RType == magna.NSType {
-
rule := determineBailiwickRule(zone, auth.Name)
-
if rule != BailiwickOutside {
-
nsName := auth.RData.String()
-
nsRecords[nsName] = ""
+
for _, nsRR := range nsRecords {
+
nsName := nsRR.RData.String()
+
if ip, found := glueMap[nsName]; found {
+
nextServers = append(nextServers, ip)
+
} else {
+
nsAnswers, err := h.resolveQuestion(resolveCtx, magna.Question{QName: nsName, QType: magna.AType, QClass: magna.IN}, h.RootServers)
+
if err != nil {
+
slog.Warn("error resolving NS A record", "ns", nsName, "error", err)
+
continue
+
}
-
if ip, exists := glueRecords[nsName]; exists {
-
nextServers = append(nextServers, ip)
-
} else {
-
needResolution = append(needResolution, nsName)
+
for _, ans := range nsAnswers {
+
if ans.RType == magna.AType {
+
nextServers = append(nextServers, ans.RData.String())
}
}
}
}
if len(nextServers) > 0 {
-
h.Logger.Debug("using glue records for resolution",
-
"servers", nextServers)
-
return h.resolveQuestionWithZone(ctx, question, nextServers, zone)
+
return h.resolveQuestion(resolveCtx, question, nextServers)
}
-
for _, ns := range needResolution {
-
answers, _, err := h.resolveQuestionWithZone(
-
ctx,
-
magna.Question{
-
QName: ns,
-
QType: magna.AType,
-
QClass: magna.IN,
-
},
-
h.RootServers,
-
zone,
-
)
-
if err == nil {
-
for _, ans := range answers {
-
nextServers = append(nextServers, ans.RData.String())
-
}
-
}
-
}
+
slog.Warn("could not resolve any NS records for delegation", "question", question.QName)
+
continue
+
}
-
if len(nextServers) > 0 {
-
return h.resolveQuestionWithZone(ctx, question, nextServers, zone)
-
}
+
if msg.Header.RCode == magna.NOERROR && msg.Header.ANCount == 0 {
+
return []magna.ResourceRecord{}, nil
}
-
case <-ctx.Done():
-
return nil, nil, ctx.Err()
+
slog.Warn("unexpected response state", "question", question, "server", res.Server, "rcode", msg.Header.RCode)
+
continue
+
case <-resolveCtx.Done():
+
slog.Debug("resolution cancelled or timed out", "question", question)
+
return []magna.ResourceRecord{}, fmt.Errorf("resolution timed out or cancelled")
}
}
-
return nil, nil, fmt.Errorf("all queries failed")
+
slog.Warn("all resolution paths failed", "question", question)
+
return []magna.ResourceRecord{}, fmt.Errorf("all resolution paths failed")
}
-
func queryServer(ctx context.Context, question magna.Question, server string, ch chan<- queryResponse, timeout time.Duration) {
-
done := make(chan struct{}, 1)
-
-
go func() {
-
conn, err := net.Dial("udp", fmt.Sprintf("%s:53", server))
-
if err != nil {
-
ch <- queryResponse{Error: err}
-
return
-
}
-
defer conn.Close()
-
-
query := magna.CreateRequest(0, false)
-
query = query.AddQuestion(question)
-
if _, err := conn.Write(query.Encode()); err != nil {
-
ch <- queryResponse{Server: server, Error: err}
-
return
-
}
+
func queryServer(ctx context.Context, question magna.Question, server string, ch chan<- queryResponse) {
+
var response queryResponse
+
response.Server = server
-
p := make([]byte, 512)
-
nn, err := conn.Read(p)
-
-
// TODO: retry request with TCP
-
if err != nil || nn > 512 {
-
if err == nil {
-
err = fmt.Errorf("truncated response")
-
}
-
ch <- queryResponse{Server: server, Error: err}
-
return
+
defer func() {
+
select {
+
case ch <- response:
+
default:
+
slog.Debug("queryServer response channel blocked or closed", "server", server)
}
-
-
var response magna.Message
-
err = response.Decode(p)
-
ch <- queryResponse{MSG: response, Server: server, Error: err}
}()
select {
case <-ctx.Done():
-
ch <- queryResponse{Server: server, Error: ctx.Err()}
-
case <-done:
-
// goroutine finished with no cancellation
-
case <-time.After(timeout):
-
ch <- queryResponse{Server: server, Error: fmt.Errorf("timeout")}
+
response.Error = ctx.Err()
+
return
+
default:
}
-
}
-
func convertCachedToMagna(cached []CachedResourceRecord, now time.Time) []magna.ResourceRecord {
-
result := make([]magna.ResourceRecord, 0, len(cached))
-
for _, record := range cached {
-
if now.Before(record.ExpireAt) {
-
rr := record.Record
-
rr.TTL = uint32(record.ExpireAt.Sub(now).Seconds())
-
result = append(result, rr)
-
}
+
dialer := net.Dialer{Timeout: 2 * time.Second}
+
conn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf("%s:53", server))
+
if err != nil {
+
response.Error = fmt.Errorf("dial error: %w", err)
+
return
}
-
return result
-
}
+
defer conn.Close()
-
func getCurrentRequest(ctx context.Context) *Request {
-
if ctx == nil {
-
return nil
+
deadline, ok := ctx.Deadline()
+
if !ok {
+
deadline = time.Now().Add(5 * time.Second)
}
+
conn.SetDeadline(deadline)
-
if r, ok := ctx.Value(contextKey("request")).(*Request); ok {
-
return r
+
query := magna.Message{
+
Header: magna.Header{
+
ID: uint16(rand.Int() % 65535),
+
QR: false,
+
OPCode: magna.QUERY,
+
AA: false,
+
TC: false,
+
RD: false,
+
RA: false,
+
Z: 0,
+
RCode: magna.NOERROR,
+
QDCount: 1,
+
ARCount: 0,
+
NSCount: 0,
+
ANCount: 0,
+
},
+
Question: []magna.Question{question},
}
-
return nil
-
}
+
msgBytes, err := query.Encode()
+
if err != nil {
+
response.Error = fmt.Errorf("encode error: %w", err)
+
return
+
}
-
func extractZone(msg *magna.Message) string {
-
for _, auth := range msg.Authority {
-
if auth.RType == magna.SOAType {
-
return auth.Name
+
_, err = conn.Write(msgBytes)
+
if err != nil {
+
if ctx.Err() != nil {
+
response.Error = ctx.Err()
+
} else {
+
response.Error = fmt.Errorf("write error: %w", err)
}
+
return
+
}
-
if auth.RType == magna.NSType {
-
return auth.Name
+
p := make([]byte, 512)
+
n, err := conn.Read(p)
+
if err != nil {
+
if ctx.Err() != nil {
+
response.Error = ctx.Err()
+
} else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
+
response.Error = fmt.Errorf("read timeout: %w", err)
+
} else {
+
response.Error = fmt.Errorf("read error: %w", err)
}
+
return
}
-
if len(msg.Question) > 0 {
-
return msg.Question[0].QName
+
// TODO: retry request with TCP
+
if n >= 512 {
+
response.Error = fmt.Errorf("response possibly truncated (size %d >= 512)", n)
+
return
+
}
+
+
var decodedMsg magna.Message
+
err = decodedMsg.Decode(p[:n])
+
if err != nil {
+
response.Error = fmt.Errorf("decode error: %w", err)
+
return
}
-
return ""
+
if decodedMsg.Header.ID != query.Header.ID {
+
response.Error = fmt.Errorf("response ID mismatch (got %d, expected %d)", decodedMsg.Header.ID, query.Header.ID)
+
return
+
}
+
+
response.MSG = decodedMsg
+
response.Error = nil
}
+14 -5
pkg/dns/server.go
···
}
func (w *udpResponseWriter) WriteMsg(msg *magna.Message) {
-
ans := msg.Encode()
+
ans, err := msg.Encode()
+
if err != nil {
+
w.logger.Warn("err encoding msg", "error", err)
+
return
+
}
+
if len(ans) > 512 {
ans[3] |= 1 << 6 // set the truncated bit
}
-
err := w.udpConn.SetWriteDeadline(time.Now().Add(w.writeTimeout))
+
err = w.udpConn.SetWriteDeadline(time.Now().Add(w.writeTimeout))
if err != nil {
w.logger.Warn("error setting write deadline for UDP", "error", err)
}
···
}
func (w *tcpResponseWriter) WriteMsg(msg *magna.Message) {
-
ans := msg.Encode()
+
ans, err := msg.Encode()
+
if err != nil {
+
w.logger.Warn("err encoding msg", "error", err)
+
return
+
}
-
err := w.tcpConn.SetWriteDeadline(time.Now().Add(w.writeTiemout))
+
err = w.tcpConn.SetWriteDeadline(time.Now().Add(w.writeTiemout))
if err != nil {
w.logger.Warn("error setting write deadline for TCP", "error", err)
}
···
ReadTimeout time.Duration
WriteTimeout time.Duration
Logger *slog.Logger
-
Cache *MemoryCache
+
Cache Cache
}
func (srv *Server) ListenAndServe() error {
+7 -7
pkg/metrics/clickhouse.go
···
func (m *ClickHouseMetrics) RecordCacheStats(stats *dns.CacheStats) {
m.RecordCacheMetrics(CacheMetric{
Timestamp: time.Now(),
-
TotalQueries: stats.TotalQueries,
-
CacheHits: stats.CacheHits,
-
CacheMisses: stats.CacheMisses,
-
NegativeHits: stats.NegativeHits,
-
PositiveHits: stats.PositiveHits,
-
Evictions: stats.Evictions,
-
Size: stats.Size,
+
TotalQueries: stats.TotalQueries.Load(),
+
CacheHits: stats.CacheHits.Load(),
+
CacheMisses: stats.CacheMisses.Load(),
+
NegativeHits: stats.NegativeHits.Load(),
+
PositiveHits: stats.PositiveHits.Load(),
+
Evictions: stats.Evictions.Load(),
+
Size: stats.Size.Load(),
})
}