接口优化
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File, Form
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
import base64
|
||||
from ..core.database import get_db
|
||||
from ..schemas.database import (
|
||||
SQLExecuteRequest, SQLExecuteResponse, TableDataRequest,
|
||||
TableDataResponse, CreateTableRequest, ImportDataRequest, FileImportRequest
|
||||
TableDataResponse, CreateTableRequest, ImportDataRequest, FileImportFormData
|
||||
)
|
||||
from ..services.database import DatabaseService
|
||||
|
||||
@@ -53,19 +54,64 @@ def import_data(request: ImportDataRequest, db: Session = Depends(get_db)):
|
||||
result = DatabaseService.import_data(db, request.table_name, request.data)
|
||||
return SQLExecuteResponse(**result)
|
||||
|
||||
@router.post("/tables", response_model=List[str])
|
||||
@router.post("/tables", response_model=SQLExecuteResponse)
|
||||
def get_table_list():
|
||||
"""获取所有表名"""
|
||||
return DatabaseService.get_table_list()
|
||||
result = DatabaseService.get_table_list()
|
||||
return SQLExecuteResponse(**result)
|
||||
|
||||
@router.post("/import-file", response_model=SQLExecuteResponse)
|
||||
def import_file(request: FileImportRequest, db: Session = Depends(get_db)):
|
||||
async def import_file(
|
||||
file: UploadFile = File(..., description="上传的Excel/CSV文件"),
|
||||
table_name: Optional[str] = Form(None, description="自定义表名,如果不提供则使用文件名"),
|
||||
force_overwrite: bool = Form(False, description="是否强制覆盖已存在的表"),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""导入Excel/CSV文件到数据库"""
|
||||
result = DatabaseService.import_file_to_database(
|
||||
db,
|
||||
request.filename,
|
||||
request.file_content,
|
||||
request.table_name,
|
||||
request.force_overwrite
|
||||
)
|
||||
return SQLExecuteResponse(**result)
|
||||
try:
|
||||
# 检查文件类型
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="文件名不能为空")
|
||||
|
||||
# 处理中文文件名和后缀名 - 更可靠的方法
|
||||
try:
|
||||
# FastAPI的UploadFile已经正确处理了文件名编码
|
||||
filename = file.filename
|
||||
except:
|
||||
filename = "unknown_file"
|
||||
|
||||
if not filename or filename == "":
|
||||
raise HTTPException(status_code=400, detail="文件名不能为空")
|
||||
|
||||
# 提取文件扩展名
|
||||
if '.' not in filename:
|
||||
raise HTTPException(status_code=400, detail="文件必须有扩展名")
|
||||
|
||||
file_ext = filename.lower().split('.')[-1]
|
||||
|
||||
if file_ext not in ['csv', 'xlsx', 'xls']:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的文件类型: .{file_ext},仅支持 .csv, .xlsx, .xls"
|
||||
)
|
||||
|
||||
# 读取文件内容
|
||||
file_content = await file.read()
|
||||
|
||||
# 将文件内容转换为base64(保持与现有FileImportUtils兼容)
|
||||
file_content_base64 = base64.b64encode(file_content).decode('utf-8')
|
||||
|
||||
# 调用服务层方法
|
||||
result = DatabaseService.import_file_to_database(
|
||||
db, filename, file_content_base64, table_name, force_overwrite
|
||||
)
|
||||
|
||||
return SQLExecuteResponse(**result)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
return SQLExecuteResponse(
|
||||
success=False,
|
||||
message=f"文件导入失败: {str(e)}"
|
||||
)
|
||||
@@ -32,7 +32,11 @@ class AccountListRequest(BaseModel):
|
||||
limit: Optional[int] = 100
|
||||
|
||||
class AccountGetRequest(BaseModel):
|
||||
account_id: int
|
||||
account_id: Optional[int] = None
|
||||
account: Optional[str] = None
|
||||
section: Optional[str] = None
|
||||
status: Optional[int] = None
|
||||
today_updated: Optional[int] = None
|
||||
|
||||
class AccountUpdateRequest(BaseModel):
|
||||
account_id: int
|
||||
|
||||
@@ -30,8 +30,7 @@ class ImportDataRequest(BaseModel):
|
||||
table_name: str
|
||||
data: List[Dict[str, Any]]
|
||||
|
||||
class FileImportRequest(BaseModel):
|
||||
filename: str
|
||||
file_content: str # base64编码的文件内容
|
||||
class FileImportFormData(BaseModel):
|
||||
"""文件导入表单数据"""
|
||||
table_name: Optional[str] = None # 可选的自定义表名
|
||||
force_overwrite: bool = False # 是否强制覆盖已存在的表
|
||||
@@ -192,9 +192,17 @@ class DatabaseService:
|
||||
try:
|
||||
inspector = inspect(engine)
|
||||
# 排除mysql的系统表和accounts表
|
||||
return [table for table in inspector.get_table_names() if not table.startswith('mysql') and table != 'accounts']
|
||||
data = [table for table in inspector.get_table_names() if not table.startswith('mysql') and table != 'accounts' and table != 'apscheduler_jobs']
|
||||
return {
|
||||
"success": True,
|
||||
"message": "获取表名成功",
|
||||
"data": data
|
||||
}
|
||||
except Exception as e:
|
||||
return []
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"获取表名失败: {str(e)}"
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def import_file_to_database(db: Session, filename: str, file_content: str,
|
||||
|
||||
Reference in New Issue
Block a user