diff --git a/repositories/board.py b/repositories/board.py index f9cbf86..32a2e94 100644 --- a/repositories/board.py +++ b/repositories/board.py @@ -1,11 +1,9 @@ -from typing import Optional - -from sqlalchemy import select +from sqlalchemy import select, Select from sqlalchemy.orm import selectinload from models import Board from repositories.base import BaseRepository -from repositories.mixins import RepDeleteMixin, RepCreateMixin +from repositories.mixins import RepDeleteMixin, RepCreateMixin, GetByIdMixin from schemas.board import UpdateBoardSchema, CreateBoardSchema @@ -13,6 +11,7 @@ class BoardRepository( BaseRepository, RepDeleteMixin[Board], RepCreateMixin[Board, CreateBoardSchema], + GetByIdMixin[Board], ): entity_class = Board @@ -25,14 +24,8 @@ class BoardRepository( result = await self.session.execute(stmt) return list(result.scalars().all()) - async def get_by_id(self, board_id: int) -> Optional[Board]: - stmt = ( - select(Board) - .where(Board.id == board_id, Board.is_deleted.is_(False)) - .options(selectinload(Board.deals)) - ) - result = await self.session.execute(stmt) - return result.scalar_one_or_none() + def _process_get_by_id_stmt(self, stmt: Select) -> Select: + return stmt.options(selectinload(Board.deals)) async def update(self, board: Board, data: UpdateBoardSchema) -> Board: board.lexorank = data.lexorank if data.lexorank else board.lexorank diff --git a/repositories/deal.py b/repositories/deal.py index f007b2e..75a72d5 100644 --- a/repositories/deal.py +++ b/repositories/deal.py @@ -1,18 +1,21 @@ from typing import Optional -from sqlalchemy import select +from sqlalchemy import select, Select from sqlalchemy.orm import joinedload from models import Deal, CardStatusHistory, Board from repositories.base import BaseRepository -from repositories.mixins import RepDeleteMixin, RepCreateMixin +from repositories.mixins import RepDeleteMixin, RepCreateMixin, GetByIdMixin from schemas.base import SortDir from schemas.deal import UpdateDealSchema, CreateDealSchema from utils.sorting import apply_sorting class DealRepository( - BaseRepository, RepDeleteMixin[Deal], RepCreateMixin[Deal, CreateDealSchema] + BaseRepository, + RepDeleteMixin[Deal], + RepCreateMixin[Deal, CreateDealSchema], + GetByIdMixin[Deal], ): entity_class = Deal @@ -58,14 +61,8 @@ class DealRepository( result = await self.session.execute(stmt) return list(result.scalars().all()), total_items - async def get_by_id(self, deal_id: int) -> Optional[Deal]: - stmt = ( - select(Deal) - .options(joinedload(Deal.status), joinedload(Deal.board)) - .where(Deal.id == deal_id, Deal.is_deleted.is_(False)) - ) - result = await self.session.execute(stmt) - return result.scalar_one_or_none() + def _process_get_by_id_stmt(self, stmt: Select) -> Select: + return stmt.options(joinedload(Deal.status), joinedload(Deal.board)) async def update(self, deal: Deal, data: UpdateDealSchema) -> Deal: deal.lexorank = data.lexorank if data.lexorank else deal.lexorank diff --git a/repositories/mixins.py b/repositories/mixins.py index 5fd991f..93e1408 100644 --- a/repositories/mixins.py +++ b/repositories/mixins.py @@ -1,5 +1,6 @@ -from typing import Generic, TypeVar, Type +from typing import Generic, TypeVar, Type, Optional +from sqlalchemy import select, Select from sqlalchemy.ext.asyncio import AsyncSession EntityType = TypeVar("EntityType") @@ -26,8 +27,7 @@ class RepDeleteMixin(RepBaseMixin[EntityType]): await self.session.commit() -class RepCreateMixin(Generic[EntityType, CreateType]): - session: AsyncSession +class RepCreateMixin(RepBaseMixin[EntityType], Generic[EntityType, CreateType]): entity_class: Type[EntityType] async def create(self, data: CreateType) -> int: @@ -36,3 +36,19 @@ class RepCreateMixin(Generic[EntityType, CreateType]): await self.session.commit() await self.session.refresh(obj) return obj.id + + +class GetByIdMixin(RepBaseMixin[EntityType]): + entity_class: Type[EntityType] + + def _process_get_by_id_stmt(self, stmt: Select) -> Select: + return stmt + + async def get_by_id(self, item_id: int) -> Optional[EntityType]: + stmt = select(self.entity_class).where(self.entity_class.id == item_id) + if hasattr(self, "is_deleted"): + stmt = stmt.where(self.entity_class.is_deleted.is_(False)) + + stmt = self._process_get_by_id_stmt(stmt) + result = await self.session.execute(stmt) + return result.scalar_one_or_none() diff --git a/repositories/project.py b/repositories/project.py index 57ab0d1..7fa1801 100644 --- a/repositories/project.py +++ b/repositories/project.py @@ -1,11 +1,9 @@ -from typing import Optional - -from sqlalchemy import select +from sqlalchemy import select, Select from sqlalchemy.orm import selectinload from models.project import Project from repositories.base import BaseRepository -from repositories.mixins import RepDeleteMixin, RepCreateMixin +from repositories.mixins import RepDeleteMixin, RepCreateMixin, GetByIdMixin from schemas.project import CreateProjectSchema, UpdateProjectSchema @@ -13,6 +11,7 @@ class ProjectRepository( BaseRepository, RepDeleteMixin[Project], RepCreateMixin[Project, CreateProjectSchema], + GetByIdMixin[Project], ): entity_class = Project @@ -21,14 +20,8 @@ class ProjectRepository( result = await self.session.execute(stmt) return list(result.scalars().all()) - async def get_by_id(self, project_id: int) -> Optional[Project]: - stmt = ( - select(Project) - .where(Project.id == project_id, Project.is_deleted.is_(False)) - .options(selectinload(Project.boards)) - ) - result = await self.session.execute(stmt) - return result.scalar_one_or_none() + def _process_get_by_id_stmt(self, stmt: Select) -> Select: + return stmt.options(selectinload(Project.boards)) async def update(self, project: Project, data: UpdateProjectSchema) -> Project: project.name = data.name if data.name else project.name diff --git a/repositories/status.py b/repositories/status.py index fe77474..f7c048d 100644 --- a/repositories/status.py +++ b/repositories/status.py @@ -1,17 +1,16 @@ -from typing import Optional - from sqlalchemy import select, func from models import Status, Deal from repositories.base import BaseRepository -from repositories.mixins import RepDeleteMixin, RepCreateMixin +from repositories.mixins import RepDeleteMixin, RepCreateMixin, GetByIdMixin from schemas.status import UpdateStatusSchema, CreateStatusSchema class StatusRepository( BaseRepository, RepDeleteMixin[Status], - RepCreateMixin[Deal, CreateStatusSchema], + RepCreateMixin[Status, CreateStatusSchema], + GetByIdMixin[Status], ): entity_class = Status @@ -24,13 +23,6 @@ class StatusRepository( result = await self.session.execute(stmt) return list(result.scalars().all()) - async def get_by_id(self, status_id: int) -> Optional[Status]: - stmt = select(Status).where( - Status.id == status_id, Status.is_deleted.is_(False) - ) - result = await self.session.execute(stmt) - return result.scalar_one_or_none() - async def get_deals_count(self, status_id: int) -> int: stmt = select(func.count(Deal.id)).where(Deal.status_id == status_id) result = await self.session.execute(stmt)