refactor: repository get by id mixin

This commit is contained in:
2025-09-05 09:53:16 +04:00
parent e5be35be35
commit c1d3ac98f0
5 changed files with 40 additions and 49 deletions

View File

@ -1,11 +1,9 @@
from typing import Optional from sqlalchemy import select, Select
from sqlalchemy import select
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from models import Board from models import Board
from repositories.base import BaseRepository from repositories.base import BaseRepository
from repositories.mixins import RepDeleteMixin, RepCreateMixin from repositories.mixins import RepDeleteMixin, RepCreateMixin, GetByIdMixin
from schemas.board import UpdateBoardSchema, CreateBoardSchema from schemas.board import UpdateBoardSchema, CreateBoardSchema
@ -13,6 +11,7 @@ class BoardRepository(
BaseRepository, BaseRepository,
RepDeleteMixin[Board], RepDeleteMixin[Board],
RepCreateMixin[Board, CreateBoardSchema], RepCreateMixin[Board, CreateBoardSchema],
GetByIdMixin[Board],
): ):
entity_class = Board entity_class = Board
@ -25,14 +24,8 @@ class BoardRepository(
result = await self.session.execute(stmt) result = await self.session.execute(stmt)
return list(result.scalars().all()) return list(result.scalars().all())
async def get_by_id(self, board_id: int) -> Optional[Board]: def _process_get_by_id_stmt(self, stmt: Select) -> Select:
stmt = ( return stmt.options(selectinload(Board.deals))
select(Board)
.where(Board.id == board_id, Board.is_deleted.is_(False))
.options(selectinload(Board.deals))
)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def update(self, board: Board, data: UpdateBoardSchema) -> Board: async def update(self, board: Board, data: UpdateBoardSchema) -> Board:
board.lexorank = data.lexorank if data.lexorank else board.lexorank board.lexorank = data.lexorank if data.lexorank else board.lexorank

View File

@ -1,18 +1,21 @@
from typing import Optional from typing import Optional
from sqlalchemy import select from sqlalchemy import select, Select
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from models import Deal, CardStatusHistory, Board from models import Deal, CardStatusHistory, Board
from repositories.base import BaseRepository from repositories.base import BaseRepository
from repositories.mixins import RepDeleteMixin, RepCreateMixin from repositories.mixins import RepDeleteMixin, RepCreateMixin, GetByIdMixin
from schemas.base import SortDir from schemas.base import SortDir
from schemas.deal import UpdateDealSchema, CreateDealSchema from schemas.deal import UpdateDealSchema, CreateDealSchema
from utils.sorting import apply_sorting from utils.sorting import apply_sorting
class DealRepository( class DealRepository(
BaseRepository, RepDeleteMixin[Deal], RepCreateMixin[Deal, CreateDealSchema] BaseRepository,
RepDeleteMixin[Deal],
RepCreateMixin[Deal, CreateDealSchema],
GetByIdMixin[Deal],
): ):
entity_class = Deal entity_class = Deal
@ -58,14 +61,8 @@ class DealRepository(
result = await self.session.execute(stmt) result = await self.session.execute(stmt)
return list(result.scalars().all()), total_items return list(result.scalars().all()), total_items
async def get_by_id(self, deal_id: int) -> Optional[Deal]: def _process_get_by_id_stmt(self, stmt: Select) -> Select:
stmt = ( return stmt.options(joinedload(Deal.status), joinedload(Deal.board))
select(Deal)
.options(joinedload(Deal.status), joinedload(Deal.board))
.where(Deal.id == deal_id, Deal.is_deleted.is_(False))
)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def update(self, deal: Deal, data: UpdateDealSchema) -> Deal: async def update(self, deal: Deal, data: UpdateDealSchema) -> Deal:
deal.lexorank = data.lexorank if data.lexorank else deal.lexorank deal.lexorank = data.lexorank if data.lexorank else deal.lexorank

View File

@ -1,5 +1,6 @@
from typing import Generic, TypeVar, Type from typing import Generic, TypeVar, Type, Optional
from sqlalchemy import select, Select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
EntityType = TypeVar("EntityType") EntityType = TypeVar("EntityType")
@ -26,8 +27,7 @@ class RepDeleteMixin(RepBaseMixin[EntityType]):
await self.session.commit() await self.session.commit()
class RepCreateMixin(Generic[EntityType, CreateType]): class RepCreateMixin(RepBaseMixin[EntityType], Generic[EntityType, CreateType]):
session: AsyncSession
entity_class: Type[EntityType] entity_class: Type[EntityType]
async def create(self, data: CreateType) -> int: async def create(self, data: CreateType) -> int:
@ -36,3 +36,19 @@ class RepCreateMixin(Generic[EntityType, CreateType]):
await self.session.commit() await self.session.commit()
await self.session.refresh(obj) await self.session.refresh(obj)
return obj.id return obj.id
class GetByIdMixin(RepBaseMixin[EntityType]):
entity_class: Type[EntityType]
def _process_get_by_id_stmt(self, stmt: Select) -> Select:
return stmt
async def get_by_id(self, item_id: int) -> Optional[EntityType]:
stmt = select(self.entity_class).where(self.entity_class.id == item_id)
if hasattr(self, "is_deleted"):
stmt = stmt.where(self.entity_class.is_deleted.is_(False))
stmt = self._process_get_by_id_stmt(stmt)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()

View File

@ -1,11 +1,9 @@
from typing import Optional from sqlalchemy import select, Select
from sqlalchemy import select
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from models.project import Project from models.project import Project
from repositories.base import BaseRepository from repositories.base import BaseRepository
from repositories.mixins import RepDeleteMixin, RepCreateMixin from repositories.mixins import RepDeleteMixin, RepCreateMixin, GetByIdMixin
from schemas.project import CreateProjectSchema, UpdateProjectSchema from schemas.project import CreateProjectSchema, UpdateProjectSchema
@ -13,6 +11,7 @@ class ProjectRepository(
BaseRepository, BaseRepository,
RepDeleteMixin[Project], RepDeleteMixin[Project],
RepCreateMixin[Project, CreateProjectSchema], RepCreateMixin[Project, CreateProjectSchema],
GetByIdMixin[Project],
): ):
entity_class = Project entity_class = Project
@ -21,14 +20,8 @@ class ProjectRepository(
result = await self.session.execute(stmt) result = await self.session.execute(stmt)
return list(result.scalars().all()) return list(result.scalars().all())
async def get_by_id(self, project_id: int) -> Optional[Project]: def _process_get_by_id_stmt(self, stmt: Select) -> Select:
stmt = ( return stmt.options(selectinload(Project.boards))
select(Project)
.where(Project.id == project_id, Project.is_deleted.is_(False))
.options(selectinload(Project.boards))
)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def update(self, project: Project, data: UpdateProjectSchema) -> Project: async def update(self, project: Project, data: UpdateProjectSchema) -> Project:
project.name = data.name if data.name else project.name project.name = data.name if data.name else project.name

View File

@ -1,17 +1,16 @@
from typing import Optional
from sqlalchemy import select, func from sqlalchemy import select, func
from models import Status, Deal from models import Status, Deal
from repositories.base import BaseRepository from repositories.base import BaseRepository
from repositories.mixins import RepDeleteMixin, RepCreateMixin from repositories.mixins import RepDeleteMixin, RepCreateMixin, GetByIdMixin
from schemas.status import UpdateStatusSchema, CreateStatusSchema from schemas.status import UpdateStatusSchema, CreateStatusSchema
class StatusRepository( class StatusRepository(
BaseRepository, BaseRepository,
RepDeleteMixin[Status], RepDeleteMixin[Status],
RepCreateMixin[Deal, CreateStatusSchema], RepCreateMixin[Status, CreateStatusSchema],
GetByIdMixin[Status],
): ):
entity_class = Status entity_class = Status
@ -24,13 +23,6 @@ class StatusRepository(
result = await self.session.execute(stmt) result = await self.session.execute(stmt)
return list(result.scalars().all()) return list(result.scalars().all())
async def get_by_id(self, status_id: int) -> Optional[Status]:
stmt = select(Status).where(
Status.id == status_id, Status.is_deleted.is_(False)
)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def get_deals_count(self, status_id: int) -> int: async def get_deals_count(self, status_id: int) -> int:
stmt = select(func.count(Deal.id)).where(Deal.status_id == status_id) stmt = select(func.count(Deal.id)).where(Deal.status_id == status_id)
result = await self.session.execute(stmt) result = await self.session.execute(stmt)