refactor: repository get all mixin
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
from typing import Generic, TypeVar, Type, Optional
|
||||
from typing import Generic, TypeVar, Type, Optional, overload
|
||||
|
||||
from sqlalchemy import select, Select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@ -52,3 +52,25 @@ class GetByIdMixin(RepBaseMixin[EntityType]):
|
||||
stmt = self._process_get_by_id_stmt(stmt)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
class GetAllMixin(RepBaseMixin[EntityType]):
|
||||
entity_class: Type[EntityType]
|
||||
|
||||
def _process_get_all_stmt_with_args(self, stmt: Select, *args) -> Select:
|
||||
return stmt
|
||||
|
||||
def _process_get_all_stmt(self, stmt: Select) -> Select:
|
||||
return stmt
|
||||
|
||||
async def get_all(self, *args) -> list[EntityType]:
|
||||
stmt = select(self.entity_class).order_by(self.entity_class.id)
|
||||
if hasattr(self, "is_deleted"):
|
||||
stmt = stmt.where(self.entity_class.is_deleted.is_(False))
|
||||
|
||||
if args:
|
||||
stmt = self._process_get_all_stmt_with_args(stmt, *args)
|
||||
else:
|
||||
stmt = self._process_get_all_stmt(stmt)
|
||||
result = await self.session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
Reference in New Issue
Block a user