Skip to content

Commit

Permalink
Executor: rename funcx_client arg to client
Browse files Browse the repository at this point in the history
This is achieved by adding a new client argument that has higher
priority than the existing funcx_client argument, and by emitting a
warning to users if they supply funcx_client.

Also standardize on the usage of '' vs `` in the Executor deprecation
warning messages ('' is used by warnings in other classes).
  • Loading branch information
chris-janidlo committed Dec 7, 2023
1 parent 24bcd99 commit 18ba19b
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 32 deletions.
4 changes: 4 additions & 0 deletions changelog.d/20231204_155243_chris_remove_funcx_client.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Deprecated
^^^^^^^^^^

- The ``funcx_client`` argument to the ``Executor`` has been deprecated and replaced with ``client``.
41 changes: 25 additions & 16 deletions compute_sdk/globus_compute_sdk/sdk/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,20 +115,20 @@ def __init__(
self,
endpoint_id: UUID_LIKE_T | None = None,
container_id: UUID_LIKE_T | None = None,
funcx_client: Client | None = None,
client: Client | None = None,
task_group_id: UUID_LIKE_T | None = None,
user_endpoint_config: dict[str, t.Any] | None = None,
label: str = "",
batch_size: int = 128,
funcx_client: Client | None = None,
amqp_port: int | None = None,
**kwargs,
):
"""
:param endpoint_id: id of the endpoint to which to submit tasks
:param container_id: id of the container in which to execute tasks
:param funcx_client: instance of Client to be used by the
executor. If not provided, the executor will instantiate one with default
arguments.
:param client: instance of Client to be used by the executor. If not provided,
the executor will instantiate one with default arguments.
:param task_group_id: The Task Group to which to associate tasks. If not set,
one will be instantiated.
:param user_endpoint_config: User endpoint configuration values as described
Expand All @@ -138,6 +138,7 @@ def __init__(
logging and advanced needs with multiple executors.
:param batch_size: the maximum number of tasks to coalesce before
sending upstream [min: 1, default: 128]
:param funcx_client: [DEPRECATED] alias for client.
:param batch_interval: [DEPRECATED; unused] number of seconds to coalesce tasks
before submitting upstream
:param batch_enabled: [DEPRECATED; unused] whether to batch results
Expand All @@ -148,16 +149,24 @@ def __init__(
for key in kwargs:
if key in deprecated_kwargs:
warnings.warn(
f"`{key}` is not utilized and will be removed in a future release",
f"'{key}' is not utilized and will be removed in a future release",
DeprecationWarning,
)
continue
msg = f"'{key}' is an invalid argument for {self.__class__.__name__}"
raise TypeError(msg)

if not funcx_client:
funcx_client = Client()
self.funcx_client = funcx_client
if funcx_client:
warnings.warn(
"'funcx_client' is superseded by 'client'"
" and will be removed in a future release",
DeprecationWarning,
)

if not client:
client = funcx_client if funcx_client else Client()

self.client = client

self.endpoint_id = endpoint_id

Expand Down Expand Up @@ -380,7 +389,7 @@ def register_function(
reg_kwargs.update(func_register_kwargs)

try:
func_reg_id = self.funcx_client.register_function(fn, **reg_kwargs)
func_reg_id = self.client.register_function(fn, **reg_kwargs)
except Exception:
log.error(f"Unable to register function: {fn.__name__}")
self.shutdown(wait=False, cancel_futures=True)
Expand Down Expand Up @@ -579,7 +588,7 @@ def reload_tasks(
assert task_group_id is not None # mypy: we _just_ proved this

# step 2: from server, acquire list of related task ids and make futures
r = self.funcx_client.web_client.get_taskgroup_tasks(task_group_id)
r = self.client.web_client.get_taskgroup_tasks(task_group_id)
if r["taskgroup_id"] != str(task_group_id):
msg = (
"Server did not respond with requested TaskGroup Tasks. "
Expand All @@ -595,7 +604,7 @@ def reload_tasks(
if task_ids:
# Complete the futures that already have results.
pending: list[ComputeFuture] = []
deserialize = self.funcx_client.fx_serializer.deserialize
deserialize = self.client.fx_serializer.deserialize
chunk_size = 1024
num_chunks = len(task_ids) // chunk_size + 1
for chunk_no, id_chunk in enumerate(
Expand All @@ -611,7 +620,7 @@ def reload_tasks(
len(id_chunk),
)

res = self.funcx_client.web_client.get_batch_status(id_chunk)
res = self.client.web_client.get_batch_status(id_chunk)
for task_id, task in res.data.get("results", {}).items():
fut = ComputeFuture(task_id)
futures.append(fut)
Expand Down Expand Up @@ -850,7 +859,7 @@ def _submit_tasks(
if taskgroup_uuid is None and self.task_group_id:
taskgroup_uuid = self.task_group_id

batch = self.funcx_client.create_batch(
batch = self.client.create_batch(
taskgroup_uuid, user_endpoint_config, create_websocket_queue=True
)
submitted_futs_by_fn: t.DefaultDict[str, list[ComputeFuture]] = defaultdict(
Expand All @@ -863,7 +872,7 @@ def _submit_tasks(
log.debug("Added task to Globus Compute batch: %s", task)

try:
batch_response = self.funcx_client.batch_run(endpoint_uuid, batch)
batch_response = self.client.batch_run(endpoint_uuid, batch)
except Exception as e:
log.exception(f"Error submitting {len(tasks)} tasks to Globus Compute")
for fut_list in submitted_futs_by_fn.values():
Expand Down Expand Up @@ -1163,7 +1172,7 @@ def _match_results_to_futures(self):
This method will set the _open_futures_empty event if there are no open
futures *at the time of processing*.
"""
deserialize = self.funcx_executor.funcx_client.fx_serializer.deserialize
deserialize = self.funcx_executor.client.fx_serializer.deserialize
with self._new_futures_lock:
futures_to_complete = [
self._open_futures.pop(tid)
Expand Down Expand Up @@ -1290,7 +1299,7 @@ def _stop_ioloop(self):

def _connect(self) -> pika.SelectConnection:
with self._new_futures_lock:
res = self.funcx_executor.funcx_client.get_result_amqp_url()
res = self.funcx_executor.client.get_result_amqp_url()
self._queue_prefix = res["queue_prefix"]
connection_url = res["connection_url"]

Expand Down
10 changes: 5 additions & 5 deletions compute_sdk/tests/integration/test_executor_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
def test_resultwatcher_graceful_shutdown():
service_url = os.environ["COMPUTE_INTEGRATION_TEST_WEB_URL"]
gcc = Client(funcx_service_address=service_url)
gce = Executor(funcx_client=gcc)
gce = Executor(client=gcc)
rw = _ResultWatcher(gce)
rw._start_consuming = mock.Mock()
rw.start()
Expand All @@ -38,12 +38,12 @@ def test_executor_atexit_handler_catches_all_instances(tmp_path):
from globus_compute_sdk import Executor
from globus_compute_sdk.sdk.executor import _REGISTERED_FXEXECUTORS
gcc = " a fake funcx_client"
gcc = " a fake client"
num_executors = random.randrange(1, 10)
for i in range(num_executors):
Executor(funcx_client=gcc) # start N threads, none shutdown
gce = Executor(funcx_client=gcc) # intentionally overwritten
gce = Executor(funcx_client=gcc)
Executor(client=gcc) # start N threads, none shutdown
gce = Executor(client=gcc) # intentionally overwritten
gce = Executor(client=gcc)
num_executors += 2
assert len(_REGISTERED_FXEXECUTORS) == num_executors, (
Expand Down
16 changes: 8 additions & 8 deletions compute_sdk/tests/unit/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def noop():

class MockedExecutor(Executor):
def __init__(self, *args, **kwargs):
kwargs.update({"funcx_client": mock.Mock(spec=Client)})
kwargs.update({"client": mock.Mock(spec=Client)})
super().__init__(*args, **kwargs)
self._test_paused = threading.Event()
self._time_to_stop_mock = threading.Event()
Expand Down Expand Up @@ -76,7 +76,7 @@ def join(self, timeout: float | None = None) -> None:
@pytest.fixture
def gc_executor(mocker):
gcc = mock.MagicMock()
gce = Executor(funcx_client=gcc)
gce = Executor(client=gcc)
watcher = mocker.patch(f"{_MOCK_BASE}_ResultWatcher", autospec=True)

def create_mock_watcher(*args, **kwargs):
Expand Down Expand Up @@ -136,10 +136,10 @@ def test_task_submission_info_stringification(tg_id, fn_id, ep_id, uep_config):
def test_deprecated_args_warned(argname, mocker):
mock_warn = mocker.patch(f"{_MOCK_BASE}warnings")
gcc = mock.Mock(spec=Client)
Executor(funcx_client=gcc).shutdown()
Executor(client=gcc).shutdown()
mock_warn.warn.assert_not_called()

Executor(funcx_client=gcc, **{argname: 1}).shutdown()
Executor(client=gcc, **{argname: 1}).shutdown()
mock_warn.warn.assert_called()


Expand Down Expand Up @@ -645,7 +645,7 @@ def test_task_submitter_stops_executor_on_upstream_error_response(randomstring):
gce = MockedExecutor()

upstream_error = Exception(f"Upstream error {randomstring}!!")
gce.funcx_client.batch_run.side_effect = upstream_error
gce.client.batch_run.side_effect = upstream_error
gce.task_group_id = uuid.uuid4()
tsi = _TaskSubmissionInfo(
task_num=12345,
Expand Down Expand Up @@ -931,7 +931,7 @@ def test_resultwatcher_match_sets_exception(randomstring):
res = Result(task_id=fut.task_id, error_details=err_details, data=payload)

mrw = MockedResultWatcher(mock.Mock())
mrw.funcx_executor.funcx_client.fx_serializer.deserialize = fxs.deserialize
mrw.funcx_executor.client.fx_serializer.deserialize = fxs.deserialize
mrw._received_results[fut.task_id] = (mock.Mock(timestamp=5), res)
mrw.watch_for_task_results([fut])
mrw.start()
Expand All @@ -949,7 +949,7 @@ def test_resultwatcher_match_sets_result(randomstring):
res = Result(task_id=fut.task_id, data=fxs.serialize(payload))

mrw = MockedResultWatcher(mock.Mock())
mrw.funcx_executor.funcx_client.fx_serializer.deserialize = fxs.deserialize
mrw.funcx_executor.client.fx_serializer.deserialize = fxs.deserialize
mrw._received_results[fut.task_id] = (None, res)
mrw.watch_for_task_results([fut])
mrw.start()
Expand All @@ -966,7 +966,7 @@ def test_resultwatcher_match_handles_deserialization_error():
res = Result(task_id=fut.task_id, data=invalid_payload)

mrw = MockedResultWatcher(mock.Mock())
mrw.funcx_executor.funcx_client.fx_serializer.deserialize = fxs.deserialize
mrw.funcx_executor.client.fx_serializer.deserialize = fxs.deserialize
mrw._received_results[fut.task_id] = (None, res)
mrw.watch_for_task_results([fut])
mrw.start()
Expand Down
4 changes: 2 additions & 2 deletions docs/sdk.rst
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ More details on the Globus Compute login manager prototcol are available `here.
compute_login_manager.ensure_logged_in()
gc = Client(login_manager=compute_login_manager)
gce = Executor(endpoint_id=tutorial_endpoint, funcx_client=gc)
gce = Executor(endpoint_id=tutorial_endpoint, client=gc)
Specifying a Serialization Strategy
Expand All @@ -281,7 +281,7 @@ another serializer, use the ``code_serialization_strategy`` and
code_serialization_strategy=CombinedCode(),
data_serialization_strategy=DillDataBase64()
)
gcx = Executor('4b116d3c-1703-4f8f-9f6f-39921e5864df', funcx_client=gcc)
gcx = Executor('4b116d3c-1703-4f8f-9f6f-39921e5864df', client=gcc)
# do something with gcx
Expand Down
2 changes: 1 addition & 1 deletion smoke_tests/tests/test_running_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_executor(compute_client, endpoint, tutorial_function_id):
num_tasks = 10
submit_count = 2 # we've had at least one bug that prevented executor re-use

with Executor(endpoint_id=endpoint, funcx_client=compute_client) as gce:
with Executor(endpoint_id=endpoint, client=compute_client) as gce:
for _ in range(submit_count):
futures = [
gce.submit_to_registered_function(tutorial_function_id)
Expand Down

0 comments on commit 18ba19b

Please sign in to comment.