""" 文件导入工具类 处理Excel/CSV文件的读取、解析和数据清洗 """ import base64 import io import openpyxl import pandas as pd import re from pathlib import Path from typing import Any, Dict, List, Optional class FileImportUtils: """文件导入工具类""" @staticmethod def clean_cell_value(value: Any) -> Any: """清理单元格内容:去除空格、换行、制表符等""" if value is None: return None if isinstance(value, str): cleaned = value.replace(' ', '').replace('\n', '').replace('\r', '').replace('\t', '').replace('\u3000', '') return cleaned if cleaned else None return value @staticmethod def clean_column_name(col_name: str) -> str: """清理列名,使其符合MySQL命名规范""" if not col_name: return "column_1" # 去除特殊字符,只保留字母、数字、下划线、中文 cleaned = re.sub(r'[^\w\u4e00-\u9fff]', '_', str(col_name)) # 如果以数字开头,添加前缀 if cleaned and cleaned[0].isdigit(): cleaned = f"col_{cleaned}" return cleaned[:64] # MySQL列名最大长度限制 @staticmethod def infer_mysql_column_type(series: pd.Series) -> str: """推断MySQL列类型""" # 去除空值 non_null_series = series.dropna() if len(non_null_series) == 0: return "TEXT" # 检查是否为数字 try: pd.to_numeric(non_null_series) # 检查是否为整数 if all(isinstance(x, (int, float)) and (pd.isna(x) or x == int(x)) for x in non_null_series): max_val = abs(non_null_series.max()) if max_val < 2147483648: # INT范围 return "INT" else: return "BIGINT" else: return "DECIMAL(10,2)" except: pass # 检查字符串长度 max_length = non_null_series.astype(str).str.len().max() if max_length <= 255: return f"VARCHAR({min(max_length * 2, 255)})" else: return "TEXT" @staticmethod def decode_file_content(file_content: str) -> bytes: """解码base64文件内容""" try: return base64.b64decode(file_content) except Exception as e: raise ValueError(f"文件内容解码失败: {str(e)}") @staticmethod def read_excel_file(file_content: bytes, filename: str) -> List[Dict[str, Any]]: """读取Excel文件内容""" try: wb = openpyxl.load_workbook(io.BytesIO(file_content), data_only=True) table_name_base = Path(filename).stem results = [] for sheet_name in wb.sheetnames: ws = wb[sheet_name] if ws.max_row is None or ws.max_row <= 1: continue # 读取数据 data = [] headers = [] # 读取第一行作为列名 for col in range(1, ws.max_column + 1): header_cell = ws.cell(row=1, column=col) header = FileImportUtils.clean_cell_value(header_cell.value) if header: headers.append(FileImportUtils.clean_column_name(header)) else: headers.append(f"column_{col}") # 读取数据行 for row in range(2, ws.max_row + 1): row_data = {} has_data = False for col, header in enumerate(headers, 1): cell_value = FileImportUtils.clean_cell_value(ws.cell(row=row, column=col).value) row_data[header] = cell_value if cell_value is not None: has_data = True if has_data: data.append(row_data) if data: table_name = f"{table_name_base}_{sheet_name}" if len(wb.sheetnames) > 1 else table_name_base results.append({ "table_name": FileImportUtils.clean_column_name(table_name), "data": data, "columns": headers }) return results except Exception as e: raise Exception(f"读取Excel文件失败: {str(e)}") @staticmethod def read_csv_file(file_content: bytes, filename: str) -> List[Dict[str, Any]]: """读取CSV文件内容""" try: # 尝试不同编码 encodings = ['utf-8', 'gbk', 'gb2312', 'gb18030'] df = None for encoding in encodings: try: df = pd.read_csv(io.BytesIO(file_content), encoding=encoding) break except: continue if df is None: raise Exception("无法识别文件编码") # 清理列名 df.columns = [FileImportUtils.clean_column_name(col) for col in df.columns] # 处理重复列名 if len(df.columns) != len(set(df.columns)): new_cols = [] col_count = {} for col in df.columns: if col in col_count: col_count[col] += 1 new_cols.append(f"{col}_{col_count[col]}") else: col_count[col] = 0 new_cols.append(col) df.columns = new_cols # 清理数据 for col in df.columns: df[col] = df[col].apply(lambda x: FileImportUtils.clean_cell_value(x)) df[col] = df[col].where(pd.notna(df[col]), None) # 转换为字典列表 data = df.to_dict('records') table_name = FileImportUtils.clean_column_name(Path(filename).stem) return [{ "table_name": table_name, "data": data, "columns": list(df.columns) }] except Exception as e: raise Exception(f"读取CSV文件失败: {str(e)}") @staticmethod def parse_file(filename: str, file_content: str) -> List[Dict[str, Any]]: """ 解析文件内容,根据文件扩展名选择相应的解析方法 Args: filename: 文件名 file_content: base64编码的文件内容 Returns: 解析后的文件数据列表 """ # 解码文件内容 file_bytes = FileImportUtils.decode_file_content(file_content) # 根据文件扩展名处理文件 file_ext = Path(filename).suffix.lower() if file_ext == '.csv': return FileImportUtils.read_csv_file(file_bytes, filename) elif file_ext in ['.xlsx', '.xls']: return FileImportUtils.read_excel_file(file_bytes, filename) else: raise ValueError(f"不支持的文件类型: {file_ext}") @staticmethod def prepare_table_columns(data: List[Dict[str, Any]], columns: List[str]) -> Dict[str, str]: """ 分析数据并准备表列定义 Args: data: 数据列表 columns: 列名列表 Returns: 列名到MySQL数据类型的映射 """ if not data: return {} # 创建DataFrame用于类型推断 df = pd.DataFrame(data) column_types = {} for col in columns: if col in df.columns: column_types[col] = FileImportUtils.infer_mysql_column_type(df[col]) else: column_types[col] = "TEXT" return column_types