131 lines
4.2 KiB
Python
131 lines
4.2 KiB
Python
from typing import Type, Optional, TypeVar, Generic
|
|
|
|
from sqlalchemy import select, Select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from repositories.base import BaseRepository
|
|
from schemas.base import BaseSchema
|
|
from utils.exceptions import ObjectNotFoundException
|
|
|
|
EntityType = TypeVar("EntityType")
|
|
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseSchema)
|
|
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseSchema)
|
|
|
|
|
|
class RepBaseMixin(Generic[EntityType]):
|
|
session: AsyncSession
|
|
|
|
|
|
class RepDeleteMixin(Generic[EntityType], RepBaseMixin[EntityType]):
|
|
async def _before_delete(self, obj: EntityType) -> None:
|
|
pass
|
|
|
|
async def delete(self, obj: EntityType, is_soft: bool) -> None:
|
|
await self._before_delete(obj)
|
|
|
|
if not is_soft:
|
|
await self.session.delete(obj)
|
|
await self.session.commit()
|
|
return
|
|
|
|
if not hasattr(obj, "is_deleted"):
|
|
raise AttributeError(
|
|
f"{obj.__class__.__name__} does not support soft delete (missing is_deleted field)"
|
|
)
|
|
obj.is_deleted = True
|
|
self.session.add(obj)
|
|
await self.session.commit()
|
|
|
|
|
|
class RepCreateMixin(Generic[EntityType, CreateSchemaType], RepBaseMixin[EntityType]):
|
|
entity_class: Type[EntityType]
|
|
|
|
async def create(self, data: CreateSchemaType) -> int:
|
|
obj = self.entity_class(**data.model_dump())
|
|
self.session.add(obj)
|
|
await self.session.commit()
|
|
await self.session.refresh(obj)
|
|
return obj.id
|
|
|
|
|
|
class RepUpdateMixin(Generic[EntityType, UpdateSchemaType], RepBaseMixin[EntityType]):
|
|
async def _apply_update_data_to_model(
|
|
self,
|
|
model: EntityType,
|
|
data: UpdateSchemaType,
|
|
with_commit: Optional[bool] = False,
|
|
fields: Optional[list[str]] = None,
|
|
) -> EntityType:
|
|
if fields is None:
|
|
fields = data.model_dump().keys()
|
|
|
|
for field in fields:
|
|
value = getattr(data, field)
|
|
if value is not None:
|
|
setattr(model, field, value)
|
|
|
|
if with_commit:
|
|
self.session.add(model)
|
|
await self.session.commit()
|
|
await self.session.refresh(model)
|
|
return model
|
|
|
|
async def update(self, entity: EntityType, data: UpdateSchemaType) -> EntityType:
|
|
return await self._apply_update_data_to_model(entity, data, True)
|
|
|
|
|
|
class RepGetByIdMixin(Generic[EntityType], RepBaseMixin[EntityType]):
|
|
entity_class: Type[EntityType]
|
|
entity_not_found_msg = "Entity not found"
|
|
|
|
def _process_get_by_id_stmt(self, stmt: Select) -> Select:
|
|
return stmt
|
|
|
|
async def get_by_id(
|
|
self, item_id: int, raise_if_not_found: Optional[bool] = True
|
|
) -> Optional[EntityType]:
|
|
stmt = select(self.entity_class).where(self.entity_class.id == item_id)
|
|
if hasattr(self, "is_deleted"):
|
|
stmt = stmt.where(self.entity_class.is_deleted.is_(False))
|
|
|
|
stmt = self._process_get_by_id_stmt(stmt)
|
|
result = (await self.session.execute(stmt)).scalar_one_or_none()
|
|
if result is None and raise_if_not_found:
|
|
raise ObjectNotFoundException(self.entity_not_found_msg)
|
|
|
|
return result
|
|
|
|
|
|
class RepGetAllMixin(Generic[EntityType], RepBaseMixin[EntityType]):
|
|
entity_class: Type[EntityType]
|
|
|
|
def _process_get_all_stmt_with_args(self, stmt: Select, *args) -> Select:
|
|
return stmt
|
|
|
|
def _process_get_all_stmt(self, stmt: Select) -> Select:
|
|
return stmt
|
|
|
|
async def get_all(self, *args) -> list[EntityType]:
|
|
stmt = select(self.entity_class)
|
|
if hasattr(self, "is_deleted"):
|
|
stmt = stmt.where(self.entity_class.is_deleted.is_(False))
|
|
|
|
if args:
|
|
stmt = self._process_get_all_stmt_with_args(stmt, *args)
|
|
else:
|
|
stmt = self._process_get_all_stmt(stmt)
|
|
result = await self.session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
|
|
class RepCrudMixin(
|
|
Generic[EntityType, CreateSchemaType, UpdateSchemaType],
|
|
BaseRepository,
|
|
RepGetAllMixin[EntityType],
|
|
RepCreateMixin[EntityType, CreateSchemaType],
|
|
RepUpdateMixin[EntityType, UpdateSchemaType],
|
|
RepGetByIdMixin[EntityType],
|
|
RepDeleteMixin[EntityType],
|
|
):
|
|
pass
|