a recursive dns resolver
at main 3.0 kB view raw
1package dns 2 3import ( 4 "net" 5 "sync" 6 "time" 7 8 "golang.org/x/time/rate" 9 "tangled.sh/seiso.moe/magna" 10) 11 12type RateLimitConfig struct { 13 Rate float64 14 Burst int 15 WindowLength time.Duration 16 ExpirationTime time.Duration 17} 18 19type rateLimiter struct { 20 config RateLimitConfig 21 limiters map[string]*ipRateLimiterEntry 22 mu sync.RWMutex 23 stopCleanup chan struct{} 24} 25 26type ipRateLimiterEntry struct { 27 limiter *rate.Limiter 28 lastAccess time.Time 29} 30 31func NewDefaultRateLimitConfig() *RateLimitConfig { 32 return &RateLimitConfig{ 33 Rate: 1, 34 Burst: 1, 35 WindowLength: time.Second, 36 ExpirationTime: 5 * time.Minute, 37 } 38} 39 40func newRateLimiter(config RateLimitConfig) *rateLimiter { 41 rl := &rateLimiter{ 42 config: config, 43 limiters: make(map[string]*ipRateLimiterEntry), 44 stopCleanup: make(chan struct{}), 45 } 46 47 go rl.cleanupLoop() 48 return rl 49} 50 51func (rl *rateLimiter) allow(ip string) bool { 52 rl.mu.Lock() 53 defer rl.mu.Unlock() 54 55 entry, exists := rl.limiters[ip] 56 now := time.Now() 57 58 if !exists { 59 limiter := rate.NewLimiter(rate.Limit(rl.config.Rate), rl.config.Burst) 60 entry := &ipRateLimiterEntry{ 61 limiter: limiter, 62 lastAccess: now, 63 } 64 65 rl.limiters[ip] = entry 66 return entry.limiter.Allow() 67 } 68 69 entry.lastAccess = now 70 return entry.limiter.Allow() 71} 72 73func (rl *rateLimiter) cleanupLoop() { 74 ticker := time.NewTicker(rl.config.ExpirationTime) 75 defer ticker.Stop() 76 77 for { 78 select { 79 case <-ticker.C: 80 rl.performCleanup() 81 case <-rl.stopCleanup: 82 return 83 } 84 } 85} 86 87func (rl *rateLimiter) performCleanup() { 88 rl.mu.Lock() 89 defer rl.mu.Unlock() 90 91 now := time.Now() 92 expirationCutoff := now.Add(-rl.config.ExpirationTime) 93 94 for ip, entry := range rl.limiters { 95 if entry.lastAccess.Before(expirationCutoff) { 96 delete(rl.limiters, ip) 97 } 98 } 99} 100 101func (rl *rateLimiter) Stop() { 102 close(rl.stopCleanup) 103} 104 105func extractIP(addr net.Addr) string { 106 switch v := addr.(type) { 107 case *net.UDPAddr: 108 return v.IP.String() 109 case *net.TCPAddr: 110 return v.IP.String() 111 default: 112 host, _, err := net.SplitHostPort(addr.String()) 113 if err != nil { 114 return addr.String() 115 } 116 return host 117 } 118} 119 120func RateLimitMiddleware(config *RateLimitConfig) func(Handler) Handler { 121 if config == nil { 122 config = NewDefaultRateLimitConfig() 123 } 124 125 if config.Rate <= 0 { 126 config.Rate = 1 127 if config.Burst <= 0 { 128 config.Burst = 1 129 } 130 } 131 132 rl := newRateLimiter(*config) 133 134 return func(next Handler) Handler { 135 return HandlerFunc(func(w ResponseWriter, r *Request) { 136 clientIP := extractIP(r.RemoteAddr) 137 if !rl.allow(clientIP) { 138 msg := r.Message.CreateReply(r.Message) 139 msg = r.Message.SetRCode(magna.REFUSED) 140 msg.Header.RA = true 141 142 msg.Header.ANCount = 0 143 msg.Header.NSCount = 0 144 msg.Header.ARCount = 0 145 msg.Answer = []magna.ResourceRecord{} 146 msg.Additional = []magna.ResourceRecord{} 147 msg.Authority = []magna.ResourceRecord{} 148 149 w.WriteMsg(msg) 150 return 151 } 152 153 next.ServeDNS(w, r) 154 }) 155 } 156}