refactor: repository get all mixin

This commit is contained in:
2025-09-05 11:13:49 +04:00
parent c1d3ac98f0
commit 7990e7d460
4 changed files with 58 additions and 19 deletions

View File

@ -1,28 +1,29 @@
from sqlalchemy import select, Select
from sqlalchemy import Select
from sqlalchemy.orm import selectinload
from models import Board
from repositories.base import BaseRepository
from repositories.mixins import RepDeleteMixin, RepCreateMixin, GetByIdMixin
from repositories.mixins import (
RepDeleteMixin,
RepCreateMixin,
GetByIdMixin,
GetAllMixin,
)
from schemas.board import UpdateBoardSchema, CreateBoardSchema
class BoardRepository(
BaseRepository,
GetAllMixin[Board],
RepDeleteMixin[Board],
RepCreateMixin[Board, CreateBoardSchema],
GetByIdMixin[Board],
):
entity_class = Board
async def get_all(self, project_id: int) -> list[Board]:
stmt = (
select(Board)
.where(Board.is_deleted.is_(False), Board.project_id == project_id)
.order_by(Board.lexorank)
)
result = await self.session.execute(stmt)
return list(result.scalars().all())
def _process_get_all_stmt_with_args(self, stmt: Select, *args) -> Select:
project_id = args[0]
return stmt.where(Board.project_id == project_id).order_by(Board.lexorank)
def _process_get_by_id_stmt(self, stmt: Select) -> Select:
return stmt.options(selectinload(Board.deals))

View File

@ -1,4 +1,4 @@
from typing import Generic, TypeVar, Type, Optional
from typing import Generic, TypeVar, Type, Optional, overload
from sqlalchemy import select, Select
from sqlalchemy.ext.asyncio import AsyncSession
@ -52,3 +52,25 @@ class GetByIdMixin(RepBaseMixin[EntityType]):
stmt = self._process_get_by_id_stmt(stmt)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
class GetAllMixin(RepBaseMixin[EntityType]):
entity_class: Type[EntityType]
def _process_get_all_stmt_with_args(self, stmt: Select, *args) -> Select:
return stmt
def _process_get_all_stmt(self, stmt: Select) -> Select:
return stmt
async def get_all(self, *args) -> list[EntityType]:
stmt = select(self.entity_class).order_by(self.entity_class.id)
if hasattr(self, "is_deleted"):
stmt = stmt.where(self.entity_class.is_deleted.is_(False))
if args:
stmt = self._process_get_all_stmt_with_args(stmt, *args)
else:
stmt = self._process_get_all_stmt(stmt)
result = await self.session.execute(stmt)
return list(result.scalars().all())

View File

@ -1,24 +1,30 @@
from sqlalchemy import select, Select
from typing import overload
from sqlalchemy import Select
from sqlalchemy.orm import selectinload
from models.project import Project
from repositories.base import BaseRepository
from repositories.mixins import RepDeleteMixin, RepCreateMixin, GetByIdMixin
from repositories.mixins import (
RepDeleteMixin,
RepCreateMixin,
GetByIdMixin,
GetAllMixin,
)
from schemas.project import CreateProjectSchema, UpdateProjectSchema
class ProjectRepository(
BaseRepository,
GetAllMixin[Project],
RepDeleteMixin[Project],
RepCreateMixin[Project, CreateProjectSchema],
GetByIdMixin[Project],
):
entity_class = Project
async def get_all(self) -> list[Project]:
stmt = select(Project).where(Project.is_deleted.is_(False)).order_by(Project.id)
result = await self.session.execute(stmt)
return list(result.scalars().all())
def _process_get_all_stmt(self, stmt: Select) -> Select:
return stmt.order_by(Project.id)
def _process_get_by_id_stmt(self, stmt: Select) -> Select:
return stmt.options(selectinload(Project.boards))

View File

@ -1,19 +1,29 @@
from sqlalchemy import select, func
from sqlalchemy import select, func, Select
from models import Status, Deal
from repositories.base import BaseRepository
from repositories.mixins import RepDeleteMixin, RepCreateMixin, GetByIdMixin
from repositories.mixins import (
RepDeleteMixin,
RepCreateMixin,
GetByIdMixin,
GetAllMixin,
)
from schemas.status import UpdateStatusSchema, CreateStatusSchema
class StatusRepository(
BaseRepository,
GetAllMixin[Status],
RepDeleteMixin[Status],
RepCreateMixin[Status, CreateStatusSchema],
GetByIdMixin[Status],
):
entity_class = Status
def _process_get_all_stmt_with_args(self, stmt: Select, *args) -> Select:
board_id = args[0]
return stmt.where(Status.board_id == board_id).order_by(Status.lexorank)
async def get_all(self, board_id: int) -> list[Status]:
stmt = (
select(Status)