diff --git a/.pylintrc b/.pylintrc index cb085b3f..ce5735a9 100644 --- a/.pylintrc +++ b/.pylintrc @@ -368,4 +368,4 @@ exclude-protected=_asdict,_fields,_replace,_source,_make # Exceptions that will emit a warning when being caught. Defaults to # "Exception" -overgeneral-exceptions=Exception +overgeneral-exceptions=builtins.Exception,builtins.BaseException diff --git a/Makefile b/Makefile index ade10bf1..41f2a5aa 100644 --- a/Makefile +++ b/Makefile @@ -66,6 +66,8 @@ install-dom0: all-dom0 install -t $(DESTDIR)/etc/qubes/policy.d/include -m 664 policy.d/include/* install -d $(DESTDIR)/lib/systemd/system -m 755 install -t $(DESTDIR)/lib/systemd/system -m 644 systemd/qubes-qrexec-policy-daemon.service + install -m 755 -d $(DESTDIR)/usr/lib/tmpfiles.d/ + install -m 0644 -t $(DESTDIR)/usr/lib/tmpfiles.d/ systemd/qrexec.conf .PHONY: install-dom0 diff --git a/qrexec/__init__.py b/qrexec/__init__.py index da17325f..143647dc 100644 --- a/qrexec/__init__.py +++ b/qrexec/__init__.py @@ -38,10 +38,12 @@ RPC_PATH = "/etc/qubes-rpc" POLICY_AGENT_SOCKET_PATH = "/var/run/qubes/policy-agent.sock" POLICYPATH = pathlib.Path("/etc/qubes/policy.d") +RUNTIME_POLICY_PATH = pathlib.Path("/run/qubes/policy.d") POLICYSOCKET = pathlib.Path("/var/run/qubes/policy.sock") POLICY_EVAL_SOCKET = pathlib.Path("/etc/qubes-rpc/policy.EvalSimple") POLICY_GUI_SOCKET = pathlib.Path("/etc/qubes-rpc/policy.EvalGUI") INCLUDEPATH = POLICYPATH / "include" +RUNTIME_INCLUDE_PATH = RUNTIME_POLICY_PATH / "include" POLICYSUFFIX = ".policy" POLICYPATH_OLD = pathlib.Path("/etc/qubes-rpc/policy") diff --git a/qrexec/policy/parser.py b/qrexec/policy/parser.py index 8013edb2..d226bd4b 100644 --- a/qrexec/policy/parser.py +++ b/qrexec/policy/parser.py @@ -47,7 +47,7 @@ Sequence, ) -from .. import POLICYPATH, RPCNAME_ALLOWED_CHARSET, POLICYSUFFIX +from .. import POLICYPATH, RPCNAME_ALLOWED_CHARSET, POLICYSUFFIX, RUNTIME_POLICY_PATH from ..utils import FullSystemInfo from .. import exc from ..exc import ( @@ -1790,22 +1790,54 @@ class AbstractFileSystemLoader(AbstractDirectoryLoader, AbstractFileLoader): """This class is used when policy is stored as regular files in a directory. Args: - policy_path (pathlib.Path): Load this directory. Paths given to - ``!include`` etc. directives are interpreted relative to this path. + policy_path: Load these directories. Paths given to + ``!include`` etc. directives in a file are interpreted relative to + the path from which the file was loaded. """ - def __init__(self, *, policy_path=POLICYPATH, **kwds): - super().__init__(**kwds) - self.policy_path = pathlib.Path(policy_path) - + policy_path: Optional[pathlib.Path] + def __init__( + self, + *, + policy_path: Union[None, pathlib.PurePath, Iterable[pathlib.PurePath]] + ) -> None: + super().__init__() + if policy_path is None: + iterable_policy_paths = [RUNTIME_POLICY_PATH, POLICYPATH] + elif isinstance(policy_path, pathlib.Path): + iterable_policy_paths = [policy_path] + elif isinstance(policy_path, list): + iterable_policy_paths = policy_path + else: + raise TypeError("unexpected type of policy path in AbstractFileSystemLoader.__init__!") try: - self.load_policy_dir(self.policy_path) + self.load_policy_dirs(iterable_policy_paths) except OSError as err: raise AccessDenied( "failed to load {} file: {!s}".format(err.filename, err) ) from err - - def resolve_path(self, included_path): + self.policy_path = None + + def load_policy_dirs(self, paths: Iterable[pathlib.PurePath]) -> None: + already_seen = set() + final_list = [] + for path in paths: + for file_path in filter_filepaths(pathlib.Path(path).iterdir()): + basename = file_path.name + if basename not in already_seen: + already_seen.add(basename) + final_list.append(file_path) + final_list.sort(key=lambda x: x.name) + for file_path in final_list: + with file_path.open() as file: + self.policy_path = file_path.parent + try: + self.load_policy_file(file, file_path) + finally: + self.policy_path = None + + def resolve_path(self, included_path: pathlib.PurePosixPath) -> pathlib.Path: + assert self.policy_path is not None, "Tried to resolve a path when not loading policy" return (self.policy_path / included_path).resolve() @@ -1840,12 +1872,21 @@ class ValidateParser(FilePolicy): """ def __init__( - self, *args, overrides: Dict[pathlib.Path, Optional[str]], **kwds - ): + self, + *, + overrides: Dict[pathlib.Path, Optional[str]], + policy_path: Union[None, pathlib.PurePath, Iterable[pathlib.PurePath]] = None, + ) -> None: self.overrides = overrides - super().__init__(*args, **kwds) + super().__init__(policy_path=policy_path) - def load_policy_dir(self, dirpath): + def load_policy_dirs(self, paths: Iterable[pathlib.PurePath]) -> None: + assert len(paths) == 1 + path, = paths + self.policy_path = path + self.load_policy_dir(path) + + def load_policy_dir(self, dirpath: pathlib.Path) -> None: for path in filter_filepaths(dirpath.iterdir()): if path not in self.overrides: with path.open() as file: diff --git a/qrexec/policy/utils.py b/qrexec/policy/utils.py index cce5757f..76416718 100644 --- a/qrexec/policy/utils.py +++ b/qrexec/policy/utils.py @@ -20,18 +20,20 @@ import asyncio import os.path import pyinotify -from qrexec import POLICYPATH, POLICYPATH_OLD +from qrexec import POLICYPATH, POLICYPATH_OLD, RUNTIME_POLICY_PATH from . import parser class PolicyCache: - def __init__(self, path=POLICYPATH, use_legacy=True, lazy_load=False): - self.path = path + def __init__( + self, path=(RUNTIME_POLICY_PATH, POLICYPATH), use_legacy=True, lazy_load=False + ) -> None: + self.paths = list(path) self.outdated = lazy_load if lazy_load: self.policy = None else: - self.policy = parser.FilePolicy(policy_path=self.path) + self.policy = parser.FilePolicy(policy_path=self.paths) # default policy paths are listed manually, for compatibility with R4.0 # to be removed in Qubes 5.0 @@ -62,22 +64,20 @@ def initialize_watcher(self): self.watch_manager, loop, default_proc_fun=PolicyWatcher(self) ) - if str(self.path) not in self.default_policy_paths and os.path.exists( - self.path - ): - self.watches.append( - self.watch_manager.add_watch( - str(self.path), mask, rec=True, auto_add=True + for path in self.paths: + str_path = str(path) + if str_path not in self.default_policy_paths and os.path.exists(str_path): + self.watches.append( + self.watch_manager.add_watch( + str_path, mask, rec=True, auto_add=True + ) ) - ) for path in self.default_policy_paths: if not os.path.exists(path): continue self.watches.append( - self.watch_manager.add_watch( - str(path), mask, rec=True, auto_add=True - ) + self.watch_manager.add_watch(str(path), mask, rec=True, auto_add=True) ) def cleanup(self): @@ -92,7 +92,7 @@ def cleanup(self): def get_policy(self): if self.outdated: - self.policy = parser.FilePolicy(policy_path=self.path) + self.policy = parser.FilePolicy(policy_path=self.paths) self.outdated = False return self.policy diff --git a/qrexec/tests/cli.py b/qrexec/tests/cli.py index bb6c88d9..ffb69af0 100644 --- a/qrexec/tests/cli.py +++ b/qrexec/tests/cli.py @@ -96,7 +96,7 @@ def policy(): yield policy assert mock_policy.mock_calls == [ - mock.call(policy_path=PosixPath("/etc/qubes/policy.d")) + mock.call(policy_path=[PosixPath("/run/qubes/policy.d"), PosixPath("/etc/qubes/policy.d")]), ] diff --git a/qrexec/tests/policy_cache.py b/qrexec/tests/policy_cache.py index b1ca4634..e7baae01 100644 --- a/qrexec/tests/policy_cache.py +++ b/qrexec/tests/policy_cache.py @@ -23,11 +23,20 @@ import pytest import unittest import unittest.mock +import pathlib from ..policy import utils class TestPolicyCache: + @pytest.fixture + def tmp_paths(self, tmp_path: pathlib.Path) -> list[pathlib.Path]: + path1 = tmp_path / "path1" + path2 = tmp_path / "path2" + path1.mkdir() + path2.mkdir() + return [path1, path2] + @pytest.fixture def mock_parser(self, monkeypatch): mock_parser = unittest.mock.Mock() @@ -37,58 +46,60 @@ def mock_parser(self, monkeypatch): return mock_parser def test_00_policy_init(self, tmp_path, mock_parser): - cache = utils.PolicyCache(tmp_path) - mock_parser.assert_called_once_with(policy_path=tmp_path) + cache = utils.PolicyCache([tmp_path]) + mock_parser.assert_called_once_with(policy_path=[tmp_path]) @pytest.mark.asyncio - async def test_10_file_created(self, tmp_path, mock_parser): - cache = utils.PolicyCache(tmp_path) - cache.initialize_watcher() + async def test_10_file_created(self, tmp_paths, mock_parser): + for i in tmp_paths: + cache = utils.PolicyCache(tmp_paths) + cache.initialize_watcher() - assert not cache.outdated + assert not cache.outdated - file = tmp_path / "test" - file.write_text("test") + (i / "file").write_text("test") - await asyncio.sleep(1) + await asyncio.sleep(1) - assert cache.outdated + assert cache.outdated @pytest.mark.asyncio - async def test_11_file_changed(self, tmp_path, mock_parser): - file = tmp_path / "test" - file.write_text("test") + async def test_11_file_changed(self, tmp_paths, mock_parser): + for i in tmp_paths: + file = i / "test" + file.write_text("test") - cache = utils.PolicyCache(tmp_path) - cache.initialize_watcher() + cache = utils.PolicyCache(tmp_paths) + cache.initialize_watcher() - assert not cache.outdated + assert not cache.outdated - file.write_text("new_content") + file.write_text("new_content") - await asyncio.sleep(1) + await asyncio.sleep(1) - assert cache.outdated + assert cache.outdated @pytest.mark.asyncio - async def test_12_file_deleted(self, tmp_path, mock_parser): - file = tmp_path / "test" - file.write_text("test") + async def test_12_file_deleted(self, tmp_paths, mock_parser): + for i in tmp_paths: + file = i / "test" + file.write_text("test") - cache = utils.PolicyCache(tmp_path) - cache.initialize_watcher() + cache = utils.PolicyCache(tmp_paths) + cache.initialize_watcher() - assert not cache.outdated + assert not cache.outdated - os.remove(file) + os.remove(file) - await asyncio.sleep(1) + await asyncio.sleep(1) - assert cache.outdated + assert cache.outdated @pytest.mark.asyncio - async def test_13_no_change(self, tmp_path, mock_parser): - cache = utils.PolicyCache(tmp_path) + async def test_13_no_change(self, tmp_paths, mock_parser): + cache = utils.PolicyCache(tmp_paths) cache.initialize_watcher() assert not cache.outdated @@ -101,10 +112,10 @@ async def test_13_no_change(self, tmp_path, mock_parser): async def test_14_policy_move(self, tmp_path, mock_parser): policy_path = tmp_path / "policy" policy_path.mkdir() - cache = utils.PolicyCache(policy_path) + cache = utils.PolicyCache([policy_path]) cache.initialize_watcher() - mock_parser.assert_called_once_with(policy_path=policy_path) + mock_parser.assert_called_once_with(policy_path=[policy_path]) assert not cache.outdated @@ -135,27 +146,35 @@ async def test_14_policy_move(self, tmp_path, mock_parser): cache.get_policy() - call = unittest.mock.call(policy_path=policy_path) + call = unittest.mock.call(policy_path=[policy_path]) assert mock_parser.mock_calls == [call, call, call] @pytest.mark.asyncio - async def test_20_policy_updates(self, tmp_path, mock_parser): - cache = utils.PolicyCache(tmp_path) + async def test_20_policy_updates(self, tmp_paths, mock_parser): + cache = utils.PolicyCache(tmp_paths) cache.initialize_watcher() + count = 0 - mock_parser.assert_called_once_with(policy_path=tmp_path) + for i in tmp_paths: + call = unittest.mock.call(policy_path=tmp_paths) - assert not cache.outdated + count += 2 + assert mock_parser.mock_calls == [call] * (count - 1) + cache = utils.PolicyCache(tmp_paths) + cache.initialize_watcher() - file = tmp_path / "test" - file.write_text("test") + l = len(mock_parser.mock_calls) + assert mock_parser.mock_calls == [call] * l - await asyncio.sleep(1) + assert not cache.outdated - assert cache.outdated + file = i / "test" + file.write_text("test") - cache.get_policy() + await asyncio.sleep(1) + + assert cache.outdated - call = unittest.mock.call(policy_path=tmp_path) + cache.get_policy() - assert mock_parser.mock_calls == [call, call] + assert mock_parser.mock_calls == [call] * (count + 1) diff --git a/qrexec/tools/qrexec_legacy_convert.py b/qrexec/tools/qrexec_legacy_convert.py index 02388dd3..939904ee 100644 --- a/qrexec/tools/qrexec_legacy_convert.py +++ b/qrexec/tools/qrexec_legacy_convert.py @@ -288,7 +288,7 @@ def main(args=None): str(POLICYPATH), '--full-output'], output=current_state_string) current_state = set(current_state_string.getvalue().split('\n')) - except Exception: #pylint: disable-broad-except + except Exception: # pylint: disable=broad-except current_state = 'ERROR' if initial_state != current_state: diff --git a/qrexec/tools/qrexec_policy_daemon.py b/qrexec/tools/qrexec_policy_daemon.py index 5e8ddf28..ecd99d54 100644 --- a/qrexec/tools/qrexec_policy_daemon.py +++ b/qrexec/tools/qrexec_policy_daemon.py @@ -27,7 +27,7 @@ from ..utils import sanitize_domain_name, get_system_info from .qrexec_policy_exec import handle_request -from .. import POLICYPATH, POLICYSOCKET, POLICY_EVAL_SOCKET, POLICY_GUI_SOCKET +from .. import POLICYPATH, POLICYSOCKET, POLICY_EVAL_SOCKET, POLICY_GUI_SOCKET, RUNTIME_POLICY_PATH from ..policy.utils import PolicyCache argparser = argparse.ArgumentParser(description="Evaluate qrexec policy daemon") @@ -35,8 +35,9 @@ argparser.add_argument( "--policy-path", type=pathlib.Path, - default=POLICYPATH, + default=[RUNTIME_POLICY_PATH, POLICYPATH], help="Use alternative policy path", + action='append', ) argparser.add_argument( "--socket-path", @@ -291,6 +292,8 @@ async def handle_qrexec_connection( async def start_serving(args=None): args = argparser.parse_args(args) + if len(args.policy_path) > 2: + args.policy_path = args.policy_path[2:] logging.basicConfig(format="%(message)s") log = logging.getLogger("policy") diff --git a/qrexec/tools/qrexec_policy_exec.py b/qrexec/tools/qrexec_policy_exec.py index cabd9de9..31507b1b 100644 --- a/qrexec/tools/qrexec_policy_exec.py +++ b/qrexec/tools/qrexec_policy_exec.py @@ -27,12 +27,13 @@ import subprocess from typing import Optional, List, Union, Dict, Type -from .. import DEFAULT_POLICY, QREXEC_CLIENT, POLICYPATH +from .. import DEFAULT_POLICY, QREXEC_CLIENT, POLICYPATH, RUNTIME_POLICY_PATH from .. import exc from .. import utils from ..policy import parser from ..policy.utils import PolicyCache from ..server import call_socket_service +from ..utils import FullSystemInfo def create_default_policy(service_name): @@ -222,8 +223,9 @@ def prepare_resolution_types( argparser.add_argument( "--path", type=pathlib.Path, - default=POLICYPATH, + default=[RUNTIME_POLICY_PATH, POLICYPATH], help="Use alternative policy path", + action='append', ) argparser.add_argument( "args", @@ -233,6 +235,8 @@ def prepare_resolution_types( # pylint: disable=too-many-locals def get_result(args: Optional[List[str]]) -> Union[str, int]: parsed_args = argparser.parse_args(args) + if len(parsed_args.path) > 2: + parsed_args.path = args.path[2:] log = logging.getLogger("policy") log.setLevel(logging.INFO) @@ -326,8 +330,8 @@ async def handle_request( just_evaluate: bool = False, assume_yes_for_ask: bool = False, allow_resolution_type: Optional[type]=None, - policy_cache=None, - system_info=None, + policy_cache: Optional[PolicyCache]=None, + system_info: Optional[FullSystemInfo]=None, ) -> str: # Add source domain information, required by qrexec-client for establishing # connection @@ -345,10 +349,8 @@ async def handle_request( service, argument = service_and_arg, "+" try: - if policy_cache: - policy = policy_cache.get_policy() - else: - policy = parser.FilePolicy(policy_path=POLICYPATH) + assert policy_cache is not None + policy = policy_cache.get_policy() allow_resolution_class: Type[parser.AllowResolution] if allow_resolution_type is None: diff --git a/rpm_spec/qubes-qrexec-dom0.spec.in b/rpm_spec/qubes-qrexec-dom0.spec.in index 89155bcd..97f1e615 100644 --- a/rpm_spec/qubes-qrexec-dom0.spec.in +++ b/rpm_spec/qubes-qrexec-dom0.spec.in @@ -122,6 +122,7 @@ rm -f %{name}-%{version} %{_sysconfdir}/qubes-rpc/policy.include.Get %{_sysconfdir}/qubes-rpc/policy.include.Replace %{_sysconfdir}/qubes-rpc/policy.include.Remove +%{_tmpfilesdir}/qrexec.conf /lib/systemd/system/qubes-qrexec-policy-daemon.service diff --git a/systemd/qrexec.conf b/systemd/qrexec.conf new file mode 100644 index 00000000..de57e146 --- /dev/null +++ b/systemd/qrexec.conf @@ -0,0 +1,2 @@ +d /run/qubes 2770 root qubes +d /run/qubes/policy.d 2770 root qubes