From 0ebb7f0819e3c88f82327fc695112859556394d5 Mon Sep 17 00:00:00 2001 From: Chris Janidlo Date: Mon, 4 Dec 2023 15:55:11 -0600 Subject: [PATCH] Executor: rename funcx_client arg to compute_client This is achieved by adding a new compute_client argument that has higher priority than the existing funcx_client argument, and by emitting a warning to users if they supply funcx_client. Also standardize on the usage of '' vs `` in the Executor deprecation warning messages ('' is used by warnings in other classes). --- ...31204_155243_chris_remove_funcx_client.rst | 4 ++ .../globus_compute_sdk/sdk/executor.py | 41 +++++++++++-------- .../tests/integration/test_executor_int.py | 10 ++--- compute_sdk/tests/unit/test_executor.py | 16 ++++---- docs/sdk.rst | 4 +- smoke_tests/tests/test_running_functions.py | 2 +- 6 files changed, 45 insertions(+), 32 deletions(-) create mode 100644 changelog.d/20231204_155243_chris_remove_funcx_client.rst diff --git a/changelog.d/20231204_155243_chris_remove_funcx_client.rst b/changelog.d/20231204_155243_chris_remove_funcx_client.rst new file mode 100644 index 000000000..130102203 --- /dev/null +++ b/changelog.d/20231204_155243_chris_remove_funcx_client.rst @@ -0,0 +1,4 @@ +Deprecated +^^^^^^^^^^ + +- The ``funcx_client`` argument to the ``Executor`` has been deprecated and replaced with ``client``. diff --git a/compute_sdk/globus_compute_sdk/sdk/executor.py b/compute_sdk/globus_compute_sdk/sdk/executor.py index 9b58cdc08..22e00f69b 100644 --- a/compute_sdk/globus_compute_sdk/sdk/executor.py +++ b/compute_sdk/globus_compute_sdk/sdk/executor.py @@ -115,20 +115,20 @@ def __init__( self, endpoint_id: UUID_LIKE_T | None = None, container_id: UUID_LIKE_T | None = None, - funcx_client: Client | None = None, + client: Client | None = None, task_group_id: UUID_LIKE_T | None = None, user_endpoint_config: dict[str, t.Any] | None = None, label: str = "", batch_size: int = 128, + funcx_client: Client | None = None, amqp_port: int | None = None, **kwargs, ): """ :param endpoint_id: id of the endpoint to which to submit tasks :param container_id: id of the container in which to execute tasks - :param funcx_client: instance of Client to be used by the - executor. If not provided, the executor will instantiate one with default - arguments. + :param client: instance of Client to be used by the executor. If not provided, + the executor will instantiate one with default arguments. :param task_group_id: The Task Group to which to associate tasks. If not set, one will be instantiated. :param user_endpoint_config: User endpoint configuration values as described @@ -138,6 +138,7 @@ def __init__( logging and advanced needs with multiple executors. :param batch_size: the maximum number of tasks to coalesce before sending upstream [min: 1, default: 128] + :param funcx_client: [DEPRECATED] alias for client. :param batch_interval: [DEPRECATED; unused] number of seconds to coalesce tasks before submitting upstream :param batch_enabled: [DEPRECATED; unused] whether to batch results @@ -148,16 +149,24 @@ def __init__( for key in kwargs: if key in deprecated_kwargs: warnings.warn( - f"`{key}` is not utilized and will be removed in a future release", + f"'{key}' is not utilized and will be removed in a future release", DeprecationWarning, ) continue msg = f"'{key}' is an invalid argument for {self.__class__.__name__}" raise TypeError(msg) - if not funcx_client: - funcx_client = Client() - self.funcx_client = funcx_client + if funcx_client: + warnings.warn( + "'funcx_client' is superceded by 'client'" + " and will be removed in a future release", + DeprecationWarning, + ) + + if not client: + client = funcx_client if funcx_client else Client() + + self.client = client self.endpoint_id = endpoint_id @@ -380,7 +389,7 @@ def register_function( reg_kwargs.update(func_register_kwargs) try: - func_reg_id = self.funcx_client.register_function(fn, **reg_kwargs) + func_reg_id = self.client.register_function(fn, **reg_kwargs) except Exception: log.error(f"Unable to register function: {fn.__name__}") self.shutdown(wait=False, cancel_futures=True) @@ -579,7 +588,7 @@ def reload_tasks( assert task_group_id is not None # mypy: we _just_ proved this # step 2: from server, acquire list of related task ids and make futures - r = self.funcx_client.web_client.get_taskgroup_tasks(task_group_id) + r = self.client.web_client.get_taskgroup_tasks(task_group_id) if r["taskgroup_id"] != str(task_group_id): msg = ( "Server did not respond with requested TaskGroup Tasks. " @@ -595,7 +604,7 @@ def reload_tasks( if task_ids: # Complete the futures that already have results. pending: list[ComputeFuture] = [] - deserialize = self.funcx_client.fx_serializer.deserialize + deserialize = self.client.fx_serializer.deserialize chunk_size = 1024 num_chunks = len(task_ids) // chunk_size + 1 for chunk_no, id_chunk in enumerate( @@ -611,7 +620,7 @@ def reload_tasks( len(id_chunk), ) - res = self.funcx_client.web_client.get_batch_status(id_chunk) + res = self.client.web_client.get_batch_status(id_chunk) for task_id, task in res.data.get("results", {}).items(): fut = ComputeFuture(task_id) futures.append(fut) @@ -850,7 +859,7 @@ def _submit_tasks( if taskgroup_uuid is None and self.task_group_id: taskgroup_uuid = self.task_group_id - batch = self.funcx_client.create_batch( + batch = self.client.create_batch( taskgroup_uuid, user_endpoint_config, create_websocket_queue=True ) submitted_futs_by_fn: t.DefaultDict[str, list[ComputeFuture]] = defaultdict( @@ -863,7 +872,7 @@ def _submit_tasks( log.debug("Added task to Globus Compute batch: %s", task) try: - batch_response = self.funcx_client.batch_run(endpoint_uuid, batch) + batch_response = self.client.batch_run(endpoint_uuid, batch) except Exception as e: log.exception(f"Error submitting {len(tasks)} tasks to Globus Compute") for fut_list in submitted_futs_by_fn.values(): @@ -1163,7 +1172,7 @@ def _match_results_to_futures(self): This method will set the _open_futures_empty event if there are no open futures *at the time of processing*. """ - deserialize = self.funcx_executor.funcx_client.fx_serializer.deserialize + deserialize = self.funcx_executor.client.fx_serializer.deserialize with self._new_futures_lock: futures_to_complete = [ self._open_futures.pop(tid) @@ -1290,7 +1299,7 @@ def _stop_ioloop(self): def _connect(self) -> pika.SelectConnection: with self._new_futures_lock: - res = self.funcx_executor.funcx_client.get_result_amqp_url() + res = self.funcx_executor.client.get_result_amqp_url() self._queue_prefix = res["queue_prefix"] connection_url = res["connection_url"] diff --git a/compute_sdk/tests/integration/test_executor_int.py b/compute_sdk/tests/integration/test_executor_int.py index ce0a21893..64493d9f3 100644 --- a/compute_sdk/tests/integration/test_executor_int.py +++ b/compute_sdk/tests/integration/test_executor_int.py @@ -16,7 +16,7 @@ def test_resultwatcher_graceful_shutdown(): service_url = os.environ["COMPUTE_INTEGRATION_TEST_WEB_URL"] gcc = Client(funcx_service_address=service_url) - gce = Executor(funcx_client=gcc) + gce = Executor(client=gcc) rw = _ResultWatcher(gce) rw._start_consuming = mock.Mock() rw.start() @@ -38,12 +38,12 @@ def test_executor_atexit_handler_catches_all_instances(tmp_path): from globus_compute_sdk import Executor from globus_compute_sdk.sdk.executor import _REGISTERED_FXEXECUTORS - gcc = " a fake funcx_client" + gcc = " a fake client" num_executors = random.randrange(1, 10) for i in range(num_executors): - Executor(funcx_client=gcc) # start N threads, none shutdown - gce = Executor(funcx_client=gcc) # intentionally overwritten - gce = Executor(funcx_client=gcc) + Executor(client=gcc) # start N threads, none shutdown + gce = Executor(client=gcc) # intentionally overwritten + gce = Executor(client=gcc) num_executors += 2 assert len(_REGISTERED_FXEXECUTORS) == num_executors, ( diff --git a/compute_sdk/tests/unit/test_executor.py b/compute_sdk/tests/unit/test_executor.py index 1e0b94bb7..29fd1af7a 100644 --- a/compute_sdk/tests/unit/test_executor.py +++ b/compute_sdk/tests/unit/test_executor.py @@ -36,7 +36,7 @@ def noop(): class MockedExecutor(Executor): def __init__(self, *args, **kwargs): - kwargs.update({"funcx_client": mock.Mock(spec=Client)}) + kwargs.update({"client": mock.Mock(spec=Client)}) super().__init__(*args, **kwargs) self._test_paused = threading.Event() self._time_to_stop_mock = threading.Event() @@ -76,7 +76,7 @@ def join(self, timeout: float | None = None) -> None: @pytest.fixture def gc_executor(mocker): gcc = mock.MagicMock() - gce = Executor(funcx_client=gcc) + gce = Executor(client=gcc) watcher = mocker.patch(f"{_MOCK_BASE}_ResultWatcher", autospec=True) def create_mock_watcher(*args, **kwargs): @@ -136,10 +136,10 @@ def test_task_submission_info_stringification(tg_id, fn_id, ep_id, uep_config): def test_deprecated_args_warned(argname, mocker): mock_warn = mocker.patch(f"{_MOCK_BASE}warnings") gcc = mock.Mock(spec=Client) - Executor(funcx_client=gcc).shutdown() + Executor(client=gcc).shutdown() mock_warn.warn.assert_not_called() - Executor(funcx_client=gcc, **{argname: 1}).shutdown() + Executor(client=gcc, **{argname: 1}).shutdown() mock_warn.warn.assert_called() @@ -645,7 +645,7 @@ def test_task_submitter_stops_executor_on_upstream_error_response(randomstring): gce = MockedExecutor() upstream_error = Exception(f"Upstream error {randomstring}!!") - gce.funcx_client.batch_run.side_effect = upstream_error + gce.client.batch_run.side_effect = upstream_error gce.task_group_id = uuid.uuid4() tsi = _TaskSubmissionInfo( task_num=12345, @@ -931,7 +931,7 @@ def test_resultwatcher_match_sets_exception(randomstring): res = Result(task_id=fut.task_id, error_details=err_details, data=payload) mrw = MockedResultWatcher(mock.Mock()) - mrw.funcx_executor.funcx_client.fx_serializer.deserialize = fxs.deserialize + mrw.funcx_executor.client.fx_serializer.deserialize = fxs.deserialize mrw._received_results[fut.task_id] = (mock.Mock(timestamp=5), res) mrw.watch_for_task_results([fut]) mrw.start() @@ -949,7 +949,7 @@ def test_resultwatcher_match_sets_result(randomstring): res = Result(task_id=fut.task_id, data=fxs.serialize(payload)) mrw = MockedResultWatcher(mock.Mock()) - mrw.funcx_executor.funcx_client.fx_serializer.deserialize = fxs.deserialize + mrw.funcx_executor.client.fx_serializer.deserialize = fxs.deserialize mrw._received_results[fut.task_id] = (None, res) mrw.watch_for_task_results([fut]) mrw.start() @@ -966,7 +966,7 @@ def test_resultwatcher_match_handles_deserialization_error(): res = Result(task_id=fut.task_id, data=invalid_payload) mrw = MockedResultWatcher(mock.Mock()) - mrw.funcx_executor.funcx_client.fx_serializer.deserialize = fxs.deserialize + mrw.funcx_executor.client.fx_serializer.deserialize = fxs.deserialize mrw._received_results[fut.task_id] = (None, res) mrw.watch_for_task_results([fut]) mrw.start() diff --git a/docs/sdk.rst b/docs/sdk.rst index 7ea0f4a8f..b5a3789d1 100644 --- a/docs/sdk.rst +++ b/docs/sdk.rst @@ -254,7 +254,7 @@ More details on the Globus Compute login manager prototcol are available `here. compute_login_manager.ensure_logged_in() gc = Client(login_manager=compute_login_manager) - gce = Executor(endpoint_id=tutorial_endpoint, funcx_client=gc) + gce = Executor(endpoint_id=tutorial_endpoint, client=gc) Specifying a Serialization Strategy @@ -281,7 +281,7 @@ another serializer, use the ``code_serialization_strategy`` and code_serialization_strategy=CombinedCode(), data_serialization_strategy=DillDataBase64() ) - gcx = Executor('4b116d3c-1703-4f8f-9f6f-39921e5864df', funcx_client=gcc) + gcx = Executor('4b116d3c-1703-4f8f-9f6f-39921e5864df', client=gcc) # do something with gcx diff --git a/smoke_tests/tests/test_running_functions.py b/smoke_tests/tests/test_running_functions.py index 5b99b31dd..241913797 100644 --- a/smoke_tests/tests/test_running_functions.py +++ b/smoke_tests/tests/test_running_functions.py @@ -96,7 +96,7 @@ def test_executor(compute_client, endpoint, tutorial_function_id): num_tasks = 10 submit_count = 2 # we've had at least one bug that prevented executor re-use - with Executor(endpoint_id=endpoint, funcx_client=compute_client) as gce: + with Executor(endpoint_id=endpoint, client=compute_client) as gce: for _ in range(submit_count): futures = [ gce.submit_to_registered_function(tutorial_function_id)