SpringBoot 限流
自定义注解助力系统保护与高效运行
一、引入依赖
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>2.6.0</version>
</parent>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
二、创建注解
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface RateLimiter {
String key() default "rate_limit:";
/**
* 限流时间,单位秒
* @return
*/
int time() default 5;
/**
* 限流次数
* @return
*/
int count() default 10;
}
三、Redis 配置
@Configuration
@EnableCaching
public class RedisConfig extends CachingConfigurerSupport {
@Bean
public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory redisConnectionFactory) {
RedisTemplate<String, Object> template = new RedisTemplate<>();
template.setConnectionFactory(redisConnectionFactory);
// 设置 key 和 value 的序列化方式,可以根据需要进行定制
template.setKeySerializer(new StringRedisSerializer());
template.setValueSerializer(new GenericJackson2JsonRedisSerializer());
return template;
}
@Bean
public DefaultRedisScript<Long> limitScript() {
DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>();
redisScript.setScriptText(limitScriptText());
redisScript.setResultType(Long.class);
return redisScript;
}
/**
* 限流脚本
*/
private String limitScriptText() {
return "local key = KEYS[1]\n" +
"local count = tonumber(ARGV[1])\n" +
"local time = tonumber(ARGV[2])\n" +
"local current = redis.call('get', key);\n" +
"if current and tonumber(current) > count then\n" +
" return tonumber(current);\n" +
"end\n" +
"current = redis.call('incr', key)\n" +
"if tonumber(current) == 1 then\n" +
" redis.call('expire', key, time)\n" +
"end\n" +
"return tonumber(current);";
}
}
四、创建切面
1.第一种写法:
@Slf4j
@Aspect // 标记为切面类
@Component
public class RateLimiterAspect {
@Autowired // 注入RedisTemplate实例
private RedisTemplate<String, Object> redisTemplate; // Redis操作模板
@Autowired // 注入RedisScript实例
private RedisScript<Long> limitScript; // 用于执行Lua脚本的RedisScript实例
/**
* 在方法执行之前进行限流检查。
*
* @param point 当前JoinPoint(连接点)
*/
@Before("@annotation(org.example.common.annotation.RateLimiter)") // 在带有RateLimiter注解的方法执行前触发
public void doBefore(JoinPoint point) {
MethodSignature signature = (MethodSignature) point.getSignature(); // 获取方法签名
Method method = signature.getMethod(); // 获取被通知的方法
// 在这里,你可以获取方法上的注解
RateLimiter annotation = method.getAnnotation(RateLimiter.class);
if (annotation == null) {
// 注解对象为空,直接返回
return;
}
// 获取RateLimiter注解中的时间窗口长度
int time = annotation.time();
// 获取RateLimiter注解中的请求次数限制
int count = annotation.count();
// 组合限流键名
String combineKey = getCombineKey(annotation,point);
// 将组合后的键名封装成List
List<String> keys = Collections.singletonList(combineKey);
try {
// 使用RedisTemplate执行Lua脚本,传递键名、请求次数限制和时间窗口长度作为参数
Long number = redisTemplate.execute(limitScript, keys, count, time);
// 如果返回的数字大于请求次数限制,则抛出异常提示请求过于频繁
if (number != null && number > count) {
throw new RuntimeException("请求过于频繁,请稍后再试");
}
} catch (Exception ex) {
// 打印异常堆栈信息
ex.printStackTrace();
}
}
/**
* 获取组合的限流键名。
*
* @param point 当前JoinPoint(连接点)
* @return 组合后的限流键名
*/
public String getCombineKey(RateLimiter rateLimiter, JoinPoint point) {
StringBuffer stringBuffer = new StringBuffer(rateLimiter.key());
stringBuffer.append(IpUtils.getIpAddr(ServletUtils.getRequest())).append("-");
MethodSignature signature = (MethodSignature) point.getSignature();
Method method = signature.getMethod();
Class<?> targetClass = method.getDeclaringClass();
stringBuffer.append(targetClass.getName()).append("-").append(method.getName());
return stringBuffer.toString();
}
}
2.第二种写法:
@Slf4j
@Aspect // 标记为切面类
@Component
public class RateLimiterAspect {
@Autowired // 注入RedisTemplate实例
private RedisTemplate<String, Object> redisTemplate; // Redis操作模板
@Autowired // 注入RedisScript实例
private RedisScript<Long> limitScript; // 用于执行Lua脚本的RedisScript实例
/**
* 在方法执行之前进行限流检查。
*
* @param point 当前JoinPoint(连接点)
*/
@Before("@annotation(rateLimiter)")
public void doBefore(JoinPoint point, RateLimiter rateLimiter) {
// 获取RateLimiter注解中的时间窗口长度
int time = rateLimiter.time();
// 获取RateLimiter注解中的请求次数限制
int count = rateLimiter.count();
// 组合限流键名
String combineKey = getCombineKey(rateLimiter,point);
// 将组合后的键名封装成List
List<String> keys = Collections.singletonList(combineKey);
try {
// 使用RedisTemplate执行Lua脚本,传递键名、请求次数限制和时间窗口长度作为参数
Long number = redisTemplate.execute(limitScript, keys, count, time);
// 如果返回的数字大于请求次数限制,则抛出异常提示请求过于频繁
if (number != null && number > count) {
throw new RuntimeException("请求过于频繁,请稍后再试");
}
} catch (Exception ex) {
// 打印异常堆栈信息
ex.printStackTrace();
}
}
/**
* 获取组合的限流键名。
*
* @param point 当前JoinPoint(连接点)
* @return 组合后的限流键名
*/
public String getCombineKey(RateLimiter rateLimiter, JoinPoint point) {
StringBuffer stringBuffer = new StringBuffer(rateLimiter.key());
stringBuffer.append(IpUtils.getIpAddr(ServletUtils.getRequest())).append("-");
MethodSignature signature = (MethodSignature) point.getSignature();
Method method = signature.getMethod();
Class<?> targetClass = method.getDeclaringClass();
stringBuffer.append(targetClass.getName()).append("-").append(method.getName());
return stringBuffer.toString();
}
}
五、配置 Application
server:
port: 8081
spring:
# redis 配置
redis:
# 地址
host: 127.0.0.1
# 端口,默认为6379
port: 6379
# 数据库索引
database: 0
# 密码
password:
# 连接超时时间
timeout: 10s
lettuce:
pool:
# 连接池中的最小空闲连接
min-idle: 0
# 连接池中的最大空闲连接
max-idle: 8
# 连接池的最大数据库连接数
max-active: 8
# #连接池最大阻塞等待时间(使用负值表示没有限制)
max-wait: -1ms
六、工具
public class IpUtils {
public static String getIpAddr(HttpServletRequest request)
{
if (request == null)
{
return "unknown";
}
String ip = request.getHeader("x-forwarded-for");
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip))
{
ip = request.getHeader("Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip))
{
ip = request.getHeader("X-Forwarded-For");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip))
{
ip = request.getHeader("WL-Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip))
{
ip = request.getHeader("X-Real-IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip))
{
ip = request.getRemoteAddr();
}
return "0:0:0:0:0:0:0:1".equals(ip) ? "127.0.0.1" : clean(ip);
}
public static String clean(String content){
return new HTMLFilter().filter(content);
}
}
public class ServletUtils {
/**
* 获取request
*/
public static HttpServletRequest getRequest(){
return getRequestAttributes().getRequest();
}
public static ServletRequestAttributes getRequestAttributes(){
RequestAttributes attributes = RequestContextHolder.getRequestAttributes();
return (ServletRequestAttributes) attributes;
}
}
七、测试 Controller
@RestController
public class TestController {
/**
* 测试接口方法,用于返回一个简单的字符串响应。
*
* @RateLimiter 注解用于限制此方法的访问频率。
* 该方法在10秒的时间窗口内最多只能被调用20次。
*
* @return 返回一个字符串,表示这是测试接口
*/
//限流注解,time为时间窗口长度(秒),count为时间窗口内的最大请求次数
@RateLimiter(time = 60, count = 2)
@GetMapping("/test") // GET请求映射到/test路径
public String test() {
return "测试接口"; // 返回字符串表示这是测试接口
}
}