Files
railway_cloud/app/services/checkpoint.py

166 lines
7.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from sqlalchemy.orm import Session
from typing import List, Optional, Dict, Any
from ..models.checkpoint import Checkpoint
from ..models.section_data import SectionData
from .base import BaseService
class CheckpointService(BaseService[Checkpoint]):
def __init__(self):
super().__init__(Checkpoint)
def get_by_point_id(self, db: Session, point_id: str) -> Optional[Checkpoint]:
"""根据观测点ID获取观测点"""
checkpoints = self.get_by_field(db, "point_id", point_id)
return checkpoints[0] if checkpoints else None
def search_checkpoints(self, db: Session,
aname: Optional[str] = None,
section_id: Optional[str] = None,
point_id: Optional[str] = None) -> List[Checkpoint]:
"""根据多个条件搜索观测点"""
conditions = {}
if aname is not None:
conditions["aname"] = aname
if section_id is not None:
conditions["section_id"] = section_id
if point_id is not None:
conditions["point_id"] = point_id
return self.search_by_conditions(db, conditions)
def _check_section_exists(self, db: Session, section_id: str) -> bool:
"""检查断面是否存在"""
section = db.query(SectionData).filter(SectionData.section_id == section_id).first()
return section is not None
def batch_import_checkpoints(self, db: Session, data: List) -> Dict[str, Any]:
"""
批量导入观测点数据根据观测点ID判断是否重复重复数据改为更新操作
判断断面id是否存在不存在则全部不导入
支持事务回滚,失败时重试一次
"""
import logging
logger = logging.getLogger(__name__)
total_count = len(data)
success_count = 0
failed_count = 0
failed_items = []
for attempt in range(2): # 最多重试1次
try:
db.begin()
success_count = 0
failed_count = 0
failed_items = []
for item_data in data:
try:
# 判断断面id是否存在
if not self._check_section_exists(db, item_data.get('section_id')):
logger.error(f"Section {item_data.get('section_id')} not found")
raise Exception(f"Section {item_data.get('section_id')} not found")
checkpoint = self.get_by_point_id(db, item_data.get('point_id'))
if checkpoint:
# 更新操作
checkpoint.aname = item_data.get('aname')
checkpoint.section_id = item_data.get('section_id')
checkpoint.burial_date = item_data.get('burial_date')
logger.info(f"Updated checkpoint: {item_data.get('point_id')}")
else:
# 新增操作
checkpoint = Checkpoint(
point_id=item_data.get('point_id'),
aname=item_data.get('aname'),
section_id=item_data.get('section_id'),
burial_date=item_data.get('burial_date'),
)
db.add(checkpoint)
logger.info(f"Created checkpoint: {item_data.get('point_id')}")
success_count += 1
except Exception as e:
failed_count += 1
failed_items.append({
'data': item_data,
'error': str(e)
})
logger.error(f"Failed to process checkpoint {item_data.get('point_id')}: {str(e)}")
raise e
db.commit()
logger.info(f"Batch import checkpoints completed. Success: {success_count}, Failed: {failed_count}")
break
except Exception as e:
db.rollback()
logger.warning(f"Batch import attempt {attempt + 1} failed: {str(e)}")
if attempt == 1: # 最后一次重试失败
logger.error("Batch import checkpoints failed after retries")
return {
'success': False,
'message': f'批量导入失败: {str(e)}',
'total_count': total_count,
'success_count': 0,
'failed_count': total_count,
'failed_items': failed_items
}
return {
'success': True,
'message': '批量导入完成' if failed_count == 0 else f'部分导入失败',
'total_count': total_count,
'success_count': success_count,
'failed_count': failed_count,
'failed_items': failed_items
}
def get_by_nyid(self, db: Session, nyid: str) -> List[Checkpoint]:
"""根据NYID获取所有相关的测点信息"""
return self.get_by_field(db, "NYID", nyid)
# 通过section_id获取所有观测点数据
def get_by_section_id(self, db: Session, section_id: str) -> List[Checkpoint]:
"""根据section_id获取所有相关的测点信息"""
return self.get_by_field(db, "section_id", section_id)
def get_by_section_ids_batch(self, db: Session, section_ids: List[str]) -> List[Checkpoint]:
"""批量根据section_id列表获取所有观测点数据使用IN查询优化性能"""
if not section_ids:
return []
return db.query(Checkpoint).filter(Checkpoint.section_id.in_(section_ids)).all()
def get_by_section_ids(self, db: Session, section_ids: List[str]) -> List[Checkpoint]:
"""根据多个section_id批量获取观测点数据"""
return db.query(Checkpoint).filter(Checkpoint.section_id.in_(section_ids)).all()
def get_point_ids_by_linecode(self, db: Session, linecode: str) -> List[str]:
"""
根据水准线路编码获取全部观测点ID
业务逻辑:
1. linecode在水准数据表(LevelData)查询获取NYID去重
2. NYID在沉降数据表(SettlementData)找到全部沉降数据提取point_id去重
3. 响应point_id数据集
使用in查询避免循环查询提高查询效率
"""
from ..models.level_data import LevelData
from ..models.settlement_data import SettlementData
# 1. 根据linecode查询水准数据表获取所有NYID去重
nyid_query = db.query(LevelData.NYID).filter(LevelData.linecode == linecode).distinct()
nyid_list = [result.NYID for result in nyid_query.all() if result.NYID]
if not nyid_list:
return []
# 2. 根据NYID列表查询沉降数据表获取所有point_id去重
point_id_query = db.query(SettlementData.point_id).filter(
SettlementData.NYID.in_(nyid_list)
).distinct()
point_id_list = [result.point_id for result in point_id_query.all() if result.point_id]
return point_id_list