from sqlalchemy import delete from modules.fulfillment_base.models import Service from modules.fulfillment_base.models.service import ServicePriceRange from modules.fulfillment_base.schemas.service import ( CreateServiceSchema, UpdateServiceSchema, ServicePriceRangeSchema, ) from repositories.mixins import * class ServiceRepository( BaseRepository, RepGetAllMixin[Service], RepDeleteMixin[Service], RepUpdateMixin[Service, UpdateServiceSchema], RepGetByIdMixin[Service], ): entity_class = Service entity_not_found_msg = "Услуга не найдена" def _process_get_all_stmt(self, stmt: Select) -> Select: return stmt.order_by(Service.lexorank) @staticmethod def _price_ranges_schemas_to_models( price_ranges: list[ServicePriceRangeSchema], ) -> list[ServicePriceRange]: models = [] for range in price_ranges: models.append( ServicePriceRange( from_quantity=range.from_quantity, to_quantity=range.to_quantity, price=range.price, ) ) return models async def create(self, data: CreateServiceSchema) -> int: price_ranges = self._price_ranges_schemas_to_models(data.price_ranges) data_dict = data.model_dump() data_dict["price_ranges"] = price_ranges del data_dict["category"] service = Service(**data_dict) self.session.add(service) await self.session.commit() await self.session.refresh(service) return service.id async def _delete_price_ranges_by_service_id(self, service_id: int) -> None: stmt = delete(ServicePriceRange).where( ServicePriceRange.service_id == service_id ) await self.session.execute(stmt) await self.session.commit() async def update(self, service: Service, data: UpdateServiceSchema) -> Service: await self._delete_price_ranges_by_service_id(service.id) price_ranges = self._price_ranges_schemas_to_models(data.price_ranges) for price_range in price_ranges: service.price_ranges.append(price_range) del data.price_ranges del data.category return await self._apply_update_data_to_model(service, data, True) async def get_by_ids(self, ids: list[int]) -> list[Service]: stmt = select(Service).where(Service.id.in_(ids)) result = await self.session.execute(stmt) return result.scalars().all()