refactor: repository get by id mixin

This commit is contained in:
2025-09-05 09:53:16 +04:00
parent e5be35be35
commit c1d3ac98f0
5 changed files with 40 additions and 49 deletions

View File

@ -1,11 +1,9 @@
from typing import Optional
from sqlalchemy import select
from sqlalchemy import select, Select
from sqlalchemy.orm import selectinload
from models import Board
from repositories.base import BaseRepository
from repositories.mixins import RepDeleteMixin, RepCreateMixin
from repositories.mixins import RepDeleteMixin, RepCreateMixin, GetByIdMixin
from schemas.board import UpdateBoardSchema, CreateBoardSchema
@ -13,6 +11,7 @@ class BoardRepository(
BaseRepository,
RepDeleteMixin[Board],
RepCreateMixin[Board, CreateBoardSchema],
GetByIdMixin[Board],
):
entity_class = Board
@ -25,14 +24,8 @@ class BoardRepository(
result = await self.session.execute(stmt)
return list(result.scalars().all())
async def get_by_id(self, board_id: int) -> Optional[Board]:
stmt = (
select(Board)
.where(Board.id == board_id, Board.is_deleted.is_(False))
.options(selectinload(Board.deals))
)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
def _process_get_by_id_stmt(self, stmt: Select) -> Select:
return stmt.options(selectinload(Board.deals))
async def update(self, board: Board, data: UpdateBoardSchema) -> Board:
board.lexorank = data.lexorank if data.lexorank else board.lexorank

View File

@ -1,18 +1,21 @@
from typing import Optional
from sqlalchemy import select
from sqlalchemy import select, Select
from sqlalchemy.orm import joinedload
from models import Deal, CardStatusHistory, Board
from repositories.base import BaseRepository
from repositories.mixins import RepDeleteMixin, RepCreateMixin
from repositories.mixins import RepDeleteMixin, RepCreateMixin, GetByIdMixin
from schemas.base import SortDir
from schemas.deal import UpdateDealSchema, CreateDealSchema
from utils.sorting import apply_sorting
class DealRepository(
BaseRepository, RepDeleteMixin[Deal], RepCreateMixin[Deal, CreateDealSchema]
BaseRepository,
RepDeleteMixin[Deal],
RepCreateMixin[Deal, CreateDealSchema],
GetByIdMixin[Deal],
):
entity_class = Deal
@ -58,14 +61,8 @@ class DealRepository(
result = await self.session.execute(stmt)
return list(result.scalars().all()), total_items
async def get_by_id(self, deal_id: int) -> Optional[Deal]:
stmt = (
select(Deal)
.options(joinedload(Deal.status), joinedload(Deal.board))
.where(Deal.id == deal_id, Deal.is_deleted.is_(False))
)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
def _process_get_by_id_stmt(self, stmt: Select) -> Select:
return stmt.options(joinedload(Deal.status), joinedload(Deal.board))
async def update(self, deal: Deal, data: UpdateDealSchema) -> Deal:
deal.lexorank = data.lexorank if data.lexorank else deal.lexorank

View File

@ -1,5 +1,6 @@
from typing import Generic, TypeVar, Type
from typing import Generic, TypeVar, Type, Optional
from sqlalchemy import select, Select
from sqlalchemy.ext.asyncio import AsyncSession
EntityType = TypeVar("EntityType")
@ -26,8 +27,7 @@ class RepDeleteMixin(RepBaseMixin[EntityType]):
await self.session.commit()
class RepCreateMixin(Generic[EntityType, CreateType]):
session: AsyncSession
class RepCreateMixin(RepBaseMixin[EntityType], Generic[EntityType, CreateType]):
entity_class: Type[EntityType]
async def create(self, data: CreateType) -> int:
@ -36,3 +36,19 @@ class RepCreateMixin(Generic[EntityType, CreateType]):
await self.session.commit()
await self.session.refresh(obj)
return obj.id
class GetByIdMixin(RepBaseMixin[EntityType]):
entity_class: Type[EntityType]
def _process_get_by_id_stmt(self, stmt: Select) -> Select:
return stmt
async def get_by_id(self, item_id: int) -> Optional[EntityType]:
stmt = select(self.entity_class).where(self.entity_class.id == item_id)
if hasattr(self, "is_deleted"):
stmt = stmt.where(self.entity_class.is_deleted.is_(False))
stmt = self._process_get_by_id_stmt(stmt)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()

View File

@ -1,11 +1,9 @@
from typing import Optional
from sqlalchemy import select
from sqlalchemy import select, Select
from sqlalchemy.orm import selectinload
from models.project import Project
from repositories.base import BaseRepository
from repositories.mixins import RepDeleteMixin, RepCreateMixin
from repositories.mixins import RepDeleteMixin, RepCreateMixin, GetByIdMixin
from schemas.project import CreateProjectSchema, UpdateProjectSchema
@ -13,6 +11,7 @@ class ProjectRepository(
BaseRepository,
RepDeleteMixin[Project],
RepCreateMixin[Project, CreateProjectSchema],
GetByIdMixin[Project],
):
entity_class = Project
@ -21,14 +20,8 @@ class ProjectRepository(
result = await self.session.execute(stmt)
return list(result.scalars().all())
async def get_by_id(self, project_id: int) -> Optional[Project]:
stmt = (
select(Project)
.where(Project.id == project_id, Project.is_deleted.is_(False))
.options(selectinload(Project.boards))
)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
def _process_get_by_id_stmt(self, stmt: Select) -> Select:
return stmt.options(selectinload(Project.boards))
async def update(self, project: Project, data: UpdateProjectSchema) -> Project:
project.name = data.name if data.name else project.name

View File

@ -1,17 +1,16 @@
from typing import Optional
from sqlalchemy import select, func
from models import Status, Deal
from repositories.base import BaseRepository
from repositories.mixins import RepDeleteMixin, RepCreateMixin
from repositories.mixins import RepDeleteMixin, RepCreateMixin, GetByIdMixin
from schemas.status import UpdateStatusSchema, CreateStatusSchema
class StatusRepository(
BaseRepository,
RepDeleteMixin[Status],
RepCreateMixin[Deal, CreateStatusSchema],
RepCreateMixin[Status, CreateStatusSchema],
GetByIdMixin[Status],
):
entity_class = Status
@ -24,13 +23,6 @@ class StatusRepository(
result = await self.session.execute(stmt)
return list(result.scalars().all())
async def get_by_id(self, status_id: int) -> Optional[Status]:
stmt = select(Status).where(
Status.id == status_id, Status.is_deleted.is_(False)
)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def get_deals_count(self, status_id: int) -> int:
stmt = select(func.count(Deal.id)).where(Deal.status_id == status_id)
result = await self.session.execute(stmt)