80 lines
3.2 KiB
Python
80 lines
3.2 KiB
Python
from typing import Type, TypeVar, Generic, List, Optional, Any, Dict
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy.ext.declarative import DeclarativeMeta
|
|
|
|
ModelType = TypeVar("ModelType")
|
|
|
|
class BaseService(Generic[ModelType]):
|
|
def __init__(self, model: Type[ModelType]):
|
|
self.model = model
|
|
|
|
def create(self, db: Session, obj_data: Dict[str, Any]) -> ModelType:
|
|
"""创建记录"""
|
|
db_obj = self.model(**obj_data)
|
|
db.add(db_obj)
|
|
db.commit()
|
|
db.refresh(db_obj)
|
|
return db_obj
|
|
|
|
def get_by_id(self, db: Session, obj_id: int) -> Optional[ModelType]:
|
|
"""根据ID获取记录"""
|
|
return db.query(self.model).filter(self.model.id == obj_id).first()
|
|
|
|
def get_all(self, db: Session, skip: int = 0, limit: int = 100) -> List[ModelType]:
|
|
"""获取所有记录"""
|
|
return db.query(self.model).offset(skip).limit(limit).all()
|
|
|
|
def update(self, db: Session, obj_id: int, update_data: Dict[str, Any]) -> Optional[ModelType]:
|
|
"""更新记录"""
|
|
db_obj = self.get_by_id(db, obj_id)
|
|
if db_obj:
|
|
for field, value in update_data.items():
|
|
if hasattr(db_obj, field):
|
|
setattr(db_obj, field, value)
|
|
db.commit()
|
|
db.refresh(db_obj)
|
|
return db_obj
|
|
return None
|
|
|
|
def delete(self, db: Session, obj_id: int) -> bool:
|
|
"""删除记录"""
|
|
db_obj = self.get_by_id(db, obj_id)
|
|
if db_obj:
|
|
db.delete(db_obj)
|
|
db.commit()
|
|
return True
|
|
return False
|
|
|
|
def get_by_field(self, db: Session, field_name: str, field_value: Any) -> List[ModelType]:
|
|
"""根据字段值查询记录"""
|
|
if hasattr(self.model, field_name):
|
|
field = getattr(self.model, field_name)
|
|
return db.query(self.model).filter(field == field_value).all()
|
|
return []
|
|
|
|
def search_by_conditions(self, db: Session, conditions: Dict[str, Any], skip: int = 0, limit: Optional[int] = None) -> List[ModelType]:
|
|
"""根据多个条件搜索记录"""
|
|
query = db.query(self.model)
|
|
for field_name, field_value in conditions.items():
|
|
if hasattr(self.model, field_name) and field_value is not None:
|
|
field = getattr(self.model, field_name)
|
|
if isinstance(field_value, str):
|
|
query = query.filter(field.like(f"{field_value}"))
|
|
else:
|
|
query = query.filter(field == field_value)
|
|
query = query.offset(skip)
|
|
if limit is not None:
|
|
query = query.limit(limit)
|
|
return query.all()
|
|
|
|
def search_by_conditions_count(self, db: Session, conditions: Dict[str, Any]) -> int:
|
|
"""根据多个条件搜索记录总数"""
|
|
query = db.query(self.model)
|
|
for field_name, field_value in conditions.items():
|
|
if hasattr(self.model, field_name) and field_value is not None:
|
|
field = getattr(self.model, field_name)
|
|
if isinstance(field_value, str):
|
|
query = query.filter(field.like(f"{field_value}"))
|
|
else:
|
|
query = query.filter(field == field_value)
|
|
return query.count() |