From e5be35be356b3d1a057ccd9c6e37eaed5a029cc8 Mon Sep 17 00:00:00 2001 From: AlexSserb Date: Fri, 5 Sep 2025 00:04:09 +0400 Subject: [PATCH] refactor: repository create mixin --- repositories/board.py | 17 ++++++++--------- repositories/deal.py | 15 ++++++--------- repositories/mixins.py | 17 +++++++++++++++-- repositories/project.py | 17 ++++++++--------- repositories/status.py | 17 ++++++++--------- services/board.py | 3 ++- services/deal.py | 5 +++-- services/project.py | 3 ++- services/status.py | 3 ++- 9 files changed, 54 insertions(+), 43 deletions(-) diff --git a/repositories/board.py b/repositories/board.py index 0715b77..f9cbf86 100644 --- a/repositories/board.py +++ b/repositories/board.py @@ -5,11 +5,17 @@ from sqlalchemy.orm import selectinload from models import Board from repositories.base import BaseRepository -from repositories.mixins import RepDeleteMixin +from repositories.mixins import RepDeleteMixin, RepCreateMixin from schemas.board import UpdateBoardSchema, CreateBoardSchema -class BoardRepository(BaseRepository, RepDeleteMixin[Board]): +class BoardRepository( + BaseRepository, + RepDeleteMixin[Board], + RepCreateMixin[Board, CreateBoardSchema], +): + entity_class = Board + async def get_all(self, project_id: int) -> list[Board]: stmt = ( select(Board) @@ -28,13 +34,6 @@ class BoardRepository(BaseRepository, RepDeleteMixin[Board]): result = await self.session.execute(stmt) return result.scalar_one_or_none() - async def create(self, data: CreateBoardSchema) -> Board: - board = Board(**data.model_dump()) - self.session.add(board) - await self.session.commit() - await self.session.refresh(board) - return board - async def update(self, board: Board, data: UpdateBoardSchema) -> Board: board.lexorank = data.lexorank if data.lexorank else board.lexorank board.name = data.name if data.name else board.name diff --git a/repositories/deal.py b/repositories/deal.py index 0fb8e8f..f007b2e 100644 --- a/repositories/deal.py +++ b/repositories/deal.py @@ -5,13 +5,17 @@ from sqlalchemy.orm import joinedload from models import Deal, CardStatusHistory, Board from repositories.base import BaseRepository -from repositories.mixins import RepDeleteMixin +from repositories.mixins import RepDeleteMixin, RepCreateMixin from schemas.base import SortDir from schemas.deal import UpdateDealSchema, CreateDealSchema from utils.sorting import apply_sorting -class DealRepository(BaseRepository, RepDeleteMixin[Deal]): +class DealRepository( + BaseRepository, RepDeleteMixin[Deal], RepCreateMixin[Deal, CreateDealSchema] +): + entity_class = Deal + async def get_all( self, page: Optional[int], @@ -63,13 +67,6 @@ class DealRepository(BaseRepository, RepDeleteMixin[Deal]): result = await self.session.execute(stmt) return result.scalar_one_or_none() - async def create(self, data: CreateDealSchema) -> Deal: - deal = Deal(**data.model_dump()) - self.session.add(deal) - await self.session.commit() - await self.session.refresh(deal) - return await self.get_by_id(deal.id) - async def update(self, deal: Deal, data: UpdateDealSchema) -> Deal: deal.lexorank = data.lexorank if data.lexorank else deal.lexorank deal.name = data.name if data.name else deal.name diff --git a/repositories/mixins.py b/repositories/mixins.py index 9c2f7b7..5fd991f 100644 --- a/repositories/mixins.py +++ b/repositories/mixins.py @@ -1,8 +1,9 @@ -from typing import Generic, TypeVar +from typing import Generic, TypeVar, Type from sqlalchemy.ext.asyncio import AsyncSession EntityType = TypeVar("EntityType") +CreateType = TypeVar("CreateType") class RepBaseMixin(Generic[EntityType]): @@ -18,8 +19,20 @@ class RepDeleteMixin(RepBaseMixin[EntityType]): if not hasattr(obj, "is_deleted"): raise AttributeError( - f"{obj.__class__.__name__} does not support soft delete (missing `is_deleted` field)" + f"{obj.__class__.__name__} does not support soft delete (missing is_deleted field)" ) obj.is_deleted = True self.session.add(obj) await self.session.commit() + + +class RepCreateMixin(Generic[EntityType, CreateType]): + session: AsyncSession + entity_class: Type[EntityType] + + async def create(self, data: CreateType) -> int: + obj = self.entity_class(**data.model_dump()) + self.session.add(obj) + await self.session.commit() + await self.session.refresh(obj) + return obj.id diff --git a/repositories/project.py b/repositories/project.py index 4b29214..57ab0d1 100644 --- a/repositories/project.py +++ b/repositories/project.py @@ -5,11 +5,17 @@ from sqlalchemy.orm import selectinload from models.project import Project from repositories.base import BaseRepository -from repositories.mixins import RepDeleteMixin +from repositories.mixins import RepDeleteMixin, RepCreateMixin from schemas.project import CreateProjectSchema, UpdateProjectSchema -class ProjectRepository(BaseRepository, RepDeleteMixin[Project]): +class ProjectRepository( + BaseRepository, + RepDeleteMixin[Project], + RepCreateMixin[Project, CreateProjectSchema], +): + entity_class = Project + async def get_all(self) -> list[Project]: stmt = select(Project).where(Project.is_deleted.is_(False)).order_by(Project.id) result = await self.session.execute(stmt) @@ -24,13 +30,6 @@ class ProjectRepository(BaseRepository, RepDeleteMixin[Project]): result = await self.session.execute(stmt) return result.scalar_one_or_none() - async def create(self, data: CreateProjectSchema) -> Project: - project = Project(**data.model_dump()) - self.session.add(project) - await self.session.commit() - await self.session.refresh(project) - return project - async def update(self, project: Project, data: UpdateProjectSchema) -> Project: project.name = data.name if data.name else project.name diff --git a/repositories/status.py b/repositories/status.py index 713d3ed..fe77474 100644 --- a/repositories/status.py +++ b/repositories/status.py @@ -4,11 +4,17 @@ from sqlalchemy import select, func from models import Status, Deal from repositories.base import BaseRepository -from repositories.mixins import RepDeleteMixin +from repositories.mixins import RepDeleteMixin, RepCreateMixin from schemas.status import UpdateStatusSchema, CreateStatusSchema -class StatusRepository(BaseRepository, RepDeleteMixin[Status]): +class StatusRepository( + BaseRepository, + RepDeleteMixin[Status], + RepCreateMixin[Deal, CreateStatusSchema], +): + entity_class = Status + async def get_all(self, board_id: int) -> list[Status]: stmt = ( select(Status) @@ -30,13 +36,6 @@ class StatusRepository(BaseRepository, RepDeleteMixin[Status]): result = await self.session.execute(stmt) return result.scalar() - async def create(self, data: CreateStatusSchema) -> Status: - status = Status(**data.model_dump()) - self.session.add(status) - await self.session.commit() - await self.session.refresh(status) - return status - async def update(self, status: Status, data: UpdateStatusSchema) -> Status: status.lexorank = data.lexorank if data.lexorank else status.lexorank status.name = data.name if data.name else status.name diff --git a/services/board.py b/services/board.py index 8ffc4aa..9521651 100644 --- a/services/board.py +++ b/services/board.py @@ -16,7 +16,8 @@ class BoardService: ) async def create_board(self, request: CreateBoardRequest) -> CreateBoardResponse: - board = await self.repository.create(request.entity) + board_id = await self.repository.create(request.entity) + board = await self.repository.get_by_id(board_id) return CreateBoardResponse( entity=BoardSchema.model_validate(board), message="Доска успешно создана", diff --git a/services/deal.py b/services/deal.py index b47d9f7..fa8d251 100644 --- a/services/deal.py +++ b/services/deal.py @@ -38,7 +38,8 @@ class DealService: ) async def create_deal(self, request: CreateDealRequest) -> CreateDealResponse: - deal = await self.repository.create(request.entity) + deal_id = await self.repository.create(request.entity) + deal = await self.repository.get_by_id(deal_id) return CreateDealResponse( entity=DealSchema.model_validate(deal), message="Сделка успешно создана", @@ -58,4 +59,4 @@ class DealService: raise HTTPException(status_code=404, detail="Сделка не найдена") await self.repository.delete(deal, True) - return DeleteDealResponse(message="Сделка успешно удалена") + return DeleteDealResponse(message="Сделка успешно удалена") \ No newline at end of file diff --git a/services/project.py b/services/project.py index b2974bb..51685f0 100644 --- a/services/project.py +++ b/services/project.py @@ -18,7 +18,8 @@ class ProjectService: async def create_project( self, request: CreateProjectRequest ) -> CreateProjectResponse: - project = await self.repository.create(request.entity) + project_id = await self.repository.create(request.entity) + project = await self.repository.get_by_id(project_id) return CreateProjectResponse( entity=ProjectSchema.model_validate(project), message="Проект успешно создан", diff --git a/services/status.py b/services/status.py index 5c18d15..4029fe8 100644 --- a/services/status.py +++ b/services/status.py @@ -17,7 +17,8 @@ class StatusService: ) async def create_status(self, request: CreateStatusRequest) -> CreateStatusResponse: - status = await self.repository.create(request.entity) + status_id = await self.repository.create(request.entity) + status = await self.repository.get_by_id(status_id) return CreateStatusResponse( entity=StatusSchema.model_validate(status), message="Статус успешно создан",