Files
railway_cloud/app/services/daily.py
2025-10-30 11:43:32 +08:00

152 lines
5.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from sqlalchemy.orm import Session
from typing import List, Optional, Dict, Any, Set, Tuple,Union
from ..models.level_data import LevelData
from ..models.daily import DailyData
from .base import BaseService
from ..models.settlement_data import SettlementData
from sqlalchemy import func, select, desc,over
from sqlalchemy.orm import Session
import logging
logger = logging.getLogger(__name__)
class DailyDataService(BaseService[DailyData]):
def __init__(self):
super().__init__(DailyData)
def _dict_to_instance(self, data_dict: Dict) -> DailyData:
"""辅助方法:将单个字典转换为 DailyData 实例"""
model_fields = [col.name for col in DailyData.__table__.columns]
filtered_data = {k: v for k, v in data_dict.items() if k in model_fields}
return DailyData(**filtered_data)
def _ensure_instances(self, data: Union[List[Dict], List[DailyData]]) -> List[DailyData]:
"""确保输入数据是 DailyData 实例列表"""
if not isinstance(data, list):
raise TypeError(f"输入必须是列表,而非 {type(data)}")
instances = []
for item in data:
if isinstance(item, DailyData):
instances.append(item)
elif isinstance(item, dict):
instances.append(self._dict_to_instance(item))
else:
raise TypeError(f"列表元素必须是 dict 或 DailyData 实例,而非 {type(item)}")
return instances
def batch_create_by_account_nyid(self, db: Session, data: Union[List[Dict], List[DailyData]]) -> List[DailyData]:
"""
批量创建记录,支持两种输入格式:
- List[DailyData]:模型实例列表
- List[dict]:字典列表(自动转换为实例)
通过 (account_id, NYID) 联合判断是否已存在,存在则忽略
"""
try:
data_list = self._ensure_instances(data)
except TypeError as e:
logger.error(f"数据格式错误:{str(e)}")
raise
target_pairs: List[Tuple[int, int]] = [
(item.account_id, item.NYID)
for item in data_list
if item.account_id is not None and item.NYID is not None
]
if not target_pairs:
logger.warning("批量创建失败:所有记录缺少 account_id 或 NYID")
return []
existing_pairs: Set[Tuple[int, int]] = {
(item.account_id, item.NYID)
for item in db.query(DailyData.account_id, DailyData.NYID)
.filter(DailyData.account_id.in_([p[0] for p in target_pairs]),
DailyData.NYID.in_([p[1] for p in target_pairs]))
.all()
}
to_create = [
item for item in data_list
if (item.account_id, item.NYID) not in existing_pairs
]
ignored_count = len(data_list) - len(to_create)
if ignored_count > 0:
logger.info(f"批量创建时忽略{ignored_count}条已存在记录account_id和NYID已存在")
if not to_create:
return []
# 修复点:使用 add_all 替代 bulk_save_objects确保对象被会话跟踪
db.add_all(to_create) # 这里是关键修改
db.commit()
# 现在可以安全地刷新实例了
for item in to_create:
db.refresh(item)
return to_create
def get_nyid_by_point_id(
self,
db: Session,
point_ids: List[int] = None,
max_num: int = 1
) -> List[List[dict]]:
"""
获取指定point_id的记录修复子查询中模型对象访问错误
"""
# 处理参数默认值
point_ids = point_ids or []
max_num = max(max_num, 1)
# 窗口函数按point_id分组每组内按NYID降序编号
row_num = over(
func.row_number(),
partition_by=SettlementData.point_id,
order_by=desc(SettlementData.NYID)
).label("row_num")
# 子查询:查询模型的所有字段 + 行号(不保留模型对象,只展平字段)
# 先获取模型的所有字段列表
model_columns = [getattr(SettlementData, col.name) for col in SettlementData.__table__.columns]
subquery = (
select(*model_columns, row_num) # 展开所有字段 + 行号
.where(SettlementData.point_id.in_(point_ids) if point_ids else True)
.subquery()
)
# 主查询:筛选行号<=max_num的记录
query = (
select(subquery)
.where(subquery.c.row_num <= max_num)
.order_by(subquery.c.point_id, subquery.c.row_num)
)
# 执行查询(结果为包含字段值的行对象)
results = db.execute(query).all()
grouped: Dict[int, List[dict]] = {}
# 获取模型字段名列表(用于映射行对象到字典)
field_names = [col.name for col in SettlementData.__table__.columns]
for row in results:
# 将行对象转换为字典忽略最后一个字段row_num
item_dict = {
field: getattr(row, field)
for field in field_names
}
pid = item_dict["point_id"]
if pid not in grouped:
grouped[pid] = []
grouped[pid].append(item_dict)
# 按输入point_ids顺序整理结果
if not point_ids:
point_ids = sorted(grouped.keys())
# 构建[[{}], [{}]]格式
return [
[record] for pid in point_ids for record in grouped.get(pid, [])
]