"""工区数据服务""" from typing import List, Dict, Any, Tuple from sqlalchemy import text from sqlalchemy.orm import Session from app.core.logging_config import get_logger from app.models.work_area import WorkArea from app.schemas.work_area import WorkAreaCreate, WorkAreaQuery from app.schemas.common import BatchImportResponse from .table_manager import TableManager logger = get_logger(__name__) class WorkAreaService: """工区数据服务""" @staticmethod def batch_import(db: Session, account_id: int, data: List[WorkAreaCreate]) -> BatchImportResponse: """批量导入工区数据""" table_name = WorkArea.get_table_name(account_id) # 确保表存在 if not TableManager.ensure_table_exists(db, "work_area", account_id): return BatchImportResponse( success=False, total=len(data), inserted=0, skipped=0, message="创建表失败" ) # 获取已存在的department_id department_ids = [item.department_id for item in data if item.department_id] existing_ids = set() if department_ids: placeholders = ",".join([f":id_{i}" for i in range(len(department_ids))]) params = {f"id_{i}": did for i, did in enumerate(department_ids)} result = db.execute( text(f"SELECT department_id FROM {table_name} WHERE department_id IN ({placeholders})"), params ) existing_ids = {row[0] for row in result.fetchall()} # 过滤重复数据 to_insert = [] skipped_ids = [] for item in data: if item.department_id in existing_ids: skipped_ids.append(item.department_id) else: to_insert.append(item) existing_ids.add(item.department_id) # 防止批次内重复 # 批量插入 if to_insert: try: values = [] params = {} for i, item in enumerate(to_insert): values.append(f"(:department_id_{i}, :parent_id_{i}, :type_{i}, :name_{i})") params[f"department_id_{i}"] = item.department_id params[f"parent_id_{i}"] = item.parent_id params[f"type_{i}"] = item.type params[f"name_{i}"] = item.name sql = f"INSERT INTO {table_name} (department_id, parent_id, type, name) VALUES {','.join(values)}" db.execute(text(sql), params) db.commit() logger.info(f"工区数据导入成功: account_id={account_id}, 插入={len(to_insert)}, 跳过={len(skipped_ids)}") except Exception as e: db.rollback() logger.error(f"工区数据导入失败: {e}") return BatchImportResponse( success=False, total=len(data), inserted=0, skipped=len(skipped_ids), skipped_ids=skipped_ids, message=f"插入失败: {str(e)}" ) return BatchImportResponse( success=True, total=len(data), inserted=len(to_insert), skipped=len(skipped_ids), skipped_ids=skipped_ids, message="导入成功" ) @staticmethod def query(db: Session, params: WorkAreaQuery) -> Tuple[List[Dict], int]: """查询工区数据""" table_name = WorkArea.get_table_name(params.account_id) # 确保表存在 if not TableManager.ensure_table_exists(db, "work_area", params.account_id): return [], 0 # 构建查询条件 conditions = [] query_params = {} if params.department_id: conditions.append("department_id = :department_id") query_params["department_id"] = params.department_id if params.parent_id: conditions.append("parent_id = :parent_id") query_params["parent_id"] = params.parent_id if params.type: conditions.append("type = :type") query_params["type"] = params.type if params.name: conditions.append("name LIKE :name") query_params["name"] = f"%{params.name}%" where_clause = " AND ".join(conditions) if conditions else "1=1" # 查询总数 count_sql = f"SELECT COUNT(*) FROM {table_name} WHERE {where_clause}" total = db.execute(text(count_sql), query_params).scalar() # 分页查询 offset = (params.page - 1) * params.page_size query_params["limit"] = params.page_size query_params["offset"] = offset data_sql = f"SELECT * FROM {table_name} WHERE {where_clause} LIMIT :limit OFFSET :offset" result = db.execute(text(data_sql), query_params) items = [dict(row._mapping) for row in result.fetchall()] return items, total