a recursive dns resolver

ope

Changed files
+81 -503
pkg
config
dns
metrics
rootservers
+2 -2
go.mod
···
)
require (
-
github.com/ClickHouse/ch-go v0.65.1 // indirect
+
github.com/ClickHouse/ch-go v0.66.0 // indirect
github.com/andybalholm/brotli v1.1.1 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-faster/city v1.0.1 // indirect
···
github.com/shopspring/decimal v1.4.0 // indirect
go.opentelemetry.io/otel v1.35.0 // indirect
go.opentelemetry.io/otel/trace v1.35.0 // indirect
-
golang.org/x/sys v0.32.0 // indirect
+
golang.org/x/sys v0.33.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
+5
go.sum
···
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
github.com/ClickHouse/ch-go v0.65.1 h1:SLuxmLl5Mjj44/XbINsK2HFvzqup0s6rwKLFH347ZhU=
github.com/ClickHouse/ch-go v0.65.1/go.mod h1:bsodgURwmrkvkBe5jw1qnGDgyITsYErfONKAHn05nv4=
+
github.com/ClickHouse/ch-go v0.66.0 h1:hLslxxAVb2PHpbHr4n0d6aP8CEIpUYGMVT1Yj/Q5Img=
+
github.com/ClickHouse/ch-go v0.66.0/go.mod h1:noiHWyLMJAZ5wYuq3R/K0TcRhrNA8h7o1AqHX0klEhM=
+
github.com/ClickHouse/clickhouse-go v1.5.4 h1:cKjXeYLNWVJIx2J1K6H2CqyRmfwVJVY1OV1coaaFcI0=
github.com/ClickHouse/clickhouse-go/v2 v2.34.0 h1:Y4rqkdrRHgExvC4o/NTbLdY5LFQ3LHS77/RNFxFX3Co=
github.com/ClickHouse/clickhouse-go/v2 v2.34.0/go.mod h1:yioSINoRLVZkLyDzdMXPLRIqhDvel8iLBlwh6Iefso8=
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
···
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20=
golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
+
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
+
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
-43
main.go
···
"os/signal"
"strings"
"syscall"
-
"time"
"tangled.sh/seiso.moe/alky/pkg/config"
"tangled.sh/seiso.moe/alky/pkg/dns"
···
}
logger.Info("Root hints loaded", "count", len(rootServers), "path", cfg.Server.RootHintsFile)
-
cache := dns.NewLRUCache(cfg.Cache.MaxItems, cfg.Cache.CleanupInterval.Duration, logger.With("component", "cache"))
-
defer cache.Stop()
-
logger.Info("DNS cache initialized")
-
queryHandler := &dns.QueryHandler{
RootServers: rootServers,
Timeout: cfg.Advanced.QueryTimeout.Duration,
-
Cache: cache,
Logger: logger.With("component", "resolver"),
}
-
go monitorCacheMetrics(cache, metricsClient, logger.With("component", "cache-monitor"))
-
var currentHandler dns.Handler = queryHandler
if cfg.Ratelimit.Rate > 0 {
···
go func() {
sig := <-sigChan
logger.Info("Received signal, shutting down gracefully...", "signal", sig.String())
-
-
logger.Info("Stopping cache...")
-
cache.Stop()
logger.Info("Closing metrics client...")
metricsClient.Close()
···
if err := s.ListenAndServe(); err != nil {
logger.Error("Failed to start server", "error", err)
os.Exit(1)
-
}
-
}
-
-
func monitorCacheMetrics(cache dns.Cache, metricsClient *metrics.ClickHouseMetrics, logger *slog.Logger) {
-
interval := 1 * time.Minute
-
logger.Info("Starting cache metrics monitoring", "interval", interval)
-
ticker := time.NewTicker(interval)
-
defer ticker.Stop()
-
-
for {
-
select {
-
case <-ticker.C:
-
stats := cache.GetStats()
-
logger.Debug("Recording cache metrics",
-
"hits", stats.Hits.Load(),
-
"misses", stats.Misses.Load(),
-
"size", stats.Size.Load(),
-
"evictions", stats.Evictions.Load(),
-
"expired", stats.Expired.Load(),
-
"pos_hits", stats.PositiveHits.Load(),
-
"neg_hits", stats.NegativeHits.Load(),
-
)
-
metricsClient.RecordCacheStats(metrics.CacheMetric{
-
Timestamp: time.Now(),
-
CacheHits: stats.Hits.Load(),
-
CacheMisses: stats.Misses.Load(),
-
NegativeHits: stats.NegativeHits.Load(),
-
PositiveHits: stats.PositiveHits.Load(),
-
Evictions: stats.Evictions.Load() + stats.Expired.Load(),
-
Size: stats.Size.Load(),
-
})
-
}
}
}
+2 -26
pkg/config/config.go
···
import (
"fmt"
"os"
+
"slices"
"strings"
"time"
···
WriteTimeout duration `toml:"write_timeout"`
}
-
type CacheConfig struct {
-
MaxItems int `toml:"max_items"`
-
CleanupInterval duration `toml:"cleanup_interval"`
-
}
-
type Config struct {
Server ServerConfig `toml:"server"`
Logging LoggingConfig `toml:"logging"`
Ratelimit RatelimitConfig `toml:"ratelimit"`
Metrics MetricsConfig `toml:"metrics"`
-
Cache CacheConfig `toml:"cache"`
Advanced AdvancedConfig `toml:"advanced"`
}
···
cfg.Metrics.FlushInterval.Duration = 10 * time.Second
cfg.Metrics.RetentionPeriod.Duration = 30 * 24 * time.Hour
-
cfg.Cache.MaxItems = 10000
-
cfg.Cache.CleanupInterval.Duration = 5 * time.Minute
-
cfg.Advanced.QueryTimeout.Duration = 5 * time.Second
cfg.Advanced.ReadTimeout.Duration = 2 * time.Second
cfg.Advanced.WriteTimeout.Duration = 2 * time.Second
···
}
validLevels := []string{"debug", "info", "warn", "error"}
-
isValidLevel := false
-
for _, level := range validLevels {
-
if strings.ToLower(cfg.Logging.Level) == level {
-
isValidLevel = true
-
break
-
}
-
}
-
-
if !isValidLevel {
+
if !slices.Contains(validLevels, cfg.Logging.Level) {
return cfg, fmt.Errorf("invalid logging level '%s', must be one of: %v", cfg.Logging.Level, validLevels)
-
}
-
-
if cfg.Cache.MaxItems <= 0 {
-
return cfg, fmt.Errorf("cache max_items must be positive")
-
}
-
-
if cfg.Cache.CleanupInterval.Duration <= 0 {
-
cfg.Cache.CleanupInterval.Duration = 5 * time.Minute
}
if cfg.Advanced.QueryTimeout.Duration <= 0 {
+62 -316
pkg/dns/resolve.go
···
"context"
"fmt"
"log/slog"
-
"math/rand"
"net"
"strings"
"time"
···
type QueryHandler struct {
RootServers []string
Timeout time.Duration
-
Cache Cache
Logger *slog.Logger
}
···
msg := r.Message.CreateReply(r.Message)
msg.Header.RA = true
-
-
if h.Cache != nil {
-
cacheKey := GenerateCacheKey(question)
-
if entry, found := h.Cache.Get(cacheKey); found {
-
h.Logger.Debug("cache hit", "question", question.QName, "type", question.QType)
-
-
r.Context = setCacheHit(r.Context)
-
-
msg.Header.RCode = entry.RCode
-
msg.Answer = entry.Answer
-
msg.Authority = entry.Authority
-
msg.Additional = entry.Additional
-
msg.Header.ANCount = uint16(len(msg.Answer))
-
msg.Header.NSCount = uint16(len(msg.Authority))
-
msg.Header.ARCount = uint16(len(msg.Additional))
-
-
w.WriteMsg(msg)
-
return
-
} else {
-
r.Context = setCacheMiss(r.Context)
-
}
-
}
records, err := h.resolveQuestion(r.Context, question, h.RootServers)
···
msg = msg.SetRCode(magna.NOERROR)
}
-
if h.Cache != nil {
-
if err == nil || err == errNXDOMAIN {
-
entry := CreateCacheEntry(question, msg)
-
if entry != nil {
-
h.Cache.Set(GenerateCacheKey(question), entry)
-
h.Logger.Debug("cached response",
-
"question", question.QName,
-
"type", question.QType.String(),
-
"negative", entry.IsNegative,
-
"rcode", entry.RCode.String(),
-
"expires", entry.ExpireAt.Sub(time.Now()).String())
-
}
-
}
-
}
-
w.WriteMsg(msg)
}
func (h *QueryHandler) resolveQuestion(ctx context.Context, question magna.Question, servers []string) ([]magna.ResourceRecord, error) {
-
depth := 0
-
if d, ok := ctx.Value("recursion_depth").(int); ok {
-
depth = d
-
}
+
ch := make(chan queryResponse, len(servers))
-
// TODO move to configuration file
-
if depth > 15 {
-
h.Logger.Warn("Max recursion depth exceeded", "question", question.QName, "depth", depth)
-
return nil, fmt.Errorf("max recursion depth exceeded")
-
}
-
ctx = context.WithValue(ctx, "recursion_depth", depth+1)
-
-
if h.Cache != nil {
-
cacheKey := GenerateCacheKey(question)
-
if entry, found := h.Cache.Get(cacheKey); found {
-
h.Logger.Debug("cache hit during recursion", "question", question.QName, "type", question.QType)
-
-
if entry.RCode == magna.NXDOMAIN {
-
return entry.Authority, errNXDOMAIN
-
}
+
labels = strings.Split(question.QName, ".")
-
if entry.IsNegative {
-
return []magna.ResourceRecord{}, nil
-
}
-
-
return entry.Answer, nil
-
}
-
}
-
-
resolveCtx, cancel := context.WithTimeout(ctx, h.Timeout)
-
defer cancel()
-
-
ch := make(chan queryResponse, len(servers))
-
+
for label in
for _, s := range servers {
-
go queryServer(resolveCtx, question, s, ch)
+
go queryServer(ctx, question, s, ch, h.Timeout)
}
-
-
var lastError error
for range servers {
select {
case res := <-ch:
if res.Error != nil {
-
h.Logger.Warn("error", "question", question.QName, "server", res.Server, "error", res.Error)
-
lastError = res.Error
-
continue
+
break
}
msg := res.MSG
-
-
fullMsg := magna.Message{
-
Header: msg.Header,
-
Question: []magna.Question{question},
-
Answer: msg.Answer,
-
Authority: msg.Authority,
-
Additional: msg.Additional,
-
}
-
fullMsg.Header.ANCount = uint16(len(msg.Answer))
-
fullMsg.Header.NSCount = uint16(len(msg.Authority))
-
fullMsg.Header.ARCount = uint16(len(msg.Additional))
-
-
if msg.Header.RCode == magna.NXDOMAIN {
-
if h.Cache != nil {
-
entry := CreateCacheEntry(question, &fullMsg)
-
if entry != nil {
-
h.Cache.Set(GenerateCacheKey(question), entry)
-
}
-
}
-
return msg.Authority, errNXDOMAIN
-
}
-
if msg.Header.ANCount > 0 {
if msg.Answer[0].RType == magna.CNAMEType {
-
if h.Cache != nil {
-
partialEntry := CreateCacheEntry(question, &fullMsg)
-
if partialEntry != nil {
-
h.Cache.Set(GenerateCacheKey(question), partialEntry)
-
}
-
}
-
-
cnameTarget := msg.Answer[0].RData.String()
-
h.Logger.Debug("following CNAME", "from", question.QName, "to", cnameTarget)
-
-
cname_answers, err := h.resolveQuestion(resolveCtx, magna.Question{
-
QName: cnameTarget,
-
QType: question.QType,
-
QClass: question.QClass,
-
}, h.RootServers)
+
cname_answers, err := h.resolveQuestion(ctx, magna.Question{QName: msg.Answer[0].RData.String(), QType: question.QType, QClass: question.QClass}, h.RootServers)
if err != nil {
-
if err == errNXDOMAIN {
-
return cname_answers, errNXDOMAIN
-
}
-
h.Logger.Warn("error resolving CNAME", "cname", cnameTarget, "error", err)
continue
}
-
-
combinedAnswer := append([]magna.ResourceRecord{}, msg.Answer[0])
-
combinedAnswer = append(combinedAnswer, cname_answers...)
-
-
fullMsg.Answer = combinedAnswer
-
if h.Cache != nil {
-
fullEntry := CreateCacheEntry(question, &fullMsg)
-
if fullEntry != nil {
-
h.Cache.Set(GenerateCacheKey(question), fullEntry)
-
}
-
}
-
-
return combinedAnswer, nil
+
msg.Answer = append(msg.Answer, cname_answers...)
}
-
if h.Cache != nil {
-
entry := CreateCacheEntry(question, &fullMsg)
-
if entry != nil {
-
h.Cache.Set(GenerateCacheKey(question), entry)
-
}
-
}
return msg.Answer, nil
}
-
if msg.Header.NSCount > 0 {
-
var nextServers []string
-
var nsRecords []magna.ResourceRecord
-
nsRecordsWithoutGlue := 0
-
-
glueMap := make(map[string]string)
-
for _, rr := range msg.Additional {
-
if rr.RType == magna.AType {
-
glueMap[strings.ToLower(rr.Name)] = rr.RData.String()
+
if msg.Header.ARCount > 0 {
+
var nextZone []string
+
for _, ans := range msg.Additional {
+
if ans.RType == magna.AType {
+
nextZone = append(nextZone, ans.RData.String())
}
}
-
for _, rr := range msg.Authority {
-
if rr.RType == magna.NSType {
-
nsRecords = append(nsRecords, rr)
-
}
-
}
+
return h.resolveQuestion(ctx, question, nextZone)
+
}
-
for _, nsRR := range nsRecords {
-
nsName := strings.ToLower(nsRR.RData.String())
-
-
if ip, found := glueMap[nsName]; found {
-
nextServers = append(nextServers, ip)
-
} else {
-
nsRecordsWithoutGlue++
-
nsQuestion := magna.Question{QName: nsName, QType: magna.AType, QClass: magna.IN}
-
nsAnswers, err := h.resolveQuestion(resolveCtx, nsQuestion, h.RootServers)
+
if msg.Header.NSCount > 0 {
+
var ns []string
+
for _, a := range msg.Authority {
+
if a.RType == magna.NSType {
+
ans, err := h.resolveQuestion(ctx, magna.Question{QName: a.RData.String(), QType: magna.AType, QClass: magna.IN}, h.RootServers)
if err != nil {
-
h.Logger.Warn("error resolving NS A record", "ns", nsName, "error", err)
-
continue
-
} else {
-
foundAddr := false
-
for _, ans := range nsAnswers {
-
if ans.RType == magna.AType && strings.ToLower(ans.Name) == nsName {
-
ipAddr := ans.RData.String()
-
nextServers = append(nextServers, ipAddr)
-
foundAddr = true
-
}
-
}
-
if !foundAddr {
-
h.Logger.Warn("A record lookup for NS succeeded but yielded no matching A records", "ns", nsName)
-
}
+
break
}
-
-
}
-
}
-
-
if len(nextServers) > 0 {
-
return h.resolveQuestion(ctx, question, nextServers)
-
} else {
-
h.Logger.Warn("Processed delegation but failed to find any next server IPs", "question", question.QName, "server", res.Server, "ns_records_count", len(nsRecords), "ns_without_glue", nsRecordsWithoutGlue, "depth", depth)
-
lastError = fmt.Errorf("failed to resolve NS addresses for delegation from %s", res.Server)
-
}
-
-
h.Logger.Warn("could not resolve any NS records for delegation", "question", question.QName)
-
continue
-
}
-
-
if msg.Header.RCode == magna.NOERROR && msg.Header.ANCount == 0 && msg.Header.NSCount == 0{
-
h.Logger.Debug("Received NODATA response (NOERROR, ANCount=0)", "question", question.QName, "server", res.Server, "depth", depth)
-
-
if h.Cache != nil {
-
entry := CreateCacheEntry(question, &fullMsg)
-
if entry != nil {
-
h.Cache.Set(GenerateCacheKey(question), entry)
+
for _, x := range ans {
+
ns = append(ns, x.RData.String())
+
}
}
}
-
return []magna.ResourceRecord{}, nil
+
return h.resolveQuestion(ctx, question, ns)
}
-
h.Logger.Warn("Unhandled response state", "question", question.QName, "server", res.Server, "rcode", msg.Header.RCode.String(), "ancount", msg.Header.ANCount, "nscount", msg.Header.NSCount, "depth", depth)
-
lastError = fmt.Errorf("unhandled response code %s from %s", msg.Header.RCode.String(), res.Server)
-
-
case <-resolveCtx.Done():
-
h.Logger.Warn("Resolution step timed out or cancelled", "question", question.QName, "elapsed", time.Since(time.Now().Add(-h.Timeout)), "depth", depth, "error", resolveCtx.Err()) // Approx elapsed
-
if lastError != nil {
-
return nil, fmt.Errorf("resolution timed out after error: %w", lastError)
-
}
-
return nil, fmt.Errorf("resolution step timed out or cancelled: %w", resolveCtx.Err())
+
return []magna.ResourceRecord{}, nil
}
}
-
h.Logger.Warn("all resolution paths failed", "question", question.QName)
-
if lastError != nil {
-
return nil, fmt.Errorf("all resolution paths failed: %w", lastError)
-
}
-
return []magna.ResourceRecord{}, fmt.Errorf("all resolution paths failed")
+
return []magna.ResourceRecord{}, nil
}
-
func queryServer(ctx context.Context, question magna.Question, server string, ch chan<- queryResponse) {
-
var response queryResponse
-
response.Server = server
-
-
bufPtr := resolverUDPBufferPool.Get().(*[]byte)
-
defer resolverUDPBufferPool.Put(bufPtr)
-
p := *bufPtr
+
func queryServer(ctx context.Context, question magna.Question, server string, ch chan<- queryResponse, timeout time.Duration) {
+
done := make(chan struct{}, 1)
-
defer func() {
-
if ctx.Err() != nil {
-
if response.Error == nil {
-
response.Error = ctx.Err()
-
}
+
go func() {
+
conn, err := net.Dial("udp", fmt.Sprintf("%s:53", server))
+
if err != nil {
+
ch <- queryResponse{Error: err}
+
return
}
+
defer conn.Close()
-
select {
-
case ch <- response:
-
case <-ctx.Done():
-
slog.Debug("queryServer response channel send cancelled by context", "server", server, "question", question.QName)
-
default:
-
slog.Debug("queryServer response channel blocked or closed on send", "server", server, "question", question.QName)
+
query := magna.CreateRequest(0, false)
+
query = query.AddQuestion(question)
+
msg, err := query.Encode()
+
if err != nil {
+
ch <-queryResponse{Server: server, Error: err}
}
-
}()
-
select {
-
case <-ctx.Done():
-
response.Error = ctx.Err()
-
return
-
default:
-
}
-
-
dialer := net.Dialer{Timeout: 2 * time.Second}
-
conn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf("%s:53", server))
-
if err != nil {
-
if ctx.Err() != nil {
-
response.Error = ctx.Err()
-
} else {
-
response.Error = fmt.Errorf("dial error: %w", err)
+
if _, err := conn.Write(msg); err != nil {
+
ch <- queryResponse{Server: server, Error: err}
+
return
}
-
return
-
}
-
defer conn.Close()
-
-
deadline, ok := ctx.Deadline()
-
if !ok {
-
deadline = time.Now().Add(5 * time.Second)
-
}
-
conn.SetDeadline(deadline)
-
-
query := magna.Message{
-
Header: magna.Header{
-
ID: uint16(rand.Int() % 65535),
-
QR: false,
-
OPCode: magna.QUERY,
-
AA: false,
-
TC: false,
-
RD: false,
-
RA: false,
-
Z: 0,
-
RCode: magna.NOERROR,
-
QDCount: 1,
-
ARCount: 0,
-
NSCount: 0,
-
ANCount: 0,
-
},
-
Question: []magna.Question{question},
-
}
-
-
msgBytes, err := query.Encode()
-
if err != nil {
-
response.Error = fmt.Errorf("encode error: %w", err)
-
return
-
}
-
-
_, err = conn.Write(msgBytes)
-
if err != nil {
-
if ctx.Err() != nil {
-
response.Error = ctx.Err()
-
} else {
-
response.Error = fmt.Errorf("write error: %w", err)
-
}
-
return
-
}
+
p := make([]byte, 512)
+
nn, err := conn.Read(p)
-
n, err := conn.Read(p)
-
if err != nil {
-
if ctx.Err() != nil {
-
response.Error = ctx.Err()
-
} else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
-
response.Error = fmt.Errorf("read timeout: %w", err)
-
} else {
-
response.Error = fmt.Errorf("read error: %w", err)
+
// TODO: retry request with TCP
+
if err != nil || nn > 512 {
+
if err == nil {
+
err = fmt.Errorf("truncated response")
+
}
+
ch <- queryResponse{Server: server, Error: err}
+
return
}
-
return
-
}
-
// TODO: retry request with TCP
-
if n > 512 {
-
response.Error = fmt.Errorf("response possibly truncated (size %d >= 512)", n)
-
return
-
}
+
var response magna.Message
+
err = response.Decode(p)
+
ch <- queryResponse{MSG: response, Server: server, Error: err}
+
}()
-
var decodedMsg magna.Message
-
err = decodedMsg.Decode(p[:n])
-
if err != nil {
-
response.Error = fmt.Errorf("decode error: %w", err)
-
return
+
select {
+
case <-ctx.Done():
+
ch <- queryResponse{Server: server, Error: ctx.Err()}
+
case <-done:
+
// goroutine finished with no cancellation
+
case <-time.After(timeout):
+
ch <- queryResponse{Server: server, Error: fmt.Errorf("timeout")}
}
-
-
if decodedMsg.Header.ID != query.Header.ID {
-
response.Error = fmt.Errorf("response ID mismatch (got %d, expected %d)", decodedMsg.Header.ID, query.Header.ID)
-
return
-
}
-
-
response.MSG = decodedMsg
-
response.Error = nil
}
+4 -19
pkg/dns/server.go
···
type contextKey string
const (
-
cacheHitKey contextKey = "cache_hit"
maxUDPBufferSize = 4096
maxResolverUDPBufferSize = 4096
)
···
},
}
)
-
-
func setCacheHit(ctx context.Context) context.Context {
-
return context.WithValue(ctx, cacheHitKey, true)
-
}
-
-
func setCacheMiss(ctx context.Context) context.Context {
-
return context.WithValue(ctx, cacheHitKey, false)
-
}
-
-
func GetCacheHit(ctx context.Context) bool {
-
if ctx == nil {
-
return false
-
}
-
v := ctx.Value(cacheHitKey)
-
hit, ok := v.(bool)
-
return ok && hit
-
}
type Handler interface {
ServeDNS(ResponseWriter, *Request)
···
return
}
+
ctx, cancel := context.WithTimeout(context.Background(), srv.ReadTimeout)
+
defer cancel()
+
r := &Request{
-
Context: context.Background(),
+
Context: ctx,
Message: &query,
RemoteAddr: remoteAddr,
}
+2 -97
pkg/metrics/clickhouse.go
···
db *sql.DB
config *config.MetricsConfig
queryBuffer []QueryMetric
-
cacheBuffer []CacheMetric
mu sync.Mutex
stopChan chan struct{}
wg sync.WaitGroup
···
RemoteAddr string
ResponseCode string
Duration int64
-
CacheHit bool
-
}
-
-
type CacheMetric struct {
-
Timestamp time.Time
-
CacheHits int64
-
CacheMisses int64
-
NegativeHits int64
-
PositiveHits int64
-
Evictions int64
-
Size int64
}
func NewClickHouseMetrics(config *config.MetricsConfig, logger *slog.Logger) (*ClickHouseMetrics, error) {
···
db: db,
config: config,
queryBuffer: make([]QueryMetric, 0, config.BatchSize),
-
cacheBuffer: make([]CacheMetric, 0, config.BatchSize),
stopChan: make(chan struct{}),
logger: logger,
}
···
}
}
-
func (m *ClickHouseMetrics) RecordCacheStats(metric CacheMetric) {
-
m.mu.Lock()
-
defer m.mu.Unlock()
-
-
m.cacheBuffer = append(m.cacheBuffer, metric)
-
if len(m.cacheBuffer) >= m.config.BatchSize {
-
m.flushCacheMetricsLocked()
-
}
-
}
-
func (m *ClickHouseMetrics) flushLoop() {
defer m.wg.Done()
ticker := time.NewTicker(m.config.FlushInterval.Duration)
···
if len(m.queryBuffer) > 0 {
m.flushQueriesLocked()
}
-
if len(m.cacheBuffer) > 0 {
-
m.flushCacheMetricsLocked()
-
}
}
func (m *ClickHouseMetrics) checkAndUpdateTTL() error {
···
}
if err := updateTableTTL("alky_dns_queries"); err != nil {
-
return err
-
}
-
if err := updateTableTTL("alky_dns_cache_metrics"); err != nil {
return err
}
···
INSERT INTO alky_dns_queries (
timestamp, instance_id, query_name, query_type, query_class,
remote_addr, response_code, duration, cache_hit
-
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
+
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
`)
if err != nil {
m.logger.Error("Failed to prepare statement for query metrics", "error", err)
···
metric.RemoteAddr,
metric.ResponseCode,
metric.Duration,
-
metric.CacheHit,
+
false,
)
if err != nil {
m.logger.Error("Failed to execute statement for query metric", "error", err, "metric_index", count)
···
m.logger.Debug("Successfully flushed query metrics", "count", count, "duration", time.Since(start))
m.queryBuffer = m.queryBuffer[:0]
-
}
-
-
func (m *ClickHouseMetrics) flushCacheMetricsLocked() {
-
if len(m.cacheBuffer) == 0 {
-
return
-
}
-
-
m.logger.Debug("Flushing cache metrics", "count", len(m.cacheBuffer))
-
start := time.Now()
-
-
tx, err := m.db.Begin()
-
if err != nil {
-
m.logger.Error("Failed to begin transaction for cache metrics", "error", err)
-
m.cacheBuffer = m.cacheBuffer[:0]
-
return
-
}
-
defer func() {
-
if err != nil {
-
if rbErr := tx.Rollback(); rbErr != nil {
-
m.logger.Error("Failed to rollback transaction for cache metrics", "error", rbErr)
-
}
-
}
-
}()
-
-
stmt, err := tx.Prepare(`
-
INSERT INTO alky_dns_cache_metrics (
-
timestamp, instance_id, cache_hits, cache_misses,
-
negative_hits, positive_hits, evictions, size
-
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
-
`)
-
if err != nil {
-
m.logger.Error("Failed to prepare statement for cache metrics", "error", err)
-
m.cacheBuffer = m.cacheBuffer[:0]
-
return
-
}
-
defer stmt.Close()
-
-
count := 0
-
for _, metric := range m.cacheBuffer {
-
instanceID := GetInstanceID()
-
-
_, err = stmt.Exec(
-
metric.Timestamp,
-
instanceID,
-
metric.CacheHits,
-
metric.CacheMisses,
-
metric.NegativeHits,
-
metric.PositiveHits,
-
metric.Evictions,
-
metric.Size,
-
)
-
if err != nil {
-
m.logger.Error("Failed to execute statement for cache metric", "error", err, "metric_index", count)
-
return
-
}
-
count++
-
}
-
-
err = tx.Commit()
-
if err != nil {
-
m.logger.Error("Failed to commit transaction for cache metrics", "error", err)
-
return
-
}
-
-
m.logger.Debug("Successfully flushed cache metrics", "count", count, "duration", time.Since(start))
-
m.cacheBuffer = m.cacheBuffer[:0]
}
func (m *ClickHouseMetrics) Close() error {
+4
pkg/rootservers/loader.go
···
}
for _, line := range bytes.Split(data, []byte{'\n'}) {
+
if len(line) == 0 {
+
continue
+
}
+
// skip comments
if line[0] == ';' {
continue