Files
railway_cloud/app/services/section_data.py
2025-11-17 16:37:14 +08:00

396 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from sqlalchemy.orm import Session
from typing import List, Optional, Dict, Any
from ..models.section_data import SectionData
from .base import BaseService
from .checkpoint import CheckpointService
from .settlement_data import SettlementDataService
from .level_data import LevelDataService
from .original_data import OriginalDataService
from typing import Dict
class SectionDataService(BaseService[SectionData]):
def __init__(self):
super().__init__(SectionData)
self.checkpoint_service = CheckpointService()
self.settlement_service = SettlementDataService()
self.level_service = LevelDataService()
self.original_service = OriginalDataService()
def get_by_section_id(self, db: Session, section_id: str) -> Optional[SectionData]:
"""根据断面ID获取断面数据"""
sections = self.get_by_field(db, "section_id", section_id)
return sections[0] if sections else None
def get_by_account_id(self, db: Session, account_id: str) -> List[SectionData]:
"""根据账号ID获取断面数据"""
accounts = self.get_by_field(db, "account_id", account_id)
return accounts if accounts else []
def get_by_account_id_batch(self, db: Session, account_ids: List[str]) -> List[SectionData]:
"""批量根据账号ID列表获取断面数据使用IN查询优化性能"""
if not account_ids:
return []
return db.query(SectionData).filter(SectionData.account_id.in_(account_ids)).all()
def get_by_number(self, db: Session, number: str) -> List[SectionData]:
"""根据桥梁墩(台)编号获取断面数据"""
return self.get_by_field(db, "number", number)
def search_section_data(self, db: Session,
id: Optional[int] = None,
section_id: Optional[str] = None,
mileage: Optional[str] = None,
work_site: Optional[str] = None,
number: Optional[str] = None,
status: Optional[str] = None,
basic_types: Optional[str] = None,
account_id: Optional[str] = None,
skip: int = 0,
limit: Optional[int] = None) -> List[SectionData]:
"""根据多个条件搜索断面数据"""
conditions = {}
if section_id is not None:
conditions["section_id"] = section_id
if work_site is not None:
conditions["work_site"] = work_site
if number is not None:
conditions["number"] = number
if status is not None:
conditions["status"] = status
if basic_types is not None:
conditions["basic_types"] = basic_types
if id is not None:
conditions['id'] = id
if mileage is not None:
conditions['mileage'] = mileage
if account_id is not None:
conditions['account_id'] = account_id
return self.search_by_conditions(db, conditions, skip, limit)
def search_sections_with_checkpoints(self, db: Session,
id: Optional[int] = None,
section_id: Optional[str] = None,
mileage: Optional[str] = None,
work_site: Optional[str] = None,
number: Optional[str] = None,
status: Optional[str] = None,
account_id: Optional[str] = None,
skip: int = 0,
limit: Optional[int] = None) -> Dict[str, Any]:
"""查询断面数据并返回带观测点的结果(支持分页)"""
# 构建查询条件
conditions = {}
if section_id is not None:
conditions["section_id"] = section_id
if work_site is not None:
conditions["work_site"] = work_site
if number is not None:
conditions["number"] = number
if status is not None:
conditions["status"] = status
if id is not None:
conditions['id'] = id
if mileage is not None:
conditions['mileage'] = mileage
if account_id is not None:
conditions['account_id'] = account_id
# 获取总数
total_count = self.search_by_conditions_count(db, conditions)
# 获取分页数据
sections = self.search_by_conditions(db, conditions, skip, limit)
result = []
for section in sections:
checkpoints = self.checkpoint_service.get_by_section_id(db, section.section_id)
section_dict = {
"id": section.id,
"section_id": section.section_id,
"mileage": section.mileage,
"work_site": section.work_site,
"basic_types": section.basic_types,
"height": section.height,
"status": section.status,
"number": section.number,
"transition_paragraph": section.transition_paragraph,
"design_fill_height": section.design_fill_height,
"compression_layer_thickness": section.compression_layer_thickness,
"treatment_depth": section.treatment_depth,
"foundation_treatment_method": section.foundation_treatment_method,
"rock_mass_classification": section.rock_mass_classification,
"account_id": section.account_id,
"checkpoints": [
{
"id": cp.id,
"point_id": cp.point_id,
"aname": cp.aname,
"burial_date": cp.burial_date,
"section_id": cp.section_id
} for cp in checkpoints
]
}
result.append(section_dict)
return {
"data": result,
"total": total_count,
"skip": skip,
"limit": limit
}
def get_section_with_checkpoints(self, db: Session, section_id: str) -> Dict[str, Any]:
"""获取断面数据及其关联的观测点"""
section = self.get_by_section_id(db, section_id)
if not section:
return {}
checkpoints = self.checkpoint_service.get_by_section_id(db, section_id)
return {
"section": section,
"checkpoints": checkpoints,
"checkpoint_count": len(checkpoints)
}
def get_section_with_all_data(self, db: Session, section_id: str) -> Dict[str, Any]:
"""获取断面数据及其所有关联数据(观测点、沉降数据、原始数据)"""
section = self.get_by_section_id(db, section_id)
if not section:
return {}
checkpoints = self.checkpoint_service.get_by_section_id(db, section_id)
all_settlement_data = []
all_original_data = []
all_level_data = []
for checkpoint in checkpoints:
point_id = checkpoint.point_id
settlement_data = self.settlement_service.get_by_point_id(db, point_id)
all_settlement_data.extend(settlement_data)
for settlement in settlement_data:
nyid = settlement.NYID
level_data = self.level_service.get_by_nyid(db, nyid)
all_level_data.extend(level_data)
original_data = self.original_service.get_by_nyid(db, nyid)
all_original_data.extend(original_data)
return {
"section": section,
"checkpoints": checkpoints,
"settlement_data": all_settlement_data,
"level_data": all_level_data,
"original_data": all_original_data,
"summary": {
"checkpoint_count": len(checkpoints),
"settlement_data_count": len(all_settlement_data),
"level_data_count": len(all_level_data),
"original_data_count": len(all_original_data)
}
}
def get_sections_by_checkpoint_point_id(self, db: Session, point_id: str) -> List[SectionData]:
"""根据观测点ID反向查找断面数据"""
checkpoint = self.checkpoint_service.get_by_point_id(db, point_id)
if checkpoint:
return [self.get_by_section_id(db, checkpoint.section_id)]
return []
def get_sections_by_settlement_nyid(self, db: Session, nyid: str) -> List[SectionData]:
"""根据期数ID反向查找相关的断面数据"""
settlement_data = self.settlement_service.get_by_nyid(db, nyid)
sections = []
for settlement in settlement_data:
point_id = settlement.point_id
checkpoint = self.checkpoint_service.get_by_point_id(db, point_id)
if checkpoint:
section = self.get_by_section_id(db, checkpoint.section_id)
if section and section not in sections:
sections.append(section)
return sections
def get_settlement_data_by_section(self, db: Session, section_id: str) -> List:
"""获取指定断面的所有沉降数据"""
checkpoints = self.checkpoint_service.get_by_section_id(db, section_id)
all_settlement_data = []
for checkpoint in checkpoints:
settlement_data = self.settlement_service.get_by_point_id(db, checkpoint.point_id)
all_settlement_data.extend(settlement_data)
return all_settlement_data
def get_original_data_by_section(self, db: Session, section_id: str) -> List:
"""获取指定断面的所有原始数据"""
settlement_data = self.get_settlement_data_by_section(db, section_id)
all_original_data = []
for settlement in settlement_data:
original_data = self.original_service.get_by_nyid(db, settlement.NYID)
all_original_data.extend(original_data)
return all_original_data
def get_level_data_by_section(self, db: Session, section_id: str) -> List:
"""获取指定断面的所有水准数据"""
settlement_data = self.get_settlement_data_by_section(db, section_id)
all_level_data = []
for settlement in settlement_data:
level_data = self.level_service.get_by_nyid(db, settlement.NYID)
all_level_data.extend(level_data)
return all_level_data
def batch_import_sections(self, db: Session, data: List) -> Dict[str, Any]:
"""
批量导入断面数据 - 性能优化版
使用批量查询和批量操作,大幅提升导入速度
根据断面ID判断是否重复重复数据跳过不进行更新操作
支持事务回滚,失败时重试一次
"""
import logging
logger = logging.getLogger(__name__)
total_count = len(data)
success_count = 0
failed_count = 0
failed_items = []
if total_count == 0:
return {
'success': False,
'message': '导入数据不能为空',
'total_count': 0,
'success_count': 0,
'failed_count': 0,
'failed_items': []
}
for attempt in range(2): # 最多重试1次
try:
db.begin()
success_count = 0
failed_count = 0
failed_items = []
# ===== 性能优化1批量查询现有断面数据IN查询 =====
# 统一转换为字符串处理数据库section_id字段是VARCHAR类型
section_id_list = list(set(str(item.get('section_id')) for item in data if item.get('section_id')))
logger.info(f"Checking {len(section_id_list)} unique section_ids")
existing_sections = db.query(SectionData).filter(SectionData.section_id.in_(section_id_list)).all()
# 使用section_id创建查找表
existing_map = {
section.section_id: section
for section in existing_sections
}
logger.info(f"Found {len(existing_sections)} existing sections")
# ===== 性能优化2批量处理插入和跳过 =====
to_insert = []
for item_data in data:
section_id = str(item_data.get('section_id')) # 统一转换为字符串
if section_id in existing_map:
# 数据已存在,跳过
logger.info(f"Continue section data: {section_id}")
failed_count += 1
failed_items.append({
'data': item_data,
'error': '数据已存在,跳过插入操作'
})
else:
# 记录需要插入的数据
to_insert.append(item_data)
# ===== 执行批量插入 =====
if to_insert:
logger.info(f"Inserting {len(to_insert)} new records")
# 分批插入每批500条避免SQL过长
batch_size = 500
for i in range(0, len(to_insert), batch_size):
batch = to_insert[i:i + batch_size]
try:
section_data_list = [
SectionData(
section_id=str(item.get('section_id')), # 统一转换为字符串
mileage=item.get('mileage'),
work_site=item.get('work_site'),
basic_types=item.get('basic_types'),
height=item.get('height'),
status=item.get('status'),
number=str(item.get('number')) if item.get('number') else None, # 统一转换为字符串
transition_paragraph=item.get('transition_paragraph'),
design_fill_height=item.get('design_fill_height'),
compression_layer_thickness=item.get('compression_layer_thickness'),
treatment_depth=item.get('treatment_depth'),
foundation_treatment_method=item.get('foundation_treatment_method'),
rock_mass_classification=item.get('rock_mass_classification'),
account_id=str(item.get('account_id')) if item.get('account_id') else None # 统一转换为字符串
)
for item in batch
]
db.add_all(section_data_list)
success_count += len(batch)
logger.info(f"Inserted batch {i//batch_size + 1}: {len(batch)} records")
except Exception as e:
failed_count += len(batch)
failed_items.extend([
{
'data': item,
'error': f'插入失败: {str(e)}'
}
for item in batch
])
logger.error(f"Failed to insert batch: {str(e)}")
raise e
# 如果有插入失败记录(不是跳过记录),不提交事务
# 跳过记录不应该影响事务,只插入失败的记录才需要回滚
insert_failed_items = [item for item in failed_items if '插入失败' in item.get('error', '')]
if insert_failed_items:
db.rollback()
return {
'success': False,
'message': f'批量导入失败: {len(insert_failed_items)}条记录插入失败',
'total_count': total_count,
'success_count': success_count,
'failed_count': failed_count,
'failed_items': failed_items
}
db.commit()
logger.info(f"Batch import sections completed. Success: {success_count}, Failed: {failed_count}")
break
except Exception as e:
db.rollback()
logger.warning(f"Batch import attempt {attempt + 1} failed: {str(e)}")
if attempt == 1: # 最后一次重试失败
logger.error("Batch import sections failed after retries")
return {
'success': False,
'message': f'批量导入失败: {str(e)}',
'total_count': total_count,
'success_count': 0,
'failed_count': total_count,
'failed_items': failed_items
}
return {
'success': True,
'message': '批量导入完成' if failed_count == 0 else f'部分导入失败',
'total_count': total_count,
'success_count': success_count,
'failed_count': failed_count,
'failed_items': failed_items
}