提交 337c7c0f authored 作者: mooncake9527's avatar mooncake9527

滑动窗口ip限速器

上级 c63111d3
......@@ -193,8 +193,8 @@ func Logging(opts ...Option) gin.HandlerFunc {
newWriter := &bodyLogWriter{body: &bytes.Buffer{}, ResponseWriter: c.Writer}
c.Writer = newWriter
ip := ips.GetIP(c)
c.Set(ctxUtil.KeyIP, ip)
ip := ips.GetClientIP(c)
c.Set(ctxUtil.KeyClientIP, ip)
// processing requests
c.Next()
......
package ratelimiter
import (
"sync"
"time"
)
type Window struct {
Start time.Time
Requests int
}
type IPRateLimiter struct {
mu sync.Mutex
ips map[string][]*Window
windowSize time.Duration // 滑动窗口时间长度
maxRequests int // 窗口内最大请求数
}
func NewIPRateLimiter(windowSize time.Duration, maxRequests int) *IPRateLimiter {
return &IPRateLimiter{
ips: make(map[string][]*Window),
windowSize: windowSize,
maxRequests: maxRequests,
}
}
func (rl *IPRateLimiter) Allow(ip string) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
cutoff := now.Add(-rl.windowSize)
if windows, exists := rl.ips[ip]; exists {
var validWindows []*Window
for _, w := range windows {
if w.Start.After(cutoff) {
validWindows = append(validWindows, w)
}
}
rl.ips[ip] = validWindows
}
total := 0
for _, w := range rl.ips[ip] {
total += w.Requests
}
if total >= rl.maxRequests {
return false
}
if len(rl.ips[ip]) == 0 || now.Sub(rl.ips[ip][len(rl.ips[ip])-1].Start) > time.Second {
rl.ips[ip] = append(rl.ips[ip], &Window{
Start: now,
Requests: 1,
})
} else {
rl.ips[ip][len(rl.ips[ip])-1].Requests++
}
return true
}
package ratelimiter
import (
"testing"
"time"
)
import (
"fmt"
)
func TestRateLimiter(t *testing.T) {
rl := NewIPRateLimiter(1*time.Second, 5) // 每分钟10次限制
ip := "192.168.1.1"
for i := 0; i < 100; i++ {
allowed := rl.Allow(ip)
fmt.Printf("Request %d: %t\n", i+1, allowed)
time.Sleep(100 * time.Millisecond)
}
}
......@@ -22,7 +22,7 @@ var (
HeaderXRequestIDKey = "X-Request-ID"
KeyTid = "tid"
KeyIP = "IP"
KeyClientIP = "client_ip"
KeyUID = "userId"
KeyUType = "userType"
KeyCompanyID = "companyId"
......@@ -139,10 +139,10 @@ func GetCtxTid(ctx context.Context) string {
return tid
}
// GetIP get ip from context.Context, the client ip is set in gin.Context
func GetIP(ctx context.Context) string {
// GetClientIP get ip from context.Context, the client ip is set in gin.Context
func GetClientIP(ctx context.Context) string {
ip := ""
ipVal := ctx.Value(KeyIP)
ipVal := ctx.Value(KeyClientIP)
if ipVal != nil {
if str, ok := ipVal.(string); ok {
ip = str
......
......@@ -12,7 +12,7 @@ import (
)
// get ip of client from context
func GetIP(c *gin.Context) string {
func GetClientIP(c *gin.Context) string {
ip := c.Request.Header.Get("X-Forwarded-For")
if strings.Contains(ip, "127.0.0.1") || ip == "" {
ip = c.Request.Header.Get("X-real-ip")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论