导入文件建表接口
This commit is contained in:
236
app/utils/file_import.py
Normal file
236
app/utils/file_import.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""
|
||||
文件导入工具类
|
||||
处理Excel/CSV文件的读取、解析和数据清洗
|
||||
"""
|
||||
import base64
|
||||
import io
|
||||
import openpyxl
|
||||
import pandas as pd
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class FileImportUtils:
|
||||
"""文件导入工具类"""
|
||||
|
||||
@staticmethod
|
||||
def clean_cell_value(value: Any) -> Any:
|
||||
"""清理单元格内容:去除空格、换行、制表符等"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if isinstance(value, str):
|
||||
cleaned = value.replace(' ', '').replace('\n', '').replace('\r', '').replace('\t', '').replace('\u3000', '')
|
||||
return cleaned if cleaned else None
|
||||
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def clean_column_name(col_name: str) -> str:
|
||||
"""清理列名,使其符合MySQL命名规范"""
|
||||
if not col_name:
|
||||
return "column_1"
|
||||
|
||||
# 去除特殊字符,只保留字母、数字、下划线、中文
|
||||
cleaned = re.sub(r'[^\w\u4e00-\u9fff]', '_', str(col_name))
|
||||
|
||||
# 如果以数字开头,添加前缀
|
||||
if cleaned and cleaned[0].isdigit():
|
||||
cleaned = f"col_{cleaned}"
|
||||
|
||||
return cleaned[:64] # MySQL列名最大长度限制
|
||||
|
||||
@staticmethod
|
||||
def infer_mysql_column_type(series: pd.Series) -> str:
|
||||
"""推断MySQL列类型"""
|
||||
# 去除空值
|
||||
non_null_series = series.dropna()
|
||||
|
||||
if len(non_null_series) == 0:
|
||||
return "TEXT"
|
||||
|
||||
# 检查是否为数字
|
||||
try:
|
||||
pd.to_numeric(non_null_series)
|
||||
# 检查是否为整数
|
||||
if all(isinstance(x, (int, float)) and (pd.isna(x) or x == int(x)) for x in non_null_series):
|
||||
max_val = abs(non_null_series.max())
|
||||
if max_val < 2147483648: # INT范围
|
||||
return "INT"
|
||||
else:
|
||||
return "BIGINT"
|
||||
else:
|
||||
return "DECIMAL(10,2)"
|
||||
except:
|
||||
pass
|
||||
|
||||
# 检查字符串长度
|
||||
max_length = non_null_series.astype(str).str.len().max()
|
||||
if max_length <= 255:
|
||||
return f"VARCHAR({min(max_length * 2, 255)})"
|
||||
else:
|
||||
return "TEXT"
|
||||
|
||||
@staticmethod
|
||||
def decode_file_content(file_content: str) -> bytes:
|
||||
"""解码base64文件内容"""
|
||||
try:
|
||||
return base64.b64decode(file_content)
|
||||
except Exception as e:
|
||||
raise ValueError(f"文件内容解码失败: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def read_excel_file(file_content: bytes, filename: str) -> List[Dict[str, Any]]:
|
||||
"""读取Excel文件内容"""
|
||||
try:
|
||||
wb = openpyxl.load_workbook(io.BytesIO(file_content), data_only=True)
|
||||
table_name_base = Path(filename).stem
|
||||
results = []
|
||||
|
||||
for sheet_name in wb.sheetnames:
|
||||
ws = wb[sheet_name]
|
||||
if ws.max_row is None or ws.max_row <= 1:
|
||||
continue
|
||||
|
||||
# 读取数据
|
||||
data = []
|
||||
headers = []
|
||||
|
||||
# 读取第一行作为列名
|
||||
for col in range(1, ws.max_column + 1):
|
||||
header_cell = ws.cell(row=1, column=col)
|
||||
header = FileImportUtils.clean_cell_value(header_cell.value)
|
||||
if header:
|
||||
headers.append(FileImportUtils.clean_column_name(header))
|
||||
else:
|
||||
headers.append(f"column_{col}")
|
||||
|
||||
# 读取数据行
|
||||
for row in range(2, ws.max_row + 1):
|
||||
row_data = {}
|
||||
has_data = False
|
||||
for col, header in enumerate(headers, 1):
|
||||
cell_value = FileImportUtils.clean_cell_value(ws.cell(row=row, column=col).value)
|
||||
row_data[header] = cell_value
|
||||
if cell_value is not None:
|
||||
has_data = True
|
||||
|
||||
if has_data:
|
||||
data.append(row_data)
|
||||
|
||||
if data:
|
||||
table_name = f"{table_name_base}_{sheet_name}" if len(wb.sheetnames) > 1 else table_name_base
|
||||
results.append({
|
||||
"table_name": FileImportUtils.clean_column_name(table_name),
|
||||
"data": data,
|
||||
"columns": headers
|
||||
})
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
raise Exception(f"读取Excel文件失败: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def read_csv_file(file_content: bytes, filename: str) -> List[Dict[str, Any]]:
|
||||
"""读取CSV文件内容"""
|
||||
try:
|
||||
# 尝试不同编码
|
||||
encodings = ['utf-8', 'gbk', 'gb2312', 'gb18030']
|
||||
df = None
|
||||
|
||||
for encoding in encodings:
|
||||
try:
|
||||
df = pd.read_csv(io.BytesIO(file_content), encoding=encoding)
|
||||
break
|
||||
except:
|
||||
continue
|
||||
|
||||
if df is None:
|
||||
raise Exception("无法识别文件编码")
|
||||
|
||||
# 清理列名
|
||||
df.columns = [FileImportUtils.clean_column_name(col) for col in df.columns]
|
||||
|
||||
# 处理重复列名
|
||||
if len(df.columns) != len(set(df.columns)):
|
||||
new_cols = []
|
||||
col_count = {}
|
||||
for col in df.columns:
|
||||
if col in col_count:
|
||||
col_count[col] += 1
|
||||
new_cols.append(f"{col}_{col_count[col]}")
|
||||
else:
|
||||
col_count[col] = 0
|
||||
new_cols.append(col)
|
||||
df.columns = new_cols
|
||||
|
||||
# 清理数据
|
||||
for col in df.columns:
|
||||
df[col] = df[col].apply(lambda x: FileImportUtils.clean_cell_value(x))
|
||||
df[col] = df[col].where(pd.notna(df[col]), None)
|
||||
|
||||
# 转换为字典列表
|
||||
data = df.to_dict('records')
|
||||
|
||||
table_name = FileImportUtils.clean_column_name(Path(filename).stem)
|
||||
|
||||
return [{
|
||||
"table_name": table_name,
|
||||
"data": data,
|
||||
"columns": list(df.columns)
|
||||
}]
|
||||
except Exception as e:
|
||||
raise Exception(f"读取CSV文件失败: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def parse_file(filename: str, file_content: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
解析文件内容,根据文件扩展名选择相应的解析方法
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
file_content: base64编码的文件内容
|
||||
|
||||
Returns:
|
||||
解析后的文件数据列表
|
||||
"""
|
||||
# 解码文件内容
|
||||
file_bytes = FileImportUtils.decode_file_content(file_content)
|
||||
|
||||
# 根据文件扩展名处理文件
|
||||
file_ext = Path(filename).suffix.lower()
|
||||
|
||||
if file_ext == '.csv':
|
||||
return FileImportUtils.read_csv_file(file_bytes, filename)
|
||||
elif file_ext in ['.xlsx', '.xls']:
|
||||
return FileImportUtils.read_excel_file(file_bytes, filename)
|
||||
else:
|
||||
raise ValueError(f"不支持的文件类型: {file_ext}")
|
||||
|
||||
@staticmethod
|
||||
def prepare_table_columns(data: List[Dict[str, Any]], columns: List[str]) -> Dict[str, str]:
|
||||
"""
|
||||
分析数据并准备表列定义
|
||||
|
||||
Args:
|
||||
data: 数据列表
|
||||
columns: 列名列表
|
||||
|
||||
Returns:
|
||||
列名到MySQL数据类型的映射
|
||||
"""
|
||||
if not data:
|
||||
return {}
|
||||
|
||||
# 创建DataFrame用于类型推断
|
||||
df = pd.DataFrame(data)
|
||||
column_types = {}
|
||||
|
||||
for col in columns:
|
||||
if col in df.columns:
|
||||
column_types[col] = FileImportUtils.infer_mysql_column_type(df[col])
|
||||
else:
|
||||
column_types[col] = "TEXT"
|
||||
|
||||
return column_types
|
||||
Reference in New Issue
Block a user