RateLimiterAspect.java 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. package com.fdkankan.site.aspectj;
  2. import cn.hutool.core.util.ObjectUtil;
  3. import cn.hutool.core.util.StrUtil;
  4. import com.fdkankan.site.annotation.RateLimiter;
  5. import com.fdkankan.site.common.ResultCode;
  6. import com.fdkankan.site.enums.LimitType;
  7. import com.fdkankan.site.exception.BusinessException;
  8. import com.fdkankan.site.util.IpUtils;
  9. import com.fdkankan.site.util.ServletUtils;
  10. import org.aspectj.lang.JoinPoint;
  11. import org.aspectj.lang.annotation.Aspect;
  12. import org.aspectj.lang.annotation.Before;
  13. import org.aspectj.lang.reflect.MethodSignature;
  14. import org.slf4j.Logger;
  15. import org.slf4j.LoggerFactory;
  16. import org.springframework.beans.factory.annotation.Autowired;
  17. import org.springframework.context.annotation.Bean;
  18. import org.springframework.context.annotation.Configuration;
  19. import org.springframework.core.io.ClassPathResource;
  20. import org.springframework.data.redis.core.RedisTemplate;
  21. import org.springframework.data.redis.core.script.DefaultRedisScript;
  22. import org.springframework.data.redis.core.script.RedisScript;
  23. import org.springframework.scripting.support.ResourceScriptSource;
  24. import org.springframework.stereotype.Component;
  25. import java.lang.reflect.Method;
  26. import java.util.Collections;
  27. import java.util.List;
  28. /**
  29. * 限流处理
  30. *
  31. * @author fdkk
  32. */
  33. @Aspect
  34. @Configuration
  35. public class RateLimiterAspect
  36. {
  37. private static final Logger log = LoggerFactory.getLogger(RateLimiterAspect.class);
  38. @Autowired
  39. private RedisTemplate redisTemplate;
  40. private RedisScript<Long> limitScript;
  41. @Autowired
  42. public void setLimitScript(RedisScript<Long> limitScript)
  43. {
  44. this.limitScript = limitScript;
  45. }
  46. @Before("@annotation(rateLimiter)")
  47. public void doBefore(JoinPoint point, RateLimiter rateLimiter) throws Throwable
  48. {
  49. String key = rateLimiter.key();
  50. int time = rateLimiter.time();
  51. int count = rateLimiter.count();
  52. String combineKey = getCombineKey(rateLimiter, point);
  53. List<Object> keys = Collections.singletonList(combineKey);
  54. try
  55. {
  56. Long number = (Long) redisTemplate.execute(limitScript, keys, String.valueOf(count), String.valueOf(time));
  57. if (ObjectUtil.isNull(number) || number.intValue() > count)
  58. {
  59. throw new BusinessException(ResultCode.RATE_LIMITER);
  60. }
  61. log.info("限制请求'{}',当前请求'{}',缓存key'{}'", count, number.intValue(), key);
  62. }
  63. catch (BusinessException e)
  64. {
  65. throw e;
  66. }
  67. catch (Exception e)
  68. {
  69. throw new RuntimeException("服务器限流异常,请稍后再试");
  70. }
  71. }
  72. public String getCombineKey(RateLimiter rateLimiter, JoinPoint point)
  73. {
  74. StringBuffer stringBuffer = new StringBuffer(rateLimiter.key());
  75. if (rateLimiter.limitType() == LimitType.IP)
  76. {
  77. stringBuffer.append(IpUtils.getIpAddr(ServletUtils.getRequest())).append("-");
  78. }
  79. MethodSignature signature = (MethodSignature) point.getSignature();
  80. Method method = signature.getMethod();
  81. Class<?> targetClass = method.getDeclaringClass();
  82. stringBuffer.append(targetClass.getName()).append("-").append(method.getName());
  83. return stringBuffer.toString();
  84. }
  85. /**
  86. * 限流脚本
  87. */
  88. @Bean("limitRedisScript")
  89. public RedisScript<Long> limitRedisScript() {
  90. DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>();
  91. redisScript.setScriptText(limitScriptText());
  92. redisScript.setResultType(Long.class);
  93. return redisScript;
  94. }
  95. /**
  96. * 限流脚本
  97. */
  98. private String limitScriptText()
  99. {
  100. return "local key = KEYS[1]\n" +
  101. "local count = tonumber(ARGV[1])\n" +
  102. "local time = tonumber(ARGV[2])\n" +
  103. "local current = redis.call('get', key);\n" +
  104. "if current and tonumber(current) > count then\n" +
  105. " return tonumber(current);\n" +
  106. "end\n" +
  107. "current = redis.call('incr', key)\n" +
  108. "if tonumber(current) == 1 then\n" +
  109. " redis.call('expire', key, time)\n" +
  110. "end\n" +
  111. "return tonumber(current);";
  112. }
  113. }