A community based topic aggregation platform built on atproto
1package middleware
2
3import (
4 "net/http"
5 "sync"
6 "time"
7)
8
9// RateLimiter implements a simple in-memory rate limiter
10// For production, consider using Redis or a distributed rate limiter
11type RateLimiter struct {
12 clients map[string]*clientLimit
13 requests int
14 window time.Duration
15 mu sync.Mutex
16}
17
18type clientLimit struct {
19 resetTime time.Time
20 count int
21}
22
23// NewRateLimiter creates a new rate limiter
24// requests: maximum number of requests allowed per window
25// window: time window duration (e.g., 1 minute)
26func NewRateLimiter(requests int, window time.Duration) *RateLimiter {
27 rl := &RateLimiter{
28 clients: make(map[string]*clientLimit),
29 requests: requests,
30 window: window,
31 }
32
33 // Cleanup old entries every window duration
34 go rl.cleanup()
35
36 return rl
37}
38
39// Middleware returns a rate limiting middleware
40func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
41 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
42 // Use IP address as client identifier
43 // In production, consider using authenticated user ID if available
44 clientID := getClientIP(r)
45
46 if !rl.allow(clientID) {
47 http.Error(w, "Rate limit exceeded. Please try again later.", http.StatusTooManyRequests)
48 return
49 }
50
51 next.ServeHTTP(w, r)
52 })
53}
54
55// allow checks if a client is allowed to make a request
56func (rl *RateLimiter) allow(clientID string) bool {
57 rl.mu.Lock()
58 defer rl.mu.Unlock()
59
60 now := time.Now().UTC()
61
62 // Get or create client limit
63 client, exists := rl.clients[clientID]
64 if !exists {
65 rl.clients[clientID] = &clientLimit{
66 count: 1,
67 resetTime: now.Add(rl.window),
68 }
69 return true
70 }
71
72 // Check if window has expired
73 if now.After(client.resetTime) {
74 client.count = 1
75 client.resetTime = now.Add(rl.window)
76 return true
77 }
78
79 // Check if under limit
80 if client.count < rl.requests {
81 client.count++
82 return true
83 }
84
85 // Rate limit exceeded
86 return false
87}
88
89// cleanup removes expired client entries periodically
90func (rl *RateLimiter) cleanup() {
91 ticker := time.NewTicker(rl.window)
92 defer ticker.Stop()
93
94 for range ticker.C {
95 rl.mu.Lock()
96 now := time.Now().UTC()
97 for clientID, client := range rl.clients {
98 if now.After(client.resetTime) {
99 delete(rl.clients, clientID)
100 }
101 }
102 rl.mu.Unlock()
103 }
104}
105
106// getClientIP extracts the client IP from the request
107func getClientIP(r *http.Request) string {
108 // Check X-Forwarded-For header (if behind proxy)
109 forwarded := r.Header.Get("X-Forwarded-For")
110 if forwarded != "" {
111 return forwarded
112 }
113
114 // Check X-Real-IP header
115 realIP := r.Header.Get("X-Real-IP")
116 if realIP != "" {
117 return realIP
118 }
119
120 // Fall back to RemoteAddr
121 return r.RemoteAddr
122}