导入文件建表接口

This commit is contained in:
lhx
2025-09-26 17:19:58 +08:00
parent 18478b148a
commit 52d2753824
6 changed files with 735 additions and 5 deletions

View File

@@ -4,7 +4,7 @@ from typing import List
from ..core.database import get_db
from ..schemas.database import (
SQLExecuteRequest, SQLExecuteResponse, TableDataRequest,
TableDataResponse, CreateTableRequest, ImportDataRequest
TableDataResponse, CreateTableRequest, ImportDataRequest, FileImportRequest
)
from ..services.database import DatabaseService
@@ -56,4 +56,16 @@ def import_data(request: ImportDataRequest, db: Session = Depends(get_db)):
@router.post("/tables", response_model=List[str])
def get_table_list():
"""获取所有表名"""
return DatabaseService.get_table_list()
return DatabaseService.get_table_list()
@router.post("/import-file", response_model=SQLExecuteResponse)
def import_file(request: FileImportRequest, db: Session = Depends(get_db)):
"""导入Excel/CSV文件到数据库"""
result = DatabaseService.import_file_to_database(
db,
request.filename,
request.file_content,
request.table_name,
request.force_overwrite
)
return SQLExecuteResponse(**result)

View File

@@ -28,4 +28,10 @@ class CreateTableRequest(BaseModel):
class ImportDataRequest(BaseModel):
table_name: str
data: List[Dict[str, Any]]
data: List[Dict[str, Any]]
class FileImportRequest(BaseModel):
filename: str
file_content: str # base64编码的文件内容
table_name: Optional[str] = None # 可选的自定义表名
force_overwrite: bool = False # 是否强制覆盖已存在的表

View File

@@ -3,6 +3,7 @@ from sqlalchemy import text, MetaData, Table, Column, create_engine, inspect
from sqlalchemy.exc import SQLAlchemyError
from typing import List, Dict, Any, Optional
from ..core.database import engine
from ..utils.file_import import FileImportUtils
import pandas as pd
class DatabaseService:
@@ -193,4 +194,114 @@ class DatabaseService:
# 排除mysql的系统表和accounts表
return [table for table in inspector.get_table_names() if not table.startswith('mysql') and table != 'accounts']
except Exception as e:
return []
return []
@staticmethod
def import_file_to_database(db: Session, filename: str, file_content: str,
table_name: Optional[str] = None,
force_overwrite: bool = False) -> Dict[str, Any]:
"""从文件导入数据到数据库"""
try:
# 解析文件内容
file_data = FileImportUtils.parse_file(filename, file_content)
if not file_data:
return {
"success": False,
"message": "文件中没有找到有效数据"
}
results = []
inspector = inspect(engine)
existing_tables = inspector.get_table_names()
for item in file_data:
final_table_name = table_name if table_name else item["table_name"]
data = item["data"]
columns = item["columns"]
if not data:
continue
# 检查表是否存在
if final_table_name in existing_tables and not force_overwrite:
return {
"success": False,
"message": f"{final_table_name} 已存在,请使用 force_overwrite=true 覆盖或选择其他表名"
}
# 如果需要覆盖,先删除表
if final_table_name in existing_tables and force_overwrite:
db.execute(text(f"DROP TABLE IF EXISTS `{final_table_name}`"))
db.commit()
# 准备列类型定义
column_types = FileImportUtils.prepare_table_columns(data, columns)
# 创建表
column_definitions = []
for col in columns:
col_type = column_types.get(col, "TEXT")
column_definitions.append(f"`{col}` {col_type}")
create_table_sql = f"""
CREATE TABLE `{final_table_name}` (
id INT AUTO_INCREMENT PRIMARY KEY,
{', '.join(column_definitions)}
)
"""
db.execute(text(create_table_sql))
db.commit()
# 导入数据
df_clean = pd.DataFrame(data)
# 使用pandas的to_sql方法批量导入
df_clean.to_sql(final_table_name, engine, if_exists='append',
index=False, method='multi', chunksize=1000)
results.append({
"table_name": final_table_name,
"rows_imported": len(data),
"columns": columns
})
# 返回结果
if len(results) == 1:
result = results[0]
return {
"success": True,
"message": f"成功导入 {result['rows_imported']} 行数据到表 {result['table_name']}",
"table_name": result["table_name"],
"rows_imported": result["rows_imported"],
"columns": result["columns"]
}
else:
total_rows = sum(r["rows_imported"] for r in results)
table_names = [r["table_name"] for r in results]
return {
"success": True,
"message": f"成功导入 {total_rows} 行数据到 {len(results)} 个表",
"tables": results,
"total_rows": total_rows,
"table_names": table_names
}
except ValueError as e:
return {
"success": False,
"message": str(e)
}
except SQLAlchemyError as e:
db.rollback()
return {
"success": False,
"message": f"数据库操作失败: {str(e)}"
}
except Exception as e:
db.rollback()
return {
"success": False,
"message": f"导入失败: {str(e)}"
}

236
app/utils/file_import.py Normal file
View File

