Skip to content

Commit

Permalink
more mypy fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Alami-Amine committed Jul 23, 2024
1 parent 2b1dd69 commit 9f579e2
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 73 deletions.
53 changes: 26 additions & 27 deletions src/controller/python/chip/ChipDeviceCtrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from ctypes import (CDLL, CFUNCTYPE, POINTER, Structure, byref, c_bool, c_char, c_char_p, c_int, c_int32, c_size_t, c_uint8,
c_uint16, c_uint32, c_uint64, c_void_p, create_string_buffer, pointer, py_object, resize, string_at)
from dataclasses import dataclass
from typing import Optional

import dacite # type: ignore

Expand Down Expand Up @@ -267,7 +268,7 @@ class CommissioningContext(CallbackContext):
This context also resets commissioning related device controller state.
"""

def __init__(self, devCtrl: ChipDeviceController, lock: asyncio.Lock) -> None:
def __init__(self, devCtrl: ChipDeviceControllerBase, lock: asyncio.Lock) -> None:
super().__init__(lock)
self._devCtrl = devCtrl

Expand Down Expand Up @@ -388,12 +389,12 @@ def attestationChallenge(self) -> bytes:


class ChipDeviceControllerBase():
activeList = set()
activeList: typing.Set = set()

def __init__(self, name: str = ''):
self.devCtrl = None
self._ChipStack = builtins.chipStack
self._dmLib = None
self._dmLib: typing.Any = None

self._InitLib()

Expand Down Expand Up @@ -553,10 +554,6 @@ def _enablePairingCompleteCallback(self, value: bool):
self._dmLib.pychip_ScriptDevicePairingDelegate_SetExpectingPairingComplete(
self.pairingDelegate, value)

@property
def fabricAdmin(self) -> FabricAdmin.FabricAdmin:
return self._fabricAdmin

@property
def nodeId(self) -> int:
return self._nodeId
Expand Down Expand Up @@ -595,7 +592,7 @@ def Shutdown(self):
ChipDeviceController.activeList.remove(self)
self._isActive = False

def ShutdownAll():
def ShutdownAll(self):
''' Shut down all active controllers and reclaim any used resources.
'''
#
Expand Down Expand Up @@ -766,7 +763,7 @@ def GetAddressAndPort(self, nodeid):

return (address.value.decode(), port.value) if error == 0 else None

async def DiscoverCommissionableNodes(self, filterType: discovery.FilterType = discovery.FilterType.NONE, filter: typing.Any = None,
async def DiscoverCommissionableNodes(self, filterType: discovery.FilterType = discovery.FilterType.NONE, filter: Optional[typing.Any] = None,
stopOnFirst: bool = False, timeoutSecond: int = 5) -> typing.Union[None, CommissionableNode, typing.List[CommissionableNode]]:
''' Discover commissionable nodes via DNS-SD with specified filters.
Supported filters are:
Expand Down Expand Up @@ -928,7 +925,7 @@ def GetClusterHandler(self):

return self._Cluster

async def FindOrEstablishPASESession(self, setupCode: str, nodeid: int, timeoutMs: int = None) -> typing.Optional[DeviceProxyWrapper]:
async def FindOrEstablishPASESession(self, setupCode: str, nodeid: int, timeoutMs: Optional[int] = None) -> typing.Optional[DeviceProxyWrapper]:
''' Returns CommissioneeDeviceProxy if we can find or establish a PASE connection to the specified device'''
self.CheckIsActive()
returnDevice = c_void_p(None)
Expand All @@ -944,7 +941,9 @@ async def FindOrEstablishPASESession(self, setupCode: str, nodeid: int, timeoutM
if res.is_success:
return DeviceProxyWrapper(returnDevice, DeviceProxyWrapper.DeviceProxyType.COMMISSIONEE, self._dmLib)

def GetConnectedDeviceSync(self, nodeid, allowPASE=True, timeoutMs: int = None):
return None

def GetConnectedDeviceSync(self, nodeid, allowPASE=True, timeoutMs: Optional[int] = None):
''' Gets an OperationalDeviceProxy or CommissioneeDeviceProxy for the specified Node.
nodeId: Target's Node ID
Expand All @@ -956,8 +955,8 @@ def GetConnectedDeviceSync(self, nodeid, allowPASE=True, timeoutMs: int = None):
'''
self.CheckIsActive()

returnDevice = c_void_p(None)
returnErr = None
returnDevice: ctypes.c_void_p = c_void_p(None)
returnErr: typing.Any = None
deviceAvailableCV = threading.Condition()

if allowPASE:
Expand Down Expand Up @@ -1014,7 +1013,7 @@ async def WaitForActive(self, nodeid, *, timeoutSeconds=30.0, stayActiveDuration
await WaitForCheckIn(ScopedNodeId(nodeid, self._fabricIndex), timeoutSeconds=timeoutSeconds)
return await self.SendCommand(nodeid, 0, Clusters.IcdManagement.Commands.StayActiveRequest(stayActiveDuration=stayActiveDurationMs))

async def GetConnectedDevice(self, nodeid, allowPASE: bool = True, timeoutMs: int = None):
async def GetConnectedDevice(self, nodeid, allowPASE: bool = True, timeoutMs: Optional[int] = None):
''' Gets an OperationalDeviceProxy or CommissioneeDeviceProxy for the specified Node.
nodeId: Target's Node ID
Expand Down Expand Up @@ -1412,7 +1411,7 @@ def _parseEventPathTuple(self, pathTuple: typing.Union[
else:
raise ValueError("Unsupported Attribute Path")

async def Read(self, nodeid: int, attributes: typing.List[typing.Union[
async def Read(self, nodeid: int, attributes: Optional[typing.List[typing.Union[
None, # Empty tuple, all wildcard
typing.Tuple[int], # Endpoint
# Wildcard endpoint, Cluster id present
Expand All @@ -1423,9 +1422,9 @@ async def Read(self, nodeid: int, attributes: typing.List[typing.Union[
typing.Tuple[int, typing.Type[ClusterObjects.Cluster]],
# Concrete path
typing.Tuple[int, typing.Type[ClusterObjects.ClusterAttributeDescriptor]]
]] = None,
dataVersionFilters: typing.List[typing.Tuple[int, typing.Type[ClusterObjects.Cluster], int]] = None, events: typing.List[
typing.Union[
]]] = None,
dataVersionFilters: Optional[typing.List[typing.Tuple[int, typing.Type[ClusterObjects.Cluster], int]]] = None, events: Optional[typing.List[
typing.Union[
None, # Empty tuple, all wildcard
typing.Tuple[str, int], # all wildcard with urgency set
typing.Tuple[int, int], # Endpoint,
Expand All @@ -1437,9 +1436,9 @@ async def Read(self, nodeid: int, attributes: typing.List[typing.Union[
typing.Tuple[int, typing.Type[ClusterObjects.Cluster], int],
# Concrete path
typing.Tuple[int, typing.Type[ClusterObjects.ClusterEvent], int]
]] = None,
]]] = None,
eventNumberFilter: typing.Optional[int] = None,
returnClusterObject: bool = False, reportInterval: typing.Tuple[int, int] = None,
returnClusterObject: bool = False, reportInterval: Optional[typing.Tuple[int, int]] = None,
fabricFiltered: bool = True, keepSubscriptions: bool = False, autoResubscribe: bool = True):
'''
Read a list of attributes and/or events from a target node
Expand Down Expand Up @@ -1528,9 +1527,9 @@ async def ReadAttribute(self, nodeid: int, attributes: typing.List[typing.Union[
typing.Tuple[int, typing.Type[ClusterObjects.Cluster]],
# Concrete path
typing.Tuple[int, typing.Type[ClusterObjects.ClusterAttributeDescriptor]]
]], dataVersionFilters: typing.List[typing.Tuple[int, typing.Type[ClusterObjects.Cluster], int]] = None,
]], dataVersionFilters: Optional[typing.List[typing.Tuple[int, typing.Type[ClusterObjects.Cluster], int]]] = None,
returnClusterObject: bool = False,
reportInterval: typing.Tuple[int, int] = None,
reportInterval: Optional[typing.Tuple[int, int]] = None,
fabricFiltered: bool = True, keepSubscriptions: bool = False, autoResubscribe: bool = True):
'''
Read a list of attributes from a target node, this is a wrapper of DeviceController.Read()
Expand Down Expand Up @@ -1611,7 +1610,7 @@ async def ReadEvent(self, nodeid: int, events: typing.List[typing.Union[
typing.Tuple[int, typing.Type[ClusterObjects.ClusterEvent], int]
]], eventNumberFilter: typing.Optional[int] = None,
fabricFiltered: bool = True,
reportInterval: typing.Tuple[int, int] = None,
reportInterval: Optional[typing.Tuple[int, int]] = None,
keepSubscriptions: bool = False,
autoResubscribe: bool = True):
'''
Expand Down Expand Up @@ -1928,7 +1927,7 @@ class ChipDeviceController(ChipDeviceControllerBase):
'''

def __init__(self, opCredsContext: ctypes.c_void_p, fabricId: int, nodeId: int, adminVendorId: int, catTags: typing.List[int] = [
], paaTrustStorePath: str = "", useTestCommissioner: bool = False, fabricAdmin: FabricAdmin = None, name: str = None, keypair: p256keypair.P256Keypair = None):
], paaTrustStorePath: str = "", useTestCommissioner: bool = False, fabricAdmin: Optional[FabricAdmin] = None, name: Optional[str] = None, keypair: Optional[p256keypair.P256Keypair] = None):
super().__init__(
name or
f"caIndex({fabricAdmin.caIndex:x})/fabricId(0x{fabricId:016X})/nodeId(0x{nodeId:016X})"
Expand Down Expand Up @@ -1971,7 +1970,7 @@ def caIndex(self) -> int:
return self._caIndex

@property
def fabricAdmin(self) -> FabricAdmin:
def fabricAdmin(self) -> FabricAdmin.FabricAdmin:
return self._fabricAdmin

async def Commission(self, nodeid) -> int:
Expand Down Expand Up @@ -2114,7 +2113,7 @@ def GetFabricCheckResult(self) -> int:
return self._fabricCheckNodeId

async def CommissionOnNetwork(self, nodeId: int, setupPinCode: int,
filterType: DiscoveryFilterType = DiscoveryFilterType.NONE, filter: typing.Any = None,
filterType: DiscoveryFilterType = DiscoveryFilterType.NONE, filter: Optional[typing.Any] = None,
discoveryTimeoutMsec: int = 30000) -> int:
'''
Does the routine for OnNetworkCommissioning, with a filter for mDNS discovery.
Expand Down Expand Up @@ -2217,7 +2216,7 @@ class BareChipDeviceController(ChipDeviceControllerBase):
'''

def __init__(self, operationalKey: p256keypair.P256Keypair, noc: bytes,
icac: typing.Union[bytes, None], rcac: bytes, ipk: typing.Union[bytes, None], adminVendorId: int, name: str = None):
icac: typing.Union[bytes, None], rcac: bytes, ipk: typing.Union[bytes, None], adminVendorId: int, name: Optional[str] = None):
'''Creates a controller without AutoCommissioner.
The allocated controller uses the noc, icac, rcac and ipk instead of the default,
Expand Down
52 changes: 21 additions & 31 deletions src/controller/python/chip/clusters/Attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ class EventPriority(Enum):

@dataclass(frozen=True)
class AttributePath:
EndpointId: int = int()
ClusterId: int = int()
AttributeId: int = int()
EndpointId: int = None
ClusterId: int = None
AttributeId: int = None

@staticmethod
def from_cluster(EndpointId: int, Cluster: Cluster) -> AttributePath:
Expand All @@ -80,12 +80,12 @@ def __str__(self) -> str:

@dataclass(frozen=True)
class DataVersionFilter:
EndpointId: int = None
ClusterId: int = None
DataVersion: int = None
EndpointId: int
ClusterId: int
DataVersion: int

@staticmethod
def from_cluster(EndpointId: int, Cluster: Cluster, DataVersion: int = None) -> AttributePath:
def from_cluster(EndpointId: int, Cluster: Cluster, DataVersion: int) -> DataVersionFilter:
if Cluster is None:
raise ValueError("Cluster cannot be None")
return DataVersionFilter(EndpointId=EndpointId, ClusterId=Cluster.id, DataVersion=DataVersion)
Expand All @@ -99,13 +99,13 @@ class TypedAttributePath:
''' Encapsulates an attribute path that has strongly typed references to cluster and attribute
cluster object types. These types serve as keys into the attribute cache.
'''
ClusterType: Cluster = None
AttributeType: ClusterAttributeDescriptor = None
AttributeName: str = None
Path: AttributePath = None
ClusterType: Cluster
AttributeType: ClusterAttributeDescriptor
AttributeName: str
Path: AttributePath

def __init__(self, ClusterType: Cluster = None, AttributeType: ClusterAttributeDescriptor = None,
Path: AttributePath = None):
def __init__(self, ClusterType: Optional[Cluster] = None, AttributeType: Optional[ClusterAttributeDescriptor] = None,
Path: Optional[AttributePath] = None):
''' Only one of either ClusterType and AttributeType OR Path may be provided.
'''

Expand Down Expand Up @@ -156,13 +156,13 @@ class EventPath:
Urgent: int = None

@staticmethod
def from_cluster(EndpointId: int, Cluster: Cluster, EventId: int = None, Urgent: int = None) -> "EventPath":
def from_cluster(EndpointId: int, Cluster: Cluster, EventId: Optional[int] = None, Urgent: Optional[int] = None) -> "EventPath":
if Cluster is None:
raise ValueError("Cluster cannot be None")
return EventPath(EndpointId=EndpointId, ClusterId=Cluster.id, EventId=EventId, Urgent=Urgent)

@staticmethod
def from_event(EndpointId: int, Event: ClusterEvent, Urgent: int = None) -> "EventPath":
def from_event(EndpointId: int, Event: ClusterEvent, Urgent: Optional[int] = None) -> "EventPath":
if Event is None:
raise ValueError("Event cannot be None")
return EventPath(EndpointId=EndpointId, ClusterId=Event.cluster_id, EventId=Event.event_id, Urgent=Urgent)
Expand All @@ -181,16 +181,6 @@ class EventHeader:
Timestamp: int = None
TimestampType: EventTimestampType = None

def __init__(self, EndpointId: int = None, ClusterId: int = None,
EventId: int = None, EventNumber=None, Priority=None, Timestamp=None, TimestampType=None):
self.EndpointId = EndpointId
self.ClusterId = ClusterId
self.EventId = EventId
self.EventNumber = EventNumber
self.Priority = Priority
self.Timestamp = Timestamp
self.TimestampType = TimestampType

def __str__(self) -> str:
return (f"{self.EndpointId}/{self.ClusterId}/{self.EventId}/"
f"{self.EventNumber}/{self.Priority}/{self.Timestamp}/{self.TimestampType}")
Expand Down Expand Up @@ -1044,9 +1034,9 @@ def WriteGroupAttributes(groupId: int, devCtrl: c_void_p, attributes: List[Attri


def Read(future: Future, eventLoop, device, devCtrl,
attributes: List[AttributePath] = None, dataVersionFilters: List[DataVersionFilter] = None,
events: List[EventPath] = None, eventNumberFilter: Optional[int] = None, returnClusterObject: bool = True,
subscriptionParameters: SubscriptionParameters = None,
attributes: Optional[List[AttributePath]] = None, dataVersionFilters: Optional[List[DataVersionFilter]] = None,
events: Optional[List[EventPath]] = None, eventNumberFilter: Optional[int] = None, returnClusterObject: bool = True,
subscriptionParameters: Optional[SubscriptionParameters] = None,
fabricFiltered: bool = True, keepSubscriptions: bool = False, autoResubscribe: bool = True) -> PyChipError:
if (not attributes) and dataVersionFilters:
raise ValueError(
Expand Down Expand Up @@ -1164,9 +1154,9 @@ def Read(future: Future, eventLoop, device, devCtrl,


def ReadAttributes(future: Future, eventLoop, device, devCtrl,
attributes: List[AttributePath], dataVersionFilters: List[DataVersionFilter] = None,
attributes: List[AttributePath], dataVersionFilters: Optional[List[DataVersionFilter]] = None,
returnClusterObject: bool = True,
subscriptionParameters: SubscriptionParameters = None, fabricFiltered: bool = True) -> int:
subscriptionParameters: Optional[SubscriptionParameters] = None, fabricFiltered: bool = True) -> int:
return Read(future=future, eventLoop=eventLoop, device=device,
devCtrl=devCtrl, attributes=attributes, dataVersionFilters=dataVersionFilters,
events=None, returnClusterObject=returnClusterObject,
Expand All @@ -1175,7 +1165,7 @@ def ReadAttributes(future: Future, eventLoop, device, devCtrl,

def ReadEvents(future: Future, eventLoop, device, devCtrl,
events: List[EventPath], eventNumberFilter=None, returnClusterObject: bool = True,
subscriptionParameters: SubscriptionParameters = None, fabricFiltered: bool = True) -> int:
subscriptionParameters: Optional[SubscriptionParameters] = None, fabricFiltered: bool = True) -> int:
return Read(future=future, eventLoop=eventLoop, device=device, devCtrl=devCtrl, attributes=None,
dataVersionFilters=None, events=events, eventNumberFilter=eventNumberFilter,
returnClusterObject=returnClusterObject,
Expand Down
10 changes: 5 additions & 5 deletions src/controller/python/chip/clusters/ClusterObjects.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def TagDictToLabelDict(self, debugPath: str, tlvData: Dict[int, Any]) -> Dict[st

def TLVToDict(self, tlvBuf: bytes) -> Dict[str, Any]:
tlvData = tlv.TLVReader(tlvBuf).get().get('Any', {})
return self.TagDictToLabelDict([], tlvData)
return self.TagDictToLabelDict('', tlvData)

def DictToTLVWithWriter(self, debugPath: str, tag, data: Mapping, writer: tlv.TLVWriter):
writer.startStructure(tag)
Expand Down Expand Up @@ -210,11 +210,11 @@ def descriptor(cls):

# The below dictionaries will be filled dynamically
# and are used for quick lookup/mapping from cluster/attribute id to the correct class
ALL_CLUSTERS = {}
ALL_ATTRIBUTES = {}
ALL_CLUSTERS: typing.Dict = {}
ALL_ATTRIBUTES: typing.Dict = {}
# These need to be separate because there can be overlap in command ids for commands and responses.
ALL_ACCEPTED_COMMANDS = {}
ALL_GENERATED_COMMANDS = {}
ALL_ACCEPTED_COMMANDS: typing.Dict = {}
ALL_GENERATED_COMMANDS: typing.Dict = {}


class ClusterCommand(ClusterObject):
Expand Down
10 changes: 5 additions & 5 deletions src/controller/python/chip/crypto/p256keypair.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

import abc
import hashlib
from ctypes import (CFUNCTYPE, POINTER, c_bool, c_char, c_size_t, c_uint8, c_uint32, c_void_p, cast, memmove, pointer, py_object,
string_at)
from ctypes import (CFUNCTYPE, POINTER, _Pointer, c_bool, c_char, c_size_t, c_uint8, c_uint32, c_void_p, cast, memmove,
py_object, string_at)
from typing import TYPE_CHECKING

from chip import native
Expand All @@ -30,7 +30,7 @@ class pointer_fix:
@classmethod
def __class_getitem__(cls, item):
return POINTER(item)
pointer = pointer_fix
_Pointer = pointer_fix


_pychip_P256Keypair_ECDSA_sign_msg_func = CFUNCTYPE(
Expand All @@ -43,15 +43,15 @@ def __class_getitem__(cls, item):


@ _pychip_P256Keypair_ECDSA_sign_msg_func
def _pychip_ECDSA_sign_msg(self_: 'P256Keypair', message_buf: pointer[c_uint8], message_size: int, signature_buf: pointer[c_uint8], signature_buf_size: pointer[c_size_t]) -> bool:
def _pychip_ECDSA_sign_msg(self_: 'P256Keypair', message_buf: _Pointer[c_uint8], message_size: int, signature_buf: _Pointer[c_uint8], signature_buf_size: _Pointer[c_size_t]) -> bool:
res = self_.ECDSA_sign_msg(string_at(message_buf, message_size)[:])
memmove(signature_buf, res, len(res))
signature_buf_size.contents.value = len(res)
return True


@ _pychip_P256Keypair_ECDH_derive_secret_func
def _pychip_ECDH_derive_secret(self_: 'P256Keypair', remote_pubkey: pointer[c_uint8], out_secret_buf: pointer[c_uint8], out_secret_buf_size: pointer[c_uint32]) -> bool:
def _pychip_ECDH_derive_secret(self_: 'P256Keypair', remote_pubkey: _Pointer[c_uint8], out_secret_buf: _Pointer[c_uint8], out_secret_buf_size: _Pointer[c_uint32]) -> bool:
res = self_.ECDH_derive_secret(
string_at(remote_pubkey, P256_PUBLIC_KEY_LENGTH)[:])
memmove(out_secret_buf, res, len(res))
Expand Down
Loading

0 comments on commit 9f579e2

Please sign in to comment.