a recursive dns resolver

Added better cache handling and metrics.

+4
Justfile
···
init:
+
docker-compose up -d
+
goose up
+
+
status:
goose status
format:
+7
docs/alky.toml
···
# This uses time.ParseDuration semantics
retention_period = "720h"
+
[cache]
+
# The maximum number of items to store in the cache.
+
max_items = 5000
+
+
# How often the cache will evict items.
+
cleanup_interval = "5m"
+
[advanced]
# Timeout (in milliseconds) for outgoing queries before being cancelled.
query_timeout = 100
+1 -1
main.go
···
log.Fatal(err)
}
-
cache := dns.NewMemoryCache(5000, 5*time.Minute)
+
cache := dns.NewMemoryCache(cfg.Cache.MaxItems, cfg.Cache.CleanupInterval.Duration)
handler := &dns.QueryHandler{
RootServers: rootServers,
+14
pkg/config/config.go
···
QueryTimeout int `toml:"query_timeout"`
}
+
type CacheConfig struct {
+
MaxItems int `toml:"max_items"`
+
CleanupInterval duration `toml:"cleanup_interval"`
+
}
+
type Config struct {
Server ServerConfig `toml:"server"`
Logging LoggingConfig `toml:"logging"`
Ratelimit RatelimitConfig `toml:"ratelimit"`
Metrics MetricsConfig `toml:"metrics"`
+
Cache CacheConfig `toml:"cache"`
Advanced AdvancedConfig `toml:"advanced"`
}
···
if cfg.Metrics.RetentionPeriod.Duration == 0 {
cfg.Metrics.RetentionPeriod.Duration = 30 * 24 * time.Hour
+
}
+
+
if cfg.Cache.MaxItems == 0 {
+
cfg.Cache.MaxItems = 5000
+
}
+
+
if cfg.Cache.CleanupInterval.Duration == 0 {
+
cfg.Cache.CleanupInterval.Duration = 5 * time.Minute
}
if cfg.Advanced.QueryTimeout == 0 {
+3 -42
pkg/dns/cache.go
···
package dns
import (
-
"log"
"strings"
"sync"
"time"
···
ExpireAt time.Time
CreateTime time.Time
IsNegative bool
-
Bailiwick string
}
type CachedResourceRecord struct {
···
NegativeHits int64
PositiveHits int64
Evictions int64
-
Size int
+
Size int64
}
type Cache interface {
Get(key string) (*CacheEntry, bool)
Set(key string, entry *CacheEntry)
-
Clear()
GetStats() *CacheStats
}
···
c.mu.RUnlock()
c.stats.TotalQueries++
-
-
log.Println("looking for: ", key)
-
for k := range c.entries {
-
log.Printf("\t%s\n", k)
-
}
-
-
log.Println(c.entries[key])
entry, exists := c.entries[key]
if !exists {
-
log.Println("cache miss")
c.stats.CacheMisses++
return nil, false
}
if time.Now().After(entry.ExpireAt) {
c.stats.CacheMisses++
-
log.Println("cache expire")
return nil, false
}
···
}
c.entries[key] = entry
-
c.stats.Size = len(c.entries)
-
}
-
-
func (c *MemoryCache) Clear() {
-
c.mu.Lock()
-
defer c.mu.Unlock()
-
-
c.entries = make(map[string]*CacheEntry)
-
c.stats = CacheStats{}
+
c.stats.Size = int64(len(c.entries))
}
func (c *MemoryCache) GetStats() *CacheStats {
···
}
}
-
c.stats.Size = len(c.entries)
+
c.stats.Size = int64(len(c.entries))
}
func getMinTTL(records []magna.ResourceRecord) uint32 {
···
entry := &CacheEntry{
CreateTime: now,
IsNegative: msg.Header.RCode == magna.NXDOMAIN,
-
Bailiwick: zone,
}
if entry.IsNegative {
···
return BailiwickOutside
}
-
-
func extractZone(msg *magna.Message) string {
-
for _, auth := range msg.Authority {
-
if auth.RType == magna.SOAType {
-
return auth.Name
-
}
-
-
if auth.RType == magna.NSType {
-
return auth.Name
-
}
-
}
-
-
if len(msg.Question) > 0 {
-
return msg.Question[0].QName
-
}
-
-
return ""
-
}
+20 -8
pkg/dns/resolve.go
···
import (
"context"
"fmt"
-
"log"
"log/slog"
"net"
"time"
···
cacheKey := fmt.Sprintf("%s:%s:%s", question.QName, question.QType.String(), question.QClass.String())
if h.Cache != nil {
-
log.Println("cache is set")
if entry, hit := h.Cache.Get(cacheKey); hit {
-
log.Println("womp womp here here", entry, hit)
-
log.Println("request context: ", getCurrentRequest(ctx))
now := time.Now()
if r := getCurrentRequest(ctx); r != nil {
···
}
func getCurrentRequest(ctx context.Context) *Request {
-
log.Println(ctx)
if ctx == nil {
-
log.Println(">>>>")
return nil
}
+
if r, ok := ctx.Value(contextKey("request")).(*Request); ok {
-
log.Println("<<<<")
return r
}
-
log.Println("====")
+
return nil
}
+
+
func extractZone(msg *magna.Message) string {
+
for _, auth := range msg.Authority {
+
if auth.RType == magna.SOAType {
+
return auth.Name
+
}
+
+
if auth.RType == magna.NSType {
+
return auth.Name
+
}
+
}
+
+
if len(msg.Question) > 0 {
+
return msg.Question[0].QName
+
}
+
+
return ""
+
}
+1 -4
pkg/dns/server.go
···
"encoding/binary"
"fmt"
"io"
-
"log"
"log/slog"
"net"
"sync"
···
)
func setCacheHit(ctx context.Context) context.Context {
-
log.Println("setting to true")
return context.WithValue(ctx, cacheHitKey, true)
}
func setCacheMiss(ctx context.Context) context.Context {
-
log.Println("setting to false")
return context.WithValue(ctx, cacheHitKey, false)
}
···
RemoteAddr: remoteAddr,
}
-
r.Context = context.WithValue(context.Background(), contextKey("request"), r)
+
r.Context = context.WithValue(context.Background(), contextKey("request"), r)
srv.Handler.ServeDNS(w, r)
}
+13 -5
pkg/metrics/clickhouse.go
···
import (
"database/sql"
"fmt"
-
"log"
"log/slog"
"sync"
"time"
···
CacheHits int64
CacheMisses int64
NegativeHits int64
+
PositiveHits int64
+
Evictions int64
+
Size int64
}
func NewClickHouseMetrics(config *config.MetricsConfig, logger *slog.Logger) (*ClickHouseMetrics, error) {
···
CacheHits: stats.CacheHits,
CacheMisses: stats.CacheMisses,
NegativeHits: stats.NegativeHits,
+
PositiveHits: stats.PositiveHits,
+
Evictions: stats.Evictions,
+
Size: stats.Size,
})
}
···
defer stmt.Close()
for _, metric := range m.queryBuffer {
-
log.Println(metric)
_, err := stmt.Exec(
metric.Timestamp,
metric.InstanceID,
···
stmt, err := tx.Prepare(`
INSERT INTO alky_dns_cache_metrics (
-
timestamp, total_queries, cache_hits, cache_misses,
-
negative_hits
-
) VALUES (?, ?, ?, ?, ?)
+
timestamp, instance_id, total_queries, cache_hits, cache_misses,
+
negative_hits, positive_hits, evictions, size
+
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
`)
if err != nil {
return err
···
for _, metric := range m.cacheBuffer {
_, err := stmt.Exec(
metric.Timestamp,
+
GetInstanceID(),
metric.TotalQueries,
metric.CacheHits,
metric.CacheMisses,
metric.NegativeHits,
+
metric.PositiveHits,
+
metric.Evictions,
+
metric.Size,
)
if err != nil {
return err
+9 -9
pkg/metrics/middleware.go
···
import (
"context"
"fmt"
-
"log"
"os"
"time"
···
)
var (
-
hostID string
-
version string
+
instanceID string
+
version string
)
func init() {
···
}
if version != "" {
-
hostID = fmt.Sprintf("%s-%s", hostname, version)
+
instanceID = fmt.Sprintf("%s-%s", hostname, version)
} else {
-
hostID = hostname
+
instanceID = hostname
}
}
···
duration := time.Since(start)
question := r.Message.Question[0]
-
log.Println(hostID)
-
log.Println(dns.GetCacheHit(r.Context))
-
metrics.RecordQuery(QueryMetric{
Timestamp: time.Now(),
-
InstanceID: hostID,
+
InstanceID: instanceID,
QueryName: question.QName,
QueryType: question.QType.String(),
QueryClass: question.QClass.String(),
···
})
}
}
+
+
func GetInstanceID() string {
+
return instanceID
+
}