Files
web/server/middleware/admin_post_security.go

187 lines
4.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package middleware
import (
"bytes"
"crypto/sha256"
"encoding/hex"
"io"
"net/http"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
)
// AdminPOSTSecurity 对 /api/admin 下 POST 校验时间戳、IP 频率、重复请求multipart 上传仅做限流不做 body 去重
func AdminPOSTSecurity() gin.HandlerFunc {
ipLimit := getIntEnv("ADMIN_POST_IP_PER_MIN", 120)
dedupeSec := getIntEnv("ADMIN_DEDUPE_SEC", 3)
tsSkew := time.Duration(getIntEnv("ADMIN_REQUEST_TS_SKEW_SEC", 300)) * time.Second
return func(c *gin.Context) {
if c.Request.Method != http.MethodPost {
c.Next()
return
}
tsStr := c.GetHeader("X-Request-Timestamp")
if tsStr == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少请求头 X-Request-Timestamp毫秒时间戳"})
c.Abort()
return
}
tsMs, err := strconv.ParseInt(tsStr, 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "X-Request-Timestamp 格式无效"})
c.Abort()
return
}
clientT := time.UnixMilli(tsMs)
if d := time.Since(clientT); d > tsSkew || d < -tsSkew {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求时间戳无效或时钟偏差过大"})
c.Abort()
return
}
ip := c.ClientIP()
if !ipPostLimiter.allow("ip:"+ip, ipLimit, time.Minute) {
c.JSON(http.StatusTooManyRequests, gin.H{"error": "该 IP 请求过于频繁,请稍后再试"})
c.Abort()
return
}
ct := c.GetHeader("Content-Type")
if strings.Contains(strings.ToLower(ct), "multipart/form-data") {
c.Next()
return
}
body, err := io.ReadAll(c.Request.Body)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败"})
c.Abort()
return
}
c.Request.Body = io.NopCloser(bytes.NewReader(body))
sig := hashSig(c.FullPath(), c.Request.URL.RawQuery, body)
key := ip + "|" + sig
if !dedupeStore.try(key, time.Duration(dedupeSec)*time.Second) {
c.JSON(http.StatusTooManyRequests, gin.H{"error": "相同请求请勿在 3 秒内重复提交"})
c.Abort()
return
}
c.Next()
}
}
// AdminPOSTUserRateLimit 需在 AuthRequired 之后:按账号限制 POST 频率
func AdminPOSTUserRateLimit() gin.HandlerFunc {
userLimit := getIntEnv("ADMIN_POST_USER_PER_MIN", 80)
return func(c *gin.Context) {
if c.Request.Method != http.MethodPost {
c.Next()
return
}
uid, ok := c.Get("user_id")
if !ok {
c.Next()
return
}
suid, _ := uid.(string)
if suid == "" {
c.Next()
return
}
if !ipPostLimiter.allow("uid:"+suid, userLimit, time.Minute) {
c.JSON(http.StatusTooManyRequests, gin.H{"error": "该账号请求过于频繁,请稍后再试"})
c.Abort()
return
}
c.Next()
}
}
func hashSig(path, query string, body []byte) string {
h := sha256.New()
h.Write([]byte(path))
h.Write([]byte{0})
h.Write([]byte(query))
h.Write([]byte{0})
h.Write(body)
return hex.EncodeToString(h.Sum(nil))
}
type slidingLimiter struct {
mu sync.Mutex
// key -> 时间戳列表(纳秒)
m map[string][]int64
}
var ipPostLimiter = &slidingLimiter{m: make(map[string][]int64)}
func (s *slidingLimiter) allow(key string, max int, window time.Duration) bool {
now := time.Now().UnixNano()
cutoff := now - window.Nanoseconds()
s.mu.Lock()
defer s.mu.Unlock()
list := s.m[key]
out := list[:0]
for _, t := range list {
if t >= cutoff {
out = append(out, t)
}
}
if len(out) >= max {
s.m[key] = out
return false
}
out = append(out, now)
s.m[key] = out
return true
}
type deduper struct {
mu sync.Mutex
m map[string]int64 // key -> last unix nano
}
var dedupeStore = &deduper{m: make(map[string]int64)}
func (d *deduper) try(key string, minGap time.Duration) bool {
now := time.Now().UnixNano()
gap := minGap.Nanoseconds()
d.mu.Lock()
defer d.mu.Unlock()
if last, ok := d.m[key]; ok && now-last < gap {
return false
}
d.m[key] = now
if len(d.m) > 10000 {
// 简单清理过期项
cutoff := now - gap*10
for k, v := range d.m {
if v < cutoff {
delete(d.m, k)
}
}
}
return true
}
func getIntEnv(key string, def int) int {
s := strings.TrimSpace(os.Getenv(key))
if s == "" {
return def
}
n, err := strconv.Atoi(s)
if err != nil {
return def
}
return n
}