导入文件建表接口
This commit is contained in:
@@ -4,7 +4,7 @@ from typing import List
|
|||||||
from ..core.database import get_db
|
from ..core.database import get_db
|
||||||
from ..schemas.database import (
|
from ..schemas.database import (
|
||||||
SQLExecuteRequest, SQLExecuteResponse, TableDataRequest,
|
SQLExecuteRequest, SQLExecuteResponse, TableDataRequest,
|
||||||
TableDataResponse, CreateTableRequest, ImportDataRequest
|
TableDataResponse, CreateTableRequest, ImportDataRequest, FileImportRequest
|
||||||
)
|
)
|
||||||
from ..services.database import DatabaseService
|
from ..services.database import DatabaseService
|
||||||
|
|
||||||
@@ -56,4 +56,16 @@ def import_data(request: ImportDataRequest, db: Session = Depends(get_db)):
|
|||||||
@router.post("/tables", response_model=List[str])
|
@router.post("/tables", response_model=List[str])
|
||||||
def get_table_list():
|
def get_table_list():
|
||||||
"""获取所有表名"""
|
"""获取所有表名"""
|
||||||
return DatabaseService.get_table_list()
|
return DatabaseService.get_table_list()
|
||||||
|
|
||||||
|
@router.post("/import-file", response_model=SQLExecuteResponse)
|
||||||
|
def import_file(request: FileImportRequest, db: Session = Depends(get_db)):
|
||||||
|
"""导入Excel/CSV文件到数据库"""
|
||||||
|
result = DatabaseService.import_file_to_database(
|
||||||
|
db,
|
||||||
|
request.filename,
|
||||||
|
request.file_content,
|
||||||
|
request.table_name,
|
||||||
|
request.force_overwrite
|
||||||
|
)
|
||||||
|
return SQLExecuteResponse(**result)
|
||||||
@@ -28,4 +28,10 @@ class CreateTableRequest(BaseModel):
|
|||||||
|
|
||||||
class ImportDataRequest(BaseModel):
|
class ImportDataRequest(BaseModel):
|
||||||
table_name: str
|
table_name: str
|
||||||
data: List[Dict[str, Any]]
|
data: List[Dict[str, Any]]
|
||||||
|
|
||||||
|
class FileImportRequest(BaseModel):
|
||||||
|
filename: str
|
||||||
|
file_content: str # base64编码的文件内容
|
||||||
|
table_name: Optional[str] = None # 可选的自定义表名
|
||||||
|
force_overwrite: bool = False # 是否强制覆盖已存在的表
|
||||||
@@ -3,6 +3,7 @@ from sqlalchemy import text, MetaData, Table, Column, create_engine, inspect
|
|||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
from ..core.database import engine
|
from ..core.database import engine
|
||||||
|
from ..utils.file_import import FileImportUtils
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
class DatabaseService:
|
class DatabaseService:
|
||||||
@@ -193,4 +194,114 @@ class DatabaseService:
|
|||||||
# 排除mysql的系统表和accounts表
|
# 排除mysql的系统表和accounts表
|
||||||
return [table for table in inspector.get_table_names() if not table.startswith('mysql') and table != 'accounts']
|
return [table for table in inspector.get_table_names() if not table.startswith('mysql') and table != 'accounts']
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def import_file_to_database(db: Session, filename: str, file_content: str,
|
||||||
|
table_name: Optional[str] = None,
|
||||||
|
force_overwrite: bool = False) -> Dict[str, Any]:
|
||||||
|
"""从文件导入数据到数据库"""
|
||||||
|
try:
|
||||||
|
# 解析文件内容
|
||||||
|
file_data = FileImportUtils.parse_file(filename, file_content)
|
||||||
|
|
||||||
|
if not file_data:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": "文件中没有找到有效数据"
|
||||||
|
}
|
||||||
|
|
||||||
|
results = []
|
||||||
|
inspector = inspect(engine)
|
||||||
|
existing_tables = inspector.get_table_names()
|
||||||
|
|
||||||
|
for item in file_data:
|
||||||
|
final_table_name = table_name if table_name else item["table_name"]
|
||||||
|
data = item["data"]
|
||||||
|
columns = item["columns"]
|
||||||
|
|
||||||
|
if not data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查表是否存在
|
||||||
|
if final_table_name in existing_tables and not force_overwrite:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": f"表 {final_table_name} 已存在,请使用 force_overwrite=true 覆盖或选择其他表名"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 如果需要覆盖,先删除表
|
||||||
|
if final_table_name in existing_tables and force_overwrite:
|
||||||
|
db.execute(text(f"DROP TABLE IF EXISTS `{final_table_name}`"))
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
# 准备列类型定义
|
||||||
|
column_types = FileImportUtils.prepare_table_columns(data, columns)
|
||||||
|
|
||||||
|
# 创建表
|
||||||
|
column_definitions = []
|
||||||
|
for col in columns:
|
||||||
|
col_type = column_types.get(col, "TEXT")
|
||||||
|
column_definitions.append(f"`{col}` {col_type}")
|
||||||
|
|
||||||
|
create_table_sql = f"""
|
||||||
|
CREATE TABLE `{final_table_name}` (
|
||||||
|
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||||
|
{', '.join(column_definitions)}
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
db.execute(text(create_table_sql))
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
# 导入数据
|
||||||
|
df_clean = pd.DataFrame(data)
|
||||||
|
|
||||||
|
# 使用pandas的to_sql方法批量导入
|
||||||
|
df_clean.to_sql(final_table_name, engine, if_exists='append',
|
||||||
|
index=False, method='multi', chunksize=1000)
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"table_name": final_table_name,
|
||||||
|
"rows_imported": len(data),
|
||||||
|
"columns": columns
|
||||||
|
})
|
||||||
|
|
||||||
|
# 返回结果
|
||||||
|
if len(results) == 1:
|
||||||
|
result = results[0]
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": f"成功导入 {result['rows_imported']} 行数据到表 {result['table_name']}",
|
||||||
|
"table_name": result["table_name"],
|
||||||
|
"rows_imported": result["rows_imported"],
|
||||||
|
"columns": result["columns"]
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
total_rows = sum(r["rows_imported"] for r in results)
|
||||||
|
table_names = [r["table_name"] for r in results]
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": f"成功导入 {total_rows} 行数据到 {len(results)} 个表",
|
||||||
|
"tables": results,
|
||||||
|
"total_rows": total_rows,
|
||||||
|
"table_names": table_names
|
||||||
|
}
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": str(e)
|
||||||
|
}
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
db.rollback()
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": f"数据库操作失败: {str(e)}"
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": f"导入失败: {str(e)}"
|
||||||
|
}
|
||||||
236
app/utils/file_import.py
Normal file
236
app/utils/file_import.py
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
"""
|
||||||
|
文件导入工具类
|
||||||
|
处理Excel/CSV文件的读取、解析和数据清洗
|
||||||
|
"""
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import openpyxl
|
||||||
|
import pandas as pd
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class FileImportUtils:
|
||||||
|
"""文件导入工具类"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def clean_cell_value(value: Any) -> Any:
|
||||||
|
"""清理单元格内容:去除空格、换行、制表符等"""
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(value, str):
|
||||||
|
cleaned = value.replace(' ', '').replace('\n', '').replace('\r', '').replace('\t', '').replace('\u3000', '')
|
||||||
|
return cleaned if cleaned else None
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def clean_column_name(col_name: str) -> str:
|
||||||
|
"""清理列名,使其符合MySQL命名规范"""
|
||||||
|
if not col_name:
|
||||||
|
return "column_1"
|
||||||
|
|
||||||
|
# 去除特殊字符,只保留字母、数字、下划线、中文
|
||||||
|
cleaned = re.sub(r'[^\w\u4e00-\u9fff]', '_', str(col_name))
|
||||||
|
|
||||||
|
# 如果以数字开头,添加前缀
|
||||||
|
if cleaned and cleaned[0].isdigit():
|
||||||
|
cleaned = f"col_{cleaned}"
|
||||||
|
|
||||||
|
return cleaned[:64] # MySQL列名最大长度限制
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def infer_mysql_column_type(series: pd.Series) -> str:
|
||||||
|
"""推断MySQL列类型"""
|
||||||
|
# 去除空值
|
||||||
|
non_null_series = series.dropna()
|
||||||
|
|
||||||
|
if len(non_null_series) == 0:
|
||||||
|
return "TEXT"
|
||||||
|
|
||||||
|
# 检查是否为数字
|
||||||
|
try:
|
||||||
|
pd.to_numeric(non_null_series)
|
||||||
|
# 检查是否为整数
|
||||||
|
if all(isinstance(x, (int, float)) and (pd.isna(x) or x == int(x)) for x in non_null_series):
|
||||||
|
max_val = abs(non_null_series.max())
|
||||||
|
if max_val < 2147483648: # INT范围
|
||||||
|
return "INT"
|
||||||
|
else:
|
||||||
|
return "BIGINT"
|
||||||
|
else:
|
||||||
|
return "DECIMAL(10,2)"
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 检查字符串长度
|
||||||
|
max_length = non_null_series.astype(str).str.len().max()
|
||||||
|
if max_length <= 255:
|
||||||
|
return f"VARCHAR({min(max_length * 2, 255)})"
|
||||||
|
else:
|
||||||
|
return "TEXT"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def decode_file_content(file_content: str) -> bytes:
|
||||||
|
"""解码base64文件内容"""
|
||||||
|
try:
|
||||||
|
return base64.b64decode(file_content)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"文件内容解码失败: {str(e)}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def read_excel_file(file_content: bytes, filename: str) -> List[Dict[str, Any]]:
|
||||||
|
"""读取Excel文件内容"""
|
||||||
|
try:
|
||||||
|
wb = openpyxl.load_workbook(io.BytesIO(file_content), data_only=True)
|
||||||
|
table_name_base = Path(filename).stem
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for sheet_name in wb.sheetnames:
|
||||||
|
ws = wb[sheet_name]
|
||||||
|
if ws.max_row is None or ws.max_row <= 1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 读取数据
|
||||||
|
data = []
|
||||||
|
headers = []
|
||||||
|
|
||||||
|
# 读取第一行作为列名
|
||||||
|
for col in range(1, ws.max_column + 1):
|
||||||
|
header_cell = ws.cell(row=1, column=col)
|
||||||
|
header = FileImportUtils.clean_cell_value(header_cell.value)
|
||||||
|
if header:
|
||||||
|
headers.append(FileImportUtils.clean_column_name(header))
|
||||||
|
else:
|
||||||
|
headers.append(f"column_{col}")
|
||||||
|
|
||||||
|
# 读取数据行
|
||||||
|
for row in range(2, ws.max_row + 1):
|
||||||
|
row_data = {}
|
||||||
|
has_data = False
|
||||||
|
for col, header in enumerate(headers, 1):
|
||||||
|
cell_value = FileImportUtils.clean_cell_value(ws.cell(row=row, column=col).value)
|
||||||
|
row_data[header] = cell_value
|
||||||
|
if cell_value is not None:
|
||||||
|
has_data = True
|
||||||
|
|
||||||
|
if has_data:
|
||||||
|
data.append(row_data)
|
||||||
|
|
||||||
|
if data:
|
||||||
|
table_name = f"{table_name_base}_{sheet_name}" if len(wb.sheetnames) > 1 else table_name_base
|
||||||
|
results.append({
|
||||||
|
"table_name": FileImportUtils.clean_column_name(table_name),
|
||||||
|
"data": data,
|
||||||
|
"columns": headers
|
||||||
|
})
|
||||||
|
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"读取Excel文件失败: {str(e)}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def read_csv_file(file_content: bytes, filename: str) -> List[Dict[str, Any]]:
|
||||||
|
"""读取CSV文件内容"""
|
||||||
|
try:
|
||||||
|
# 尝试不同编码
|
||||||
|
encodings = ['utf-8', 'gbk', 'gb2312', 'gb18030']
|
||||||
|
df = None
|
||||||
|
|
||||||
|
for encoding in encodings:
|
||||||
|
try:
|
||||||
|
df = pd.read_csv(io.BytesIO(file_content), encoding=encoding)
|
||||||
|
break
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if df is None:
|
||||||
|
raise Exception("无法识别文件编码")
|
||||||
|
|
||||||
|
# 清理列名
|
||||||
|
df.columns = [FileImportUtils.clean_column_name(col) for col in df.columns]
|
||||||
|
|
||||||
|
# 处理重复列名
|
||||||
|
if len(df.columns) != len(set(df.columns)):
|
||||||
|
new_cols = []
|
||||||
|
col_count = {}
|
||||||
|
for col in df.columns:
|
||||||
|
if col in col_count:
|
||||||
|
col_count[col] += 1
|
||||||
|
new_cols.append(f"{col}_{col_count[col]}")
|
||||||
|
else:
|
||||||
|
col_count[col] = 0
|
||||||
|
new_cols.append(col)
|
||||||
|
df.columns = new_cols
|
||||||
|
|
||||||
|
# 清理数据
|
||||||
|
for col in df.columns:
|
||||||
|
df[col] = df[col].apply(lambda x: FileImportUtils.clean_cell_value(x))
|
||||||
|
df[col] = df[col].where(pd.notna(df[col]), None)
|
||||||
|
|
||||||
|
# 转换为字典列表
|
||||||
|
data = df.to_dict('records')
|
||||||
|
|
||||||
|
table_name = FileImportUtils.clean_column_name(Path(filename).stem)
|
||||||
|
|
||||||
|
return [{
|
||||||
|
"table_name": table_name,
|
||||||
|
"data": data,
|
||||||
|
"columns": list(df.columns)
|
||||||
|
}]
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"读取CSV文件失败: {str(e)}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_file(filename: str, file_content: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
解析文件内容,根据文件扩展名选择相应的解析方法
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: 文件名
|
||||||
|
file_content: base64编码的文件内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
解析后的文件数据列表
|
||||||
|
"""
|
||||||
|
# 解码文件内容
|
||||||
|
file_bytes = FileImportUtils.decode_file_content(file_content)
|
||||||
|
|
||||||
|
# 根据文件扩展名处理文件
|
||||||
|
file_ext = Path(filename).suffix.lower()
|
||||||
|
|
||||||
|
if file_ext == '.csv':
|
||||||
|
return FileImportUtils.read_csv_file(file_bytes, filename)
|
||||||
|
elif file_ext in ['.xlsx', '.xls']:
|
||||||
|
return FileImportUtils.read_excel_file(file_bytes, filename)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的文件类型: {file_ext}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def prepare_table_columns(data: List[Dict[str, Any]], columns: List[str]) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
分析数据并准备表列定义
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 数据列表
|
||||||
|
columns: 列名列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
列名到MySQL数据类型的映射
|
||||||
|
"""
|
||||||
|
if not data:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# 创建DataFrame用于类型推断
|
||||||
|
df = pd.DataFrame(data)
|
||||||
|
column_types = {}
|
||||||
|
|
||||||
|
for col in columns:
|
||||||
|
if col in df.columns:
|
||||||
|
column_types[col] = FileImportUtils.infer_mysql_column_type(df[col])
|
||||||
|
else:
|
||||||
|
column_types[col] = "TEXT"
|
||||||
|
|
||||||
|
return column_types
|
||||||
364
import_database_utils.py
Normal file
364
import_database_utils.py
Normal file
@@ -0,0 +1,364 @@
|
|||||||
|
"""
|
||||||
|
数据库导入工具模块
|
||||||
|
用于通过HTTP接口导入文件到数据库
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import sqlite3
|
||||||
|
import pandas as pd
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Dict, Any, Tuple, List
|
||||||
|
import openpyxl
|
||||||
|
import io
|
||||||
|
|
||||||
|
# 添加sqlite-mcp-server到路径以使用其工具函数
|
||||||
|
current_dir = os.path.dirname(__file__)
|
||||||
|
project_root = os.path.dirname(os.path.dirname(current_dir))
|
||||||
|
sqlite_mcp_path = os.path.join(project_root, 'sqlite-mcp-server')
|
||||||
|
if os.path.exists(sqlite_mcp_path) and sqlite_mcp_path not in sys.path:
|
||||||
|
sys.path.insert(0, sqlite_mcp_path)
|
||||||
|
|
||||||
|
from tools.common import (
|
||||||
|
clean_table_name, clean_column_name, infer_data_types,
|
||||||
|
CHUNK_SIZE, SQLITE_MAX_VARIABLES
|
||||||
|
)
|
||||||
|
|
||||||
|
def _clean_cell_value(value: Any) -> Any:
|
||||||
|
"""
|
||||||
|
清理单元格内容:去除空格、换行、制表符等
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: 单元格的值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
清理后的值
|
||||||
|
"""
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(value, str):
|
||||||
|
# 去除各种空白字符:空格、换行、回车、制表符、全角空格
|
||||||
|
cleaned = value.replace(' ', '').replace('\n', '').replace('\r', '').replace('\t', '').replace('\u3000', '')
|
||||||
|
return cleaned if cleaned else None
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
def _process_merged_cells(ws: openpyxl.worksheet.worksheet.Worksheet) -> Dict[Tuple[int, int], Any]:
|
||||||
|
"""
|
||||||
|
处理合并单元格,将所有合并单元格打散并填充原值
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ws: openpyxl worksheet对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
单元格坐标到值的映射
|
||||||
|
"""
|
||||||
|
merged_map = {}
|
||||||
|
|
||||||
|
# 处理所有合并单元格范围
|
||||||
|
for merged_range in ws.merged_cells.ranges:
|
||||||
|
# 获取左上角单元格的值
|
||||||
|
min_row, min_col = merged_range.min_row, merged_range.min_col
|
||||||
|
value = ws.cell(row=min_row, column=min_col).value
|
||||||
|
|
||||||
|
# 填充整个合并范围
|
||||||
|
for row in range(merged_range.min_row, merged_range.max_row + 1):
|
||||||
|
for col in range(merged_range.min_col, merged_range.max_col + 1):
|
||||||
|
merged_map[(row, col)] = value
|
||||||
|
|
||||||
|
return merged_map
|
||||||
|
|
||||||
|
|
||||||
|
def _worksheet_is_empty(ws: openpyxl.worksheet.worksheet.Worksheet) -> bool:
|
||||||
|
"""
|
||||||
|
判断工作表是否为空(无任何有效数据)
|
||||||
|
"""
|
||||||
|
max_row = ws.max_row or 0
|
||||||
|
max_col = ws.max_column or 0
|
||||||
|
if max_row == 0 or max_col == 0:
|
||||||
|
return True
|
||||||
|
# 快速路径:单个单元且为空
|
||||||
|
if max_row == 1 and max_col == 1:
|
||||||
|
return _clean_cell_value(ws.cell(row=1, column=1).value) is None
|
||||||
|
|
||||||
|
# 扫描是否存在任一非空单元格
|
||||||
|
for row in ws.iter_rows(min_row=1, max_row=max_row, min_col=1, max_col=max_col, values_only=True):
|
||||||
|
if any(_clean_cell_value(v) is not None for v in row):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _read_excel_with_merged_cells(table_name: str, file_content: bytes) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
从内存中读取Excel文件,处理合并单元格
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_name: 表名
|
||||||
|
file_content: Excel文件的二进制内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理后的DataFrame
|
||||||
|
"""
|
||||||
|
# 从字节流加载Excel文件
|
||||||
|
wb = openpyxl.load_workbook(io.BytesIO(file_content), data_only=True)
|
||||||
|
wss = wb.sheetnames
|
||||||
|
data = []
|
||||||
|
df = None
|
||||||
|
|
||||||
|
for ws_name in wss:
|
||||||
|
ws = wb[ws_name]
|
||||||
|
# 跳过空工作表
|
||||||
|
if _worksheet_is_empty(ws):
|
||||||
|
continue
|
||||||
|
# 处理合并单元格
|
||||||
|
# merged_cells_map = _process_merged_cells(ws)
|
||||||
|
# 读取所有数据
|
||||||
|
all_data = []
|
||||||
|
# 跳过空列
|
||||||
|
none_raw = []
|
||||||
|
max_row = ws.max_row
|
||||||
|
max_col = ws.max_column
|
||||||
|
for row_idx in range(1, max_row + 1):
|
||||||
|
row_data = []
|
||||||
|
for col_idx in range(1, max_col + 1):
|
||||||
|
|
||||||
|
# 检查是否在合并单元格映射中
|
||||||
|
# if (row_idx, col_idx) in merged_cells_map:
|
||||||
|
# value = merged_cells_map[(row_idx, col_idx)]
|
||||||
|
# else:
|
||||||
|
value = ws.cell(row=row_idx, column=col_idx).value
|
||||||
|
if (value is None and row_idx == 1) or (col_idx in none_raw):
|
||||||
|
none_raw.append(col_idx)
|
||||||
|
continue
|
||||||
|
# 清理单元格内容
|
||||||
|
value = _clean_cell_value(value)
|
||||||
|
row_data.append(value)
|
||||||
|
# 跳过完全空白的行
|
||||||
|
if any(v is not None for v in row_data):
|
||||||
|
all_data.append(row_data)
|
||||||
|
# 转换为DataFrame
|
||||||
|
if not all_data:
|
||||||
|
# 空表,跳过
|
||||||
|
continue
|
||||||
|
# 第一行作为列名
|
||||||
|
columns = [str(col) if col is not None else f"列{i+1}"
|
||||||
|
for i, col in enumerate(all_data[0])]
|
||||||
|
# 其余行作为数据
|
||||||
|
data_rows = all_data[1:] if len(all_data) > 1 else []
|
||||||
|
# 创建DataFrame
|
||||||
|
df = pd.DataFrame(data_rows, columns=columns)
|
||||||
|
if len(wss) > 1:
|
||||||
|
normalized_sheet = ws_name.replace(' ', '').replace('-', '_')
|
||||||
|
data.append({"table_name": f"{table_name}_{normalized_sheet}", "df": df})
|
||||||
|
else:
|
||||||
|
data.append({"table_name": table_name, "df": df})
|
||||||
|
if not data:
|
||||||
|
return []
|
||||||
|
return data
|
||||||
|
|
||||||
|
def import_to_database(
|
||||||
|
db_path: str,
|
||||||
|
file_content: bytes,
|
||||||
|
filename: str,
|
||||||
|
table_name: Optional[str] = None,
|
||||||
|
force_overwrite: bool = False
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
导入文件内容到指定的SQLite数据库
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: 数据库文件路径
|
||||||
|
file_content: 文件的二进制内容
|
||||||
|
filename: 原始文件名
|
||||||
|
table_name: 指定的表名(可选,默认从文件名生成)
|
||||||
|
force_overwrite: 是否强制覆盖已存在的表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
导入结果信息
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 检查数据库是否存在
|
||||||
|
if not os.path.exists(db_path):
|
||||||
|
return {"success": False, "error": f"数据库不存在: {db_path}"}
|
||||||
|
|
||||||
|
|
||||||
|
# 读取文件
|
||||||
|
file_ext = Path(filename).suffix.lower()
|
||||||
|
table_name = filename.split('.')[0]
|
||||||
|
|
||||||
|
if file_ext == '.csv':
|
||||||
|
# 检测编码
|
||||||
|
encodings = ['utf-8', 'gbk', 'gb2312', 'gb18030']
|
||||||
|
df = None
|
||||||
|
data = []
|
||||||
|
for encoding in encodings:
|
||||||
|
try:
|
||||||
|
# 尝试从内存读取
|
||||||
|
df = pd.read_csv(io.BytesIO(file_content), encoding=encoding, nrows=5)
|
||||||
|
# 如果成功,用同样的编码读取完整内容
|
||||||
|
df = pd.read_csv(io.BytesIO(file_content), encoding=encoding)
|
||||||
|
data.append({"table_name": table_name, "df": df})
|
||||||
|
break
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
if df is None:
|
||||||
|
conn.close()
|
||||||
|
return {"success": False, "error": "无法识别文件编码"}
|
||||||
|
elif file_ext in ['.xlsx', '.xls']:
|
||||||
|
# 使用自定义函数读取Excel,处理合并单元格
|
||||||
|
data = _read_excel_with_merged_cells(table_name, file_content)
|
||||||
|
else:
|
||||||
|
conn.close()
|
||||||
|
return {"success": False, "error": f"不支持的文件类型: {file_ext}"}
|
||||||
|
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
for item in data:
|
||||||
|
table_name = item["table_name"]
|
||||||
|
df = item["df"]
|
||||||
|
# 检查表是否存在
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT name FROM sqlite_master
|
||||||
|
WHERE type='table' AND name='{table_name}'
|
||||||
|
""")
|
||||||
|
|
||||||
|
table_exists = cursor.fetchone() is not None
|
||||||
|
|
||||||
|
if table_exists and not force_overwrite:
|
||||||
|
# 获取表的基本信息
|
||||||
|
cursor.execute(f"SELECT COUNT(*) FROM \"{table_name}\"")
|
||||||
|
row_count = cursor.fetchone()[0]
|
||||||
|
|
||||||
|
cursor.execute(f"PRAGMA table_info(\"{table_name}\")")
|
||||||
|
columns = [col[1] for col in cursor.fetchall() if col[1] != '_row_id']
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": "table_exists",
|
||||||
|
"error_type": "TABLE_EXISTS",
|
||||||
|
"message": f"表 {table_name} 已存在",
|
||||||
|
"table_info": {
|
||||||
|
"table_name": table_name,
|
||||||
|
"row_count": row_count,
|
||||||
|
"columns": columns,
|
||||||
|
"column_count": len(columns)
|
||||||
|
},
|
||||||
|
"suggestion": "您可以选择:1) 覆盖现有表(force_overwrite=true)2) 使用其他表名 3) 取消操作"
|
||||||
|
}
|
||||||
|
|
||||||
|
total_rows_all = 0
|
||||||
|
imported_rows_all = 0
|
||||||
|
error_count_all = 0
|
||||||
|
table_name_all = []
|
||||||
|
for item in data:
|
||||||
|
df = item["df"]
|
||||||
|
table_name = item["table_name"]
|
||||||
|
# 清理列名
|
||||||
|
df.columns = [clean_column_name(col) for col in df.columns]
|
||||||
|
|
||||||
|
# 检查是否有重复列名
|
||||||
|
if len(df.columns) != len(set(df.columns)):
|
||||||
|
# 为重复列名添加序号
|
||||||
|
new_cols = []
|
||||||
|
col_count = {}
|
||||||
|
for col in df.columns:
|
||||||
|
if col in col_count:
|
||||||
|
col_count[col] += 1
|
||||||
|
new_cols.append(f"{col}_{col_count[col]}")
|
||||||
|
else:
|
||||||
|
col_count[col] = 0
|
||||||
|
new_cols.append(col)
|
||||||
|
df.columns = new_cols
|
||||||
|
|
||||||
|
total_rows = len(df)
|
||||||
|
|
||||||
|
# 如果表存在且需要覆盖,先删除表
|
||||||
|
if table_exists and force_overwrite:
|
||||||
|
cursor.execute(f'DROP TABLE IF EXISTS "{table_name}"')
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
# 创建新表
|
||||||
|
data_types = infer_data_types(df)
|
||||||
|
|
||||||
|
columns_def = []
|
||||||
|
for col in df.columns:
|
||||||
|
col_type = data_types.get(col, 'TEXT')
|
||||||
|
columns_def.append(f'"{col}" {col_type}')
|
||||||
|
|
||||||
|
create_table_sql = f"""
|
||||||
|
CREATE TABLE IF NOT EXISTS "{table_name}" (
|
||||||
|
_row_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
{', '.join(columns_def)}
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
cursor.execute(create_table_sql)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
# 数据清洗和类型转换
|
||||||
|
for col in df.columns:
|
||||||
|
# 对于CSV文件,去除前后空格
|
||||||
|
if file_ext == '.csv' and df[col].dtype == 'object':
|
||||||
|
df[col] = df[col].apply(lambda x: x.strip() if isinstance(x, str) else x)
|
||||||
|
|
||||||
|
# 将 NaN 替换为 None (SQLite NULL)
|
||||||
|
df[col] = df[col].where(pd.notna(df[col]), None)
|
||||||
|
|
||||||
|
# 日期类型转换为 ISO 格式
|
||||||
|
# try:
|
||||||
|
# # 尝试识别日期列
|
||||||
|
# if df[col].dtype == 'object' and df[col].notna().any():
|
||||||
|
# sample = df[col].dropna().head(100)
|
||||||
|
# dates = pd.to_datetime(sample, errors='coerce')
|
||||||
|
# if dates.notna().sum() / len(sample) > 0.9:
|
||||||
|
# df[col] = pd.to_datetime(df[col], errors='coerce')
|
||||||
|
# df[col] = df[col].dt.strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
# except:
|
||||||
|
# pass
|
||||||
|
|
||||||
|
# 计算安全的批次大小(考虑 SQLite 变量限制)
|
||||||
|
num_columns = len(df.columns)
|
||||||
|
safe_batch_size = min(CHUNK_SIZE, SQLITE_MAX_VARIABLES // num_columns - 1)
|
||||||
|
|
||||||
|
# 批量导入数据
|
||||||
|
error_count = 0
|
||||||
|
imported_rows = 0
|
||||||
|
|
||||||
|
for batch_start in range(0, total_rows, safe_batch_size):
|
||||||
|
batch_end = min(batch_start + safe_batch_size, total_rows)
|
||||||
|
batch_df = df.iloc[batch_start:batch_end]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 批量插入
|
||||||
|
batch_df.to_sql(table_name, conn, if_exists='append',
|
||||||
|
index=False, method='multi', chunksize=100)
|
||||||
|
|
||||||
|
imported_rows = batch_end
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_count += 1
|
||||||
|
print(f"批次导入错误 [{batch_start}-{batch_end}]: {str(e)}")
|
||||||
|
|
||||||
|
# 优化表
|
||||||
|
cursor.execute(f'ANALYZE "{table_name}"')
|
||||||
|
total_rows_all += total_rows
|
||||||
|
imported_rows_all += imported_rows
|
||||||
|
error_count_all += error_count
|
||||||
|
table_name_all.append(table_name)
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
table_name_all = set(table_name_all)
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"table_name": table_name,
|
||||||
|
"total_rows": total_rows_all,
|
||||||
|
"imported_rows": imported_rows_all,
|
||||||
|
"error_count": error_count_all,
|
||||||
|
"message": f"成功导入 {imported_rows}/{total_rows} 行数据到表 {table_name_all}"
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {"success": False, "error": str(e)}
|
||||||
@@ -7,4 +7,5 @@ pydantic==2.5.0
|
|||||||
python-dotenv==1.0.0
|
python-dotenv==1.0.0
|
||||||
apscheduler==3.10.4
|
apscheduler==3.10.4
|
||||||
pandas==2.1.3
|
pandas==2.1.3
|
||||||
python-multipart==0.0.6
|
python-multipart==0.0.6
|
||||||
|
openpyxl==3.1.5
|
||||||
Reference in New Issue
Block a user