refactor: repository create mixin

This commit is contained in:
2025-09-05 00:04:09 +04:00
parent c632fb8037
commit e5be35be35
9 changed files with 54 additions and 43 deletions

View File

@ -5,11 +5,17 @@ from sqlalchemy.orm import selectinload
from models import Board from models import Board
from repositories.base import BaseRepository from repositories.base import BaseRepository
from repositories.mixins import RepDeleteMixin from repositories.mixins import RepDeleteMixin, RepCreateMixin
from schemas.board import UpdateBoardSchema, CreateBoardSchema from schemas.board import UpdateBoardSchema, CreateBoardSchema
class BoardRepository(BaseRepository, RepDeleteMixin[Board]): class BoardRepository(
BaseRepository,
RepDeleteMixin[Board],
RepCreateMixin[Board, CreateBoardSchema],
):
entity_class = Board
async def get_all(self, project_id: int) -> list[Board]: async def get_all(self, project_id: int) -> list[Board]:
stmt = ( stmt = (
select(Board) select(Board)
@ -28,13 +34,6 @@ class BoardRepository(BaseRepository, RepDeleteMixin[Board]):
result = await self.session.execute(stmt) result = await self.session.execute(stmt)
return result.scalar_one_or_none() return result.scalar_one_or_none()
async def create(self, data: CreateBoardSchema) -> Board:
board = Board(**data.model_dump())
self.session.add(board)
await self.session.commit()
await self.session.refresh(board)
return board
async def update(self, board: Board, data: UpdateBoardSchema) -> Board: async def update(self, board: Board, data: UpdateBoardSchema) -> Board:
board.lexorank = data.lexorank if data.lexorank else board.lexorank board.lexorank = data.lexorank if data.lexorank else board.lexorank
board.name = data.name if data.name else board.name board.name = data.name if data.name else board.name

View File

@ -5,13 +5,17 @@ from sqlalchemy.orm import joinedload
from models import Deal, CardStatusHistory, Board from models import Deal, CardStatusHistory, Board
from repositories.base import BaseRepository from repositories.base import BaseRepository
from repositories.mixins import RepDeleteMixin from repositories.mixins import RepDeleteMixin, RepCreateMixin
from schemas.base import SortDir from schemas.base import SortDir
from schemas.deal import UpdateDealSchema, CreateDealSchema from schemas.deal import UpdateDealSchema, CreateDealSchema
from utils.sorting import apply_sorting from utils.sorting import apply_sorting
class DealRepository(BaseRepository, RepDeleteMixin[Deal]): class DealRepository(
BaseRepository, RepDeleteMixin[Deal], RepCreateMixin[Deal, CreateDealSchema]
):
entity_class = Deal
async def get_all( async def get_all(
self, self,
page: Optional[int], page: Optional[int],
@ -63,13 +67,6 @@ class DealRepository(BaseRepository, RepDeleteMixin[Deal]):
result = await self.session.execute(stmt) result = await self.session.execute(stmt)
return result.scalar_one_or_none() return result.scalar_one_or_none()
async def create(self, data: CreateDealSchema) -> Deal:
deal = Deal(**data.model_dump())
self.session.add(deal)
await self.session.commit()
await self.session.refresh(deal)
return await self.get_by_id(deal.id)
async def update(self, deal: Deal, data: UpdateDealSchema) -> Deal: async def update(self, deal: Deal, data: UpdateDealSchema) -> Deal:
deal.lexorank = data.lexorank if data.lexorank else deal.lexorank deal.lexorank = data.lexorank if data.lexorank else deal.lexorank
deal.name = data.name if data.name else deal.name deal.name = data.name if data.name else deal.name

View File

@ -1,8 +1,9 @@
from typing import Generic, TypeVar from typing import Generic, TypeVar, Type
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
EntityType = TypeVar("EntityType") EntityType = TypeVar("EntityType")
CreateType = TypeVar("CreateType")
class RepBaseMixin(Generic[EntityType]): class RepBaseMixin(Generic[EntityType]):
@ -18,8 +19,20 @@ class RepDeleteMixin(RepBaseMixin[EntityType]):
if not hasattr(obj, "is_deleted"): if not hasattr(obj, "is_deleted"):
raise AttributeError( raise AttributeError(
f"{obj.__class__.__name__} does not support soft delete (missing `is_deleted` field)" f"{obj.__class__.__name__} does not support soft delete (missing is_deleted field)"
) )
obj.is_deleted = True obj.is_deleted = True
self.session.add(obj) self.session.add(obj)
await self.session.commit() await self.session.commit()
class RepCreateMixin(Generic[EntityType, CreateType]):
session: AsyncSession
entity_class: Type[EntityType]
async def create(self, data: CreateType) -> int:
obj = self.entity_class(**data.model_dump())
self.session.add(obj)
await self.session.commit()
await self.session.refresh(obj)
return obj.id

View File

@ -5,11 +5,17 @@ from sqlalchemy.orm import selectinload
from models.project import Project from models.project import Project
from repositories.base import BaseRepository from repositories.base import BaseRepository
from repositories.mixins import RepDeleteMixin from repositories.mixins import RepDeleteMixin, RepCreateMixin
from schemas.project import CreateProjectSchema, UpdateProjectSchema from schemas.project import CreateProjectSchema, UpdateProjectSchema
class ProjectRepository(BaseRepository, RepDeleteMixin[Project]): class ProjectRepository(
BaseRepository,
RepDeleteMixin[Project],
RepCreateMixin[Project, CreateProjectSchema],
):
entity_class = Project
async def get_all(self) -> list[Project]: async def get_all(self) -> list[Project]:
stmt = select(Project).where(Project.is_deleted.is_(False)).order_by(Project.id) stmt = select(Project).where(Project.is_deleted.is_(False)).order_by(Project.id)
result = await self.session.execute(stmt) result = await self.session.execute(stmt)
@ -24,13 +30,6 @@ class ProjectRepository(BaseRepository, RepDeleteMixin[Project]):
result = await self.session.execute(stmt) result = await self.session.execute(stmt)
return result.scalar_one_or_none() return result.scalar_one_or_none()
async def create(self, data: CreateProjectSchema) -> Project:
project = Project(**data.model_dump())
self.session.add(project)
await self.session.commit()
await self.session.refresh(project)
return project
async def update(self, project: Project, data: UpdateProjectSchema) -> Project: async def update(self, project: Project, data: UpdateProjectSchema) -> Project:
project.name = data.name if data.name else project.name project.name = data.name if data.name else project.name

View File

@ -4,11 +4,17 @@ from sqlalchemy import select, func
from models import Status, Deal from models import Status, Deal
from repositories.base import BaseRepository from repositories.base import BaseRepository
from repositories.mixins import RepDeleteMixin from repositories.mixins import RepDeleteMixin, RepCreateMixin
from schemas.status import UpdateStatusSchema, CreateStatusSchema from schemas.status import UpdateStatusSchema, CreateStatusSchema
class StatusRepository(BaseRepository, RepDeleteMixin[Status]): class StatusRepository(
BaseRepository,
RepDeleteMixin[Status],
RepCreateMixin[Deal, CreateStatusSchema],
):
entity_class = Status
async def get_all(self, board_id: int) -> list[Status]: async def get_all(self, board_id: int) -> list[Status]:
stmt = ( stmt = (
select(Status) select(Status)
@ -30,13 +36,6 @@ class StatusRepository(BaseRepository, RepDeleteMixin[Status]):
result = await self.session.execute(stmt) result = await self.session.execute(stmt)
return result.scalar() return result.scalar()
async def create(self, data: CreateStatusSchema) -> Status:
status = Status(**data.model_dump())
self.session.add(status)
await self.session.commit()
await self.session.refresh(status)
return status
async def update(self, status: Status, data: UpdateStatusSchema) -> Status: async def update(self, status: Status, data: UpdateStatusSchema) -> Status:
status.lexorank = data.lexorank if data.lexorank else status.lexorank status.lexorank = data.lexorank if data.lexorank else status.lexorank
status.name = data.name if data.name else status.name status.name = data.name if data.name else status.name

View File

@ -16,7 +16,8 @@ class BoardService:
) )
async def create_board(self, request: CreateBoardRequest) -> CreateBoardResponse: async def create_board(self, request: CreateBoardRequest) -> CreateBoardResponse:
board = await self.repository.create(request.entity) board_id = await self.repository.create(request.entity)
board = await self.repository.get_by_id(board_id)
return CreateBoardResponse( return CreateBoardResponse(
entity=BoardSchema.model_validate(board), entity=BoardSchema.model_validate(board),
message="Доска успешно создана", message="Доска успешно создана",

View File

@ -38,7 +38,8 @@ class DealService:
) )
async def create_deal(self, request: CreateDealRequest) -> CreateDealResponse: async def create_deal(self, request: CreateDealRequest) -> CreateDealResponse:
deal = await self.repository.create(request.entity) deal_id = await self.repository.create(request.entity)
deal = await self.repository.get_by_id(deal_id)
return CreateDealResponse( return CreateDealResponse(
entity=DealSchema.model_validate(deal), entity=DealSchema.model_validate(deal),
message="Сделка успешно создана", message="Сделка успешно создана",
@ -58,4 +59,4 @@ class DealService:
raise HTTPException(status_code=404, detail="Сделка не найдена") raise HTTPException(status_code=404, detail="Сделка не найдена")
await self.repository.delete(deal, True) await self.repository.delete(deal, True)
return DeleteDealResponse(message="Сделка успешно удалена") return DeleteDealResponse(message="Сделка успешно удалена")

View File

@ -18,7 +18,8 @@ class ProjectService:
async def create_project( async def create_project(
self, request: CreateProjectRequest self, request: CreateProjectRequest
) -> CreateProjectResponse: ) -> CreateProjectResponse:
project = await self.repository.create(request.entity) project_id = await self.repository.create(request.entity)
project = await self.repository.get_by_id(project_id)
return CreateProjectResponse( return CreateProjectResponse(
entity=ProjectSchema.model_validate(project), entity=ProjectSchema.model_validate(project),
message="Проект успешно создан", message="Проект успешно создан",

View File

@ -17,7 +17,8 @@ class StatusService:
) )
async def create_status(self, request: CreateStatusRequest) -> CreateStatusResponse: async def create_status(self, request: CreateStatusRequest) -> CreateStatusResponse:
status = await self.repository.create(request.entity) status_id = await self.repository.create(request.entity)
status = await self.repository.get_by_id(status_id)
return CreateStatusResponse( return CreateStatusResponse(
entity=StatusSchema.model_validate(status), entity=StatusSchema.model_validate(status),
message="Статус успешно создан", message="Статус успешно создан",