212 lines
9.2 KiB
Python
212 lines
9.2 KiB
Python
from sqlalchemy.orm import Session
|
||
from typing import List, Dict, Any, Optional
|
||
from ..models.account import Account
|
||
from ..models.section_data import SectionData
|
||
from ..models.checkpoint import Checkpoint
|
||
from ..models.settlement_data import SettlementData
|
||
from ..models.level_data import LevelData
|
||
from ..services.section_data import SectionDataService
|
||
from ..services.checkpoint import CheckpointService
|
||
from ..services.settlement_data import SettlementDataService
|
||
from ..services.level_data import LevelDataService
|
||
from ..services.account import AccountService
|
||
from ..core.exceptions import DataNotFoundException, AccountNotFoundException
|
||
import pandas as pd
|
||
import logging
|
||
from datetime import datetime
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class ExportExcelService:
|
||
def __init__(self):
|
||
self.account_service = AccountService()
|
||
self.section_service = SectionDataService()
|
||
self.checkpoint_service = CheckpointService()
|
||
self.settlement_service = SettlementDataService()
|
||
self.level_service = LevelDataService()
|
||
|
||
def get_field_comments(self, model_class) -> Dict[str, str]:
|
||
"""获取模型字段的注释信息"""
|
||
comments = {}
|
||
for column in model_class.__table__.columns:
|
||
if column.comment:
|
||
comments[column.name] = column.comment
|
||
return comments
|
||
|
||
def merge_settlement_with_related_data(self,
|
||
db: Session,
|
||
settlement_data: SettlementData,
|
||
section_data: SectionData,
|
||
checkpoint_data: Checkpoint,
|
||
level_data: Optional[LevelData]) -> Dict[str, Any]:
|
||
"""
|
||
合并沉降数据与关联数据,去除重复和id字段
|
||
"""
|
||
result = {}
|
||
|
||
# 沉降数据字段映射(用注释名作为键)
|
||
settlement_comments = self.get_field_comments(SettlementData)
|
||
settlement_dict = settlement_data.to_dict()
|
||
for field_name, value in settlement_dict.items():
|
||
# 跳过id字段
|
||
if field_name == 'id':
|
||
continue
|
||
# 使用注释名作为键,如果没有注释则使用字段名
|
||
key = settlement_comments.get(field_name, field_name)
|
||
result[key] = value
|
||
|
||
# 断面数据字段映射(添加前缀)
|
||
section_comments = self.get_field_comments(SectionData)
|
||
section_dict = section_data.to_dict()
|
||
for field_name, value in section_dict.items():
|
||
# 跳过id和account_id字段
|
||
if field_name in ['id', 'account_id']:
|
||
continue
|
||
key = section_comments.get(field_name, field_name)
|
||
result[f"断面_{key}"] = value
|
||
|
||
# 观测点数据字段映射(添加前缀)
|
||
checkpoint_comments = self.get_field_comments(Checkpoint)
|
||
checkpoint_dict = checkpoint_data.to_dict()
|
||
for field_name, value in checkpoint_dict.items():
|
||
# 跳过id和section_id字段(section_id可能重复)
|
||
if field_name in ['id', 'section_id']:
|
||
continue
|
||
key = checkpoint_comments.get(field_name, field_name)
|
||
result[f"观测点_{key}"] = value
|
||
|
||
# 水准数据字段映射(添加前缀)
|
||
if level_data is not None:
|
||
level_comments = self.get_field_comments(LevelData)
|
||
level_dict = level_data.to_dict()
|
||
for field_name, value in level_dict.items():
|
||
# 跳过id和NYID字段(NYID可能重复)
|
||
if field_name in ['id', 'NYID']:
|
||
continue
|
||
key = level_comments.get(field_name, field_name)
|
||
result[f"水准_{key}"] = value
|
||
|
||
return result
|
||
|
||
def export_settlement_data_to_file(self, db: Session, project_name: str, file_path: str):
|
||
"""
|
||
根据项目名称导出沉降数据Excel文件到指定路径(批量查询优化版本)
|
||
"""
|
||
logger.info(f"开始导出项目 '{project_name}' 的沉降数据到文件: {file_path}")
|
||
|
||
# 1. 在账号表查询到账号id作为account_id
|
||
account_responses = self.account_service.search_accounts(db, project_name=project_name)
|
||
if not account_responses:
|
||
logger.warning(f"未找到项目名称为 '{project_name}' 的账号")
|
||
raise AccountNotFoundException(f"未找到项目名称为 '{project_name}' 的账号")
|
||
|
||
account_response = account_responses[0]
|
||
account_id = str(account_response.account_id)
|
||
logger.info(f"找到账号 ID: {account_id}")
|
||
|
||
# 2. 通过 account_id 查询断面数据
|
||
sections = self.section_service.search_section_data(db, account_id=account_id, limit=10000)
|
||
if not sections:
|
||
logger.warning(f"账号 {account_id} 下未找到断面数据")
|
||
raise DataNotFoundException(f"账号 {account_id} 下未找到断面数据")
|
||
|
||
logger.info(f"找到 {len(sections)} 个断面")
|
||
|
||
# 3. 收集所有观测点数据,建立断面->观测点映射
|
||
section_dict = {section.section_id: section for section in sections}
|
||
section_checkpoint_map = {} # section_id -> [checkpoints]
|
||
all_checkpoints = []
|
||
|
||
for section in sections:
|
||
checkpoints = self.checkpoint_service.get_by_section_id(db, section.section_id)
|
||
if checkpoints:
|
||
section_checkpoint_map[section.section_id] = checkpoints
|
||
all_checkpoints.extend(checkpoints)
|
||
|
||
if not all_checkpoints:
|
||
logger.warning("未找到任何观测点数据")
|
||
raise DataNotFoundException("未找到任何观测点数据")
|
||
|
||
logger.info(f"找到 {len(all_checkpoints)} 个观测点")
|
||
|
||
# 4. 批量查询沉降数据(关键优化点)
|
||
point_ids = [cp.point_id for cp in all_checkpoints]
|
||
logger.info(f"开始批量查询 {len(point_ids)} 个观测点的沉降数据")
|
||
all_settlements = self.settlement_service.get_by_point_ids(db, point_ids)
|
||
|
||
if not all_settlements:
|
||
logger.warning("未找到任何沉降数据")
|
||
logger.info(f"观测点id集合{point_ids}")
|
||
raise DataNotFoundException("未找到任何沉降数据")
|
||
|
||
logger.info(f"批量查询到 {len(all_settlements)} 条沉降数据")
|
||
|
||
# 5. 建立观测点->沉降数据映射
|
||
checkpoint_dict = {cp.point_id: cp for cp in all_checkpoints}
|
||
point_settlement_map = {} # point_id -> [settlements]
|
||
nyid_set = set()
|
||
|
||
for settlement in all_settlements:
|
||
if settlement.point_id not in point_settlement_map:
|
||
point_settlement_map[settlement.point_id] = []
|
||
point_settlement_map[settlement.point_id].append(settlement)
|
||
if settlement.NYID:
|
||
nyid_set.add(settlement.NYID)
|
||
|
||
# 6. 批量查询水准数据(关键优化点)
|
||
nyid_list = list(nyid_set)
|
||
logger.info(f"开始批量查询 {len(nyid_list)} 个期数的水准数据")
|
||
all_level_data = self.level_service.get_by_nyids(db, nyid_list)
|
||
logger.info(f"批量查询到 {len(all_level_data)} 条水准数据")
|
||
|
||
# 建立NYID->水准数据映射
|
||
nyid_level_map = {}
|
||
for level_data in all_level_data:
|
||
if level_data.NYID not in nyid_level_map:
|
||
nyid_level_map[level_data.NYID] = level_data
|
||
|
||
# 7. 合并数据
|
||
all_settlement_records = []
|
||
for section in sections:
|
||
checkpoints = section_checkpoint_map.get(section.section_id, [])
|
||
for checkpoint in checkpoints:
|
||
settlements = point_settlement_map.get(checkpoint.point_id, [])
|
||
for settlement in settlements:
|
||
# 从映射中获取水准数据
|
||
level_data = nyid_level_map.get(settlement.NYID)
|
||
|
||
# 合并数据
|
||
merged_record = self.merge_settlement_with_related_data(
|
||
db, settlement, section, checkpoint, level_data
|
||
)
|
||
all_settlement_records.append(merged_record)
|
||
|
||
if not all_settlement_records:
|
||
logger.warning("未能合并任何数据记录")
|
||
raise DataNotFoundException("未能合并任何数据记录")
|
||
|
||
logger.info(f"共找到 {len(all_settlement_records)} 条沉降数据记录")
|
||
|
||
# 转换为DataFrame
|
||
df = pd.DataFrame(all_settlement_records)
|
||
|
||
# 导出到Excel文件
|
||
with pd.ExcelWriter(file_path, engine='openpyxl') as writer:
|
||
df.to_excel(writer, index=False, sheet_name='沉降数据')
|
||
|
||
# 自动调整列宽
|
||
worksheet = writer.sheets['沉降数据']
|
||
for column in worksheet.columns:
|
||
max_length = 0
|
||
column_letter = column[0].column_letter
|
||
for cell in column:
|
||
try:
|
||
if len(str(cell.value)) > max_length:
|
||
max_length = len(str(cell.value))
|
||
except:
|
||
pass
|
||
adjusted_width = min(max_length + 2, 50)
|
||
worksheet.column_dimensions[column_letter].width = adjusted_width
|
||
|
||
logger.info("Excel文件生成完成")
|