From 48ec2347a1b986e3f70150612321df1c2e74f72e Mon Sep 17 00:00:00 2001 From: lhx Date: Thu, 20 Nov 2025 17:20:00 +0800 Subject: [PATCH] =?UTF-8?q?=E7=9B=B4=E6=8E=A5=E5=88=B0=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E5=BA=93=E8=84=9A=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- upload_app/direct_import/config.py | 102 +++ upload_app/direct_import/database_service.py | 831 ++++++++++++++++++ .../direct_import/direct_batch_importer.py | 470 ++++++++++ 3 files changed, 1403 insertions(+) create mode 100644 upload_app/direct_import/config.py create mode 100644 upload_app/direct_import/database_service.py create mode 100644 upload_app/direct_import/direct_batch_importer.py diff --git a/upload_app/direct_import/config.py b/upload_app/direct_import/config.py new file mode 100644 index 0000000..99b2ab1 --- /dev/null +++ b/upload_app/direct_import/config.py @@ -0,0 +1,102 @@ +""" +批量导入脚本配置文件 +统一管理所有配置参数 +""" +import os +from pathlib import Path + +# 获取项目根目录 +PROJECT_ROOT = Path(__file__).parent.parent.parent.absolute() + +# 数据文件配置 +DATA_ROOT = PROJECT_ROOT / "upload_app" / "data" # parquet文件根目录 +DEFAULT_ACCOUNT_ID = 1 # 默认账号ID + +# 批次大小配置(根据数据库性能调整) +BATCH_SIZE = { + "section": 500, # 断面数据批次大小 + "checkpoint": 500, # 观测点批次大小 + "settlement": 500, # 沉降数据批次大小 + "level": 500, # 水准数据批次大小 + "original": 1000 # 原始数据批次大小(支持更大批次) +} + +# 断点续传配置 +RESUME_PROGRESS_FILE = PROJECT_ROOT / "data_import_progress.json" +RESUME_ENABLE = False # 是否开启断点续传 + +# 数据类型映射(复用现有逻辑) +DATA_TYPE_MAPPING = { + "section": ( + "断面数据表", + "section_", + "sections", + ["account_id"] + ), + "checkpoint": ( + "观测点数据表", + "point_", + "checkpoints", + [] + ), + "settlement": ( + "沉降数据表", + "settlement_", + "settlements", + [] + ), + "level": ( + "水准数据表", + "level_", + "levels", + [] + ), + "original": ( + "原始数据表", + "original_", + "originals", + [] + ) +} + +# 数据依赖顺序(可忽略) +DATA_TYPE_ORDER = [ + # ("section", "断面数据"), + # ("checkpoint", "观测点数据"), + # ("settlement", "沉降数据"), + # ("level", "水准数据"), + ("original", "原始数据") +] + +# 核心字段校验配置 +CRITICAL_FIELDS = { + "section": ["section_id", "account_id", "mileage", "work_site"], + "checkpoint": ["point_id", "section_id", "aname", "burial_date"], + "settlement": ["NYID", "point_id", "sjName"], + "level": ["NYID", "linecode", "wsphigh", "createDate"], + "original": ["NYID", "bfpcode", "mtime", "bfpvalue", "sort"] +} + +# 数值型字段强制转换配置 +TYPE_CONVERSION = { + "section": { + "section_id": int, + "account_id": int + }, + "checkpoint": { + "point_id": int + }, + "settlement": { + "NYID": str # 沉降NYID转为字符串 + }, + "original": { + "sort": int + } +} + +# 日志配置 +LOG_LEVEL = "INFO" +LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + +# 文件过滤配置 +MIN_FILE_SIZE = 1024 # 最小文件大小(字节),过滤空文件 diff --git a/upload_app/direct_import/database_service.py b/upload_app/direct_import/database_service.py new file mode 100644 index 0000000..cc8dfb0 --- /dev/null +++ b/upload_app/direct_import/database_service.py @@ -0,0 +1,831 @@ +""" +数据库操作服务层 +直接操作数据库,跳过HTTP请求环节 +复用现有的业务逻辑和模型 +""" +import os +import sys +import logging +from pathlib import Path +from typing import List, Optional, Dict, Any + +# 添加项目根目录到Python路径 +PROJECT_ROOT = Path(__file__).parent.parent.parent.absolute() +sys.path.insert(0, str(PROJECT_ROOT)) + +# 导入数据库操作相关模块 +from sqlalchemy import create_engine, text, inspect +from sqlalchemy.orm import sessionmaker, Session +from dotenv import load_dotenv + +# 导入数据模型 +from app.models.section_data import SectionData +from app.models.checkpoint import Checkpoint +from app.models.settlement_data import SettlementData +from app.models.level_data import LevelData +from app.models.account import Account + +# 导入原始数据动态表支持 +from app.models.original_data import get_original_data_model, get_table_name + +# 加载环境变量 +load_dotenv(PROJECT_ROOT / ".env") + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +class DatabaseService: + """数据库操作服务类 - 直接操作数据库""" + + def __init__(self): + """初始化数据库连接""" + # 构建数据库URL + database_url = os.getenv("DATABASE_URL") + if not database_url: + # 从单独的数据库配置构建URL + db_host = os.getenv("DB_HOST", "localhost") + db_port = os.getenv("DB_PORT", "3306") + db_user = os.getenv("DB_USER", "root") + db_password = os.getenv("DB_PASSWORD", "root") + db_name = os.getenv("DB_NAME", "railway") + database_url = f"mysql+pymysql://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}" + + logger.info(f"数据库连接: {database_url}") + + # 创建数据库引擎和会话工厂 + self.engine = create_engine( + database_url, + pool_pre_ping=True, + pool_recycle=3600, + echo=False # 设置为True可以看到SQL语句 + ) + self.SessionLocal = sessionmaker(bind=self.engine) + + def get_db_session(self) -> Session: + """获取数据库会话""" + return self.SessionLocal() + + # ==================== 断面数据服务 ==================== + + def batch_import_sections(self, db: Session, data: List[Dict[str, Any]]) -> Dict[str, Any]: + """批量导入断面数据 - 复用现有业务逻辑""" + total_count = len(data) + success_count = 0 + failed_count = 0 + failed_items = [] + + if total_count == 0: + return { + 'success': False, + 'message': '导入数据不能为空', + 'total_count': 0, + 'success_count': 0, + 'failed_count': 0, + 'failed_items': [] + } + + try: + db.begin() + + # 批量查询现有断面数据(IN查询) + section_id_list = list(set(str(item.get('section_id')) for item in data if item.get('section_id'))) + logger.info(f"Checking {len(section_id_list)} unique section_ids") + + existing_sections = db.query(SectionData).filter(SectionData.section_id.in_(section_id_list)).all() + + # 使用section_id创建查找表 + existing_map = { + section.section_id: section + for section in existing_sections + } + logger.info(f"Found {len(existing_sections)} existing sections") + + # 批量处理插入和跳过 + to_insert = [] + + for item_data in data: + section_id = str(item_data.get('section_id')) + + if section_id in existing_map: + # 数据已存在,跳过 + logger.info(f"Continue section data: {section_id}") + failed_count += 1 + failed_items.append({ + 'data': item_data, + 'error': '数据已存在,跳过插入操作' + }) + else: + # 记录需要插入的数据 + to_insert.append(item_data) + + # 执行批量插入 + if to_insert: + logger.info(f"Inserting {len(to_insert)} new records") + batch_size = 500 + for i in range(0, len(to_insert), batch_size): + batch = to_insert[i:i + batch_size] + try: + section_data_list = [ + SectionData( + section_id=str(item.get('section_id')), + mileage=item.get('mileage'), + work_site=item.get('work_site'), + basic_types=item.get('basic_types'), + height=item.get('height'), + status=item.get('status'), + number=str(item.get('number')) if item.get('number') else None, + transition_paragraph=item.get('transition_paragraph'), + design_fill_height=item.get('design_fill_height'), + compression_layer_thickness=item.get('compression_layer_thickness'), + treatment_depth=item.get('treatment_depth'), + foundation_treatment_method=item.get('foundation_treatment_method'), + rock_mass_classification=item.get('rock_mass_classification'), + account_id=str(item.get('account_id')) if item.get('account_id') else None + ) + for item in batch + ] + db.add_all(section_data_list) + success_count += len(batch) + logger.info(f"Inserted batch {i//batch_size + 1}: {len(batch)} records") + except Exception as e: + failed_count += len(batch) + failed_items.extend([ + { + 'data': item, + 'error': f'插入失败: {str(e)}' + } + for item in batch + ]) + logger.error(f"Failed to insert batch: {str(e)}") + raise e + + db.commit() + logger.info(f"Batch import sections completed. Success: {success_count}, Failed: {failed_count}") + + return { + 'success': True, + 'message': '批量导入完成' if failed_count == 0 else f'部分导入失败', + 'total_count': total_count, + 'success_count': success_count, + 'failed_count': failed_count, + 'failed_items': failed_items + } + + except Exception as e: + db.rollback() + logger.error(f"Batch import sections failed: {str(e)}") + return { + 'success': False, + 'message': f'批量导入失败: {str(e)}', + 'total_count': total_count, + 'success_count': 0, + 'failed_count': total_count, + 'failed_items': failed_items + } + + # ==================== 观测点数据服务 ==================== + + def batch_import_checkpoints(self, db: Session, data: List[Dict[str, Any]]) -> Dict[str, Any]: + """批量导入观测点数据 - 复用现有业务逻辑""" + total_count = len(data) + success_count = 0 + failed_count = 0 + failed_items = [] + + if total_count == 0: + return { + 'success': False, + 'message': '导入数据不能为空', + 'total_count': 0, + 'success_count': 0, + 'failed_count': 0, + 'failed_items': [] + } + + try: + db.begin() + + # 批量查询断面数据(IN查询) + section_id_list = list(set(str(item.get('section_id')) for item in data if item.get('section_id'))) + logger.info(f"Checking {len(section_id_list)} unique section_ids in section data") + + sections = db.query(SectionData).filter(SectionData.section_id.in_(section_id_list)).all() + section_map = {s.section_id: s for s in sections} + missing_section_ids = set(section_id_list) - set(section_map.keys()) + + # 记录缺失的断面 + for item_data in data: + section_id = str(item_data.get('section_id')) + if section_id in missing_section_ids: + failed_count += 1 + failed_items.append({ + 'data': item_data, + 'error': '断面ID不存在,跳过插入操作' + }) + + # 如果所有数据都失败,直接返回 + if failed_count == total_count: + db.rollback() + return { + 'success': False, + 'message': '所有断面ID都不存在', + 'total_count': total_count, + 'success_count': 0, + 'failed_count': total_count, + 'failed_items': failed_items + } + + # 批量查询现有观测点数据 + valid_items = [item for item in data if str(item.get('section_id')) not in missing_section_ids] + if valid_items: + point_id_list = list(set(str(item.get('point_id')) for item in valid_items if item.get('point_id'))) + existing_checkpoints = db.query(Checkpoint).filter(Checkpoint.point_id.in_(point_id_list)).all() + + existing_map = { + checkpoint.point_id: checkpoint + for checkpoint in existing_checkpoints + } + logger.info(f"Found {len(existing_checkpoints)} existing checkpoints") + + to_insert = [] + + for item_data in valid_items: + point_id = str(item_data.get('point_id')) + + if point_id in existing_map: + # 数据已存在,跳过 + logger.info(f"Continue checkpoint data: {point_id}") + failed_count += 1 + failed_items.append({ + 'data': item_data, + 'error': '数据已存在,跳过插入操作' + }) + else: + to_insert.append(item_data) + + # 执行批量插入 + if to_insert: + logger.info(f"Inserting {len(to_insert)} new records") + batch_size = 500 + for i in range(0, len(to_insert), batch_size): + batch = to_insert[i:i + batch_size] + try: + checkpoint_list = [ + Checkpoint( + point_id=str(item.get('point_id')), + aname=item.get('aname'), + section_id=str(item.get('section_id')), + burial_date=item.get('burial_date') + ) + for item in batch + ] + db.add_all(checkpoint_list) + success_count += len(batch) + logger.info(f"Inserted batch {i//batch_size + 1}: {len(batch)} records") + except Exception as e: + failed_count += len(batch) + failed_items.extend([ + { + 'data': item, + 'error': f'插入失败: {str(e)}' + } + for item in batch + ]) + logger.error(f"Failed to insert batch: {str(e)}") + raise e + + db.commit() + logger.info(f"Batch import checkpoints completed. Success: {success_count}, Failed: {failed_count}") + + return { + 'success': True, + 'message': '批量导入完成' if failed_count == 0 else f'部分导入失败', + 'total_count': total_count, + 'success_count': success_count, + 'failed_count': failed_count, + 'failed_items': failed_items + } + + except Exception as e: + db.rollback() + logger.error(f"Batch import checkpoints failed: {str(e)}") + return { + 'success': False, + 'message': f'批量导入失败: {str(e)}', + 'total_count': total_count, + 'success_count': 0, + 'failed_count': total_count, + 'failed_items': failed_items + } + + # ==================== 沉降数据服务 ==================== + + def batch_import_settlement_data(self, db: Session, data: List[Dict[str, Any]]) -> Dict[str, Any]: + """批量导入沉降数据 - 复用现有业务逻辑""" + total_count = len(data) + success_count = 0 + failed_count = 0 + failed_items = [] + + if total_count == 0: + return { + 'success': False, + 'message': '导入数据不能为空', + 'total_count': 0, + 'success_count': 0, + 'failed_count': 0, + 'failed_items': [] + } + + try: + db.begin() + + # 批量查询观测点数据 + point_id_list = list(set(str(item.get('point_id')) for item in data if item.get('point_id'))) + logger.info(f"Checking {len(point_id_list)} unique point_ids in checkpoint data") + + checkpoints = db.query(Checkpoint).filter(Checkpoint.point_id.in_(point_id_list)).all() + checkpoint_map = {c.point_id: c for c in checkpoints} + missing_point_ids = set(point_id_list) - set(checkpoint_map.keys()) + + # 记录缺失的观测点 + for item_data in data: + point_id = str(item_data.get('point_id')) + if point_id in missing_point_ids: + failed_count += 1 + failed_items.append({ + 'data': item_data, + 'error': '测点id不存在,跳过插入操作' + }) + + if failed_count == total_count: + db.rollback() + return { + 'success': False, + 'message': '所有观测点ID都不存在', + 'total_count': total_count, + 'success_count': 0, + 'failed_count': total_count, + 'failed_items': failed_items + } + + # 批量查询现有沉降数据 + valid_items = [item for item in data if str(item.get('point_id')) not in missing_point_ids] + if valid_items: + existing_data = db.query(SettlementData).filter( + SettlementData.point_id.in_(point_id_list) + ).all() + + existing_map = { + f"{item.point_id}_{item.NYID}": item + for item in existing_data + } + logger.info(f"Found {len(existing_data)} existing settlement records") + + to_insert = [] + + for item_data in valid_items: + point_id = str(item_data.get('point_id')) + nyid = str(item_data.get('NYID')) + + key = f"{point_id}_{nyid}" + + if key in existing_map: + # 数据已存在,跳过 + logger.info(f"Continue settlement data: {point_id}-{nyid}") + failed_count += 1 + failed_items.append({ + 'data': item_data, + 'error': '数据已存在,跳过插入操作' + }) + else: + to_insert.append(item_data) + + # 执行批量插入 + if to_insert: + logger.info(f"Inserting {len(to_insert)} new records") + batch_size = 500 + for i in range(0, len(to_insert), batch_size): + batch = to_insert[i:i + batch_size] + try: + settlement_data_list = [ + SettlementData( + point_id=str(item.get('point_id')), + CVALUE=item.get('CVALUE'), + MAVALUE=item.get('MAVALUE'), + MTIME_W=item.get('MTIME_W'), + NYID=str(item.get('NYID')), + PRELOADH=item.get('PRELOADH'), + PSTATE=item.get('PSTATE'), + REMARK=item.get('REMARK'), + WORKINFO=item.get('WORKINFO'), + createdate=item.get('createdate'), + day=item.get('day'), + day_jg=item.get('day_jg'), + isgzjdxz=item.get('isgzjdxz'), + mavalue_bc=item.get('mavalue_bc'), + mavalue_lj=item.get('mavalue_lj'), + sjName=item.get('sjName'), + useflag=item.get('useflag'), + workinfoname=item.get('workinfoname'), + upd_remark=item.get('upd_remark') + ) + for item in batch + ] + db.add_all(settlement_data_list) + success_count += len(batch) + logger.info(f"Inserted batch {i//batch_size + 1}: {len(batch)} records") + except Exception as e: + failed_count += len(batch) + failed_items.extend([ + { + 'data': item, + 'error': f'插入失败: {str(e)}' + } + for item in batch + ]) + logger.error(f"Failed to insert batch: {str(e)}") + raise e + + db.commit() + logger.info(f"Batch import settlement data completed. Success: {success_count}, Failed: {failed_count}") + + return { + 'success': True, + 'message': '批量导入完成' if failed_count == 0 else f'部分导入失败', + 'total_count': total_count, + 'success_count': success_count, + 'failed_count': failed_count, + 'failed_items': failed_items + } + + except Exception as e: + db.rollback() + logger.error(f"Batch import settlement data failed: {str(e)}") + return { + 'success': False, + 'message': f'批量导入失败: {str(e)}', + 'total_count': total_count, + 'success_count': 0, + 'failed_count': total_count, + 'failed_items': failed_items + } + + # ==================== 水准数据服务 ==================== + + def batch_import_level_data(self, db: Session, data: List[Dict[str, Any]]) -> Dict[str, Any]: + """批量导入水准数据 - 复用现有业务逻辑""" + total_count = len(data) + success_count = 0 + failed_count = 0 + failed_items = [] + + if total_count == 0: + return { + 'success': False, + 'message': '导入数据不能为空', + 'total_count': 0, + 'success_count': 0, + 'failed_count': 0, + 'failed_items': [] + } + + try: + db.begin() + + # 批量查询沉降数据 + nyid_list = list(set(str(item.get('NYID')) for item in data if item.get('NYID'))) + logger.info(f"Checking {len(nyid_list)} unique NYIDs in settlement data") + + settlements = db.query(SettlementData).filter(SettlementData.NYID.in_(nyid_list)).all() + settlement_map = {s.NYID: s for s in settlements} + missing_nyids = set(nyid_list) - set(settlement_map.keys()) + + # 记录缺失的NYID + for item_data in data: + nyid = str(item_data.get('NYID')) + if nyid in missing_nyids: + failed_count += 1 + failed_items.append({ + 'data': item_data, + 'error': '期数ID在沉降表中不存在,跳过插入操作' + }) + + if failed_count == total_count: + db.rollback() + return { + 'success': False, + 'message': '所有期数ID在沉降表中都不存在', + 'total_count': total_count, + 'success_count': 0, + 'failed_count': total_count, + 'failed_items': failed_items + } + + # 批量查询现有水准数据 + valid_items = [item for item in data if str(item.get('NYID')) not in missing_nyids] + if valid_items: + existing_data = db.query(LevelData).filter( + LevelData.NYID.in_(nyid_list) + ).all() + + existing_map = { + f"{item.NYID}_{item.linecode}": item + for item in existing_data + } + logger.info(f"Found {len(existing_data)} existing level records") + + to_insert = [] + + for item_data in valid_items: + nyid = str(item_data.get('NYID')) + linecode = item_data.get('linecode') + + key = f"{nyid}_{linecode}" + + if key in existing_map: + # 数据已存在,跳过 + logger.info(f"Continue level data: {nyid}-{linecode}") + failed_count += 1 + failed_items.append({ + 'data': item_data, + 'error': '数据已存在,跳过插入操作' + }) + else: + to_insert.append(item_data) + + # 执行批量插入 + if to_insert: + logger.info(f"Inserting {len(to_insert)} new records") + batch_size = 500 + for i in range(0, len(to_insert), batch_size): + batch = to_insert[i:i + batch_size] + try: + level_data_list = [ + LevelData( + linecode=str(item.get('linecode')), + benchmarkids=item.get('benchmarkids'), + wsphigh=item.get('wsphigh'), + mtype=item.get('mtype'), + NYID=str(item.get('NYID')), + createDate=item.get('createDate'), + wspversion=item.get('wspversion'), + barometric=str(item.get('barometric')) if item.get('barometric') is not None else None, + equipbrand=item.get('equipbrand'), + instrumodel=item.get('instrumodel'), + serialnum=item.get('serialnum'), + sjname=item.get('sjname'), + temperature=str(item.get('temperature')) if item.get('temperature') is not None else None, + weather=str(item.get('weather')) if item.get('weather') is not None else None + ) + for item in batch + ] + db.add_all(level_data_list) + success_count += len(batch) + logger.info(f"Inserted batch {i//batch_size + 1}: {len(batch)} records") + except Exception as e: + failed_count += len(batch) + failed_items.extend([ + { + 'data': item, + 'error': f'插入失败: {str(e)}' + } + for item in batch + ]) + logger.error(f"Failed to insert batch: {str(e)}") + raise e + + db.commit() + logger.info(f"Batch import level data completed. Success: {success_count}, Failed: {failed_count}") + + return { + 'success': True, + 'message': '批量导入完成' if failed_count == 0 else f'部分导入失败', + 'total_count': total_count, + 'success_count': success_count, + 'failed_count': failed_count, + 'failed_items': failed_items + } + + except Exception as e: + db.rollback() + logger.error(f"Batch import level data failed: {str(e)}") + return { + 'success': False, + 'message': f'批量导入失败: {str(e)}', + 'total_count': total_count, + 'success_count': 0, + 'failed_count': total_count, + 'failed_items': failed_items + } + + # ==================== 原始数据服务 ==================== + + def _ensure_table_exists(self, account_id: int) -> bool: + """确保指定账号的原始数据表存在,不存在则创建""" + table_name = get_table_name(account_id) + inspector = inspect(self.engine) + + # 检查表是否存在 + if table_name in inspector.get_table_names(): + logger.info(f"Table {table_name} already exists") + return True + + # 表不存在,创建表 + max_retries = 3 + for attempt in range(max_retries): + try: + create_table_sql = f""" + CREATE TABLE `{table_name}` ( + `id` INT AUTO_INCREMENT PRIMARY KEY, + `account_id` INT NOT NULL COMMENT '账号ID', + `bfpcode` VARCHAR(1000) NOT NULL COMMENT '前(后)视点名称', + `mtime` DATETIME NOT NULL COMMENT '测点观测时间', + `bffb` VARCHAR(1000) NOT NULL COMMENT '前(后)视标记符', + `bfpl` VARCHAR(1000) NOT NULL COMMENT '前(后)视距离(m)', + `bfpvalue` VARCHAR(1000) NOT NULL COMMENT '前(后)视尺读数(m)', + `NYID` VARCHAR(100) NOT NULL COMMENT '期数id', + `sort` INT COMMENT '序号', + INDEX `idx_nyid` (`NYID`), + INDEX `idx_account_id` (`account_id`) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='原始数据表_账号{account_id}' + """ + # 使用引擎直接执行,不需要事务(CREATE TABLE是自动提交的) + with self.engine.begin() as conn: + conn.execute(text(create_table_sql)) + + logger.info(f"Table {table_name} created successfully") + return True + + except Exception as e: + logger.warning(f"Attempt {attempt + 1} to create table {table_name} failed: {str(e)}") + if attempt == max_retries - 1: + logger.error(f"Failed to create table {table_name} after {max_retries} attempts") + return False + + import time + time.sleep(0.1 * (attempt + 1)) + + return False + + def batch_import_original_data(self, db: Session, data: List[Dict[str, Any]]) -> Dict[str, Any]: + """批量导入原始数据 - 复用现有业务逻辑,支持分表""" + total_count = len(data) + success_count = 0 + failed_count = 0 + failed_items = [] + + if total_count == 0: + return { + 'success': False, + 'message': '导入数据不能为空', + 'total_count': 0, + 'success_count': 0, + 'failed_count': 0, + 'failed_items': [] + } + + # 获取数据的account_id + account_id = data[0].get('account_id') + if not account_id: + return { + 'success': False, + 'message': '数据中缺少account_id字段', + 'total_count': total_count, + 'success_count': 0, + 'failed_count': total_count, + 'failed_items': [] + } + + # 验证账号是否存在 + account = db.query(Account).filter(Account.id == account_id).first() + if not account: + return { + 'success': False, + 'message': f'账号ID {account_id} 不存在', + 'total_count': total_count, + 'success_count': 0, + 'failed_count': total_count, + 'failed_items': [] + } + + # 确保表存在 + table_created = self._ensure_table_exists(account_id) + if not table_created: + return { + 'success': False, + 'message': '创建原始数据表失败', + 'total_count': total_count, + 'success_count': 0, + 'failed_count': total_count, + 'failed_items': [] + } + + table_name = get_table_name(account_id) + + try: + # 注意:不开启db.begin(),由上层方法管理事务 + + # 批量查询沉降数据是否存在 + nyid_list = list(set(str(item.get('NYID')) for item in data if item.get('NYID'))) + from app.models.settlement_data import SettlementData + settlements = db.query(SettlementData).filter(SettlementData.NYID.in_(nyid_list)).all() + settlement_map = {s.NYID: s for s in settlements} + missing_nyids = set(nyid_list) - set(settlement_map.keys()) + + if missing_nyids: + raise Exception(f'以下期数在沉降表中不存在: {list(missing_nyids)}') + + # 查询现有原始数据 + existing_data = db.query(text("*")).from_statement( + text(f"SELECT * FROM `{table_name}` WHERE account_id = :account_id") + ).params(account_id=account_id).all() + + existing_map = { + f"{item[7]}_{item[8]}": item # NYID是第8个字段(索引7),sort是第9个字段(索引8) + for item in existing_data + } + logger.info(f"Found {len(existing_data)} existing records in {table_name}") + + # 批量处理插入和跳过 + to_insert = [] + skipped_count = 0 + + for item_data in data: + nyid = str(item_data.get('NYID')) + sort = item_data.get('sort') + + key = f"{nyid}_{sort}" + + if key in existing_map: + # 数据已存在,跳过 + skipped_count += 1 + else: + to_insert.append(item_data) + + logger.info(f"Filtered {skipped_count} duplicate records, {len(to_insert)} new records to insert") + + # 执行批量插入 + if to_insert: + logger.info(f"Inserting {len(to_insert)} new records") + batch_size = 1000 + for i in range(0, len(to_insert), batch_size): + batch = to_insert[i:i + batch_size] + + # 构建批量参数 + values_list = [] + params = {} + for idx, item_data in enumerate(batch): + values_list.append( + f"(:account_id_{idx}, :bfpcode_{idx}, :mtime_{idx}, :bffb_{idx}, " + f":bfpl_{idx}, :bfpvalue_{idx}, :NYID_{idx}, :sort_{idx})" + ) + params.update({ + f"account_id_{idx}": account_id, + f"bfpcode_{idx}": item_data.get('bfpcode'), + f"mtime_{idx}": item_data.get('mtime'), + f"bffb_{idx}": item_data.get('bffb'), + f"bfpl_{idx}": item_data.get('bfpl'), + f"bfpvalue_{idx}": item_data.get('bfpvalue'), + f"NYID_{idx}": item_data.get('NYID'), + f"sort_{idx}": item_data.get('sort') + }) + + # 批量插入SQL + insert_sql = f""" + INSERT INTO `{table_name}` + (account_id, bfpcode, mtime, bffb, bfpl, bfpvalue, NYID, sort) + VALUES {", ".join(values_list)} + """ + final_sql = text(insert_sql) + db.execute(final_sql, params) + success_count += len(batch) + logger.info(f"Inserted batch {i//batch_size + 1}: {len(batch)} records") + + # 注意:不在这里提交事务,由上层方法管理事务 + logger.info(f"Batch import original data completed. Success: {success_count}, Failed: {failed_count}") + + return { + 'success': True, + 'message': '批量导入完成' if failed_count == 0 else f'部分导入失败', + 'total_count': total_count, + 'success_count': success_count, + 'failed_count': failed_count, + 'failed_items': failed_items + } + + except Exception as e: + # 注意:不在这里回滚,由上层方法管理事务 + logger.error(f"Batch import original data failed: {str(e)}") + return { + 'success': False, + 'message': f'批量导入失败: {str(e)}', + 'total_count': total_count, + 'success_count': 0, + 'failed_count': total_count, + 'failed_items': failed_items + } diff --git a/upload_app/direct_import/direct_batch_importer.py b/upload_app/direct_import/direct_batch_importer.py new file mode 100644 index 0000000..1a20643 --- /dev/null +++ b/upload_app/direct_import/direct_batch_importer.py @@ -0,0 +1,470 @@ +""" +直接批量导入脚本 - 整合版 +跳过HTTP请求,直接从parquet文件读取并插入数据库 + +功能: +1. 扫描parquet文件 +2. 读取和验证数据 +3. 直接插入数据库 +4. 支持5种数据类型:sections, checkpoints, settlement_data, level_data, original_data + +使用方法: +python direct_batch_importer.py +""" +import os +import sys +import json +import time +import logging +from pathlib import Path +from typing import Dict, List, Any, Optional + +# 添加项目根目录到Python路径 +PROJECT_ROOT = Path(__file__).parent.parent.parent.absolute() +sys.path.insert(0, str(PROJECT_ROOT)) + +# 导入配置和服务 +from config import * +from database_service import DatabaseService + +# 导入pandas +try: + import pandas as pd +except ImportError: + print("错误:需要安装pandas,请运行:pip install pandas") + sys.exit(1) + +# 导入原始数据动态表支持 +from app.models.original_data import get_original_data_model, get_table_name + +# 配置日志 +logging.basicConfig( + level=getattr(logging, LOG_LEVEL), + format=LOG_FORMAT, + handlers=[ + logging.StreamHandler(), # 输出到控制台 + logging.FileHandler(PROJECT_ROOT / "direct_import.log", mode='a', encoding='utf-8') # 输出到文件 + ] +) +logger = logging.getLogger(__name__) + +class DirectBatchImporter: + """直接批量导入器 - 整合文件读取和数据库操作""" + + def __init__(self): + """初始化导入器""" + self.db_service = DatabaseService() + logger.info("=" * 60) + logger.info("直接批量导入器启动") + logger.info("=" * 60) + + # ==================== 文件操作相关 ==================== + + def scan_all_parquet(self, root_dir: Path) -> Dict[str, List[str]]: + """ + 递归扫描并分类Parquet文件,过滤空文件 + 复用 insert_all.py 的逻辑 + """ + classified_files = {data_type: [] for data_type in DATA_TYPE_MAPPING.keys()} + + logger.info(f"扫描目录:{root_dir}") + + if not root_dir.exists(): + logger.error(f"数据目录不存在:{root_dir}") + return classified_files + + for root, dirs, files in os.walk(root_dir): + # 匹配目录关键词 + matched_data_type = None + for data_type, (dir_keyword, file_keyword, _, _) in DATA_TYPE_MAPPING.items(): + if dir_keyword in root: + matched_data_type = data_type + logger.debug(f"目录匹配:{root} → 类型:{data_type}") + break + + if not matched_data_type: + logger.debug(f"跳过目录:{root}") + continue + + # 匹配文件关键词并过滤空文件 + # file_keyword 从上面的循环中获取 + + for file in files: + if file.endswith(".parquet") and file_keyword in file: + file_path = os.path.abspath(os.path.join(root, file)) + if os.path.getsize(file_path) > MIN_FILE_SIZE: + classified_files[matched_data_type].append(file_path) + logger.info(f"发现文件:{file_path}") + else: + logger.warning(f"跳过空文件:{file_path} (大小: {os.path.getsize(file_path)} bytes)") + + # 打印扫描结果 + logger.info("\n=== 扫描完成 ===") + for data_type, paths in classified_files.items(): + logger.info(f"{data_type}: {len(paths)} 个文件") + + return classified_files + + def read_parquet_by_type(self, file_paths: List[str], data_type: str) -> List[Dict[str, Any]]: + """ + 读取Parquet文件,处理空值和字段补充 + 复用 insert_all.py 的逻辑 + """ + data_list = [] + critical_fields = CRITICAL_FIELDS.get(data_type, []) + + logger.info(f"\n开始读取 {data_type} 数据,共 {len(file_paths)} 个文件") + + for file_path in file_paths: + try: + # 读取并处理空值 + df = pd.read_parquet(file_path) + df = df.fillna("") + file_basename = os.path.basename(file_path) + + # 打印文件实际列名 + actual_columns = df.columns.tolist() + logger.info(f"[读取] {file_basename} 实际列名:{actual_columns}") + + # 校验核心字段是否存在 + missing_fields = [f for f in critical_fields if f not in actual_columns] + if missing_fields: + logger.warning(f"[读取] {file_basename} 缺失字段:{missing_fields} → 跳过") + continue + + # 转换为字典列表并过滤空记录 + records = df.to_dict("records") + valid_records = [r for r in records if any(r.values())] + + if not valid_records: + logger.warning(f"[读取] {file_basename} 无有效记录 → 跳过") + continue + + # 字段格式化 + for record in valid_records: + self._format_record_fields(record, data_type) + + # 累加数据 + data_list.extend(valid_records) + logger.info(f"[读取] {file_basename} 处理完成 → 有效记录:{len(valid_records)}条,累计:{len(data_list)}条") + + except Exception as e: + logger.error(f"[读取] {os.path.basename(file_path)} 读取失败:{str(e)} → 跳过") + continue + + logger.info(f"\n=== {data_type} 数据读取总结 ===") + logger.info(f"总文件数:{len(file_paths)} 个") + logger.info(f"有效记录数:{len(data_list)} 条") + + return data_list + + def _format_record_fields(self, record: Dict[str, Any], data_type: str) -> None: + """格式化记录字段(类型转换、字段补充等)""" + # 字段补充 + if data_type in TYPE_CONVERSION and "account_id" in TYPE_CONVERSION[data_type]: + if "account_id" not in record or not record["account_id"]: + record["account_id"] = DEFAULT_ACCOUNT_ID + logger.debug(f"补充 account_id={DEFAULT_ACCOUNT_ID}") + + # 数值型字段强制转换 + type_conversion = TYPE_CONVERSION.get(data_type, {}) + for field, convert_func in type_conversion.items(): + if field in record and record[field] is not None: + try: + record[field] = convert_func(record[field]) + logger.debug(f"{field} 转换为 {type(record[field]).__name__}: {record[field]}") + except (ValueError, TypeError) as e: + logger.warning(f"字段 {field} 转换失败:{record[field]} → {e}") + # 设置默认值 + if convert_func == int: + record[field] = 0 + elif convert_func == str: + record[field] = "" + + # ==================== 数据导入相关 ==================== + + def batch_import_data(self, data_list: List[Dict[str, Any]], data_type: str) -> Dict[str, Any]: + """ + 批量导入数据到数据库 + """ + if not data_list: + logger.warning(f"无 {data_type} 数据 → 跳过") + return {'success': True, 'message': '无数据可导入', 'total_count': 0, 'success_count': 0, 'failed_count': 0, 'failed_items': []} + + logger.info(f"\n开始导入 {data_type} 数据,共 {len(data_list)} 条记录") + + # 获取数据库会话 + db = self.db_service.get_db_session() + + try: + # 根据数据类型调用对应的导入方法 + if data_type == "section": + result = self.db_service.batch_import_sections(db, data_list) + elif data_type == "checkpoint": + result = self.db_service.batch_import_checkpoints(db, data_list) + elif data_type == "settlement": + result = self.db_service.batch_import_settlement_data(db, data_list) + elif data_type == "level": + result = self.db_service.batch_import_level_data(db, data_list) + elif data_type == "original": + result = self.db_service.batch_import_original_data(db, data_list) + else: + result = { + 'success': False, + 'message': f'不支持的数据类型:{data_type}', + 'total_count': len(data_list), + 'success_count': 0, + 'failed_count': len(data_list), + 'failed_items': [{'data': {}, 'error': f'不支持的数据类型:{data_type}'}] + } + + logger.info(f"导入结果:{json.dumps(result, ensure_ascii=False, indent=2)}") + return result + + except Exception as e: + logger.error(f"导入 {data_type} 数据时发生异常:{str(e)}", exc_info=True) + return { + 'success': False, + 'message': f'导入异常:{str(e)}', + 'total_count': len(data_list), + 'success_count': 0, + 'failed_count': len(data_list), + 'failed_items': [{'data': {}, 'error': str(e)}] + } + finally: + db.close() + + def _pre_create_tables_for_original_data(self, data_list: List[Dict[str, Any]]) -> None: + """ + 预先为原始数据创建所有需要的表 - 避免导入时的事务冲突 + """ + # 收集所有需要的 account_id + account_ids = set() + for item in data_list: + if 'account_id' in item and item['account_id']: + account_ids.add(int(item['account_id'])) + + logger.info(f"[预建表] 需要为 {len(account_ids)} 个账号创建原始数据表") + + # 为每个 account_id 建表 + for account_id in sorted(account_ids): + try: + logger.info(f"[预建表] 为账号 {account_id} 创建原始数据表...") + success = self.db_service._ensure_table_exists(account_id) + if not success: + logger.error(f"[预建表] 账号 {account_id} 表创建失败") + else: + logger.info(f"[预建表] 账号 {account_id} 表创建成功") + except Exception as e: + logger.error(f"[预建表] 账号 {account_id} 表创建异常:{str(e)}", exc_info=True) + + logger.info(f"[预建表] 所有原始数据表创建完成") + + def _batch_import_original_data_by_account(self, data_list: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + 按账号分组导入原始数据 - 解决多账号数据被单一账号处理的问题 + """ + logger.info(f"\n开始按账号分组导入原始数据,共 {len(data_list)} 条记录") + + # 按账号分组 + account_groups = {} + for item in data_list: + account_id = item.get('account_id') + if account_id: + account_id = int(account_id) + if account_id not in account_groups: + account_groups[account_id] = [] + account_groups[account_id].append(item) + + logger.info(f"发现 {len(account_groups)} 个账号的数据需要导入") + + total_success = 0 + total_failed = 0 + failed_items = [] + + # 获取数据库会话 + db = self.db_service.get_db_session() + + try: + db.begin() + + # 为每个账号分别导入数据 + for account_id in sorted(account_groups.keys()): + account_data = account_groups[account_id] + logger.info(f"\n[分账号导入] 账号 {account_id}:{len(account_data)} 条记录") + + try: + # 导入该账号的数据 + result = self.db_service.batch_import_original_data(db, account_data) + + if result.get('success', False): + total_success += result.get('success_count', 0) + logger.info(f"[分账号导入] 账号 {account_id} 成功导入 {result.get('success_count', 0)} 条") + else: + total_failed += result.get('failed_count', 0) + logger.error(f"[分账号导入] 账号 {account_id} 导入失败:{result.get('message', '未知错误')}") + failed_items.extend(result.get('failed_items', [])) + + except Exception as e: + logger.error(f"[分账号导入] 账号 {account_id} 处理异常:{str(e)}", exc_info=True) + total_failed += len(account_data) + failed_items.append({ + 'account_id': account_id, + 'error': str(e) + }) + + # 提交事务 + db.commit() + logger.info(f"\n[分账号导入] 所有账号导入完成,总计:成功 {total_success} 条,失败 {total_failed} 条") + + return { + 'success': True, + 'message': '分账号批量导入完成', + 'total_count': len(data_list), + 'success_count': total_success, + 'failed_count': total_failed, + 'failed_items': failed_items + } + + except Exception as e: + logger.error(f"分账号导入原始数据时发生异常:{str(e)}", exc_info=True) + try: + db.rollback() + except: + pass + return { + 'success': False, + 'message': f'分账号导入异常:{str(e)}', + 'total_count': len(data_list), + 'success_count': total_success, + 'failed_count': total_failed, + 'failed_items': failed_items + } + finally: + db.close() + + # ==================== 主流程 ==================== + + def run(self): + """运行主流程""" + start_time = time.time() + logger.info(f"启动时间:{time.strftime('%Y-%m-%d %H:%M:%S')}") + + # 1. 扫描所有Parquet文件 + logger.info("\n" + "=" * 60) + logger.info("第一步:扫描数据文件") + logger.info("=" * 60) + classified_files = self.scan_all_parquet(DATA_ROOT) + + if not any(classified_files.values()): + logger.error("未找到任何有效Parquet文件 → 终止程序") + return + + # 2. 按依赖顺序导入(可忽略顺序) + logger.info("\n" + "=" * 60) + logger.info("第二步:按顺序导入数据") + logger.info("=" * 60) + + # 汇总所有类型的导入结果 + import_summary = { + 'section': {'success': 0, 'failed': 0}, + 'checkpoint': {'success': 0, 'failed': 0}, + 'settlement': {'success': 0, 'failed': 0}, + 'level': {'success': 0, 'failed': 0}, + 'original': {'success': 0, 'failed': 0} + } + + # 按顺序处理每种数据类型 + for data_type, data_name in DATA_TYPE_ORDER: + logger.info(f"\n{'='*50}") + logger.info(f"处理【{data_name}】(类型:{data_type})") + logger.info(f"{'='*50}") + + # 获取文件路径 + file_paths = classified_files.get(data_type, []) + if not file_paths: + logger.info(f"[主流程] 【{data_name}】无数据文件 → 跳过") + continue + + # 读取数据 + data_list = self.read_parquet_by_type(file_paths, data_type) + if not data_list: + logger.warning(f"[主流程] 【{data_name}】无有效数据 → 跳过") + continue + + # 特殊处理:原始数据需要按账号分组导入 + if data_type == "original": + logger.info(f"[主流程] 开始为原始数据预建表...") + self._pre_create_tables_for_original_data(data_list) + + # 按账号分组导入 + logger.info(f"[主流程] 开始按账号分组导入:{len(data_list)} 条数据") + result = self._batch_import_original_data_by_account(data_list) + else: + # 导入数据 + logger.info(f"[主流程] 开始导入:{len(data_list)} 条数据,分 {len(file_paths)} 个文件") + result = self.batch_import_data(data_list, data_type) + + # 更新汇总结果 + import_summary[data_type]['success'] = result.get('success_count', 0) + import_summary[data_type]['failed'] = result.get('failed_count', 0) + + # 检查导入结果 + if not result.get('success', False): + logger.error(f"[主流程] 【{data_name}】导入失败 → 终止程序") + logger.error(f"失败原因:{result.get('message', '未知错误')}") + break + + logger.info(f"[主流程] 【{data_name}】导入完成:成功 {result.get('success_count', 0)} 条,失败 {result.get('failed_count', 0)} 条") + + # 最终统计 + end_time = time.time() + elapsed = (end_time - start_time) / 60 + + logger.info("\n" + "=" * 60) + logger.info("所有任务完成") + logger.info("=" * 60) + logger.info(f"总耗时:{elapsed:.2f} 分钟") + logger.info("导入统计:") + + total_success = 0 + total_failed = 0 + for data_type, stats in import_summary.items(): + if stats['success'] > 0 or stats['failed'] > 0: + logger.info(f" {data_type}: 成功 {stats['success']} 条,失败 {stats['failed']} 条") + total_success += stats['success'] + total_failed += stats['failed'] + + logger.info(f"总计:成功 {total_success} 条,失败 {total_failed} 条") + + # 保存导入日志 + log_file = PROJECT_ROOT / "import_summary.json" + with open(log_file, 'w', encoding='utf-8') as f: + json.dump({ + 'start_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time)), + 'end_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time)), + 'elapsed_minutes': round(elapsed, 2), + 'summary': import_summary, + 'total_success': total_success, + 'total_failed': total_failed + }, f, ensure_ascii=False, indent=2) + + logger.info(f"导入日志已保存:{log_file}") + +def main(): + """主函数""" + try: + # 创建导入器实例 + importer = DirectBatchImporter() + + # 运行导入流程 + importer.run() + + except KeyboardInterrupt: + logger.info("\n用户中断程序执行") + except Exception as e: + logger.error(f"程序执行失败:{str(e)}", exc_info=True) + sys.exit(1) + +if __name__ == "__main__": + main()