最近需要自己做一个限流功能,其他业务代码都好说。唯一的终点就是限流实现,想到redis可以实现令牌桶。一拍脑门,就用它了!

场景描述

真实开发中才发现想的太简单,如果是基于redis提供的命令在代码中调用的话,效率是小事。原子性根本无法保证!如下:

  1. 线程1 获取到令牌桶获取总数为10
  2. 线程1 消耗1个令牌,剩余9个令牌
  3. 线程2 获取到令牌桶获取总数为10
  4. 线程1 刷新令牌为9个
  5. 线程2 消耗1个令牌,剩余9个令牌
  6. 线程2 刷新令牌为9个

原子性完全无法保证。加锁? 那多节点岂不是需要在引入分布式锁?看来服务器中实现不可取,保证原子性的话,势必要写LUA代码执行了。网上翻阅了一些demo,没找到想要的。想到自己搭建的spring cloud gateway是有内置令牌桶实现的。开始翻阅源码,终于找到最好的方案。

gateway中redis令牌桶实现类是:org.springframework.cloud.gateway.filter.ratelimit.RedisRateLimiter

public Mono<Response> isAllowed(String routeId, String id) {
			...
			// 这行是通过lua判断是否被限流
			Flux<List<Long>> flux = this.redisTemplate.execute(this.script, keys, scriptArgs);
			...
	}

顺着这个类找到lua代码是从org.springframework.cloud.gateway.config.GatewayRedisAutoConfiguration注入进来的

public RedisScript redisRequestRateLimiterScript() {
		DefaultRedisScript redisScript = new DefaultRedisScript<>();
		redisScript.setScriptSource(
				new ResourceScriptSource(new ClassPathResource("META-INF/scripts/request_rate_limiter.lua")));
		redisScript.setResultType(List.class);
		return redisScript;
	}

核心在request_rate_limiter.lua这个文件中

-- 获取到限流资源令牌数的key和响应时间戳的key
local tokens_key = KEYS[1]
local timestamp_key = KEYS[2]
--redis.log(redis.LOG_WARNING, "tokens_key " .. tokens_key)
-- 分别获取填充速率、令牌桶容量、当前时间戳、消耗令牌数
local rate = tonumber(ARGV[1])
local capacity = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local requested = tonumber(ARGV[4])
-- 计算出失效时间,大概是令牌桶填满时间的两倍
local fill_time = capacity/rate
local ttl = math.floor(fill_time*2)

--redis.log(redis.LOG_WARNING, "rate " .. ARGV[1])
--redis.log(redis.LOG_WARNING, "capacity " .. ARGV[2])
--redis.log(redis.LOG_WARNING, "now " .. ARGV[3])
--redis.log(redis.LOG_WARNING, "requested " .. ARGV[4])
--redis.log(redis.LOG_WARNING, "filltime " .. fill_time)
--redis.log(redis.LOG_WARNING, "ttl " .. ttl)
-- 获取到最近一次的剩余令牌数,如果不存在说明令牌桶是满的
local last_tokens = tonumber(redis.call("get", tokens_key))
if last_tokens == nil then
  last_tokens = capacity
end
--redis.log(redis.LOG_WARNING, "last_tokens " .. last_tokens)
-- 上次消耗令牌的时间戳,不存在视为0
local last_refreshed = tonumber(redis.call("get", timestamp_key))
if last_refreshed == nil then
  last_refreshed = 0
end
--redis.log(redis.LOG_WARNING, "last_refreshed " .. last_refreshed)
-- 计算出间隔时间
local delta = math.max(0, now-last_refreshed)
-- 剩余令牌数量 =  “令牌桶容量” 和 “最后令牌数+(填充速率*时间间隔)”之间的最小值
local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
-- 如果剩余令牌数量大于等于消耗令牌的数量则流量通过,否则不通过
local allowed = filled_tokens >= requested
local new_tokens = filled_tokens
local allowed_num = 0
if allowed then
  new_tokens = filled_tokens - requested
  allowed_num = 1
end

--redis.log(redis.LOG_WARNING, "delta " .. delta)
--redis.log(redis.LOG_WARNING, "filled_tokens " .. filled_tokens)
--redis.log(redis.LOG_WARNING, "allowed_num " .. allowed_num)
--redis.log(redis.LOG_WARNING, "new_tokens " .. new_tokens)
-- 最后保存数据现场
if ttl > 0 then
  redis.call("setex", tokens_key, ttl, new_tokens)
  redis.call("setex", timestamp_key, ttl, now)
