a recursive dns resolver
at main 4.4 kB view raw
1package main 2 3import ( 4 "flag" 5 "log" 6 "log/slog" 7 "os" 8 "os/signal" 9 "strings" 10 "syscall" 11 12 "tangled.sh/seiso.moe/alky/pkg/config" 13 "tangled.sh/seiso.moe/alky/pkg/dns" 14 "tangled.sh/seiso.moe/alky/pkg/metrics" 15 "tangled.sh/seiso.moe/alky/pkg/rootservers" 16) 17 18var configFlag string 19 20func init() { 21 flag.StringVar(&configFlag, "config", "/etc/alky/alky.toml", "config file path for alky") 22 flag.Parse() 23} 24 25func main() { 26 cfg, err := config.LoadConfig(configFlag) 27 if err != nil { 28 log.Fatalf("failed to load config: %v", err) 29 } 30 31 logger := setupLogger(&cfg) 32 slog.SetDefault(logger) 33 34 logger.Info("Configuration loaded", "path", configFlag) 35 36 metricsClient, err := metrics.NewClickHouseMetrics(&cfg.Metrics, logger.With("component", "metrics")) 37 if err != nil { 38 logger.Error("failed to initialize metrics client", "error", err) 39 os.Exit(1) 40 } 41 defer metricsClient.Close() 42 logger.Info("Metrics client initialized") 43 44 rootServers, err := rootservers.DecodeRootHints(cfg.Server.RootHintsFile) 45 if err != nil || len(rootServers) == 0 { 46 logger.Error("Failed to load root hints or no root servers found", "path", cfg.Server.RootHintsFile, "error", err) 47 os.Exit(1) 48 } 49 logger.Info("Root hints loaded", "count", len(rootServers), "path", cfg.Server.RootHintsFile) 50 51 queryHandler := &dns.QueryHandler{ 52 RootServers: rootServers, 53 Timeout: cfg.Advanced.QueryTimeout.Duration, 54 Logger: logger.With("component", "resolver"), 55 } 56 57 var currentHandler dns.Handler = queryHandler 58 59 if cfg.Ratelimit.Rate > 0 { 60 rateLimitHandler := dns.RateLimitMiddleware(&dns.RateLimitConfig{ 61 Rate: float64(cfg.Ratelimit.Rate), 62 Burst: cfg.Ratelimit.Burst, 63 WindowLength: cfg.Ratelimit.Window.Duration, 64 ExpirationTime: cfg.Ratelimit.ExpirationTime.Duration, 65 })(currentHandler) 66 currentHandler = rateLimitHandler 67 logger.Info("Rate limiting enabled", "rate", cfg.Ratelimit.Rate, "burst", cfg.Ratelimit.Burst) 68 } else { 69 logger.Info("Rate limiting disabled") 70 } 71 72 metricsHandler := metrics.MetricsMiddleware(metricsClient)(currentHandler) 73 currentHandler = metricsHandler 74 logger.Info("Metrics middleware enabled") 75 76 loggingHandler := dns.LoggingMiddleware(&dns.LogConfig{ 77 Logger: logger.With("component", "query-log"), 78 Level: slog.LevelInfo, 79 })(currentHandler) 80 currentHandler = loggingHandler 81 logger.Info("Logging middleware enabled") 82 83 sigChan := make(chan os.Signal, 1) 84 signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) 85 go func() { 86 sig := <-sigChan 87 logger.Info("Received signal, shutting down gracefully...", "signal", sig.String()) 88 89 logger.Info("Closing metrics client...") 90 metricsClient.Close() 91 92 logger.Info("Shutdown complete.") 93 os.Exit(0) 94 }() 95 96 s := dns.Server{ 97 Address: cfg.Server.Address, 98 Port: cfg.Server.Port, 99 Handler: currentHandler, 100 UDPSize: cfg.Server.UDPSize, 101 ReadTimeout: cfg.Advanced.ReadTimeout.Duration, 102 WriteTimeout: cfg.Advanced.WriteTimeout.Duration, 103 Logger: logger.With("component", "server"), 104 } 105 106 logger.Info("Starting DNS server listener...") 107 if err := s.ListenAndServe(); err != nil { 108 logger.Error("Failed to start server", "error", err) 109 os.Exit(1) 110 } 111} 112 113func setupLogger(cfg *config.Config) *slog.Logger { 114 var logLevel slog.Level 115 switch strings.ToLower(cfg.Logging.Level) { 116 case "debug": 117 logLevel = slog.LevelDebug 118 case "info": 119 logLevel = slog.LevelInfo 120 case "warn": 121 logLevel = slog.LevelWarn 122 case "error": 123 logLevel = slog.LevelError 124 default: 125 logLevel = slog.LevelInfo 126 } 127 128 handlerOpts := &slog.HandlerOptions{ 129 Level: logLevel, 130 AddSource: logLevel <= slog.LevelDebug, 131 } 132 133 var handler slog.Handler 134 switch cfg.Logging.Output { 135 case "file": 136 f, err := os.OpenFile(cfg.Logging.FilePath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0o644) 137 if err != nil { 138 log.Fatalf("Failed to open log file %s: %v", cfg.Logging.FilePath, err) 139 } 140 handler = slog.NewJSONHandler(f, handlerOpts) 141 log.Printf("Logging to file: %s at level: %s", cfg.Logging.FilePath, logLevel.String()) 142 case "text": 143 handler = slog.NewTextHandler(os.Stdout, handlerOpts) 144 log.Printf("Logging to stdout (text) at level: %s", logLevel.String()) 145 default: 146 handler = slog.NewJSONHandler(os.Stdout, handlerOpts) 147 log.Printf("Logging to stdout (json) at level: %s", logLevel.String()) 148 } 149 150 logger := slog.New(handler) 151 return logger 152}