from sqlalchemy.orm import Session from typing import List, Dict, Any, Optional from ..models.account import Account from ..models.section_data import SectionData from ..models.checkpoint import Checkpoint from ..models.settlement_data import SettlementData from ..models.level_data import LevelData from ..services.section_data import SectionDataService from ..services.checkpoint import CheckpointService from ..services.settlement_data import SettlementDataService from ..services.level_data import LevelDataService from ..services.account import AccountService from ..core.exceptions import DataNotFoundException, AccountNotFoundException import pandas as pd import logging from datetime import datetime logger = logging.getLogger(__name__) class ExportExcelService: def __init__(self): self.account_service = AccountService() self.section_service = SectionDataService() self.checkpoint_service = CheckpointService() self.settlement_service = SettlementDataService() self.level_service = LevelDataService() def get_field_comments(self, model_class) -> Dict[str, str]: """获取模型字段的注释信息""" comments = {} for column in model_class.__table__.columns: if column.comment: comments[column.name] = column.comment return comments def merge_settlement_with_related_data(self, db: Session, settlement_data: SettlementData, section_data: SectionData, checkpoint_data: Checkpoint, level_data: Optional[LevelData]) -> Dict[str, Any]: """ 合并沉降数据与关联数据,去除重复和id字段 """ result = {} # 沉降数据字段映射(用注释名作为键) settlement_comments = self.get_field_comments(SettlementData) settlement_dict = settlement_data.to_dict() for field_name, value in settlement_dict.items(): # 跳过id字段 if field_name == 'id': continue # 使用注释名作为键,如果没有注释则使用字段名 key = settlement_comments.get(field_name, field_name) result[key] = value # 断面数据字段映射(添加前缀) section_comments = self.get_field_comments(SectionData) section_dict = section_data.to_dict() for field_name, value in section_dict.items(): # 跳过id和account_id字段 if field_name in ['id', 'account_id']: continue key = section_comments.get(field_name, field_name) result[f"断面_{key}"] = value # 观测点数据字段映射(添加前缀) checkpoint_comments = self.get_field_comments(Checkpoint) checkpoint_dict = checkpoint_data.to_dict() for field_name, value in checkpoint_dict.items(): # 跳过id和section_id字段(section_id可能重复) if field_name in ['id', 'section_id']: continue key = checkpoint_comments.get(field_name, field_name) result[f"观测点_{key}"] = value # 水准数据字段映射(添加前缀) if level_data is not None: level_comments = self.get_field_comments(LevelData) level_dict = level_data.to_dict() for field_name, value in level_dict.items(): # 跳过id和NYID字段(NYID可能重复) if field_name in ['id', 'NYID']: continue key = level_comments.get(field_name, field_name) result[f"水准_{key}"] = value return result def export_settlement_data_to_file(self, db: Session, project_name: str, file_path: str): """ 根据项目名称导出沉降数据Excel文件到指定路径(批量查询优化版本) """ logger.info(f"开始导出项目 '{project_name}' 的沉降数据到文件: {file_path}") # 1. 在账号表查询到账号id作为account_id account_responses = self.account_service.search_accounts(db, project_name=project_name) if not account_responses: logger.warning(f"未找到项目名称为 '{project_name}' 的账号") raise AccountNotFoundException(f"未找到项目名称为 '{project_name}' 的账号") account_response = account_responses[0] account_id = str(account_response.account_id) logger.info(f"找到账号 ID: {account_id}") # 2. 通过 account_id 查询断面数据 sections = self.section_service.search_section_data(db, account_id=account_id, limit=10000) if not sections: logger.warning(f"账号 {account_id} 下未找到断面数据") raise DataNotFoundException(f"账号 {account_id} 下未找到断面数据") logger.info(f"找到 {len(sections)} 个断面") # 3. 收集所有观测点数据,建立断面->观测点映射 section_dict = {section.section_id: section for section in sections} section_checkpoint_map = {} # section_id -> [checkpoints] all_checkpoints = [] for section in sections: checkpoints = self.checkpoint_service.get_by_section_id(db, section.section_id) if checkpoints: section_checkpoint_map[section.section_id] = checkpoints all_checkpoints.extend(checkpoints) if not all_checkpoints: logger.warning("未找到任何观测点数据") raise DataNotFoundException("未找到任何观测点数据") logger.info(f"找到 {len(all_checkpoints)} 个观测点") # 4. 批量查询沉降数据(关键优化点) point_ids = [cp.point_id for cp in all_checkpoints] logger.info(f"开始批量查询 {len(point_ids)} 个观测点的沉降数据") all_settlements = self.settlement_service.get_by_point_ids(db, point_ids) if not all_settlements: logger.warning("未找到任何沉降数据") logger.info(f"观测点id集合{point_ids}") raise DataNotFoundException("未找到任何沉降数据") logger.info(f"批量查询到 {len(all_settlements)} 条沉降数据") # 5. 建立观测点->沉降数据映射 checkpoint_dict = {cp.point_id: cp for cp in all_checkpoints} point_settlement_map = {} # point_id -> [settlements] nyid_set = set() for settlement in all_settlements: if settlement.point_id not in point_settlement_map: point_settlement_map[settlement.point_id] = [] point_settlement_map[settlement.point_id].append(settlement) if settlement.NYID: nyid_set.add(settlement.NYID) # 6. 批量查询水准数据(关键优化点) nyid_list = list(nyid_set) logger.info(f"开始批量查询 {len(nyid_list)} 个期数的水准数据") all_level_data = self.level_service.get_by_nyids(db, nyid_list) logger.info(f"批量查询到 {len(all_level_data)} 条水准数据") # 建立NYID->水准数据映射 nyid_level_map = {} for level_data in all_level_data: if level_data.NYID not in nyid_level_map: nyid_level_map[level_data.NYID] = level_data # 7. 合并数据 all_settlement_records = [] for section in sections: checkpoints = section_checkpoint_map.get(section.section_id, []) for checkpoint in checkpoints: settlements = point_settlement_map.get(checkpoint.point_id, []) for settlement in settlements: # 从映射中获取水准数据 level_data = nyid_level_map.get(settlement.NYID) # 合并数据 merged_record = self.merge_settlement_with_related_data( db, settlement, section, checkpoint, level_data ) all_settlement_records.append(merged_record) if not all_settlement_records: logger.warning("未能合并任何数据记录") raise DataNotFoundException("未能合并任何数据记录") logger.info(f"共找到 {len(all_settlement_records)} 条沉降数据记录") # 转换为DataFrame df = pd.DataFrame(all_settlement_records) # 导出到Excel文件 with pd.ExcelWriter(file_path, engine='openpyxl') as writer: df.to_excel(writer, index=False, sheet_name='沉降数据') # 自动调整列宽 worksheet = writer.sheets['沉降数据'] for column in worksheet.columns: max_length = 0 column_letter = column[0].column_letter for cell in column: try: if len(str(cell.value)) > max_length: max_length = len(str(cell.value)) except: pass adjusted_width = min(max_length + 2, 50) worksheet.column_dimensions[column_letter].width = adjusted_width logger.info("Excel文件生成完成")