接口优化
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File, Form
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
import base64
|
||||
from ..core.database import get_db
|
||||
from ..schemas.database import (
|
||||
SQLExecuteRequest, SQLExecuteResponse, TableDataRequest,
|
||||
TableDataResponse, CreateTableRequest, ImportDataRequest, FileImportRequest
|
||||
TableDataResponse, CreateTableRequest, ImportDataRequest, FileImportFormData
|
||||
)
|
||||
from ..services.database import DatabaseService
|
||||
|
||||
@@ -53,19 +54,64 @@ def import_data(request: ImportDataRequest, db: Session = Depends(get_db)):
|
||||
result = DatabaseService.import_data(db, request.table_name, request.data)
|
||||
return SQLExecuteResponse(**result)
|
||||
|
||||
@router.post("/tables", response_model=List[str])
|
||||
@router.post("/tables", response_model=SQLExecuteResponse)
|
||||
def get_table_list():
|
||||
"""获取所有表名"""
|
||||
return DatabaseService.get_table_list()
|
||||
result = DatabaseService.get_table_list()
|
||||
return SQLExecuteResponse(**result)
|
||||
|
||||
@router.post("/import-file", response_model=SQLExecuteResponse)
|
||||
def import_file(request: FileImportRequest, db: Session = Depends(get_db)):
|
||||
async def import_file(
|
||||
file: UploadFile = File(..., description="上传的Excel/CSV文件"),
|
||||
table_name: Optional[str] = Form(None, description="自定义表名,如果不提供则使用文件名"),
|
||||
force_overwrite: bool = Form(False, description="是否强制覆盖已存在的表"),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""导入Excel/CSV文件到数据库"""
|
||||
result = DatabaseService.import_file_to_database(
|
||||
db,
|
||||
request.filename,
|
||||
request.file_content,
|
||||
request.table_name,
|
||||
request.force_overwrite
|
||||
)
|
||||
return SQLExecuteResponse(**result)
|
||||
try:
|
||||
# 检查文件类型
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="文件名不能为空")
|
||||
|
||||
# 处理中文文件名和后缀名 - 更可靠的方法
|
||||
try:
|
||||
# FastAPI的UploadFile已经正确处理了文件名编码
|
||||
filename = file.filename
|
||||
except:
|
||||
filename = "unknown_file"
|
||||
|
||||
if not filename or filename == "":
|
||||
raise HTTPException(status_code=400, detail="文件名不能为空")
|
||||
|
||||
# 提取文件扩展名
|
||||
if '.' not in filename:
|
||||
raise HTTPException(status_code=400, detail="文件必须有扩展名")
|
||||
|
||||
file_ext = filename.lower().split('.')[-1]
|
||||
|
||||
if file_ext not in ['csv', 'xlsx', 'xls']:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的文件类型: .{file_ext},仅支持 .csv, .xlsx, .xls"
|
||||
)
|
||||
|
||||
# 读取文件内容
|
||||
file_content = await file.read()
|
||||
|
||||
# 将文件内容转换为base64(保持与现有FileImportUtils兼容)
|
||||
file_content_base64 = base64.b64encode(file_content).decode('utf-8')
|
||||
|
||||
# 调用服务层方法
|
||||
result = DatabaseService.import_file_to_database(
|
||||
db, filename, file_content_base64, table_name, force_overwrite
|
||||
)
|
||||
|
||||
return SQLExecuteResponse(**result)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
return SQLExecuteResponse(
|
||||
success=False,
|
||||
message=f"文件导入失败: {str(e)}"
|
||||
)
|
||||
@@ -32,7 +32,11 @@ class AccountListRequest(BaseModel):
|
||||
limit: Optional[int] = 100
|
||||
|
||||
class AccountGetRequest(BaseModel):
|
||||
account_id: int
|
||||
account_id: Optional[int] = None
|
||||
account: Optional[str] = None
|
||||
section: Optional[str] = None
|
||||
status: Optional[int] = None
|
||||
today_updated: Optional[int] = None
|
||||
|
||||
class AccountUpdateRequest(BaseModel):
|
||||
account_id: int
|
||||
|
||||
@@ -30,8 +30,7 @@ class ImportDataRequest(BaseModel):
|
||||
table_name: str
|
||||
data: List[Dict[str, Any]]
|
||||
|
||||
class FileImportRequest(BaseModel):
|
||||
filename: str
|
||||
file_content: str # base64编码的文件内容
|
||||
class FileImportFormData(BaseModel):
|
||||
"""文件导入表单数据"""
|
||||
table_name: Optional[str] = None # 可选的自定义表名
|
||||
force_overwrite: bool = False # 是否强制覆盖已存在的表
|
||||
@@ -192,9 +192,17 @@ class DatabaseService:
|
||||
try:
|
||||
inspector = inspect(engine)
|
||||
# 排除mysql的系统表和accounts表
|
||||
return [table for table in inspector.get_table_names() if not table.startswith('mysql') and table != 'accounts']
|
||||
data = [table for table in inspector.get_table_names() if not table.startswith('mysql') and table != 'accounts' and table != 'apscheduler_jobs']
|
||||
return {
|
||||
"success": True,
|
||||
"message": "获取表名成功",
|
||||
"data": data
|
||||
}
|
||||
except Exception as e:
|
||||
return []
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"获取表名失败: {str(e)}"
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def import_file_to_database(db: Session, filename: str, file_content: str,
|
||||
|
||||
@@ -1,364 +0,0 @@
|
||||
"""
|
||||
数据库导入工具模块
|
||||
用于通过HTTP接口导入文件到数据库
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import sqlite3
|
||||
import pandas as pd
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, Tuple, List
|
||||
import openpyxl
|
||||
import io
|
||||
|
||||
# 添加sqlite-mcp-server到路径以使用其工具函数
|
||||
current_dir = os.path.dirname(__file__)
|
||||
project_root = os.path.dirname(os.path.dirname(current_dir))
|
||||
sqlite_mcp_path = os.path.join(project_root, 'sqlite-mcp-server')
|
||||
if os.path.exists(sqlite_mcp_path) and sqlite_mcp_path not in sys.path:
|
||||
sys.path.insert(0, sqlite_mcp_path)
|
||||
|
||||
from tools.common import (
|
||||
clean_table_name, clean_column_name, infer_data_types,
|
||||
CHUNK_SIZE, SQLITE_MAX_VARIABLES
|
||||
)
|
||||
|
||||
def _clean_cell_value(value: Any) -> Any:
|
||||
"""
|
||||
清理单元格内容:去除空格、换行、制表符等
|
||||
|
||||
Args:
|
||||
value: 单元格的值
|
||||
|
||||
Returns:
|
||||
清理后的值
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if isinstance(value, str):
|
||||
# 去除各种空白字符:空格、换行、回车、制表符、全角空格
|
||||
cleaned = value.replace(' ', '').replace('\n', '').replace('\r', '').replace('\t', '').replace('\u3000', '')
|
||||
return cleaned if cleaned else None
|
||||
|
||||
return value
|
||||
|
||||
def _process_merged_cells(ws: openpyxl.worksheet.worksheet.Worksheet) -> Dict[Tuple[int, int], Any]:
|
||||
"""
|
||||
处理合并单元格,将所有合并单元格打散并填充原值
|
||||
|
||||
Args:
|
||||
ws: openpyxl worksheet对象
|
||||
|
||||
Returns:
|
||||
单元格坐标到值的映射
|
||||
"""
|
||||
merged_map = {}
|
||||
|
||||
# 处理所有合并单元格范围
|
||||
for merged_range in ws.merged_cells.ranges:
|
||||
# 获取左上角单元格的值
|
||||
min_row, min_col = merged_range.min_row, merged_range.min_col
|
||||
value = ws.cell(row=min_row, column=min_col).value
|
||||
|
||||
# 填充整个合并范围
|
||||
for row in range(merged_range.min_row, merged_range.max_row + 1):
|
||||
for col in range(merged_range.min_col, merged_range.max_col + 1):
|
||||
merged_map[(row, col)] = value
|
||||
|
||||
return merged_map
|
||||
|
||||
|
||||
def _worksheet_is_empty(ws: openpyxl.worksheet.worksheet.Worksheet) -> bool:
|
||||
"""
|
||||
判断工作表是否为空(无任何有效数据)
|
||||
"""
|
||||
max_row = ws.max_row or 0
|
||||
max_col = ws.max_column or 0
|
||||
if max_row == 0 or max_col == 0:
|
||||
return True
|
||||
# 快速路径:单个单元且为空
|
||||
if max_row == 1 and max_col == 1:
|
||||
return _clean_cell_value(ws.cell(row=1, column=1).value) is None
|
||||
|
||||
# 扫描是否存在任一非空单元格
|
||||
for row in ws.iter_rows(min_row=1, max_row=max_row, min_col=1, max_col=max_col, values_only=True):
|
||||
if any(_clean_cell_value(v) is not None for v in row):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _read_excel_with_merged_cells(table_name: str, file_content: bytes) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
从内存中读取Excel文件,处理合并单元格
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
file_content: Excel文件的二进制内容
|
||||
|
||||
Returns:
|
||||
处理后的DataFrame
|
||||
"""
|
||||
# 从字节流加载Excel文件
|
||||
wb = openpyxl.load_workbook(io.BytesIO(file_content), data_only=True)
|
||||
wss = wb.sheetnames
|
||||
data = []
|
||||
df = None
|
||||
|
||||
for ws_name in wss:
|
||||
ws = wb[ws_name]
|
||||
# 跳过空工作表
|
||||
if _worksheet_is_empty(ws):
|
||||
continue
|
||||
# 处理合并单元格
|
||||
# merged_cells_map = _process_merged_cells(ws)
|
||||
# 读取所有数据
|
||||
all_data = []
|
||||
# 跳过空列
|
||||
none_raw = []
|
||||
max_row = ws.max_row
|
||||
max_col = ws.max_column
|
||||
for row_idx in range(1, max_row + 1):
|
||||
row_data = []
|
||||
for col_idx in range(1, max_col + 1):
|
||||
|
||||
# 检查是否在合并单元格映射中
|
||||
# if (row_idx, col_idx) in merged_cells_map:
|
||||
# value = merged_cells_map[(row_idx, col_idx)]
|
||||
# else:
|
||||
value = ws.cell(row=row_idx, column=col_idx).value
|
||||
if (value is None and row_idx == 1) or (col_idx in none_raw):
|
||||
none_raw.append(col_idx)
|
||||
continue
|
||||
# 清理单元格内容
|
||||
value = _clean_cell_value(value)
|
||||
row_data.append(value)
|
||||
# 跳过完全空白的行
|
||||
if any(v is not None for v in row_data):
|
||||
all_data.append(row_data)
|
||||
# 转换为DataFrame
|
||||
if not all_data:
|
||||
# 空表,跳过
|
||||
continue
|
||||
# 第一行作为列名
|
||||
columns = [str(col) if col is not None else f"列{i+1}"
|
||||
for i, col in enumerate(all_data[0])]
|
||||
# 其余行作为数据
|
||||
data_rows = all_data[1:] if len(all_data) > 1 else []
|
||||
# 创建DataFrame
|
||||
df = pd.DataFrame(data_rows, columns=columns)
|
||||
if len(wss) > 1:
|
||||
normalized_sheet = ws_name.replace(' ', '').replace('-', '_')
|
||||
data.append({"table_name": f"{table_name}_{normalized_sheet}", "df": df})
|
||||
else:
|
||||
data.append({"table_name": table_name, "df": df})
|
||||
if not data:
|
||||
return []
|
||||
return data
|
||||
|
||||
def import_to_database(
|
||||
db_path: str,
|
||||
file_content: bytes,
|
||||
filename: str,
|
||||
table_name: Optional[str] = None,
|
||||
force_overwrite: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
导入文件内容到指定的SQLite数据库
|
||||
|
||||
Args:
|
||||
db_path: 数据库文件路径
|
||||
file_content: 文件的二进制内容
|
||||
filename: 原始文件名
|
||||
table_name: 指定的表名(可选,默认从文件名生成)
|
||||
force_overwrite: 是否强制覆盖已存在的表
|
||||
|
||||
Returns:
|
||||
导入结果信息
|
||||
"""
|
||||
try:
|
||||
# 检查数据库是否存在
|
||||
if not os.path.exists(db_path):
|
||||
return {"success": False, "error": f"数据库不存在: {db_path}"}
|
||||
|
||||
|
||||
# 读取文件
|
||||
file_ext = Path(filename).suffix.lower()
|
||||
table_name = filename.split('.')[0]
|
||||
|
||||
if file_ext == '.csv':
|
||||
# 检测编码
|
||||
encodings = ['utf-8', 'gbk', 'gb2312', 'gb18030']
|
||||
df = None
|
||||
data = []
|
||||
for encoding in encodings:
|
||||
try:
|
||||
# 尝试从内存读取
|
||||
df = pd.read_csv(io.BytesIO(file_content), encoding=encoding, nrows=5)
|
||||
# 如果成功,用同样的编码读取完整内容
|
||||
df = pd.read_csv(io.BytesIO(file_content), encoding=encoding)
|
||||
data.append({"table_name": table_name, "df": df})
|
||||
break
|
||||
except:
|
||||
continue
|
||||
if df is None:
|
||||
conn.close()
|
||||
return {"success": False, "error": "无法识别文件编码"}
|
||||
elif file_ext in ['.xlsx', '.xls']:
|
||||
# 使用自定义函数读取Excel,处理合并单元格
|
||||
data = _read_excel_with_merged_cells(table_name, file_content)
|
||||
else:
|
||||
conn.close()
|
||||
return {"success": False, "error": f"不支持的文件类型: {file_ext}"}
|
||||
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
for item in data:
|
||||
table_name = item["table_name"]
|
||||
df = item["df"]
|
||||
# 检查表是否存在
|
||||
cursor.execute(f"""
|
||||
SELECT name FROM sqlite_master
|
||||
WHERE type='table' AND name='{table_name}'
|
||||
""")
|
||||
|
||||
table_exists = cursor.fetchone() is not None
|
||||
|
||||
if table_exists and not force_overwrite:
|
||||
# 获取表的基本信息
|
||||
cursor.execute(f"SELECT COUNT(*) FROM \"{table_name}\"")
|
||||
row_count = cursor.fetchone()[0]
|
||||
|
||||
cursor.execute(f"PRAGMA table_info(\"{table_name}\")")
|
||||
columns = [col[1] for col in cursor.fetchall() if col[1] != '_row_id']
|
||||
|
||||
conn.close()
|
||||
return {
|
||||
"success": False,
|
||||
"error": "table_exists",
|
||||
"error_type": "TABLE_EXISTS",
|
||||
"message": f"表 {table_name} 已存在",
|
||||
"table_info": {
|
||||
"table_name": table_name,
|
||||
"row_count": row_count,
|
||||
"columns": columns,
|
||||
"column_count": len(columns)
|
||||
},
|
||||
"suggestion": "您可以选择:1) 覆盖现有表(force_overwrite=true)2) 使用其他表名 3) 取消操作"
|
||||
}
|
||||
|
||||
total_rows_all = 0
|
||||
imported_rows_all = 0
|
||||
error_count_all = 0
|
||||
table_name_all = []
|
||||
for item in data:
|
||||
df = item["df"]
|
||||
table_name = item["table_name"]
|
||||
# 清理列名
|
||||
df.columns = [clean_column_name(col) for col in df.columns]
|
||||
|
||||
# 检查是否有重复列名
|
||||
if len(df.columns) != len(set(df.columns)):
|
||||
# 为重复列名添加序号
|
||||
new_cols = []
|
||||
col_count = {}
|
||||
for col in df.columns:
|
||||
if col in col_count:
|
||||
col_count[col] += 1
|
||||
new_cols.append(f"{col}_{col_count[col]}")
|
||||
else:
|
||||
col_count[col] = 0
|
||||
new_cols.append(col)
|
||||
df.columns = new_cols
|
||||
|
||||
total_rows = len(df)
|
||||
|
||||
# 如果表存在且需要覆盖,先删除表
|
||||
if table_exists and force_overwrite:
|
||||
cursor.execute(f'DROP TABLE IF EXISTS "{table_name}"')
|
||||
conn.commit()
|
||||
|
||||
# 创建新表
|
||||
data_types = infer_data_types(df)
|
||||
|
||||
columns_def = []
|
||||
for col in df.columns:
|
||||
col_type = data_types.get(col, 'TEXT')
|
||||
columns_def.append(f'"{col}" {col_type}')
|
||||
|
||||
create_table_sql = f"""
|
||||
CREATE TABLE IF NOT EXISTS "{table_name}" (
|
||||
_row_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
{', '.join(columns_def)}
|
||||
)
|
||||
"""
|
||||
cursor.execute(create_table_sql)
|
||||
conn.commit()
|
||||
|
||||
# 数据清洗和类型转换
|
||||
for col in df.columns:
|
||||
# 对于CSV文件,去除前后空格
|
||||
if file_ext == '.csv' and df[col].dtype == 'object':
|
||||
df[col] = df[col].apply(lambda x: x.strip() if isinstance(x, str) else x)
|
||||
|
||||
# 将 NaN 替换为 None (SQLite NULL)
|
||||
df[col] = df[col].where(pd.notna(df[col]), None)
|
||||
|
||||
# 日期类型转换为 ISO 格式
|
||||
# try:
|
||||
# # 尝试识别日期列
|
||||
# if df[col].dtype == 'object' and df[col].notna().any():
|
||||
# sample = df[col].dropna().head(100)
|
||||
# dates = pd.to_datetime(sample, errors='coerce')
|
||||
# if dates.notna().sum() / len(sample) > 0.9:
|
||||
# df[col] = pd.to_datetime(df[col], errors='coerce')
|
||||
# df[col] = df[col].dt.strftime('%Y-%m-%d %H:%M:%S')
|
||||
# except:
|
||||
# pass
|
||||
|
||||
# 计算安全的批次大小(考虑 SQLite 变量限制)
|
||||
num_columns = len(df.columns)
|
||||
safe_batch_size = min(CHUNK_SIZE, SQLITE_MAX_VARIABLES // num_columns - 1)
|
||||
|
||||
# 批量导入数据
|
||||
error_count = 0
|
||||
imported_rows = 0
|
||||
|
||||
for batch_start in range(0, total_rows, safe_batch_size):
|
||||
batch_end = min(batch_start + safe_batch_size, total_rows)
|
||||
batch_df = df.iloc[batch_start:batch_end]
|
||||
|
||||
try:
|
||||
# 批量插入
|
||||
batch_df.to_sql(table_name, conn, if_exists='append',
|
||||
index=False, method='multi', chunksize=100)
|
||||
|
||||
imported_rows = batch_end
|
||||
|
||||
except Exception as e:
|
||||
error_count += 1
|
||||
print(f"批次导入错误 [{batch_start}-{batch_end}]: {str(e)}")
|
||||
|
||||
# 优化表
|
||||
cursor.execute(f'ANALYZE "{table_name}"')
|
||||
total_rows_all += total_rows
|
||||
imported_rows_all += imported_rows
|
||||
error_count_all += error_count
|
||||
table_name_all.append(table_name)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
table_name_all = set(table_name_all)
|
||||
return {
|
||||
"success": True,
|
||||
"table_name": table_name,
|
||||
"total_rows": total_rows_all,
|
||||
"imported_rows": imported_rows_all,
|
||||
"error_count": error_count_all,
|
||||
"message": f"成功导入 {imported_rows}/{total_rows} 行数据到表 {table_name_all}"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
169
start_daemon.sh
Normal file
169
start_daemon.sh
Normal file
@@ -0,0 +1,169 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 铁路项目管理系统 - 后台运行启动脚本
|
||||
# 使用方法: ./start_daemon.sh [端口号]
|
||||
# 默认端口: 8000
|
||||
|
||||
# 设置端口号,默认为8000
|
||||
PORT=${1:-8000}
|
||||
|
||||
# 设置日志文件路径
|
||||
LOG_DIR="logs"
|
||||
ACCESS_LOG="$LOG_DIR/access.log"
|
||||
ERROR_LOG="$LOG_DIR/error.log"
|
||||
APP_LOG="$LOG_DIR/app.log"
|
||||
PID_FILE="$LOG_DIR/app.pid"
|
||||
|
||||
# 创建日志目录
|
||||
mkdir -p "$LOG_DIR"
|
||||
|
||||
echo "=== 铁路项目管理系统后台启动脚本 ==="
|
||||
echo "端口: $PORT"
|
||||
echo "日志目录: $LOG_DIR"
|
||||
|
||||
# 清理端口函数
|
||||
kill_process_on_port() {
|
||||
local port=$1
|
||||
echo "检查端口 ${port} 的进程..."
|
||||
|
||||
# Linux系统使用不同的命令查找占用端口的进程
|
||||
if command -v lsof >/dev/null 2>&1; then
|
||||
# 使用lsof
|
||||
pid=$(lsof -ti :${port})
|
||||
elif command -v netstat >/dev/null 2>&1; then
|
||||
# 使用netstat
|
||||
pid=$(netstat -tlnp 2>/dev/null | grep ":${port} " | awk '{print $7}' | cut -d'/' -f1)
|
||||
elif command -v ss >/dev/null 2>&1; then
|
||||
# 使用ss
|
||||
pid=$(ss -tlnp | grep ":${port} " | sed 's/.*pid=\([0-9]*\).*/\1/')
|
||||
else
|
||||
echo "警告: 无法找到 lsof、netstat 或 ss 命令,跳过端口检查"
|
||||
return
|
||||
fi
|
||||
|
||||
if [ -n "$pid" ] && [ "$pid" != "" ]; then
|
||||
echo "发现端口 ${port} 被占用,PID: ${pid},正在杀死进程..."
|
||||
kill -9 $pid 2>/dev/null
|
||||
sleep 2
|
||||
echo "已杀死端口 ${port} 的进程"
|
||||
else
|
||||
echo "端口 ${port} 未被占用"
|
||||
fi
|
||||
}
|
||||
|
||||
# 停止已运行的服务函数
|
||||
stop_service() {
|
||||
if [ -f "$PID_FILE" ]; then
|
||||
local old_pid=$(cat "$PID_FILE")
|
||||
if ps -p $old_pid > /dev/null 2>&1; then
|
||||
echo "停止已运行的服务 (PID: $old_pid)..."
|
||||
kill $old_pid
|
||||
sleep 3
|
||||
if ps -p $old_pid > /dev/null 2>&1; then
|
||||
echo "强制停止服务..."
|
||||
kill -9 $old_pid
|
||||
fi
|
||||
fi
|
||||
rm -f "$PID_FILE"
|
||||
fi
|
||||
}
|
||||
|
||||
# 检查Python环境
|
||||
check_python() {
|
||||
if ! command -v python3 >/dev/null 2>&1; then
|
||||
echo "错误: 未找到 python3"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 检查虚拟环境
|
||||
if [ -d ".venv" ]; then
|
||||
echo "激活虚拟环境..."
|
||||
source .venv/bin/activate
|
||||
elif [ -d "venv" ]; then
|
||||
echo "激活虚拟环境..."
|
||||
source venv/bin/activate
|
||||
else
|
||||
echo "警告: 未找到虚拟环境,使用系统Python"
|
||||
fi
|
||||
}
|
||||
|
||||
# 检查依赖
|
||||
check_dependencies() {
|
||||
echo "检查依赖包..."
|
||||
python3 -c "import fastapi, uvicorn" 2>/dev/null
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "错误: 缺少必要的依赖包 (fastapi, uvicorn)"
|
||||
echo "请运行: pip install -r requirements.txt"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# 主函数
|
||||
main() {
|
||||
# 停止已运行的服务
|
||||
stop_service
|
||||
|
||||
# 清理端口
|
||||
kill_process_on_port $PORT
|
||||
|
||||
# 检查环境
|
||||
check_python
|
||||
check_dependencies
|
||||
|
||||
echo "启动服务..."
|
||||
echo "启动时间: $(date)" > "$APP_LOG"
|
||||
echo "端口: $PORT" >> "$APP_LOG"
|
||||
echo "=================================" >> "$APP_LOG"
|
||||
|
||||
# 后台启动服务
|
||||
nohup python3 -m uvicorn app.main:app \
|
||||
--host 0.0.0.0 \
|
||||
--port $PORT \
|
||||
--access-log \
|
||||
--log-level info \
|
||||
>> "$APP_LOG" 2>&1 &
|
||||
|
||||
# 保存PID
|
||||
echo $! > "$PID_FILE"
|
||||
|
||||
# 等待服务启动
|
||||
sleep 3
|
||||
|
||||
# 检查服务是否启动成功
|
||||
if ps -p $(cat "$PID_FILE") > /dev/null 2>&1; then
|
||||
echo "✅ 服务启动成功!"
|
||||
echo "🌐 访问地址: http://localhost:$PORT"
|
||||
echo "📖 API文档: http://localhost:$PORT/docs"
|
||||
echo "📋 ReDoc文档: http://localhost:$PORT/redoc"
|
||||
echo "📄 日志文件: $APP_LOG"
|
||||
echo "🔍 进程ID: $(cat $PID_FILE)"
|
||||
echo ""
|
||||
echo "查看日志: tail -f $APP_LOG"
|
||||
echo "停止服务: ./stop.sh 或 kill $(cat $PID_FILE)"
|
||||
|
||||
# 创建停止脚本
|
||||
cat > stop.sh << EOF
|
||||
#!/bin/bash
|
||||
if [ -f "$PID_FILE" ]; then
|
||||
PID=\$(cat "$PID_FILE")
|
||||
echo "停止服务 (PID: \$PID)..."
|
||||
kill \$PID
|
||||
rm -f "$PID_FILE"
|
||||
echo "服务已停止"
|
||||
else
|
||||
echo "未找到运行中的服务"
|
||||
fi
|
||||
EOF
|
||||
chmod +x stop.sh
|
||||
|
||||
else
|
||||
echo "❌ 服务启动失败,请查看日志: $APP_LOG"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# 捕获退出信号
|
||||
trap 'echo "脚本被中断"; exit 1' INT TERM
|
||||
|
||||
# 运行主函数
|
||||
main
|
||||
122
start_server.sh
Normal file
122
start_server.sh
Normal file
@@ -0,0 +1,122 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 铁路项目管理系统 - 直接运行启动脚本
|
||||
# 使用方法: ./start_server.sh [端口号]
|
||||
# 默认端口: 8000
|
||||
|
||||
# 设置端口号,默认为8000
|
||||
PORT=${1:-8000}
|
||||
|
||||
echo "=== 铁路项目管理系统启动脚本 ==="
|
||||
echo "端口: $PORT"
|
||||
echo "按 Ctrl+C 停止服务"
|
||||
|
||||
# 清理端口函数
|
||||
kill_process_on_port() {
|
||||
local port=$1
|
||||
echo "检查端口 ${port} 的进程..."
|
||||
|
||||
# Linux系统使用不同的命令查找占用端口的进程
|
||||
if command -v lsof >/dev/null 2>&1; then
|
||||
# 使用lsof
|
||||
pid=$(lsof -ti :${port})
|
||||
elif command -v netstat >/dev/null 2>&1; then
|
||||
# 使用netstat
|
||||
pid=$(netstat -tlnp 2>/dev/null | grep ":${port} " | awk '{print $7}' | cut -d'/' -f1)
|
||||
elif command -v ss >/dev/null 2>&1; then
|
||||
# 使用ss
|
||||
pid=$(ss -tlnp | grep ":${port} " | sed 's/.*pid=\([0-9]*\).*/\1/')
|
||||
else
|
||||
echo "警告: 无法找到 lsof、netstat 或 ss 命令,跳过端口检查"
|
||||
return
|
||||
fi
|
||||
|
||||
if [ -n "$pid" ] && [ "$pid" != "" ]; then
|
||||
echo "发现端口 ${port} 被占用,PID: ${pid},正在杀死进程..."
|
||||
kill -9 $pid 2>/dev/null
|
||||
sleep 2
|
||||
echo "已杀死端口 ${port} 的进程"
|
||||
else
|
||||
echo "端口 ${port} 未被占用"
|
||||
fi
|
||||
}
|
||||
|
||||
# 检查Python环境
|
||||
check_python() {
|
||||
if ! command -v python3 >/dev/null 2>&1; then
|
||||
echo "错误: 未找到 python3"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 检查虚拟环境
|
||||
if [ -d ".venv" ]; then
|
||||
echo "激活虚拟环境..."
|
||||
source .venv/bin/activate
|
||||
elif [ -d "venv" ]; then
|
||||
echo "激活虚拟环境..."
|
||||
source venv/bin/activate
|
||||
else
|
||||
echo "警告: 未找到虚拟环境,使用系统Python"
|
||||
fi
|
||||
}
|
||||
|
||||
# 检查依赖
|
||||
check_dependencies() {
|
||||
echo "检查依赖包..."
|
||||
python3 -c "import fastapi, uvicorn" 2>/dev/null
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "错误: 缺少必要的依赖包 (fastapi, uvicorn)"
|
||||
echo "请运行: pip install -r requirements.txt"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# 显示服务信息
|
||||
show_service_info() {
|
||||
echo ""
|
||||
echo "=================================="
|
||||
echo "✅ 服务启动成功!"
|
||||
echo "🌐 访问地址: http://localhost:$PORT"
|
||||
echo "📖 API文档: http://localhost:$PORT/docs"
|
||||
echo "📋 ReDoc文档: http://localhost:$PORT/redoc"
|
||||
echo "🔍 健康检查: http://localhost:$PORT/health"
|
||||
echo "=================================="
|
||||
echo ""
|
||||
}
|
||||
|
||||
# 清理函数
|
||||
cleanup() {
|
||||
echo ""
|
||||
echo "正在停止服务..."
|
||||
# uvicorn会自动处理SIGTERM信号
|
||||
exit 0
|
||||
}
|
||||
|
||||
# 主函数
|
||||
main() {
|
||||
# 清理端口
|
||||
kill_process_on_port $PORT
|
||||
|
||||
# 检查环境
|
||||
check_python
|
||||
check_dependencies
|
||||
|
||||
# 捕获退出信号
|
||||
trap cleanup SIGINT SIGTERM
|
||||
|
||||
echo "启动服务..."
|
||||
|
||||
# 显示服务信息(在后台执行,延迟3秒显示)
|
||||
(sleep 3 && show_service_info) &
|
||||
|
||||
# 直接启动服务(前台运行)
|
||||
python3 -m uvicorn app.main:app \
|
||||
--host 0.0.0.0 \
|
||||
--port $PORT \
|
||||
--reload \
|
||||
--access-log \
|
||||
--log-level info
|
||||
}
|
||||
|
||||
# 运行主函数
|
||||
main
|
||||
Reference in New Issue
Block a user