refactor: crud mixins for repositories and services
This commit is contained in:
@ -1,14 +1,21 @@
|
||||
from typing import Type, Optional
|
||||
from typing import Type, Optional, TypeVar, Generic
|
||||
|
||||
from sqlalchemy import select, Select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from repositories.base import BaseRepository
|
||||
from schemas.base import BaseSchema
|
||||
|
||||
class RepBaseMixin[EntityType]:
|
||||
EntityType = TypeVar("EntityType")
|
||||
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseSchema)
|
||||
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseSchema)
|
||||
|
||||
|
||||
class RepBaseMixin(Generic[EntityType]):
|
||||
session: AsyncSession
|
||||
|
||||
|
||||
class RepDeleteMixin[EntityType](RepBaseMixin[EntityType]):
|
||||
class RepDeleteMixin(Generic[EntityType], RepBaseMixin[EntityType]):
|
||||
async def delete(self, obj: EntityType, is_soft: bool) -> None:
|
||||
if not is_soft:
|
||||
await self.session.delete(obj)
|
||||
@ -24,10 +31,10 @@ class RepDeleteMixin[EntityType](RepBaseMixin[EntityType]):
|
||||
await self.session.commit()
|
||||
|
||||
|
||||
class RepCreateMixin[EntityType, CreateType](RepBaseMixin[EntityType]):
|
||||
class RepCreateMixin(Generic[EntityType, CreateSchemaType], RepBaseMixin[EntityType]):
|
||||
entity_class: Type[EntityType]
|
||||
|
||||
async def create(self, data: CreateType) -> int:
|
||||
async def create(self, data: CreateSchemaType) -> int:
|
||||
obj = self.entity_class(**data.model_dump())
|
||||
self.session.add(obj)
|
||||
await self.session.commit()
|
||||
@ -35,11 +42,11 @@ class RepCreateMixin[EntityType, CreateType](RepBaseMixin[EntityType]):
|
||||
return obj.id
|
||||
|
||||
|
||||
class RepUpdateMixin[EntityType, UpdateType](RepBaseMixin[EntityType]):
|
||||
class RepUpdateMixin(Generic[EntityType, UpdateSchemaType], RepBaseMixin[EntityType]):
|
||||
async def _apply_update_data_to_model(
|
||||
self,
|
||||
model: EntityType,
|
||||
data: UpdateType,
|
||||
data: UpdateSchemaType,
|
||||
with_commit: Optional[bool] = False,
|
||||
fields: Optional[list[str]] = None,
|
||||
) -> EntityType:
|
||||
@ -57,11 +64,11 @@ class RepUpdateMixin[EntityType, UpdateType](RepBaseMixin[EntityType]):
|
||||
await self.session.refresh(model)
|
||||
return model
|
||||
|
||||
async def update(self, entity: EntityType, data: UpdateType) -> EntityType:
|
||||
async def update(self, entity: EntityType, data: UpdateSchemaType) -> EntityType:
|
||||
pass
|
||||
|
||||
|
||||
class RepGetByIdMixin[EntityType](RepBaseMixin[EntityType]):
|
||||
class RepGetByIdMixin(Generic[EntityType], RepBaseMixin[EntityType]):
|
||||
entity_class: Type[EntityType]
|
||||
|
||||
def _process_get_by_id_stmt(self, stmt: Select) -> Select:
|
||||
@ -77,7 +84,7 @@ class RepGetByIdMixin[EntityType](RepBaseMixin[EntityType]):
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
class RepGetAllMixin[EntityType](RepBaseMixin[EntityType]):
|
||||
class RepGetAllMixin(Generic[EntityType], RepBaseMixin[EntityType]):
|
||||
entity_class: Type[EntityType]
|
||||
|
||||
def _process_get_all_stmt_with_args(self, stmt: Select, *args) -> Select:
|
||||
@ -97,3 +104,15 @@ class RepGetAllMixin[EntityType](RepBaseMixin[EntityType]):
|
||||
stmt = self._process_get_all_stmt(stmt)
|
||||
result = await self.session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
class RepCrudMixin(
|
||||
Generic[EntityType, CreateSchemaType, UpdateSchemaType],
|
||||
BaseRepository,
|
||||
RepGetAllMixin[EntityType],
|
||||
RepCreateMixin[EntityType, CreateSchemaType],
|
||||
RepUpdateMixin[EntityType, UpdateSchemaType],
|
||||
RepGetByIdMixin[EntityType],
|
||||
RepDeleteMixin[EntityType],
|
||||
):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user