refactor: repository get by id mixin
This commit is contained in:
@ -1,11 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import select, Select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from models import Board
|
||||
from repositories.base import BaseRepository
|
||||
from repositories.mixins import RepDeleteMixin, RepCreateMixin
|
||||
from repositories.mixins import RepDeleteMixin, RepCreateMixin, GetByIdMixin
|
||||
from schemas.board import UpdateBoardSchema, CreateBoardSchema
|
||||
|
||||
|
||||
@ -13,6 +11,7 @@ class BoardRepository(
|
||||
BaseRepository,
|
||||
RepDeleteMixin[Board],
|
||||
RepCreateMixin[Board, CreateBoardSchema],
|
||||
GetByIdMixin[Board],
|
||||
):
|
||||
entity_class = Board
|
||||
|
||||
@ -25,14 +24,8 @@ class BoardRepository(
|
||||
result = await self.session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_by_id(self, board_id: int) -> Optional[Board]:
|
||||
stmt = (
|
||||
select(Board)
|
||||
.where(Board.id == board_id, Board.is_deleted.is_(False))
|
||||
.options(selectinload(Board.deals))
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
def _process_get_by_id_stmt(self, stmt: Select) -> Select:
|
||||
return stmt.options(selectinload(Board.deals))
|
||||
|
||||
async def update(self, board: Board, data: UpdateBoardSchema) -> Board:
|
||||
board.lexorank = data.lexorank if data.lexorank else board.lexorank
|
||||
|
||||
@ -1,18 +1,21 @@
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import select, Select
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from models import Deal, CardStatusHistory, Board
|
||||
from repositories.base import BaseRepository
|
||||
from repositories.mixins import RepDeleteMixin, RepCreateMixin
|
||||
from repositories.mixins import RepDeleteMixin, RepCreateMixin, GetByIdMixin
|
||||
from schemas.base import SortDir
|
||||
from schemas.deal import UpdateDealSchema, CreateDealSchema
|
||||
from utils.sorting import apply_sorting
|
||||
|
||||
|
||||
class DealRepository(
|
||||
BaseRepository, RepDeleteMixin[Deal], RepCreateMixin[Deal, CreateDealSchema]
|
||||
BaseRepository,
|
||||
RepDeleteMixin[Deal],
|
||||
RepCreateMixin[Deal, CreateDealSchema],
|
||||
GetByIdMixin[Deal],
|
||||
):
|
||||
entity_class = Deal
|
||||
|
||||
@ -58,14 +61,8 @@ class DealRepository(
|
||||
result = await self.session.execute(stmt)
|
||||
return list(result.scalars().all()), total_items
|
||||
|
||||
async def get_by_id(self, deal_id: int) -> Optional[Deal]:
|
||||
stmt = (
|
||||
select(Deal)
|
||||
.options(joinedload(Deal.status), joinedload(Deal.board))
|
||||
.where(Deal.id == deal_id, Deal.is_deleted.is_(False))
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
def _process_get_by_id_stmt(self, stmt: Select) -> Select:
|
||||
return stmt.options(joinedload(Deal.status), joinedload(Deal.board))
|
||||
|
||||
async def update(self, deal: Deal, data: UpdateDealSchema) -> Deal:
|
||||
deal.lexorank = data.lexorank if data.lexorank else deal.lexorank
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from typing import Generic, TypeVar, Type
|
||||
from typing import Generic, TypeVar, Type, Optional
|
||||
|
||||
from sqlalchemy import select, Select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
EntityType = TypeVar("EntityType")
|
||||
@ -26,8 +27,7 @@ class RepDeleteMixin(RepBaseMixin[EntityType]):
|
||||
await self.session.commit()
|
||||
|
||||
|
||||
class RepCreateMixin(Generic[EntityType, CreateType]):
|
||||
session: AsyncSession
|
||||
class RepCreateMixin(RepBaseMixin[EntityType], Generic[EntityType, CreateType]):
|
||||
entity_class: Type[EntityType]
|
||||
|
||||
async def create(self, data: CreateType) -> int:
|
||||
@ -36,3 +36,19 @@ class RepCreateMixin(Generic[EntityType, CreateType]):
|
||||
await self.session.commit()
|
||||
await self.session.refresh(obj)
|
||||
return obj.id
|
||||
|
||||
|
||||
class GetByIdMixin(RepBaseMixin[EntityType]):
|
||||
entity_class: Type[EntityType]
|
||||
|
||||
def _process_get_by_id_stmt(self, stmt: Select) -> Select:
|
||||
return stmt
|
||||
|
||||
async def get_by_id(self, item_id: int) -> Optional[EntityType]:
|
||||
stmt = select(self.entity_class).where(self.entity_class.id == item_id)
|
||||
if hasattr(self, "is_deleted"):
|
||||
stmt = stmt.where(self.entity_class.is_deleted.is_(False))
|
||||
|
||||
stmt = self._process_get_by_id_stmt(stmt)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import select, Select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from models.project import Project
|
||||
from repositories.base import BaseRepository
|
||||
from repositories.mixins import RepDeleteMixin, RepCreateMixin
|
||||
from repositories.mixins import RepDeleteMixin, RepCreateMixin, GetByIdMixin
|
||||
from schemas.project import CreateProjectSchema, UpdateProjectSchema
|
||||
|
||||
|
||||
@ -13,6 +11,7 @@ class ProjectRepository(
|
||||
BaseRepository,
|
||||
RepDeleteMixin[Project],
|
||||
RepCreateMixin[Project, CreateProjectSchema],
|
||||
GetByIdMixin[Project],
|
||||
):
|
||||
entity_class = Project
|
||||
|
||||
@ -21,14 +20,8 @@ class ProjectRepository(
|
||||
result = await self.session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_by_id(self, project_id: int) -> Optional[Project]:
|
||||
stmt = (
|
||||
select(Project)
|
||||
.where(Project.id == project_id, Project.is_deleted.is_(False))
|
||||
.options(selectinload(Project.boards))
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
def _process_get_by_id_stmt(self, stmt: Select) -> Select:
|
||||
return stmt.options(selectinload(Project.boards))
|
||||
|
||||
async def update(self, project: Project, data: UpdateProjectSchema) -> Project:
|
||||
project.name = data.name if data.name else project.name
|
||||
|
||||
@ -1,17 +1,16 @@
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, func
|
||||
|
||||
from models import Status, Deal
|
||||
from repositories.base import BaseRepository
|
||||
from repositories.mixins import RepDeleteMixin, RepCreateMixin
|
||||
from repositories.mixins import RepDeleteMixin, RepCreateMixin, GetByIdMixin
|
||||
from schemas.status import UpdateStatusSchema, CreateStatusSchema
|
||||
|
||||
|
||||
class StatusRepository(
|
||||
BaseRepository,
|
||||
RepDeleteMixin[Status],
|
||||
RepCreateMixin[Deal, CreateStatusSchema],
|
||||
RepCreateMixin[Status, CreateStatusSchema],
|
||||
GetByIdMixin[Status],
|
||||
):
|
||||
entity_class = Status
|
||||
|
||||
@ -24,13 +23,6 @@ class StatusRepository(
|
||||
result = await self.session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_by_id(self, status_id: int) -> Optional[Status]:
|
||||
stmt = select(Status).where(
|
||||
Status.id == status_id, Status.is_deleted.is_(False)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_deals_count(self, status_id: int) -> int:
|
||||
stmt = select(func.count(Deal.id)).where(Deal.status_id == status_id)
|
||||
result = await self.session.execute(stmt)
|
||||
|
||||
Reference in New Issue
Block a user