diff --git a/qubesadmin/device_protocol.py b/qubesadmin/device_protocol.py index 711a1fc4..91ab73c8 100644 --- a/qubesadmin/device_protocol.py +++ b/qubesadmin/device_protocol.py @@ -205,14 +205,14 @@ def parse_basic_device_properties( properties['port'] = expected @staticmethod - def serialize_str(value: str): + def serialize_str(value: str) -> str: """ Serialize python string to ensure consistency. """ return "'" + str(value).replace("'", r"\'") + "'" @staticmethod - def deserialize_str(value: str): + def deserialize_str(value: str) -> str: """ Deserialize python string to ensure consistency. """ @@ -255,7 +255,11 @@ class Port: port_id (str): A unique (in backend domain) identifier for the port. devclass (str): The class of the port (e.g., 'usb', 'pci'). """ - def __init__(self, backend_domain, port_id, devclass): + def __init__(self, + backend_domain: Optional[QubesVM], + port_id: str, + devclass: str + ): self.__backend_domain = backend_domain self.__port_id = port_id self.__devclass = devclass @@ -288,7 +292,7 @@ def __str__(self): @property def backend_name(self) -> str: # pylint: disable=missing-function-docstring - if self.backend_domain not in (None, "*"): + if self.backend_domain is not None: return self.backend_domain.name return "*" @@ -359,12 +363,22 @@ def devclass(self) -> str: return self.__devclass return "peripheral" - @property def has_devclass(self): return self.__devclass is not None +class AnyPort(Port): + def __init__(self, devclass: str): + super().__init__(None, "*", devclass) + + def __repr__(self): + return "*" + + def __str__(self): + return "*" + + class VirtualDevice: """ Class of a device connected to *port*. @@ -378,7 +392,7 @@ def __init__( port: Optional[Port] = None, device_id: Optional[str] = None, ): - assert port is not None or device_id is not None + assert not isinstance(port, AnyPort) or device_id is not None self.port: Optional[Port] = port self._device_id = device_id @@ -394,19 +408,26 @@ def clone(self, **kwargs) -> 'VirtualDevice': return VirtualDevice(**attr) @property - def port(self) -> Union[Port, str]: + def port(self) -> Port: # pylint: disable=missing-function-docstring return self._port @port.setter def port(self, value: Union[Port, str, None]): # pylint: disable=missing-function-docstring - self._port = value if value is not None else '*' + if isinstance(value, Port): + self._port = value + return + if isinstance(value, str) and value != '*': + raise ValueError("Unsupported value for port") + if self.device_id == '*': + raise ValueError("Cannot set port to '*' if device_is is '*'") + self._port = AnyPort(self.devclass) @property def device_id(self) -> str: # pylint: disable=missing-function-docstring - if self._device_id is not None: + if self.is_device_id_set: return self._device_id return '*' @@ -418,34 +439,26 @@ def is_device_id_set(self) -> bool: return self._device_id is not None @property - def backend_domain(self) -> Union[QubesVM, str]: + def backend_domain(self) -> Optional[QubesVM]: # pylint: disable=missing-function-docstring - if self.port != '*' and self.port.backend_domain is not None: - return self.port.backend_domain - return '*' + return self.port.backend_domain @property def backend_name(self) -> str: """ Return backend domain name if any or `*`. """ - if self.port != '*': - return self.port.backend_name - return '*' + return self.port.backend_name @property def port_id(self) -> str: # pylint: disable=missing-function-docstring - if self.port != '*' and self.port.port_id is not None: - return self.port.port_id - return '*' + return self.port.port_id @property def devclass(self) -> str: # pylint: disable=missing-function-docstring - if self.port != '*' and self.port.devclass is not None: - return self.port.devclass - return '*' + return self.port.devclass @property def description(self) -> str: @@ -483,9 +496,11 @@ def __lt__(self, other): 4. *:* """ if isinstance(other, (VirtualDevice, DeviceAssignment)): - if self.port == '*' and other.port != '*': + if (isinstance(self.port, AnyPort) + and not isinstance(other.port, AnyPort)): return True - if self.port != '*' and other.port == '*': + if (not isinstance(self.port, AnyPort) + and isinstance(other.port, AnyPort)): return False reprs = {self: [self.port], other: [other.port]} for obj, obj_repr in reprs.items(): @@ -509,10 +524,10 @@ def __str__(self): def from_qarg( cls, representation: str, - devclass, + devclass: Optional[str], domains, - blind=False, - backend=None, + blind: bool = False, + backend: Optional[QubesVM] = None, ) -> 'VirtualDevice': """ Parse qrexec argument +: to get device info @@ -528,8 +543,12 @@ def from_qarg( @classmethod def from_str( - cls, representation: str, devclass: Optional[str], domains, - blind=False, backend=None + cls, + representation: str, + devclass: Optional[str], + domains, + blind: bool = False, + backend: Optional[QubesVM] = None, ) -> 'VirtualDevice': """ Parse string +: to get device info @@ -549,7 +568,7 @@ def _parse( representation: str, devclass: Optional[str], get_domain: Callable, - backend, + backend: Optional[QubesVM], sep: str ) -> 'VirtualDevice': """ @@ -942,7 +961,8 @@ def subdevices(self) -> List[VirtualDevice]: If the device has subdevices (e.g., partitions of a USB stick), the subdevices id should be here. """ - return [dev for dev in self.backend_domain.devices[self.devclass] + return [dev for devclass in self.backend_domain.devices.keys() + for dev in self.backend_domain.devices[devclass] if dev.parent_device.port.port_id == self.port_id] @property @@ -956,7 +976,7 @@ def serialize(self) -> bytes: """ Serialize an object to be transmitted via Qubes API. """ - properties = VirtualDevice.serialize(self) + properties = super().serialize() # 'attachment', 'interfaces', 'data', 'parent_device' # are not string, so they need special treatment default = DeviceInfo(self.port)