跳转到主要内容

Gin 基于令牌桶的限流中间件

代码
// Author: 10935336
// Version: 2024-09-30.1
// Function: A token bucket-based current limiting Gin middleware
// License: The Unlicense


package ratelimiter

import (
	"context"
	"time"

	"github.com/gin-gonic/gin"
	"golang.org/x/time/rate"
	"sync"
)

type Limiters struct {
	limiters *sync.Map
	cancel   context.CancelFunc // Used to stop the cleanup goroutine
}

type Limiter struct {
	limiter *rate.Limiter
	lastGet time.Time // Last time a token was requested
	key     string    // Rate limiting identifier, e.g., context.ClientIP() is rate limiting by IP
}

var GlobalLimiters = &Limiters{
	limiters: &sync.Map{},
}

var once sync.Once

// NewLimiter creates a new or retrieves an existing rate limiter, with cleanup support
func NewLimiter(r rate.Limit, b int, key string, clearInterval time.Duration, expireAfter time.Duration) *Limiter {
	once.Do(func() {
		ctx, cancel := context.WithCancel(context.Background())
		GlobalLimiters.cancel = cancel
		go GlobalLimiters.clearLimiter(ctx, clearInterval, expireAfter)
	})

	keyLimiter := GlobalLimiters.getLimiter(r, b, key)

	return keyLimiter
}

// Allow checks if a token can be acquired, updating the lastGet timestamp
func (l *Limiter) Allow() bool {
	l.lastGet = time.Now()
	return l.limiter.Allow()
}

// RemainingTokens returns the number of tokens left in the bucket
func (l *Limiter) RemainingTokens() int {
	return int(l.limiter.Tokens())
}

// Limit returns the rate limit configuration
func (l *Limiter) Limit() rate.Limit {
	return l.limiter.Limit()
}

// getLimiter retrieves or creates a new rate limiter for a specific key
func (ls *Limiters) getLimiter(r rate.Limit, b int, key string) *Limiter {
	limiter, ok := ls.limiters.Load(key)

	if ok {
		return limiter.(*Limiter)
	}

	l := &Limiter{
		limiter: rate.NewLimiter(r, b),
		lastGet: time.Now(),
		key:     key,
	}

	ls.limiters.Store(key, l)

	return l
}

// clearLimiter removes rate limiters that have been idle for a specified duration
func (ls *Limiters) clearLimiter(ctx context.Context, clearInterval time.Duration, expireAfter time.Duration) {
	for {
		select {
		case <-ctx.Done():
			return
		case <-time.After(clearInterval):
			ls.limiters.Range(func(key, value interface{}) bool {
				limiter := value.(*Limiter)
				if time.Since(limiter.lastGet) > expireAfter {
					ls.limiters.Delete(key)
				}
				return true
			})
		}
	}
}

// Stop stops the cleanup goroutine and clears all stored limiters
func (ls *Limiters) Stop() {
	if ls.cancel != nil {
		ls.cancel() // Stop the cleanup goroutine
	}
	ls.limiters.Range(func(key, value interface{}) bool {
		ls.limiters.Delete(key)
		return true
	})
}

var rateLimitResponse = `{"code":429,"msg":"Too many requests, please try again later.","detail":""}`

// RateLimiterMiddleware Gin middleware for rate limiting, with a dynamic key function
func RateLimiterMiddleware(r rate.Limit, b int, clearInterval time.Duration, expireAfter time.Duration, keyFunc func(c *gin.Context) string) gin.HandlerFunc {
	return func(c *gin.Context) {
		// Use keyFunc to get the dynamic rate limit key (e.g., client IP, user ID, etc.)
		identifier := keyFunc(c)

		// Create or get the rate limiter for this identifier
		limiter := NewLimiter(r, b, identifier, clearInterval, expireAfter)

		// Check if the request is allowed
		if !limiter.Allow() {
			// If not allowed, return HTTP 429 Too Many Requests
			c.Data(429, "application/json", []byte(rateLimitResponse))
			c.Abort() // Stop further processing
			return
		}

		// Continue processing the request
		c.Next()
	}
}

使用示例

func (srv *GinServer) InjectRoutes() *gin.Engine {
    router := gin.Default()
    
    // allow 1 request per second, with a bucket size of 20, cleanup every 1 minute, and expire limiters after 1 minute of inactivity
    // Use IP as the rate limit identifier, this means that all API can only respond to 1 request per second for each IP address, and can burst to 20 requests per second.
    router.Use(ratelimiter.RateLimiterMiddleware(rate.Every(1*time.Second), 20, 1*time.Minute, 1*time.Minute, func(c *gin.Context) string {
      return c.ClientIP()
    }))
  
  
    authApi := router.Group("/api/auth/oidc")
    // allow 1 request per second, with a bucket size of 5, cleanup every 1 minute, and expire limiters after 1 minute of inactivity
    // Use "global" as the rate limit identifier, this means that this API can only respond to 1 request per second, and can burst to 5
    authApi.Use(ratelimiter.RateLimiterMiddleware(rate.Every(1*time.Second), 5, 1*time.Minute, 1*time.Minute, func(c *gin.Context) string {
    return "global"
    }))
    {
        authApi.GET("/session", auth.GetSession)
        authApi.POST("/login/start", auth.StartLogin)
        authApi.GET("/login/end", auth.EndLogin)
        authApi.POST("/refresh", auth.RefreshToken)
        authApi.POST("/logout", auth.Logout)
    }
}


