1package dns
2
3import (
4 "context"
5 "encoding/binary"
6 "fmt"
7 "io"
8 "log/slog"
9 "net"
10 "sync"
11 "time"
12
13 "tangled.sh/seiso.moe/magna"
14)
15
16type contextKey string
17
18const (
19 maxUDPBufferSize = 4096
20 maxResolverUDPBufferSize = 4096
21)
22
23var (
24 serverUDPBufferPool = sync.Pool{
25 New: func() any {
26 b := make([]byte, maxUDPBufferSize)
27 return &b
28 },
29 }
30
31 resolverUDPBufferPool = sync.Pool{
32 New: func() any {
33 b := make([]byte, maxResolverUDPBufferSize)
34 return &b
35 },
36 }
37)
38
39type Handler interface {
40 ServeDNS(ResponseWriter, *Request)
41}
42
43type HandlerFunc func(ResponseWriter, *Request)
44
45func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Request) {
46 f(w, r)
47}
48
49type Request struct {
50 Context context.Context
51 RemoteAddr net.Addr
52 Message *magna.Message
53}
54
55type ResponseWriter interface {
56 WriteMsg(*magna.Message)
57}
58
59type udpResponseWriter struct {
60 udpConn *net.UDPConn
61 addr *net.UDPAddr
62 logger *slog.Logger
63 writeTimeout time.Duration
64 udpSize int
65}
66
67func (w *udpResponseWriter) WriteMsg(msg *magna.Message) {
68 ans, err := msg.Encode()
69 if err != nil {
70 w.logger.Warn("err encoding msg", "error", err)
71
72 failMsg := msg.CreateReply(msg)
73 failMsg = failMsg.SetRCode(magna.SERVFAIL)
74 failMsg.Answer = nil
75 failMsg.Authority = nil
76 failMsg.Additional = nil
77 failMsg.Header.ANCount = 0
78 failMsg.Header.NSCount = 0
79 failMsg.Header.ARCount = 0
80 ans, _ = failMsg.Encode()
81 if ans == nil {
82 return
83 }
84 }
85
86 if len(ans) > w.udpSize {
87 w.logger.Debug("Response exceeds UDP size, setting TC bit", "size", len(ans), "limit", w.udpSize, "client", w.addr)
88
89 tcMsg := msg.CreateReply(msg)
90 tcMsg.Header.TC = true
91 tcMsg.Answer = nil
92 tcMsg.Authority = nil
93 tcMsg.Additional = nil
94 tcMsg.Header.ANCount = 0
95 tcMsg.Header.NSCount = 0
96 tcMsg.Header.ARCount = 0
97 tcMsg.Header.RCode = magna.NOERROR
98
99 ans, err = tcMsg.Encode()
100 if err != nil {
101 w.logger.Error("Error encoding truncated UDP response", "error", err, "client", w.addr)
102 return
103 }
104 if len(ans) > w.udpSize {
105 w.logger.Warn("Truncated message still exceeds UDP size limit!", "size", len(ans), "limit", w.udpSize)
106 }
107 }
108
109 if w.udpConn == nil || w.addr == nil {
110 w.logger.Error("UDP response writer used incorrectly (nil conn or addr)")
111 return
112 }
113
114 err = w.udpConn.SetWriteDeadline(time.Now().Add(w.writeTimeout))
115 if err != nil {
116 w.logger.Warn("error setting write deadline for UDP", "error", err)
117 }
118
119 _, err = w.udpConn.WriteToUDP(ans, w.addr)
120 if err != nil {
121 w.logger.Error("error writing UDP response", "error", err)
122 }
123}
124
125type tcpResponseWriter struct {
126 tcpConn net.Conn
127 logger *slog.Logger
128 writeTiemout time.Duration
129}
130
131func (w *tcpResponseWriter) WriteMsg(msg *magna.Message) {
132 ans, err := msg.Encode()
133 if err != nil {
134 w.logger.Warn("err encoding msg", "error", err)
135
136 failMsg := msg.CreateReply(msg)
137 failMsg = failMsg.SetRCode(magna.SERVFAIL)
138 failMsg.Answer = nil
139 failMsg.Authority = nil
140 failMsg.Additional = nil
141 failMsg.Header.ANCount = 0
142 failMsg.Header.NSCount = 0
143 failMsg.Header.ARCount = 0
144 ans, _ = failMsg.Encode()
145 return
146 }
147
148 if w.tcpConn == nil {
149 w.logger.Error("TCP response writer used incorrectly (nil conn)")
150 return
151 }
152
153 err = w.tcpConn.SetWriteDeadline(time.Now().Add(w.writeTiemout))
154 if err != nil {
155 w.logger.Warn("error setting write deadline for TCP", "error", err)
156 }
157
158 _, err = w.tcpConn.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(ans))))
159 if err != nil {
160 w.logger.Error("error writing TCP message length", "error", err)
161 return
162 }
163
164 _, err = w.tcpConn.Write(ans)
165 if err != nil {
166 w.logger.Error("error writing TCP response", "error", err)
167 }
168}
169
170type Server struct {
171 Address string
172 Port int
173 Handler Handler
174 UDPSize int
175 ReadTimeout time.Duration
176 WriteTimeout time.Duration
177 Logger *slog.Logger
178}
179
180func (srv *Server) ListenAndServe() error {
181 if srv.Logger == nil {
182 srv.Logger = slog.Default()
183 }
184
185 if srv.UDPSize <= 0 {
186 srv.UDPSize = 512
187 }
188 srv.Logger.Info("Starting DNS server", "address", srv.Address, "port", srv.Port, "udp_size", srv.UDPSize)
189
190 var wg sync.WaitGroup
191 errChan := make(chan error, 2)
192
193 wg.Add(2)
194
195 go func() {
196 defer wg.Done()
197 if err := srv.serveTCP(); err != nil {
198 err = fmt.Errorf("TCP server error: %w", err)
199 srv.Logger.Error(err.Error())
200 select {
201 case errChan <- err:
202 default:
203 srv.Logger.Warn("Error channel full, discarding TCP error")
204 }
205 }
206 }()
207
208 go func() {
209 defer wg.Done()
210 if err := srv.serveUDP(); err != nil {
211 err = fmt.Errorf("UDP server error: %w", err)
212 srv.Logger.Error(err.Error())
213 select {
214 case errChan <- err:
215 default:
216 srv.Logger.Warn("Error channel full, discarding UDP error")
217 }
218 }
219 }()
220
221 go func() {
222 wg.Wait()
223 close(errChan)
224 }()
225
226 err := <-errChan
227 return err
228}
229
230func (srv *Server) serveUDP() error {
231 addr := &net.UDPAddr{
232 Port: srv.Port,
233 IP: net.ParseIP(srv.Address),
234 }
235 conn, err := net.ListenUDP("udp", addr)
236 if err != nil {
237 return fmt.Errorf("failed to listen on UDP %s:%d: %w", srv.Address, srv.Port, err)
238 }
239 defer conn.Close()
240 srv.Logger.Info("UDP listener started", "address", conn.LocalAddr())
241
242 for {
243 bufPtr := serverUDPBufferPool.Get().(*[]byte)
244 buffer := *bufPtr
245
246 readDeadlineSet := false
247 if srv.ReadTimeout > 0 {
248 err := conn.SetReadDeadline(time.Now().Add(srv.ReadTimeout))
249 if err != nil {
250 serverUDPBufferPool.Put(bufPtr)
251 return fmt.Errorf("error setting UDP read deadline: %w", err)
252 }
253 readDeadlineSet = true
254 }
255
256 n, remoteAddr, err := conn.ReadFromUDP(buffer)
257 if err != nil {
258 serverUDPBufferPool.Put(bufPtr)
259 if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
260 if readDeadlineSet {
261 continue
262 }
263
264 srv.Logger.Warn("UDP read timeout occurred without explicit deadline set", "error", err)
265 continue
266 }
267
268 if opErr, ok := err.(*net.OpError); ok && opErr.Err.Error() == "use of closed network connection" {
269 srv.Logger.Info("UDP listener stopping: connection closed")
270 return nil
271 }
272
273 srv.Logger.Warn("UDP read error", "error", err)
274 continue
275 }
276
277 queryData := make([]byte, n)
278 copy(queryData, buffer[:n])
279 serverUDPBufferPool.Put(bufPtr)
280
281 go srv.handleUDPQuery(conn, queryData, remoteAddr)
282 }
283}
284
285func (srv *Server) handleUDPQuery(conn *net.UDPConn, query []byte, remoteAddr *net.UDPAddr) {
286 w := &udpResponseWriter{
287 udpConn: conn,
288 addr: remoteAddr,
289 logger: srv.Logger,
290 writeTimeout: srv.WriteTimeout,
291 udpSize: srv.UDPSize,
292 }
293
294 srv.handleQuery(query, w, remoteAddr)
295}
296
297func (srv *Server) serveTCP() error {
298 addr := net.TCPAddr{
299 Port: srv.Port,
300 IP: net.ParseIP(srv.Address),
301 }
302
303 listener, err := net.ListenTCP("tcp", &addr)
304 if err != nil {
305 return fmt.Errorf("failed to listen on TCP %s:%d: %w", srv.Address, srv.Port, err)
306 }
307 defer listener.Close()
308 srv.Logger.Info("TCP listener started", "address", listener.Addr())
309
310 for {
311 conn, err := listener.AcceptTCP()
312 if err != nil {
313 if opErr, ok := err.(*net.OpError); ok && opErr.Err.Error() == "use of closed network connection" {
314 srv.Logger.Info("TCP listener stopping: listener closed")
315 return nil
316 }
317 srv.Logger.Warn("TCP accept error", "error", err)
318 continue
319 }
320
321 conn.SetKeepAlive(true)
322 conn.SetKeepAlivePeriod(3 * time.Minute)
323 go srv.handleTCPQuery(conn)
324 }
325}
326
327func (srv *Server) handleTCPQuery(conn net.Conn) {
328 defer conn.Close()
329
330 if srv.ReadTimeout > 0 {
331 err := conn.SetReadDeadline(time.Now().Add(srv.ReadTimeout))
332 if err != nil {
333 srv.Logger.Warn("Error setting TCP initial read deadline", "error", err, "client", conn.RemoteAddr())
334 return
335 }
336 }
337
338 sizeBuffer := make([]byte, 2)
339 _, err := io.ReadFull(conn, sizeBuffer)
340 if err != nil {
341 if err != io.EOF && err != io.ErrUnexpectedEOF {
342 srv.Logger.Warn("TCP error reading message length", "error", err, "client", conn.RemoteAddr())
343 }
344 return
345 }
346
347 size := binary.BigEndian.Uint16(sizeBuffer)
348 if size == 0 {
349 srv.Logger.Debug("TCP received zero-length message", "client", conn.RemoteAddr())
350 return
351 }
352
353 maxTCPMessageSize := 65535
354 if size > uint16(maxTCPMessageSize) {
355 srv.Logger.Warn("TCP message size exceeds limit", "size", size, "limit", maxTCPMessageSize, "client", conn.RemoteAddr())
356 return
357 }
358
359 data := make([]byte, size)
360 if srv.ReadTimeout > 0 {
361 err = conn.SetReadDeadline(time.Now().Add(srv.ReadTimeout))
362 if err != nil {
363 srv.Logger.Warn("Error setting TCP read deadline for body", "error", err, "client", conn.RemoteAddr())
364 return
365 }
366 }
367
368 _, err = io.ReadFull(conn, data)
369 if err != nil {
370 srv.Logger.Warn("TCP error reading message body", "error", err, "client", conn.RemoteAddr())
371 return
372 }
373
374 w := &tcpResponseWriter{
375 tcpConn: conn,
376 logger: srv.Logger,
377 writeTiemout: srv.WriteTimeout,
378 }
379
380 srv.handleQuery(data, w, conn.RemoteAddr())
381}
382
383func (srv *Server) handleQuery(messageBuffer []byte, w ResponseWriter, remoteAddr net.Addr) {
384 var query magna.Message
385 err := query.Decode(messageBuffer)
386 if err != nil {
387 srv.Logger.Warn("Message decode error", "error", err, "client", remoteAddr)
388 // TODO: find better way to handle failed decode drop for now.
389 return
390 }
391
392 ctx, cancel := context.WithTimeout(context.Background(), srv.ReadTimeout)
393 defer cancel()
394
395 r := &Request{
396 Context: ctx,
397 Message: &query,
398 RemoteAddr: remoteAddr,
399 }
400
401 if srv.Handler == nil {
402 srv.Logger.Error("No DNS handler configured!")
403 reply := query.CreateReply(&query)
404 reply = reply.SetRCode(magna.SERVFAIL)
405 w.WriteMsg(reply)
406 return
407 }
408 srv.Handler.ServeDNS(w, r)
409}