end

-- return { allowed_num, new_tokens, capacity, filled_tokens, requested, new_tokens }
return { allowed_num, new_tokens }

解决方案

好了,可以开始写自己的限流工具类了。

import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.core.script.RedisScript;

import java.time.Instant;
import java.util.Arrays;
import java.util.List;

/**
 * redis限流器
 *
 */
public class MyRedisLimiter{

    private RedisTemplate redisTemplate;

    private static final Long SUCCESS_FLAG = 1L;


    /**
     * 判断是否允许访问
     *@param id 这次获取令牌桶的id
     *@param rate 每秒填充速率
     *@param capacity 令牌桶最大容量
     *@param tokens 每次访问消耗几个令牌
     *@return true 允许访问 false 不允许访问
     */
    public boolean isAllowed(String id,int rate,int capacity,int tokens){
        RedisScript<Long> redisScript = new DefaultRedisScript<>(SCRIPT,Long.class);
        
        Object result = redisTemplate.execute(redisScript,
                getKey(id),rate, capacity,
                Instant.now().getEpochSecond(), tokens);
        return SUCCESS_FLAG.equals(result);
    }

    private List<String> getKey(String id){
        String prefix = "limiter:"+id;
        String tokenKey = prefix + ":tokens";
        String timestampKey = prefix + ":timestamp";
        return Arrays.asList(tokenKey, timestampKey);
    }

    private static final String SCRIPT = "local tokens_key = KEYS[1]\n" +
            "local timestamp_key = KEYS[2]\n" +
            "local rate = tonumber(ARGV[1])\n" +
            "local capacity = tonumber(ARGV[2])\n" +
            "local now = tonumber(ARGV[3])\n" +
            "local requested = tonumber(ARGV[4])\n" +
            "local fill_time = capacity/rate\n" +
            "local ttl = math.floor(fill_time*2)\n" +
            "local last_tokens = tonumber(redis.call('get', tokens_key))\n" +
            "if last_tokens == nil then\n" +
            "  last_tokens = capacity\n" +
            "end\n" +
            "local last_refreshed = tonumber(redis.call('get', timestamp_key))\n" +
            "if last_refreshed == nil then\n" +
            "  last_refreshed = 0\n" +
            "end\n" +
            "local diff_time = math.max(0, now-last_refreshed)\n" +
            "local filled_tokens = math.min(capacity, last_tokens+(diff_time*rate))\n" +
            "local allowed = filled_tokens >= requested\n" +
            "local new_tokens = filled_tokens\n" +
            "local allowed_num = 0\n" +
            "if allowed then\n" +
            "  new_tokens = filled_tokens - requested\n" +
            "  allowed_num = 1\n" +
            "end\n" +
            "if ttl > 0 then\n" +
            "  redis.call('setex', tokens_key, ttl, new_tokens)\n" +
            "  redis.call('setex', timestamp_key, ttl, now)\n" +
            "end\n" +
            "return allowed_num\n";
}

完事儿,这个令牌桶实现完全是参考的spring cloud gateway,食用起来也比较放心。
简单描述一下这个工具类使用的效果吧。

	isAllowed("testId1",1,60,1);

上面这个描述代表:令牌桶testId1。每分钟可通过访问60次。

当然这是理想情况。极限情况的话,应该是可以访问120次的。极限场景如下

  • testId1令牌桶未被使用的时间>=60秒
  • testId1令牌桶在某一刻开始被使用
  • 令牌被消耗的同时也在被填充
  • 在最初被使用的60秒内,令牌桶初始有60个令牌,使用期限有填充了60个

我个人理解,令牌桶主要是为了保证使用速率。对于上面这个场景,到底算不算bug。看每个人的使用情况。不过我已经对上述情况做了解决,只需要如下的几点小改动。

  1. 传入周期时间内最大流量数(周期时间:桶容量60,填充速率1/s,那么周期时间=60s)
  2. 获取key的方法增加一个记录周期时间内数量的key
  3. 修改lua脚本

改动如下:

