a recursive dns resolver

initial commit

Changed files
+504
docs
pkg
config
dns
rootservers
+1
.gitignore
···
+
alky
+5
README.md
···
+
# ALKY
+
## IMPORTANT
+
#### REQUIRED But not currently supported
+
- imperative mode for resolving
+
- i would support this as required by 1034 section 4.3.1 (https://datatracker.ietf.org/doc/html/rfc1034#section-4.3.1) but it does not look like any of the servers i have queried support this (cloudflare, google, quad9)
+17
docs/alky.toml
···
+
[server]
+
# Address to bind tcp/udp to.
+
address = "127.0.0.1"
+
# Port to bind tcp/udp to.
+
port = 2053
+
# Location of root hints file.
+
root_hints_file = "/etc/dns/root.hints"
+
+
[logging]
+
# Logging output: "stdout" or "file".
+
output = "stdout"
+
# This is used only if logging.output is "file".
+
file_path = "/home/blu/git/kiri/alky/log"
+
+
[advanced]
+
# Timeout (in milliseconds) for outgoing queries before being cancelled.
+
query_timeout = 100
+8
go.mod
···
+
module code.kiri.systems/kiri/alky
+
+
go 1.22.5
+
+
require (
+
code.kiri.systems/kiri/magna v0.0.0-20240721214902-8d0a079dbd84
+
github.com/BurntSushi/toml v1.4.0
+
)
+12
go.sum
···
+
code.kiri.systems/kiri/magna v0.0.0-20240721214902-8d0a079dbd84 h1:igzBX4k3REg0WZExjGLWW7/wu/X+U6QlbMc8aeO2030=
+
code.kiri.systems/kiri/magna v0.0.0-20240721214902-8d0a079dbd84/go.mod h1:gSzCiTKyKlUEjGgl/qTb8rxF0QUVuWOEORAsTXA0qyI=
+
github.com/BurntSushi/toml v1.4.0 h1:kuoIxZQy2WRRk1pttg9asf+WVv6tWQuBNVmK8+nqPr0=
+
github.com/BurntSushi/toml v1.4.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
+
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
+
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
+
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
+
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
+
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
+
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+62
main.go
···
+
package main
+
+
import (
+
"log"
+
"log/slog"
+
"os"
+
"flag"
+
+
"code.kiri.systems/kiri/alky/pkg/config"
+
"code.kiri.systems/kiri/alky/pkg/dns"
+
"code.kiri.systems/kiri/alky/pkg/rootservers"
+
)
+
+
var configFlag string
+
+
func init() {
+
flag.StringVar(&configFlag, "config", "/etc/alky/alky.toml", "config file path for alky")
+
+
flag.Parse()
+
}
+
+
func main() {
+
cfg, err := config.LoadConfig(configFlag)
+
if err != nil {
+
log.Fatal(err)
+
}
+
+
rootServers, err := rootservers.DecodeRootHints(cfg.Server.RootHintsFile)
+
if err != nil {
+
log.Fatal(err)
+
}
+
+
var logger *slog.Logger
+
switch cfg.Logging.Output {
+
case "file":
+
f, err := os.OpenFile(cfg.Logging.FilePath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0o644)
+
if err != nil {
+
log.Fatal(err)
+
}
+
+
logger = slog.New(slog.NewJSONHandler(f, nil))
+
case "stdout":
+
fallthrough
+
default:
+
logger = slog.New(slog.NewJSONHandler(os.Stdout, nil))
+
}
+
+
s := dns.Server{
+
Address: cfg.Server.Address,
+
Port: cfg.Server.Port,
+
Timeout: cfg.Advanced.QueryTimeout,
+
RootServers: rootServers,
+
+
Logger: logger,
+
}
+
+
go s.TCPListenAndServe()
+
go s.UDPListenAndServe()
+
+
for {
+
}
+
}
+57
pkg/config/config.go
···
+
package config
+
+
import (
+
"fmt"
+
+
"github.com/BurntSushi/toml"
+
)
+
+
type ServerConfig struct {
+
Address string `toml:"address"`
+
Port int `toml:"port"`
+
RootHintsFile string `toml:"root_hints_file"`
+
}
+
+
type LoggingConfig struct {
+
Output string `toml:"output"`
+
FilePath string `toml:"file_path"`
+
}
+
+
type AdvancedConfig struct {
+
QueryTimeout int `toml:"query_timeout"`
+
}
+
+
type Config struct {
+
Server ServerConfig `toml:"server"`
+
Logging LoggingConfig `toml:"logging"`
+
Advanced AdvancedConfig `toml:"advanced"`
+
}
+
+
func LoadConfig(path string) (Config, error) {
+
cfg := Config{}
+
if _, err := toml.DecodeFile(path, &cfg); err != nil {
+
return cfg, err
+
}
+
+
if cfg.Server.Address == "" {
+
cfg.Server.Address = "127.0.0.1"
+
}
+
+
if cfg.Server.Port == 0 {
+
cfg.Server.Port = 53
+
}
+
+
if cfg.Server.RootHintsFile == "" {
+
cfg.Server.RootHintsFile = "/etc/dns/root.hints"
+
}
+
+
if cfg.Logging.Output == "file" && cfg.Logging.FilePath == "" {
+
return cfg, fmt.Errorf("If `[logging.output]` is `file` then `file_path` must be set.")
+
}
+
+
if cfg.Advanced.QueryTimeout == 0 {
+
cfg.Advanced.QueryTimeout = 100
+
}
+
+
return cfg, nil
+
}
+301
pkg/dns/dns.go
···
+
package dns
+
+
import (
+
"context"
+
"encoding/binary"
+
"fmt"
+
"io"
+
"log/slog"
+
"math/rand/v2"
+
"net"
+
"time"
+
+
"code.kiri.systems/kiri/magna"
+
)
+
+
type Server struct {
+
Address string
+
Port int
+
Timeout int
+
RootServers []string
+
+
Logger *slog.Logger
+
}
+
+
type queryResponse struct {
+
MSG magna.Message
+
Server string
+
Error error
+
}
+
+
func (s *Server) UDPListenAndServe() error {
+
addr := net.UDPAddr{
+
Port: s.Port,
+
IP: net.ParseIP(s.Address),
+
}
+
server, err := net.ListenUDP("udp", &addr)
+
if err != nil {
+
return err
+
}
+
defer server.Close()
+
+
for {
+
b := make([]byte, 512)
+
_, remote_addr, err := server.ReadFromUDP(b)
+
if err != nil {
+
s.Logger.Warn(err.Error())
+
continue
+
}
+
+
start := time.Now()
+
msg := s.processQuery(b)
+
s.Logger.Info("query", "class", msg.Question[0].QClass.String(), "type", msg.Question[0].QType.String(), "name", msg.Question[0].QName, "rcode", msg.Header.RCode.String(), "remote_addr", remote_addr.IP, "time_taken", time.Since(start).Nanoseconds())
+
if err != nil {
+
s.Logger.Warn(err.Error())
+
continue
+
}
+
+
ans := msg.Encode()
+
// xxx: set the TC bit if the message is over 512 bytes
+
if len(ans) > 512 {
+
ans[3] |= 1 << 6
+
}
+
+
if _, err := server.WriteToUDP(ans, remote_addr); err != nil {
+
s.Logger.Warn("sending response", "err", err.Error())
+
}
+
}
+
}
+
+
func (s *Server) TCPListenAndServe() error {
+
addr := net.TCPAddr{
+
Port: s.Port,
+
IP: net.ParseIP(s.Address),
+
}
+
+
server, err := net.ListenTCP("tcp", &addr)
+
if err != nil {
+
return err
+
}
+
defer server.Close()
+
+
for {
+
conn, err := server.Accept()
+
if err != nil {
+
s.Logger.Warn("conn error:", err)
+
continue
+
}
+
+
sizeBuffer := make([]byte, 2)
+
if _, err := io.ReadFull(conn, sizeBuffer); err != nil {
+
s.Logger.Warn("tcp-error", err)
+
continue
+
}
+
+
size := binary.BigEndian.Uint16(sizeBuffer)
+
+
data := make([]byte, size)
+
if _, err := io.ReadFull(conn, data); err != nil {
+
s.Logger.Warn("tcp-error", err)
+
continue
+
}
+
+
start := time.Now()
+
msg := s.processQuery(data)
+
s.Logger.Info("query", "class", msg.Question[0].QClass.String(), "type", msg.Question[0].QType.String(), "name", msg.Question[0].QName, "rcode", msg.Header.RCode.String(), "remote_addr", conn.RemoteAddr(), "time_taken", time.Since(start).Nanoseconds())
+
+
ans := msg.Encode()
+
conn.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(ans))))
+
if _, err := conn.Write(ans); err != nil {
+
s.Logger.Error("tcp-error", err)
+
}
+
}
+
}
+
+
func (s *Server) processQuery(messageBuffer []byte) (msg magna.Message) {
+
var query magna.Message
+
if err := query.Decode(messageBuffer); err != nil {
+
slog.Warn("decode", err)
+
return
+
}
+
+
msg = magna.Message{
+
Header: magna.Header{
+
ID: query.Header.ID,
+
QR: true,
+
OPCode: 0,
+
AA: false,
+
TC: false,
+
RD: query.Header.RD,
+
RA: true,
+
Z: 0,
+
RCode: magna.NOERROR,
+
QDCount: 1,
+
ANCount: 0,
+
NSCount: 0,
+
ARCount: 0,
+
},
+
Question: []magna.Question{},
+
Answer: []magna.ResourceRecord{},
+
Additional: []magna.ResourceRecord{},
+
Authority: []magna.ResourceRecord{},
+
}
+
+
if len(query.Question) < 0 {
+
msg.Header.RCode = magna.FORMERR
+
return
+
}
+
question := query.Question[0]
+
msg.Question = append(msg.Question, question)
+
+
if question.QClass != magna.IN {
+
msg.Header.RCode = magna.NOTIMP
+
return
+
} else {
+
answer, err := s.resolveQuestion(question, s.RootServers)
+
if err != nil {
+
slog.Warn("resolve-question", err)
+
msg.Header.RCode = magna.SERVFAIL
+
return
+
}
+
+
msg.Header.ANCount = uint16(len(answer))
+
msg.Answer = answer
+
+
if msg.Header.ANCount == 0 {
+
msg.Header.RCode = magna.NXDOMAIN
+
return
+
}
+
}
+
+
return
+
}
+
+
func (s *Server) resolveQuestion(question magna.Question, servers []string) ([]magna.ResourceRecord, error) {
+
ctx, cancel := context.WithCancel(context.Background())
+
defer cancel()
+
+
ch := make(chan queryResponse, len(servers))
+
+
for _, s := range servers {
+
go queryServer(ctx, question, s, ch)
+
}
+
+
for i := 0; i < len(servers); i++ {
+
select {
+
case res := <-ch:
+
if res.Error != nil {
+
slog.Warn("error", "question", question, "server", res.Server, "error", res.Error)
+
break
+
}
+
+
msg := res.MSG
+
if msg.Header.ANCount > 0 {
+
if msg.Answer[0].RType == magna.CNAMEType {
+
cname_answers, err := s.resolveQuestion(magna.Question{QName: msg.Answer[0].RData.String(), QType: question.QType, QClass: question.QClass}, s.RootServers)
+
if err != nil {
+
slog.Warn("error with cname request", err)
+
continue
+
}
+
msg.Answer = append(msg.Answer, cname_answers...)
+
}
+
+
return msg.Answer, nil
+
}
+
+
if msg.Header.ARCount > 0 {
+
var nextZone []string
+
for _, ans := range msg.Additional {
+
if ans.RType == magna.AType {
+
nextZone = append(nextZone, ans.RData.String())
+
}
+
}
+
+
return s.resolveQuestion(question, nextZone)
+
}
+
+
if msg.Header.NSCount > 0 {
+
var ns []string
+
for _, a := range msg.Authority {
+
if a.RType == magna.NSType {
+
ans, err := s.resolveQuestion(magna.Question{QName: a.RData.String(), QType: magna.AType, QClass: magna.IN}, s.RootServers)
+
if err != nil {
+
slog.Warn("error with ns request", err)
+
break
+
}
+
for _, x := range ans {
+
ns = append(ns, x.RData.String())
+
}
+
}
+
}
+
+
return s.resolveQuestion(question, ns)
+
}
+
+
return []magna.ResourceRecord{}, nil
+
case <-time.After(time.Duration(s.Timeout) * time.Millisecond):
+
cancel()
+
}
+
}
+
+
return []magna.ResourceRecord{}, nil
+
}
+
+
func queryServer(ctx context.Context, question magna.Question, server string, ch chan<- queryResponse) {
+
done := make(chan struct{}, 1)
+
+
go func() {
+
conn, err := net.Dial("udp", fmt.Sprintf("%s:53", server))
+
if err != nil {
+
ch <- queryResponse{Error: err}
+
return
+
}
+
defer conn.Close()
+
+
query := magna.Message{
+
Header: magna.Header{
+
ID: uint16(rand.Int() % 65535),
+
QR: false,
+
OPCode: 0,
+
AA: false,
+
TC: false,
+
RD: false,
+
RA: false,
+
Z: 0,
+
RCode: magna.NOERROR,
+
QDCount: 1,
+
ARCount: 0,
+
NSCount: 0,
+
ANCount: 0,
+
},
+
Question: []magna.Question{question},
+
}
+
if _, err := conn.Write(query.Encode()); err != nil {
+
ch <- queryResponse{Server: server, Error: err}
+
return
+
}
+
+
p := make([]byte, 512)
+
nn, err := conn.Read(p)
+
+
// TODO: retry request with TCP
+
if err != nil || nn > 512 {
+
if err == nil {
+
err = fmt.Errorf("truncated response")
+
}
+
ch <- queryResponse{Server: server, Error: err}
+
return
+
}
+
+
var response magna.Message
+
err = response.Decode(p)
+
ch <- queryResponse{MSG: response, Server: server, Error: err}
+
}()
+
+
select {
+
case <-ctx.Done():
+
ch <- queryResponse{Server: server, Error: ctx.Err()}
+
case <-done:
+
// goroutine finished with no cancellation
+
}
+
}
+41
pkg/rootservers/loader.go
···
+
package rootservers
+
+
import (
+
"bytes"
+
"net"
+
"os"
+
)
+
+
// should probally just be a part of magna
+
func DecodeRootHints(path string) ([]string, error) {
+
rootServers := make([]string, 0)
+
+
data, err := os.ReadFile(path)
+
if err != nil {
+
return rootServers, err
+
}
+
+
for _, line := range bytes.Split(data, []byte{'\n'}) {
+
// skip comments
+
if line[0] == ';' {
+
continue
+
}
+
+
// xxx: not a great way to do this should probally just be a zone
+
// custom parser
+
fields := bytes.Fields(line)
+
if len(fields) != 4 {
+
continue
+
}
+
+
// only supports ipv4 for now, need to do testing with ipv6 and support
+
// https://datatracker.ietf.org/doc/html/rfc3596 in magna
+
if bytes.Equal(fields[2], []byte{'A'}) {
+
if address := net.ParseIP(string(fields[3])); address != nil {
+
rootServers = append(rootServers, address.String())
+
}
+
}
+
}
+
+
return rootServers, nil
+
}