Files
Crm-Backend/repositories/attribute.py

213 lines
7.1 KiB
Python

from collections import defaultdict
from sqlalchemy import and_
from sqlalchemy.orm import joinedload
from models import (
Attribute,
AttributeLabel,
AttributeType,
AttributeValue,
module_attribute,
Module,
Project,
)
from repositories.mixins import *
from schemas.attribute import (
CreateAttributeSchema,
UpdateAttributeSchema,
UpdateDealModuleAttributeSchema,
)
from utils.exceptions import ForbiddenException
class AttributeRepository(
RepCrudMixin[Attribute, CreateAttributeSchema, UpdateAttributeSchema]
):
session: AsyncSession
entity_class = Attribute
def _process_get_all_stmt(self, stmt: Select) -> Select:
return (
stmt.options(joinedload(Attribute.type))
.where(Attribute.is_deleted.is_(False))
.order_by(Attribute.is_built_in.desc(), Attribute.id)
)
def _process_get_by_id_stmt(self, stmt: Select) -> Select:
return stmt.options(joinedload(Attribute.type))
async def update(self, attr: Attribute, data: UpdateAttributeSchema) -> Attribute:
return await self._apply_update_data_to_model(
attr, data, with_commit=True, set_if_value_is_not_none=False
)
async def _before_delete(self, attribute: Attribute) -> None:
if attribute.is_built_in:
raise ForbiddenException("Нельзя менять встроенный атрибут")
async def _get_all_attributes_for_deal(
self, project_id
) -> list[tuple[Attribute, int]]:
stmt = (
select(Attribute, Module.id)
.join(Attribute.modules)
.join(Module.projects)
.where(
Module.is_deleted.is_(False),
Project.is_deleted.is_(False),
Project.id == project_id,
)
.distinct(Attribute.id, Module.id)
)
result = await self.session.execute(stmt)
return list(result.all())
async def create_attributes_for_new_deal(
self, deal_id: int, project_id: int
) -> None:
attributes = await self._get_all_attributes_for_deal(project_id)
for attribute, module_id in attributes:
def_val = (
attribute.default_option_id
if attribute.default_option_id is not None
else attribute.default_value
)
if def_val is None:
continue
value = AttributeValue(
attribute_id=attribute.id,
deal_id=deal_id,
module_id=module_id,
value=def_val,
)
self.session.add(value)
async def _get_attribute_module_label(
self, module_id: int, attribute_id: int
) -> Optional[AttributeLabel]:
stmt = select(AttributeLabel).where(
AttributeLabel.attribute_id == attribute_id,
AttributeLabel.module_id == module_id,
)
result = await self.session.execute(stmt)
row = result.one_or_none()
return row[0] if row else None
async def create_or_update_attribute_label(
self, module_id: int, attribute_id: int, label: str
):
attribute_label = await self._get_attribute_module_label(
module_id, attribute_id
)
if attribute_label:
attribute_label.label = label
else:
attribute_label = AttributeLabel(
module_id=module_id,
attribute_id=attribute_id,
label=label,
)
self.session.add(attribute_label)
await self.session.commit()
async def get_attribute_types(self) -> list[AttributeType]:
stmt = select(AttributeType).where(AttributeType.is_deleted.is_(False))
result = await self.session.execute(stmt)
return list(result.scalars().all())
async def get_deal_module_attributes(
self, deal_id: int, module_id: int
) -> list[tuple[Attribute, AttributeValue, AttributeLabel]]:
stmt = (
select(
Attribute,
AttributeValue,
AttributeLabel,
)
.join(
module_attribute,
and_(
module_attribute.c.attribute_id == Attribute.id,
module_attribute.c.module_id == module_id,
),
)
.outerjoin(
AttributeValue,
and_(
AttributeValue.attribute_id == Attribute.id,
AttributeValue.module_id == module_id,
AttributeValue.deal_id == deal_id,
),
)
.outerjoin(
AttributeLabel,
and_(
AttributeLabel.attribute_id == Attribute.id,
AttributeLabel.module_id == module_id,
),
)
.where(Attribute.is_deleted.is_(False))
.options(joinedload(Attribute.select))
)
result = await self.session.execute(stmt)
return list(result.all())
async def _get_deals_attribute_values(
self, deal_ids: list[int], module_id: int
) -> list[AttributeValue]:
stmt = (
select(AttributeValue)
.join(Attribute, AttributeValue.attribute_id == Attribute.id)
.where(
AttributeValue.deal_id.in_(deal_ids),
AttributeValue.module_id == module_id,
Attribute.is_deleted.is_(False),
)
)
result = await self.session.execute(stmt)
return list(result.scalars().all())
async def update_or_create_deals_attribute_values(
self,
main_deal_id: int,
group_deal_ids: list[int],
module_id: int,
attributes: list[UpdateDealModuleAttributeSchema],
):
old_deal_attribute_values: list[
AttributeValue
] = await self._get_deals_attribute_values(group_deal_ids, module_id)
dict_old_attrs: dict[int, dict[int, AttributeValue]] = defaultdict(dict)
for deal_attribute in old_deal_attribute_values:
dict_old_attrs[deal_attribute.deal_id][deal_attribute.attribute_id] = (
deal_attribute
)
for attribute in attributes:
if attribute.is_applicable_to_group:
deal_ids_to_apply = group_deal_ids
else:
deal_ids_to_apply = [main_deal_id]
for deal_id in deal_ids_to_apply:
if attribute.attribute_id in dict_old_attrs[deal_id]:
attribute_value = dict_old_attrs[deal_id][attribute.attribute_id]
attribute_value.value = attribute.value
else:
if attribute.value is None:
continue
attribute_value = AttributeValue(
attribute_id=attribute.attribute_id,
deal_id=deal_id,
module_id=module_id,
value=attribute.value,
)
self.session.add(attribute_value)
await self.session.commit()