From 52d2753824a8d67b7e613a68f600358aae51b4b0 Mon Sep 17 00:00:00 2001 From: lhx Date: Fri, 26 Sep 2025 17:19:58 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AF=BC=E5=85=A5=E6=96=87=E4=BB=B6=E5=BB=BA?= =?UTF-8?q?=E8=A1=A8=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/database.py | 16 +- app/schemas/database.py | 8 +- app/services/database.py | 113 +++++++++++- app/utils/file_import.py | 236 +++++++++++++++++++++++++ import_database_utils.py | 364 +++++++++++++++++++++++++++++++++++++++ requirements.txt | 3 +- 6 files changed, 735 insertions(+), 5 deletions(-) create mode 100644 app/utils/file_import.py create mode 100644 import_database_utils.py diff --git a/app/api/database.py b/app/api/database.py index b7e213c..531ff9d 100644 --- a/app/api/database.py +++ b/app/api/database.py @@ -4,7 +4,7 @@ from typing import List from ..core.database import get_db from ..schemas.database import ( SQLExecuteRequest, SQLExecuteResponse, TableDataRequest, - TableDataResponse, CreateTableRequest, ImportDataRequest + TableDataResponse, CreateTableRequest, ImportDataRequest, FileImportRequest ) from ..services.database import DatabaseService @@ -56,4 +56,16 @@ def import_data(request: ImportDataRequest, db: Session = Depends(get_db)): @router.post("/tables", response_model=List[str]) def get_table_list(): """获取所有表名""" - return DatabaseService.get_table_list() \ No newline at end of file + return DatabaseService.get_table_list() + +@router.post("/import-file", response_model=SQLExecuteResponse) +def import_file(request: FileImportRequest, db: Session = Depends(get_db)): + """导入Excel/CSV文件到数据库""" + result = DatabaseService.import_file_to_database( + db, + request.filename, + request.file_content, + request.table_name, + request.force_overwrite + ) + return SQLExecuteResponse(**result) \ No newline at end of file diff --git a/app/schemas/database.py b/app/schemas/database.py index b96a231..f36efb0 100644 --- a/app/schemas/database.py +++ b/app/schemas/database.py @@ -28,4 +28,10 @@ class CreateTableRequest(BaseModel): class ImportDataRequest(BaseModel): table_name: str - data: List[Dict[str, Any]] \ No newline at end of file + data: List[Dict[str, Any]] + +class FileImportRequest(BaseModel): + filename: str + file_content: str # base64编码的文件内容 + table_name: Optional[str] = None # 可选的自定义表名 + force_overwrite: bool = False # 是否强制覆盖已存在的表 \ No newline at end of file diff --git a/app/services/database.py b/app/services/database.py index d4d7ce2..9950009 100644 --- a/app/services/database.py +++ b/app/services/database.py @@ -3,6 +3,7 @@ from sqlalchemy import text, MetaData, Table, Column, create_engine, inspect from sqlalchemy.exc import SQLAlchemyError from typing import List, Dict, Any, Optional from ..core.database import engine +from ..utils.file_import import FileImportUtils import pandas as pd class DatabaseService: @@ -193,4 +194,114 @@ class DatabaseService: # 排除mysql的系统表和accounts表 return [table for table in inspector.get_table_names() if not table.startswith('mysql') and table != 'accounts'] except Exception as e: - return [] \ No newline at end of file + return [] + + @staticmethod + def import_file_to_database(db: Session, filename: str, file_content: str, + table_name: Optional[str] = None, + force_overwrite: bool = False) -> Dict[str, Any]: + """从文件导入数据到数据库""" + try: + # 解析文件内容 + file_data = FileImportUtils.parse_file(filename, file_content) + + if not file_data: + return { + "success": False, + "message": "文件中没有找到有效数据" + } + + results = [] + inspector = inspect(engine) + existing_tables = inspector.get_table_names() + + for item in file_data: + final_table_name = table_name if table_name else item["table_name"] + data = item["data"] + columns = item["columns"] + + if not data: + continue + + # 检查表是否存在 + if final_table_name in existing_tables and not force_overwrite: + return { + "success": False, + "message": f"表 {final_table_name} 已存在,请使用 force_overwrite=true 覆盖或选择其他表名" + } + + # 如果需要覆盖,先删除表 + if final_table_name in existing_tables and force_overwrite: + db.execute(text(f"DROP TABLE IF EXISTS `{final_table_name}`")) + db.commit() + + # 准备列类型定义 + column_types = FileImportUtils.prepare_table_columns(data, columns) + + # 创建表 + column_definitions = [] + for col in columns: + col_type = column_types.get(col, "TEXT") + column_definitions.append(f"`{col}` {col_type}") + + create_table_sql = f""" + CREATE TABLE `{final_table_name}` ( + id INT AUTO_INCREMENT PRIMARY KEY, + {', '.join(column_definitions)} + ) + """ + + db.execute(text(create_table_sql)) + db.commit() + + # 导入数据 + df_clean = pd.DataFrame(data) + + # 使用pandas的to_sql方法批量导入 + df_clean.to_sql(final_table_name, engine, if_exists='append', + index=False, method='multi', chunksize=1000) + + results.append({ + "table_name": final_table_name, + "rows_imported": len(data), + "columns": columns + }) + + # 返回结果 + if len(results) == 1: + result = results[0] + return { + "success": True, + "message": f"成功导入 {result['rows_imported']} 行数据到表 {result['table_name']}", + "table_name": result["table_name"], + "rows_imported": result["rows_imported"], + "columns": result["columns"] + } + else: + total_rows = sum(r["rows_imported"] for r in results) + table_names = [r["table_name"] for r in results] + return { + "success": True, + "message": f"成功导入 {total_rows} 行数据到 {len(results)} 个表", + "tables": results, + "total_rows": total_rows, + "table_names": table_names + } + + except ValueError as e: + return { + "success": False, + "message": str(e) + } + except SQLAlchemyError as e: + db.rollback() + return { + "success": False, + "message": f"数据库操作失败: {str(e)}" + } + except Exception as e: + db.rollback() + return { + "success": False, + "message": f"导入失败: {str(e)}" + } \ No newline at end of file diff --git a/app/utils/file_import.py b/app/utils/file_import.py new file mode 100644 index 0000000..10dab01 --- /dev/null +++ b/app/utils/file_import.py @@ -0,0 +1,236 @@ +""" +文件导入工具类 +处理Excel/CSV文件的读取、解析和数据清洗 +""" +import base64 +import io +import openpyxl +import pandas as pd +import re +from pathlib import Path +from typing import Any, Dict, List, Optional + + +class FileImportUtils: + """文件导入工具类""" + + @staticmethod + def clean_cell_value(value: Any) -> Any: + """清理单元格内容:去除空格、换行、制表符等""" + if value is None: + return None + + if isinstance(value, str): + cleaned = value.replace(' ', '').replace('\n', '').replace('\r', '').replace('\t', '').replace('\u3000', '') + return cleaned if cleaned else None + + return value + + @staticmethod + def clean_column_name(col_name: str) -> str: + """清理列名,使其符合MySQL命名规范""" + if not col_name: + return "column_1" + + # 去除特殊字符,只保留字母、数字、下划线、中文 + cleaned = re.sub(r'[^\w\u4e00-\u9fff]', '_', str(col_name)) + + # 如果以数字开头,添加前缀 + if cleaned and cleaned[0].isdigit(): + cleaned = f"col_{cleaned}" + + return cleaned[:64] # MySQL列名最大长度限制 + + @staticmethod + def infer_mysql_column_type(series: pd.Series) -> str: + """推断MySQL列类型""" + # 去除空值 + non_null_series = series.dropna() + + if len(non_null_series) == 0: + return "TEXT" + + # 检查是否为数字 + try: + pd.to_numeric(non_null_series) + # 检查是否为整数 + if all(isinstance(x, (int, float)) and (pd.isna(x) or x == int(x)) for x in non_null_series): + max_val = abs(non_null_series.max()) + if max_val < 2147483648: # INT范围 + return "INT" + else: + return "BIGINT" + else: + return "DECIMAL(10,2)" + except: + pass + + # 检查字符串长度 + max_length = non_null_series.astype(str).str.len().max() + if max_length <= 255: + return f"VARCHAR({min(max_length * 2, 255)})" + else: + return "TEXT" + + @staticmethod + def decode_file_content(file_content: str) -> bytes: + """解码base64文件内容""" + try: + return base64.b64decode(file_content) + except Exception as e: + raise ValueError(f"文件内容解码失败: {str(e)}") + + @staticmethod + def read_excel_file(file_content: bytes, filename: str) -> List[Dict[str, Any]]: + """读取Excel文件内容""" + try: + wb = openpyxl.load_workbook(io.BytesIO(file_content), data_only=True) + table_name_base = Path(filename).stem + results = [] + + for sheet_name in wb.sheetnames: + ws = wb[sheet_name] + if ws.max_row is None or ws.max_row <= 1: + continue + + # 读取数据 + data = [] + headers = [] + + # 读取第一行作为列名 + for col in range(1, ws.max_column + 1): + header_cell = ws.cell(row=1, column=col) + header = FileImportUtils.clean_cell_value(header_cell.value) + if header: + headers.append(FileImportUtils.clean_column_name(header)) + else: + headers.append(f"column_{col}") + + # 读取数据行 + for row in range(2, ws.max_row + 1): + row_data = {} + has_data = False + for col, header in enumerate(headers, 1): + cell_value = FileImportUtils.clean_cell_value(ws.cell(row=row, column=col).value) + row_data[header] = cell_value + if cell_value is not None: + has_data = True + + if has_data: + data.append(row_data) + + if data: + table_name = f"{table_name_base}_{sheet_name}" if len(wb.sheetnames) > 1 else table_name_base + results.append({ + "table_name": FileImportUtils.clean_column_name(table_name), + "data": data, + "columns": headers + }) + + return results + except Exception as e: + raise Exception(f"读取Excel文件失败: {str(e)}") + + @staticmethod + def read_csv_file(file_content: bytes, filename: str) -> List[Dict[str, Any]]: + """读取CSV文件内容""" + try: + # 尝试不同编码 + encodings = ['utf-8', 'gbk', 'gb2312', 'gb18030'] + df = None + + for encoding in encodings: + try: + df = pd.read_csv(io.BytesIO(file_content), encoding=encoding) + break + except: + continue + + if df is None: + raise Exception("无法识别文件编码") + + # 清理列名 + df.columns = [FileImportUtils.clean_column_name(col) for col in df.columns] + + # 处理重复列名 + if len(df.columns) != len(set(df.columns)): + new_cols = [] + col_count = {} + for col in df.columns: + if col in col_count: + col_count[col] += 1 + new_cols.append(f"{col}_{col_count[col]}") + else: + col_count[col] = 0 + new_cols.append(col) + df.columns = new_cols + + # 清理数据 + for col in df.columns: + df[col] = df[col].apply(lambda x: FileImportUtils.clean_cell_value(x)) + df[col] = df[col].where(pd.notna(df[col]), None) + + # 转换为字典列表 + data = df.to_dict('records') + + table_name = FileImportUtils.clean_column_name(Path(filename).stem) + + return [{ + "table_name": table_name, + "data": data, + "columns": list(df.columns) + }] + except Exception as e: + raise Exception(f"读取CSV文件失败: {str(e)}") + + @staticmethod + def parse_file(filename: str, file_content: str) -> List[Dict[str, Any]]: + """ + 解析文件内容,根据文件扩展名选择相应的解析方法 + + Args: + filename: 文件名 + file_content: base64编码的文件内容 + + Returns: + 解析后的文件数据列表 + """ + # 解码文件内容 + file_bytes = FileImportUtils.decode_file_content(file_content) + + # 根据文件扩展名处理文件 + file_ext = Path(filename).suffix.lower() + + if file_ext == '.csv': + return FileImportUtils.read_csv_file(file_bytes, filename) + elif file_ext in ['.xlsx', '.xls']: + return FileImportUtils.read_excel_file(file_bytes, filename) + else: + raise ValueError(f"不支持的文件类型: {file_ext}") + + @staticmethod + def prepare_table_columns(data: List[Dict[str, Any]], columns: List[str]) -> Dict[str, str]: + """ + 分析数据并准备表列定义 + + Args: + data: 数据列表 + columns: 列名列表 + + Returns: + 列名到MySQL数据类型的映射 + """ + if not data: + return {} + + # 创建DataFrame用于类型推断 + df = pd.DataFrame(data) + column_types = {} + + for col in columns: + if col in df.columns: + column_types[col] = FileImportUtils.infer_mysql_column_type(df[col]) + else: + column_types[col] = "TEXT" + + return column_types \ No newline at end of file diff --git a/import_database_utils.py b/import_database_utils.py new file mode 100644 index 0000000..4defa18 --- /dev/null +++ b/import_database_utils.py @@ -0,0 +1,364 @@ +""" +数据库导入工具模块 +用于通过HTTP接口导入文件到数据库 +""" + +import os +import sys +import sqlite3 +import pandas as pd +from datetime import datetime +from pathlib import Path +from typing import Optional, Dict, Any, Tuple, List +import openpyxl +import io + +# 添加sqlite-mcp-server到路径以使用其工具函数 +current_dir = os.path.dirname(__file__) +project_root = os.path.dirname(os.path.dirname(current_dir)) +sqlite_mcp_path = os.path.join(project_root, 'sqlite-mcp-server') +if os.path.exists(sqlite_mcp_path) and sqlite_mcp_path not in sys.path: + sys.path.insert(0, sqlite_mcp_path) + +from tools.common import ( + clean_table_name, clean_column_name, infer_data_types, + CHUNK_SIZE, SQLITE_MAX_VARIABLES +) + +def _clean_cell_value(value: Any) -> Any: + """ + 清理单元格内容:去除空格、换行、制表符等 + + Args: + value: 单元格的值 + + Returns: + 清理后的值 + """ + if value is None: + return None + + if isinstance(value, str): + # 去除各种空白字符:空格、换行、回车、制表符、全角空格 + cleaned = value.replace(' ', '').replace('\n', '').replace('\r', '').replace('\t', '').replace('\u3000', '') + return cleaned if cleaned else None + + return value + +def _process_merged_cells(ws: openpyxl.worksheet.worksheet.Worksheet) -> Dict[Tuple[int, int], Any]: + """ + 处理合并单元格,将所有合并单元格打散并填充原值 + + Args: + ws: openpyxl worksheet对象 + + Returns: + 单元格坐标到值的映射 + """ + merged_map = {} + + # 处理所有合并单元格范围 + for merged_range in ws.merged_cells.ranges: + # 获取左上角单元格的值 + min_row, min_col = merged_range.min_row, merged_range.min_col + value = ws.cell(row=min_row, column=min_col).value + + # 填充整个合并范围 + for row in range(merged_range.min_row, merged_range.max_row + 1): + for col in range(merged_range.min_col, merged_range.max_col + 1): + merged_map[(row, col)] = value + + return merged_map + + +def _worksheet_is_empty(ws: openpyxl.worksheet.worksheet.Worksheet) -> bool: + """ + 判断工作表是否为空(无任何有效数据) + """ + max_row = ws.max_row or 0 + max_col = ws.max_column or 0 + if max_row == 0 or max_col == 0: + return True + # 快速路径:单个单元且为空 + if max_row == 1 and max_col == 1: + return _clean_cell_value(ws.cell(row=1, column=1).value) is None + + # 扫描是否存在任一非空单元格 + for row in ws.iter_rows(min_row=1, max_row=max_row, min_col=1, max_col=max_col, values_only=True): + if any(_clean_cell_value(v) is not None for v in row): + return False + return True + +def _read_excel_with_merged_cells(table_name: str, file_content: bytes) -> List[Dict[str, Any]]: + """ + 从内存中读取Excel文件,处理合并单元格 + + Args: + table_name: 表名 + file_content: Excel文件的二进制内容 + + Returns: + 处理后的DataFrame + """ + # 从字节流加载Excel文件 + wb = openpyxl.load_workbook(io.BytesIO(file_content), data_only=True) + wss = wb.sheetnames + data = [] + df = None + + for ws_name in wss: + ws = wb[ws_name] + # 跳过空工作表 + if _worksheet_is_empty(ws): + continue + # 处理合并单元格 + # merged_cells_map = _process_merged_cells(ws) + # 读取所有数据 + all_data = [] + # 跳过空列 + none_raw = [] + max_row = ws.max_row + max_col = ws.max_column + for row_idx in range(1, max_row + 1): + row_data = [] + for col_idx in range(1, max_col + 1): + + # 检查是否在合并单元格映射中 + # if (row_idx, col_idx) in merged_cells_map: + # value = merged_cells_map[(row_idx, col_idx)] + # else: + value = ws.cell(row=row_idx, column=col_idx).value + if (value is None and row_idx == 1) or (col_idx in none_raw): + none_raw.append(col_idx) + continue + # 清理单元格内容 + value = _clean_cell_value(value) + row_data.append(value) + # 跳过完全空白的行 + if any(v is not None for v in row_data): + all_data.append(row_data) + # 转换为DataFrame + if not all_data: + # 空表,跳过 + continue + # 第一行作为列名 + columns = [str(col) if col is not None else f"列{i+1}" + for i, col in enumerate(all_data[0])] + # 其余行作为数据 + data_rows = all_data[1:] if len(all_data) > 1 else [] + # 创建DataFrame + df = pd.DataFrame(data_rows, columns=columns) + if len(wss) > 1: + normalized_sheet = ws_name.replace(' ', '').replace('-', '_') + data.append({"table_name": f"{table_name}_{normalized_sheet}", "df": df}) + else: + data.append({"table_name": table_name, "df": df}) + if not data: + return [] + return data + +def import_to_database( + db_path: str, + file_content: bytes, + filename: str, + table_name: Optional[str] = None, + force_overwrite: bool = False +) -> Dict[str, Any]: + """ + 导入文件内容到指定的SQLite数据库 + + Args: + db_path: 数据库文件路径 + file_content: 文件的二进制内容 + filename: 原始文件名 + table_name: 指定的表名(可选,默认从文件名生成) + force_overwrite: 是否强制覆盖已存在的表 + + Returns: + 导入结果信息 + """ + try: + # 检查数据库是否存在 + if not os.path.exists(db_path): + return {"success": False, "error": f"数据库不存在: {db_path}"} + + + # 读取文件 + file_ext = Path(filename).suffix.lower() + table_name = filename.split('.')[0] + + if file_ext == '.csv': + # 检测编码 + encodings = ['utf-8', 'gbk', 'gb2312', 'gb18030'] + df = None + data = [] + for encoding in encodings: + try: + # 尝试从内存读取 + df = pd.read_csv(io.BytesIO(file_content), encoding=encoding, nrows=5) + # 如果成功,用同样的编码读取完整内容 + df = pd.read_csv(io.BytesIO(file_content), encoding=encoding) + data.append({"table_name": table_name, "df": df}) + break + except: + continue + if df is None: + conn.close() + return {"success": False, "error": "无法识别文件编码"} + elif file_ext in ['.xlsx', '.xls']: + # 使用自定义函数读取Excel,处理合并单元格 + data = _read_excel_with_merged_cells(table_name, file_content) + else: + conn.close() + return {"success": False, "error": f"不支持的文件类型: {file_ext}"} + + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + + for item in data: + table_name = item["table_name"] + df = item["df"] + # 检查表是否存在 + cursor.execute(f""" + SELECT name FROM sqlite_master + WHERE type='table' AND name='{table_name}' + """) + + table_exists = cursor.fetchone() is not None + + if table_exists and not force_overwrite: + # 获取表的基本信息 + cursor.execute(f"SELECT COUNT(*) FROM \"{table_name}\"") + row_count = cursor.fetchone()[0] + + cursor.execute(f"PRAGMA table_info(\"{table_name}\")") + columns = [col[1] for col in cursor.fetchall() if col[1] != '_row_id'] + + conn.close() + return { + "success": False, + "error": "table_exists", + "error_type": "TABLE_EXISTS", + "message": f"表 {table_name} 已存在", + "table_info": { + "table_name": table_name, + "row_count": row_count, + "columns": columns, + "column_count": len(columns) + }, + "suggestion": "您可以选择:1) 覆盖现有表(force_overwrite=true)2) 使用其他表名 3) 取消操作" + } + + total_rows_all = 0 + imported_rows_all = 0 + error_count_all = 0 + table_name_all = [] + for item in data: + df = item["df"] + table_name = item["table_name"] + # 清理列名 + df.columns = [clean_column_name(col) for col in df.columns] + + # 检查是否有重复列名 + if len(df.columns) != len(set(df.columns)): + # 为重复列名添加序号 + new_cols = [] + col_count = {} + for col in df.columns: + if col in col_count: + col_count[col] += 1 + new_cols.append(f"{col}_{col_count[col]}") + else: + col_count[col] = 0 + new_cols.append(col) + df.columns = new_cols + + total_rows = len(df) + + # 如果表存在且需要覆盖,先删除表 + if table_exists and force_overwrite: + cursor.execute(f'DROP TABLE IF EXISTS "{table_name}"') + conn.commit() + + # 创建新表 + data_types = infer_data_types(df) + + columns_def = [] + for col in df.columns: + col_type = data_types.get(col, 'TEXT') + columns_def.append(f'"{col}" {col_type}') + + create_table_sql = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + _row_id INTEGER PRIMARY KEY AUTOINCREMENT, + {', '.join(columns_def)} + ) + """ + cursor.execute(create_table_sql) + conn.commit() + + # 数据清洗和类型转换 + for col in df.columns: + # 对于CSV文件,去除前后空格 + if file_ext == '.csv' and df[col].dtype == 'object': + df[col] = df[col].apply(lambda x: x.strip() if isinstance(x, str) else x) + + # 将 NaN 替换为 None (SQLite NULL) + df[col] = df[col].where(pd.notna(df[col]), None) + + # 日期类型转换为 ISO 格式 + # try: + # # 尝试识别日期列 + # if df[col].dtype == 'object' and df[col].notna().any(): + # sample = df[col].dropna().head(100) + # dates = pd.to_datetime(sample, errors='coerce') + # if dates.notna().sum() / len(sample) > 0.9: + # df[col] = pd.to_datetime(df[col], errors='coerce') + # df[col] = df[col].dt.strftime('%Y-%m-%d %H:%M:%S') + # except: + # pass + + # 计算安全的批次大小(考虑 SQLite 变量限制) + num_columns = len(df.columns) + safe_batch_size = min(CHUNK_SIZE, SQLITE_MAX_VARIABLES // num_columns - 1) + + # 批量导入数据 + error_count = 0 + imported_rows = 0 + + for batch_start in range(0, total_rows, safe_batch_size): + batch_end = min(batch_start + safe_batch_size, total_rows) + batch_df = df.iloc[batch_start:batch_end] + + try: + # 批量插入 + batch_df.to_sql(table_name, conn, if_exists='append', + index=False, method='multi', chunksize=100) + + imported_rows = batch_end + + except Exception as e: + error_count += 1 + print(f"批次导入错误 [{batch_start}-{batch_end}]: {str(e)}") + + # 优化表 + cursor.execute(f'ANALYZE "{table_name}"') + total_rows_all += total_rows + imported_rows_all += imported_rows + error_count_all += error_count + table_name_all.append(table_name) + conn.commit() + conn.close() + + table_name_all = set(table_name_all) + return { + "success": True, + "table_name": table_name, + "total_rows": total_rows_all, + "imported_rows": imported_rows_all, + "error_count": error_count_all, + "message": f"成功导入 {imported_rows}/{total_rows} 行数据到表 {table_name_all}" + } + + except Exception as e: + return {"success": False, "error": str(e)} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index f0d72d9..a0ab391 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,5 @@ pydantic==2.5.0 python-dotenv==1.0.0 apscheduler==3.10.4 pandas==2.1.3 -python-multipart==0.0.6 \ No newline at end of file +python-multipart==0.0.6 +openpyxl==3.1.5 \ No newline at end of file