1package main
2
3import (
4 "flag"
5 "log"
6 "log/slog"
7 "os"
8 "os/signal"
9 "strings"
10 "syscall"
11
12 "tangled.sh/seiso.moe/alky/pkg/config"
13 "tangled.sh/seiso.moe/alky/pkg/dns"
14 "tangled.sh/seiso.moe/alky/pkg/metrics"
15 "tangled.sh/seiso.moe/alky/pkg/rootservers"
16)
17
18var configFlag string
19
20func init() {
21 flag.StringVar(&configFlag, "config", "/etc/alky/alky.toml", "config file path for alky")
22 flag.Parse()
23}
24
25func main() {
26 cfg, err := config.LoadConfig(configFlag)
27 if err != nil {
28 log.Fatalf("failed to load config: %v", err)
29 }
30
31 logger := setupLogger(&cfg)
32 slog.SetDefault(logger)
33
34 logger.Info("Configuration loaded", "path", configFlag)
35
36 metricsClient, err := metrics.NewClickHouseMetrics(&cfg.Metrics, logger.With("component", "metrics"))
37 if err != nil {
38 logger.Error("failed to initialize metrics client", "error", err)
39 os.Exit(1)
40 }
41 defer metricsClient.Close()
42 logger.Info("Metrics client initialized")
43
44 rootServers, err := rootservers.DecodeRootHints(cfg.Server.RootHintsFile)
45 if err != nil || len(rootServers) == 0 {
46 logger.Error("Failed to load root hints or no root servers found", "path", cfg.Server.RootHintsFile, "error", err)
47 os.Exit(1)
48 }
49 logger.Info("Root hints loaded", "count", len(rootServers), "path", cfg.Server.RootHintsFile)
50
51 queryHandler := &dns.QueryHandler{
52 RootServers: rootServers,
53 Timeout: cfg.Advanced.QueryTimeout.Duration,
54 Logger: logger.With("component", "resolver"),
55 }
56
57 var currentHandler dns.Handler = queryHandler
58
59 if cfg.Ratelimit.Rate > 0 {
60 rateLimitHandler := dns.RateLimitMiddleware(&dns.RateLimitConfig{
61 Rate: float64(cfg.Ratelimit.Rate),
62 Burst: cfg.Ratelimit.Burst,
63 WindowLength: cfg.Ratelimit.Window.Duration,
64 ExpirationTime: cfg.Ratelimit.ExpirationTime.Duration,
65 })(currentHandler)
66 currentHandler = rateLimitHandler
67 logger.Info("Rate limiting enabled", "rate", cfg.Ratelimit.Rate, "burst", cfg.Ratelimit.Burst)
68 } else {
69 logger.Info("Rate limiting disabled")
70 }
71
72 metricsHandler := metrics.MetricsMiddleware(metricsClient)(currentHandler)
73 currentHandler = metricsHandler
74 logger.Info("Metrics middleware enabled")
75
76 loggingHandler := dns.LoggingMiddleware(&dns.LogConfig{
77 Logger: logger.With("component", "query-log"),
78 Level: slog.LevelInfo,
79 })(currentHandler)
80 currentHandler = loggingHandler
81 logger.Info("Logging middleware enabled")
82
83 sigChan := make(chan os.Signal, 1)
84 signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
85 go func() {
86 sig := <-sigChan
87 logger.Info("Received signal, shutting down gracefully...", "signal", sig.String())
88
89 logger.Info("Closing metrics client...")
90 metricsClient.Close()
91
92 logger.Info("Shutdown complete.")
93 os.Exit(0)
94 }()
95
96 s := dns.Server{
97 Address: cfg.Server.Address,
98 Port: cfg.Server.Port,
99 Handler: currentHandler,
100 UDPSize: cfg.Server.UDPSize,
101 ReadTimeout: cfg.Advanced.ReadTimeout.Duration,
102 WriteTimeout: cfg.Advanced.WriteTimeout.Duration,
103 Logger: logger.With("component", "server"),
104 }
105
106 logger.Info("Starting DNS server listener...")
107 if err := s.ListenAndServe(); err != nil {
108 logger.Error("Failed to start server", "error", err)
109 os.Exit(1)
110 }
111}
112
113func setupLogger(cfg *config.Config) *slog.Logger {
114 var logLevel slog.Level
115 switch strings.ToLower(cfg.Logging.Level) {
116 case "debug":
117 logLevel = slog.LevelDebug
118 case "info":
119 logLevel = slog.LevelInfo
120 case "warn":
121 logLevel = slog.LevelWarn
122 case "error":
123 logLevel = slog.LevelError
124 default:
125 logLevel = slog.LevelInfo
126 }
127
128 handlerOpts := &slog.HandlerOptions{
129 Level: logLevel,
130 AddSource: logLevel <= slog.LevelDebug,
131 }
132
133 var handler slog.Handler
134 switch cfg.Logging.Output {
135 case "file":
136 f, err := os.OpenFile(cfg.Logging.FilePath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0o644)
137 if err != nil {
138 log.Fatalf("Failed to open log file %s: %v", cfg.Logging.FilePath, err)
139 }
140 handler = slog.NewJSONHandler(f, handlerOpts)
141 log.Printf("Logging to file: %s at level: %s", cfg.Logging.FilePath, logLevel.String())
142 case "text":
143 handler = slog.NewTextHandler(os.Stdout, handlerOpts)
144 log.Printf("Logging to stdout (text) at level: %s", logLevel.String())
145 default:
146 handler = slog.NewJSONHandler(os.Stdout, handlerOpts)
147 log.Printf("Logging to stdout (json) at level: %s", logLevel.String())
148 }
149
150 logger := slog.New(handler)
151 return logger
152}