a recursive dns resolver
1package main 2 3import ( 4 "flag" 5 "log" 6 "log/slog" 7 "os" 8 "os/signal" 9 "strings" 10 "syscall" 11 "time" 12 13 "tangled.sh/seiso.moe/alky/pkg/config" 14 "tangled.sh/seiso.moe/alky/pkg/dns" 15 "tangled.sh/seiso.moe/alky/pkg/metrics" 16 "tangled.sh/seiso.moe/alky/pkg/rootservers" 17) 18 19var configFlag string 20 21func init() { 22 flag.StringVar(&configFlag, "config", "/etc/alky/alky.toml", "config file path for alky") 23 flag.Parse() 24} 25 26func main() { 27 cfg, err := config.LoadConfig(configFlag) 28 if err != nil { 29 log.Fatalf("failed to load config: %v", err) 30 } 31 32 logger := setupLogger(&cfg) 33 slog.SetDefault(logger) 34 35 logger.Info("Configuration loaded", "path", configFlag) 36 37 metricsClient, err := metrics.NewClickHouseMetrics(&cfg.Metrics, logger.With("component", "metrics")) 38 if err != nil { 39 logger.Error("failed to initialize metrics client", "error", err) 40 os.Exit(1) 41 } 42 defer metricsClient.Close() 43 logger.Info("Metrics client initialized") 44 45 rootServers, err := rootservers.DecodeRootHints(cfg.Server.RootHintsFile) 46 if err != nil || len(rootServers) == 0 { 47 logger.Error("Failed to load root hints or no root servers found", "path", cfg.Server.RootHintsFile, "error", err) 48 os.Exit(1) 49 } 50 logger.Info("Root hints loaded", "count", len(rootServers), "path", cfg.Server.RootHintsFile) 51 52 cache := dns.NewLRUCache(cfg.Cache.MaxItems, cfg.Cache.CleanupInterval.Duration, logger.With("component", "cache")) 53 defer cache.Stop() 54 logger.Info("DNS cache initialized") 55 56 queryHandler := &dns.QueryHandler{ 57 RootServers: rootServers, 58 Timeout: cfg.Advanced.QueryTimeout.Duration, 59 Cache: cache, 60 Logger: logger.With("component", "resolver"), 61 } 62 63 go monitorCacheMetrics(cache, metricsClient, logger.With("component", "cache-monitor")) 64 65 var currentHandler dns.Handler = queryHandler 66 67 if cfg.Ratelimit.Rate > 0 { 68 rateLimitHandler := dns.RateLimitMiddleware(&dns.RateLimitConfig{ 69 Rate: float64(cfg.Ratelimit.Rate), 70 Burst: cfg.Ratelimit.Burst, 71 WindowLength: cfg.Ratelimit.Window.Duration, 72 ExpirationTime: cfg.Ratelimit.ExpirationTime.Duration, 73 })(currentHandler) 74 currentHandler = rateLimitHandler 75 logger.Info("Rate limiting enabled", "rate", cfg.Ratelimit.Rate, "burst", cfg.Ratelimit.Burst) 76 } else { 77 logger.Info("Rate limiting disabled") 78 } 79 80 metricsHandler := metrics.MetricsMiddleware(metricsClient)(currentHandler) 81 currentHandler = metricsHandler 82 logger.Info("Metrics middleware enabled") 83 84 loggingHandler := dns.LoggingMiddleware(&dns.LogConfig{ 85 Logger: logger.With("component", "query-log"), 86 Level: slog.LevelInfo, 87 })(currentHandler) 88 currentHandler = loggingHandler 89 logger.Info("Logging middleware enabled") 90 91 sigChan := make(chan os.Signal, 1) 92 signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) 93 go func() { 94 sig := <-sigChan 95 logger.Info("Received signal, shutting down gracefully...", "signal", sig.String()) 96 97 logger.Info("Stopping cache...") 98 cache.Stop() 99 100 logger.Info("Closing metrics client...") 101 metricsClient.Close() 102 103 logger.Info("Shutdown complete.") 104 os.Exit(0) 105 }() 106 107 s := dns.Server{ 108 Address: cfg.Server.Address, 109 Port: cfg.Server.Port, 110 Handler: currentHandler, 111 UDPSize: cfg.Server.UDPSize, 112 ReadTimeout: cfg.Advanced.ReadTimeout.Duration, 113 WriteTimeout: cfg.Advanced.WriteTimeout.Duration, 114 Logger: logger.With("component", "server"), 115 } 116 117 logger.Info("Starting DNS server listener...") 118 if err := s.ListenAndServe(); err != nil { 119 logger.Error("Failed to start server", "error", err) 120 os.Exit(1) 121 } 122} 123 124func monitorCacheMetrics(cache dns.Cache, metricsClient *metrics.ClickHouseMetrics, logger *slog.Logger) { 125 interval := 1 * time.Minute 126 logger.Info("Starting cache metrics monitoring", "interval", interval) 127 ticker := time.NewTicker(interval) 128 defer ticker.Stop() 129 130 for { 131 select { 132 case <-ticker.C: 133 stats := cache.GetStats() 134 logger.Debug("Recording cache metrics", 135 "hits", stats.Hits.Load(), 136 "misses", stats.Misses.Load(), 137 "size", stats.Size.Load(), 138 "evictions", stats.Evictions.Load(), 139 "expired", stats.Expired.Load(), 140 "pos_hits", stats.PositiveHits.Load(), 141 "neg_hits", stats.NegativeHits.Load(), 142 ) 143 metricsClient.RecordCacheStats(metrics.CacheMetric{ 144 Timestamp: time.Now(), 145 CacheHits: stats.Hits.Load(), 146 CacheMisses: stats.Misses.Load(), 147 NegativeHits: stats.NegativeHits.Load(), 148 PositiveHits: stats.PositiveHits.Load(), 149 Evictions: stats.Evictions.Load() + stats.Expired.Load(), 150 Size: stats.Size.Load(), 151 }) 152 } 153 } 154} 155 156func setupLogger(cfg *config.Config) *slog.Logger { 157 var logLevel slog.Level 158 switch strings.ToLower(cfg.Logging.Level) { 159 case "debug": 160 logLevel = slog.LevelDebug 161 case "info": 162 logLevel = slog.LevelInfo 163 case "warn": 164 logLevel = slog.LevelWarn 165 case "error": 166 logLevel = slog.LevelError 167 default: 168 logLevel = slog.LevelInfo 169 } 170 171 handlerOpts := &slog.HandlerOptions{ 172 Level: logLevel, 173 AddSource: logLevel <= slog.LevelDebug, 174 } 175 176 var handler slog.Handler 177 switch cfg.Logging.Output { 178 case "file": 179 f, err := os.OpenFile(cfg.Logging.FilePath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0o644) 180 if err != nil { 181 log.Fatalf("Failed to open log file %s: %v", cfg.Logging.FilePath, err) 182 } 183 handler = slog.NewJSONHandler(f, handlerOpts) 184 log.Printf("Logging to file: %s at level: %s", cfg.Logging.FilePath, logLevel.String()) 185 case "text": 186 handler = slog.NewTextHandler(os.Stdout, handlerOpts) 187 log.Printf("Logging to stdout (text) at level: %s", logLevel.String()) 188 default: 189 handler = slog.NewJSONHandler(os.Stdout, handlerOpts) 190 log.Printf("Logging to stdout (json) at level: %s", logLevel.String()) 191 } 192 193 logger := slog.New(handler) 194 return logger 195}