A community based topic aggregation platform built on atproto
1package middleware 2 3import ( 4 "net/http" 5 "sync" 6 "time" 7) 8 9// RateLimiter implements a simple in-memory rate limiter 10// For production, consider using Redis or a distributed rate limiter 11type RateLimiter struct { 12 clients map[string]*clientLimit 13 requests int 14 window time.Duration 15 mu sync.Mutex 16} 17 18type clientLimit struct { 19 resetTime time.Time 20 count int 21} 22 23// NewRateLimiter creates a new rate limiter 24// requests: maximum number of requests allowed per window 25// window: time window duration (e.g., 1 minute) 26func NewRateLimiter(requests int, window time.Duration) *RateLimiter { 27 rl := &RateLimiter{ 28 clients: make(map[string]*clientLimit), 29 requests: requests, 30 window: window, 31 } 32 33 // Cleanup old entries every window duration 34 go rl.cleanup() 35 36 return rl 37} 38 39// Middleware returns a rate limiting middleware 40func (rl *RateLimiter) Middleware(next http.Handler) http.Handler { 41 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 42 // Use IP address as client identifier 43 // In production, consider using authenticated user ID if available 44 clientID := getClientIP(r) 45 46 if !rl.allow(clientID) { 47 http.Error(w, "Rate limit exceeded. Please try again later.", http.StatusTooManyRequests) 48 return 49 } 50 51 next.ServeHTTP(w, r) 52 }) 53} 54 55// allow checks if a client is allowed to make a request 56func (rl *RateLimiter) allow(clientID string) bool { 57 rl.mu.Lock() 58 defer rl.mu.Unlock() 59 60 now := time.Now().UTC() 61 62 // Get or create client limit 63 client, exists := rl.clients[clientID] 64 if !exists { 65 rl.clients[clientID] = &clientLimit{ 66 count: 1, 67 resetTime: now.Add(rl.window), 68 } 69 return true 70 } 71 72 // Check if window has expired 73 if now.After(client.resetTime) { 74 client.count = 1 75 client.resetTime = now.Add(rl.window) 76 return true 77 } 78 79 // Check if under limit 80 if client.count < rl.requests { 81 client.count++ 82 return true 83 } 84 85 // Rate limit exceeded 86 return false 87} 88 89// cleanup removes expired client entries periodically 90func (rl *RateLimiter) cleanup() { 91 ticker := time.NewTicker(rl.window) 92 defer ticker.Stop() 93 94 for range ticker.C { 95 rl.mu.Lock() 96 now := time.Now().UTC() 97 for clientID, client := range rl.clients { 98 if now.After(client.resetTime) { 99 delete(rl.clients, clientID) 100 } 101 } 102 rl.mu.Unlock() 103 } 104} 105 106// getClientIP extracts the client IP from the request 107func getClientIP(r *http.Request) string { 108 // Check X-Forwarded-For header (if behind proxy) 109 forwarded := r.Header.Get("X-Forwarded-For") 110 if forwarded != "" { 111 return forwarded 112 } 113 114 // Check X-Real-IP header 115 realIP := r.Header.Get("X-Real-IP") 116 if realIP != "" { 117 return realIP 118 } 119 120 // Fall back to RemoteAddr 121 return r.RemoteAddr 122}