@@ -0,0 +1,236 @@
"""
文件导入工具类
处理Excel/CSV文件的读取、解析和数据清洗
"""
import base64
import io
import openpyxl
import pandas as pd
import re
from pathlib import Path
from typing import Any, Dict, List, Optional
class FileImportUtils:
"""文件导入工具类"""
@staticmethod
def clean_cell_value(value: Any) -> Any:
"""清理单元格内容:去除空格、换行、制表符等"""
if value is None:
return None
if isinstance(value, str):
cleaned = value.replace(' ', '').replace('\n', '').replace('\r', '').replace('\t', '').replace('\u3000', '')
return cleaned if cleaned else None
return value
@staticmethod
def clean_column_name(col_name: str) -> str:
"""清理列名使其符合MySQL命名规范"""
if not col_name:
return "column_1"
# 去除特殊字符,只保留字母、数字、下划线、中文
cleaned = re.sub(r'[^\w\u4e00-\u9fff]', '_', str(col_name))
# 如果以数字开头,添加前缀
if cleaned and cleaned[0].isdigit():
cleaned = f"col_{cleaned}"
return cleaned[:64] # MySQL列名最大长度限制
@staticmethod
def infer_mysql_column_type(series: pd.Series) -> str:
"""推断MySQL列类型"""
# 去除空值
non_null_series = series.dropna()
if len(non_null_series) == 0:
return "TEXT"
# 检查是否为数字
try:
pd.to_numeric(non_null_series)
# 检查是否为整数
if all(isinstance(x, (int, float)) and (pd.isna(x) or x == int(x)) for x in non_null_series):
max_val = abs(non_null_series.max())
if max_val < 2147483648: # INT范围
return "INT"
else:
return "BIGINT"
else:
return "DECIMAL(10,2)"
except:
pass
# 检查字符串长度
max_length = non_null_series.astype(str).str.len().max()
if max_length <= 255:
return f"VARCHAR({min(max_length * 2, 255)})"
else:
return "TEXT"
@staticmethod
def decode_file_content(file_content: str) -> bytes:
"""解码base64文件内容"""
try:
return base64.b64decode(file_content)
except Exception as e:
raise ValueError(f"文件内容解码失败: {str(e)}")
@staticmethod
def read_excel_file(file_content: bytes, filename: str) -> List[Dict[str, Any]]:
"""读取Excel文件内容"""
try:
wb = openpyxl.load_workbook(io.BytesIO(file_content), data_only=True)
table_name_base = Path(filename).stem
results = []
for sheet_name in wb.sheetnames:
ws = wb[sheet_name]
if ws.max_row is None or ws.max_row <= 1:
continue
# 读取数据
data = []
headers = []
# 读取第一行作为列名
for col in range(1, ws.max_column + 1):
header_cell = ws.cell(row=1, column=col)
header = FileImportUtils.clean_cell_value(header_cell.value)
if header:
headers.append(FileImportUtils.clean_column_name(header))
else:
headers.append(f"column_{col}")
# 读取数据行
for row in range(2, ws.max_row + 1):
row_data = {}
has_data = False
for col, header in enumerate(headers, 1):
cell_value = FileImportUtils.clean_cell_value(ws.cell(row=row, column=col).value)
row_data[header] = cell_value
if cell_value is not None:
has_data = True
if has_data:
data.append(row_data)
if data:
table_name = f"{table_name_base}_{sheet_name}" if len(wb.sheetnames) > 1 else table_name_base
results.append({
"table_name": FileImportUtils.clean_column_name(table_name),
"data": data,
"columns": headers
})
return results
except Exception as e:
raise Exception(f"读取Excel文件失败: {str(e)}")
@staticmethod
def read_csv_file(file_content: bytes, filename: str) -> List[Dict[str, Any]]:
"""读取CSV文件内容"""
try:
# 尝试不同编码
encodings = ['utf-8', 'gbk', 'gb2312', 'gb18030']
df = None
for encoding in encodings:
try:
df = pd.read_csv(io.BytesIO(file_content), encoding=encoding)
break
except:
continue
if df is None:
raise Exception("无法识别文件编码")
# 清理列名
df.columns = [FileImportUtils.clean_column_name(col) for col in df.columns]
# 处理重复列名
if len(df.columns) != len(set(df.columns)):
new_cols = []
col_count = {}
for col in df.columns:
if col in col_count:
col_count[col] += 1
new_cols.append(f"{col}_{col_count[col]}")
else:
col_count[col] = 0
new_cols.append(col)
df.columns = new_cols
# 清理数据
for col in df.columns:
df[col] = df[col].apply(lambda x: FileImportUtils.clean_cell_value(x))
df[col] = df[col].where(pd.notna(df[col]), None)
# 转换为字典列表
data = df.to_dict('records')
table_name = FileImportUtils.clean_column_name(Path(filename).stem)
return [{
"table_name": table_name,
"data": data,
"columns": list(df.columns)
}]
except Exception as e:
raise Exception(f"读取CSV文件失败: {str(e)}")
@staticmethod
def parse_file(filename: str, file_content: str) -> List[Dict[str, Any]]:
"""
解析文件内容,根据文件扩展名选择相应的解析方法
Args:
filename: 文件名
file_content: base64编码的文件内容
Returns:
解析后的文件数据列表
"""
# 解码文件内容
file_bytes = FileImportUtils.decode_file_content(file_content)
# 根据文件扩展名处理文件
file_ext = Path(filename).suffix.lower()
if file_ext == '.csv':
return FileImportUtils.read_csv_file(file_bytes, filename)
elif file_ext in ['.xlsx', '.xls']:
return FileImportUtils.read_excel_file(file_bytes, filename)
else:
raise ValueError(f"不支持的文件类型: {file_ext}")
@staticmethod
def prepare_table_columns(data: List[Dict[str, Any]], columns: List[str]) -> Dict[str, str]:
"""
分析数据并准备表列定义
Args:
data: 数据列表
columns: 列名列表
Returns:
列名到MySQL数据类型的映射
"""
if not data:
return {}
# 创建DataFrame用于类型推断
df = pd.DataFrame(data)
column_types = {}
for col in columns:
if col in df.columns:
column_types[col] = FileImportUtils.infer_mysql_column_type(df[col])
else:
column_types[col] = "TEXT"
return column_types