Skip to content

Commit

Permalink
q-dev: extract ext/utils
Browse files Browse the repository at this point in the history
  • Loading branch information
piotrbartman committed Jun 12, 2024
1 parent 28f1789 commit 9796c57
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 168 deletions.
4 changes: 2 additions & 2 deletions qubes/api/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,10 +1381,10 @@ async def vm_device_detach(self, endpoint):
scope='local', write=True)
async def vm_device_set_assignment(self, endpoint, untrusted_payload):
"""
Update assignment of already attached device.
Update assignment of an already attached device.
Payload:
`None` -> unassign device from qube
`None` -> unassign device from a qube
`False` -> device will be auto-attached to qube
`True` -> device is required to start qube
"""
Expand Down
180 changes: 94 additions & 86 deletions qubes/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,9 +535,11 @@ def deserialize(
expected_backend_domain: 'qubes.vm.BaseVM',
expected_devclass: Optional[str] = None,
) -> 'DeviceInfo':
ident, _, rest = serialization.partition(b' ')
ident = ident.decode('ascii', errors='ignore')
try:
result = DeviceInfo._deserialize(
cls, serialization, expected_backend_domain, expected_devclass)
cls, rest, expected_backend_domain, ident, expected_devclass)
except Exception as exc:
print(exc, file=sys.stderr)
ident = serialization.split(b' ')[0].decode(
Expand All @@ -552,37 +554,25 @@ def deserialize(
@staticmethod
def _deserialize(
cls: Type,
serialization: bytes,
untrusted_serialization: bytes,
expected_backend_domain: 'qubes.vm.BaseVM',
expected_ident: str,
expected_devclass: Optional[str] = None,
) -> 'DeviceInfo':
decoded = serialization.decode('ascii', errors='ignore')
ident, _, rest = decoded.partition(' ')
keys = []
values = []
key, _, rest = rest.partition("='")
keys.append(key)
while "='" in rest:
value_key, _, rest = rest.partition("='")
value, _, key = value_key.rpartition("' ")
values.append(deserialize_str(value))
keys.append(key)
value = rest[:-1] # ending '
values.append(deserialize_str(value))

properties = dict()
for key, value in zip(keys, values):
if key.startswith("_"):
# it's handled in cls.__init__
properties[key[1:]] = value
else:
properties[key] = value

if properties['backend_domain'] != expected_backend_domain.name:
raise UnexpectedDeviceProperty(
f"Got device exposed by {properties['backend_domain']}"
f"when expected devices from {expected_backend_domain.name}.")
properties['backend_domain'] = expected_backend_domain
allowed_chars_key = string.digits + string.ascii_letters + '-_.'
allowed_chars_value = (
allowed_chars_key + ',+:' + string.punctuation + ' ')

properties, options = unpack_properties(
untrusted_serialization, allowed_chars_key, allowed_chars_value)
properties.update(options)

check_device_properties(
expected_backend_domain,
expected_ident,
expected_devclass,
properties
)

if 'attachment' not in properties or not properties['attachment']:
properties['attachment'] = None
Expand All @@ -596,11 +586,6 @@ def _deserialize(
f"Got {properties['devclass']} device "
f"when expected {expected_devclass}.")

if properties["ident"] != ident:
raise UnexpectedDeviceProperty(
f"Got device with id: {properties['ident']} "
f"when expected id: {ident}.")

interfaces = properties['interfaces']
interfaces = [
DeviceInterface(interfaces[i:i + 7])
Expand Down Expand Up @@ -672,6 +657,71 @@ def sanitize_str(
return result


def unpack_properties(
untrusted_serialization: bytes,
allowed_chars_key: str,
allowed_chars_value: str
):
ut_decoded = untrusted_serialization.decode(
'ascii', errors='strict').strip()

options = {}
keys = []
values = []
ut_key, _, ut_rest = ut_decoded.partition("='")

key = sanitize_str(
ut_key, allowed_chars_key,
error_message='Invalid chars in property name')
keys.append(key)
while "='" in ut_rest:
ut_value_key, _, ut_rest = ut_rest.partition("='")
ut_value, _, ut_key = ut_value_key.rpartition("' ")
value = sanitize_str(
deserialize_str(ut_value), allowed_chars_value,
error_message='Invalid chars in property value')
values.append(value)
key = sanitize_str(
ut_key, allowed_chars_key,
error_message='Invalid chars in property name')
keys.append(key)
ut_value = ut_rest[:-1] # ending '
value = sanitize_str(
deserialize_str(ut_value), allowed_chars_value,
error_message='Invalid chars in property value')
values.append(value)

properties = dict()
for key, value in zip(keys, values):
if key.startswith("_"):
# it's handled in cls.__init__
options[key[1:]] = value
else:
properties[key] = value

return properties, options


def check_device_properties(
expected_backend_domain, expected_ident, expected_devclass, properties
):
if properties['backend_domain'] != expected_backend_domain.name:
raise UnexpectedDeviceProperty(
f"Got device exposed by {properties['backend_domain']}"
f"when expected devices from {expected_backend_domain.name}.")
properties['backend_domain'] = expected_backend_domain

if properties['ident'] != expected_ident:
raise UnexpectedDeviceProperty(
f"Got device with id: {properties['ident']} "
f"when expected id: {expected_ident}.")

if expected_devclass and properties['devclass'] != expected_devclass:
raise UnexpectedDeviceProperty(
f"Got {properties['devclass']} device "
f"when expected {expected_devclass}.")


class UnknownDevice(DeviceInfo):
# pylint: disable=too-few-public-methods
"""Unknown device - for example exposed by domain not running currently"""
Expand Down Expand Up @@ -741,13 +791,13 @@ def device(self) -> DeviceInfo:
return self.backend_domain.devices[self.devclass][self.ident]

@property
def frontend_domain(self) -> Optional['qubes.vm.qubesvm.QubesVM']:
def frontend_domain(self) -> Optional['qubes.vm.BaseVM']:
""" Which domain the device is attached/assigned to. """
return self.__frontend_domain

@frontend_domain.setter
def frontend_domain(
self, frontend_domain: Optional[Union[str, 'qubes.vm.qubesvm.QubesVM']]
self, frontend_domain: Optional[Union[str, 'qubes.vm.BaseVM']]
):
""" Which domain the device is attached/assigned to. """
if isinstance(frontend_domain, str):
Expand Down Expand Up @@ -852,61 +902,19 @@ def _deserialize(
expected_ident: str,
expected_devclass: Optional[str] = None,
) -> 'DeviceAssignment':
options = {}
allowed_chars_key = string.digits + string.ascii_letters + '-_.'
allowed_chars_value = allowed_chars_key + ',+:'

untrusted_decoded = untrusted_serialization.decode(
'ascii', 'strict').strip()
keys = []
values = []
untrusted_key, _, untrusted_rest = untrusted_decoded.partition("='")

key = sanitize_str(
untrusted_key, allowed_chars_key,
error_message='Invalid chars in property name')
keys.append(key)
while "='" in untrusted_rest:
ut_value_key, _, untrusted_rest = untrusted_rest.partition("='")
untrusted_value, _, untrusted_key = ut_value_key.rpartition("' ")
value = sanitize_str(
deserialize_str(untrusted_value), allowed_chars_value,
error_message='Invalid chars in property value')
values.append(value)
key = sanitize_str(
untrusted_key, allowed_chars_key,
error_message='Invalid chars in property name')
keys.append(key)
untrusted_value = untrusted_rest[:-1] # ending '
value = sanitize_str(
deserialize_str(untrusted_value), allowed_chars_value,
error_message='Invalid chars in property value')
values.append(value)

properties = dict()
for key, value in zip(keys, values):
if key.startswith("_"):
options[key[1:]] = value
else:
properties[key] = value

properties, options = unpack_properties(
untrusted_serialization, allowed_chars_key, allowed_chars_value)
properties['options'] = options

if properties['backend_domain'] != expected_backend_domain.name:
raise UnexpectedDeviceProperty(
f"Got device exposed by {properties['backend_domain']} "
f"when expected devices from {expected_backend_domain.name}.")
properties['backend_domain'] = expected_backend_domain

if properties["ident"] != expected_ident:
raise UnexpectedDeviceProperty(
f"Got device with id: {properties['ident']} "
f"when expected id: {expected_ident}.")

if expected_devclass and properties['devclass'] != expected_devclass:
raise UnexpectedDeviceProperty(
f"Got {properties['devclass']} device "
f"when expected {expected_devclass}.")
check_device_properties(
expected_backend_domain,
expected_ident,
expected_devclass,
properties
)

properties['attach_automatically'] = qubes.property.bool(
None, None, properties['attach_automatically'])
Expand Down
97 changes: 17 additions & 80 deletions qubes/ext/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

import qubes.devices
import qubes.ext
from qubes.ext.utils import device_list_change

name_re = re.compile(r"\A[a-z0-9-]{1,12}\Z")
device_re = re.compile(r"\A[a-z0-9/-]{1,64}\Z")
Expand Down Expand Up @@ -302,10 +303,10 @@ def on_domain_init_load(self, vm, event):
if event == 'domain-load':
# avoid building a cache on domain-init, as it isn't fully set yet,
# and definitely isn't running yet
current_devices = {
dev.ident: dev.attachment # TODO load all attachments at once
for dev in self.on_device_list_block(vm, None)
}
device_attachments = self.get_device_attachments(vm)
current_devices = dict(
(dev.ident, device_attachments.get(dev.ident, None))
for dev in self.on_device_list_block(vm, None))
self.devices_cache[vm.name] = current_devices.copy()
else:
self.devices_cache[vm.name] = {}.copy()
Expand All @@ -314,81 +315,14 @@ def on_domain_init_load(self, vm, event):
def on_qdb_change(self, vm, event, path):
"""A change in QubesDB means a change in a device list."""
# pylint: disable=unused-argument
if path is not None:
vm.fire_event('device-list-change:block')
device_attachments = self.get_device_attachments(vm)
current_devices = dict(
(dev.ident, device_attachments.get(dev.ident, None))
for dev in self.on_device_list_block(vm, None))
device_list_change(self, current_devices, vm, path, BlockDevice)

added, attached, detached, removed = (
self._compare_cache(vm, current_devices))

# send events about devices detached/attached outside by themselves
for dev_id, front_vm in detached.items():
dev = BlockDevice(vm, dev_id)
asyncio.ensure_future(front_vm.fire_event_async(
'device-detach:block', device=dev))
for dev_id in removed:
device = BlockDevice(vm, dev_id)
vm.fire_event('device-removed:block', device=device)
for dev_id in added:
device = BlockDevice(vm, dev_id)
vm.fire_event('device-added:block', device=device)
for dev_ident, front_vm in attached.items():
dev = BlockDevice(vm, dev_ident)
asyncio.ensure_future(front_vm.fire_event_async(
'device-attach:block', device=dev, options={}))

self.devices_cache[vm.name] = current_devices.copy()

for front_vm in vm.app.domains:
if not front_vm.is_running():
continue
for assignment in front_vm.devices['block'].get_assigned_devices():
if (assignment.backend_domain == vm
and assignment.ident in added
and assignment.ident not in attached
):
asyncio.ensure_future(self._attach_and_notify(
front_vm, assignment.device, assignment.options))

def _compare_cache(self, vm, current_devices):
# compare cached devices and current devices, collect:
# - newly appeared devices (ident)
# - devices attached from a vm to frontend vm (ident: frontend_vm)
# - devices detached from frontend vm (ident: frontend_vm)
# - disappeared devices, e.g., plugged out (ident)
added = set()
attached = dict()
detached = dict()
removed = set()
cache = self.devices_cache[vm.name]
for dev_id, front_vm in current_devices.items():
if dev_id not in cache:
added.add(dev_id)
if front_vm is not None:
attached[dev_id] = front_vm
elif cache[dev_id] != front_vm:
cached_front = cache[dev_id]
if front_vm is None:
detached[dev_id] = cached_front
elif cached_front is None:
attached[dev_id] = front_vm
else:
# a front changed from one to another, so we signal it as:
# detach from the first one and attach to the second one.
detached[dev_id] = cached_front
attached[dev_id] = front_vm

for dev_id, cached_front in cache.items():
if dev_id not in current_devices:
removed.add(dev_id)
if cached_front is not None:
detached[dev_id] = cached_front
return added, attached, detached, removed

def get_device_attachments(self, vm_):
@staticmethod
def get_device_attachments(vm_):
result = {}
for vm in vm_.app.domains:
if not vm.is_running():
Expand All @@ -405,7 +339,8 @@ def get_device_attachments(self, vm_):
result[ident] = vm
return result

def device_get(self, vm, ident):
@staticmethod
def device_get(vm, ident):
"""
Read information about a device from QubesDB
Expand Down Expand Up @@ -506,9 +441,11 @@ def on_device_list_attached(self, vm, event, **kwargs):

yield (BlockDevice(backend_domain, ident), options)

def find_unused_frontend(self, vm, devtype='disk'):
'''Find unused block frontend device node for <target dev=.../>
parameter'''
@staticmethod
def find_unused_frontend(vm, devtype='disk'):
"""
Find unused block frontend device node for <target dev=.../> parameter
"""
assert vm.is_running()

xml = vm.libvirt_domain.XMLDesc()
Expand Down Expand Up @@ -602,10 +539,10 @@ def on_device_pre_attached_block(self, vm, event, device, options):
async def on_domain_start(self, vm, _event, **_kwargs):
# pylint: disable=unused-argument
for assignment in vm.devices['block'].get_assigned_devices():
asyncio.ensure_future(self._attach_and_notify(
asyncio.ensure_future(self.attach_and_notify(
vm, assignment.device, assignment.options))

async def _attach_and_notify(self, vm, device, options):
async def attach_and_notify(self, vm, device, options):
# bypass DeviceCollection logic preventing double attach
try:
self.on_device_pre_attached_block(
Expand Down
Loading

0 comments on commit 9796c57

Please sign in to comment.