diff --git a/repositories/board.py b/repositories/board.py index be4263c..44f27e2 100644 --- a/repositories/board.py +++ b/repositories/board.py @@ -8,6 +8,7 @@ from repositories.mixins import ( RepCreateMixin, GetByIdMixin, GetAllMixin, + RepUpdateMixin, ) from schemas.board import UpdateBoardSchema, CreateBoardSchema @@ -17,6 +18,7 @@ class BoardRepository( GetAllMixin[Board], RepDeleteMixin[Board], RepCreateMixin[Board, CreateBoardSchema], + RepUpdateMixin[Board, UpdateBoardSchema], GetByIdMixin[Board], ): entity_class = Board @@ -29,10 +31,4 @@ class BoardRepository( return stmt.options(selectinload(Board.deals)) async def update(self, board: Board, data: UpdateBoardSchema) -> Board: - board.lexorank = data.lexorank if data.lexorank else board.lexorank - board.name = data.name if data.name else board.name - - self.session.add(board) - await self.session.commit() - await self.session.refresh(board) - return board + return await self._apply_update_data_to_model(board, data, True) diff --git a/repositories/deal.py b/repositories/deal.py index 75a72d5..8bab9bb 100644 --- a/repositories/deal.py +++ b/repositories/deal.py @@ -5,7 +5,12 @@ from sqlalchemy.orm import joinedload from models import Deal, CardStatusHistory, Board from repositories.base import BaseRepository -from repositories.mixins import RepDeleteMixin, RepCreateMixin, GetByIdMixin +from repositories.mixins import ( + RepDeleteMixin, + RepCreateMixin, + GetByIdMixin, + RepUpdateMixin, +) from schemas.base import SortDir from schemas.deal import UpdateDealSchema, CreateDealSchema from utils.sorting import apply_sorting @@ -15,6 +20,7 @@ class DealRepository( BaseRepository, RepDeleteMixin[Deal], RepCreateMixin[Deal, CreateDealSchema], + RepUpdateMixin[Deal, UpdateDealSchema], GetByIdMixin[Deal], ): entity_class = Deal @@ -65,9 +71,8 @@ class DealRepository( return stmt.options(joinedload(Deal.status), joinedload(Deal.board)) async def update(self, deal: Deal, data: UpdateDealSchema) -> Deal: - deal.lexorank = data.lexorank if data.lexorank else deal.lexorank - deal.name = data.name if data.name else deal.name - deal.board_id = data.board_id if data.board_id else deal.board_id + fields = ["lexorank", "name", "board_id"] + deal = await self._apply_update_data_to_model(deal, data, False, fields) if data.status_id and deal.status_id != data.status_id: deal.status_history.append( diff --git a/repositories/mixins.py b/repositories/mixins.py index 36282c5..b06b708 100644 --- a/repositories/mixins.py +++ b/repositories/mixins.py @@ -5,6 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession EntityType = TypeVar("EntityType") CreateType = TypeVar("CreateType") +UpdateType = TypeVar("UpdateType") class RepBaseMixin(Generic[EntityType]): @@ -38,6 +39,29 @@ class RepCreateMixin(RepBaseMixin[EntityType], Generic[EntityType, CreateType]): return obj.id +class RepUpdateMixin(RepBaseMixin[EntityType], Generic[EntityType, UpdateType]): + async def _apply_update_data_to_model( + self, + model: EntityType, + data: UpdateType, + with_commit: Optional[bool] = False, + fields: Optional[list[str]] = None, + ) -> EntityType: + if fields is None: + fields = data.model_dump().keys() + + for field in fields: + value = getattr(data, field) + if value is not None: + setattr(model, field, value) + + if with_commit: + self.session.add(model) + await self.session.commit() + await self.session.refresh(model) + return model + + class GetByIdMixin(RepBaseMixin[EntityType]): entity_class: Type[EntityType] diff --git a/repositories/project.py b/repositories/project.py index 706e0a6..b5ab5e1 100644 --- a/repositories/project.py +++ b/repositories/project.py @@ -1,5 +1,3 @@ -from typing import overload - from sqlalchemy import Select from sqlalchemy.orm import selectinload @@ -10,6 +8,7 @@ from repositories.mixins import ( RepCreateMixin, GetByIdMixin, GetAllMixin, + RepUpdateMixin, ) from schemas.project import CreateProjectSchema, UpdateProjectSchema @@ -19,6 +18,7 @@ class ProjectRepository( GetAllMixin[Project], RepDeleteMixin[Project], RepCreateMixin[Project, CreateProjectSchema], + RepUpdateMixin[Project, UpdateProjectSchema], GetByIdMixin[Project], ): entity_class = Project @@ -30,9 +30,4 @@ class ProjectRepository( return stmt.options(selectinload(Project.boards)) async def update(self, project: Project, data: UpdateProjectSchema) -> Project: - project.name = data.name if data.name else project.name - - self.session.add(project) - await self.session.commit() - await self.session.refresh(project) - return project + return await self._apply_update_data_to_model(project, data, True) diff --git a/repositories/status.py b/repositories/status.py index 7d1f6f1..ab8d4d3 100644 --- a/repositories/status.py +++ b/repositories/status.py @@ -7,6 +7,7 @@ from repositories.mixins import ( RepCreateMixin, GetByIdMixin, GetAllMixin, + RepUpdateMixin, ) from schemas.status import UpdateStatusSchema, CreateStatusSchema @@ -16,6 +17,7 @@ class StatusRepository( GetAllMixin[Status], RepDeleteMixin[Status], RepCreateMixin[Status, CreateStatusSchema], + RepUpdateMixin[Status, UpdateStatusSchema], GetByIdMixin[Status], ): entity_class = Status @@ -39,10 +41,4 @@ class StatusRepository( return result.scalar() async def update(self, status: Status, data: UpdateStatusSchema) -> Status: - status.lexorank = data.lexorank if data.lexorank else status.lexorank - status.name = data.name if data.name else status.name - - self.session.add(status) - await self.session.commit() - await self.session.refresh(status) - return status + return await self._apply_update_data_to_model(status, data, True)