Files
railway_cloud/app/services/checkpoint.py
2025-11-17 16:14:12 +08:00

250 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from sqlalchemy.orm import Session
from typing import List, Optional, Dict, Any
from ..models.checkpoint import Checkpoint
from ..models.section_data import SectionData
from .base import BaseService
class CheckpointService(BaseService[Checkpoint]):
def __init__(self):
super().__init__(Checkpoint)
def get_by_point_id(self, db: Session, point_id: str) -> Optional[Checkpoint]:
"""根据观测点ID获取观测点"""
checkpoints = self.get_by_field(db, "point_id", point_id)
return checkpoints[0] if checkpoints else None
def search_checkpoints(self, db: Session,
aname: Optional[str] = None,
section_id: Optional[str] = None,
point_id: Optional[str] = None) -> List[Checkpoint]:
"""根据多个条件搜索观测点"""
conditions = {}
if aname is not None:
conditions["aname"] = aname
if section_id is not None:
conditions["section_id"] = section_id
if point_id is not None:
conditions["point_id"] = point_id
return self.search_by_conditions(db, conditions)
def _check_section_exists(self, db: Session, section_id: str) -> bool:
"""检查断面是否存在"""
section = db.query(SectionData).filter(SectionData.section_id == section_id).first()
return section is not None
def batch_import_checkpoints(self, db: Session, data: List) -> Dict[str, Any]:
"""
批量导入观测点数据 - 性能优化版
使用批量查询和批量操作,大幅提升导入速度
1.判断断面id是否存在不存在则跳过该条数据
2.根据观测点ID判断是否重复重复数据跳过不进行更新操作
支持事务回滚,失败时重试一次
"""
import logging
logger = logging.getLogger(__name__)
total_count = len(data)
success_count = 0
failed_count = 0
failed_items = []
if total_count == 0:
return {
'success': False,
'message': '导入数据不能为空',
'total_count': 0,
'success_count': 0,
'failed_count': 0,
'failed_items': []
}
for attempt in range(2): # 最多重试1次
try:
db.begin()
success_count = 0
failed_count = 0
failed_items = []
# ===== 性能优化1批量查询断面数据IN查询 =====
# 统一转换为字符串处理数据库section_id字段是VARCHAR类型
section_id_list = list(set(str(item.get('section_id')) for item in data if item.get('section_id')))
logger.info(f"Checking {len(section_id_list)} unique section_ids in section data")
sections = db.query(SectionData).filter(SectionData.section_id.in_(section_id_list)).all()
section_map = {s.section_id: s for s in sections}
missing_section_ids = set(section_id_list) - set(section_map.keys())
# 记录缺失的断面
for item_data in data:
section_id = str(item_data.get('section_id')) # 统一转换为字符串
if section_id in missing_section_ids:
failed_count += 1
failed_items.append({
'data': item_data,
'error': '断面ID不存在跳过插入操作'
})
# 如果所有数据都失败,直接返回
if failed_count == total_count:
db.rollback()
return {
'success': False,
'message': '所有断面ID都不存在',
'total_count': total_count,
'success_count': 0,
'failed_count': total_count,
'failed_items': failed_items
}
# ===== 性能优化2批量查询现有观测点数据IN查询 =====
# 只查询有效的断面数据
valid_items = [item for item in data if str(item.get('section_id')) not in missing_section_ids]
if valid_items:
# 统一转换为字符串处理数据库point_id字段是VARCHAR类型
point_id_list = list(set(str(item.get('point_id')) for item in valid_items if item.get('point_id')))
existing_checkpoints = db.query(Checkpoint).filter(Checkpoint.point_id.in_(point_id_list)).all()
# 使用point_id创建查找表
existing_map = {
checkpoint.point_id: checkpoint
for checkpoint in existing_checkpoints
}
logger.info(f"Found {len(existing_checkpoints)} existing checkpoints")
# ===== 性能优化3批量处理插入和跳过 =====
to_insert = []
for item_data in valid_items:
point_id = str(item_data.get('point_id')) # 统一转换为字符串
if point_id in existing_map:
# 数据已存在,跳过
logger.info(f"Continue checkpoint data: {point_id}")
failed_count += 1
failed_items.append({
'data': item_data,
'error': '数据已存在,跳过插入操作'
})
else:
# 记录需要插入的数据
to_insert.append(item_data)
# ===== 执行批量插入 =====
if to_insert:
logger.info(f"Inserting {len(to_insert)} new records")
# 分批插入每批500条避免SQL过长
batch_size = 500
for i in range(0, len(to_insert), batch_size):
batch = to_insert[i:i + batch_size]
try:
checkpoint_list = [
Checkpoint(
point_id=str(item.get('point_id')), # 统一转换为字符串
aname=item.get('aname'),
section_id=str(item.get('section_id')), # 统一转换为字符串
burial_date=item.get('burial_date')
)
for item in batch
]
db.add_all(checkpoint_list)
success_count += len(batch)
logger.info(f"Inserted batch {i//batch_size + 1}: {len(batch)} records")
except Exception as e:
failed_count += len(batch)
failed_items.extend([
{
'data': item,
'error': f'插入失败: {str(e)}'
}
for item in batch
])
logger.error(f"Failed to insert batch: {str(e)}")
raise e
# 如果有失败记录,不提交事务
if failed_items:
db.rollback()
return {
'success': False,
'message': f'批量导入失败: {len(failed_items)}条记录处理失败',
'total_count': total_count,
'success_count': success_count,
'failed_count': failed_count,
'failed_items': failed_items
}
db.commit()
logger.info(f"Batch import checkpoints completed. Success: {success_count}, Failed: {failed_count}")
break
except Exception as e:
db.rollback()
logger.warning(f"Batch import attempt {attempt + 1} failed: {str(e)}")
if attempt == 1: # 最后一次重试失败
logger.error("Batch import checkpoints failed after retries")
return {
'success': False,
'message': f'批量导入失败: {str(e)}',
'total_count': total_count,
'success_count': 0,
'failed_count': total_count,
'failed_items': failed_items
}
return {
'success': True,
'message': '批量导入完成' if failed_count == 0 else f'部分导入失败',
'total_count': total_count,
'success_count': success_count,
'failed_count': failed_count,
'failed_items': failed_items
}
def get_by_nyid(self, db: Session, nyid: str) -> List[Checkpoint]:
"""根据NYID获取所有相关的测点信息"""
return self.get_by_field(db, "NYID", nyid)
# 通过section_id获取所有观测点数据
def get_by_section_id(self, db: Session, section_id: str) -> List[Checkpoint]:
"""根据section_id获取所有相关的测点信息"""
return self.get_by_field(db, "section_id", section_id)
def get_by_section_ids_batch(self, db: Session, section_ids: List[str]) -> List[Checkpoint]:
"""批量根据section_id列表获取所有观测点数据使用IN查询优化性能"""
if not section_ids:
return []
return db.query(Checkpoint).filter(Checkpoint.section_id.in_(section_ids)).all()
def get_by_section_ids(self, db: Session, section_ids: List[str]) -> List[Checkpoint]:
"""根据多个section_id批量获取观测点数据"""
return db.query(Checkpoint).filter(Checkpoint.section_id.in_(section_ids)).all()
def get_point_ids_by_linecode(self, db: Session, linecode: str) -> List[str]:
"""
根据水准线路编码获取全部观测点ID
业务逻辑:
1. linecode在水准数据表(LevelData)查询获取NYID去重
2. NYID在沉降数据表(SettlementData)找到全部沉降数据提取point_id去重
3. 响应point_id数据集
使用in查询避免循环查询提高查询效率
"""
from ..models.level_data import LevelData
from ..models.settlement_data import SettlementData
# 1. 根据linecode查询水准数据表获取所有NYID去重
nyid_query = db.query(LevelData.NYID).filter(LevelData.linecode == linecode).distinct()
nyid_list = [result.NYID for result in nyid_query.all() if result.NYID]
if not nyid_list:
return []
# 2. 根据NYID列表查询沉降数据表获取所有point_id去重
point_id_query = db.query(SettlementData.point_id).filter(
SettlementData.NYID.in_(nyid_list)
).distinct()
point_id_list = [result.point_id for result in point_id_query.all() if result.point_id]
return point_id_list