refactor: repository create mixin
This commit is contained in:
@ -5,11 +5,17 @@ from sqlalchemy.orm import selectinload
|
||||
|
||||
from models import Board
|
||||
from repositories.base import BaseRepository
|
||||
from repositories.mixins import RepDeleteMixin
|
||||
from repositories.mixins import RepDeleteMixin, RepCreateMixin
|
||||
from schemas.board import UpdateBoardSchema, CreateBoardSchema
|
||||
|
||||
|
||||
class BoardRepository(BaseRepository, RepDeleteMixin[Board]):
|
||||
class BoardRepository(
|
||||
BaseRepository,
|
||||
RepDeleteMixin[Board],
|
||||
RepCreateMixin[Board, CreateBoardSchema],
|
||||
):
|
||||
entity_class = Board
|
||||
|
||||
async def get_all(self, project_id: int) -> list[Board]:
|
||||
stmt = (
|
||||
select(Board)
|
||||
@ -28,13 +34,6 @@ class BoardRepository(BaseRepository, RepDeleteMixin[Board]):
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def create(self, data: CreateBoardSchema) -> Board:
|
||||
board = Board(**data.model_dump())
|
||||
self.session.add(board)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(board)
|
||||
return board
|
||||
|
||||
async def update(self, board: Board, data: UpdateBoardSchema) -> Board:
|
||||
board.lexorank = data.lexorank if data.lexorank else board.lexorank
|
||||
board.name = data.name if data.name else board.name
|
||||
|
||||
@ -5,13 +5,17 @@ from sqlalchemy.orm import joinedload
|
||||
|
||||
from models import Deal, CardStatusHistory, Board
|
||||
from repositories.base import BaseRepository
|
||||
from repositories.mixins import RepDeleteMixin
|
||||
from repositories.mixins import RepDeleteMixin, RepCreateMixin
|
||||
from schemas.base import SortDir
|
||||
from schemas.deal import UpdateDealSchema, CreateDealSchema
|
||||
from utils.sorting import apply_sorting
|
||||
|
||||
|
||||
class DealRepository(BaseRepository, RepDeleteMixin[Deal]):
|
||||
class DealRepository(
|
||||
BaseRepository, RepDeleteMixin[Deal], RepCreateMixin[Deal, CreateDealSchema]
|
||||
):
|
||||
entity_class = Deal
|
||||
|
||||
async def get_all(
|
||||
self,
|
||||
page: Optional[int],
|
||||
@ -63,13 +67,6 @@ class DealRepository(BaseRepository, RepDeleteMixin[Deal]):
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def create(self, data: CreateDealSchema) -> Deal:
|
||||
deal = Deal(**data.model_dump())
|
||||
self.session.add(deal)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(deal)
|
||||
return await self.get_by_id(deal.id)
|
||||
|
||||
async def update(self, deal: Deal, data: UpdateDealSchema) -> Deal:
|
||||
deal.lexorank = data.lexorank if data.lexorank else deal.lexorank
|
||||
deal.name = data.name if data.name else deal.name
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
from typing import Generic, TypeVar
|
||||
from typing import Generic, TypeVar, Type
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
EntityType = TypeVar("EntityType")
|
||||
CreateType = TypeVar("CreateType")
|
||||
|
||||
|
||||
class RepBaseMixin(Generic[EntityType]):
|
||||
@ -18,8 +19,20 @@ class RepDeleteMixin(RepBaseMixin[EntityType]):
|
||||
|
||||
if not hasattr(obj, "is_deleted"):
|
||||
raise AttributeError(
|
||||
f"{obj.__class__.__name__} does not support soft delete (missing `is_deleted` field)"
|
||||
f"{obj.__class__.__name__} does not support soft delete (missing is_deleted field)"
|
||||
)
|
||||
obj.is_deleted = True
|
||||
self.session.add(obj)
|
||||
await self.session.commit()
|
||||
|
||||
|
||||
class RepCreateMixin(Generic[EntityType, CreateType]):
|
||||
session: AsyncSession
|
||||
entity_class: Type[EntityType]
|
||||
|
||||
async def create(self, data: CreateType) -> int:
|
||||
obj = self.entity_class(**data.model_dump())
|
||||
self.session.add(obj)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(obj)
|
||||
return obj.id
|
||||
|
||||
@ -5,11 +5,17 @@ from sqlalchemy.orm import selectinload
|
||||
|
||||
from models.project import Project
|
||||
from repositories.base import BaseRepository
|
||||
from repositories.mixins import RepDeleteMixin
|
||||
from repositories.mixins import RepDeleteMixin, RepCreateMixin
|
||||
from schemas.project import CreateProjectSchema, UpdateProjectSchema
|
||||
|
||||
|
||||
class ProjectRepository(BaseRepository, RepDeleteMixin[Project]):
|
||||
class ProjectRepository(
|
||||
BaseRepository,
|
||||
RepDeleteMixin[Project],
|
||||
RepCreateMixin[Project, CreateProjectSchema],
|
||||
):
|
||||
entity_class = Project
|
||||
|
||||
async def get_all(self) -> list[Project]:
|
||||
stmt = select(Project).where(Project.is_deleted.is_(False)).order_by(Project.id)
|
||||
result = await self.session.execute(stmt)
|
||||
@ -24,13 +30,6 @@ class ProjectRepository(BaseRepository, RepDeleteMixin[Project]):
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def create(self, data: CreateProjectSchema) -> Project:
|
||||
project = Project(**data.model_dump())
|
||||
self.session.add(project)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(project)
|
||||
return project
|
||||
|
||||
async def update(self, project: Project, data: UpdateProjectSchema) -> Project:
|
||||
project.name = data.name if data.name else project.name
|
||||
|
||||
|
||||
@ -4,11 +4,17 @@ from sqlalchemy import select, func
|
||||
|
||||
from models import Status, Deal
|
||||
from repositories.base import BaseRepository
|
||||
from repositories.mixins import RepDeleteMixin
|
||||
from repositories.mixins import RepDeleteMixin, RepCreateMixin
|
||||
from schemas.status import UpdateStatusSchema, CreateStatusSchema
|
||||
|
||||
|
||||
class StatusRepository(BaseRepository, RepDeleteMixin[Status]):
|
||||
class StatusRepository(
|
||||
BaseRepository,
|
||||
RepDeleteMixin[Status],
|
||||
RepCreateMixin[Deal, CreateStatusSchema],
|
||||
):
|
||||
entity_class = Status
|
||||
|
||||
async def get_all(self, board_id: int) -> list[Status]:
|
||||
stmt = (
|
||||
select(Status)
|
||||
@ -30,13 +36,6 @@ class StatusRepository(BaseRepository, RepDeleteMixin[Status]):
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar()
|
||||
|
||||
async def create(self, data: CreateStatusSchema) -> Status:
|
||||
status = Status(**data.model_dump())
|
||||
self.session.add(status)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(status)
|
||||
return status
|
||||
|
||||
async def update(self, status: Status, data: UpdateStatusSchema) -> Status:
|
||||
status.lexorank = data.lexorank if data.lexorank else status.lexorank
|
||||
status.name = data.name if data.name else status.name
|
||||
|
||||
Reference in New Issue
Block a user