72 lines
2.5 KiB
Python
72 lines
2.5 KiB
Python
from typing import Optional
|
||
|
||
from sqlalchemy import Select, select, delete
|
||
from sqlalchemy.orm import joinedload
|
||
|
||
from models import Deal
|
||
from modules.fulfillment_base.models import DealService
|
||
from modules.fulfillment_base.models.service import ServicesKit
|
||
from modules.fulfillment_base.schemas.deal_service import (
|
||
UpdateDealServiceSchema,
|
||
CreateDealServiceSchema,
|
||
)
|
||
from repositories.base import BaseRepository
|
||
from repositories.mixins import RepGetAllMixin, RepUpdateMixin
|
||
from utils.exceptions import ObjectNotFoundException
|
||
|
||
|
||
class DealServiceRepository(
|
||
BaseRepository,
|
||
RepGetAllMixin[DealService],
|
||
RepUpdateMixin[DealService, UpdateDealServiceSchema],
|
||
):
|
||
entity_class = DealService
|
||
|
||
def _process_get_all_stmt_with_args(self, stmt: Select, *args) -> Select:
|
||
deal_id = args[0]
|
||
return (
|
||
stmt.options(
|
||
joinedload(DealService.service),
|
||
)
|
||
.where(DealService.deal_id == deal_id)
|
||
.order_by(DealService.service_id)
|
||
)
|
||
|
||
async def get_by_id(
|
||
self, deal_id: int, service_id: int, raise_if_not_found: Optional[bool] = True
|
||
) -> Optional[DealService]:
|
||
stmt = (
|
||
select(DealService)
|
||
.options(joinedload(DealService.service))
|
||
.where(DealService.deal_id == deal_id, DealService.service_id == service_id)
|
||
)
|
||
result = (await self.session.execute(stmt)).scalar_one_or_none()
|
||
if result is None and raise_if_not_found:
|
||
raise ObjectNotFoundException("Связь сделки с услугой не найдена")
|
||
return result
|
||
|
||
async def create(self, data: CreateDealServiceSchema):
|
||
deal_service = DealService(**data.model_dump())
|
||
self.session.add(deal_service)
|
||
await self.session.commit()
|
||
|
||
async def delete(self, obj: DealService):
|
||
await self.session.delete(obj)
|
||
await self.session.commit()
|
||
|
||
async def delete_deal_services(self, deal_id: int):
|
||
stmt = delete(DealService).where(DealService.deal_id == deal_id)
|
||
await self.session.execute(stmt)
|
||
await self.session.flush()
|
||
|
||
async def add_services_kit(self, deal: Deal, services_kit: ServicesKit):
|
||
for service in services_kit.services:
|
||
deal_service = DealService(
|
||
deal_id=deal.id,
|
||
service_id=service.id,
|
||
price=service.price,
|
||
quantity=1,
|
||
)
|
||
self.session.add(deal_service)
|
||
await self.session.commit()
|