Files
Crm-Backend/repositories/mixins.py

100 lines
3.1 KiB
Python

from typing import Type, Optional
from sqlalchemy import select, Select
from sqlalchemy.ext.asyncio import AsyncSession
class RepBaseMixin[EntityType]:
session: AsyncSession
class RepDeleteMixin[EntityType](RepBaseMixin[EntityType]):
async def delete(self, obj: EntityType, is_soft: bool) -> None:
if not is_soft:
await self.session.delete(obj)
await self.session.commit()
return
if not hasattr(obj, "is_deleted"):
raise AttributeError(
f"{obj.__class__.__name__} does not support soft delete (missing is_deleted field)"
)
obj.is_deleted = True
self.session.add(obj)
await self.session.commit()
class RepCreateMixin[EntityType, CreateType](RepBaseMixin[EntityType]):
entity_class: Type[EntityType]
async def create(self, data: CreateType) -> int:
obj = self.entity_class(**data.model_dump())
self.session.add(obj)
await self.session.commit()
await self.session.refresh(obj)
return obj.id
class RepUpdateMixin[EntityType, UpdateType](RepBaseMixin[EntityType]):
async def _apply_update_data_to_model(
self,
model: EntityType,
data: UpdateType,
with_commit: Optional[bool] = False,
fields: Optional[list[str]] = None,
) -> EntityType:
if fields is None:
fields = data.model_dump().keys()
for field in fields:
value = getattr(data, field)
if value is not None:
setattr(model, field, value)
if with_commit:
self.session.add(model)
await self.session.commit()
await self.session.refresh(model)
return model
async def update(self, entity: EntityType, data: UpdateType) -> EntityType:
pass
class RepGetByIdMixin[EntityType](RepBaseMixin[EntityType]):
entity_class: Type[EntityType]
def _process_get_by_id_stmt(self, stmt: Select) -> Select:
return stmt
async def get_by_id(self, item_id: int) -> Optional[EntityType]:
stmt = select(self.entity_class).where(self.entity_class.id == item_id)
if hasattr(self, "is_deleted"):
stmt = stmt.where(self.entity_class.is_deleted.is_(False))
stmt = self._process_get_by_id_stmt(stmt)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
class RepGetAllMixin[EntityType](RepBaseMixin[EntityType]):
entity_class: Type[EntityType]
def _process_get_all_stmt_with_args(self, stmt: Select, *args) -> Select:
return stmt
def _process_get_all_stmt(self, stmt: Select) -> Select:
return stmt
async def get_all(self, *args) -> list[EntityType]:
stmt = select(self.entity_class)
if hasattr(self, "is_deleted"):
stmt = stmt.where(self.entity_class.is_deleted.is_(False))
if args:
stmt = self._process_get_all_stmt_with_args(stmt, *args)
else:
stmt = self._process_get_all_stmt(stmt)
result = await self.session.execute(stmt)
return list(result.scalars().all())