77 lines
2.4 KiB
Python
77 lines
2.4 KiB
Python
from typing import Generic, TypeVar, Type, Optional, overload
|
|
|
|
from sqlalchemy import select, Select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
EntityType = TypeVar("EntityType")
|
|
CreateType = TypeVar("CreateType")
|
|
|
|
|
|
class RepBaseMixin(Generic[EntityType]):
|
|
session: AsyncSession
|
|
|
|
|
|
class RepDeleteMixin(RepBaseMixin[EntityType]):
|
|
async def delete(self, obj: EntityType, is_soft: bool) -> None:
|
|
if not is_soft:
|
|
await self.session.delete(obj)
|
|
await self.session.commit()
|
|
return
|
|
|
|
if not hasattr(obj, "is_deleted"):
|
|
raise AttributeError(
|
|
f"{obj.__class__.__name__} does not support soft delete (missing is_deleted field)"
|
|
)
|
|
obj.is_deleted = True
|
|
self.session.add(obj)
|
|
await self.session.commit()
|
|
|
|
|
|
class RepCreateMixin(RepBaseMixin[EntityType], Generic[EntityType, CreateType]):
|
|
entity_class: Type[EntityType]
|
|
|
|
async def create(self, data: CreateType) -> int:
|
|
obj = self.entity_class(**data.model_dump())
|
|
self.session.add(obj)
|
|
await self.session.commit()
|
|
await self.session.refresh(obj)
|
|
return obj.id
|
|
|
|
|
|
class GetByIdMixin(RepBaseMixin[EntityType]):
|
|
entity_class: Type[EntityType]
|
|
|
|
def _process_get_by_id_stmt(self, stmt: Select) -> Select:
|
|
return stmt
|
|
|
|
async def get_by_id(self, item_id: int) -> Optional[EntityType]:
|
|
stmt = select(self.entity_class).where(self.entity_class.id == item_id)
|
|
if hasattr(self, "is_deleted"):
|
|
stmt = stmt.where(self.entity_class.is_deleted.is_(False))
|
|
|
|
stmt = self._process_get_by_id_stmt(stmt)
|
|
result = await self.session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
class GetAllMixin(RepBaseMixin[EntityType]):
|
|
entity_class: Type[EntityType]
|
|
|
|
def _process_get_all_stmt_with_args(self, stmt: Select, *args) -> Select:
|
|
return stmt
|
|
|
|
def _process_get_all_stmt(self, stmt: Select) -> Select:
|
|
return stmt
|
|
|
|
async def get_all(self, *args) -> list[EntityType]:
|
|
stmt = select(self.entity_class).order_by(self.entity_class.id)
|
|
if hasattr(self, "is_deleted"):
|
|
stmt = stmt.where(self.entity_class.is_deleted.is_(False))
|
|
|
|
if args:
|
|
stmt = self._process_get_all_stmt_with_args(stmt, *args)
|
|
else:
|
|
stmt = self._process_get_all_stmt(stmt)
|
|
result = await self.session.execute(stmt)
|
|
return list(result.scalars().all())
|