/**
     * 判断是否允许访问
     *@param id 这次获取令牌桶的id
     *@param rate 每秒填充速率
     *@param capacity 令牌桶最大容量
     *@param tokens 每次访问消耗几个令牌
     *@param maxCount 周期时间内最大访问量
     *@return true 允许访问 false 不允许访问
     */
    public boolean isAllowed(String id,int rate,int capacity,int tokens,int maxCount){
        RedisScript<Long> redisScript = new DefaultRedisScript<>(SCRIPT,Long.class);
        Object result = redisTemplate.execute(redisScript,
                getKey(id),rate, capacity,
                Instant.now().getEpochSecond(), tokens,maxCount);
        return SUCCESS_FLAG.equals(result);
    }

    private List<String> getKey(String id){
        String prefix = "limiter:"+id;
        String tokenKey = prefix + ":tokens";
        String timestampKey = prefix + ":timestamp";
        String countKey = prefix + ":count";
        return Arrays.asList(tokenKey, timestampKey,countKey);
    }

    private static final String SCRIPT = "local tokens_key = KEYS[1]\n" +
            "local timestamp_key = KEYS[2]\n" +
            "local count_key = KEYS[3]\n" +
            "local rate = tonumber(ARGV[1])\n" +
            "local capacity = tonumber(ARGV[2])\n" +
            "local now = tonumber(ARGV[3])\n" +
            "local requested = tonumber(ARGV[4])\n" +
            "local min_max = tonumber(ARGV[5])\n" +
            "local fill_time = capacity/rate\n" +
            "local ttl = math.floor(fill_time*2)\n" +
            "local has_count = tonumber(redis.call('get', count_key))\n" +
            "if has_count == nil then\n" +
            "  has_count = 0\n" +
            "end\n" +
            "if has_count >= min_max then\n" +
            "return 0\n" +
            "end\n" +
            "local last_tokens = tonumber(redis.call('get', tokens_key))\n" +
            "if last_tokens == nil then\n" +
            "  last_tokens = capacity\n" +
            "end\n" +
            "local last_refreshed = tonumber(redis.call('get', timestamp_key))\n" +
            "if last_refreshed == nil then\n" +
            "  last_refreshed = 0\n" +
            "end\n" +
            "local diff_time = math.max(0, now-last_refreshed)\n" +
            "local filled_tokens = math.min(capacity, last_tokens+(diff_time*rate))\n" +
            "local allowed = filled_tokens >= requested\n" +
            "local new_tokens = filled_tokens\n" +
            "local allowed_num = 0\n" +
            "if allowed then\n" +
            "  new_tokens = filled_tokens - requested\n" +
            "  allowed_num = 1\n" +
            "end\n" +
            "if ttl > 0 then\n" +
            "  redis.call('setex', tokens_key, ttl, new_tokens)\n" +
            "  redis.call('setex', timestamp_key, ttl, now)\n" +
            "end\n" +
            "local count_ttl = tonumber(redis.call('ttl',count_key))\n" +
            "if count_ttl < 0 then\n" +
            "  count_ttl = fill_time\n" +
            "end\n" +
            "redis.call('setex', count_key,count_ttl , has_count+1)\n" +
            "return allowed_num\n";
}

这样的改动做到了保持访问速率和吞吐量的可控,但是有没有必要这样就看自己的需求了。

最后想说一句,优秀的开源项目真的是宝藏。学会使用它,然后站在巨人的肩膀上前进

PS:常见问题解决

1. 脚本执行失败

控制台打印异常如下:

org.springframework.dao.InvalidDataAccessApiUsageException: 
ERR Error running script (call to f_60a814775b950dd3853efa8ee996aa08537dd41b): @user_script:7: user_script:7: attempt to perform arithmetic on local 'capacity' (a nil value) ; nested exception is redis.clients.jedis.exceptions.JedisDataException: ERR Error running script (call to f_60a814775b950dd3853efa8ee996aa08537dd41b): @user_script:7: user_script:7: attempt to perform arithmetic on local 'capacity' (a nil value) 

可能是系列化问题,我这里是配置过全局序列化的使用的是这样的:

redisTemplate.setValueSerializer(new FastJsonRedisSerializer<>(Object.class));

注意:自己项目里使用了其他方式的话可以在执行LUA之前改一下redisTemplate的序列化方式,执行完再改回来。或者考虑其他解决方案

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