Files
Crm-Backend/repositories/mixins.py
2025-10-04 10:13:24 +04:00

131 lines
4.2 KiB
Python

from typing import Type, Optional, TypeVar, Generic
from sqlalchemy import select, Select
from sqlalchemy.ext.asyncio import AsyncSession
from repositories.base import BaseRepository
from schemas.base import BaseSchema
from utils.exceptions import ObjectNotFoundException
EntityType = TypeVar("EntityType")
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseSchema)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseSchema)
class RepBaseMixin(Generic[EntityType]):
session: AsyncSession
class RepDeleteMixin(Generic[EntityType], RepBaseMixin[EntityType]):
async def _before_delete(self, obj: EntityType) -> None:
pass
async def delete(self, obj: EntityType, is_soft: bool) -> None:
await self._before_delete(obj)
if not is_soft:
await self.session.delete(obj)
await self.session.commit()
return
if not hasattr(obj, "is_deleted"):
raise AttributeError(
f"{obj.__class__.__name__} does not support soft delete (missing is_deleted field)"
)
obj.is_deleted = True
self.session.add(obj)
await self.session.commit()
class RepCreateMixin(Generic[EntityType, CreateSchemaType], RepBaseMixin[EntityType]):
entity_class: Type[EntityType]
async def create(self, data: CreateSchemaType) -> int:
obj = self.entity_class(**data.model_dump())
self.session.add(obj)
await self.session.commit()
await self.session.refresh(obj)
return obj.id
class RepUpdateMixin(Generic[EntityType, UpdateSchemaType], RepBaseMixin[EntityType]):
async def _apply_update_data_to_model(
self,
model: EntityType,
data: UpdateSchemaType,
with_commit: Optional[bool] = False,
fields: Optional[list[str]] = None,
) -> EntityType:
if fields is None:
fields = data.model_dump().keys()
for field in fields:
value = getattr(data, field)
if value is not None:
setattr(model, field, value)
if with_commit:
self.session.add(model)
await self.session.commit()
await self.session.refresh(model)
return model
async def update(self, entity: EntityType, data: UpdateSchemaType) -> EntityType:
return await self._apply_update_data_to_model(entity, data, True)
class RepGetByIdMixin(Generic[EntityType], RepBaseMixin[EntityType]):
entity_class: Type[EntityType]
entity_not_found_msg = "Entity not found"
def _process_get_by_id_stmt(self, stmt: Select) -> Select:
return stmt
async def get_by_id(
self, item_id: int, raise_if_not_found: Optional[bool] = True
) -> Optional[EntityType]:
stmt = select(self.entity_class).where(self.entity_class.id == item_id)
if hasattr(self, "is_deleted"):
stmt = stmt.where(self.entity_class.is_deleted.is_(False))
stmt = self._process_get_by_id_stmt(stmt)
result = (await self.session.execute(stmt)).scalar_one_or_none()
if result is None and raise_if_not_found:
raise ObjectNotFoundException(self.entity_not_found_msg)
return result
class RepGetAllMixin(Generic[EntityType], RepBaseMixin[EntityType]):
entity_class: Type[EntityType]
def _process_get_all_stmt_with_args(self, stmt: Select, *args) -> Select:
return stmt
def _process_get_all_stmt(self, stmt: Select) -> Select:
return stmt
async def get_all(self, *args) -> list[EntityType]:
stmt = select(self.entity_class)
if hasattr(self, "is_deleted"):
stmt = stmt.where(self.entity_class.is_deleted.is_(False))
if args:
stmt = self._process_get_all_stmt_with_args(stmt, *args)
else:
stmt = self._process_get_all_stmt(stmt)
result = await self.session.execute(stmt)
return list(result.scalars().all())
class RepCrudMixin(
Generic[EntityType, CreateSchemaType, UpdateSchemaType],
BaseRepository,
RepGetAllMixin[EntityType],
RepCreateMixin[EntityType, CreateSchemaType],
RepUpdateMixin[EntityType, UpdateSchemaType],
RepGetByIdMixin[EntityType],
RepDeleteMixin[EntityType],
):
pass