Files
railway_cloud/app/api/database.py

162 lines
5.7 KiB
Python

from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File, Form
from sqlalchemy.orm import Session
from typing import List, Optional
import base64
from ..core.database import get_db
from ..core.response_code import ResponseCode, ResponseMessage
from ..schemas.database import (
SQLExecuteRequest, SQLExecuteResponse, TableDataRequest,
TableDataResponse, CreateTableRequest, ImportDataRequest, FileImportFormData
)
from ..services.database import DatabaseService
router = APIRouter(prefix="/database", tags=["数据库管理"])
@router.post("/execute", response_model=SQLExecuteResponse)
def execute_sql(request: SQLExecuteRequest, db: Session = Depends(get_db)):
"""执行SQL语句"""
result = DatabaseService.execute_sql(db, request.sql)
return SQLExecuteResponse(
code=ResponseCode.SUCCESS if result.get('success') else ResponseCode.DATABASE_ERROR,
message=result['message'],
data=result.get('data')
)
@router.post("/table-data", response_model=TableDataResponse)
def get_table_data(request: TableDataRequest, db: Session = Depends(get_db)):
"""获取表数据"""
result = DatabaseService.get_table_data(
db,
request.table_name,
request.limit or 100,
request.offset or 0
)
return TableDataResponse(
code=ResponseCode.SUCCESS if result.get('success') else ResponseCode.DATABASE_ERROR,
message=result['message'],
total=result.get('total_count'),
data=result.get('data')
)
@router.post("/create-table", response_model=SQLExecuteResponse)
def create_table(request: CreateTableRequest, db: Session = Depends(get_db)):
"""创建表"""
result = DatabaseService.create_table(
db,
request.table_name,
request.columns,
request.primary_key
)
return SQLExecuteResponse(
code=ResponseCode.SUCCESS if result.get('success') else ResponseCode.DATABASE_ERROR,
message=result['message'],
data=None
)
@router.post("/drop-table", response_model=SQLExecuteResponse)
def drop_table(request: dict, db: Session = Depends(get_db)):
"""删除表"""
table_name = request.get("table_name")
if not table_name:
return SQLExecuteResponse(
code=ResponseCode.BAD_REQUEST,
message="table_name is required",
data=None
)
result = DatabaseService.drop_table(db, table_name)
return SQLExecuteResponse(
code=ResponseCode.SUCCESS if result.get('success') else ResponseCode.DATABASE_ERROR,
message=result['message'],
data=None
)
@router.post("/import-data", response_model=SQLExecuteResponse)
def import_data(request: ImportDataRequest, db: Session = Depends(get_db)):
"""导入数据"""
result = DatabaseService.import_data(db, request.table_name, request.data)
return SQLExecuteResponse(
code=ResponseCode.SUCCESS if result.get('success') else ResponseCode.DATABASE_ERROR,
message=result['message'],
data=None
)
@router.post("/tables", response_model=SQLExecuteResponse)
def get_table_list():
"""获取所有表名"""
result = DatabaseService.get_table_list()
return SQLExecuteResponse(
code=ResponseCode.SUCCESS if result.get('success') else ResponseCode.DATABASE_ERROR,
message=result['message'],
data=result.get('data')
)
@router.post("/import-file", response_model=SQLExecuteResponse)
async def import_file(
file: UploadFile = File(..., description="上传的Excel/CSV文件"),
table_name: Optional[str] = Form(None, description="自定义表名,如果不提供则使用文件名"),
force_overwrite: bool = Form(False, description="是否强制覆盖已存在的表"),
db: Session = Depends(get_db)
):
"""导入Excel/CSV文件到数据库"""
try:
# 检查文件类型
if not file.filename:
return SQLExecuteResponse(
code=ResponseCode.BAD_REQUEST,
message="文件名不能为空",
data=None
)
# 处理中文文件名和后缀名
try:
filename = file.filename
except:
filename = "unknown_file"
if not filename or filename == "":
return SQLExecuteResponse(
code=ResponseCode.BAD_REQUEST,
message="文件名不能为空",
data=None
)
# 提取文件扩展名
if '.' not in filename:
return SQLExecuteResponse(
code=ResponseCode.BAD_REQUEST,
message="文件必须有扩展名",
data=None
)
file_ext = filename.lower().split('.')[-1]
if file_ext not in ['csv', 'xlsx', 'xls']:
return SQLExecuteResponse(
code=ResponseCode.BAD_REQUEST,
message=f"不支持的文件类型: .{file_ext},仅支持 .csv, .xlsx, .xls",
data=None
)
# 读取文件内容
file_content = await file.read()
# 将文件内容转换为base64
file_content_base64 = base64.b64encode(file_content).decode('utf-8')
# 调用服务层方法
result = DatabaseService.import_file_to_database(
db, filename, file_content_base64, table_name, force_overwrite
)
return SQLExecuteResponse(
code=ResponseCode.SUCCESS if result.get('success') else ResponseCode.IMPORT_FAILED,
message=result['message'],
data=result.get('data') if result.get('success') else None
)
except Exception as e:
return SQLExecuteResponse(
code=ResponseCode.IMPORT_FAILED,
message=f"文件导入失败: {str(e)}",
data=None
)