diff --git a/app/services/original_data.py b/app/services/original_data.py index 8b2ef88..71f7fd2 100644 --- a/app/services/original_data.py +++ b/app/services/original_data.py @@ -337,101 +337,46 @@ class OriginalDataService(BaseService[OriginalData]): table_name = self._get_table_name(account_id) - for attempt in range(2): # 最多重试1次 + # 检查是否已在事务中(避免重复开始事务) + in_transaction = db.in_transaction() + + # 如果不在事务中,才需要手动管理事务 + if not in_transaction: + for attempt in range(2): # 最多重试1次 + try: + db.begin() + success_count = 0 + failed_count = 0 + failed_items = [] + + # 执行数据导入操作 + success_count = self._execute_import(db, table_name, data, account_id) + db.commit() + logger.info(f"Batch import original data completed. Success: {success_count}, Failed: {failed_count}") + break + + except Exception as e: + db.rollback() + logger.warning(f"Batch import attempt {attempt + 1} failed: {str(e)}") + if attempt == 1: # 最后一次重试失败 + logger.error("Batch import original data failed after retries") + return { + 'success': False, + 'message': f'批量导入失败: {str(e)}', + 'total_count': total_count, + 'success_count': 0, + 'failed_count': total_count, + 'failed_items': failed_items + } + else: + # 如果已在事务中,直接执行操作(不管理事务) try: - db.begin() - success_count = 0 - failed_count = 0 - failed_items = [] - - nyid = str(data[0].get('NYID')) # 统一转换为字符串 - # 检查该期数数据是否已存在 - check_query = text(f"SELECT COUNT(*) as cnt FROM `{table_name}` WHERE NYID = :nyid") - is_exists = db.execute(check_query, {"nyid": nyid}).fetchone()[0] - - if is_exists > 0: - db.rollback() - return { - 'success': True, - 'message': '数据已存在', - 'total_count': 0, - 'success_count': success_count, - 'failed_count': failed_count, - 'failed_items': failed_items - } - - # ===== 性能优化:批量查询沉降数据 ===== - # 统一转换为字符串处理(数据库NYID字段是VARCHAR类型) - nyid_list = list(set(str(item.get('NYID')) for item in data if item.get('NYID'))) - settlements = db.query(SettlementData).filter(SettlementData.NYID.in_(nyid_list)).all() - settlement_map = {s.NYID: s for s in settlements} - missing_nyids = set(nyid_list) - set(settlement_map.keys()) - - if missing_nyids: - logger.warning(f"[批量导入原始数据] 批量查询settlement数据失败, Nyid: {list(missing_nyids)}") - db.rollback() - return { - 'success': False, - 'message': f'以下期数在沉降表中不存在: {list(missing_nyids)}', - 'total_count': total_count, - 'success_count': 0, - 'failed_count': total_count, - 'failed_items': [] - } - - # ===== 性能优化:使用批量插入 ===== - # 将数据分组,每组1000条(MySQL默认支持) - batch_size = 1000 - for i in range(0, len(data), batch_size): - batch_data = data[i:i + batch_size] - - # 构建批量参数 - values_list = [] - params = {} - for idx, item_data in enumerate(batch_data): - values_list.append( - f"(:account_id_{idx}, :bfpcode_{idx}, :mtime_{idx}, :bffb_{idx}, " - f":bfpl_{idx}, :bfpvalue_{idx}, :NYID_{idx}, :sort_{idx})" - ) - params.update({ - f"account_id_{idx}": account_id, - f"bfpcode_{idx}": item_data.get('bfpcode'), - f"mtime_{idx}": item_data.get('mtime'), - f"bffb_{idx}": item_data.get('bffb'), - f"bfpl_{idx}": item_data.get('bfpl'), - f"bfpvalue_{idx}": item_data.get('bfpvalue'), - f"NYID_{idx}": item_data.get('NYID'), - f"sort_{idx}": item_data.get('sort') - }) - - # 批量插入SQL - 使用字符串拼接(修复TextClause拼接问题) - insert_sql = f""" - INSERT INTO `{table_name}` - (account_id, bfpcode, mtime, bffb, bfpl, bfpvalue, NYID, sort) - VALUES {", ".join(values_list)} - """ - final_sql = text(insert_sql) - db.execute(final_sql, params) - success_count += len(batch_data) - logger.info(f"Inserted batch {i//batch_size + 1}: {len(batch_data)} records") - - db.commit() - logger.info(f"Batch import original data completed. Success: {success_count}, Failed: {failed_count}") - break - + success_count = self._execute_import(db, table_name, data, account_id) + logger.info(f"Batch import original data completed in existing transaction. Success: {success_count}") except Exception as e: - db.rollback() - logger.warning(f"Batch import attempt {attempt + 1} failed: {str(e)}") - if attempt == 1: # 最后一次重试失败 - logger.error("Batch import original data failed after retries") - return { - 'success': False, - 'message': f'批量导入失败: {str(e)}', - 'total_count': total_count, - 'success_count': 0, - 'failed_count': total_count, - 'failed_items': failed_items - } + logger.error(f"Batch import failed in existing transaction: {str(e)}") + # 抛出异常,让外部处理事务回滚 + raise return { 'success': True, @@ -440,4 +385,64 @@ class OriginalDataService(BaseService[OriginalData]): 'success_count': success_count, 'failed_count': failed_count, 'failed_items': failed_items - } \ No newline at end of file + } + + def _execute_import(self, db: Session, table_name: str, data: List, account_id: int) -> int: + """执行数据导入操作(抽取的公共逻辑)""" + nyid = str(data[0].get('NYID')) # 统一转换为字符串 + # 检查该期数数据是否已存在 + check_query = text(f"SELECT COUNT(*) as cnt FROM `{table_name}` WHERE NYID = :nyid") + is_exists = db.execute(check_query, {"nyid": nyid}).fetchone()[0] + + if is_exists > 0: + return 0 + + # ===== 性能优化:批量查询沉降数据 ===== + # 统一转换为字符串处理(数据库NYID字段是VARCHAR类型) + nyid_list = list(set(str(item.get('NYID')) for item in data if item.get('NYID'))) + from ..models.settlement_data import SettlementData + settlements = db.query(SettlementData).filter(SettlementData.NYID.in_(nyid_list)).all() + settlement_map = {s.NYID: s for s in settlements} + missing_nyids = set(nyid_list) - set(settlement_map.keys()) + + if missing_nyids: + raise Exception(f'以下期数在沉降表中不存在: {list(missing_nyids)}') + + # ===== 性能优化:使用批量插入 ===== + # 将数据分组,每组1000条(MySQL默认支持) + batch_size = 1000 + total_inserted = 0 + for i in range(0, len(data), batch_size): + batch_data = data[i:i + batch_size] + + # 构建批量参数 + values_list = [] + params = {} + for idx, item_data in enumerate(batch_data): + values_list.append( + f"(:account_id_{idx}, :bfpcode_{idx}, :mtime_{idx}, :bffb_{idx}, " + f":bfpl_{idx}, :bfpvalue_{idx}, :NYID_{idx}, :sort_{idx})" + ) + params.update({ + f"account_id_{idx}": account_id, + f"bfpcode_{idx}": item_data.get('bfpcode'), + f"mtime_{idx}": item_data.get('mtime'), + f"bffb_{idx}": item_data.get('bffb'), + f"bfpl_{idx}": item_data.get('bfpl'), + f"bfpvalue_{idx}": item_data.get('bfpvalue'), + f"NYID_{idx}": item_data.get('NYID'), + f"sort_{idx}": item_data.get('sort') + }) + + # 批量插入SQL - 使用字符串拼接(修复TextClause拼接问题) + insert_sql = f""" + INSERT INTO `{table_name}` + (account_id, bfpcode, mtime, bffb, bfpl, bfpvalue, NYID, sort) + VALUES {", ".join(values_list)} + """ + final_sql = text(insert_sql) + db.execute(final_sql, params) + total_inserted += len(batch_data) + logger.info(f"Inserted batch {i//batch_size + 1}: {len(batch_data)} records") + + return total_inserted \ No newline at end of file