func Server(){
  	server := &GinServer{}
	// Inject routes into the Gin engine.
	router := server.InjectRoutes()

	// Start the HTTP server.
	log.Fatal(http.ListenAndServe("0.0.0.0:9999", router))
}

使用解释

不想写,让 GPT 帮我写了一点介绍

// allow 1 request per second, with a bucket size of 5, cleanup every 1 minute, and expire limiters after 1 minute of inactivity
// this means that all API can only respond to 1 request per second for each IP address, and can burst to 5 requests per second.
router.Use(ratelimiter.RateLimiterMiddleware(rate.Every(1*time.Second), 5, 1*time.Minute, 1*time.Minute, func(c *gin.Context) string {
  return c.ClientIP()
}))

这是全局设置了一个基于客户端 IP 的限流中间件(router),限流机制的具体行为如下:

1. 速率限制rate.Every(1 * time.Second)
这个参数 rate.Every(1 * time.Second) 表示令牌生成速率为每秒 1 个。换句话说,每秒钟系统会向令牌桶中加入 1 个令牌。

2. 令牌桶容量5
令牌桶的容量为 5,这意味着令牌桶最多可以存储 5 个令牌。即使没有请求,令牌桶也最多只会积累 5 个令牌,超过这个容量的令牌会被丢弃。
如果有突发请求,系统可以一次性处理最多 5 个请求(假设这些请求到达时桶里有 5 个令牌)。

3. 清理间隔1 * time.Minute
每 1 分钟会清理那些 1 分钟内未被使用的限流器。如果一个限流器与某个 IP 地址关联,并且 1 分钟内没有请求来自该 IP 地址,该限流器将会被自动清除。

4. 过期时间1 * time.Minute
如果一个限流器关联的 IP 地址在 1 分钟内没有请求,意味着该限流器会被认为已过期,接着在下一次清理时会被删除。

5. 标识符func(c *gin.Context) string { return c.ClientIP() }
c.ClientIP() 函数返回请求的客户端 IP,作为限流的标识符。换句话说,每个客户端 IP 都会有自己独立的限流器。
限流行为是基于 IP 地址的,每个 IP 地址有自己独立的速率限制和令牌桶。
当然你可以自定义其他标识符,比如用户 UUID、固定全局字符串等。

行为解释

假设有不同的客户端通过不同的 IP 发送请求,具体行为如下:

速率控制:

  • 每个 IP 地址都有自己的令牌桶,且每秒生成 1 个令牌。如果请求的频率超过每秒 1 个,那么超过速率的请求将会被限制。
  • 如果 IP 地址发起的请求以每秒 1 次的频率发出,则每次请求都会成功,因为令牌桶中会在每秒自动生成 1 个令牌。


处理突发请求:

  • 如果某个 IP 短时间内发起多个请求(例如 5 个请求在 1 秒内到达),系统允许它立即处理前 5 个请求(因为桶的容量是 5),如果这些请求到达时桶里已经存满了令牌。
  • 第 6 个请求会被拒绝,直到 1 秒后新的令牌被生成。


持续高频请求的处理:

  • 如果某个 IP 持续以每秒超过 1 个请求的速率发出请求(例如每秒发出 5 个请求),系统会允许前 5 个请求(因为有 5 个令牌可以立即处理)。
  • 之后,每秒钟系统只会生成 1 个令牌(一秒钟允许一次请求),所以超出速率的请求将会被限流,返回 HTTP 429 "Too Many Requests" 错误。

清理与过期:

  • 如果某个 IP 地址在 1 分钟内没有发送任何请求,它的限流器将会被标记为过期。
  • 每 1 分钟,系统会清理这些过期的限流器以节省资源。

具体示例

场景 1:突发的 5 个请求

  • 假设一个客户端 IP 地址是 192.168.1.1,该客户端在 1 秒内发出 5 个请求,且这些请求正好匹配了令牌桶的最大容量(5 个令牌)。
  • 由于令牌桶中有足够的令牌(假设已累积了 5 个令牌),所有这 5 个请求都会被处理。
  • 如果该客户端立即发出第 6 个请求,它会被拒绝,因为令牌桶已空。
  • 该客户端需要等待 1 秒钟(系统生成新令牌)后,才能再次成功发出请求。

场景 2:持续的高频请求

  • 如果 192.168.1.1 持续以每秒 5 个请求的速率发出请求,系统会立即处理前 5 个请求(因为桶有 5 个令牌)。
  • 接下来的每秒只能处理 1 个请求,因为每秒只生成 1 个令牌。超出速率的请求会被限流,并返回 HTTP 429 错误。


场景 3:清理与过期

  • 如果某个 IP 地址在 1 分钟内没有发送任何请求,限流器将会被清理。
  • 当该 IP 地址再次发送请求时,系统会重新创建一个限流器,并重新开始限流计算。

总结

在这个配置下,基于每个客户端 IP 地址的限流器具有以下行为:

  • 每个 IP 地址 有自己的速率限制,每秒生成 1 个令牌。
  • 每个 IP 地址可以在短时间内发出最多 5 个突发请求(由于令牌桶的容量是 5)。
  • 如果某个 IP 地址持续发出大量请求,超出速率限制的请求将被拒绝,且返回 HTTP 429 错误。
  • 限流器将在 1 分钟内不活跃时被清理,以节省内存资源。
  • 这个限流器配置适合用在防止单个 IP 地址对服务过载的场景,同时允许短暂的突发流量。