a recursive dns resolver

refractor code to be more testable hopefully

+4 -2
docs/alky.toml
···
port = 2053
# Location of root hints file.
root_hints_file = "/etc/dns/root.hints"
+
udp_payload_size = 512
[logging]
# Logging output: "stdout" or "file".
output = "stdout"
+
level = "debug"
# This is used only if logging.output is "file".
file_path = "/var/log/alky.log"
···
# window: The interval (in seconds) at which the rate limit is checked and potentially reset
# Implements a sliding window rate limit mechanism
# Type: Integer
-
window = 3
+
window_seconds = "1s"
# expiration_time: Duration (in seconds) for keeping a client's rate limit data in memory
# After this period of inactivity, a client's rate limit data is removed to free up memory
# Type: Integer
-
expiration_time = 300
+
cleanup_seconds = "5m"
[metrics]
# ClickHouse connection string
+109 -43
main.go
···
"log/slog"
"os"
"os/signal"
+
"strings"
"syscall"
"time"
···
func init() {
flag.StringVar(&configFlag, "config", "/etc/alky/alky.toml", "config file path for alky")
-
flag.Parse()
}
func main() {
cfg, err := config.LoadConfig(configFlag)
if err != nil {
-
log.Fatal(err)
+
log.Fatalf("failed to load config: ", err)
}
logger := setupLogger(&cfg)
+
slog.SetDefault(logger)
-
metricsClient, err := metrics.NewClickHouseMetrics(&cfg.Metrics, logger)
+
logger.Info("Configuration loaded", "path", configFlag)
+
+
metricsClient, err := metrics.NewClickHouseMetrics(&cfg.Metrics, logger.With("component", "metrics"))
if err != nil {
-
log.Fatal(err)
+
logger.Error("failed to initialize metrics client", "error", err)
+
os.Exit(1)
}
defer metricsClient.Close()
+
logger.Info("Metrics client initialized")
rootServers, err := rootservers.DecodeRootHints(cfg.Server.RootHintsFile)
-
if err != nil {
-
log.Fatal(err)
+
if err != nil || len(rootServers) == 0 {
+
logger.Error("Failed to load root hints or no root servers found", "path", cfg.Server.RootHintsFile, "error", err)
+
os.Exit(1)
}
+
logger.Info("Root hints loaded", "count", len(rootServers), "path", cfg.Server.RootHintsFile)
-
cache := dns.NewLRUCache(cfg.Cache.MaxItems, cfg.Cache.CleanupInterval.Duration)
+
cache := dns.NewLRUCache(cfg.Cache.MaxItems, cfg.Cache.CleanupInterval.Duration, logger.With("component", "cache"))
defer cache.Stop()
+
logger.Info("DNS cache initialized")
-
handler := &dns.QueryHandler{
+
queryHandler := &dns.QueryHandler{
RootServers: rootServers,
-
Timeout: time.Duration(cfg.Advanced.QueryTimeout) * time.Second,
+
Timeout: cfg.Advanced.QueryTimeout.Duration,
Cache: cache,
-
Logger: logger,
+
Logger: logger.With("component", "resolver"),
}
-
go monitorCacheMetrics(cache, metricsClient, logger)
+
go monitorCacheMetrics(cache, metricsClient, logger.With("component", "cache-monitor"))
-
rateLimitHandler := dns.RateLimitMiddleware(&dns.RateLimitConfig{
-
Rate: float64(cfg.Ratelimit.Rate),
-
Burst: cfg.Ratelimit.Burst,
-
WindowLength: time.Duration(cfg.Ratelimit.Window) * time.Second,
-
ExpirationTime: time.Duration(cfg.Ratelimit.ExpirationTime) * time.Second,
-
})(handler)
+
var currentHandler dns.Handler = queryHandler
-
metricsHandler := metrics.MetricsMiddleware(metricsClient)(rateLimitHandler)
+
if cfg.Ratelimit.Rate > 0 {
+
rateLimitHandler := dns.RateLimitMiddleware(&dns.RateLimitConfig{
+
Rate: float64(cfg.Ratelimit.Rate),
+
Burst: cfg.Ratelimit.Burst,
+
WindowLength: cfg.Ratelimit.Window.Duration,
+
ExpirationTime: cfg.Ratelimit.ExpirationTime.Duration,
+
})(currentHandler)
+
currentHandler = rateLimitHandler
+
logger.Info("Rate limiting enabled", "rate", cfg.Ratelimit.Rate, "burst", cfg.Ratelimit.Burst)
+
} else {
+
logger.Info("Rate limiting disabled")
+
}
+
+
metricsHandler := metrics.MetricsMiddleware(metricsClient)(currentHandler)
+
currentHandler = metricsHandler
+
logger.Info("Metrics middleware enabled")
loggingHandler := dns.LoggingMiddleware(&dns.LogConfig{
-
Logger: logger,
+
Logger: logger.With("component", "query-log"),
Level: slog.LevelInfo,
-
})(metricsHandler)
+
})(currentHandler)
+
currentHandler = loggingHandler
+
logger.Info("Logging middleware enabled")
+
sigChan := make(chan os.Signal, 1)
+
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
-
sigChan := make(chan os.Signal, 1)
-
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
-
<-sigChan
-
slog.Info("Shutting down...")
+
sig := <-sigChan
+
logger.Info("Received signal, shutting down gracefully...", "signal", sig.String())
+
+
logger.Info("Stopping cache...")
cache.Stop()
+
+
logger.Info("Closing metrics client...")
metricsClient.Close()
+
+
logger.Info("Shutdown complete.")
os.Exit(0)
}()
s := dns.Server{
Address: cfg.Server.Address,
Port: cfg.Server.Port,
-
Handler: loggingHandler,
-
UDPSize: 512,
-
ReadTimeout: 2 * time.Second,
-
WriteTimeout: 2 * time.Second,
-
Logger: logger,
+
Handler: currentHandler,
+
UDPSize: cfg.Server.UDPSize,
+
ReadTimeout: cfg.Advanced.ReadTimeout.Duration,
+
WriteTimeout: cfg.Advanced.WriteTimeout.Duration,
+
Logger: logger.With("component", "server"),
}
+
logger.Info("Starting DNS server listener...")
if err := s.ListenAndServe(); err != nil {
-
slog.Error("Failed to start server", "error", err)
-
cache.Stop()
-
metricsClient.Close()
+
logger.Error("Failed to start server", "error", err)
+
os.Exit(1)
}
}
func monitorCacheMetrics(cache dns.Cache, metricsClient *metrics.ClickHouseMetrics, logger *slog.Logger) {
-
ticker := time.NewTicker(1 * time.Minute)
+
interval := 1 * time.Minute
+
logger.Info("Starting cache metrics monitoring", "interval", interval)
+
ticker := time.NewTicker(interval)
defer ticker.Stop()
-
for range ticker.C {
-
stats := cache.GetStats()
-
metricsClient.RecordCacheStats(stats)
-
logger.Info("Cache metrics recorded to ClickHouse")
+
for {
+
select {
+
case <-ticker.C:
+
stats := cache.GetStats()
+
logger.Debug("Recording cache metrics",
+
"hits", stats.Hits.Load(),
+
"misses", stats.Misses.Load(),
+
"size", stats.Size.Load(),
+
"evictions", stats.Evictions.Load(),
+
"expired", stats.Expired.Load(),
+
"pos_hits", stats.PositiveHits.Load(),
+
"neg_hits", stats.NegativeHits.Load(),
+
)
+
metricsClient.RecordCacheStats(metrics.CacheMetric{
+
Timestamp: time.Now(),
+
CacheHits: stats.Hits.Load(),
+
CacheMisses: stats.Misses.Load(),
+
NegativeHits: stats.NegativeHits.Load(),
+
PositiveHits: stats.PositiveHits.Load(),
+
Evictions: stats.Evictions.Load() + stats.Expired.Load(),
+
Size: stats.Size.Load(),
+
})
+
}
}
}
func setupLogger(cfg *config.Config) *slog.Logger {
-
var logger *slog.Logger
+
var logLevel slog.Level
+
switch strings.ToLower(cfg.Logging.Level) {
+
case "debug":
+
logLevel = slog.LevelDebug
+
case "info":
+
logLevel = slog.LevelInfo
+
case "warn":
+
logLevel = slog.LevelWarn
+
case "error":
+
logLevel = slog.LevelError
+
default:
+
logLevel = slog.LevelInfo
+
}
handlerOpts := &slog.HandlerOptions{
-
Level: slog.LevelDebug,
+
Level: logLevel,
+
AddSource: logLevel <= slog.LevelDebug,
}
+
var handler slog.Handler
switch cfg.Logging.Output {
case "file":
f, err := os.OpenFile(cfg.Logging.FilePath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0o644)
if err != nil {
-
log.Fatal(err)
+
log.Fatalf("Failed to open log file %s: %v", cfg.Logging.FilePath, err)
}
-
-
logger = slog.New(slog.NewJSONHandler(f, handlerOpts))
+
handler = slog.NewJSONHandler(f, handlerOpts)
+
log.Printf("Logging to file: %s at level: %s", cfg.Logging.FilePath, logLevel.String())
+
case "text":
+
handler = slog.NewTextHandler(os.Stdout, handlerOpts)
+
log.Printf("Logging to stdout (text) at level: %s", logLevel.String())
default:
-
logger = slog.New(slog.NewJSONHandler(os.Stdout, handlerOpts))
+
handler = slog.NewJSONHandler(os.Stdout, handlerOpts)
+
log.Printf("Logging to stdout (json) at level: %s", logLevel.String())
}
+
logger := slog.New(handler)
return logger
}
+60
migrations/00002_modified_metrics.sql
···
+
-- +goose Up
+
ALTER TABLE alky_dns_queries
+
MODIFY COLUMN timestamp DateTime CODEC(Delta, ZSTD(1)),
+
MODIFY COLUMN instance_id String CODEC(ZSTD(1)),
+
MODIFY COLUMN query_name String CODEC(ZSTD(1)),
+
MODIFY COLUMN query_type LowCardinality(String) CODEC(ZSTD(1)),
+
MODIFY COLUMN query_class LowCardinality(String) CODEC(ZSTD(1)),
+
MODIFY COLUMN remote_addr String CODEC(ZSTD(1)),
+
MODIFY COLUMN response_code LowCardinality(String) CODEC(ZSTD(1)),
+
MODIFY COLUMN duration Int64 CODEC(T64, ZSTD(1)),
+
MODIFY COLUMN cache_hit Bool CODEC(ZSTD(1));
+
+
ALTER TABLE alky_dns_queries MODIFY TTL timestamp + INTERVAL 30 DAY;
+
+
ALTER TABLE alky_dns_cache_metrics
+
DROP COLUMN IF EXISTS total_queries,
+
MODIFY COLUMN timestamp DateTime CODEC(Delta, ZSTD(1)),
+
MODIFY COLUMN instance_id String CODEC(ZSTD(1)),
+
MODIFY COLUMN cache_hits Int64 CODEC(T64, ZSTD(1)),
+
MODIFY COLUMN cache_misses Int64 CODEC(T64, ZSTD(1)),
+
MODIFY COLUMN negative_hits Int64 CODEC(T64, ZSTD(1)),
+
MODIFY COLUMN positive_hits Int64 CODEC(T64, ZSTD(1)),
+
MODIFY COLUMN evictions Int64 CODEC(T64, ZSTD(1)),
+
MODIFY COLUMN size Int64 CODEC(T64, ZSTD(1));
+
+
ALTER TABLE alky_dns_cache_metrics
+
ADD COLUMN IF NOT EXISTS expired_count Int64 CODEC(T64, ZSTD(1));
+
+
ALTER TABLE alky_dns_cache_metrics MODIFY TTL timestamp + INTERVAL 30 DAY;
+
+
-- +goose Down
+
ALTER TABLE alky_dns_queries
+
MODIFY COLUMN timestamp DateTime,
+
MODIFY COLUMN instance_id String,
+
MODIFY COLUMN query_name String,
+
MODIFY COLUMN query_type String,
+
MODIFY COLUMN query_class String,
+
MODIFY COLUMN remote_addr String,
+
MODIFY COLUMN response_code String,
+
MODIFY COLUMN duration Int64,
+
MODIFY COLUMN cache_hit Bool;
+
+
ALTER TABLE alky_dns_queries MODIFY TTL timestamp + toIntervalDay(30);
+
+
ALTER TABLE alky_dns_cache_metrics
+
ADD COLUMN IF NOT EXISTS total_queries Int64 AFTER instance_id;
+
+
ALTER TABLE alky_dns_cache_metrics
+
MODIFY COLUMN timestamp DateTime,
+
MODIFY COLUMN instance_id String,
+
MODIFY COLUMN cache_hits Int64,
+
MODIFY COLUMN cache_misses Int64,
+
MODIFY COLUMN negative_hits Int64,
+
MODIFY COLUMN positive_hits Int64,
+
MODIFY COLUMN evictions Int64,
+
MODIFY COLUMN size Int;
+
+
ALTER TABLE alky_dns_cache_metrics DROP COLUMN IF EXISTS expired_count;
+
+
ALTER TABLE alky_dns_cache_metrics MODIFY TTL timestamp + toIntervalDay(30);
+91 -30
pkg/config/config.go
···
import (
"fmt"
+
"os"
+
"strings"
"time"
"github.com/BurntSushi/toml"
···
Address string `toml:"address"`
Port int `toml:"port"`
RootHintsFile string `toml:"root_hints_file"`
+
UDPSize int `toml:"udp_payload_size"`
}
type LoggingConfig struct {
Output string `toml:"output"`
FilePath string `toml:"file_path"`
+
Level string `toml:"level"`
}
type RatelimitConfig struct {
-
Rate int `toml:"rate"`
-
Burst int `toml:"burst"`
-
Window int `toml:"window"`
-
ExpirationTime int `toml:"expiration_time"`
+
Rate int `toml:"rate"`
+
Burst int `toml:"burst"`
+
Window duration `toml:"window_seconds"`
+
ExpirationTime duration `toml:"cleanup_seconds"`
}
type MetricsConfig struct {
···
}
type AdvancedConfig struct {
-
QueryTimeout int `toml:"query_timeout"`
+
QueryTimeout duration `toml:"query_timeout"`
+
ReadTimeout duration `toml:"read_timeout"`
+
WriteTimeout duration `toml:"write_timeout"`
}
type CacheConfig struct {
···
func (d *duration) UnmarshalText(text []byte) error {
var err error
d.Duration, err = time.ParseDuration(string(text))
-
return err
+
if err != nil {
+
var seconds int64
+
seconds, err = parseInt64(string(text))
+
if err == nil {
+
d.Duration = time.Duration(seconds) * time.Second
+
return nil
+
}
+
return fmt.Errorf("invalid duration format: %w", err)
+
}
+
return nil
+
}
+
+
func parseInt64(s string) (int64, error) {
+
var i int64
+
_, err := fmt.Sscan(s, &i)
+
return i, err
}
func LoadConfig(path string) (Config, error) {
cfg := Config{}
-
if _, err := toml.DecodeFile(path, &cfg); err != nil {
-
return cfg, err
+
cfg.Server.Address = "127.0.0.1"
+
cfg.Server.Port = 53
+
cfg.Server.RootHintsFile = "/etc/alky/root.hints"
+
cfg.Server.UDPSize = 512
+
+
cfg.Logging.Output = "stdout"
+
cfg.Logging.Level = "info"
+
+
cfg.Ratelimit.Rate = 100
+
cfg.Ratelimit.Burst = 200
+
cfg.Ratelimit.Window.Duration = 1 * time.Second
+
cfg.Ratelimit.ExpirationTime.Duration = 1 * time.Minute
+
+
cfg.Metrics.DSN = "clickhouse://localhost:9000/default"
+
cfg.Metrics.BatchSize = 1000
+
cfg.Metrics.FlushInterval.Duration = 10 * time.Second
+
cfg.Metrics.RetentionPeriod.Duration = 30 * 24 * time.Hour
+
+
cfg.Cache.MaxItems = 10000
+
cfg.Cache.CleanupInterval.Duration = 5 * time.Minute
+
+
cfg.Advanced.QueryTimeout.Duration = 5 * time.Second
+
cfg.Advanced.ReadTimeout.Duration = 2 * time.Second
+
cfg.Advanced.WriteTimeout.Duration = 2 * time.Second
+
+
md, err := toml.DecodeFile(path, &cfg)
+
if err != nil {
+
if os.IsNotExist(err) {
+
fmt.Printf("Warning: Config file '%s' not found, using default settings.\n", path)
+
} else {
+
return cfg, fmt.Errorf("error decoding config file '%s': %w", path, err)
+
}
}
-
if cfg.Server.Address == "" {
-
cfg.Server.Address = "127.0.0.1"
+
if len(md.Undecoded()) > 0 {
+
return cfg, fmt.Errorf("unknown configuration keys found: %v", md.Undecoded())
}
-
if cfg.Server.Port == 0 {
-
cfg.Server.Port = 53
+
if cfg.Logging.Output == "file" && cfg.Logging.FilePath == "" {
+
return cfg, fmt.Errorf("logging output is 'file' but 'file_path' is not set")
}
-
if cfg.Server.RootHintsFile == "" {
-
cfg.Server.RootHintsFile = "/etc/dns/root.hints"
+
validLevels := []string{"debug", "info", "warn", "error"}
+
isValidLevel := false
+
for _, level := range validLevels {
+
if strings.ToLower(cfg.Logging.Level) == level {
+
isValidLevel = true
+
break
+
}
}
-
if cfg.Logging.Output == "file" && cfg.Logging.FilePath == "" {
-
return cfg, fmt.Errorf("If `[logging.output]` is `file` then `file_path` must be set.")
+
if !isValidLevel {
+
return cfg, fmt.Errorf("invalid logging level '%s', must be one of: %v", cfg.Logging.Level, validLevels)
}
-
if cfg.Metrics.DSN == "" {
-
cfg.Metrics.DSN = "clickhouse://localhost:9000/default"
+
if cfg.Cache.MaxItems <= 0 {
+
return cfg, fmt.Errorf("cache max_items must be positive")
}
-
if cfg.Metrics.BatchSize == 0 {
-
cfg.Metrics.BatchSize = 1000
+
if cfg.Cache.CleanupInterval.Duration <= 0 {
+
cfg.Cache.CleanupInterval.Duration = 5 * time.Minute
}
-
if cfg.Metrics.FlushInterval.Duration == 0 {
-
cfg.Metrics.FlushInterval.Duration = 10 * time.Second
+
if cfg.Advanced.QueryTimeout.Duration <= 0 {
+
cfg.Advanced.QueryTimeout.Duration = 5 * time.Second
}
-
if cfg.Metrics.RetentionPeriod.Duration == 0 {
-
cfg.Metrics.RetentionPeriod.Duration = 30 * 24 * time.Hour
+
if cfg.Advanced.ReadTimeout.Duration <= 0 {
+
cfg.Advanced.ReadTimeout.Duration = 2 * time.Second
}
-
if cfg.Cache.MaxItems == 0 {
-
cfg.Cache.MaxItems = 5000
+
if cfg.Advanced.WriteTimeout.Duration <= 0 {
+
cfg.Advanced.WriteTimeout.Duration = 2 * time.Second
}
-
if cfg.Cache.CleanupInterval.Duration == 0 {
-
cfg.Cache.CleanupInterval.Duration = 5 * time.Minute
+
if cfg.Server.UDPSize <= 0 || cfg.Server.UDPSize > 4096 {
+
cfg.Server.UDPSize = 512
}
-
if cfg.Advanced.QueryTimeout == 0 {
-
cfg.Advanced.QueryTimeout = 100
+
if cfg.Ratelimit.Rate > 0 {
+
if cfg.Ratelimit.Window.Duration <= 0 {
+
cfg.Ratelimit.Window.Duration = 1 * time.Second
+
}
+
if cfg.Ratelimit.ExpirationTime.Duration <= 0 {
+
cfg.Ratelimit.ExpirationTime.Duration = 1 * time.Minute
+
}
}
return cfg, nil
+188 -171
pkg/dns/cache.go
···
import (
"container/list"
+
"fmt"
+
"log/slog"
"strings"
"sync"
"sync/atomic"
···
"tangled.sh/seiso.moe/magna"
)
-
type BailiwickRule int
-
const (
-
BailiwickSame BailiwickRule = iota
-
BailiwickChild
-
BailiwickOutside
+
defaultNegativeTTL = 5 * time.Minute
+
maxNegativeTTL = 3 * time.Hour
)
type CacheEntry struct {
-
Answer []CachedResourceRecord
-
Authority []CachedResourceRecord
-
Additional []CachedResourceRecord
-
NegativeTTL time.Duration
-
ExpireAt time.Time
-
IsNegative bool
-
}
+
Key string
+
RCode magna.RCode
+
IsNegative bool
+
+
Answer []magna.ResourceRecord
+
Authority []magna.ResourceRecord
+
Additional []magna.ResourceRecord
-
type CachedResourceRecord struct {
-
Record magna.ResourceRecord
-
ExpireAt time.Time
-
BailiwickRule BailiwickRule
+
CacheTime time.Time
+
ExpireAt time.Time
}
type CacheStats struct {
-
TotalQueries atomic.Int64
-
CacheHits atomic.Int64
-
CacheMisses atomic.Int64
-
NegativeHits atomic.Int64
-
PositiveHits atomic.Int64
+
Hits atomic.Int64
+
Misses atomic.Int64
Evictions atomic.Int64
+
Expired atomic.Int64
Size atomic.Int64
+
NegativeHits atomic.Int64
+
PositiveHits atomic.Int64
}
type LRUCache struct {
···
cacheMap map[string]*list.Element
stats CacheStats
stopCleanup chan struct{}
+
logger *slog.Logger
}
type lruItem struct {
key string
-
value *CacheEntry
+
entry *CacheEntry
}
type Cache interface {
Get(key string) (*CacheEntry, bool)
Set(key string, entry *CacheEntry)
-
GetStats() *CacheStats
+
GetStats() CacheStats
Stop()
}
-
func NewLRUCache(maxSize int, cleanupInterval time.Duration) *LRUCache {
+
func NewLRUCache(maxSize int, cleanupInterval time.Duration, logger *slog.Logger) *LRUCache {
if maxSize <= 0 {
maxSize = 5000
}
···
maxSize: maxSize,
cleanupInterval: cleanupInterval,
lruList: list.New(),
-
cacheMap: make(map[string]*list.Element),
+
cacheMap: make(map[string]*list.Element, maxSize),
stopCleanup: make(chan struct{}),
+
logger: logger.With("component", "cache"),
}
+
cache.logger.Info("starting LRU cache", "max_size", maxSize, "cleanup_interval", cleanupInterval)
go cache.periodicCleanup()
return cache
}
+
func GenerateCacheKey(q magna.Question) string {
+
name := strings.ToLower(q.QName)
+
if !strings.HasSuffix(name, ".") {
+
name += "."
+
}
+
return fmt.Sprintf("%s:%s:%s", name, q.QType.String(), q.QClass.String())
+
}
+
func (c *LRUCache) Get(key string) (*CacheEntry, bool) {
-
c.stats.TotalQueries.Add(1)
-
c.mu.RLock()
element, exists := c.cacheMap[key]
c.mu.RUnlock()
if !exists {
-
c.stats.CacheMisses.Add(1)
+
c.stats.Misses.Add(1)
return nil, false
}
···
defer c.mu.Unlock()
element, exists = c.cacheMap[key]
-
// element might have been evicted or cleaned between RUnlock and Lock
if !exists {
+
c.stats.Misses.Add(1)
return nil, false
}
item := element.Value.(*lruItem)
-
entry := item.value
+
entry := item.entry
-
if time.Now().After(entry.ExpireAt) {
+
now := time.Now()
+
if now.After(entry.ExpireAt) {
+
c.stats.Misses.Add(1)
+
c.stats.Expired.Add(1)
c.removeItem(element)
-
c.stats.CacheMisses.Add(1)
return nil, false
}
+
c.stats.Hits.Add(1)
if entry.IsNegative {
c.stats.NegativeHits.Add(1)
} else {
c.stats.PositiveHits.Add(1)
}
-
c.stats.CacheHits.Add(1)
+
+
c.lruList.MoveToFront(element)
+
respEntry := c.adjustTTLs(entry, now)
-
return entry, true
+
return respEntry, true
+
}
+
+
func (c *LRUCache) adjustTTLs(original *CacheEntry, now time.Time) *CacheEntry {
+
adjustedEntry := &CacheEntry{
+
Key: original.Key,
+
RCode: original.RCode,
+
IsNegative: original.IsNegative,
+
CacheTime: original.CacheTime,
+
ExpireAt: original.ExpireAt,
+
Answer: make([]magna.ResourceRecord, len(original.Answer)),
+
Authority: make([]magna.ResourceRecord, len(original.Authority)),
+
Additional: make([]magna.ResourceRecord, len(original.Additional)),
+
}
+
+
remainingDuration := original.ExpireAt.Sub(now)
+
if remainingDuration < 0 {
+
remainingDuration = 0
+
}
+
remainingTTL := uint32(remainingDuration.Seconds())
+
+
copyAndAdjust := func(dest *[]magna.ResourceRecord, src []magna.ResourceRecord) {
+
for i, rr := range src {
+
(*dest)[i] = rr
+
rrExpireAt := original.CacheTime.Add(time.Duration(rr.TTL) * time.Second)
+
rrRemainingDuration := rrExpireAt.Sub(now)
+
if rrRemainingDuration < 0 {
+
rrRemainingDuration = 0
+
}
+
rrRemainingTTL := uint32(rrRemainingDuration.Seconds())
+
+
finalTTL := min(rrRemainingTTL, remainingTTL)
+
(*dest)[i].TTL = finalTTL
+
}
+
}
+
+
copyAndAdjust(&adjustedEntry.Answer, original.Answer)
+
copyAndAdjust(&adjustedEntry.Authority, original.Authority)
+
copyAndAdjust(&adjustedEntry.Additional, original.Additional)
+
+
return adjustedEntry
}
func (c *LRUCache) Set(key string, entry *CacheEntry) {
-
if time.Now().After(entry.ExpireAt) {
+
if entry == nil {
+
c.logger.Warn("attempted to set nil entry in cache", "key", key)
+
return
+
}
+
+
if entry.ExpireAt.Before(time.Now().Add(1 * time.Second)) {
return
}
···
defer c.mu.Unlock()
if element, exists := c.cacheMap[key]; exists {
-
element.Value.(*lruItem).value = entry
c.lruList.MoveToFront(element)
+
element.Value.(*lruItem).entry = entry
return
}
newItem := &lruItem{
key: key,
-
value: entry,
+
entry: entry,
}
element := c.lruList.PushFront(newItem)
c.cacheMap[key] = element
-
currentSize := int64(c.lruList.Len())
-
c.stats.Size.Store(currentSize)
+
c.stats.Size.Store(int64(c.lruList.Len()))
for int64(c.lruList.Len()) > int64(c.maxSize) {
c.evictLRU()
···
c.stats.Size.Store(int64(c.lruList.Len()))
}
-
func (c *LRUCache) GetStats() *CacheStats {
-
statsSnapshot := &CacheStats{}
-
statsSnapshot.TotalQueries.Store(c.stats.TotalQueries.Load())
-
statsSnapshot.CacheHits.Store(c.stats.CacheHits.Load())
-
statsSnapshot.CacheMisses.Store(c.stats.CacheMisses.Load())
-
statsSnapshot.NegativeHits.Store(c.stats.NegativeHits.Load())
-
statsSnapshot.PositiveHits.Store(c.stats.PositiveHits.Load())
+
func (c *LRUCache) GetStats() CacheStats {
+
statsSnapshot := CacheStats{}
+
statsSnapshot.Hits.Store(c.stats.Hits.Load())
+
statsSnapshot.Misses.Store(c.stats.Misses.Load())
statsSnapshot.Evictions.Store(c.stats.Evictions.Load())
+
statsSnapshot.Expired.Store(c.stats.Expired.Load())
statsSnapshot.Size.Store(c.stats.Size.Load())
+
statsSnapshot.NegativeHits.Store(c.stats.NegativeHits.Load())
+
statsSnapshot.PositiveHits.Store(c.stats.PositiveHits.Load())
return statsSnapshot
}
···
defer c.mu.Unlock()
now := time.Now()
+
cleanedCount := 0
element := c.lruList.Back()
+
for element != nil {
item := element.Value.(*lruItem)
-
nextElement := element.Prev()
+
prevElement := element.Prev()
-
if now.After(item.value.ExpireAt) {
+
if now.After(item.entry.ExpireAt) {
c.removeItem(element)
+
c.stats.Expired.Add(1)
+
cleanedCount++
}
-
element = nextElement
+
element = prevElement
+
}
+
if cleanedCount > 0 {
+
c.logger.Info("Cache cleanup finished", "items_removed", cleanedCount, "current_size", c.stats.Size.Load())
}
}
···
close(c.stopCleanup)
}
-
func getMinTTL(records []magna.ResourceRecord) uint32 {
-
if len(records) == 0 {
-
return 60
-
}
-
minTTL := records[0].TTL
-
for _, rr := range records[1:] {
-
if rr.TTL == 0 {
-
return 60
-
}
-
if rr.TTL < minTTL {
-
minTTL = rr.TTL
-
}
-
}
-
if minTTL < 60 {
-
return 60
-
}
-
return minTTL
-
}
-
-
func CreateCacheEntry(msg *magna.Message, zone string) *CacheEntry {
+
func CreateCacheEntry(query magna.Question, response *magna.Message) *CacheEntry {
now := time.Now()
entry := &CacheEntry{
-
IsNegative: msg.Header.RCode == magna.NXDOMAIN,
+
Key: GenerateCacheKey(query),
+
RCode: response.Header.RCode,
+
Answer: response.Answer,
+
Authority: response.Authority,
+
Additional: response.Additional,
+
CacheTime: now,
}
-
if entry.IsNegative {
-
var soaTTL uint32
-
for _, auth := range msg.Authority {
-
if auth.RType == magna.SOAType {
-
soa, ok := auth.RData.(*magna.SOA)
-
if ok {
-
soaTTL = soa.Minimum
-
if auth.TTL < soaTTL {
-
soaTTL = auth.TTL
+
var minTTL time.Duration = -1
+
+
if response.Header.RCode == magna.NXDOMAIN || (response.Header.RCode == magna.NOERROR && len(response.Answer) == 0) {
+
entry.IsNegative = true
+
negativeTTL := defaultNegativeTTL
+
+
for _, rr := range response.Authority {
+
if rr.RType == magna.SOAType {
+
soa, ok := rr.RData.(*magna.SOA)
+
if ok && soa != nil {
+
ttl := time.Duration(rr.TTL) * time.Second
+
minimum := time.Duration(soa.Minimum) * time.Second
+
+
if minimum > maxNegativeTTL {
+
negativeTTL = ttl
+
} else {
+
negativeTTL = minDuration(ttl, minimum)
}
+
} else {
-
soaTTL = auth.TTL
+
negativeTTL = time.Duration(rr.TTL) * time.Second
}
break
}
}
-
if soaTTL == 0 {
-
soaTTL = 900
-
}
-
-
if soaTTL < 60 {
-
soaTTL = 60
+
if negativeTTL < 0 {
+
negativeTTL = defaultNegativeTTL
}
-
-
entry.NegativeTTL = time.Duration(soaTTL) * time.Second
-
entry.ExpireAt = now.Add(entry.NegativeTTL)
-
-
entry.Authority = make([]CachedResourceRecord, len(msg.Authority))
-
for i, rr := range msg.Authority {
-
recordTTL := rr.TTL
-
if recordTTL < 60 {
-
recordTTL = 60
-
}
-
rule := determineBailiwickRule(zone, rr.Name)
-
entry.Authority[i] = CachedResourceRecord{
-
Record: rr,
-
ExpireAt: now.Add(time.Duration(recordTTL) * time.Second),
-
BailiwickRule: rule,
-
}
+
if negativeTTL > maxNegativeTTL {
+
negativeTTL = maxNegativeTTL
}
-
} else {
-
minAnswerTTL := getMinTTL(msg.Answer)
+
minTTL = negativeTTL
+
} else if response.Header.RCode == magna.NOERROR {
+
entry.IsNegative = false
-
if len(msg.Answer) == 0 && len(msg.Authority) > 0 && msg.Authority[0].RType == magna.SOAType {
+
if len(response.Answer) > 0 {
+
minTTL = time.Duration(response.Answer[0].TTL) * time.Second
-
soa, ok := msg.Authority[0].RData.(*magna.SOA)
-
soaTTL := uint32(900)
-
if ok {
-
soaTTL = soa.Minimum
-
if msg.Authority[0].TTL < soaTTL {
-
soaTTL = msg.Authority[0].TTL
+
for _, rr := range response.Answer[1:] {
+
currentTTL := time.Duration(rr.TTL) * time.Second
+
if currentTTL < minTTL {
+
minTTL = currentTTL
}
-
} else {
-
soaTTL = msg.Authority[0].TTL
}
-
if soaTTL < 60 {
-
soaTTL = 60
-
}
-
+
} else {
entry.IsNegative = true
-
entry.NegativeTTL = time.Duration(soaTTL) * time.Second
-
entry.ExpireAt = now.Add(entry.NegativeTTL)
+
negativeTTL := defaultNegativeTTL
+
for _, rr := range response.Authority {
+
if rr.RType == magna.SOAType {
+
soa, ok := rr.RData.(*magna.SOA)
+
if ok && soa != nil {
+
ttl := time.Duration(rr.TTL) * time.Second
+
minimum := time.Duration(soa.Minimum) * time.Second
+
if minimum > maxNegativeTTL {
+
negativeTTL = ttl
+
} else {
+
negativeTTL = minDuration(ttl, minimum)
+
}
-
entry.Authority = make([]CachedResourceRecord, len(msg.Authority))
-
for i, rr := range msg.Authority {
-
recordTTL := rr.TTL
-
if recordTTL < 60 {
-
recordTTL = 60
+
} else {
+
negativeTTL = time.Duration(rr.TTL) * time.Second
+
}
+
break
}
-
rule := determineBailiwickRule(zone, rr.Name)
-
entry.Authority[i] = CachedResourceRecord{Record: rr, ExpireAt: now.Add(time.Duration(recordTTL) * time.Second), BailiwickRule: rule}
+
}
+
if negativeTTL < 0 {
+
negativeTTL = defaultNegativeTTL
}
-
-
} else {
-
entry.ExpireAt = now.Add(time.Duration(minAnswerTTL) * time.Second)
-
-
entry.Answer = make([]CachedResourceRecord, len(msg.Answer))
-
for i, rr := range msg.Answer {
-
recordTTL := rr.TTL
-
if recordTTL < 60 {
-
recordTTL = 60
-
}
-
rule := determineBailiwickRule(zone, rr.Name)
-
entry.Answer[i] = CachedResourceRecord{Record: rr, ExpireAt: now.Add(time.Duration(recordTTL) * time.Second), BailiwickRule: rule}
+
if negativeTTL > maxNegativeTTL {
+
negativeTTL = maxNegativeTTL
}
+
minTTL = negativeTTL
+
}
-
entry.Authority = make([]CachedResourceRecord, len(msg.Authority))
-
for i, rr := range msg.Authority {
-
recordTTL := rr.TTL
-
if recordTTL < 60 {
-
recordTTL = 60
-
}
-
rule := determineBailiwickRule(zone, rr.Name)
-
entry.Authority[i] = CachedResourceRecord{Record: rr, ExpireAt: now.Add(time.Duration(recordTTL) * time.Second), BailiwickRule: rule}
-
}
+
} else {
+
return nil
+
}
-
entry.Additional = make([]CachedResourceRecord, len(msg.Additional))
-
for i, rr := range msg.Additional {
-
recordTTL := rr.TTL
-
if recordTTL < 60 {
-
recordTTL = 60
-
}
-
rule := determineBailiwickRule(zone, rr.Name)
-
entry.Additional[i] = CachedResourceRecord{Record: rr, ExpireAt: now.Add(time.Duration(recordTTL) * time.Second), BailiwickRule: rule}
-
}
-
}
+
if minTTL < 0 {
+
return nil
}
+
entry.ExpireAt = now.Add(minTTL)
return entry
}
-
func isSubdomainOf(a, b string) bool {
-
if a == b {
-
return true
+
func minDuration(a, b time.Duration) time.Duration {
+
if a < b {
+
return a
}
-
-
return strings.HasSuffix(b, "."+a)
+
return b
}
-
func determineBailiwickRule(zone, name string) BailiwickRule {
-
if zone == name {
-
return BailiwickSame
+
func min(a, b uint32) uint32 {
+
if a < b {
+
return a
}
-
-
if isSubdomainOf(zone, name) {
-
return BailiwickChild
-
}
-
-
return BailiwickOutside
+
return b
}
+122 -14
pkg/dns/resolve.go
···
}
question := r.Message.Question[0]
-
records, err := h.resolveQuestion(r.Context, question, h.RootServers)
msg := r.Message.CreateReply(r.Message)
msg.Header.RA = true
+
if h.Cache != nil {
+
cacheKey := GenerateCacheKey(question)
+
if entry, found := h.Cache.Get(cacheKey); found {
+
h.Logger.Debug("cache hit", "question", question.QName, "type", question.QType)
+
+
r.Context = setCacheHit(r.Context)
+
+
msg.Header.RCode = entry.RCode
+
msg.Answer = entry.Answer
+
msg.Authority = entry.Authority
+
msg.Additional = entry.Additional
+
msg.Header.ANCount = uint16(len(msg.Answer))
+
msg.Header.NSCount = uint16(len(msg.Authority))
+
msg.Header.ARCount = uint16(len(msg.Additional))
+
+
w.WriteMsg(msg)
+
return
+
} else {
+
r.Context = setCacheMiss(r.Context)
+
}
+
}
+
+
records, err := h.resolveQuestion(r.Context, question, h.RootServers)
+
if err == errNXDOMAIN {
msg = msg.SetRCode(magna.NXDOMAIN)
msg.Authority = records
···
msg.Header.ANCount = 0
msg.Answer = nil
} else if err != nil {
-
msg.SetRCode(magna.SERVFAIL)
+
msg = msg.SetRCode(magna.SERVFAIL)
} else {
msg.Answer = records
msg.Header.ANCount = uint16(len(records))
msg = msg.SetRCode(magna.NOERROR)
}
+
if h.Cache != nil {
+
if err == nil || err == errNXDOMAIN {
+
entry := CreateCacheEntry(question, msg)
+
if entry != nil {
+
h.Cache.Set(GenerateCacheKey(question), entry)
+
h.Logger.Debug("cached response",
+
"question", question.QName,
+
"type", question.QType.String(),
+
"negative", entry.IsNegative,
+
"rcode", entry.RCode.String(),
+
"expires", entry.ExpireAt.Sub(time.Now()).String())
+
}
+
}
+
}
+
w.WriteMsg(msg)
}
func (h *QueryHandler) resolveQuestion(ctx context.Context, question magna.Question, servers []string) ([]magna.ResourceRecord, error) {
+
if h.Cache != nil {
+
cacheKey := GenerateCacheKey(question)
+
if entry, found := h.Cache.Get(cacheKey); found {
+
h.Logger.Debug("cache hit during recursion", "question", question.QName, "type", question.QType)
+
+
if entry.RCode == magna.NXDOMAIN {
+
return entry.Authority, errNXDOMAIN
+
}
+
+
if entry.IsNegative {
+
return []magna.ResourceRecord{}, nil
+
}
+
+
return entry.Answer, nil
+
}
+
}
+
resolveCtx, cancel := context.WithTimeout(ctx, h.Timeout)
defer cancel()
ch := make(chan queryResponse, len(servers))
···
select {
case res := <-ch:
if res.Error != nil {
-
slog.Warn("error", "question", question, "server", res.Server, "error", res.Error)
-
break
+
h.Logger.Warn("error", "question", question.QName, "server", res.Server, "error", res.Error)
+
continue
}
msg := res.MSG
+
+
fullMsg := magna.Message{
+
Header: msg.Header,
+
Question: []magna.Question{question},
+
Answer: msg.Answer,
+
Authority: msg.Authority,
+
Additional: msg.Additional,
+
}
+
if msg.Header.RCode == magna.NXDOMAIN {
+
if h.Cache != nil {
+
entry := CreateCacheEntry(question, &fullMsg)
+
if entry != nil {
+
h.Cache.Set(GenerateCacheKey(question), entry)
+
}
+
}
return msg.Authority, errNXDOMAIN
}
if msg.Header.ANCount > 0 {
if msg.Answer[0].RType == magna.CNAMEType {
-
cname_answers, err := h.resolveQuestion(resolveCtx, magna.Question{QName: msg.Answer[0].RData.String(), QType: question.QType, QClass: question.QClass}, h.RootServers)
+
if h.Cache != nil {
+
partialEntry := CreateCacheEntry(question, &fullMsg)
+
if partialEntry != nil {
+
h.Cache.Set(GenerateCacheKey(question), partialEntry)
+
}
+
}
+
+
cnameTarget := msg.Answer[0].RData.String()
+
h.Logger.Debug("following CNAME", "from", question.QName, "to", cnameTarget)
+
+
cname_answers, err := h.resolveQuestion(resolveCtx, magna.Question{
+
QName: cnameTarget,
+
QType: question.QType,
+
QClass: question.QClass,
+
}, h.RootServers)
if err != nil {
if err == errNXDOMAIN {
return cname_answers, errNXDOMAIN
}
-
slog.Warn("error resolving CNAME", "cname", msg.Answer[0].RData.String(), "error", err)
-
+
h.Logger.Warn("error resolving CNAME", "cname", cnameTarget, "error", err)
continue
}
-
msg.Answer = append(msg.Answer, cname_answers...)
+
combinedAnswer := append([]magna.ResourceRecord{}, msg.Answer[0])
+
combinedAnswer = append(combinedAnswer, cname_answers...)
+
+
fullMsg.Answer = combinedAnswer
+
if h.Cache != nil {
+
fullEntry := CreateCacheEntry(question, &fullMsg)
+
if fullEntry != nil {
+
h.Cache.Set(GenerateCacheKey(question), fullEntry)
+
}
+
}
+
+
return combinedAnswer, nil
}
+
if h.Cache != nil {
+
entry := CreateCacheEntry(question, &fullMsg)
+
if entry != nil {
+
h.Cache.Set(GenerateCacheKey(question), entry)
+
}
+
}
return msg.Answer, nil
}
···
if ip, found := glueMap[nsName]; found {
nextServers = append(nextServers, ip)
} else {
-
nsAnswers, err := h.resolveQuestion(resolveCtx, magna.Question{QName: nsName, QType: magna.AType, QClass: magna.IN}, h.RootServers)
+
nsQuestion := magna.Question{QName: nsName, QType: magna.AType, QClass: magna.IN}
+
nsAnswers, err := h.resolveQuestion(resolveCtx, nsQuestion, h.RootServers)
if err != nil {
-
slog.Warn("error resolving NS A record", "ns", nsName, "error", err)
+
h.Logger.Warn("error resolving NS A record", "ns", nsName, "error", err)
continue
}
···
return h.resolveQuestion(resolveCtx, question, nextServers)
}
-
slog.Warn("could not resolve any NS records for delegation", "question", question.QName)
+
h.Logger.Warn("could not resolve any NS records for delegation", "question", question.QName)
continue
}
if msg.Header.RCode == magna.NOERROR && msg.Header.ANCount == 0 {
+
if h.Cache != nil {
+
entry := CreateCacheEntry(question, &fullMsg)
+
if entry != nil {
+
h.Cache.Set(GenerateCacheKey(question), entry)
+
}
+
}
return []magna.ResourceRecord{}, nil
}
-
slog.Warn("unexpected response state", "question", question, "server", res.Server, "rcode", msg.Header.RCode)
+
h.Logger.Warn("unexpected response state", "question", question.QName, "server", res.Server, "rcode", msg.Header.RCode)
continue
+
case <-resolveCtx.Done():
-
slog.Debug("resolution cancelled or timed out", "question", question)
+
h.Logger.Debug("resolution cancelled or timed out", "question", question.QName)
return []magna.ResourceRecord{}, fmt.Errorf("resolution timed out or cancelled")
}
}
-
slog.Warn("all resolution paths failed", "question", question)
+
h.Logger.Warn("all resolution paths failed", "question", question.QName)
return []magna.ResourceRecord{}, fmt.Errorf("all resolution paths failed")
}
+155 -38
pkg/dns/server.go
···
}
func GetCacheHit(ctx context.Context) bool {
-
v := ctx.Value(cacheHitKey)
-
if v == nil {
+
if ctx == nil {
return false
}
-
-
return v.(bool)
+
v := ctx.Value(cacheHitKey)
+
hit, ok := v.(bool)
+
return ok && hit
}
type Handler interface {
···
addr *net.UDPAddr
logger *slog.Logger
writeTimeout time.Duration
+
udpSize int
}
func (w *udpResponseWriter) WriteMsg(msg *magna.Message) {
ans, err := msg.Encode()
if err != nil {
w.logger.Warn("err encoding msg", "error", err)
-
return
+
+
failMsg := msg.CreateReply(msg)
+
failMsg = failMsg.SetRCode(magna.SERVFAIL)
+
failMsg.Answer = nil
+
failMsg.Authority = nil
+
failMsg.Additional = nil
+
failMsg.Header.ANCount = 0
+
failMsg.Header.NSCount = 0
+
failMsg.Header.ARCount = 0
+
ans, _ = failMsg.Encode()
+
if ans == nil {
+
return
+
}
+
}
+
+
if len(ans) > w.udpSize {
+
w.logger.Debug("Response exceeds UDP size, setting TC bit", "size", len(ans), "limit", w.udpSize, "client", w.addr)
+
+
tcMsg := msg.CreateReply(msg)
+
tcMsg.Header.TC = true
+
tcMsg.Answer = nil
+
tcMsg.Authority = nil
+
tcMsg.Additional = nil
+
tcMsg.Header.ANCount = 0
+
tcMsg.Header.NSCount = 0
+
tcMsg.Header.ARCount = 0
+
tcMsg.Header.RCode = magna.NOERROR
+
+
ans, err = tcMsg.Encode()
+
if err != nil {
+
w.logger.Error("Error encoding truncated UDP response", "error", err, "client", w.addr)
+
return
+
}
+
if len(ans) > w.udpSize {
+
w.logger.Warn("Truncated message still exceeds UDP size limit!", "size", len(ans), "limit", w.udpSize)
+
}
}
-
if len(ans) > 512 {
-
ans[3] |= 1 << 6 // set the truncated bit
+
if w.udpConn == nil || w.addr == nil {
+
w.logger.Error("UDP response writer used incorrectly (nil conn or addr)")
+
return
}
err = w.udpConn.SetWriteDeadline(time.Now().Add(w.writeTimeout))
···
ans, err := msg.Encode()
if err != nil {
w.logger.Warn("err encoding msg", "error", err)
+
+
failMsg := msg.CreateReply(msg)
+
failMsg = failMsg.SetRCode(magna.SERVFAIL)
+
failMsg.Answer = nil
+
failMsg.Authority = nil
+
failMsg.Additional = nil
+
failMsg.Header.ANCount = 0
+
failMsg.Header.NSCount = 0
+
failMsg.Header.ARCount = 0
+
ans, _ = failMsg.Encode()
+
return
+
}
+
+
if w.tcpConn == nil {
+
w.logger.Error("TCP response writer used incorrectly (nil conn)")
return
}
···
ReadTimeout time.Duration
WriteTimeout time.Duration
Logger *slog.Logger
-
Cache Cache
}
func (srv *Server) ListenAndServe() error {
+
if srv.Logger == nil {
+
srv.Logger = slog.Default()
+
}
+
+
if srv.UDPSize <= 0 {
+
srv.UDPSize = 512
+
}
+
srv.Logger.Info("Starting DNS server", "address", srv.Address, "port", srv.Port, "udp_size", srv.UDPSize)
+
var wg sync.WaitGroup
errChan := make(chan error, 2)
···
go func() {
defer wg.Done()
if err := srv.serveTCP(); err != nil {
-
errChan <- fmt.Errorf("TCP server error: %w", err)
+
err = fmt.Errorf("TCP server error: %w", err)
+
srv.Logger.Error(err.Error())
+
select {
+
case errChan <- err:
+
default:
+
srv.Logger.Warn("Error channel full, discarding TCP error")
+
}
}
}()
go func() {
defer wg.Done()
if err := srv.serveUDP(); err != nil {
-
errChan <- fmt.Errorf("TCP server error: %w", err)
+
err = fmt.Errorf("UDP server error: %w", err)
+
srv.Logger.Error(err.Error())
+
select {
+
case errChan <- err:
+
default:
+
srv.Logger.Warn("Error channel full, discarding UDP error")
+
}
}
}()
···
close(errChan)
}()
-
for err := range errChan {
-
return err
-
}
-
-
return nil
+
err := <-errChan
+
return err
}
func (srv *Server) serveUDP() error {
-
addr := net.UDPAddr{
+
addr := &net.UDPAddr{
Port: srv.Port,
IP: net.ParseIP(srv.Address),
}
-
conn, err := net.ListenUDP("udp", &addr)
+
conn, err := net.ListenUDP("udp", addr)
if err != nil {
-
return err
+
return fmt.Errorf("failed to listen on UDP %s:%d: %w", srv.Address, srv.Port, err)
}
defer conn.Close()
+
srv.Logger.Info("UDP listener started", "address", conn.LocalAddr())
for {
buf := make([]byte, srv.UDPSize)
-
err := conn.SetReadDeadline(time.Now().Add(srv.ReadTimeout))
-
if err != nil {
-
return fmt.Errorf("error setting read deadline: %w", err)
+
if srv.ReadTimeout > 0 {
+
err := conn.SetReadDeadline(time.Now().Add(srv.ReadTimeout))
+
if err != nil {
+
return fmt.Errorf("error setting UDP read deadline: %w", err)
+
}
}
n, remoteAddr, err := conn.ReadFromUDP(buf)
if err != nil {
-
// skip logging timeout errors
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
continue
}
-
srv.Logger.Warn(err.Error())
+
if opErr, ok := err.(*net.OpError); ok && opErr.Err.Error() == "use of closed network connection" {
+
srv.Logger.Info("UDP listener stopping: connection closed")
+
return nil
+
}
+
+
srv.Logger.Warn("UDP read error", "error", err)
continue
}
···
addr: remoteAddr,
logger: srv.Logger,
writeTimeout: srv.WriteTimeout,
+
udpSize: srv.UDPSize,
}
srv.handleQuery(query, w, remoteAddr)
···
listener, err := net.ListenTCP("tcp", &addr)
if err != nil {
-
return err
+
return fmt.Errorf("failed to listen on TCP %s:%d: %w", srv.Address, srv.Port, err)
}
defer listener.Close()
+
srv.Logger.Info("TCP listener started", "address", listener.Addr())
for {
-
conn, err := listener.Accept()
+
conn, err := listener.AcceptTCP()
if err != nil {
-
srv.Logger.Warn("tcp accept error:", "error", err)
+
if opErr, ok := err.(*net.OpError); ok && opErr.Err.Error() == "use of closed network connection" {
+
srv.Logger.Info("TCP listener stopping: listener closed")
+
return nil
+
}
+
srv.Logger.Warn("TCP accept error", "error", err)
continue
}
+
conn.SetKeepAlive(true)
+
conn.SetKeepAlivePeriod(3 * time.Minute)
go srv.handleTCPQuery(conn)
}
}
···
func (srv *Server) handleTCPQuery(conn net.Conn) {
defer conn.Close()
-
err := conn.SetReadDeadline(time.Now().Add(srv.ReadTimeout))
-
if err != nil {
-
srv.Logger.Error("error setting read deadline", "error", err)
-
return
+
if srv.ReadTimeout > 0 {
+
err := conn.SetReadDeadline(time.Now().Add(srv.ReadTimeout))
+
if err != nil {
+
srv.Logger.Warn("Error setting TCP initial read deadline", "error", err, "client", conn.RemoteAddr())
+
return
+
}
}
sizeBuffer := make([]byte, 2)
-
if _, err := io.ReadFull(conn, sizeBuffer); err != nil {
-
srv.Logger.Warn("tcp error occurred", "error", err)
+
_, err := io.ReadFull(conn, sizeBuffer)
+
if err != nil {
+
if err != io.EOF && err != io.ErrUnexpectedEOF {
+
srv.Logger.Warn("TCP error reading message length", "error", err, "client", conn.RemoteAddr())
+
}
return
}
size := binary.BigEndian.Uint16(sizeBuffer)
+
if size == 0 {
+
srv.Logger.Debug("TCP received zero-length message", "client", conn.RemoteAddr())
+
return
+
}
+
+
maxTCPMessageSize := 65535
+
if size > uint16(maxTCPMessageSize) {
+
srv.Logger.Warn("TCP message size exceeds limit", "size", size, "limit", maxTCPMessageSize, "client", conn.RemoteAddr())
+
return
+
}
+
data := make([]byte, size)
-
if _, err := io.ReadFull(conn, data); err != nil {
-
srv.Logger.Warn("tcp error occurred", "error", err)
+
if srv.ReadTimeout > 0 {
+
err = conn.SetReadDeadline(time.Now().Add(srv.ReadTimeout))
+
if err != nil {
+
srv.Logger.Warn("Error setting TCP read deadline for body", "error", err, "client", conn.RemoteAddr())
+
return
+
}
+
}
+
+
_, err = io.ReadFull(conn, data)
+
if err != nil {
+
srv.Logger.Warn("TCP error reading message body", "error", err, "client", conn.RemoteAddr())
return
}
···
func (srv *Server) handleQuery(messageBuffer []byte, w ResponseWriter, remoteAddr net.Addr) {
var query magna.Message
-
if err := query.Decode(messageBuffer); err != nil {
-
srv.Logger.Warn("message decode error", "error", err)
+
err := query.Decode(messageBuffer)
+
if err != nil {
+
srv.Logger.Warn("Message decode error", "error", err, "client", remoteAddr)
+
// TODO: find better way to handle failed decode drop for now.
return
}
r := &Request{
+
Context: context.Background(),
Message: &query,
RemoteAddr: remoteAddr,
}
-
r.Context = context.WithValue(context.Background(), contextKey("request"), r)
-
+
if srv.Handler == nil {
+
srv.Logger.Error("No DNS handler configured!")
+
reply := query.CreateReply(&query)
+
reply = reply.SetRCode(magna.SERVFAIL)
+
w.WriteMsg(reply)
+
return
+
}
srv.Handler.ServeDNS(w, r)
}
+137 -66
pkg/metrics/clickhouse.go
···
import (
"database/sql"
+
"errors"
"fmt"
"log/slog"
"sync"
···
_ "github.com/ClickHouse/clickhouse-go/v2"
"tangled.sh/seiso.moe/alky/pkg/config"
-
"tangled.sh/seiso.moe/alky/pkg/dns"
)
type ClickHouseMetrics struct {
···
cacheBuffer []CacheMetric
mu sync.Mutex
stopChan chan struct{}
+
wg sync.WaitGroup
logger *slog.Logger
}
···
type CacheMetric struct {
Timestamp time.Time
-
TotalQueries int64
CacheHits int64
CacheMisses int64
NegativeHits int64
···
}
func NewClickHouseMetrics(config *config.MetricsConfig, logger *slog.Logger) (*ClickHouseMetrics, error) {
+
if logger == nil {
+
logger = slog.Default()
+
}
+
logger.Info("Connecting to ClickHouse", "dsn", config.DSN)
+
db, err := sql.Open("clickhouse", config.DSN)
if err != nil {
-
return nil, fmt.Errorf("failed to connect to ClickHouse: %w", err)
+
return nil, fmt.Errorf("failed to initialize ClickHouse driver: %w", err)
}
+
if err := db.Ping(); err != nil {
+
db.Close()
+
return nil, fmt.Errorf("failed to connect to ClickHouse (%s): %w", config.DSN, err)
+
}
+
logger.Info("Successfully connected to ClickHouse")
+
m := &ClickHouseMetrics{
db: db,
config: config,
···
logger: logger,
}
-
if err := m.changeTTL(); err != nil {
-
db.Close()
-
return nil, fmt.Errorf("failed to initialize tables: %w", err)
+
if err := m.checkAndUpdateTTL(); err != nil {
+
logger.Warn("Failed to check/update table TTLs, using existing TTL", "error", err)
+
} else {
+
logger.Info("Table TTLs verified/updated", "retention", config.RetentionPeriod.Duration)
}
+
m.wg.Add(1)
go m.flushLoop()
return m, nil
}
···
m.queryBuffer = append(m.queryBuffer, metric)
if len(m.queryBuffer) >= m.config.BatchSize {
-
m.flush()
+
m.flushQueriesLocked()
}
}
-
func (m *ClickHouseMetrics) RecordCacheMetrics(metric CacheMetric) {
+
func (m *ClickHouseMetrics) RecordCacheStats(metric CacheMetric) {
m.mu.Lock()
defer m.mu.Unlock()
m.cacheBuffer = append(m.cacheBuffer, metric)
if len(m.cacheBuffer) >= m.config.BatchSize {
-
m.flush()
+
m.flushCacheMetricsLocked()
}
}
-
func (m *ClickHouseMetrics) RecordCacheStats(stats *dns.CacheStats) {
-
m.RecordCacheMetrics(CacheMetric{
-
Timestamp: time.Now(),
-
TotalQueries: stats.TotalQueries.Load(),
-
CacheHits: stats.CacheHits.Load(),
-
CacheMisses: stats.CacheMisses.Load(),
-
NegativeHits: stats.NegativeHits.Load(),
-
PositiveHits: stats.PositiveHits.Load(),
-
Evictions: stats.Evictions.Load(),
-
Size: stats.Size.Load(),
-
})
-
}
-
func (m *ClickHouseMetrics) flushLoop() {
+
defer m.wg.Done()
ticker := time.NewTicker(m.config.FlushInterval.Duration)
defer ticker.Stop()
+
m.logger.Info("Metrics flush loop started", "interval", m.config.FlushInterval.Duration)
+
for {
select {
case <-ticker.C:
m.mu.Lock()
-
m.flush()
+
m.flushLocked()
m.mu.Unlock()
case <-m.stopChan:
+
m.logger.Info("Metrics flush loop received stop signal. Flushing remaining data...")
+
m.mu.Lock()
+
m.flushLocked()
+
m.mu.Unlock()
+
m.logger.Info("Metrics flush loop stopped.")
return
}
}
}
-
func (m *ClickHouseMetrics) flush() {
+
func (m *ClickHouseMetrics) flushLocked() {
if len(m.queryBuffer) > 0 {
-
if err := m.flushQueries(); err != nil {
-
m.logger.Error("Failed to flush query metrics", "error", err)
-
}
-
m.queryBuffer = m.queryBuffer[:0]
+
m.flushQueriesLocked()
}
-
if len(m.cacheBuffer) > 0 {
-
if err := m.flushCacheMetrics(); err != nil {
-
m.logger.Error("Failed to flush cache metrics", "error", err)
-
}
-
m.cacheBuffer = m.cacheBuffer[:0]
+
m.flushCacheMetricsLocked()
}
}
-
func (m *ClickHouseMetrics) changeTTL() error {
-
if _, err := m.db.Exec(
-
"ALTER TABLE alky_dns_queries MODIFY TTL timestamp + toIntervalSecond(?)",
-
int(m.config.RetentionPeriod.Seconds()),
-
); err != nil {
-
return fmt.Errorf("failed to update alky_dns_queries TTL: %w", err)
+
func (m *ClickHouseMetrics) checkAndUpdateTTL() error {
+
retentionSeconds := int(m.config.RetentionPeriod.Duration.Seconds())
+
if retentionSeconds <= 0 {
+
m.logger.Warn("Invalid retention period configured, skipping TTL update", "retention", m.config.RetentionPeriod.Duration)
+
return nil
+
}
+
+
m.logger.Info("Checking/Updating table TTLs", "retention_seconds", retentionSeconds)
+
+
updateTableTTL := func(tableName string) error {
+
query := fmt.Sprintf("ALTER TABLE %s MODIFY TTL timestamp + toIntervalSecond(?)", tableName)
+
m.logger.Debug("Executing TTL update query", "query", query)
+
_, err := m.db.Exec(query, retentionSeconds)
+
if err != nil && errors.Is(err, sql.ErrNoRows) {
+
m.logger.Warn("Table not found while updating TTL, likely needs migration", "table", tableName)
+
return nil
+
} else if err != nil {
+
return fmt.Errorf("failed to update %s TTL: %w", tableName, err)
+
}
+
m.logger.Debug("TTL updated successfully", "table", tableName)
+
return nil
}
-
if _, err := m.db.Exec(
-
"ALTER TABLE alky_dns_cache_metrics MODIFY TTL timestamp + toIntervalSecond(?)",
-
int(m.config.RetentionPeriod.Seconds()),
-
); err != nil {
-
return fmt.Errorf("failed to update alky_dns_cache_metrics TTL: %w", err)
+
if err := updateTableTTL("alky_dns_queries"); err != nil {
+
return err
+
}
+
if err := updateTableTTL("alky_dns_cache_metrics"); err != nil {
+
return err
}
return nil
}
-
func (m *ClickHouseMetrics) flushQueries() error {
+
func (m *ClickHouseMetrics) flushQueriesLocked() {
+
if len(m.queryBuffer) == 0 {
+
return
+
}
+
+
m.logger.Debug("Flushing query metrics", "count", len(m.queryBuffer))
+
start := time.Now()
+
tx, err := m.db.Begin()
if err != nil {
-
return err
+
m.logger.Error("Failed to begin transaction for query metrics", "error", err)
+
m.queryBuffer = m.queryBuffer[:0]
+
return
}
-
defer tx.Rollback()
+
defer func() {
+
if err != nil {
+
if rbErr := tx.Rollback(); rbErr != nil {
+
m.logger.Error("Failed to rollback transaction for query metrics", "error", rbErr)
+
}
+
}
+
}()
stmt, err := tx.Prepare(`
INSERT INTO alky_dns_queries (
···
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
`)
if err != nil {
-
return err
+
m.logger.Error("Failed to prepare statement for query metrics", "error", err)
+
m.queryBuffer = m.queryBuffer[:0]
+
return
}
defer stmt.Close()
+
count := 0
for _, metric := range m.queryBuffer {
-
_, err := stmt.Exec(
+
_, err = stmt.Exec(
metric.Timestamp,
metric.InstanceID,
metric.QueryName,
···
metric.CacheHit,
)
if err != nil {
-
return err
+
m.logger.Error("Failed to execute statement for query metric", "error", err, "metric_index", count)
+
return
}
+
count++
}
-
return tx.Commit()
+
err = tx.Commit()
+
if err != nil {
+
m.logger.Error("Failed to commit transaction for query metrics", "error", err)
+
return
+
}
+
+
m.logger.Debug("Successfully flushed query metrics", "count", count, "duration", time.Since(start))
+
m.queryBuffer = m.queryBuffer[:0]
}
-
func (m *ClickHouseMetrics) flushCacheMetrics() error {
+
func (m *ClickHouseMetrics) flushCacheMetricsLocked() {
+
if len(m.cacheBuffer) == 0 {
+
return
+
}
+
+
m.logger.Debug("Flushing cache metrics", "count", len(m.cacheBuffer))
+
start := time.Now()
+
tx, err := m.db.Begin()
if err != nil {
-
return err
+
m.logger.Error("Failed to begin transaction for cache metrics", "error", err)
+
m.cacheBuffer = m.cacheBuffer[:0]
+
return
}
-
defer tx.Rollback()
+
defer func() {
+
if err != nil {
+
if rbErr := tx.Rollback(); rbErr != nil {
+
m.logger.Error("Failed to rollback transaction for cache metrics", "error", rbErr)
+
}
+
}
+
}()
stmt, err := tx.Prepare(`
INSERT INTO alky_dns_cache_metrics (
-
timestamp, instance_id, total_queries, cache_hits, cache_misses,
-
negative_hits, positive_hits, evictions, size
-
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
+
timestamp, instance_id, cache_hits, cache_misses,
+
negative_hits, positive_hits, evictions, size
+
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
`)
if err != nil {
-
return err
+
m.logger.Error("Failed to prepare statement for cache metrics", "error", err)
+
m.cacheBuffer = m.cacheBuffer[:0]
+
return
}
defer stmt.Close()
+
count := 0
for _, metric := range m.cacheBuffer {
-
_, err := stmt.Exec(
+
instanceID := GetInstanceID()
+
+
_, err = stmt.Exec(
metric.Timestamp,
-
GetInstanceID(),
-
metric.TotalQueries,
+
instanceID,
metric.CacheHits,
metric.CacheMisses,
metric.NegativeHits,
···
metric.Size,
)
if err != nil {
-
return err
+
m.logger.Error("Failed to execute statement for cache metric", "error", err, "metric_index", count)
+
return
}
+
count++
}
-
return tx.Commit()
+
err = tx.Commit()
+
if err != nil {
+
m.logger.Error("Failed to commit transaction for cache metrics", "error", err)
+
return
+
}
+
+
m.logger.Debug("Successfully flushed cache metrics", "count", count, "duration", time.Since(start))
+
m.cacheBuffer = m.cacheBuffer[:0]
}
func (m *ClickHouseMetrics) Close() error {
+
m.logger.Info("Closing metrics client...")
close(m.stopChan)
-
m.mu.Lock()
-
defer m.mu.Unlock()
-
m.flush()
+
m.wg.Wait()
+
m.logger.Info("Flush loop stopped. Closing database connection.")
return m.db.Close()
}
+39 -17
pkg/metrics/middleware.go
···
import (
"context"
"fmt"
+
"log/slog"
"os"
+
"sync"
"time"
"tangled.sh/seiso.moe/alky/pkg/dns"
···
var (
instanceID string
+
initOnce sync.Once
version string
)
-
func init() {
-
var err error
+
func initializeInstanceID() {
hostname, err := os.Hostname()
if err != nil {
-
hostname = "unknown"
+
hostname = "unknown-host"
}
+
pid := os.Getpid()
+
if version != "" {
-
instanceID = fmt.Sprintf("%s-%s", hostname, version)
+
instanceID = fmt.Sprintf("%s-%d-%s", hostname, pid, version)
} else {
-
instanceID = hostname
+
instanceID = fmt.Sprintf("%s-%d-dev", hostname, pid)
}
}
func MetricsMiddleware(metrics *ClickHouseMetrics) func(dns.Handler) dns.Handler {
+
if metrics == nil {
+
slog := slog.Default()
+
slog.Warn("Metrics client is nil, metrics middleware will be a no-op")
+
return func(next dns.Handler) dns.Handler {
+
return next
+
}
+
}
+
+
instanceID := GetInstanceID()
+
return func(next dns.Handler) dns.Handler {
return dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Request) {
if r.Context == nil {
···
next.ServeDNS(w, r)
duration := time.Since(start)
-
question := r.Message.Question[0]
-
metrics.RecordQuery(QueryMetric{
-
Timestamp: time.Now(),
-
InstanceID: instanceID,
-
QueryName: question.QName,
-
QueryType: question.QType.String(),
-
QueryClass: question.QClass.String(),
-
RemoteAddr: r.RemoteAddr.String(),
-
ResponseCode: r.Message.Header.RCode.String(),
-
Duration: duration.Nanoseconds(),
-
CacheHit: dns.GetCacheHit(r.Context),
-
})
+
if r.Message != nil && len(r.Message.Question) > 0 {
+
question := r.Message.Question[0]
+
cacheHit := dns.GetCacheHit(r.Context)
+
+
metrics.RecordQuery(QueryMetric{
+
Timestamp: start,
+
InstanceID: instanceID,
+
QueryName: question.QName,
+
QueryType: question.QType.String(),
+
QueryClass: question.QClass.String(),
+
RemoteAddr: r.RemoteAddr.String(),
+
ResponseCode: r.Message.Header.RCode.String(),
+
Duration: duration.Nanoseconds(),
+
CacheHit: cacheHit,
+
})
+
} else {
+
slog := slog.Default()
+
slog.Warn("Metrics middleware received request with missing message or question", "remote_addr", r.RemoteAddr)
+
}
})
}
}
func GetInstanceID() string {
+
initOnce.Do(initializeInstanceID)
return instanceID
}