package auth import ( "fmt" "net/http" "sync" "time" "github.com/rs/zerolog/log" ) // RateLimiter implements a simple in-memory rate limiter type RateLimiter struct { mu sync.RWMutex buckets map[string]*bucket requests int window time.Duration cleanup time.Duration } type bucket struct { count int lastReset time.Time } // NewRateLimiter creates a new rate limiter func NewRateLimiter(requests int, window time.Duration) *RateLimiter { rl := &RateLimiter{ buckets: make(map[string]*bucket), requests: requests, window: window, cleanup: window * 2, } // Start cleanup goroutine go rl.cleanupRoutine() return rl } // Allow checks if a request should be allowed func (rl *RateLimiter) Allow(key string) bool { rl.mu.Lock() defer rl.mu.Unlock() now := time.Now() // Get or create bucket b, exists := rl.buckets[key] if !exists { rl.buckets[key] = &bucket{ count: 1, lastReset: now, } return true } // Check if window has expired if now.Sub(b.lastReset) > rl.window { b.count = 1 b.lastReset = now return true } // Check if limit exceeded if b.count >= rl.requests { return false } // Increment counter b.count++ return true } // cleanupRoutine periodically removes old buckets func (rl *RateLimiter) cleanupRoutine() { ticker := time.NewTicker(rl.cleanup) defer ticker.Stop() for range ticker.C { rl.mu.Lock() now := time.Now() for key, bucket := range rl.buckets { if now.Sub(bucket.lastReset) > rl.cleanup { delete(rl.buckets, key) } } rl.mu.Unlock() } } // RateLimitMiddleware creates a rate limiting middleware func (rl *RateLimiter) RateLimitMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Use IP address as the key key := getClientIP(r) if !rl.Allow(key) { log.Warn(). Str("client_ip", key). Str("path", r.URL.Path). Msg("Rate limit exceeded") w.Header().Set("X-RateLimit-Limit", fmt.Sprintf("%d", rl.requests)) w.Header().Set("X-RateLimit-Window", rl.window.String()) w.Header().Set("Retry-After", rl.window.String()) http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests) return } next.ServeHTTP(w, r) }) } // getClientIP extracts the real client IP address func getClientIP(r *http.Request) string { // Check X-Forwarded-For header (when behind proxy) xff := r.Header.Get("X-Forwarded-For") if xff != "" { // Take the first IP in case of multiple if idx := len(xff); idx > 0 { if commaIdx := 0; commaIdx < idx { for i, char := range xff { if char == ',' { commaIdx = i break } } if commaIdx > 0 { return xff[:commaIdx] } } return xff } } // Check X-Real-IP header if xri := r.Header.Get("X-Real-IP"); xri != "" { return xri } // Fall back to RemoteAddr return r.RemoteAddr }