diff --git a/qubes/api/admin.py b/qubes/api/admin.py index 9172d8e4e..859c98066 100644 --- a/qubes/api/admin.py +++ b/qubes/api/admin.py @@ -354,6 +354,49 @@ def vm_volume_import(self): return '{} {}'.format(size, path) + @qubes.api.method('admin.vm.tag.List', no_payload=True) + @asyncio.coroutine + def vm_tag_list(self): + assert not self.arg + + tags = self.dest.tags + + tags = self.fire_event_for_filter(tags) + + return ''.join('{}\n'.format(tag) for tag in sorted(tags)) + + @qubes.api.method('admin.vm.tag.Get', no_payload=True) + @asyncio.coroutine + def vm_tag_get(self): + qubes.vm.Tags.validate_tag(self.arg) + + self.fire_event_for_permission() + + return '1' if self.arg in self.dest.tags else '0' + + @qubes.api.method('admin.vm.tag.Set', no_payload=True) + @asyncio.coroutine + def vm_tag_set(self): + qubes.vm.Tags.validate_tag(self.arg) + + self.fire_event_for_permission() + + self.dest.tags.add(self.arg) + self.app.save() + + @qubes.api.method('admin.vm.tag.Remove', no_payload=True) + @asyncio.coroutine + def vm_tag_remove(self): + qubes.vm.Tags.validate_tag(self.arg) + + self.fire_event_for_permission() + + try: + self.dest.tags.remove(self.arg) + except KeyError: + raise qubes.exc.QubesTagNotFoundError(self.dest, self.arg) + self.app.save() + @qubes.api.method('admin.pool.List', no_payload=True) @asyncio.coroutine def pool_list(self): diff --git a/qubes/exc.py b/qubes/exc.py index 5e708c980..8b3fa68c0 100644 --- a/qubes/exc.py +++ b/qubes/exc.py @@ -162,3 +162,12 @@ def __init__(self, domain, feature): 'Feature not set for domain {}: {}'.format(domain, feature)) self.feature = feature self.vm = domain + +class QubesTagNotFoundError(QubesException, KeyError): + '''Tag not set for a given domain''' + + def __init__(self, domain, tag): + super().__init__('Tag not set for domain {}: {}'.format( + domain, tag)) + self.vm = domain + self.tag = tag diff --git a/qubes/tests/api_admin.py b/qubes/tests/api_admin.py index 6a18e2212..7eec7619f 100644 --- a/qubes/tests/api_admin.py +++ b/qubes/tests/api_admin.py @@ -1687,6 +1687,53 @@ def coroutine_mock(*args, **kwargs): self.assertEqual(func_mock.mock_calls, []) self.assertFalse(self.app.save.called) + def test_530_tag_list(self): + self.vm.tags.add('tag1') + self.vm.tags.add('tag2') + value = self.call_mgmt_func(b'admin.vm.tag.List', b'test-vm1') + self.assertEqual(value, 'tag1\ntag2\n') + self.assertFalse(self.app.save.called) + + def test_540_tag_get(self): + self.vm.tags.add('tag1') + value = self.call_mgmt_func(b'admin.vm.tag.Get', b'test-vm1', + b'tag1') + self.assertEqual(value, '1') + self.assertFalse(self.app.save.called) + + def test_541_tag_get_absent(self): + value = self.call_mgmt_func(b'admin.vm.tag.Get', b'test-vm1', b'tag1') + self.assertEqual(value, '0') + self.assertFalse(self.app.save.called) + + def test_550_tag_remove(self): + self.vm.tags.add('tag1') + value = self.call_mgmt_func(b'admin.vm.tag.Remove', b'test-vm1', + b'tag1') + self.assertIsNone(value, None) + self.assertNotIn('tag1', self.vm.tags) + self.assertTrue(self.app.save.called) + + def test_551_tag_remove_absent(self): + with self.assertRaises(qubes.exc.QubesTagNotFoundError): + self.call_mgmt_func(b'admin.vm.tag.Remove', + b'test-vm1', b'tag1') + self.assertFalse(self.app.save.called) + + def test_560_tag_set(self): + value = self.call_mgmt_func(b'admin.vm.tag.Set', + b'test-vm1', b'tag1') + self.assertIsNone(value) + self.assertIn('tag1', self.vm.tags) + self.assertTrue(self.app.save.called) + + def test_561_tag_set_invalid(self): + with self.assertRaises(AssertionError): + self.call_mgmt_func(b'admin.vm.tag.Set', + b'test-vm1', b'+.some-tag') + self.assertNotIn('+.some-tag', self.vm.tags) + self.assertFalse(self.app.save.called) + def test_990_vm_unexpected_payload(self): methods_with_no_payload = [ b'admin.vm.List', diff --git a/qubes/vm/__init__.py b/qubes/vm/__init__.py index 44edbf853..885bd0e2f 100644 --- a/qubes/vm/__init__.py +++ b/qubes/vm/__init__.py @@ -237,6 +237,11 @@ def remove(self, elem): # end of overriding # + @staticmethod + def validate_tag(tag): + safe_set = string.ascii_letters + string.digits + '-_' + assert all((x in safe_set) for x in tag) + class BaseVM(qubes.PropertyHolder): '''Base class for all VMs