接口优化

This commit is contained in:
lhx
2025-09-27 09:30:46 +08:00
parent 52d2753824
commit 0b1e9851dd
7 changed files with 368 additions and 384 deletions

View File

@@ -1,10 +1,11 @@
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File, Form
from sqlalchemy.orm import Session
from typing import List
from typing import List, Optional
import base64
from ..core.database import get_db
from ..schemas.database import (
SQLExecuteRequest, SQLExecuteResponse, TableDataRequest,
TableDataResponse, CreateTableRequest, ImportDataRequest, FileImportRequest
TableDataResponse, CreateTableRequest, ImportDataRequest, FileImportFormData
)
from ..services.database import DatabaseService
@@ -53,19 +54,64 @@ def import_data(request: ImportDataRequest, db: Session = Depends(get_db)):
result = DatabaseService.import_data(db, request.table_name, request.data)
return SQLExecuteResponse(**result)
@router.post("/tables", response_model=List[str])
@router.post("/tables", response_model=SQLExecuteResponse)
def get_table_list():
"""获取所有表名"""
return DatabaseService.get_table_list()
result = DatabaseService.get_table_list()
return SQLExecuteResponse(**result)
@router.post("/import-file", response_model=SQLExecuteResponse)
def import_file(request: FileImportRequest, db: Session = Depends(get_db)):
async def import_file(
file: UploadFile = File(..., description="上传的Excel/CSV文件"),
table_name: Optional[str] = Form(None, description="自定义表名,如果不提供则使用文件名"),
force_overwrite: bool = Form(False, description="是否强制覆盖已存在的表"),
db: Session = Depends(get_db)
):
"""导入Excel/CSV文件到数据库"""
result = DatabaseService.import_file_to_database(
db,
request.filename,
request.file_content,
request.table_name,
request.force_overwrite
)
return SQLExecuteResponse(**result)
try:
# 检查文件类型
if not file.filename:
raise HTTPException(status_code=400, detail="文件名不能为空")
# 处理中文文件名和后缀名 - 更可靠的方法
try:
# FastAPI的UploadFile已经正确处理了文件名编码
filename = file.filename
except:
filename = "unknown_file"
if not filename or filename == "":
raise HTTPException(status_code=400, detail="文件名不能为空")
# 提取文件扩展名
if '.' not in filename:
raise HTTPException(status_code=400, detail="文件必须有扩展名")
file_ext = filename.lower().split('.')[-1]
if file_ext not in ['csv', 'xlsx', 'xls']:
raise HTTPException(
status_code=400,
detail=f"不支持的文件类型: .{file_ext},仅支持 .csv, .xlsx, .xls"
)
# 读取文件内容
file_content = await file.read()
# 将文件内容转换为base64保持与现有FileImportUtils兼容
file_content_base64 = base64.b64encode(file_content).decode('utf-8')
# 调用服务层方法
result = DatabaseService.import_file_to_database(
db, filename, file_content_base64, table_name, force_overwrite
)
return SQLExecuteResponse(**result)
except HTTPException:
raise
except Exception as e:
return SQLExecuteResponse(
success=False,
message=f"文件导入失败: {str(e)}"
)

View File

@@ -32,7 +32,11 @@ class AccountListRequest(BaseModel):
limit: Optional[int] = 100
class AccountGetRequest(BaseModel):
account_id: int
account_id: Optional[int] = None
account: Optional[str] = None
section: Optional[str] = None
status: Optional[int] = None
today_updated: Optional[int] = None
class AccountUpdateRequest(BaseModel):
account_id: int

View File

@@ -30,8 +30,7 @@ class ImportDataRequest(BaseModel):
table_name: str
data: List[Dict[str, Any]]
class FileImportRequest(BaseModel):
filename: str
file_content: str # base64编码的文件内容
class FileImportFormData(BaseModel):
"""文件导入表单数据"""
table_name: Optional[str] = None # 可选的自定义表名
force_overwrite: bool = False # 是否强制覆盖已存在的表

View File

