from collections import defaultdict from sqlalchemy import and_ from sqlalchemy.orm import joinedload from models import ( Attribute, AttributeLabel, AttributeType, AttributeValue, module_attribute, Module, Project, ) from repositories.mixins import * from schemas.attribute import ( CreateAttributeSchema, UpdateAttributeSchema, UpdateDealModuleAttributeSchema, ) from utils.exceptions import ForbiddenException class AttributeRepository( RepCrudMixin[Attribute, CreateAttributeSchema, UpdateAttributeSchema] ): session: AsyncSession entity_class = Attribute def _process_get_all_stmt(self, stmt: Select) -> Select: return ( stmt.options(joinedload(Attribute.type)) .where(Attribute.is_deleted.is_(False)) .order_by(Attribute.id) ) def _process_get_by_id_stmt(self, stmt: Select) -> Select: return stmt.options(joinedload(Attribute.type)) async def _get_attribute_type_by_id(self, type_id: int) -> AttributeType: stmt = select(AttributeType).where(AttributeType.id == type_id) result = (await self.session.execute(stmt)).one_or_none() if result is None: raise ObjectNotFoundException("Тип аттрибута не найден") return result[0] async def update(self, attr: Attribute, data: UpdateAttributeSchema) -> Attribute: if data.type: data.type = await self._get_attribute_type_by_id(data.type.id) return await self._apply_update_data_to_model(attr, data, True) async def _before_delete(self, attribute: Attribute) -> None: if attribute.is_built_in: raise ForbiddenException("Нельзя менять встроенный атрибут") async def _get_all_attributes_for_deal(self, project_id) -> list[tuple[Attribute, int]]: stmt = ( select(Attribute, Module.id) .join(Attribute.modules) .join(Module.projects) .where( Module.is_deleted.is_(False), Project.is_deleted.is_(False), Project.id == project_id, ) .distinct(Attribute.id, Module.id) ) result = await self.session.execute(stmt) return list(result.all()) async def create_attributes_for_new_deal( self, deal_id: int, project_id: int ) -> None: attributes = await self._get_all_attributes_for_deal(project_id) for attribute, module_id in attributes: if attribute.default_value is None: continue value = AttributeValue( attribute_id=attribute.id, deal_id=deal_id, module_id=module_id, value=attribute.default_value, ) self.session.add(value) async def _get_attribute_module_label( self, module_id: int, attribute_id: int ) -> Optional[AttributeLabel]: stmt = select(AttributeLabel).where( AttributeLabel.attribute_id == attribute_id, AttributeLabel.module_id == module_id, ) result = await self.session.execute(stmt) row = result.one_or_none() return row[0] if row else None async def create_or_update_attribute_label( self, module_id: int, attribute_id: int, label: str ): attribute_label = await self._get_attribute_module_label( module_id, attribute_id ) if attribute_label: attribute_label.label = label else: attribute_label = AttributeLabel( module_id=module_id, attribute_id=attribute_id, label=label, ) self.session.add(attribute_label) await self.session.commit() async def get_attribute_types(self) -> list[AttributeType]: stmt = select(AttributeType).where(AttributeType.is_deleted.is_(False)) result = await self.session.execute(stmt) return list(result.scalars().all()) async def get_deal_module_attributes( self, deal_id: int, module_id: int ) -> list[tuple[Attribute, AttributeValue, AttributeLabel]]: stmt = ( select( Attribute, AttributeValue, AttributeLabel, ) .join( module_attribute, and_( module_attribute.c.attribute_id == Attribute.id, module_attribute.c.module_id == module_id, ), ) .outerjoin( AttributeValue, and_( AttributeValue.attribute_id == Attribute.id, AttributeValue.module_id == module_id, AttributeValue.deal_id == deal_id, ), ) .outerjoin( AttributeLabel, and_( AttributeLabel.attribute_id == Attribute.id, AttributeLabel.module_id == module_id, ), ) .where( Attribute.is_deleted.is_(False), ) ) result = await self.session.execute(stmt) return list(result.all()) async def _get_deals_attribute_values( self, deal_ids: list[int], module_id: int ) -> list[AttributeValue]: stmt = ( select(AttributeValue) .join(Attribute, AttributeValue.attribute_id == Attribute.id) .where( AttributeValue.deal_id.in_(deal_ids), AttributeValue.module_id == module_id, Attribute.is_deleted.is_(False), ) ) result = await self.session.execute(stmt) return list(result.scalars().all()) async def update_or_create_deals_attribute_values( self, main_deal_id: int, group_deal_ids: list[int], module_id: int, attributes: list[UpdateDealModuleAttributeSchema], ): old_deal_attribute_values: list[ AttributeValue ] = await self._get_deals_attribute_values(group_deal_ids, module_id) dict_old_attrs: dict[int, dict[int, AttributeValue]] = defaultdict(dict) for deal_attribute in old_deal_attribute_values: dict_old_attrs[deal_attribute.deal_id][deal_attribute.attribute_id] = ( deal_attribute ) for attribute in attributes: if attribute.is_applicable_to_group: deal_ids_to_apply = group_deal_ids else: deal_ids_to_apply = [main_deal_id] for deal_id in deal_ids_to_apply: if attribute.attribute_id in dict_old_attrs[deal_id]: attribute_value = dict_old_attrs[deal_id][attribute.attribute_id] attribute_value.value = attribute.value else: if attribute.value is None: continue attribute_value = AttributeValue( attribute_id=attribute.attribute_id, deal_id=deal_id, module_id=module_id, value=attribute.value, ) self.session.add(attribute_value) await self.session.commit()