接口优化

This commit is contained in:
lhx
2025-09-27 09:30:46 +08:00
parent 52d2753824
commit 0b1e9851dd
7 changed files with 368 additions and 384 deletions

View File

@@ -1,10 +1,11 @@
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File, Form
from sqlalchemy.orm import Session
from typing import List
from typing import List, Optional
import base64
from ..core.database import get_db
from ..schemas.database import (
SQLExecuteRequest, SQLExecuteResponse, TableDataRequest,
TableDataResponse, CreateTableRequest, ImportDataRequest, FileImportRequest
TableDataResponse, CreateTableRequest, ImportDataRequest, FileImportFormData
)
from ..services.database import DatabaseService
@@ -53,19 +54,64 @@ def import_data(request: ImportDataRequest, db: Session = Depends(get_db)):
result = DatabaseService.import_data(db, request.table_name, request.data)
return SQLExecuteResponse(**result)
@router.post("/tables", response_model=List[str])
@router.post("/tables", response_model=SQLExecuteResponse)
def get_table_list():
"""获取所有表名"""
return DatabaseService.get_table_list()
result = DatabaseService.get_table_list()
return SQLExecuteResponse(**result)
@router.post("/import-file", response_model=SQLExecuteResponse)
def import_file(request: FileImportRequest, db: Session = Depends(get_db)):
async def import_file(
file: UploadFile = File(..., description="上传的Excel/CSV文件"),
table_name: Optional[str] = Form(None, description="自定义表名,如果不提供则使用文件名"),
force_overwrite: bool = Form(False, description="是否强制覆盖已存在的表"),
db: Session = Depends(get_db)
):
"""导入Excel/CSV文件到数据库"""
result = DatabaseService.import_file_to_database(
db,
request.filename,
request.file_content,
request.table_name,
request.force_overwrite
)
return SQLExecuteResponse(**result)
try:
# 检查文件类型
if not file.filename:
raise HTTPException(status_code=400, detail="文件名不能为空")
# 处理中文文件名和后缀名 - 更可靠的方法
try:
# FastAPI的UploadFile已经正确处理了文件名编码
filename = file.filename
except:
filename = "unknown_file"
if not filename or filename == "":
raise HTTPException(status_code=400, detail="文件名不能为空")
# 提取文件扩展名
if '.' not in filename:
raise HTTPException(status_code=400, detail="文件必须有扩展名")
file_ext = filename.lower().split('.')[-1]
if file_ext not in ['csv', 'xlsx', 'xls']:
raise HTTPException(
status_code=400,
detail=f"不支持的文件类型: .{file_ext},仅支持 .csv, .xlsx, .xls"
)
# 读取文件内容
file_content = await file.read()
# 将文件内容转换为base64保持与现有FileImportUtils兼容
file_content_base64 = base64.b64encode(file_content).decode('utf-8')
# 调用服务层方法
result = DatabaseService.import_file_to_database(
db, filename, file_content_base64, table_name, force_overwrite
)
return SQLExecuteResponse(**result)
except HTTPException:
raise
except Exception as e:
return SQLExecuteResponse(
success=False,
message=f"文件导入失败: {str(e)}"
)