refactor: repository get all mixin

This commit is contained in:
2025-09-05 11:13:49 +04:00
parent c1d3ac98f0
commit 7990e7d460
4 changed files with 58 additions and 19 deletions

View File

@ -1,28 +1,29 @@
from sqlalchemy import select, Select from sqlalchemy import Select
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from models import Board from models import Board
from repositories.base import BaseRepository from repositories.base import BaseRepository
from repositories.mixins import RepDeleteMixin, RepCreateMixin, GetByIdMixin from repositories.mixins import (
RepDeleteMixin,
RepCreateMixin,
GetByIdMixin,
GetAllMixin,
)
from schemas.board import UpdateBoardSchema, CreateBoardSchema from schemas.board import UpdateBoardSchema, CreateBoardSchema
class BoardRepository( class BoardRepository(
BaseRepository, BaseRepository,
GetAllMixin[Board],
RepDeleteMixin[Board], RepDeleteMixin[Board],
RepCreateMixin[Board, CreateBoardSchema], RepCreateMixin[Board, CreateBoardSchema],
GetByIdMixin[Board], GetByIdMixin[Board],
): ):
entity_class = Board entity_class = Board
async def get_all(self, project_id: int) -> list[Board]: def _process_get_all_stmt_with_args(self, stmt: Select, *args) -> Select:
stmt = ( project_id = args[0]
select(Board) return stmt.where(Board.project_id == project_id).order_by(Board.lexorank)
.where(Board.is_deleted.is_(False), Board.project_id == project_id)
.order_by(Board.lexorank)
)
result = await self.session.execute(stmt)
return list(result.scalars().all())
def _process_get_by_id_stmt(self, stmt: Select) -> Select: def _process_get_by_id_stmt(self, stmt: Select) -> Select:
return stmt.options(selectinload(Board.deals)) return stmt.options(selectinload(Board.deals))

View File

@ -1,4 +1,4 @@
from typing import Generic, TypeVar, Type, Optional from typing import Generic, TypeVar, Type, Optional, overload
from sqlalchemy import select, Select from sqlalchemy import select, Select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@ -52,3 +52,25 @@ class GetByIdMixin(RepBaseMixin[EntityType]):
stmt = self._process_get_by_id_stmt(stmt) stmt = self._process_get_by_id_stmt(stmt)
result = await self.session.execute(stmt) result = await self.session.execute(stmt)
return result.scalar_one_or_none() return result.scalar_one_or_none()
class GetAllMixin(RepBaseMixin[EntityType]):
entity_class: Type[EntityType]
def _process_get_all_stmt_with_args(self, stmt: Select, *args) -> Select:
return stmt
def _process_get_all_stmt(self, stmt: Select) -> Select:
return stmt
async def get_all(self, *args) -> list[EntityType]:
stmt = select(self.entity_class).order_by(self.entity_class.id)
if hasattr(self, "is_deleted"):
stmt = stmt.where(self.entity_class.is_deleted.is_(False))
if args:
stmt = self._process_get_all_stmt_with_args(stmt, *args)
else:
stmt = self._process_get_all_stmt(stmt)
result = await self.session.execute(stmt)
return list(result.scalars().all())

View File

@ -1,24 +1,30 @@
from sqlalchemy import select, Select from typing import overload
from sqlalchemy import Select
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from models.project import Project from models.project import Project
from repositories.base import BaseRepository from repositories.base import BaseRepository
from repositories.mixins import RepDeleteMixin, RepCreateMixin, GetByIdMixin from repositories.mixins import (
RepDeleteMixin,
RepCreateMixin,
GetByIdMixin,
GetAllMixin,
)
from schemas.project import CreateProjectSchema, UpdateProjectSchema from schemas.project import CreateProjectSchema, UpdateProjectSchema
class ProjectRepository( class ProjectRepository(
BaseRepository, BaseRepository,
GetAllMixin[Project],
RepDeleteMixin[Project], RepDeleteMixin[Project],
RepCreateMixin[Project, CreateProjectSchema], RepCreateMixin[Project, CreateProjectSchema],
GetByIdMixin[Project], GetByIdMixin[Project],
): ):
entity_class = Project entity_class = Project
async def get_all(self) -> list[Project]: def _process_get_all_stmt(self, stmt: Select) -> Select:
stmt = select(Project).where(Project.is_deleted.is_(False)).order_by(Project.id) return stmt.order_by(Project.id)
result = await self.session.execute(stmt)
return list(result.scalars().all())
def _process_get_by_id_stmt(self, stmt: Select) -> Select: def _process_get_by_id_stmt(self, stmt: Select) -> Select:
return stmt.options(selectinload(Project.boards)) return stmt.options(selectinload(Project.boards))

View File

@ -1,19 +1,29 @@
from sqlalchemy import select, func from sqlalchemy import select, func, Select
from models import Status, Deal from models import Status, Deal
from repositories.base import BaseRepository from repositories.base import BaseRepository
from repositories.mixins import RepDeleteMixin, RepCreateMixin, GetByIdMixin from repositories.mixins import (
RepDeleteMixin,
RepCreateMixin,
GetByIdMixin,
GetAllMixin,
)
from schemas.status import UpdateStatusSchema, CreateStatusSchema from schemas.status import UpdateStatusSchema, CreateStatusSchema
class StatusRepository( class StatusRepository(
BaseRepository, BaseRepository,
GetAllMixin[Status],
RepDeleteMixin[Status], RepDeleteMixin[Status],
RepCreateMixin[Status, CreateStatusSchema], RepCreateMixin[Status, CreateStatusSchema],
GetByIdMixin[Status], GetByIdMixin[Status],
): ):
entity_class = Status entity_class = Status
def _process_get_all_stmt_with_args(self, stmt: Select, *args) -> Select:
board_id = args[0]
return stmt.where(Status.board_id == board_id).order_by(Status.lexorank)
async def get_all(self, board_id: int) -> list[Status]: async def get_all(self, board_id: int) -> list[Status]:
stmt = ( stmt = (
select(Status) select(Status)