from typing import Type, Optional, TypeVar, Generic from sqlalchemy import select, Select from sqlalchemy.ext.asyncio import AsyncSession from repositories.base import BaseRepository from schemas.base import BaseSchema from utils.exceptions import ObjectNotFoundException EntityType = TypeVar("EntityType") CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseSchema) UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseSchema) class RepBaseMixin(Generic[EntityType]): session: AsyncSession class RepDeleteMixin(Generic[EntityType], RepBaseMixin[EntityType]): async def _before_delete(self, obj: EntityType) -> None: pass async def delete(self, obj: EntityType, is_soft: bool) -> None: await self._before_delete(obj) if not is_soft: await self.session.delete(obj) await self.session.commit() return if not hasattr(obj, "is_deleted"): raise AttributeError( f"{obj.__class__.__name__} does not support soft delete (missing is_deleted field)" ) obj.is_deleted = True self.session.add(obj) await self.session.commit() class RepCreateMixin(Generic[EntityType, CreateSchemaType], RepBaseMixin[EntityType]): entity_class: Type[EntityType] async def _prepare_create(self, data: CreateSchemaType) -> dict: return data.model_dump() async def _after_create(self, obj: EntityType, data: CreateSchemaType) -> None: pass async def create(self, data: CreateSchemaType) -> int: prepared_data = await self._prepare_create(data) obj = self.entity_class(**prepared_data) self.session.add(obj) await self.session.flash() await self.session.refresh(obj) await self._after_create(obj, data) await self.session.commit() await self.session.refresh(obj) return obj.id class RepUpdateMixin(Generic[EntityType, UpdateSchemaType], RepBaseMixin[EntityType]): async def _apply_update_data_to_model( self, model: EntityType, data: UpdateSchemaType, with_commit: Optional[bool] = False, fields: Optional[list[str]] = None, ) -> EntityType: if fields is None: fields = data.model_dump().keys() for field in fields: value = getattr(data, field) if value is not None: setattr(model, field, value) if with_commit: self.session.add(model) await self.session.commit() await self.session.refresh(model) return model async def update(self, entity: EntityType, data: UpdateSchemaType) -> EntityType: return await self._apply_update_data_to_model(entity, data, True) class RepGetByIdMixin(Generic[EntityType], RepBaseMixin[EntityType]): entity_class: Type[EntityType] entity_not_found_msg = "Entity not found" def _process_get_by_id_stmt(self, stmt: Select) -> Select: return stmt async def get_by_id( self, item_id: int, raise_if_not_found: Optional[bool] = True ) -> Optional[EntityType]: stmt = select(self.entity_class).where(self.entity_class.id == item_id) if hasattr(self, "is_deleted"): stmt = stmt.where(self.entity_class.is_deleted.is_(False)) stmt = self._process_get_by_id_stmt(stmt) result = (await self.session.execute(stmt)).scalar_one_or_none() if result is None and raise_if_not_found: raise ObjectNotFoundException(self.entity_not_found_msg) return result class RepGetAllMixin(Generic[EntityType], RepBaseMixin[EntityType]): entity_class: Type[EntityType] def _process_get_all_stmt_with_args(self, stmt: Select, *args) -> Select: return stmt def _process_get_all_stmt(self, stmt: Select) -> Select: return stmt async def get_all(self, *args) -> list[EntityType]: stmt = select(self.entity_class) if hasattr(self, "is_deleted"): stmt = stmt.where(self.entity_class.is_deleted.is_(False)) if args: stmt = self._process_get_all_stmt_with_args(stmt, *args) else: stmt = self._process_get_all_stmt(stmt) result = await self.session.execute(stmt) return list(result.scalars().all()) class RepCrudMixin( Generic[EntityType, CreateSchemaType, UpdateSchemaType], BaseRepository, RepGetAllMixin[EntityType], RepCreateMixin[EntityType, CreateSchemaType], RepUpdateMixin[EntityType, UpdateSchemaType], RepGetByIdMixin[EntityType], RepDeleteMixin[EntityType], ): pass