导入文件建表接口

This commit is contained in:
lhx
2025-09-26 17:19:58 +08:00
parent 18478b148a
commit 52d2753824
6 changed files with 735 additions and 5 deletions

364
import_database_utils.py Normal file
View File

@@ -0,0 +1,364 @@
"""
数据库导入工具模块
用于通过HTTP接口导入文件到数据库
"""
import os
import sys
import sqlite3
import pandas as pd
from datetime import datetime
from pathlib import Path
from typing import Optional, Dict, Any, Tuple, List
import openpyxl
import io
# 添加sqlite-mcp-server到路径以使用其工具函数
current_dir = os.path.dirname(__file__)
project_root = os.path.dirname(os.path.dirname(current_dir))
sqlite_mcp_path = os.path.join(project_root, 'sqlite-mcp-server')
if os.path.exists(sqlite_mcp_path) and sqlite_mcp_path not in sys.path:
sys.path.insert(0, sqlite_mcp_path)
from tools.common import (
clean_table_name, clean_column_name, infer_data_types,
CHUNK_SIZE, SQLITE_MAX_VARIABLES
)
def _clean_cell_value(value: Any) -> Any:
"""
清理单元格内容:去除空格、换行、制表符等
Args:
value: 单元格的值
Returns:
清理后的值
"""
if value is None:
return None
if isinstance(value, str):
# 去除各种空白字符:空格、换行、回车、制表符、全角空格
cleaned = value.replace(' ', '').replace('\n', '').replace('\r', '').replace('\t', '').replace('\u3000', '')
return cleaned if cleaned else None
return value
def _process_merged_cells(ws: openpyxl.worksheet.worksheet.Worksheet) -> Dict[Tuple[int, int], Any]:
"""
处理合并单元格,将所有合并单元格打散并填充原值
Args:
ws: openpyxl worksheet对象
Returns:
单元格坐标到值的映射
"""
merged_map = {}
# 处理所有合并单元格范围
for merged_range in ws.merged_cells.ranges:
# 获取左上角单元格的值
min_row, min_col = merged_range.min_row, merged_range.min_col
value = ws.cell(row=min_row, column=min_col).value
# 填充整个合并范围
for row in range(merged_range.min_row, merged_range.max_row + 1):
for col in range(merged_range.min_col, merged_range.max_col + 1):
merged_map[(row, col)] = value
return merged_map
def _worksheet_is_empty(ws: openpyxl.worksheet.worksheet.Worksheet) -> bool:
"""
判断工作表是否为空(无任何有效数据)
"""
max_row = ws.max_row or 0
max_col = ws.max_column or 0
if max_row == 0 or max_col == 0:
return True
# 快速路径:单个单元且为空
if max_row == 1 and max_col == 1:
return _clean_cell_value(ws.cell(row=1, column=1).value) is None
# 扫描是否存在任一非空单元格
for row in ws.iter_rows(min_row=1, max_row=max_row, min_col=1, max_col=max_col, values_only=True):
if any(_clean_cell_value(v) is not None for v in row):
return False
return True
def _read_excel_with_merged_cells(table_name: str, file_content: bytes) -> List[Dict[str, Any]]:
"""
从内存中读取Excel文件处理合并单元格
Args:
table_name: 表名
file_content: Excel文件的二进制内容
Returns:
处理后的DataFrame
"""
# 从字节流加载Excel文件
wb = openpyxl.load_workbook(io.BytesIO(file_content), data_only=True)
wss = wb.sheetnames
data = []
df = None
for ws_name in wss:
ws = wb[ws_name]
# 跳过空工作表
if _worksheet_is_empty(ws):
continue
# 处理合并单元格
# merged_cells_map = _process_merged_cells(ws)
# 读取所有数据
all_data = []
# 跳过空列
none_raw = []
max_row = ws.max_row
max_col = ws.max_column
for row_idx in range(1, max_row + 1):
row_data = []
for col_idx in range(1, max_col + 1):
# 检查是否在合并单元格映射中
# if (row_idx, col_idx) in merged_cells_map:
# value = merged_cells_map[(row_idx, col_idx)]
# else:
value = ws.cell(row=row_idx, column=col_idx).value
if (value is None and row_idx == 1) or (col_idx in none_raw):
none_raw.append(col_idx)
continue
# 清理单元格内容
value = _clean_cell_value(value)
row_data.append(value)
# 跳过完全空白的行
if any(v is not None for v in row_data):
all_data.append(row_data)
# 转换为DataFrame
if not all_data:
# 空表,跳过
continue
# 第一行作为列名
columns = [str(col) if col is not None else f"{i+1}"
for i, col in enumerate(all_data[0])]
# 其余行作为数据
data_rows = all_data[1:] if len(all_data) > 1 else []
# 创建DataFrame
df = pd.DataFrame(data_rows, columns=columns)
if len(wss) > 1:
normalized_sheet = ws_name.replace(' ', '').replace('-', '_')
data.append({"table_name": f"{table_name}_{normalized_sheet}", "df": df})
else:
data.append({"table_name": table_name, "df": df})
if not data:
return []
return data
def import_to_database(
db_path: str,
file_content: bytes,
filename: str,
table_name: Optional[str] = None,
force_overwrite: bool = False
) -> Dict[str, Any]:
"""
导入文件内容到指定的SQLite数据库
Args:
db_path: 数据库文件路径
file_content: 文件的二进制内容
filename: 原始文件名
table_name: 指定的表名(可选,默认从文件名生成)
force_overwrite: 是否强制覆盖已存在的表
Returns:
导入结果信息
"""
try:
# 检查数据库是否存在
if not os.path.exists(db_path):
return {"success": False, "error": f"数据库不存在: {db_path}"}
# 读取文件
file_ext = Path(filename).suffix.lower()
table_name = filename.split('.')[0]
if file_ext == '.csv':
# 检测编码
encodings = ['utf-8', 'gbk', 'gb2312', 'gb18030']
df = None
data = []
for encoding in encodings:
try:
# 尝试从内存读取
df = pd.read_csv(io.BytesIO(file_content), encoding=encoding, nrows=5)
# 如果成功,用同样的编码读取完整内容
df = pd.read_csv(io.BytesIO(file_content), encoding=encoding)
data.append({"table_name": table_name, "df": df})
break
except:
continue
if df is None:
conn.close()
return {"success": False, "error": "无法识别文件编码"}
elif file_ext in ['.xlsx', '.xls']:
# 使用自定义函数读取Excel处理合并单元格
data = _read_excel_with_merged_cells(table_name, file_content)
else:
conn.close()
return {"success": False, "error": f"不支持的文件类型: {file_ext}"}
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
for item in data:
table_name = item["table_name"]
df = item["df"]
# 检查表是否存在
cursor.execute(f"""
SELECT name FROM sqlite_master
WHERE type='table' AND name='{table_name}'
""")
table_exists = cursor.fetchone() is not None
if table_exists and not force_overwrite:
# 获取表的基本信息
cursor.execute(f"SELECT COUNT(*) FROM \"{table_name}\"")
row_count = cursor.fetchone()[0]
cursor.execute(f"PRAGMA table_info(\"{table_name}\")")
columns = [col[1] for col in cursor.fetchall() if col[1] != '_row_id']
conn.close()
return {
"success": False,
"error": "table_exists",
"error_type": "TABLE_EXISTS",
"message": f"{table_name} 已存在",
"table_info": {
"table_name": table_name,
"row_count": row_count,
"columns": columns,
"column_count": len(columns)
},
"suggestion": "您可以选择1) 覆盖现有表force_overwrite=true2) 使用其他表名 3) 取消操作"
}
total_rows_all = 0
imported_rows_all = 0
error_count_all = 0
table_name_all = []
for item in data:
df = item["df"]
table_name = item["table_name"]
# 清理列名
df.columns = [clean_column_name(col) for col in df.columns]
# 检查是否有重复列名
if len(df.columns) != len(set(df.columns)):
# 为重复列名添加序号
new_cols = []
col_count = {}
for col in df.columns:
if col in col_count:
col_count[col] += 1
new_cols.append(f"{col}_{col_count[col]}")
else:
col_count[col] = 0
new_cols.append(col)
df.columns = new_cols
total_rows = len(df)
# 如果表存在且需要覆盖,先删除表
if table_exists and force_overwrite:
cursor.execute(f'DROP TABLE IF EXISTS "{table_name}"')
conn.commit()
# 创建新表
data_types = infer_data_types(df)
columns_def = []
for col in df.columns:
col_type = data_types.get(col, 'TEXT')
columns_def.append(f'"{col}" {col_type}')
create_table_sql = f"""
CREATE TABLE IF NOT EXISTS "{table_name}" (
_row_id INTEGER PRIMARY KEY AUTOINCREMENT,
{', '.join(columns_def)}
)
"""
cursor.execute(create_table_sql)
conn.commit()
# 数据清洗和类型转换
for col in df.columns:
# 对于CSV文件去除前后空格
if file_ext == '.csv' and df[col].dtype == 'object':
df[col] = df[col].apply(lambda x: x.strip() if isinstance(x, str) else x)
# 将 NaN 替换为 None (SQLite NULL)
df[col] = df[col].where(pd.notna(df[col]), None)
# 日期类型转换为 ISO 格式
# try:
# # 尝试识别日期列
# if df[col].dtype == 'object' and df[col].notna().any():
# sample = df[col].dropna().head(100)
# dates = pd.to_datetime(sample, errors='coerce')
# if dates.notna().sum() / len(sample) > 0.9:
# df[col] = pd.to_datetime(df[col], errors='coerce')
# df[col] = df[col].dt.strftime('%Y-%m-%d %H:%M:%S')
# except:
# pass
# 计算安全的批次大小(考虑 SQLite 变量限制)
num_columns = len(df.columns)
safe_batch_size = min(CHUNK_SIZE, SQLITE_MAX_VARIABLES // num_columns - 1)
# 批量导入数据
error_count = 0
imported_rows = 0
for batch_start in range(0, total_rows, safe_batch_size):
batch_end = min(batch_start + safe_batch_size, total_rows)
batch_df = df.iloc[batch_start:batch_end]
try:
# 批量插入
batch_df.to_sql(table_name, conn, if_exists='append',
index=False, method='multi', chunksize=100)
imported_rows = batch_end
except Exception as e:
error_count += 1
print(f"批次导入错误 [{batch_start}-{batch_end}]: {str(e)}")
# 优化表
cursor.execute(f'ANALYZE "{table_name}"')
total_rows_all += total_rows
imported_rows_all += imported_rows
error_count_all += error_count
table_name_all.append(table_name)
conn.commit()
conn.close()
table_name_all = set(table_name_all)
return {
"success": True,
"table_name": table_name,
"total_rows": total_rows_all,
"imported_rows": imported_rows_all,
"error_count": error_count_all,
"message": f"成功导入 {imported_rows}/{total_rows} 行数据到表 {table_name_all}"
}
except Exception as e:
return {"success": False, "error": str(e)}