from typing import Type, Optional from sqlalchemy import select, Select from sqlalchemy.ext.asyncio import AsyncSession class RepBaseMixin[EntityType]: session: AsyncSession class RepDeleteMixin[EntityType](RepBaseMixin[EntityType]): async def delete(self, obj: EntityType, is_soft: bool) -> None: if not is_soft: await self.session.delete(obj) await self.session.commit() return if not hasattr(obj, "is_deleted"): raise AttributeError( f"{obj.__class__.__name__} does not support soft delete (missing is_deleted field)" ) obj.is_deleted = True self.session.add(obj) await self.session.commit() class RepCreateMixin[EntityType, CreateType](RepBaseMixin[EntityType]): entity_class: Type[EntityType] async def create(self, data: CreateType) -> int: obj = self.entity_class(**data.model_dump()) self.session.add(obj) await self.session.commit() await self.session.refresh(obj) return obj.id class RepUpdateMixin[EntityType, UpdateType](RepBaseMixin[EntityType]): async def _apply_update_data_to_model( self, model: EntityType, data: UpdateType, with_commit: Optional[bool] = False, fields: Optional[list[str]] = None, ) -> EntityType: if fields is None: fields = data.model_dump().keys() for field in fields: value = getattr(data, field) if value is not None: setattr(model, field, value) if with_commit: self.session.add(model) await self.session.commit() await self.session.refresh(model) return model async def update(self, entity: EntityType, data: UpdateType) -> EntityType: pass class RepGetByIdMixin[EntityType](RepBaseMixin[EntityType]): entity_class: Type[EntityType] def _process_get_by_id_stmt(self, stmt: Select) -> Select: return stmt async def get_by_id(self, item_id: int) -> Optional[EntityType]: stmt = select(self.entity_class).where(self.entity_class.id == item_id) if hasattr(self, "is_deleted"): stmt = stmt.where(self.entity_class.is_deleted.is_(False)) stmt = self._process_get_by_id_stmt(stmt) result = await self.session.execute(stmt) return result.scalar_one_or_none() class RepGetAllMixin[EntityType](RepBaseMixin[EntityType]): entity_class: Type[EntityType] def _process_get_all_stmt_with_args(self, stmt: Select, *args) -> Select: return stmt def _process_get_all_stmt(self, stmt: Select) -> Select: return stmt async def get_all(self, *args) -> list[EntityType]: stmt = select(self.entity_class) if hasattr(self, "is_deleted"): stmt = stmt.where(self.entity_class.is_deleted.is_(False)) if args: stmt = self._process_get_all_stmt_with_args(stmt, *args) else: stmt = self._process_get_all_stmt(stmt) result = await self.session.execute(stmt) return list(result.scalars().all())