1package dns
2
3import (
4 "net"
5 "sync"
6 "time"
7
8 "golang.org/x/time/rate"
9 "tangled.sh/seiso.moe/magna"
10)
11
12type RateLimitConfig struct {
13 Rate float64
14 Burst int
15 WindowLength time.Duration
16 ExpirationTime time.Duration
17}
18
19type rateLimiter struct {
20 config RateLimitConfig
21 limiters map[string]*ipRateLimiterEntry
22 mu sync.RWMutex
23 stopCleanup chan struct{}
24}
25
26type ipRateLimiterEntry struct {
27 limiter *rate.Limiter
28 lastAccess time.Time
29}
30
31func NewDefaultRateLimitConfig() *RateLimitConfig {
32 return &RateLimitConfig{
33 Rate: 1,
34 Burst: 1,
35 WindowLength: time.Second,
36 ExpirationTime: 5 * time.Minute,
37 }
38}
39
40func newRateLimiter(config RateLimitConfig) *rateLimiter {
41 rl := &rateLimiter{
42 config: config,
43 limiters: make(map[string]*ipRateLimiterEntry),
44 stopCleanup: make(chan struct{}),
45 }
46
47 go rl.cleanupLoop()
48 return rl
49}
50
51func (rl *rateLimiter) allow(ip string) bool {
52 rl.mu.Lock()
53 defer rl.mu.Unlock()
54
55 entry, exists := rl.limiters[ip]
56 now := time.Now()
57
58 if !exists {
59 limiter := rate.NewLimiter(rate.Limit(rl.config.Rate), rl.config.Burst)
60 entry := &ipRateLimiterEntry{
61 limiter: limiter,
62 lastAccess: now,
63 }
64
65 rl.limiters[ip] = entry
66 return entry.limiter.Allow()
67 }
68
69 entry.lastAccess = now
70 return entry.limiter.Allow()
71}
72
73func (rl *rateLimiter) cleanupLoop() {
74 ticker := time.NewTicker(rl.config.ExpirationTime)
75 defer ticker.Stop()
76
77 for {
78 select {
79 case <-ticker.C:
80 rl.performCleanup()
81 case <-rl.stopCleanup:
82 return
83 }
84 }
85}
86
87func (rl *rateLimiter) performCleanup() {
88 rl.mu.Lock()
89 defer rl.mu.Unlock()
90
91 now := time.Now()
92 expirationCutoff := now.Add(-rl.config.ExpirationTime)
93
94 for ip, entry := range rl.limiters {
95 if entry.lastAccess.Before(expirationCutoff) {
96 delete(rl.limiters, ip)
97 }
98 }
99}
100
101func (rl *rateLimiter) Stop() {
102 close(rl.stopCleanup)
103}
104
105func extractIP(addr net.Addr) string {
106 switch v := addr.(type) {
107 case *net.UDPAddr:
108 return v.IP.String()
109 case *net.TCPAddr:
110 return v.IP.String()
111 default:
112 host, _, err := net.SplitHostPort(addr.String())
113 if err != nil {
114 return addr.String()
115 }
116 return host
117 }
118}
119
120func RateLimitMiddleware(config *RateLimitConfig) func(Handler) Handler {
121 if config == nil {
122 config = NewDefaultRateLimitConfig()
123 }
124
125 if config.Rate <= 0 {
126 config.Rate = 1
127 if config.Burst <= 0 {
128 config.Burst = 1
129 }
130 }
131
132 rl := newRateLimiter(*config)
133
134 return func(next Handler) Handler {
135 return HandlerFunc(func(w ResponseWriter, r *Request) {
136 clientIP := extractIP(r.RemoteAddr)
137 if !rl.allow(clientIP) {
138 msg := r.Message.CreateReply(r.Message)
139 msg = r.Message.SetRCode(magna.REFUSED)
140 msg.Header.RA = true
141
142 msg.Header.ANCount = 0
143 msg.Header.NSCount = 0
144 msg.Header.ARCount = 0
145 msg.Answer = []magna.ResourceRecord{}
146 msg.Additional = []magna.ResourceRecord{}
147 msg.Authority = []magna.ResourceRecord{}
148
149 w.WriteMsg(msg)
150 return
151 }
152
153 next.ServeDNS(w, r)
154 })
155 }
156}