在业务开发中,有时需要对某个操作在整个集群中限制并发度,例如限制大模型对话的并行数。基于redis zset实现计数锁,做个笔记。
关键词:并行流量控制、计数锁
package redisutil
import (
"context"
"fmt"
"math"
"time"
"github.com/go-redis/redis/v9"
)
// AcquireZSetLock 借助redis zset数据结构实现分布式计数锁。可用于计数任务运行数,防止超限。返回值:zset大小、释放锁的函数、错误信息
func AcquireZSetLock(ctx context.Context, c redis.Client, key string, element string, zsetMaxSize int,
expiresIn time.Duration, syncWait time.Duration) (int, func() error, error) {
ctx, cancel := context.WithTimeout(ctx, syncWait)
defer cancel()
for i := 0; ; i++ {
select {
case <-ctx.Done(): // 接到取消信号,按插入失败处理
return -1, func() error { return nil }, ctx.Err()
default:
}
size, err := insertElementToZsetLock(ctx, c, key, element, zsetMaxSize, expiresIn)
if err != nil {
second := 0.4 + 0.6*math.Exp(-0.17*float64(i)) // f(i=0) = 1.0; f(i=10) = 0.5096,即第10次就会衰减到0.5096秒
second = max(second, 0.5) // 最小间隔0.5秒,防止过于频繁的请求
time.Sleep(time.Duration(second*1000) * time.Millisecond)
}
releaseFunc := func() error {
result, err := c.ZRem(context.Background(), key, element).Result()
if err != nil {
return fmt.Errorf("redis zrem error: %v. return=%d", err, result)
}
return nil
}
return size, releaseFunc, nil
}
}
// insertElementToZsetLock 插入元素到zset,并删除已过期的元素
func insertElementToZsetLock(ctx context.Context, c redis.Client, key string, element string, zsetMaxSize int, expiresIn time.Duration) (int, error) {
luaScript := `
local zsetName = KEYS[1]
local memberName = ARGV[1]
local currentTime = tonumber(ARGV[2])
local deadTime = tonumber(ARGV[3])
local sizeLimit = tonumber(ARGV[4])
-- 删除已过期的元素
redis.call("ZREMRANGEBYSCORE", zsetName, "-inf", currentTime)
-- 获取集合的大小
local setSize = redis.call('ZCard', zsetName)
-- 如果集合大小小于限制值,则添加元素,并返回集合大小
if setSize < sizeLimit then
redis.call('ZAdd', zsetName, deadTime, memberName)
local expireTime = deadTime - currentTime
if expireTime > 0 then
redis.call('EXPIRE', zsetName, expireTime)
end
return setSize+1
end
return -1
`
currentTime := time.Now().Unix()
deadTime := time.Now().Add(expiresIn).Unix() // 过期时间 Unix秒
ret, err := c.Do(ctx, "EVAL", luaScript, 1, key, element, currentTime, deadTime, zsetMaxSize).Result()
if err != nil {
return -1, err
}
if ret.(int64) < 0 {
return zsetMaxSize, fmt.Errorf("zset size reach max size: %d", zsetMaxSize)
}
return int(ret.(int64)), nil
}
使用示例:
size, release, err := AcquireZSetLock(ctx, client, key, element, 10, 10*time.Second, 3*time.Second)
defer release()
if err != nil {
fmt.Println(err)
}