提交 0541d00b authored 作者: mooncake9527's avatar mooncake9527

rl可直接传入redis conn,省去一个redis conn

上级 02999ead
......@@ -2,9 +2,10 @@ package middleware
import (
"context"
"net/http"
ctxUtil "gitlab.wanzhuangkj.com/tush/xpkg/gin/xctx"
"net/http"
"strconv"
"time"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
......@@ -92,6 +93,8 @@ func RequestID(opts ...RequestIDOption) gin.HandlerFunc {
c.Writer.Header().Set(ctxUtil.HeaderXRequestIDKey, requestID)
c.Next()
c.Writer.Header().Set(ctxUtil.HeaderXTimestampKey, strconv.FormatInt(time.Now().Unix(), 10))
}
}
......
package sw
import (
"context"
"sync"
"time"
)
// Window 窗口
type Window interface {
Start() time.Time
Count() int64
AddCount(n int64)
Reset(s time.Time, c int64)
Sync(ctx context.Context, now time.Time)
}
type StopFunc func()
type NewWindow func() (Window, StopFunc)
type Limiter struct {
size time.Duration
limit int64
mu sync.Mutex
curr Window
prev Window
}
func NewLimiter(size time.Duration, limit int64, newWindow NewWindow) (*Limiter, StopFunc) {
currWin, currStop := newWindow()
prevWin, _ := NewLocalWindow()
lim := &Limiter{
size: size,
limit: limit,
curr: currWin,
prev: prevWin,
}
return lim, currStop
}
func (lim *Limiter) Size() time.Duration {
return lim.size
}
func (lim *Limiter) Limit() int64 {
lim.mu.Lock()
defer lim.mu.Unlock()
return lim.limit
}
func (lim *Limiter) SetLimit(newLimit int64) {
lim.mu.Lock()
defer lim.mu.Unlock()
lim.limit = newLimit
}
func (lim *Limiter) Allow(ctx context.Context) bool {
return lim.AllowN(ctx, time.Now(), 1)
}
func (lim *Limiter) AllowN(ctx context.Context, now time.Time, n int64) bool {
lim.mu.Lock()
defer lim.mu.Unlock()
lim.advance(now)
elapsed := now.Sub(lim.curr.Start())
weight := float64(lim.size-elapsed) / float64(lim.size)
count := int64(weight*float64(lim.prev.Count())) + lim.curr.Count()
defer lim.curr.Sync(ctx, now)
if count+n > lim.limit {
return false
}
lim.curr.AddCount(n)
return true
}
// advance更新由于时间的推移而产生的当前/以前的Windows。
func (lim *Limiter) advance(now time.Time) {
//计算当前窗口的起始边界。
newCurrStart := now.Truncate(lim.size)
diffSize := newCurrStart.Sub(lim.curr.Start()) / lim.size
if diffSize >= 1 {
//当前窗口至少比预期窗口小一个窗口大小。
newPrevCount := int64(0)
if diffSize == 1 {
//新的previous-window将与旧的current-window重叠,因此它继承了计数。
//注意这里的计数可能不准确,因为它只是一个 当前窗口计数的快照,其本身趋向于由于同步行为的异步性质而不准确。
newPrevCount = lim.curr.Count()
}
lim.prev.Reset(newCurrStart.Add(-lim.size), newPrevCount)
//新的当前窗口计数总是为零。
lim.curr.Reset(newCurrStart, 0)
}
}
package sync_ratelimiter
package sw
import (
"context"
"fmt"
"github.com/RussellLuo/slidingwindow"
"github.com/redis/go-redis/v9"
"gitlab.wanzhuangkj.com/tush/xpkg/gin/xctx"
"go.uber.org/zap"
"strconv"
"sync"
"time"
......@@ -13,43 +14,58 @@ import (
type RedisDatastore struct {
client redis.Cmdable
ttl time.Duration
logger *zap.Logger
}
func (d *RedisDatastore) fullKey(key string, start int64) string {
return fmt.Sprintf("%s@%d", key, start)
}
func (d *RedisDatastore) Add(key string, start, value int64) (int64, error) {
func (d *RedisDatastore) Add(ctx context.Context, key string, start, value int64) (int64, error) {
k := d.fullKey(key, start)
startTime := time.Now()
c, err := d.client.IncrBy(context.Background(), k, value).Result()
if err != nil {
if d.logger != nil {
d.logger.Error(fmt.Sprintf("[SlideWindow] add fail,[k:%s add:%d]", k, value), zap.String("err", err.Error()), zap.String("cost", time.Since(startTime).String()), xctx.CtxTraceIDField(ctx))
}
return 0, err
}
_, _ = d.client.Expire(context.Background(), k, d.ttl).Result()
if d.logger != nil {
d.logger.Info(fmt.Sprintf("[SlideWindow] add success,[k:%s add:%d -> %d]", k, value, c), zap.String("cost", time.Since(startTime).String()), xctx.CtxTraceIDField(ctx))
}
return c, err
}
func (d *RedisDatastore) Get(key string, start int64) (int64, error) {
func (d *RedisDatastore) Get(ctx context.Context, key string, start int64) (int64, error) {
k := d.fullKey(key, start)
startTime := time.Now()
value, err := d.client.Get(context.Background(), k).Result()
if err != nil {
if err == redis.Nil {
err = nil
}
if d.logger != nil {
d.logger.Error(fmt.Sprintf("[SlideWindow] get fail [k:%s]", k), zap.String("err", err.Error()), zap.String("cost", time.Since(startTime).String()), xctx.CtxTraceIDField(ctx))
}
return 0, err
}
if d.logger != nil {
d.logger.Info(fmt.Sprintf("[SlideWindow] get success [k:%s v:%s]", k, value), zap.String("cost", time.Since(startTime).String()), xctx.CtxTraceIDField(ctx))
}
return strconv.ParseInt(value, 10, 64)
}
func NewRedisDatastore(client redis.Cmdable, ttl time.Duration) *RedisDatastore {
return &RedisDatastore{client: client, ttl: ttl}
func NewRedisDatastore(client redis.Cmdable, ttl time.Duration, logger *zap.Logger) *RedisDatastore {
return &RedisDatastore{client: client, ttl: ttl, logger: logger}
}
type IPRateLimiter struct {
mu sync.Mutex
options *IPRateLimiterOptions
store *RedisDatastore
ips map[string]*Limiter
ips map[string]*SLimiter
}
type IPRateLimiterOptions struct {
......@@ -58,6 +74,8 @@ type IPRateLimiterOptions struct {
redisAddr string
redisDB int
redisPassword string
logger *zap.Logger
redisCli redis.Cmdable
}
func defaultOptions() *IPRateLimiterOptions {
......@@ -105,45 +123,63 @@ func WithRedisPassword(pwd string) Option {
}
}
func WithLogger(log *zap.Logger) Option {
return func(o *IPRateLimiterOptions) {
o.logger = log
}
}
func WithRedisCli(redisCli redis.Cmdable) Option {
return func(o *IPRateLimiterOptions) {
o.redisCli = redisCli
}
}
type Option func(o *IPRateLimiterOptions)
func NewIPRateLimiter(opts ...Option) *IPRateLimiter {
options := defaultOptions()
options.apply(opts...)
return &IPRateLimiter{
rl := &IPRateLimiter{
mu: sync.Mutex{},
options: options,
store: NewRedisDatastore(
ips: make(map[string]*SLimiter),
}
if options.redisCli != nil {
rl.store = NewRedisDatastore(options.redisCli, 2*options.windowSize, options.logger)
} else {
rl.store = NewRedisDatastore(
redis.NewClient(&redis.Options{
Addr: options.redisAddr,
Password: options.redisPassword,
DB: options.redisDB,
}),
2*options.windowSize, // twice of window-size is just enough.
),
ips: make(map[string]*Limiter),
2*options.windowSize,
options.logger,
)
}
return rl
}
type Limiter struct {
L *slidingwindow.Limiter
type SLimiter struct {
L *Limiter
Stop func()
}
func (x *IPRateLimiter) Allow(ip string) bool {
func (x *IPRateLimiter) Allow(ctx context.Context, ip string) bool {
x.mu.Lock()
defer x.mu.Unlock()
if v, ok := x.ips[ip]; ok {
return v.L.Allow()
return v.L.Allow(ctx)
} else {
lim, stop := slidingwindow.NewLimiter(x.options.windowSize, x.options.maxRequests, func() (slidingwindow.Window, slidingwindow.StopFunc) {
return slidingwindow.NewSyncWindow(fmt.Sprintf("SyncRateLimiter_%s", ip), slidingwindow.NewBlockingSynchronizer(x.store, 500*time.Millisecond))
lim, stop := NewLimiter(x.options.windowSize, x.options.maxRequests, func() (Window, StopFunc) {
return NewSyncWindow(fmt.Sprintf("SyncRateLimiter_%s", ip), NewBlockingSynchronizer(x.store, 500*time.Millisecond))
})
x.ips[ip] = &Limiter{
x.ips[ip] = &SLimiter{
L: lim,
Stop: stop,
}
return lim.Allow()
return lim.Allow(ctx)
}
}
......
package sync_ratelimiter
package sw
import (
"context"
"gitlab.wanzhuangkj.com/tush/xpkg/utils/xtime"
"testing"
"time"
......@@ -21,7 +22,7 @@ func TestRateLimiter(t *testing.T) {
)
ip := "192.168.1.1"
for i := 0; i < 100; i++ {
allowed := rl.Allow(ip)
allowed := rl.Allow(context.Background(), ip)
fmt.Printf("%s Request %d: %t\n", time.Now().Format(xtime.Layout_DateTime), i+1, allowed)
time.Sleep(100 * time.Millisecond)
}
......
package sw
import (
"context"
"log"
"time"
)
type Datastore interface {
Add(ctx context.Context, key string, start, delta int64) (int64, error)
Get(ctx context.Context, key string, start int64) (int64, error)
}
type syncHelper struct {
store Datastore
syncInterval time.Duration
inProgress bool
lastSynced time.Time
}
func newSyncHelper(store Datastore, syncInterval time.Duration) *syncHelper {
return &syncHelper{store: store, syncInterval: syncInterval}
}
func (h *syncHelper) IsTimeUp(now time.Time) bool {
return !h.inProgress && now.Sub(h.lastSynced) >= h.syncInterval
}
func (h *syncHelper) InProgress() bool {
return h.inProgress
}
func (h *syncHelper) Begin(now time.Time) {
h.inProgress = true
h.lastSynced = now
}
func (h *syncHelper) End() {
h.inProgress = false
}
func (h *syncHelper) Sync(ctx context.Context, req SyncRequest) (resp SyncResponse, err error) {
var newCount int64
if req.Changes > 0 {
newCount, err = h.store.Add(ctx, req.Key, req.Start, req.Changes)
} else {
newCount, err = h.store.Get(ctx, req.Key, req.Start)
}
if err != nil {
return SyncResponse{}, err
}
return SyncResponse{
OK: true,
Start: req.Start,
Changes: req.Changes,
OtherChanges: newCount - req.Count,
}, nil
}
type BlockingSynchronizer struct {
helper *syncHelper
}
func NewBlockingSynchronizer(store Datastore, syncInterval time.Duration) *BlockingSynchronizer {
return &BlockingSynchronizer{
helper: newSyncHelper(store, syncInterval),
}
}
func (s *BlockingSynchronizer) Start() {}
func (s *BlockingSynchronizer) Stop() {}
func (s *BlockingSynchronizer) Sync(ctx context.Context, now time.Time, makeReq MakeFunc, handleResp HandleFunc) {
if s.helper.IsTimeUp(now) {
s.helper.Begin(now)
resp, err := s.helper.Sync(ctx, makeReq(ctx))
if err != nil {
log.Printf("err: %v\n", err)
}
handleResp(resp)
s.helper.End()
}
}
type NonblockingSynchronizer struct {
reqC chan SyncRequest
respC chan SyncResponse
stopC chan struct{}
exitC chan struct{}
helper *syncHelper
}
func NewNonblockingSynchronizer(store Datastore, syncInterval time.Duration) *NonblockingSynchronizer {
return &NonblockingSynchronizer{
reqC: make(chan SyncRequest),
respC: make(chan SyncResponse),
stopC: make(chan struct{}),
exitC: make(chan struct{}),
helper: newSyncHelper(store, syncInterval),
}
}
func (s *NonblockingSynchronizer) Start() {
go s.syncLoop()
}
func (s *NonblockingSynchronizer) Stop() {
close(s.stopC)
<-s.exitC
}
func (s *NonblockingSynchronizer) syncLoop() {
for {
select {
case req := <-s.reqC:
resp, err := s.helper.Sync(req.ctx, req)
if err != nil {
log.Printf("err: %v\n", err)
}
select {
case s.respC <- resp:
case <-s.stopC:
goto exit
}
case <-s.stopC:
goto exit
}
}
exit:
close(s.exitC)
}
func (s *NonblockingSynchronizer) Sync(ctx context.Context, now time.Time, makeReq MakeFunc, handleResp HandleFunc) {
if s.helper.IsTimeUp(now) {
select {
case s.reqC <- makeReq(ctx):
s.helper.Begin(now)
default:
}
}
if s.helper.InProgress() {
select {
case resp := <-s.respC:
handleResp(resp)
s.helper.End()
default:
}
}
}
package sw
import (
"context"
"time"
)
type LocalWindow struct {
start int64
count int64
}
func NewLocalWindow() (*LocalWindow, StopFunc) {
return &LocalWindow{}, func() {}
}
func (w *LocalWindow) Start() time.Time {
return time.Unix(0, w.start)
}
func (w *LocalWindow) Count() int64 {
return w.count
}
func (w *LocalWindow) AddCount(n int64) {
w.count += n
}
func (w *LocalWindow) Reset(s time.Time, c int64) {
w.start = s.UnixNano()
w.count = c
}
func (w *LocalWindow) Sync(ctx context.Context, now time.Time) {}
type (
SyncRequest struct {
Key string
Start int64
Count int64
Changes int64
ctx context.Context
}
SyncResponse struct {
OK bool
Start int64
Changes int64
OtherChanges int64
}
MakeFunc func(ctx context.Context) SyncRequest
HandleFunc func(SyncResponse)
)
type Synchronizer interface {
Start()
Stop()
Sync(context.Context, time.Time, MakeFunc, HandleFunc)
}
type SyncWindow struct {
LocalWindow
changes int64
key string
syncer Synchronizer
}
func NewSyncWindow(key string, syncer Synchronizer) (*SyncWindow, StopFunc) {
w := &SyncWindow{
key: key,
syncer: syncer,
}
w.syncer.Start()
return w, w.syncer.Stop
}
func (w *SyncWindow) AddCount(n int64) {
w.changes += n
w.LocalWindow.AddCount(n)
}
func (w *SyncWindow) Reset(s time.Time, c int64) {
w.changes = 0
w.LocalWindow.Reset(s, c)
}
func (w *SyncWindow) makeSyncRequest(ctx context.Context) SyncRequest {
return SyncRequest{
Key: w.key,
Start: w.LocalWindow.start,
Count: w.LocalWindow.count,
Changes: w.changes,
ctx: ctx,
}
}
func (w *SyncWindow) handleSyncResponse(resp SyncResponse) {
if resp.OK && resp.Start == w.LocalWindow.start {
w.LocalWindow.count += resp.OtherChanges
w.changes -= resp.Changes
}
}
func (w *SyncWindow) Sync(ctx context.Context, now time.Time) {
w.syncer.Sync(ctx, now, w.makeSyncRequest, w.handleSyncResponse)
}
......@@ -21,6 +21,8 @@ var (
// HeaderXRequestIDKey header request id key
HeaderXRequestIDKey = "X-Request-ID"
HeaderXTimestampKey = "Timestamp"
KeyTid = "tid"
KeyClientIP = "client_ip"
KeyUID = "userId"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论