diff --git a/src/controller/python/chip/clusters/ClusterObjects.py b/src/controller/python/chip/clusters/ClusterObjects.py index 13f35f240ae3cb..99893c4b186579 100644 --- a/src/controller/python/chip/clusters/ClusterObjects.py +++ b/src/controller/python/chip/clusters/ClusterObjects.py @@ -219,6 +219,12 @@ def must_use_timed_invoke(cls) -> bool: return False +# The below dictionaries will be filled dynamically +# and are used for quick lookup/mapping from cluster/attribute id to the correct class +ALL_CLUSTERS = {} +ALL_ATTRIBUTES = {} + + class Cluster(ClusterObject): ''' When send read requests with returnClusterObject=True, we will set the data_version property of the object. @@ -228,6 +234,18 @@ class Cluster(ClusterObject): especially the TLV decoding logic. Also ThreadNetworkDiagnostics has an attribute with the same name so we picked data_version as its name. ''' + + def __init_subclass__(cls, *args, **kwargs) -> None: + """Register a subclass.""" + super().__init_subclass__(*args, **kwargs) + # register this cluster in the ALL_CLUSTERS dict for quick lookups + try: + ALL_CLUSTERS[cls.id] = cls + except NotImplementedError: + # handle case where the Cluster class is not (fully) subclassed + # and accessing the id property throws a NotImplementedError. + pass + @property def data_version(self) -> int: return self._data_version @@ -254,6 +272,20 @@ class ClusterAttributeDescriptor: The implementation of this functions is quite tricky, it will create a cluster object on-the-fly, and use it for actual encode / decode routine to save lines of code. ''' + + def __init_subclass__(cls, *args, **kwargs) -> None: + """Register a subclass.""" + super().__init_subclass__(*args, **kwargs) + try: + if cls.cluster_id not in ALL_ATTRIBUTES: + ALL_ATTRIBUTES[cls.cluster_id] = {} + # register this clusterattribute in the ALL_ATTRIBUTES dict for quick lookups + ALL_ATTRIBUTES[cls.cluster_id][cls.attribute_id] = cls + except NotImplementedError: + # handle case where the ClusterAttribute class is not (fully) subclassed + # and accessing the id property throws a NotImplementedError. + pass + @classmethod def ToTLV(cls, tag: Union[int, None], value): writer = tlv.TLVWriter()