java 注解实现接口防刷
首先定义一个防刷注解
@Retention(RetentionPolicy.RUNTIME)
@Target(METHOD)
@Documented
public @interface RateLimit {
int cycle() default 5;
int number() default 1;
String msg() default "重复请求";
}
定义一个防刷的业务类,使用因为要使用接口的注入方式,所以首先定义一个RateLimitService
接口:
public interface RateLimitService {
Boolean selectLimit(String ip, String url, RateLimit rateLimit);
}
实现类:
@Slf4j
public class DefaultRateLimitServiceImpl implements RateLimitService {
private RedisTemplate redisTemplate;
private static final String SCRIPT= "local limit = tonumber(ARGV[1]);"// 限制次数
+ "local expire_time = ARGV[2];"// 过期时间
+ "local result = redis.call('setNX',KEYS[1],1);"// key不存在时设置value为1,返回1、否则返回0
+ "if result == 1 then"// 返回值为1,key不存在此时需要设置过期时间
+ " redis.call('expire',KEYS[1],expire_time);"// 设置过期时间
+ " return 1; "// 返回1
+ "else"// key存在
+ " if tonumber(redis.call('GET', KEYS[1])) >= limit then"// 判断数目比对
+ " return 0;"// 如果超出限制返回0
+ " else" //
+ " redis.call('incr', KEYS[1]);"// key自增
+ " return 1 ;"// 返回1
+ " end "// 结束
+ "end";// 结束
public void setRedisTemplate(RedisTemplate<String, String> redisTemplate) {
this.redisTemplate = redisTemplate;
}
@Override
public Boolean selectLimit(String ip, String url, RateLimit rateLimit) {
String key="custom:rate"+ip+":"+url;
StringRedisSerializer stringRedisSerializer=new StringRedisSerializer();
redisTemplate.setValueSerializer(stringRedisSerializer);
redisTemplate.setKeySerializer(stringRedisSerializer);
DefaultRedisScript defaultRedisScript=new DefaultRedisScript<>(SCRIPT);
defaultRedisScript.setResultType(Boolean.class);
Boolean execute=null;
try{
execute = (Boolean) redisTemplate.execute(defaultRedisScript,Collections.singletonList(key),String.valueOf(rateLimit.number()),String.valueOf(rateLimit.cycle()));
}catch (Exception ex)
{
ex.printStackTrace();
}
if(!execute)
{
return false;
}
return true;
}
}
使用lua脚本是因为这样可以减少时间开销,精确到1ms,不然像爬虫一个循环都是ms级的可能没法做到效果。
redis配置类:
@Configuration
public class RateLimitConfig {
@Autowired
private RedisTemplate redisTemplate;
@Bean("RateLimitService")
@ConditionalOnMissingBean(RateLimitService.class)
public DefaultRateLimitServiceImpl defaultRateLimitService()
{
DefaultRateLimitServiceImpl defaultRateLimitService=new DefaultRateLimitServiceImpl();
defaultRateLimitService.setRedisTemplate(redisTemplate);
return defaultRateLimitService;
}
}
防刷的拦截器:
@Component
public class RateLimitInterInterceptor implements HandlerInterceptor {
private RateLimitService rateLimitService;
private static final ObjectMapper objectMapper=new ObjectMapper();
public void setRateLimitService(RateLimitService rateLimitService) {
this.rateLimitService = rateLimitService;
}
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
if (handler instanceof HandlerMethod){
HandlerMethod handlerMethod=(HandlerMethod)handler;
RateLimit methodAnnotation = handlerMethod.getMethodAnnotation(RateLimit.class);
if (methodAnnotation==null){ return true;}
String ip=request.getRemoteAddr();
String url=request.getRequestURI();
Boolean aBoolean = rateLimitService.selectLimit(ip, url, methodAnnotation);
if(!aBoolean)
{
response.setContentType("application/json;charset=UTF-8");
response.getWriter().write(objectMapper.writeValueAsString(new Result(500,"重复访问!",null)));
response.setStatus(HttpStatus.OK.value());
return false;
}
}
return true;
}
}
代码参考于csdn大佬的思路。