refactor: repository get by id mixin
This commit is contained in:
@ -1,11 +1,9 @@
|
|||||||
from typing import Optional
|
from sqlalchemy import select, Select
|
||||||
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
from models import Board
|
from models import Board
|
||||||
from repositories.base import BaseRepository
|
from repositories.base import BaseRepository
|
||||||
from repositories.mixins import RepDeleteMixin, RepCreateMixin
|
from repositories.mixins import RepDeleteMixin, RepCreateMixin, GetByIdMixin
|
||||||
from schemas.board import UpdateBoardSchema, CreateBoardSchema
|
from schemas.board import UpdateBoardSchema, CreateBoardSchema
|
||||||
|
|
||||||
|
|
||||||
@ -13,6 +11,7 @@ class BoardRepository(
|
|||||||
BaseRepository,
|
BaseRepository,
|
||||||
RepDeleteMixin[Board],
|
RepDeleteMixin[Board],
|
||||||
RepCreateMixin[Board, CreateBoardSchema],
|
RepCreateMixin[Board, CreateBoardSchema],
|
||||||
|
GetByIdMixin[Board],
|
||||||
):
|
):
|
||||||
entity_class = Board
|
entity_class = Board
|
||||||
|
|
||||||
@ -25,14 +24,8 @@ class BoardRepository(
|
|||||||
result = await self.session.execute(stmt)
|
result = await self.session.execute(stmt)
|
||||||
return list(result.scalars().all())
|
return list(result.scalars().all())
|
||||||
|
|
||||||
async def get_by_id(self, board_id: int) -> Optional[Board]:
|
def _process_get_by_id_stmt(self, stmt: Select) -> Select:
|
||||||
stmt = (
|
return stmt.options(selectinload(Board.deals))
|
||||||
select(Board)
|
|
||||||
.where(Board.id == board_id, Board.is_deleted.is_(False))
|
|
||||||
.options(selectinload(Board.deals))
|
|
||||||
)
|
|
||||||
result = await self.session.execute(stmt)
|
|
||||||
return result.scalar_one_or_none()
|
|
||||||
|
|
||||||
async def update(self, board: Board, data: UpdateBoardSchema) -> Board:
|
async def update(self, board: Board, data: UpdateBoardSchema) -> Board:
|
||||||
board.lexorank = data.lexorank if data.lexorank else board.lexorank
|
board.lexorank = data.lexorank if data.lexorank else board.lexorank
|
||||||
|
|||||||
@ -1,18 +1,21 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select, Select
|
||||||
from sqlalchemy.orm import joinedload
|
from sqlalchemy.orm import joinedload
|
||||||
|
|
||||||
from models import Deal, CardStatusHistory, Board
|
from models import Deal, CardStatusHistory, Board
|
||||||
from repositories.base import BaseRepository
|
from repositories.base import BaseRepository
|
||||||
from repositories.mixins import RepDeleteMixin, RepCreateMixin
|
from repositories.mixins import RepDeleteMixin, RepCreateMixin, GetByIdMixin
|
||||||
from schemas.base import SortDir
|
from schemas.base import SortDir
|
||||||
from schemas.deal import UpdateDealSchema, CreateDealSchema
|
from schemas.deal import UpdateDealSchema, CreateDealSchema
|
||||||
from utils.sorting import apply_sorting
|
from utils.sorting import apply_sorting
|
||||||
|
|
||||||
|
|
||||||
class DealRepository(
|
class DealRepository(
|
||||||
BaseRepository, RepDeleteMixin[Deal], RepCreateMixin[Deal, CreateDealSchema]
|
BaseRepository,
|
||||||
|
RepDeleteMixin[Deal],
|
||||||
|
RepCreateMixin[Deal, CreateDealSchema],
|
||||||
|
GetByIdMixin[Deal],
|
||||||
):
|
):
|
||||||
entity_class = Deal
|
entity_class = Deal
|
||||||
|
|
||||||
@ -58,14 +61,8 @@ class DealRepository(
|
|||||||
result = await self.session.execute(stmt)
|
result = await self.session.execute(stmt)
|
||||||
return list(result.scalars().all()), total_items
|
return list(result.scalars().all()), total_items
|
||||||
|
|
||||||
async def get_by_id(self, deal_id: int) -> Optional[Deal]:
|
def _process_get_by_id_stmt(self, stmt: Select) -> Select:
|
||||||
stmt = (
|
return stmt.options(joinedload(Deal.status), joinedload(Deal.board))
|
||||||
select(Deal)
|
|
||||||
.options(joinedload(Deal.status), joinedload(Deal.board))
|
|
||||||
.where(Deal.id == deal_id, Deal.is_deleted.is_(False))
|
|
||||||
)
|
|
||||||
result = await self.session.execute(stmt)
|
|
||||||
return result.scalar_one_or_none()
|
|
||||||
|
|
||||||
async def update(self, deal: Deal, data: UpdateDealSchema) -> Deal:
|
async def update(self, deal: Deal, data: UpdateDealSchema) -> Deal:
|
||||||
deal.lexorank = data.lexorank if data.lexorank else deal.lexorank
|
deal.lexorank = data.lexorank if data.lexorank else deal.lexorank
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from typing import Generic, TypeVar, Type
|
from typing import Generic, TypeVar, Type, Optional
|
||||||
|
|
||||||
|
from sqlalchemy import select, Select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
EntityType = TypeVar("EntityType")
|
EntityType = TypeVar("EntityType")
|
||||||
@ -26,8 +27,7 @@ class RepDeleteMixin(RepBaseMixin[EntityType]):
|
|||||||
await self.session.commit()
|
await self.session.commit()
|
||||||
|
|
||||||
|
|
||||||
class RepCreateMixin(Generic[EntityType, CreateType]):
|
class RepCreateMixin(RepBaseMixin[EntityType], Generic[EntityType, CreateType]):
|
||||||
session: AsyncSession
|
|
||||||
entity_class: Type[EntityType]
|
entity_class: Type[EntityType]
|
||||||
|
|
||||||
async def create(self, data: CreateType) -> int:
|
async def create(self, data: CreateType) -> int:
|
||||||
@ -36,3 +36,19 @@ class RepCreateMixin(Generic[EntityType, CreateType]):
|
|||||||
await self.session.commit()
|
await self.session.commit()
|
||||||
await self.session.refresh(obj)
|
await self.session.refresh(obj)
|
||||||
return obj.id
|
return obj.id
|
||||||
|
|
||||||
|
|
||||||
|
class GetByIdMixin(RepBaseMixin[EntityType]):
|
||||||
|
entity_class: Type[EntityType]
|
||||||
|
|
||||||
|
def _process_get_by_id_stmt(self, stmt: Select) -> Select:
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
async def get_by_id(self, item_id: int) -> Optional[EntityType]:
|
||||||
|
stmt = select(self.entity_class).where(self.entity_class.id == item_id)
|
||||||
|
if hasattr(self, "is_deleted"):
|
||||||
|
stmt = stmt.where(self.entity_class.is_deleted.is_(False))
|
||||||
|
|
||||||
|
stmt = self._process_get_by_id_stmt(stmt)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|||||||
@ -1,11 +1,9 @@
|
|||||||
from typing import Optional
|
from sqlalchemy import select, Select
|
||||||
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
from models.project import Project
|
from models.project import Project
|
||||||
from repositories.base import BaseRepository
|
from repositories.base import BaseRepository
|
||||||
from repositories.mixins import RepDeleteMixin, RepCreateMixin
|
from repositories.mixins import RepDeleteMixin, RepCreateMixin, GetByIdMixin
|
||||||
from schemas.project import CreateProjectSchema, UpdateProjectSchema
|
from schemas.project import CreateProjectSchema, UpdateProjectSchema
|
||||||
|
|
||||||
|
|
||||||
@ -13,6 +11,7 @@ class ProjectRepository(
|
|||||||
BaseRepository,
|
BaseRepository,
|
||||||
RepDeleteMixin[Project],
|
RepDeleteMixin[Project],
|
||||||
RepCreateMixin[Project, CreateProjectSchema],
|
RepCreateMixin[Project, CreateProjectSchema],
|
||||||
|
GetByIdMixin[Project],
|
||||||
):
|
):
|
||||||
entity_class = Project
|
entity_class = Project
|
||||||
|
|
||||||
@ -21,14 +20,8 @@ class ProjectRepository(
|
|||||||
result = await self.session.execute(stmt)
|
result = await self.session.execute(stmt)
|
||||||
return list(result.scalars().all())
|
return list(result.scalars().all())
|
||||||
|
|
||||||
async def get_by_id(self, project_id: int) -> Optional[Project]:
|
def _process_get_by_id_stmt(self, stmt: Select) -> Select:
|
||||||
stmt = (
|
return stmt.options(selectinload(Project.boards))
|
||||||
select(Project)
|
|
||||||
.where(Project.id == project_id, Project.is_deleted.is_(False))
|
|
||||||
.options(selectinload(Project.boards))
|
|
||||||
)
|
|
||||||
result = await self.session.execute(stmt)
|
|
||||||
return result.scalar_one_or_none()
|
|
||||||
|
|
||||||
async def update(self, project: Project, data: UpdateProjectSchema) -> Project:
|
async def update(self, project: Project, data: UpdateProjectSchema) -> Project:
|
||||||
project.name = data.name if data.name else project.name
|
project.name = data.name if data.name else project.name
|
||||||
|
|||||||
@ -1,17 +1,16 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from sqlalchemy import select, func
|
from sqlalchemy import select, func
|
||||||
|
|
||||||
from models import Status, Deal
|
from models import Status, Deal
|
||||||
from repositories.base import BaseRepository
|
from repositories.base import BaseRepository
|
||||||
from repositories.mixins import RepDeleteMixin, RepCreateMixin
|
from repositories.mixins import RepDeleteMixin, RepCreateMixin, GetByIdMixin
|
||||||
from schemas.status import UpdateStatusSchema, CreateStatusSchema
|
from schemas.status import UpdateStatusSchema, CreateStatusSchema
|
||||||
|
|
||||||
|
|
||||||
class StatusRepository(
|
class StatusRepository(
|
||||||
BaseRepository,
|
BaseRepository,
|
||||||
RepDeleteMixin[Status],
|
RepDeleteMixin[Status],
|
||||||
RepCreateMixin[Deal, CreateStatusSchema],
|
RepCreateMixin[Status, CreateStatusSchema],
|
||||||
|
GetByIdMixin[Status],
|
||||||
):
|
):
|
||||||
entity_class = Status
|
entity_class = Status
|
||||||
|
|
||||||
@ -24,13 +23,6 @@ class StatusRepository(
|
|||||||
result = await self.session.execute(stmt)
|
result = await self.session.execute(stmt)
|
||||||
return list(result.scalars().all())
|
return list(result.scalars().all())
|
||||||
|
|
||||||
async def get_by_id(self, status_id: int) -> Optional[Status]:
|
|
||||||
stmt = select(Status).where(
|
|
||||||
Status.id == status_id, Status.is_deleted.is_(False)
|
|
||||||
)
|
|
||||||
result = await self.session.execute(stmt)
|
|
||||||
return result.scalar_one_or_none()
|
|
||||||
|
|
||||||
async def get_deals_count(self, status_id: int) -> int:
|
async def get_deals_count(self, status_id: int) -> int:
|
||||||
stmt = select(func.count(Deal.id)).where(Deal.status_id == status_id)
|
stmt = select(func.count(Deal.id)).where(Deal.status_id == status_id)
|
||||||
result = await self.session.execute(stmt)
|
result = await self.session.execute(stmt)
|
||||||
|
|||||||
Reference in New Issue
Block a user