@@ -192,9 +192,17 @@ class DatabaseService:
try:
inspector = inspect(engine)
# 排除mysql的系统表和accounts表
return [table for table in inspector.get_table_names() if not table.startswith('mysql') and table != 'accounts']
data = [table for table in inspector.get_table_names() if not table.startswith('mysql') and table != 'accounts' and table != 'apscheduler_jobs']
return {
"success": True,
"message": "获取表名成功",
"data": data
}
except Exception as e:
return []
return {
"success": False,
"message": f"获取表名失败: {str(e)}"
}
@staticmethod
def import_file_to_database(db: Session, filename: str, file_content: str,

View File

@@ -1,364 +0,0 @@
"""
数据库导入工具模块
用于通过HTTP接口导入文件到数据库
"""
import os
import sys
import sqlite3
import pandas as pd
from datetime import datetime
from pathlib import Path
from typing import Optional, Dict, Any, Tuple, List
import openpyxl
import io
# 添加sqlite-mcp-server到路径以使用其工具函数
current_dir = os.path.dirname(__file__)
project_root = os.path.dirname(os.path.dirname(current_dir))
sqlite_mcp_path = os.path.join(project_root, 'sqlite-mcp-server')
if os.path.exists(sqlite_mcp_path) and sqlite_mcp_path not in sys.path:
sys.path.insert(0, sqlite_mcp_path)
from tools.common import (
clean_table_name, clean_column_name, infer_data_types,
CHUNK_SIZE, SQLITE_MAX_VARIABLES
)
def _clean_cell_value(value: Any) -> Any:
"""
清理单元格内容:去除空格、换行、制表符等
Args:
value: 单元格的值
Returns:
清理后的值
"""
if value is None:
return None
if isinstance(value, str):
# 去除各种空白字符:空格、换行、回车、制表符、全角空格
cleaned = value.replace(' ', '').replace('\n', '').replace('\r', '').replace('\t', '').replace('\u3000', '')
return cleaned if cleaned else None
return value
def _process_merged_cells(ws: openpyxl.worksheet.worksheet.Worksheet) -> Dict[Tuple[int, int], Any]:
"""
处理合并单元格,将所有合并单元格打散并填充原值
Args:
ws: openpyxl worksheet对象
Returns:
单元格坐标到值的映射
"""
merged_map = {}
# 处理所有合并单元格范围
for merged_range in ws.merged_cells.ranges:
# 获取左上角单元格的值
min_row, min_col = merged_range.min_row, merged_range.min_col
value = ws.cell(row=min_row, column=min_col).value
# 填充整个合并范围
for row in range(merged_range.min_row, merged_range.max_row + 1):
for col in range(merged_range.min_col, merged_range.max_col + 1):
merged_map[(row, col)] = value
return merged_map
def _worksheet_is_empty(ws: openpyxl.worksheet.worksheet.Worksheet) -> bool:
"""
判断工作表是否为空(无任何有效数据)
"""
max_row = ws.max_row or 0
max_col = ws.max_column or 0
if max_row == 0 or max_col == 0:
return True
# 快速路径:单个单元且为空
if max_row == 1 and max_col == 1:
return _clean_cell_value(ws.cell(row=1, column=1).value) is None
# 扫描是否存在任一非空单元格
for row in ws.iter_rows(min_row=1, max_row=max_row, min_col=1, max_col=max_col, values_only=True):
if any(_clean_cell_value(v) is not None for v in row):
return False
return True
def _read_excel_with_merged_cells(table_name: str, file_content: bytes) -> List[Dict[str, Any]]:
"""
从内存中读取Excel文件处理合并单元格
Args:
table_name: 表名
file_content: Excel文件的二进制内容
Returns:
处理后的DataFrame
"""
# 从字节流加载Excel文件
wb = openpyxl.load_workbook(io.BytesIO(file_content), data_only=True)
wss = wb.sheetnames
data = []
df = None
for ws_name in wss:
ws = wb[ws_name]
# 跳过空工作表
if _worksheet_is_empty(ws):
continue
# 处理合并单元格
# merged_cells_map = _process_merged_cells(ws)
# 读取所有数据
all_data = []
# 跳过空列
none_raw = []
max_row = ws.max_row
max_col = ws.max_column
for row_idx in range(1, max_row + 1):
row_data = []
for col_idx in range(1, max_col + 1):
# 检查是否在合并单元格映射中
# if (row_idx, col_idx) in merged_cells_map:
# value = merged_cells_map[(row_idx, col_idx)]
# else:
value = ws.cell(row=row_idx, column=col_idx).value
if (value is None and row_idx == 1) or (col_idx in none_raw):
none_raw.append(col_idx)
continue
# 清理单元格内容
value = _clean_cell_value(value)
row_data.append(value)
# 跳过完全空白的行
if any(v is not None for v in row_data):
all_data.append(row_data)
# 转换为DataFrame
if not all_data:
# 空表,跳过
continue
# 第一行作为列名
columns = [str(col) if col is not None else f"{i+1}"
for i, col in enumerate(all_data[0])]
# 其余行作为数据
data_rows = all_data[1:] if len(all_data) > 1 else []
# 创建DataFrame
df = pd.DataFrame(data_rows, columns=columns)
if len(wss) > 1:
normalized_sheet = ws_name.replace(' ', '').replace('-', '_')
data.append({"table_name": f"{table_name}_{normalized_sheet}", "df": df})
else:
data.append({"table_name": table_name, "df": df})
if not data:
return []
return data
def import_to_database(
db_path: str,
file_content: bytes,
filename: str,
table_name: Optional[str] = None,
force_overwrite: bool = False
) -> Dict[str, Any]:
"""
导入文件内容到指定的SQLite数据库
Args:
db_path: 数据库文件路径
file_content: 文件的二进制内容
filename: 原始文件名
table_name: 指定的表名(可选,默认从文件名生成)
force_overwrite: 是否强制覆盖已存在的表
Returns:
导入结果信息
"""
try:
# 检查数据库是否存在
if not os.path.exists(db_path):
return {"success": False, "error": f"数据库不存在: {db_path}"}
# 读取文件
file_ext = Path(filename).suffix.lower()
table_name = filename.split('.')[0]
if file_ext == '.csv':
# 检测编码
encodings = ['utf-8', 'gbk', 'gb2312', 'gb18030']
df = None
data = []
for encoding in encodings:
try:
# 尝试从内存读取
df = pd.read_csv(io.BytesIO(file_content), encoding=encoding, nrows=5)
# 如果成功,用同样的编码读取完整内容
df = pd.read_csv(io.BytesIO(file_content), encoding=encoding)
data.append({"table_name": table_name, "df": df})
break
except:
continue
if df is None:
conn.close()
return {"success": False, "error": "无法识别文件编码"}
elif file_ext in ['.xlsx', '.xls']:
# 使用自定义函数读取Excel处理合并单元格
data = _read_excel_with_merged_cells(table_name, file_content)
else:
conn.close()
return {"success": False, "error": f"不支持的文件类型: {file_ext}"}
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
for item in data:
table_name = item["table_name"]
df = item["df"]
# 检查表是否存在
cursor.execute(f"""
SELECT name FROM sqlite_master
WHERE type='table' AND name='{table_name}'
""")
table_exists = cursor.fetchone() is not None
if table_exists and not force_overwrite:
# 获取表的基本信息
cursor.execute(f"SELECT COUNT(*) FROM \"{table_name}\"")
row_count = cursor.fetchone()[0]
cursor.execute(f"PRAGMA table_info(\"{table_name}\")")
columns = [col[1] for col in cursor.fetchall() if col[1] != '_row_id']
conn.close()
return {
"success": False,
"error": "table_exists",
"error_type": "TABLE_EXISTS",
"message": f"{table_name} 已存在",
"table_info": {
"table_name": table_name,
"row_count": row_count,
"columns": columns,
"column_count": len(columns)
},
"suggestion": "您可以选择1) 覆盖现有表force_overwrite=true2) 使用其他表名 3) 取消操作"
}
total_rows_all = 0
imported_rows_all = 0
error_count_all = 0
table_name_all = []
for item in data:
df = item["df"]
table_name = item["table_name"]
# 清理列名
df.columns = [clean_column_name(col) for col in df.columns]
# 检查是否有重复列名
if len(df.columns) != len(set(df.columns)):
# 为重复列名添加序号
new_cols = []
col_count = {}
for col in df.columns:
if col in col_count:
col_count[col] += 1
new_cols.append(f"{col}_{col_count[col]}")
else:
col_count[col] = 0
new_cols.append(col)
df.columns = new_cols
total_rows = len(df)
# 如果表存在且需要覆盖,先删除表
if table_exists and force_overwrite:
cursor.execute(f'DROP TABLE IF EXISTS "{table_name}"')
conn.commit()
# 创建新表
data_types = infer_data_types(df)
columns_def = []
for col in df.columns:
col_type = data_types.get(col, 'TEXT')
columns_def.append(f'"{col}" {col_type}')
create_table_sql = f"""
CREATE TABLE IF NOT EXISTS "{table_name}" (
_row_id INTEGER PRIMARY KEY AUTOINCREMENT,
{', '.join(columns_def)}
)
"""
cursor.execute(create_table_sql)
conn.commit()
# 数据清洗和类型转换
for col in df.columns:
# 对于CSV文件去除前后空格
if file_ext == '.csv' and df[col].dtype == 'object':
df[col] = df[col].apply(lambda x: x.strip() if isinstance(x, str) else x)
# 将 NaN 替换为 None (SQLite NULL)
df[col] = df[col].where(pd.notna(df[col]), None)
# 日期类型转换为 ISO 格式
# try:
# # 尝试识别日期列
# if df[col].dtype == 'object' and df[col].notna().any():
# sample = df[col].dropna().head(100)
# dates = pd.to_datetime(sample, errors='coerce')
# if dates.notna().sum() / len(sample) > 0.9:
# df[col] = pd.to_datetime(df[col], errors='coerce')
# df[col] = df[col].dt.strftime('%Y-%m-%d %H:%M:%S')
# except:
# pass
# 计算安全的批次大小(考虑 SQLite 变量限制)
num_columns = len(df.columns)
safe_batch_size = min(CHUNK_SIZE, SQLITE_MAX_VARIABLES // num_columns - 1)
# 批量导入数据
error_count = 0
imported_rows = 0
for batch_start in range(0, total_rows, safe_batch_size):
batch_end = min(batch_start + safe_batch_size, total_rows)
batch_df = df.iloc[batch_start:batch_end]
try:
# 批量插入
batch_df.to_sql(table_name, conn, if_exists='append',
index=False, method='multi', chunksize=100)
imported_rows = batch_end
except Exception as e:
error_count += 1
print(f"批次导入错误 [{batch_start}-{batch_end}]: {str(e)}")
# 优化表
cursor.execute(f'ANALYZE "{table_name}"')
total_rows_all += total_rows
imported_rows_all += imported_rows
error_count_all += error_count
table_name_all.append(table_name)
conn.commit()
conn.close()
table_name_all = set(table_name_all)
return {
"success": True,
"table_name": table_name,
"total_rows": total_rows_all,
"imported_rows": imported_rows_all,
"error_count": error_count_all,
"message": f"成功导入 {imported_rows}/{total_rows} 行数据到表 {table_name_all}"
}
except Exception as e:
return {"success": False, "error": str(e)}

169
start_daemon.sh Normal file
View File

@@ -0,0 +1,169 @@
#!/bin/bash
# 铁路项目管理系统 - 后台运行启动脚本
# 使用方法: ./start_daemon.sh [端口号]
# 默认端口: 8000
# 设置端口号默认为8000
PORT=${1:-8000}
# 设置日志文件路径
LOG_DIR="logs"
ACCESS_LOG="$LOG_DIR/access.log"
ERROR_LOG="$LOG_DIR/error.log"
APP_LOG="$LOG_DIR/app.log"
PID_FILE="$LOG_DIR/app.pid"
# 创建日志目录
mkdir -p "$LOG_DIR"
echo "=== 铁路项目管理系统后台启动脚本 ==="
echo "端口: $PORT"
echo "日志目录: $LOG_DIR"
# 清理端口函数
kill_process_on_port() {
local port=$1
echo "检查端口 ${port} 的进程..."
# Linux系统使用不同的命令查找占用端口的进程
if command -v lsof >/dev/null 2>&1; then
# 使用lsof
pid=$(lsof -ti :${port})
elif command -v netstat >/dev/null 2>&1; then
# 使用netstat
pid=$(netstat -tlnp 2>/dev/null | grep ":${port} " | awk '{print $7}' | cut -d'/' -f1)
elif command -v ss >/dev/null 2>&1; then
# 使用ss
pid=$(ss -tlnp | grep ":${port} " | sed 's/.*pid=\([0-9]*\).*/\1/')
else
echo "警告: 无法找到 lsof、netstat 或 ss 命令,跳过端口检查"
return
fi
if [ -n "$pid" ] && [ "$pid" != "" ]; then
echo "发现端口 ${port} 被占用PID: ${pid},正在杀死进程..."
kill -9 $pid 2>/dev/null
sleep 2
echo "已杀死端口 ${port} 的进程"
else
echo "端口 ${port} 未被占用"
fi
}
# 停止已运行的服务函数
stop_service() {
if [ -f "$PID_FILE" ]; then
local old_pid=$(cat "$PID_FILE")
if ps -p $old_pid > /dev/null 2>&1; then
echo "停止已运行的服务 (PID: $old_pid)..."
kill $old_pid
sleep 3
if ps -p $old_pid > /dev/null 2>&1; then
echo "强制停止服务..."
kill -9 $old_pid
fi
fi
rm -f "$PID_FILE"
fi
}
# 检查Python环境
check_python() {
if ! command -v python3 >/dev/null 2>&1; then
echo "错误: 未找到 python3"
exit 1
fi
# 检查虚拟环境
if [ -d ".venv" ]; then
echo "激活虚拟环境..."
source .venv/bin/activate
elif [ -d "venv" ]; then
echo "激活虚拟环境..."
source venv/bin/activate
else
echo "警告: 未找到虚拟环境使用系统Python"
fi
}
# 检查依赖
check_dependencies() {
echo "检查依赖包..."
python3 -c "import fastapi, uvicorn" 2>/dev/null
if [ $? -ne 0 ]; then
echo "错误: 缺少必要的依赖包 (fastapi, uvicorn)"
echo "请运行: pip install -r requirements.txt"
exit 1
fi
}
# 主函数
main() {
# 停止已运行的服务
stop_service
# 清理端口
kill_process_on_port $PORT
# 检查环境
check_python
check_dependencies
echo "启动服务..."
echo "启动时间: $(date)" > "$APP_LOG"
echo "端口: $PORT" >> "$APP_LOG"
echo "=================================" >> "$APP_LOG"
# 后台启动服务
nohup python3 -m uvicorn app.main:app \
--host 0.0.0.0 \
--port $PORT \
--access-log \
--log-level info \
>> "$APP_LOG" 2>&1 &
# 保存PID
echo $! > "$PID_FILE"
# 等待服务启动
sleep 3
# 检查服务是否启动成功
if ps -p $(cat "$PID_FILE") > /dev/null 2>&1; then
echo "✅ 服务启动成功!"
echo "🌐 访问地址: http://localhost:$PORT"
echo "📖 API文档: http://localhost:$PORT/docs"
echo "📋 ReDoc文档: http://localhost:$PORT/redoc"
echo "📄 日志文件: $APP_LOG"
echo "🔍 进程ID: $(cat $PID_FILE)"
echo ""
echo "查看日志: tail -f $APP_LOG"
echo "停止服务: ./stop.sh 或 kill $(cat $PID_FILE)"
# 创建停止脚本
cat > stop.sh << EOF
#!/bin/bash
if [ -f "$PID_FILE" ]; then
PID=\$(cat "$PID_FILE")
echo "停止服务 (PID: \$PID)..."
kill \$PID
rm -f "$PID_FILE"
echo "服务已停止"
else
echo "未找到运行中的服务"
fi
EOF
chmod +x stop.sh
else
echo "❌ 服务启动失败,请查看日志: $APP_LOG"
exit 1
fi
}
# 捕获退出信号
trap 'echo "脚本被中断"; exit 1' INT TERM
# 运行主函数
main

122
start_server.sh Normal file
View File

@@ -0,0 +1,122 @@
#!/bin/bash
# 铁路项目管理系统 - 直接运行启动脚本
# 使用方法: ./start_server.sh [端口号]
# 默认端口: 8000
# 设置端口号默认为8000
PORT=${1:-8000}
echo "=== 铁路项目管理系统启动脚本 ==="
echo "端口: $PORT"
echo "按 Ctrl+C 停止服务"
# 清理端口函数
kill_process_on_port() {
local port=$1
echo "检查端口 ${port} 的进程..."
# Linux系统使用不同的命令查找占用端口的进程
if command -v lsof >/dev/null 2>&1; then
# 使用lsof
pid=$(lsof -ti :${port})
elif command -v netstat >/dev/null 2>&1; then
# 使用netstat
pid=$(netstat -tlnp 2>/dev/null | grep ":${port} " | awk '{print $7}' | cut -d'/' -f1)
elif command -v ss >/dev/null 2>&1; then
# 使用ss
pid=$(ss -tlnp | grep ":${port} " | sed 's/.*pid=\([0-9]*\).*/\1/')
else
echo "警告: 无法找到 lsof、netstat 或 ss 命令,跳过端口检查"
return
fi
if [ -n "$pid" ] && [ "$pid" != "" ]; then
echo "发现端口 ${port} 被占用PID: ${pid},正在杀死进程..."
kill -9 $pid 2>/dev/null
sleep 2
echo "已杀死端口 ${port} 的进程"
else
echo "端口 ${port} 未被占用"
fi
}
# 检查Python环境
check_python() {
if ! command -v python3 >/dev/null 2>&1; then
echo "错误: 未找到 python3"
exit 1
fi
# 检查虚拟环境
if [ -d ".venv" ]; then
echo "激活虚拟环境..."
source .venv/bin/activate
elif [ -d "venv" ]; then
echo "激活虚拟环境..."
source venv/bin/activate
else
echo "警告: 未找到虚拟环境使用系统Python"
fi
}
# 检查依赖
check_dependencies() {
echo "检查依赖包..."
python3 -c "import fastapi, uvicorn" 2>/dev/null
if [ $? -ne 0 ]; then
echo "错误: 缺少必要的依赖包 (fastapi, uvicorn)"
echo "请运行: pip install -r requirements.txt"
exit 1
fi
}
# 显示服务信息
show_service_info() {
echo ""
echo "=================================="
echo "✅ 服务启动成功!"
echo "🌐 访问地址: http://localhost:$PORT"
echo "📖 API文档: http://localhost:$PORT/docs"
echo "📋 ReDoc文档: http://localhost:$PORT/redoc"
echo "🔍 健康检查: http://localhost:$PORT/health"
echo "=================================="
echo ""
}
# 清理函数
cleanup() {
echo ""
echo "正在停止服务..."
# uvicorn会自动处理SIGTERM信号
exit 0
}
# 主函数
main() {
# 清理端口
kill_process_on_port $PORT
# 检查环境
check_python
check_dependencies
# 捕获退出信号
trap cleanup SIGINT SIGTERM
echo "启动服务..."
# 显示服务信息在后台执行延迟3秒显示
(sleep 3 && show_service_info) &
# 直接启动服务(前台运行)
python3 -m uvicorn app.main:app \
--host 0.0.0.0 \
--port $PORT \
--reload \
--access-log \
--log-level info
}
# 运行主函数
main