package middleware import ( "bytes" "crypto/sha256" "encoding/hex" "io" "net/http" "os" "strconv" "strings" "sync" "time" "github.com/gin-gonic/gin" ) // AdminPOSTSecurity 对 /api/admin 下 POST 校验时间戳、IP 频率、重复请求;multipart 上传仅做限流不做 body 去重 func AdminPOSTSecurity() gin.HandlerFunc { ipLimit := getIntEnv("ADMIN_POST_IP_PER_MIN", 120) dedupeSec := getIntEnv("ADMIN_DEDUPE_SEC", 3) tsSkew := time.Duration(getIntEnv("ADMIN_REQUEST_TS_SKEW_SEC", 300)) * time.Second return func(c *gin.Context) { if c.Request.Method != http.MethodPost { c.Next() return } tsStr := c.GetHeader("X-Request-Timestamp") if tsStr == "" { c.JSON(http.StatusBadRequest, gin.H{"error": "缺少请求头 X-Request-Timestamp(毫秒时间戳)"}) c.Abort() return } tsMs, err := strconv.ParseInt(tsStr, 10, 64) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "X-Request-Timestamp 格式无效"}) c.Abort() return } clientT := time.UnixMilli(tsMs) if d := time.Since(clientT); d > tsSkew || d < -tsSkew { c.JSON(http.StatusBadRequest, gin.H{"error": "请求时间戳无效或时钟偏差过大"}) c.Abort() return } ip := c.ClientIP() if !ipPostLimiter.allow("ip:"+ip, ipLimit, time.Minute) { c.JSON(http.StatusTooManyRequests, gin.H{"error": "该 IP 请求过于频繁,请稍后再试"}) c.Abort() return } ct := c.GetHeader("Content-Type") if strings.Contains(strings.ToLower(ct), "multipart/form-data") { c.Next() return } body, err := io.ReadAll(c.Request.Body) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败"}) c.Abort() return } c.Request.Body = io.NopCloser(bytes.NewReader(body)) sig := hashSig(c.FullPath(), c.Request.URL.RawQuery, body) key := ip + "|" + sig if !dedupeStore.try(key, time.Duration(dedupeSec)*time.Second) { c.JSON(http.StatusTooManyRequests, gin.H{"error": "相同请求请勿在 3 秒内重复提交"}) c.Abort() return } c.Next() } } // AdminPOSTUserRateLimit 需在 AuthRequired 之后:按账号限制 POST 频率 func AdminPOSTUserRateLimit() gin.HandlerFunc { userLimit := getIntEnv("ADMIN_POST_USER_PER_MIN", 80) return func(c *gin.Context) { if c.Request.Method != http.MethodPost { c.Next() return } uid, ok := c.Get("user_id") if !ok { c.Next() return } suid, _ := uid.(string) if suid == "" { c.Next() return } if !ipPostLimiter.allow("uid:"+suid, userLimit, time.Minute) { c.JSON(http.StatusTooManyRequests, gin.H{"error": "该账号请求过于频繁,请稍后再试"}) c.Abort() return } c.Next() } } func hashSig(path, query string, body []byte) string { h := sha256.New() h.Write([]byte(path)) h.Write([]byte{0}) h.Write([]byte(query)) h.Write([]byte{0}) h.Write(body) return hex.EncodeToString(h.Sum(nil)) } type slidingLimiter struct { mu sync.Mutex // key -> 时间戳列表(纳秒) m map[string][]int64 } var ipPostLimiter = &slidingLimiter{m: make(map[string][]int64)} func (s *slidingLimiter) allow(key string, max int, window time.Duration) bool { now := time.Now().UnixNano() cutoff := now - window.Nanoseconds() s.mu.Lock() defer s.mu.Unlock() list := s.m[key] out := list[:0] for _, t := range list { if t >= cutoff { out = append(out, t) } } if len(out) >= max { s.m[key] = out return false } out = append(out, now) s.m[key] = out return true } type deduper struct { mu sync.Mutex m map[string]int64 // key -> last unix nano } var dedupeStore = &deduper{m: make(map[string]int64)} func (d *deduper) try(key string, minGap time.Duration) bool { now := time.Now().UnixNano() gap := minGap.Nanoseconds() d.mu.Lock() defer d.mu.Unlock() if last, ok := d.m[key]; ok && now-last < gap { return false } d.m[key] = now if len(d.m) > 10000 { // 简单清理过期项 cutoff := now - gap*10 for k, v := range d.m { if v < cutoff { delete(d.m, k) } } } return true } func getIntEnv(key string, def int) int { s := strings.TrimSpace(os.Getenv(key)) if s == "" { return def } n, err := strconv.Atoi(s) if err != nil { return def } return n }