diff --git a/repositories/board.py b/repositories/board.py index 32a2e94..be4263c 100644 --- a/repositories/board.py +++ b/repositories/board.py @@ -1,28 +1,29 @@ -from sqlalchemy import select, Select +from sqlalchemy import Select from sqlalchemy.orm import selectinload from models import Board from repositories.base import BaseRepository -from repositories.mixins import RepDeleteMixin, RepCreateMixin, GetByIdMixin +from repositories.mixins import ( + RepDeleteMixin, + RepCreateMixin, + GetByIdMixin, + GetAllMixin, +) from schemas.board import UpdateBoardSchema, CreateBoardSchema class BoardRepository( BaseRepository, + GetAllMixin[Board], RepDeleteMixin[Board], RepCreateMixin[Board, CreateBoardSchema], GetByIdMixin[Board], ): entity_class = Board - async def get_all(self, project_id: int) -> list[Board]: - stmt = ( - select(Board) - .where(Board.is_deleted.is_(False), Board.project_id == project_id) - .order_by(Board.lexorank) - ) - result = await self.session.execute(stmt) - return list(result.scalars().all()) + def _process_get_all_stmt_with_args(self, stmt: Select, *args) -> Select: + project_id = args[0] + return stmt.where(Board.project_id == project_id).order_by(Board.lexorank) def _process_get_by_id_stmt(self, stmt: Select) -> Select: return stmt.options(selectinload(Board.deals)) diff --git a/repositories/mixins.py b/repositories/mixins.py index 93e1408..4dc6a12 100644 --- a/repositories/mixins.py +++ b/repositories/mixins.py @@ -1,4 +1,4 @@ -from typing import Generic, TypeVar, Type, Optional +from typing import Generic, TypeVar, Type, Optional, overload from sqlalchemy import select, Select from sqlalchemy.ext.asyncio import AsyncSession @@ -52,3 +52,25 @@ class GetByIdMixin(RepBaseMixin[EntityType]): stmt = self._process_get_by_id_stmt(stmt) result = await self.session.execute(stmt) return result.scalar_one_or_none() + + +class GetAllMixin(RepBaseMixin[EntityType]): + entity_class: Type[EntityType] + + def _process_get_all_stmt_with_args(self, stmt: Select, *args) -> Select: + return stmt + + def _process_get_all_stmt(self, stmt: Select) -> Select: + return stmt + + async def get_all(self, *args) -> list[EntityType]: + stmt = select(self.entity_class).order_by(self.entity_class.id) + if hasattr(self, "is_deleted"): + stmt = stmt.where(self.entity_class.is_deleted.is_(False)) + + if args: + stmt = self._process_get_all_stmt_with_args(stmt, *args) + else: + stmt = self._process_get_all_stmt(stmt) + result = await self.session.execute(stmt) + return list(result.scalars().all()) diff --git a/repositories/project.py b/repositories/project.py index 7fa1801..706e0a6 100644 --- a/repositories/project.py +++ b/repositories/project.py @@ -1,24 +1,30 @@ -from sqlalchemy import select, Select +from typing import overload + +from sqlalchemy import Select from sqlalchemy.orm import selectinload from models.project import Project from repositories.base import BaseRepository -from repositories.mixins import RepDeleteMixin, RepCreateMixin, GetByIdMixin +from repositories.mixins import ( + RepDeleteMixin, + RepCreateMixin, + GetByIdMixin, + GetAllMixin, +) from schemas.project import CreateProjectSchema, UpdateProjectSchema class ProjectRepository( BaseRepository, + GetAllMixin[Project], RepDeleteMixin[Project], RepCreateMixin[Project, CreateProjectSchema], GetByIdMixin[Project], ): entity_class = Project - async def get_all(self) -> list[Project]: - stmt = select(Project).where(Project.is_deleted.is_(False)).order_by(Project.id) - result = await self.session.execute(stmt) - return list(result.scalars().all()) + def _process_get_all_stmt(self, stmt: Select) -> Select: + return stmt.order_by(Project.id) def _process_get_by_id_stmt(self, stmt: Select) -> Select: return stmt.options(selectinload(Project.boards)) diff --git a/repositories/status.py b/repositories/status.py index f7c048d..7d1f6f1 100644 --- a/repositories/status.py +++ b/repositories/status.py @@ -1,19 +1,29 @@ -from sqlalchemy import select, func +from sqlalchemy import select, func, Select from models import Status, Deal from repositories.base import BaseRepository -from repositories.mixins import RepDeleteMixin, RepCreateMixin, GetByIdMixin +from repositories.mixins import ( + RepDeleteMixin, + RepCreateMixin, + GetByIdMixin, + GetAllMixin, +) from schemas.status import UpdateStatusSchema, CreateStatusSchema class StatusRepository( BaseRepository, + GetAllMixin[Status], RepDeleteMixin[Status], RepCreateMixin[Status, CreateStatusSchema], GetByIdMixin[Status], ): entity_class = Status + def _process_get_all_stmt_with_args(self, stmt: Select, *args) -> Select: + board_id = args[0] + return stmt.where(Status.board_id == board_id).order_by(Status.lexorank) + async def get_all(self, board_id: int) -> list[Status]: stmt = ( select(Status)