package com.fdkankan.site.aspectj; import cn.hutool.core.util.ObjectUtil; import cn.hutool.core.util.StrUtil; import com.fdkankan.site.annotation.RateLimiter; import com.fdkankan.site.common.ResultCode; import com.fdkankan.site.enums.LimitType; import com.fdkankan.site.exception.BusinessException; import com.fdkankan.site.util.IpUtils; import com.fdkankan.site.util.ServletUtils; import org.aspectj.lang.JoinPoint; import org.aspectj.lang.annotation.Aspect; import org.aspectj.lang.annotation.Before; import org.aspectj.lang.reflect.MethodSignature; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.core.io.ClassPathResource; import org.springframework.data.redis.core.RedisTemplate; import org.springframework.data.redis.core.script.DefaultRedisScript; import org.springframework.data.redis.core.script.RedisScript; import org.springframework.scripting.support.ResourceScriptSource; import org.springframework.stereotype.Component; import java.lang.reflect.Method; import java.util.Collections; import java.util.List; /** * 限流处理 * * @author fdkk */ @Aspect @Configuration public class RateLimiterAspect { private static final Logger log = LoggerFactory.getLogger(RateLimiterAspect.class); @Autowired private RedisTemplate redisTemplate; private RedisScript limitScript; @Autowired public void setLimitScript(RedisScript limitScript) { this.limitScript = limitScript; } @Before("@annotation(rateLimiter)") public void doBefore(JoinPoint point, RateLimiter rateLimiter) throws Throwable { String key = rateLimiter.key(); int time = rateLimiter.time(); int count = rateLimiter.count(); String combineKey = getCombineKey(rateLimiter, point); List keys = Collections.singletonList(combineKey); try { Long number = (Long) redisTemplate.execute(limitScript, keys, String.valueOf(count), String.valueOf(time)); if (ObjectUtil.isNull(number) || number.intValue() > count) { throw new BusinessException(ResultCode.RATE_LIMITER); } log.info("限制请求'{}',当前请求'{}',缓存key'{}'", count, number.intValue(), key); } catch (BusinessException e) { throw e; } catch (Exception e) { throw new RuntimeException("服务器限流异常,请稍后再试"); } } public String getCombineKey(RateLimiter rateLimiter, JoinPoint point) { StringBuffer stringBuffer = new StringBuffer(rateLimiter.key()); if (rateLimiter.limitType() == LimitType.IP) { stringBuffer.append(IpUtils.getIpAddr(ServletUtils.getRequest())).append("-"); } MethodSignature signature = (MethodSignature) point.getSignature(); Method method = signature.getMethod(); Class targetClass = method.getDeclaringClass(); stringBuffer.append(targetClass.getName()).append("-").append(method.getName()); return stringBuffer.toString(); } /** * 限流脚本 */ @Bean("limitRedisScript") public RedisScript limitRedisScript() { DefaultRedisScript redisScript = new DefaultRedisScript<>(); redisScript.setScriptText(limitScriptText()); redisScript.setResultType(Long.class); return redisScript; } /** * 限流脚本 */ private String limitScriptText() { return "local key = KEYS[1]\n" + "local count = tonumber(ARGV[1])\n" + "local time = tonumber(ARGV[2])\n" + "local current = redis.call('get', key);\n" + "if current and tonumber(current) > count then\n" + " return tonumber(current);\n" + "end\n" + "current = redis.call('incr', key)\n" + "if tonumber(current) == 1 then\n" + " redis.call('expire', key, time)\n" + "end\n" + "return tonumber(current);"; } }