a recursive dns resolver
at main 9.6 kB view raw
1package dns 2 3import ( 4 "context" 5 "encoding/binary" 6 "fmt" 7 "io" 8 "log/slog" 9 "net" 10 "sync" 11 "time" 12 13 "tangled.sh/seiso.moe/magna" 14) 15 16type contextKey string 17 18const ( 19 maxUDPBufferSize = 4096 20 maxResolverUDPBufferSize = 4096 21) 22 23var ( 24 serverUDPBufferPool = sync.Pool{ 25 New: func() any { 26 b := make([]byte, maxUDPBufferSize) 27 return &b 28 }, 29 } 30 31 resolverUDPBufferPool = sync.Pool{ 32 New: func() any { 33 b := make([]byte, maxResolverUDPBufferSize) 34 return &b 35 }, 36 } 37) 38 39type Handler interface { 40 ServeDNS(ResponseWriter, *Request) 41} 42 43type HandlerFunc func(ResponseWriter, *Request) 44 45func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Request) { 46 f(w, r) 47} 48 49type Request struct { 50 Context context.Context 51 RemoteAddr net.Addr 52 Message *magna.Message 53} 54 55type ResponseWriter interface { 56 WriteMsg(*magna.Message) 57} 58 59type udpResponseWriter struct { 60 udpConn *net.UDPConn 61 addr *net.UDPAddr 62 logger *slog.Logger 63 writeTimeout time.Duration 64 udpSize int 65} 66 67func (w *udpResponseWriter) WriteMsg(msg *magna.Message) { 68 ans, err := msg.Encode() 69 if err != nil { 70 w.logger.Warn("err encoding msg", "error", err) 71 72 failMsg := msg.CreateReply(msg) 73 failMsg = failMsg.SetRCode(magna.SERVFAIL) 74 failMsg.Answer = nil 75 failMsg.Authority = nil 76 failMsg.Additional = nil 77 failMsg.Header.ANCount = 0 78 failMsg.Header.NSCount = 0 79 failMsg.Header.ARCount = 0 80 ans, _ = failMsg.Encode() 81 if ans == nil { 82 return 83 } 84 } 85 86 if len(ans) > w.udpSize { 87 w.logger.Debug("Response exceeds UDP size, setting TC bit", "size", len(ans), "limit", w.udpSize, "client", w.addr) 88 89 tcMsg := msg.CreateReply(msg) 90 tcMsg.Header.TC = true 91 tcMsg.Answer = nil 92 tcMsg.Authority = nil 93 tcMsg.Additional = nil 94 tcMsg.Header.ANCount = 0 95 tcMsg.Header.NSCount = 0 96 tcMsg.Header.ARCount = 0 97 tcMsg.Header.RCode = magna.NOERROR 98 99 ans, err = tcMsg.Encode() 100 if err != nil { 101 w.logger.Error("Error encoding truncated UDP response", "error", err, "client", w.addr) 102 return 103 } 104 if len(ans) > w.udpSize { 105 w.logger.Warn("Truncated message still exceeds UDP size limit!", "size", len(ans), "limit", w.udpSize) 106 } 107 } 108 109 if w.udpConn == nil || w.addr == nil { 110 w.logger.Error("UDP response writer used incorrectly (nil conn or addr)") 111 return 112 } 113 114 err = w.udpConn.SetWriteDeadline(time.Now().Add(w.writeTimeout)) 115 if err != nil { 116 w.logger.Warn("error setting write deadline for UDP", "error", err) 117 } 118 119 _, err = w.udpConn.WriteToUDP(ans, w.addr) 120 if err != nil { 121 w.logger.Error("error writing UDP response", "error", err) 122 } 123} 124 125type tcpResponseWriter struct { 126 tcpConn net.Conn 127 logger *slog.Logger 128 writeTiemout time.Duration 129} 130 131func (w *tcpResponseWriter) WriteMsg(msg *magna.Message) { 132 ans, err := msg.Encode() 133 if err != nil { 134 w.logger.Warn("err encoding msg", "error", err) 135 136 failMsg := msg.CreateReply(msg) 137 failMsg = failMsg.SetRCode(magna.SERVFAIL) 138 failMsg.Answer = nil 139 failMsg.Authority = nil 140 failMsg.Additional = nil 141 failMsg.Header.ANCount = 0 142 failMsg.Header.NSCount = 0 143 failMsg.Header.ARCount = 0 144 ans, _ = failMsg.Encode() 145 return 146 } 147 148 if w.tcpConn == nil { 149 w.logger.Error("TCP response writer used incorrectly (nil conn)") 150 return 151 } 152 153 err = w.tcpConn.SetWriteDeadline(time.Now().Add(w.writeTiemout)) 154 if err != nil { 155 w.logger.Warn("error setting write deadline for TCP", "error", err) 156 } 157 158 _, err = w.tcpConn.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(ans)))) 159 if err != nil { 160 w.logger.Error("error writing TCP message length", "error", err) 161 return 162 } 163 164 _, err = w.tcpConn.Write(ans) 165 if err != nil { 166 w.logger.Error("error writing TCP response", "error", err) 167 } 168} 169 170type Server struct { 171 Address string 172 Port int 173 Handler Handler 174 UDPSize int 175 ReadTimeout time.Duration 176 WriteTimeout time.Duration 177 Logger *slog.Logger 178} 179 180func (srv *Server) ListenAndServe() error { 181 if srv.Logger == nil { 182 srv.Logger = slog.Default() 183 } 184 185 if srv.UDPSize <= 0 { 186 srv.UDPSize = 512 187 } 188 srv.Logger.Info("Starting DNS server", "address", srv.Address, "port", srv.Port, "udp_size", srv.UDPSize) 189 190 var wg sync.WaitGroup 191 errChan := make(chan error, 2) 192 193 wg.Add(2) 194 195 go func() { 196 defer wg.Done() 197 if err := srv.serveTCP(); err != nil { 198 err = fmt.Errorf("TCP server error: %w", err) 199 srv.Logger.Error(err.Error()) 200 select { 201 case errChan <- err: 202 default: 203 srv.Logger.Warn("Error channel full, discarding TCP error") 204 } 205 } 206 }() 207 208 go func() { 209 defer wg.Done() 210 if err := srv.serveUDP(); err != nil { 211 err = fmt.Errorf("UDP server error: %w", err) 212 srv.Logger.Error(err.Error()) 213 select { 214 case errChan <- err: 215 default: 216 srv.Logger.Warn("Error channel full, discarding UDP error") 217 } 218 } 219 }() 220 221 go func() { 222 wg.Wait() 223 close(errChan) 224 }() 225 226 err := <-errChan 227 return err 228} 229 230func (srv *Server) serveUDP() error { 231 addr := &net.UDPAddr{ 232 Port: srv.Port, 233 IP: net.ParseIP(srv.Address), 234 } 235 conn, err := net.ListenUDP("udp", addr) 236 if err != nil { 237 return fmt.Errorf("failed to listen on UDP %s:%d: %w", srv.Address, srv.Port, err) 238 } 239 defer conn.Close() 240 srv.Logger.Info("UDP listener started", "address", conn.LocalAddr()) 241 242 for { 243 bufPtr := serverUDPBufferPool.Get().(*[]byte) 244 buffer := *bufPtr 245 246 readDeadlineSet := false 247 if srv.ReadTimeout > 0 { 248 err := conn.SetReadDeadline(time.Now().Add(srv.ReadTimeout)) 249 if err != nil { 250 serverUDPBufferPool.Put(bufPtr) 251 return fmt.Errorf("error setting UDP read deadline: %w", err) 252 } 253 readDeadlineSet = true 254 } 255 256 n, remoteAddr, err := conn.ReadFromUDP(buffer) 257 if err != nil { 258 serverUDPBufferPool.Put(bufPtr) 259 if netErr, ok := err.(net.Error); ok && netErr.Timeout() { 260 if readDeadlineSet { 261 continue 262 } 263 264 srv.Logger.Warn("UDP read timeout occurred without explicit deadline set", "error", err) 265 continue 266 } 267 268 if opErr, ok := err.(*net.OpError); ok && opErr.Err.Error() == "use of closed network connection" { 269 srv.Logger.Info("UDP listener stopping: connection closed") 270 return nil 271 } 272 273 srv.Logger.Warn("UDP read error", "error", err) 274 continue 275 } 276 277 queryData := make([]byte, n) 278 copy(queryData, buffer[:n]) 279 serverUDPBufferPool.Put(bufPtr) 280 281 go srv.handleUDPQuery(conn, queryData, remoteAddr) 282 } 283} 284 285func (srv *Server) handleUDPQuery(conn *net.UDPConn, query []byte, remoteAddr *net.UDPAddr) { 286 w := &udpResponseWriter{ 287 udpConn: conn, 288 addr: remoteAddr, 289 logger: srv.Logger, 290 writeTimeout: srv.WriteTimeout, 291 udpSize: srv.UDPSize, 292 } 293 294 srv.handleQuery(query, w, remoteAddr) 295} 296 297func (srv *Server) serveTCP() error { 298 addr := net.TCPAddr{ 299 Port: srv.Port, 300 IP: net.ParseIP(srv.Address), 301 } 302 303 listener, err := net.ListenTCP("tcp", &addr) 304 if err != nil { 305 return fmt.Errorf("failed to listen on TCP %s:%d: %w", srv.Address, srv.Port, err) 306 } 307 defer listener.Close() 308 srv.Logger.Info("TCP listener started", "address", listener.Addr()) 309 310 for { 311 conn, err := listener.AcceptTCP() 312 if err != nil { 313 if opErr, ok := err.(*net.OpError); ok && opErr.Err.Error() == "use of closed network connection" { 314 srv.Logger.Info("TCP listener stopping: listener closed") 315 return nil 316 } 317 srv.Logger.Warn("TCP accept error", "error", err) 318 continue 319 } 320 321 conn.SetKeepAlive(true) 322 conn.SetKeepAlivePeriod(3 * time.Minute) 323 go srv.handleTCPQuery(conn) 324 } 325} 326 327func (srv *Server) handleTCPQuery(conn net.Conn) { 328 defer conn.Close() 329 330 if srv.ReadTimeout > 0 { 331 err := conn.SetReadDeadline(time.Now().Add(srv.ReadTimeout)) 332 if err != nil { 333 srv.Logger.Warn("Error setting TCP initial read deadline", "error", err, "client", conn.RemoteAddr()) 334 return 335 } 336 } 337 338 sizeBuffer := make([]byte, 2) 339 _, err := io.ReadFull(conn, sizeBuffer) 340 if err != nil { 341 if err != io.EOF && err != io.ErrUnexpectedEOF { 342 srv.Logger.Warn("TCP error reading message length", "error", err, "client", conn.RemoteAddr()) 343 } 344 return 345 } 346 347 size := binary.BigEndian.Uint16(sizeBuffer) 348 if size == 0 { 349 srv.Logger.Debug("TCP received zero-length message", "client", conn.RemoteAddr()) 350 return 351 } 352 353 maxTCPMessageSize := 65535 354 if size > uint16(maxTCPMessageSize) { 355 srv.Logger.Warn("TCP message size exceeds limit", "size", size, "limit", maxTCPMessageSize, "client", conn.RemoteAddr()) 356 return 357 } 358 359 data := make([]byte, size) 360 if srv.ReadTimeout > 0 { 361 err = conn.SetReadDeadline(time.Now().Add(srv.ReadTimeout)) 362 if err != nil { 363 srv.Logger.Warn("Error setting TCP read deadline for body", "error", err, "client", conn.RemoteAddr()) 364 return 365 } 366 } 367 368 _, err = io.ReadFull(conn, data) 369 if err != nil { 370 srv.Logger.Warn("TCP error reading message body", "error", err, "client", conn.RemoteAddr()) 371 return 372 } 373 374 w := &tcpResponseWriter{ 375 tcpConn: conn, 376 logger: srv.Logger, 377 writeTiemout: srv.WriteTimeout, 378 } 379 380 srv.handleQuery(data, w, conn.RemoteAddr()) 381} 382 383func (srv *Server) handleQuery(messageBuffer []byte, w ResponseWriter, remoteAddr net.Addr) { 384 var query magna.Message 385 err := query.Decode(messageBuffer) 386 if err != nil { 387 srv.Logger.Warn("Message decode error", "error", err, "client", remoteAddr) 388 // TODO: find better way to handle failed decode drop for now. 389 return 390 } 391 392 ctx, cancel := context.WithTimeout(context.Background(), srv.ReadTimeout) 393 defer cancel() 394 395 r := &Request{ 396 Context: ctx, 397 Message: &query, 398 RemoteAddr: remoteAddr, 399 } 400 401 if srv.Handler == nil { 402 srv.Logger.Error("No DNS handler configured!") 403 reply := query.CreateReply(&query) 404 reply = reply.SetRCode(magna.SERVFAIL) 405 w.WriteMsg(reply) 406 return 407 } 408 srv.Handler.ServeDNS(w, r) 409}