a recursive dns resolver

add rate limiting and caching

+22
docs/alky.toml
···
# This is used only if logging.output is "file".
file_path = "/var/log/alky.log"
+
[ratelimit]
+
# rate: The steady-state rate of requests allowed (r in GCRA)
+
# Defines how many requests are allowed per second in normal conditions
+
# Type: Integer
+
rate = 250
+
+
# burst: Maximum number of requests allowed to exceed the steady-state rate temporarily
+
# Allows for short bursts of traffic above the defined rate
+
# Type: Integer
+
burst = 500
+
+
# window: The interval (in seconds) at which the rate limit is checked and potentially reset
+
# Implements a sliding window rate limit mechanism
+
# Type: Integer
+
window = 3
+
+
# expiration_time: Duration (in seconds) for keeping a client's rate limit data in memory
+
# After this period of inactivity, a client's rate limit data is removed to free up memory
+
# Type: Integer
+
expiration_time = 300
+
+
[advanced]
# Timeout (in milliseconds) for outgoing queries before being cancelled.
query_timeout = 100
+1 -1
go.mod
···
go 1.22.5
require (
-
code.kiri.systems/kiri/magna v0.0.0-20240721214902-8d0a079dbd84
+
code.kiri.systems/kiri/magna v0.0.0-20240922043826-2c2a1c508469
github.com/BurntSushi/toml v1.4.0
)
+2
go.sum
···
code.kiri.systems/kiri/magna v0.0.0-20240721214902-8d0a079dbd84 h1:igzBX4k3REg0WZExjGLWW7/wu/X+U6QlbMc8aeO2030=
code.kiri.systems/kiri/magna v0.0.0-20240721214902-8d0a079dbd84/go.mod h1:gSzCiTKyKlUEjGgl/qTb8rxF0QUVuWOEORAsTXA0qyI=
+
code.kiri.systems/kiri/magna v0.0.0-20240922043826-2c2a1c508469 h1:LUvvGcJ7DuW3eo7yblNH2igCJzYsbWJQ08iZEXBWplc=
+
code.kiri.systems/kiri/magna v0.0.0-20240922043826-2c2a1c508469/go.mod h1:gSzCiTKyKlUEjGgl/qTb8rxF0QUVuWOEORAsTXA0qyI=
github.com/BurntSushi/toml v1.4.0 h1:kuoIxZQy2WRRk1pttg9asf+WVv6tWQuBNVmK8+nqPr0=
github.com/BurntSushi/toml v1.4.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
+31 -11
main.go
···
package main
import (
+
"flag"
"log"
"log/slog"
"os"
-
"flag"
+
"time"
"code.kiri.systems/kiri/alky/pkg/config"
"code.kiri.systems/kiri/alky/pkg/dns"
···
var configFlag string
func init() {
-
flag.StringVar(&configFlag, "config", "/etc/alky/alky.toml", "config file path for alky")
+
flag.StringVar(&configFlag, "config", "/etc/alky/alky.toml", "config file path for alky")
-
flag.Parse()
+
flag.Parse()
}
func main() {
···
logger = slog.New(slog.NewJSONHandler(os.Stdout, nil))
}
-
s := dns.Server{
-
Address: cfg.Server.Address,
-
Port: cfg.Server.Port,
-
Timeout: cfg.Advanced.QueryTimeout,
+
memCache := dns.NewMemoryCache()
+
var cache dns.Cache = memCache
+
+
handler := &dns.QueryHandler{
RootServers: rootServers,
+
Timeout: time.Duration(cfg.Advanced.QueryTimeout) * time.Second,
+
Cache: &cache,
+
}
+
+
logConfig := &dns.LogConfig{Logger: logger}
+
+
rateLimitHandler := dns.RateLimitMiddleware(&dns.RateLimitConfig{
+
Rate: float64(cfg.Ratelimit.Rate),
+
Burst: cfg.Ratelimit.Burst,
+
WindowLength: time.Duration(cfg.Ratelimit.Window) * time.Second,
+
ExpirationTime: time.Duration(cfg.Ratelimit.ExpirationTime) * time.Second,
+
})(handler)
+
loggingHandler := dns.LoggingMiddleware(logConfig)(rateLimitHandler)
+
+
s := dns.Server{
+
Address: cfg.Server.Address,
+
Port: cfg.Server.Port,
+
Handler: loggingHandler,
+
UDPSize: 512,
+
ReadTimeout: 2 * time.Second,
+
WriteTimeout: 2 * time.Second,
Logger: logger,
}
-
go s.TCPListenAndServe()
-
go s.UDPListenAndServe()
-
-
for {
+
if err := s.ListenAndServe(); err != nil {
+
slog.Error("Failed to start server", "error", err)
}
}
+11 -3
pkg/config/config.go
···
FilePath string `toml:"file_path"`
}
+
type RatelimitConfig struct {
+
Rate int `toml:"rate"`
+
Burst int `toml:"burst"`
+
Window int `toml:"window"`
+
ExpirationTime int `toml:"expiration_time"`
+
}
+
type AdvancedConfig struct {
QueryTimeout int `toml:"query_timeout"`
}
type Config struct {
-
Server ServerConfig `toml:"server"`
-
Logging LoggingConfig `toml:"logging"`
-
Advanced AdvancedConfig `toml:"advanced"`
+
Server ServerConfig `toml:"server"`
+
Logging LoggingConfig `toml:"logging"`
+
Ratelimit RatelimitConfig `toml:"ratelimit"`
+
Advanced AdvancedConfig `toml:"advanced"`
}
func LoadConfig(path string) (Config, error) {
+52
pkg/dns/cache.go
···
+
package dns
+
+
import (
+
"sync"
+
"time"
+
+
"code.kiri.systems/kiri/magna"
+
)
+
+
type CachedResourceRecord struct {
+
Record magna.ResourceRecord
+
ExpireAt time.Time
+
}
+
+
type CacheEntry struct {
+
Answer []CachedResourceRecord
+
}
+
+
type Cache interface {
+
Get(key string) (*CacheEntry, bool)
+
Set(key string, entry *CacheEntry)
+
}
+
+
type MemoryCache struct {
+
entries map[string]*CacheEntry
+
mu sync.RWMutex
+
}
+
+
func NewMemoryCache() *MemoryCache {
+
return &MemoryCache{
+
entries: make(map[string]*CacheEntry),
+
}
+
}
+
+
func (c *MemoryCache) Get(key string) (*CacheEntry, bool) {
+
c.mu.RLock()
+
c.mu.RUnlock()
+
+
entry, exists := c.entries[key]
+
if !exists {
+
return nil, false
+
}
+
+
return entry, true
+
}
+
+
func (c *MemoryCache) Set(key string, entry *CacheEntry) {
+
c.mu.Lock()
+
defer c.mu.Unlock()
+
+
c.entries[key] = entry
+
}
-300
pkg/dns/dns.go
···
package dns
-
-
import (
-
"context"
-
"encoding/binary"
-
"fmt"
-
"io"
-
"log/slog"
-
"math/rand/v2"
-
"net"
-
"time"
-
-
"code.kiri.systems/kiri/magna"
-
)
-
-
type Server struct {
-
Address string
-
Port int
-
Timeout int
-
RootServers []string
-
-
Logger *slog.Logger
-
}
-
-
type queryResponse struct {
-
MSG magna.Message
-
Server string
-
Error error
-
}
-
-
func (s *Server) UDPListenAndServe() error {
-
addr := net.UDPAddr{
-
Port: s.Port,
-
IP: net.ParseIP(s.Address),
-
}
-
server, err := net.ListenUDP("udp", &addr)
-
if err != nil {
-
return err
-
}
-
defer server.Close()
-
-
for {
-
b := make([]byte, 512)
-
_, remote_addr, err := server.ReadFromUDP(b)
-
if err != nil {
-
s.Logger.Warn(err.Error())
-
continue
-
}
-
-
start := time.Now()
-
msg := s.processQuery(b)
-
s.Logger.Info("query", "class", msg.Question[0].QClass.String(), "type", msg.Question[0].QType.String(), "name", msg.Question[0].QName, "rcode", msg.Header.RCode.String(), "remote_addr", remote_addr.IP, "time_taken", time.Since(start).Nanoseconds())
-
if err != nil {
-
s.Logger.Warn(err.Error())
-
continue
-
}
-
-
ans := msg.Encode()
-
// xxx: set the TC bit if the message is over 512 bytes
-
if len(ans) > 512 {
-
ans[3] |= 1 << 6
-
}
-
-
if _, err := server.WriteToUDP(ans, remote_addr); err != nil {
-
s.Logger.Warn("sending response", "err", err.Error())
-
}
-
}
-
}
-
-
func (s *Server) TCPListenAndServe() error {
-
addr := net.TCPAddr{
-
Port: s.Port,
-
IP: net.ParseIP(s.Address),
-
}
-
-
server, err := net.ListenTCP("tcp", &addr)
-
if err != nil {
-
return err
-
}
-
defer server.Close()
-
-
for {
-
conn, err := server.Accept()
-
if err != nil {
-
s.Logger.Warn("conn error:", err)
-
continue
-
}
-
-
sizeBuffer := make([]byte, 2)
-
if _, err := io.ReadFull(conn, sizeBuffer); err != nil {
-
s.Logger.Warn("tcp-error", err)
-
continue
-
}
-
-
size := binary.BigEndian.Uint16(sizeBuffer)
-
-
data := make([]byte, size)
-
if _, err := io.ReadFull(conn, data); err != nil {
-
s.Logger.Warn("tcp-error", err)
-
continue
-
}
-
-
start := time.Now()
-
msg := s.processQuery(data)
-
s.Logger.Info("query", "class", msg.Question[0].QClass.String(), "type", msg.Question[0].QType.String(), "name", msg.Question[0].QName, "rcode", msg.Header.RCode.String(), "remote_addr", conn.RemoteAddr(), "time_taken", time.Since(start).Nanoseconds())
-
-
ans := msg.Encode()
-
conn.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(ans))))
-
if _, err := conn.Write(ans); err != nil {
-
s.Logger.Error("tcp-error", err)
-
}
-
}
-
}
-
-
func (s *Server) processQuery(messageBuffer []byte) (msg magna.Message) {
-
var query magna.Message
-
if err := query.Decode(messageBuffer); err != nil {
-
slog.Warn("decode", err)
-
return
-
}
-
-
msg = magna.Message{
-
Header: magna.Header{
-
ID: query.Header.ID,
-
QR: true,
-
OPCode: 0,
-
AA: false,
-
TC: false,
-
RD: query.Header.RD,
-
RA: true,
-
Z: 0,
-
RCode: magna.NOERROR,
-
QDCount: 1,
-
ANCount: 0,
-
NSCount: 0,
-
ARCount: 0,
-
},
-
Question: []magna.Question{},
-
Answer: []magna.ResourceRecord{},
-
Additional: []magna.ResourceRecord{},
-
Authority: []magna.ResourceRecord{},
-
}
-
-
if len(query.Question) < 0 {
-
msg.Header.RCode = magna.FORMERR
-
return
-
}
-
question := query.Question[0]
-
msg.Question = append(msg.Question, question)
-
-
if question.QClass != magna.IN {
-
msg.Header.RCode = magna.NOTIMP
-
return
-
} else {
-
answer, err := s.resolveQuestion(question, s.RootServers)
-
if err != nil {
-
slog.Warn("resolve-question", err)
-
msg.Header.RCode = magna.SERVFAIL
-
return
-
}
-
-
msg.Header.ANCount = uint16(len(answer))
-
msg.Answer = answer
-
-
if msg.Header.ANCount == 0 {
-
msg.Header.RCode = magna.NXDOMAIN
-
return
-
}
-
}
-
-
return
-
}
-
-
func (s *Server) resolveQuestion(question magna.Question, servers []string) ([]magna.ResourceRecord, error) {
-
ctx, cancel := context.WithCancel(context.Background())
-
defer cancel()
-
-
ch := make(chan queryResponse, len(servers))
-
-
for _, s := range servers {
-
go queryServer(ctx, question, s, ch)
-
}
-
-
for i := 0; i < len(servers); i++ {
-
select {
-
case res := <-ch:
-
if res.Error != nil {
-
slog.Warn("error", "question", question, "server", res.Server, "error", res.Error)
-
break
-
}
-
-
msg := res.MSG
-
if msg.Header.ANCount > 0 {
-
if msg.Answer[0].RType == magna.CNAMEType {
-
cname_answers, err := s.resolveQuestion(magna.Question{QName: msg.Answer[0].RData.String(), QType: question.QType, QClass: question.QClass}, s.RootServers)
-
if err != nil {
-
slog.Warn("error with cname request", err)
-
continue
-
}
-
msg.Answer = append(msg.Answer, cname_answers...)
-
}
-
-
return msg.Answer, nil
-
}
-
-
if msg.Header.ARCount > 0 {
-
var nextZone []string
-
for _, ans := range msg.Additional {
-
if ans.RType == magna.AType {
-
nextZone = append(nextZone, ans.RData.String())
-
}
-
}
-
-
return s.resolveQuestion(question, nextZone)
-
}
-
-
if msg.Header.NSCount > 0 {
-
var ns []string
-
for _, a := range msg.Authority {
-
if a.RType == magna.NSType {
-
ans, err := s.resolveQuestion(magna.Question{QName: a.RData.String(), QType: magna.AType, QClass: magna.IN}, s.RootServers)
-
if err != nil {
-
slog.Warn("error with ns request", err)
-
break
-
}
-
for _, x := range ans {
-
ns = append(ns, x.RData.String())
-
}
-
}
-
}
-
-
return s.resolveQuestion(question, ns)
-
}
-
-
return []magna.ResourceRecord{}, nil
-
case <-time.After(time.Duration(s.Timeout) * time.Millisecond):
-
cancel()
-
}
-
}
-
-
return []magna.ResourceRecord{}, nil
-
}
-
-
func queryServer(ctx context.Context, question magna.Question, server string, ch chan<- queryResponse) {
-
done := make(chan struct{}, 1)
-
-
go func() {
-
conn, err := net.Dial("udp", fmt.Sprintf("%s:53", server))
-
if err != nil {
-
ch <- queryResponse{Error: err}
-
return
-
}
-
defer conn.Close()
-
-
query := magna.Message{
-
Header: magna.Header{
-
ID: uint16(rand.Int() % 65535),
-
QR: false,
-
OPCode: 0,
-
AA: false,
-
TC: false,
-
RD: false,
-
RA: false,
-
Z: 0,
-
RCode: magna.NOERROR,
-
QDCount: 1,
-
ARCount: 0,
-
NSCount: 0,
-
ANCount: 0,
-
},
-
Question: []magna.Question{question},
-
}
-
if _, err := conn.Write(query.Encode()); err != nil {
-
ch <- queryResponse{Server: server, Error: err}
-
return
-
}
-
-
p := make([]byte, 512)
-
nn, err := conn.Read(p)
-
-
// TODO: retry request with TCP
-
if err != nil || nn > 512 {
-
if err == nil {
-
err = fmt.Errorf("truncated response")
-
}
-
ch <- queryResponse{Server: server, Error: err}
-
return
-
}
-
-
var response magna.Message
-
err = response.Decode(p)
-
ch <- queryResponse{MSG: response, Server: server, Error: err}
-
}()
-
-
select {
-
case <-ctx.Done():
-
ch <- queryResponse{Server: server, Error: ctx.Err()}
-
case <-done:
-
// goroutine finished with no cancellation
-
}
-
}
+46
pkg/dns/logging.go
···
+
package dns
+
+
import (
+
"log/slog"
+
"os"
+
"time"
+
)
+
+
type LogConfig struct {
+
Logger *slog.Logger
+
Level slog.Level
+
}
+
+
func NewDefaultLogConfig() *LogConfig {
+
return &LogConfig{
+
Logger: slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
+
Level: slog.LevelInfo,
+
})),
+
Level: slog.LevelInfo,
+
}
+
}
+
+
func LoggingMiddleware(config *LogConfig) func(Handler) Handler {
+
if config == nil {
+
config = NewDefaultLogConfig()
+
}
+
+
return func(next Handler) Handler {
+
return HandlerFunc(func(w ResponseWriter, r *Request) {
+
start := time.Now()
+
+
next.ServeDNS(w, r)
+
+
duration := time.Since(start)
+
question := r.Message.Question[0]
+
config.Logger.Info("query",
+
"class", question.QClass.String(),
+
"type", question.QType.String(),
+
"name", question.QName,
+
"rcode", r.Message.Header.RCode.String(),
+
"remote_addr", r.RemoteAddr,
+
"time_taken", duration.Nanoseconds(),
+
)
+
})
+
}
+
}
+137
pkg/dns/ratelimit.go
···
+
package dns
+
+
import (
+
"net"
+
"sync"
+
"time"
+
+
"code.kiri.systems/kiri/magna"
+
)
+
+
type RateLimitConfig struct {
+
Rate float64
+
Burst int
+
WindowLength time.Duration
+
ExpirationTime time.Duration
+
}
+
+
type rateLimiter struct {
+
config RateLimitConfig
+
ipData map[string]*ipRateData
+
mu sync.RWMutex
+
}
+
+
type ipRateData struct {
+
time time.Time
+
}
+
+
func NewDefaultRateLimitConfig() *RateLimitConfig {
+
return &RateLimitConfig{
+
Rate: 1,
+
Burst: 1,
+
WindowLength: time.Hour,
+
ExpirationTime: time.Hour,
+
}
+
}
+
+
func newRateLimiter(config RateLimitConfig) *rateLimiter {
+
return &rateLimiter{
+
config: config,
+
ipData: make(map[string]*ipRateData),
+
}
+
}
+
+
func (rl *rateLimiter) allow(ip string) bool {
+
rl.mu.Lock()
+
defer rl.mu.Unlock()
+
+
now := time.Now()
+
cost := time.Duration(float64(time.Second) / rl.config.Rate)
+
+
data, exists := rl.ipData[ip]
+
if !exists {
+
data = &ipRateData{time: now.Add(-rl.config.WindowLength)}
+
rl.ipData[ip] = data
+
}
+
+
if data.time.Before(now.Add(-rl.config.WindowLength)) {
+
data.time = now.Add(-rl.config.WindowLength)
+
}
+
+
nextTime := data.time.Add(cost)
+
if now.Before(nextTime) {
+
return false
+
}
+
+
if nextTime.Sub(now.Add(-rl.config.WindowLength)) > time.Duration(rl.config.Burst)*cost {
+
nextTime = now.Add(cost)
+
}
+
+
data.time = nextTime
+
return true
+
}
+
+
func (rl *rateLimiter) cleanup() {
+
rl.mu.Lock()
+
defer rl.mu.Unlock()
+
+
now := time.Now()
+
for ip, data := range rl.ipData {
+
if data.time.Before(now.Add(-rl.config.WindowLength)) {
+
delete(rl.ipData, ip)
+
}
+
}
+
}
+
+
func extractIP(addr net.Addr) string {
+
switch v := addr.(type) {
+
case *net.UDPAddr:
+
return v.IP.String()
+
case *net.TCPAddr:
+
return v.IP.String()
+
default:
+
host, _, err := net.SplitHostPort(addr.String())
+
if err != nil {
+
return addr.String()
+
}
+
return host
+
}
+
}
+
+
func RateLimitMiddleware(config *RateLimitConfig) func(Handler) Handler {
+
if config == nil {
+
config = NewDefaultRateLimitConfig()
+
}
+
+
rl := newRateLimiter(*config)
+
+
go func() {
+
ticker := time.NewTicker(config.ExpirationTime)
+
for range ticker.C {
+
rl.cleanup()
+
}
+
}()
+
+
return func(next Handler) Handler {
+
return HandlerFunc(func(w ResponseWriter, r *Request) {
+
if !rl.allow(extractIP(r.RemoteAddr)) {
+
r.Message.Header.RA = true
+
msg := r.Message.CreateReply(r.Message)
+
msg = r.Message.SetRCode(magna.REFUSED)
+
+
// XXX: dont support edns yet and these get copied over on responses
+
msg.Header.ANCount = 0
+
msg.Header.NSCount = 0
+
msg.Header.ARCount = 0
+
msg.Answer = []magna.ResourceRecord{}
+
msg.Additional = []magna.ResourceRecord{}
+
msg.Authority = []magna.ResourceRecord{}
+
+
w.WriteMsg(msg)
+
return
+
}
+
+
next.ServeDNS(w, r)
+
})
+
}
+
}
+240
pkg/dns/resolve.go
···
+
package dns
+
+
import (
+
"context"
+
"fmt"
+
"net"
+
"strings"
+
"time"
+
+
"code.kiri.systems/kiri/magna"
+
)
+
+
type QueryHandler struct {
+
RootServers []string
+
Timeout time.Duration
+
Cache *Cache
+
}
+
+
type queryResponse struct {
+
MSG magna.Message
+
Server string
+
Error error
+
}
+
+
func (h *QueryHandler) ServeDNS(w ResponseWriter, r *Request) {
+
msg := h.processQuery(r.Message.Encode())
+
w.WriteMsg(msg)
+
}
+
+
func (h *QueryHandler) processQuery(messageBuffer []byte) *magna.Message {
+
var query magna.Message
+
if err := query.Decode(messageBuffer); err != nil {
+
return nil
+
}
+
+
msg := new(magna.Message)
+
msg = msg.CreateReply(&query)
+
+
if len(query.Question) < 1 {
+
return msg.SetRCode(magna.FORMERR)
+
}
+
+
question := query.Question[0]
+
msg = msg.AddQuestion(question)
+
+
if question.QClass != magna.IN {
+
return msg.SetRCode(magna.NOTIMP)
+
}
+
+
answer, err := h.resolveWithCache(question)
+
if err != nil {
+
return msg.SetRCode(magna.SERVFAIL)
+
}
+
+
if len(answer) == 0 {
+
return msg.SetRCode(magna.NXDOMAIN)
+
}
+
+
msg.Header.ANCount = uint16(len(answer))
+
msg.Answer = answer
+
return msg.SetRCode(magna.NOERROR)
+
}
+
+
func (h *QueryHandler) resolveWithCache(question magna.Question) ([]magna.ResourceRecord, error) {
+
cacheKey := fmt.Sprintf("%s:%s:%s", strings.ToLower(question.QName), question.QType.String(), question.QClass.String())
+
+
if e, found := (*h.Cache).Get(cacheKey); found {
+
now := time.Now()
+
var updatedAnswer []magna.ResourceRecord
+
var cname *magna.ResourceRecord
+
hasAddressRecord := false
+
+
for _, cachedRR := range e.Answer {
+
if now.Before(cachedRR.ExpireAt) {
+
updatedRR := cachedRR.Record
+
updatedRR.TTL = uint32(cachedRR.ExpireAt.Sub(now).Seconds())
+
updatedAnswer = append(updatedAnswer, updatedRR)
+
+
if updatedRR.RType == magna.CNAMEType && cname == nil {
+
cname = &updatedRR
+
} else if updatedRR.RType == question.QType {
+
hasAddressRecord = true
+
}
+
}
+
}
+
+
if len(updatedAnswer) > 0 {
+
// add AAAA types when magna supports those record types
+
if cname != nil && !hasAddressRecord && (question.QType == magna.AType) {
+
cnameTarget := cname.RData.String()
+
aRecords, err := h.resolveWithCache(magna.Question{QName: cnameTarget, QType: question.QType, QClass: question.QClass})
+
if err == nil && len(aRecords) > 0 {
+
updatedAnswer = append(updatedAnswer, aRecords...)
+
}
+
}
+
return updatedAnswer, nil
+
}
+
}
+
+
answer, err := h.resolveQuestion(question, h.RootServers)
+
if err != nil {
+
return nil, err
+
}
+
+
now := time.Now()
+
cachedAnswer := make([]CachedResourceRecord, len(answer))
+
for i, rr := range answer {
+
cachedAnswer[i] = CachedResourceRecord{
+
Record: rr,
+
ExpireAt: now.Add(time.Duration(rr.TTL) * time.Second),
+
}
+
}
+
+
entry := &CacheEntry{
+
Answer: cachedAnswer,
+
}
+
(*h.Cache).Set(cacheKey, entry)
+
+
if len(answer) > 0 && answer[0].RType == magna.CNAMEType && question.QType == magna.AType {
+
cnameTarget := answer[len(answer)-1].RData.String()
+
addressRecords, err := h.resolveWithCache(magna.Question{QName: cnameTarget, QType: question.QType, QClass: question.QClass})
+
if err == nil && len(addressRecords) > 0 {
+
answer = append(answer, addressRecords...)
+
}
+
}
+
+
return answer, nil
+
}
+
+
func (h *QueryHandler) resolveQuestion(question magna.Question, servers []string) ([]magna.ResourceRecord, error) {
+
ctx, cancel := context.WithCancel(context.Background())
+
defer cancel()
+
+
ch := make(chan queryResponse, len(servers))
+
+
for _, s := range servers {
+
go queryServer(ctx, question, s, ch, h.Timeout)
+
}
+
+
for i := 0; i < len(servers); i++ {
+
select {
+
case res := <-ch:
+
if res.Error != nil {
+
break
+
}
+
+
msg := res.MSG
+
if msg.Header.ANCount > 0 {
+
if msg.Answer[0].RType == magna.CNAMEType {
+
cname_answers, err := h.resolveQuestion(magna.Question{QName: msg.Answer[0].RData.String(), QType: question.QType, QClass: question.QClass}, h.RootServers)
+
if err != nil {
+
continue
+
}
+
msg.Answer = append(msg.Answer, cname_answers...)
+
}
+
+
return msg.Answer, nil
+
}
+
+
if msg.Header.ARCount > 0 {
+
var nextZone []string
+
for _, ans := range msg.Additional {
+
if ans.RType == magna.AType {
+
nextZone = append(nextZone, ans.RData.String())
+
}
+
}
+
+
return h.resolveQuestion(question, nextZone)
+
}
+
+
if msg.Header.NSCount > 0 {
+
var ns []string
+
for _, a := range msg.Authority {
+
if a.RType == magna.NSType {
+
ans, err := h.resolveQuestion(magna.Question{QName: a.RData.String(), QType: magna.AType, QClass: magna.IN}, h.RootServers)
+
if err != nil {
+
break
+
}
+
for _, x := range ans {
+
ns = append(ns, x.RData.String())
+
}
+
}
+
}
+
+
return h.resolveQuestion(question, ns)
+
}
+
+
return []magna.ResourceRecord{}, nil
+
case <-time.After(h.Timeout):
+
cancel()
+
}
+
}
+
+
return []magna.ResourceRecord{}, nil
+
}
+
+
func queryServer(ctx context.Context, question magna.Question, server string, ch chan<- queryResponse, timeout time.Duration) {
+
done := make(chan struct{}, 1)
+
+
go func() {
+
conn, err := net.Dial("udp", fmt.Sprintf("%s:53", server))
+
if err != nil {
+
ch <- queryResponse{Error: err}
+
return
+
}
+
defer conn.Close()
+
+
query := magna.CreateRequest(0, false)
+
query = query.AddQuestion(question)
+
if _, err := conn.Write(query.Encode()); err != nil {
+
ch <- queryResponse{Server: server, Error: err}
+
return
+
}
+
+
p := make([]byte, 512)
+
nn, err := conn.Read(p)
+
+
// TODO: retry request with TCP
+
if err != nil || nn > 512 {
+
if err == nil {
+
err = fmt.Errorf("truncated response")
+
}
+
ch <- queryResponse{Server: server, Error: err}
+
return
+
}
+
+
var response magna.Message
+
err = response.Decode(p)
+
ch <- queryResponse{MSG: response, Server: server, Error: err}
+
}()
+
+
select {
+
case <-ctx.Done():
+
ch <- queryResponse{Server: server, Error: ctx.Err()}
+
case <-done:
+
// goroutine finished with no cancellation
+
case <-time.After(timeout):
+
ch <- queryResponse{Server: server, Error: fmt.Errorf("timeout")}
+
}
+
}
+239
pkg/dns/server.go
···
+
package dns
+
+
import (
+
"encoding/binary"
+
"fmt"
+
"io"
+
"log/slog"
+
"net"
+
"sync"
+
"time"
+
+
"code.kiri.systems/kiri/magna"
+
)
+
+
type Handler interface {
+
ServeDNS(ResponseWriter, *Request)
+
}
+
+
type HandlerFunc func(ResponseWriter, *Request)
+
+
func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Request) {
+
f(w, r)
+
}
+
+
type Request struct {
+
RemoteAddr net.Addr
+
Message *magna.Message
+
}
+
+
type ResponseWriter interface {
+
WriteMsg(*magna.Message)
+
}
+
+
type udpResponseWriter struct {
+
udpConn *net.UDPConn
+
addr *net.UDPAddr
+
logger *slog.Logger
+
writeTimeout time.Duration
+
}
+
+
func (w *udpResponseWriter) WriteMsg(msg *magna.Message) {
+
ans := msg.Encode()
+
if len(ans) > 512 {
+
ans[3] |= 1 << 6 // set the truncated bit
+
}
+
+
err := w.udpConn.SetWriteDeadline(time.Now().Add(w.writeTimeout))
+
if err != nil {
+
w.logger.Warn("error setting write deadline for UDP", "error", err)
+
}
+
+
_, err = w.udpConn.WriteToUDP(ans, w.addr)
+
if err != nil {
+
w.logger.Error("error writing UDP response", "error", err)
+
}
+
}
+
+
type tcpResponseWriter struct {
+
tcpConn net.Conn
+
logger *slog.Logger
+
writeTiemout time.Duration
+
}
+
+
func (w *tcpResponseWriter) WriteMsg(msg *magna.Message) {
+
ans := msg.Encode()
+
+
err := w.tcpConn.SetWriteDeadline(time.Now().Add(w.writeTiemout))
+
if err != nil {
+
w.logger.Warn("error setting write deadline for TCP", "error", err)
+
}
+
+
_, err = w.tcpConn.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(ans))))
+
if err != nil {
+
w.logger.Error("error writing TCP message length", "error", err)
+
return
+
}
+
+
_, err = w.tcpConn.Write(ans)
+
if err != nil {
+
w.logger.Error("error writing TCP response", "error", err)
+
}
+
}
+
+
type Server struct {
+
Address string
+
Port int
+
Handler Handler
+
UDPSize int
+
ReadTimeout time.Duration
+
WriteTimeout time.Duration
+
Logger *slog.Logger
+
Cache Cache
+
}
+
+
func (srv *Server) ListenAndServe() error {
+
var wg sync.WaitGroup
+
errChan := make(chan error, 2)
+
+
wg.Add(2)
+
+
go func() {
+
defer wg.Done()
+
if err := srv.serveTCP(); err != nil {
+
errChan <- fmt.Errorf("TCP server error: %w", err)
+
}
+
}()
+
+
go func() {
+
defer wg.Done()
+
if err := srv.serveUDP(); err != nil {
+
errChan <- fmt.Errorf("TCP server error: %w", err)
+
}
+
}()
+
+
go func() {
+
wg.Wait()
+
close(errChan)
+
}()
+
+
for err := range errChan {
+
return err
+
}
+
+
return nil
+
}
+
+
func (srv *Server) serveUDP() error {
+
addr := net.UDPAddr{
+
Port: srv.Port,
+
IP: net.ParseIP(srv.Address),
+
}
+
conn, err := net.ListenUDP("udp", &addr)
+
if err != nil {
+
return err
+
}
+
defer conn.Close()
+
+
for {
+
buf := make([]byte, srv.UDPSize)
+
+
err := conn.SetReadDeadline(time.Now().Add(srv.ReadTimeout))
+
if err != nil {
+
return fmt.Errorf("error setting read deadline: %w", err)
+
}
+
+
n, remoteAddr, err := conn.ReadFromUDP(buf)
+
if err != nil {
+
// skip logging timeout errors
+
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
+
continue
+
}
+
+
srv.Logger.Warn(err.Error())
+
continue
+
}
+
+
go srv.handleUDPQuery(conn, buf[:n], remoteAddr)
+
}
+
}
+
+
func (srv *Server) handleUDPQuery(conn *net.UDPConn, query []byte, remoteAddr *net.UDPAddr) {
+
w := &udpResponseWriter{
+
udpConn: conn,
+
addr: remoteAddr,
+
logger: srv.Logger,
+
writeTimeout: srv.WriteTimeout,
+
}
+
+
srv.handleQuery(query, w, remoteAddr)
+
}
+
+
func (srv *Server) serveTCP() error {
+
addr := net.TCPAddr{
+
Port: srv.Port,
+
IP: net.ParseIP(srv.Address),
+
}
+
+
listener, err := net.ListenTCP("tcp", &addr)
+
if err != nil {
+
return err
+
}
+
defer listener.Close()
+
+
for {
+
conn, err := listener.Accept()
+
if err != nil {
+
srv.Logger.Warn("tcp accept error:", err)
+
continue
+
}
+
+
go srv.handleTCPQuery(conn)
+
}
+
}
+
+
func (srv *Server) handleTCPQuery(conn net.Conn) {
+
defer conn.Close()
+
+
err := conn.SetReadDeadline(time.Now().Add(srv.ReadTimeout))
+
if err != nil {
+
srv.Logger.Error("error setting read deadline", "error", err)
+
return
+
}
+
+
sizeBuffer := make([]byte, 2)
+
if _, err := io.ReadFull(conn, sizeBuffer); err != nil {
+
srv.Logger.Warn("tcp-error", err)
+
return
+
}
+
+
size := binary.BigEndian.Uint16(sizeBuffer)
+
data := make([]byte, size)
+
if _, err := io.ReadFull(conn, data); err != nil {
+
srv.Logger.Warn("tcp-error", err)
+
return
+
}
+
+
w := &tcpResponseWriter{
+
tcpConn: conn,
+
logger: srv.Logger,
+
writeTiemout: srv.WriteTimeout,
+
}
+
+
srv.handleQuery(data, w, conn.RemoteAddr())
+
}
+
+
func (srv *Server) handleQuery(messageBuffer []byte, w ResponseWriter, remoteAddr net.Addr) {
+
var query magna.Message
+
if err := query.Decode(messageBuffer); err != nil {
+
srv.Logger.Warn("decode error", err)
+
return
+
}
+
+
r := &Request{
+
Message: &query,
+
RemoteAddr: remoteAddr,
+
}
+
+
srv.Handler.ServeDNS(w, r)
+
}