187 lines
4.1 KiB
Go
187 lines
4.1 KiB
Go
package middleware
|
||
|
||
import (
|
||
"bytes"
|
||
"crypto/sha256"
|
||
"encoding/hex"
|
||
"io"
|
||
"net/http"
|
||
"os"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
// AdminPOSTSecurity 对 /api/admin 下 POST 校验时间戳、IP 频率、重复请求;multipart 上传仅做限流不做 body 去重
|
||
func AdminPOSTSecurity() gin.HandlerFunc {
|
||
ipLimit := getIntEnv("ADMIN_POST_IP_PER_MIN", 120)
|
||
dedupeSec := getIntEnv("ADMIN_DEDUPE_SEC", 3)
|
||
tsSkew := time.Duration(getIntEnv("ADMIN_REQUEST_TS_SKEW_SEC", 300)) * time.Second
|
||
|
||
return func(c *gin.Context) {
|
||
if c.Request.Method != http.MethodPost {
|
||
c.Next()
|
||
return
|
||
}
|
||
|
||
tsStr := c.GetHeader("X-Request-Timestamp")
|
||
if tsStr == "" {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少请求头 X-Request-Timestamp(毫秒时间戳)"})
|
||
c.Abort()
|
||
return
|
||
}
|
||
tsMs, err := strconv.ParseInt(tsStr, 10, 64)
|
||
if err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "X-Request-Timestamp 格式无效"})
|
||
c.Abort()
|
||
return
|
||
}
|
||
clientT := time.UnixMilli(tsMs)
|
||
if d := time.Since(clientT); d > tsSkew || d < -tsSkew {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "请求时间戳无效或时钟偏差过大"})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
ip := c.ClientIP()
|
||
if !ipPostLimiter.allow("ip:"+ip, ipLimit, time.Minute) {
|
||
c.JSON(http.StatusTooManyRequests, gin.H{"error": "该 IP 请求过于频繁,请稍后再试"})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
ct := c.GetHeader("Content-Type")
|
||
if strings.Contains(strings.ToLower(ct), "multipart/form-data") {
|
||
c.Next()
|
||
return
|
||
}
|
||
|
||
body, err := io.ReadAll(c.Request.Body)
|
||
if err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败"})
|
||
c.Abort()
|
||
return
|
||
}
|
||
c.Request.Body = io.NopCloser(bytes.NewReader(body))
|
||
|
||
sig := hashSig(c.FullPath(), c.Request.URL.RawQuery, body)
|
||
key := ip + "|" + sig
|
||
if !dedupeStore.try(key, time.Duration(dedupeSec)*time.Second) {
|
||
c.JSON(http.StatusTooManyRequests, gin.H{"error": "相同请求请勿在 3 秒内重复提交"})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
// AdminPOSTUserRateLimit 需在 AuthRequired 之后:按账号限制 POST 频率
|
||
func AdminPOSTUserRateLimit() gin.HandlerFunc {
|
||
userLimit := getIntEnv("ADMIN_POST_USER_PER_MIN", 80)
|
||
return func(c *gin.Context) {
|
||
if c.Request.Method != http.MethodPost {
|
||
c.Next()
|
||
return
|
||
}
|
||
uid, ok := c.Get("user_id")
|
||
if !ok {
|
||
c.Next()
|
||
return
|
||
}
|
||
suid, _ := uid.(string)
|
||
if suid == "" {
|
||
c.Next()
|
||
return
|
||
}
|
||
if !ipPostLimiter.allow("uid:"+suid, userLimit, time.Minute) {
|
||
c.JSON(http.StatusTooManyRequests, gin.H{"error": "该账号请求过于频繁,请稍后再试"})
|
||
c.Abort()
|
||
return
|
||
}
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
func hashSig(path, query string, body []byte) string {
|
||
h := sha256.New()
|
||
h.Write([]byte(path))
|
||
h.Write([]byte{0})
|
||
h.Write([]byte(query))
|
||
h.Write([]byte{0})
|
||
h.Write(body)
|
||
return hex.EncodeToString(h.Sum(nil))
|
||
}
|
||
|
||
type slidingLimiter struct {
|
||
mu sync.Mutex
|
||
// key -> 时间戳列表(纳秒)
|
||
m map[string][]int64
|
||
}
|
||
|
||
var ipPostLimiter = &slidingLimiter{m: make(map[string][]int64)}
|
||
|
||
func (s *slidingLimiter) allow(key string, max int, window time.Duration) bool {
|
||
now := time.Now().UnixNano()
|
||
cutoff := now - window.Nanoseconds()
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
list := s.m[key]
|
||
out := list[:0]
|
||
for _, t := range list {
|
||
if t >= cutoff {
|
||
out = append(out, t)
|
||
}
|
||
}
|
||
if len(out) >= max {
|
||
s.m[key] = out
|
||
return false
|
||
}
|
||
out = append(out, now)
|
||
s.m[key] = out
|
||
return true
|
||
}
|
||
|
||
type deduper struct {
|
||
mu sync.Mutex
|
||
m map[string]int64 // key -> last unix nano
|
||
}
|
||
|
||
var dedupeStore = &deduper{m: make(map[string]int64)}
|
||
|
||
func (d *deduper) try(key string, minGap time.Duration) bool {
|
||
now := time.Now().UnixNano()
|
||
gap := minGap.Nanoseconds()
|
||
d.mu.Lock()
|
||
defer d.mu.Unlock()
|
||
if last, ok := d.m[key]; ok && now-last < gap {
|
||
return false
|
||
}
|
||
d.m[key] = now
|
||
if len(d.m) > 10000 {
|
||
// 简单清理过期项
|
||
cutoff := now - gap*10
|
||
for k, v := range d.m {
|
||
if v < cutoff {
|
||
delete(d.m, k)
|
||
}
|
||
}
|
||
}
|
||
return true
|
||
}
|
||
|
||
func getIntEnv(key string, def int) int {
|
||
s := strings.TrimSpace(os.Getenv(key))
|
||
if s == "" {
|
||
return def
|
||
}
|
||
n, err := strconv.Atoi(s)
|
||
if err != nil {
|
||
return def
|
||
}
|
||
return n
|
||
}
|