diff --git a/compute_endpoint/globus_compute_endpoint/cli.py b/compute_endpoint/globus_compute_endpoint/cli.py index 9e9d9d38e..e98b77e7e 100644 --- a/compute_endpoint/globus_compute_endpoint/cli.py +++ b/compute_endpoint/globus_compute_endpoint/cli.py @@ -627,7 +627,8 @@ def _do_start_endpoint( ) reg_info = {} - config_str = None + config_str: str | None = None + fn_allow_list: list[str] | None | int = 0 if sys.stdin and not (sys.stdin.closed or sys.stdin.isatty()): try: stdin_data = json.loads(sys.stdin.read()) @@ -641,6 +642,7 @@ def _do_start_endpoint( reg_info = stdin_data.get("amqp_creds", {}) config_str = stdin_data.get("config", None) + fn_allow_list = stdin_data.get("allowed_functions", fn_allow_list) del stdin_data # clarity for intended scope @@ -656,6 +658,10 @@ def _do_start_endpoint( ep_config = get_config(ep_dir) del config_str + if fn_allow_list != 0: + # 0 is not a valid type for this field, so used to distinguish None + ep_config.allowed_functions = fn_allow_list + if not state.debug and ep_config.debug: setup_logging( logfile=ep_dir / "endpoint.log", diff --git a/compute_endpoint/globus_compute_endpoint/endpoint/config/config.py b/compute_endpoint/globus_compute_endpoint/endpoint/config/config.py index a9b7ad3e1..f0d295b24 100644 --- a/compute_endpoint/globus_compute_endpoint/endpoint/config/config.py +++ b/compute_endpoint/globus_compute_endpoint/endpoint/config/config.py @@ -117,7 +117,7 @@ def heartbeat_period(self, val: float | int): @property def allowed_functions(self): - if self._allowed_functions: + if self._allowed_functions is not None: return tuple(map(str, self._allowed_functions)) return None diff --git a/compute_endpoint/globus_compute_endpoint/endpoint/endpoint_manager.py b/compute_endpoint/globus_compute_endpoint/endpoint/endpoint_manager.py index 537c21561..b36b733cb 100644 --- a/compute_endpoint/globus_compute_endpoint/endpoint/endpoint_manager.py +++ b/compute_endpoint/globus_compute_endpoint/endpoint/endpoint_manager.py @@ -1073,10 +1073,13 @@ def cmd_start_endpoint( self._config, template_str, user_config_schema, user_opts, user_runtime ) stdin_data_dict = { + "allowed_functions": self._config.allowed_functions, "amqp_creds": kwargs.get("amqp_creds"), "config": user_config, } - stdin_data = json.dumps(stdin_data_dict) + + stdin_data = json.dumps(stdin_data_dict, separators=(",", ":")) + exit_code += 1 # Reminder: this is *os*.open, not *open*. Descriptors will not be closed diff --git a/compute_endpoint/tests/unit/test_cli_behavior.py b/compute_endpoint/tests/unit/test_cli_behavior.py index 76510a6de..d041e71b2 100644 --- a/compute_endpoint/tests/unit/test_cli_behavior.py +++ b/compute_endpoint/tests/unit/test_cli_behavior.py @@ -313,6 +313,34 @@ def test_start_ep_reads_stdin( assert reg_info_found == {} +@pytest.mark.parametrize("fn_count", range(-1, 5)) +def test_start_ep_stdin_allowed_fns_overrides_conf( + mocker, run_line, mock_cli_state, make_endpoint_dir, ep_name, fn_count +): + if fn_count == -1: + allowed_fns = None + else: + allowed_fns = tuple(str(uuid.uuid4()) for _ in range(fn_count)) + + conf = UserEndpointConfig(executors=[ThreadPoolEngine]) + conf.allowed_functions = [uuid.uuid4() for _ in range(5)] # to be overridden + mock_get_config = mocker.patch(f"{_MOCK_BASE}get_config") + mock_get_config.return_value = conf + + mock_sys = mocker.patch(f"{_MOCK_BASE}sys") + mock_sys.stdin.closed = False + mock_sys.stdin.isatty.return_value = False + mock_sys.stdin.read.return_value = json.dumps({"allowed_functions": allowed_fns}) + + make_endpoint_dir() + + run_line(f"start {ep_name}") + mock_ep, _ = mock_cli_state + assert mock_ep.start_endpoint.called + (_, _, found_conf, *_), _k = mock_ep.start_endpoint.call_args + assert found_conf.allowed_functions == allowed_fns, "allowed field not overridden!" + + @pytest.mark.parametrize("use_uuid", (True, False)) @mock.patch(f"{_MOCK_BASE}get_config") def test_stop_endpoint( diff --git a/compute_endpoint/tests/unit/test_endpointmanager_unit.py b/compute_endpoint/tests/unit/test_endpointmanager_unit.py index 95c5d1596..4d3e5b392 100644 --- a/compute_endpoint/tests/unit/test_endpointmanager_unit.py +++ b/compute_endpoint/tests/unit/test_endpointmanager_unit.py @@ -2006,8 +2006,7 @@ def test_pipe_size_limit(mocker, mock_log, successful_exec_from_mocked_root, con conf_str = "v: " + "$" * (conf_size - 3) - # Add 34 bytes for dict keys, etc. - stdin_data_size = conf_size + 34 + stdin_data_size = conf_size + 56 # overhead for JSON dict keys, etc. pipe_buffer_size = 512 # Subtract 256 for hard-coded buffer in-code is_valid = pipe_buffer_size - 256 - stdin_data_size >= 0 @@ -2039,6 +2038,30 @@ def _remove_user_config_template(*args, **kwargs): assert pyexc.value.code == _GOOD_EC, "Q&D: verify we exec'ed, based on '+= 1'" +@pytest.mark.parametrize("fn_count", (0, 1, 2, 3, random.randint(4, 100))) +def test_set_uep_allowed_functions( + successful_exec_from_mocked_root, mock_conf_root, fn_count +): + mock_os, *_, em = successful_exec_from_mocked_root + + m = mock.Mock() + mock_os.fdopen.return_value.__enter__.return_value = m + + fns = [str(uuid.uuid4()) for _ in range(fn_count)] + mock_conf_root.allowed_functions = fns + with mock.patch.object(fcntl, "fcntl", return_value=2**20): + # 2**20 == plenty for test + with pytest.raises(SystemExit) as pyexc: + em._event_loop() + + assert pyexc.value.code == _GOOD_EC, "Q&D: verify we exec'ed, based on '+= 1'" + + (received_stdin,), _k = m.write.call_args + parsed_stdin = json.loads(received_stdin) + assert "allowed_functions" in parsed_stdin, "Even empty list should be stated" + assert parsed_stdin["allowed_functions"] == fns + + def test_redirect_stdstreams_to_user_log( successful_exec_from_mocked_root, conf_dir, command_payload ):