zhangqian
2023-08-26 5193dcb9336e853502baf8a539d3f45efebe2f86
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
package middleware
 
import (
    "context"
    "errors"
    "net/http"
    "time"
 
    "go.uber.org/zap"
 
    "github.com/gin-gonic/gin"
    "srm/global"
    "srm/model/common/response"
)
 
type LimitConfig struct {
    // GenerationKey 根据业务生成key 下面CheckOrMark查询生成
    GenerationKey func(c *gin.Context) string
    // 检查函数,用户可修改具体逻辑,更加灵活
    CheckOrMark func(key string, expire int, limit int) error
    // Expire key 过期时间
    Expire int
    // Limit 周期时间
    Limit int
}
 
func (l LimitConfig) LimitWithTime() gin.HandlerFunc {
    return func(c *gin.Context) {
        if err := l.CheckOrMark(l.GenerationKey(c), l.Expire, l.Limit); err != nil {
            c.JSON(http.StatusOK, gin.H{"code": response.ERROR, "msg": err})
            c.Abort()
            return
        } else {
            c.Next()
        }
    }
}
 
// DefaultGenerationKey 默认生成key
func DefaultGenerationKey(c *gin.Context) string {
    return "GVA_Limit" + c.ClientIP()
}
 
func DefaultCheckOrMark(key string, expire int, limit int) (err error) {
    // 判断是否开启redis
    if global.GVA_REDIS == nil {
        return err
    }
    if err = SetLimitWithTime(key, limit, time.Duration(expire)*time.Second); err != nil {
        global.GVA_LOG.Error("limit", zap.Error(err))
    }
    return err
}
 
func DefaultLimit() gin.HandlerFunc {
    return LimitConfig{
        GenerationKey: DefaultGenerationKey,
        CheckOrMark:   DefaultCheckOrMark,
        Expire:        global.GVA_CONFIG.System.LimitTimeIP,
        Limit:         global.GVA_CONFIG.System.LimitCountIP,
    }.LimitWithTime()
}
 
// SetLimitWithTime 设置访问次数
func SetLimitWithTime(key string, limit int, expiration time.Duration) error {
    count, err := global.GVA_REDIS.Exists(context.Background(), key).Result()
    if err != nil {
        return err
    }
    if count == 0 {
        pipe := global.GVA_REDIS.TxPipeline()
        pipe.Incr(context.Background(), key)
        pipe.Expire(context.Background(), key, expiration)
        _, err = pipe.Exec(context.Background())
        return err
    } else {
        // 次数
        if times, err := global.GVA_REDIS.Get(context.Background(), key).Int(); err != nil {
            return err
        } else {
            if times >= limit {
                if t, err := global.GVA_REDIS.PTTL(context.Background(), key).Result(); err != nil {
                    return errors.New("请求太过频繁,请稍后再试")
                } else {
                    return errors.New("请求太过频繁, 请 " + t.String() + " 秒后尝试")
                }
            } else {
                return global.GVA_REDIS.Incr(context.Background(), key).Err()
            }
        }
    }
}