import uuid from sqlalchemy import and_, or_ from sqlalchemy.orm import selectinload from models import Module, Attribute, AttributeLabel, module_attribute, ModuleTab from models.module import DeviceType from repositories.mixins import * from schemas.module import UpdateModuleCommonInfoSchema, CreateModuleSchema class ModuleRepository( RepCrudMixin[Module, CreateModuleSchema, UpdateModuleCommonInfoSchema] ): entity_class = Module def _process_get_by_id_stmt(self, stmt: Select) -> Select: return stmt.options(selectinload(Module.attributes).joinedload(Attribute.type)) async def get_by_ids(self, ids: list[int]) -> list[Module]: stmt = select(Module).where(Module.id.in_(ids)) modules = await self.session.scalars(stmt) return modules.all() @staticmethod def _get_stmt_modules_with_tuples() -> Select: return ( select(Module, Attribute, AttributeLabel) .join( module_attribute, Module.id == module_attribute.c.module_id, isouter=True, ) .join( Attribute, module_attribute.c.attribute_id == Attribute.id, isouter=True ) .join( AttributeLabel, and_( Module.id == AttributeLabel.module_id, Attribute.id == AttributeLabel.attribute_id, ), isouter=True, ) .where( Module.is_deleted.is_(False), or_(Attribute.id.is_(None), Attribute.is_deleted.is_(False)), ) .order_by(Module.id, Attribute.id) ) async def get_with_attributes_as_tuples( self, ) -> list[tuple[Module, Attribute, AttributeLabel]]: stmt = self._get_stmt_modules_with_tuples() return (await self.session.execute(stmt)).unique().all() async def get_with_attributes_as_tuple_by_id( self, pk: int ) -> list[tuple[Module, Attribute, AttributeLabel]]: stmt = self._get_stmt_modules_with_tuples() stmt = stmt.where(Module.id == pk) return (await self.session.execute(stmt)).unique().all() async def _prepare_create(self, data: CreateSchemaType) -> dict: dump = data.model_dump() dump["key"] = str(uuid.uuid4()) return dump async def _after_create(self, module: Module, _) -> None: tab = ModuleTab( key=module.key, label=module.label, icon_name=None, module_id=module.id, device=DeviceType.BOTH, ) self.session.add(tab) async def get_module_tabs_by_module_id(self, module_id: int) -> list[ModuleTab]: stmt = select(ModuleTab).where(ModuleTab.module_id == module_id) result = await self.session.scalars(stmt) return list(result.all()) async def update( self, module: Module, data: UpdateModuleCommonInfoSchema ) -> Module: tabs = await self.get_module_tabs_by_module_id(module.id) for tab in tabs: tab.label = data.label self.session.add(tab) return await self._apply_update_data_to_model(module, data, True) async def add_attribute_to_module(self, module: Module, attribute: Attribute): module.attributes.append(attribute) await self.session.commit() async def delete_attribute_from_module(self, module: Module, attribute: Attribute): module.attributes.remove(attribute) await self.session.commit()