from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File, Form from sqlalchemy.orm import Session from typing import List, Optional import base64 from ..core.database import get_db from ..core.response_code import ResponseCode, ResponseMessage from ..schemas.database import ( SQLExecuteRequest, SQLExecuteResponse, TableDataRequest, TableDataResponse, CreateTableRequest, ImportDataRequest, FileImportFormData ) from ..services.database import DatabaseService router = APIRouter(prefix="/database", tags=["数据库管理"]) @router.post("/execute", response_model=SQLExecuteResponse) def execute_sql(request: SQLExecuteRequest, db: Session = Depends(get_db)): """执行SQL语句""" result = DatabaseService.execute_sql(db, request.sql) return SQLExecuteResponse( code=ResponseCode.SUCCESS if result.get('success') else ResponseCode.DATABASE_ERROR, message=result['message'], data=result.get('data') ) @router.post("/table-data", response_model=TableDataResponse) def get_table_data(request: TableDataRequest, db: Session = Depends(get_db)): """获取表数据""" result = DatabaseService.get_table_data( db, request.table_name, request.limit or 100, request.offset or 0 ) return TableDataResponse( code=ResponseCode.SUCCESS if result.get('success') else ResponseCode.DATABASE_ERROR, message=result['message'], total=result.get('total_count'), data=result.get('data') ) @router.post("/create-table", response_model=SQLExecuteResponse) def create_table(request: CreateTableRequest, db: Session = Depends(get_db)): """创建表""" result = DatabaseService.create_table( db, request.table_name, request.columns, request.primary_key ) return SQLExecuteResponse( code=ResponseCode.SUCCESS if result.get('success') else ResponseCode.DATABASE_ERROR, message=result['message'], data=None ) @router.post("/drop-table", response_model=SQLExecuteResponse) def drop_table(request: dict, db: Session = Depends(get_db)): """删除表""" table_name = request.get("table_name") if not table_name: return SQLExecuteResponse( code=ResponseCode.BAD_REQUEST, message="table_name is required", data=None ) result = DatabaseService.drop_table(db, table_name) return SQLExecuteResponse( code=ResponseCode.SUCCESS if result.get('success') else ResponseCode.DATABASE_ERROR, message=result['message'], data=None ) @router.post("/import-data", response_model=SQLExecuteResponse) def import_data(request: ImportDataRequest, db: Session = Depends(get_db)): """导入数据""" result = DatabaseService.import_data(db, request.table_name, request.data) return SQLExecuteResponse( code=ResponseCode.SUCCESS if result.get('success') else ResponseCode.DATABASE_ERROR, message=result['message'], data=None ) @router.post("/tables", response_model=SQLExecuteResponse) def get_table_list(): """获取所有表名""" result = DatabaseService.get_table_list() return SQLExecuteResponse( code=ResponseCode.SUCCESS if result.get('success') else ResponseCode.DATABASE_ERROR, message=result['message'], data=result.get('data') ) @router.post("/import-file", response_model=SQLExecuteResponse) async def import_file( file: UploadFile = File(..., description="上传的Excel/CSV文件"), table_name: Optional[str] = Form(None, description="自定义表名,如果不提供则使用文件名"), force_overwrite: bool = Form(False, description="是否强制覆盖已存在的表"), db: Session = Depends(get_db) ): """导入Excel/CSV文件到数据库""" try: # 检查文件类型 if not file.filename: return SQLExecuteResponse( code=ResponseCode.BAD_REQUEST, message="文件名不能为空", data=None ) # 处理中文文件名和后缀名 try: filename = file.filename except: filename = "unknown_file" if not filename or filename == "": return SQLExecuteResponse( code=ResponseCode.BAD_REQUEST, message="文件名不能为空", data=None ) # 提取文件扩展名 if '.' not in filename: return SQLExecuteResponse( code=ResponseCode.BAD_REQUEST, message="文件必须有扩展名", data=None ) file_ext = filename.lower().split('.')[-1] if file_ext not in ['csv', 'xlsx', 'xls']: return SQLExecuteResponse( code=ResponseCode.BAD_REQUEST, message=f"不支持的文件类型: .{file_ext},仅支持 .csv, .xlsx, .xls", data=None ) # 读取文件内容 file_content = await file.read() # 将文件内容转换为base64 file_content_base64 = base64.b64encode(file_content).decode('utf-8') # 调用服务层方法 result = DatabaseService.import_file_to_database( db, filename, file_content_base64, table_name, force_overwrite ) return SQLExecuteResponse( code=ResponseCode.SUCCESS if result.get('success') else ResponseCode.IMPORT_FAILED, message=result['message'], data=result.get('data') if result.get('success') else None ) except Exception as e: return SQLExecuteResponse( code=ResponseCode.IMPORT_FAILED, message=f"文件导入失败: {str(e)}", data=None )