diff --git a/compute_sdk/globus_compute_sdk/sdk/executor.py b/compute_sdk/globus_compute_sdk/sdk/executor.py index c3acf0f7c..24ba6f79a 100644 --- a/compute_sdk/globus_compute_sdk/sdk/executor.py +++ b/compute_sdk/globus_compute_sdk/sdk/executor.py @@ -959,6 +959,14 @@ def __hash__(self): api_burst_fill.append(fill_percent) to_watch = [f for f in fut_list if f.task_id and not f.done()] + log.debug( + "%r (tid:%s): Submission complete (to %s);" + " count to watcher: %d", + self, + _tid, + ep_uuid, + len(to_watch), + ) if not to_watch: continue @@ -1043,6 +1051,8 @@ def _submit_tasks( set when function completes successfully. :param tasks: a list of tasks to submit upstream in a batch. """ + assert len(futs) == len(tasks), "Developer reminder" + _tid = threading.get_ident() if taskgroup_uuid is None and self.task_group_id: taskgroup_uuid = self.task_group_id @@ -1083,10 +1093,12 @@ def _submit_tasks( try: received_tasks_by_fn: dict[str, list[str]] = batch_response["tasks"] - new_tg_id: str = batch_response["task_group_id"] + new_tg_id = uuid.UUID(batch_response["task_group_id"]) + _request_id = uuid.UUID(batch_response["request_id"]) + _endpoint_id = uuid.UUID(batch_response["endpoint_id"]) except Exception as e: log.debug( - f"Server response ({batch_response}) missing an expected field" + f"Invalid or unexpected server response ({batch_response})" f" [({type(e).__name__}) {e}]" ) for fut_list in submitted_futs_by_fn.values(): @@ -1095,7 +1107,7 @@ def _submit_tasks( self._submitter_thread_exception_captured = True raise - if str(self.task_group_id) != new_tg_id: + if self.task_group_id != new_tg_id: log.info(f"Updating task_group_id from {self.task_group_id} to {new_tg_id}") self.task_group_id = new_tg_id @@ -1113,17 +1125,15 @@ def _submit_tasks( for fn_id, fut_list in submitted_futs_by_fn.items(): task_uuids = received_tasks_by_fn.get(fn_id) + fut_exc = None if task_uuids is None: fut_exc = Exception( f"The Globus Compute Service ignored tasks for function {fn_id}!" " This 'should not happen,' so please reach out to the Globus" " Compute team if you are able to recreate this behavior." ) - for fut in fut_list: - fut.set_exception(fut_exc) - continue - if len(fut_list) != len(task_uuids): + elif len(fut_list) != len(task_uuids): fut_exc = Exception( "The Globus Compute Service only partially initiated requested" f" tasks for function {fn_id}! It is unclear which tasks it" @@ -1131,12 +1141,20 @@ def _submit_tasks( " to the Globus Compute team if you are able to recreate this" " behavior." ) + + if fut_exc: for fut in fut_list: + fut._metadata["request_uuid"] = _request_id + fut._metadata["endpoint_uuid"] = _endpoint_id + fut._metadata["task_group_uuid"] = new_tg_id fut.set_exception(fut_exc) continue # Happy -- expected -- path for fut, task_id in zip(fut_list, task_uuids): + fut._metadata["request_uuid"] = _request_id + fut._metadata["endpoint_uuid"] = _endpoint_id + fut._metadata["task_group_uuid"] = new_tg_id fut.task_id = task_id diff --git a/compute_sdk/tests/unit/test_executor.py b/compute_sdk/tests/unit/test_executor.py index b3ddb3ff2..f12661a02 100644 --- a/compute_sdk/tests/unit/test_executor.py +++ b/compute_sdk/tests/unit/test_executor.py @@ -44,14 +44,21 @@ def noop(): class MockedExecutor(Executor): def __init__(self, *args, **kwargs): - kwargs.setdefault( - "client", - mock.Mock( - spec=Client, - web_client=mock.Mock(spec=WebClient), - fx_serializer=mock.Mock(spec=ComputeSerializer), - ), + mock_client = mock.Mock( + spec=Client, + web_client=mock.Mock(spec=WebClient), + fx_serializer=mock.Mock(spec=ComputeSerializer), ) + # Unless test overrides, set default response: + fn_id = str(uuid.uuid4()) + mock_client.register_function.return_value = fn_id + mock_client.batch_run.return_value = { + "tasks": {fn_id: [str(uuid.uuid4())]}, + "task_group_id": str(uuid.uuid4()), + "request_id": str(uuid.uuid4()), + "endpoint_id": str(uuid.uuid4()), + } + kwargs.setdefault("client", mock_client) super().__init__(*args, **kwargs) Executor._default_task_group_id = None # Reset for each test self._test_task_submitter_exception: t.Type[Exception] | None = None @@ -106,12 +113,27 @@ def mock_result_watcher(mocker: MockerFixture): @pytest.fixture def gce(mock_result_watcher): gc_executor = MockedExecutor() + gc_executor.endpoint_id = gc_executor.client.batch_run.return_value["endpoint_id"] yield gc_executor gc_executor.shutdown(wait=False, cancel_futures=True) try_for_timeout(_is_stopped(gc_executor._task_submitter)) + if gc_executor._submitter_thread_exception_captured: + raise RuntimeError( + "Test-unhandled task submitter exception: raising for awareness" + "\n When test is complete, set flag to False to avoid this warning" + "\n\n Hint: consider `--log-cli-level=DEBUG` to pytest" + ) + + if gc_executor._test_task_submitter_exception: + raise RuntimeError( + "Test-unhandled task submitter exception: raising for awareness" + "\n When test is complete, set `_test_task_submitter_exception` to `None`" + "\n\n Hint: consider `--log-cli-level=DEBUG` to pytest" + ) from gc_executor._test_task_submitter_exception + if not _is_stopped(gc_executor._task_submitter)(): trepr = repr(gc_executor._task_submitter) raise RuntimeError( @@ -323,7 +345,7 @@ def test_executor_shutdown_cancel_futures(cancel_futures: bool, gce: Executor): # we are about to mock it and manually effect it below; see comment in # clear_queue()) gce._tasks_to_send.put((None, None)) - try_assert(lambda: not gce._task_submitter.is_alive(), "Test setup") + try_assert(_is_stopped(gce._task_submitter), "Verify test setup") gce._task_submitter = mock.Mock(spec=threading.Thread) gce._task_submitter.join.side_effect = lambda: None @@ -363,7 +385,7 @@ def some_func(*a, **k): gcc = gce.client gce._tasks_to_send.put((None, None)) # shutdown actual thread before ... - try_assert(lambda: not gce._task_submitter.is_alive(), "Verify test setup") + try_assert(_is_stopped(gce._task_submitter), "Verify test setup") gcc.register_function.return_value = str(fn_id) gce.endpoint_id = uuid.uuid4() @@ -576,13 +598,7 @@ def test_submit_raises_if_thread_stopped(gce): def test_submit_auto_registers_function(gce): gcc = gce.client - - fn_id = uuid.uuid4() - gcc.register_function.return_value = str(fn_id) - gcc.batch_run.return_value = { - "task_group_id": str(uuid.uuid4()), - "tasks": {str(fn_id): [str(uuid.uuid4())]}, - } + gcc.register_function.return_value = str(uuid.uuid4()) gce.endpoint_id = uuid.uuid4() gce.submit(noop) @@ -590,6 +606,7 @@ def test_submit_auto_registers_function(gce): def test_submit_value_error_if_no_endpoint(gce): + gce.endpoint_id = None # undo fixture setup with pytest.raises(ValueError) as pytest_exc: gce.submit(noop) @@ -942,19 +959,23 @@ def test_task_submitter_respects_batch_size(gce, batch_size: int): fn_id = str(uuid.uuid4()) gcc.register_function.return_value = fn_id - gcc.batch_run.return_value = { - "tasks": {fn_id: [str(uuid.uuid4()) for _ in range(batch_size)]}, - "task_group_id": uuid.uuid4(), + gcc.batch_run.return_value["tasks"] = { + fn_id: [str(uuid.uuid4()) for _ in range(batch_size)] } num_batches = 50 - gce.endpoint_id = uuid.uuid4() gce.batch_size = batch_size - with mock.patch(f"{_MOCK_BASE}time.sleep"): - for _ in range(num_batches * batch_size): - gce.submit(noop) + gce._tasks_to_send.put((None, None)) # stop the thread + try_assert(lambda: gce._test_task_submitter_done, "Test setup") - try_assert(lambda: gcc.batch_run.call_count >= num_batches) + for _ in range(num_batches * batch_size): + gce.submit(noop) + gce._tasks_to_send.put((None, None)) # let method stop + + with mock.patch(f"{_MOCK_BASE}time.sleep"): + # mock sleep to avoid wait for API-friendly delay + gce._task_submitter_impl() # now actually run the method + assert gcc.batch_run.call_count >= num_batches for args, _kwargs in gcc.batch_run.call_args_list: *_, batch = args @@ -966,10 +987,11 @@ def test_task_submitter_stops_executor_on_exception(gce): try_assert(lambda: gce._stopped) try_assert(lambda: isinstance(gce._test_task_submitter_exception, ValueError)) + gce._test_task_submitter_exception = None # exception was test-intentional def test_task_submitter_stops_executor_on_upstream_error_response(gce, randomstring): - upstream_error = Exception(f"Upstream error {randomstring}!!") + upstream_error = Exception(f"Upstream error {randomstring()}!!") gce.client.batch_run.side_effect = upstream_error gce.task_group_id = uuid.uuid4() tsi = _TaskSubmissionInfo( @@ -989,46 +1011,40 @@ def test_task_submitter_stops_executor_on_upstream_error_response(gce, randomstr try_assert(lambda: gce._test_task_submitter_done, "Expect graceful shutdown") assert cf.exception() is upstream_error assert gce._test_task_submitter_exception is None, "handled by future" + gce._submitter_thread_exception_captured = False -def test_sc25897_task_submit_correctly_handles_multiple_tg_ids(mocker, gce): +def test_sc25897_task_submit_correctly_handles_multiple_tg_ids(gce): gcc = gce.client - gce.endpoint_id = uuid.uuid4() - gcc.register_function.return_value = uuid.uuid4() - - can_continue = threading.Event() - def _mock_max(*a, **k): - can_continue.wait() - return max(*a, **k) + gce._tasks_to_send.put((None, None)) # stop thread + try_assert(_is_stopped(gce._task_submitter), "Verify test setup") - mocker.patch(f"{_MOCK_BASE}max", side_effect=_mock_max) - func_id = gce.register_function(noop) + tg_id_1 = str(uuid.uuid4()) + tg_id_2 = str(uuid.uuid4()) - tg_id_1 = uuid.uuid4() - tg_id_2 = uuid.uuid4() gcc.batch_run.side_effect = ( - ({"task_group_id": str(tg_id_1), "tasks": {func_id: [str(uuid.uuid4())]}}), - ({"task_group_id": str(tg_id_2), "tasks": {func_id: [str(uuid.uuid4())]}}), + ({**gcc.batch_run.return_value, "task_group_id": tg_id_1}), + ({**gcc.batch_run.return_value, "task_group_id": tg_id_2}), ) gce.task_group_id = tg_id_1 gce.submit(noop) gce.task_group_id = tg_id_2 gce.submit(noop) assert not gcc.create_batch.called, "Verify test setup" - can_continue.set() - try_assert(lambda: gcc.batch_run.call_count == 2) + gce._tasks_to_send.put((None, None)) # stop function + gce._task_submitter_impl() + assert gcc.batch_run.call_count == 2, "Two different task groups" for expected, (a, _k) in zip((tg_id_1, tg_id_2), gcc.create_batch.call_args_list): - found_tg_uuid = a[0] + found_tg_uuid = str(a[0]) assert found_tg_uuid == expected @pytest.mark.parametrize("burst_limit", (2, 3, 4)) @pytest.mark.parametrize("burst_window", (2, 3, 4)) def test_task_submitter_api_rate_limit(gce, mock_log, burst_limit, burst_window): - gce.endpoint_id = uuid.uuid4() gce._submit_tasks = mock.Mock() gce._function_registry[gce._fn_cache_key(noop)] = str(uuid.uuid4()) @@ -1058,26 +1074,10 @@ def test_task_submitter_api_rate_limit(gce, mock_log, burst_limit, burst_window) assert exp_perc_text == a[-1], "Expect to share batch utilization %" -def test_task_submit_handles_multiple_user_endpoint_configs(mocker: MockerFixture, gce): +def test_task_submit_handles_multiple_user_endpoint_configs(gce): gcc = gce.client - gce.endpoint_id = uuid.uuid4() - - func_uuid_str = str(uuid.uuid4()) - tg_uuid_str = str(uuid.uuid4()) - gcc.register_function.return_value = func_uuid_str - gcc.batch_run.side_effect = ( - ({"task_group_id": tg_uuid_str, "tasks": {func_uuid_str: [str(uuid.uuid4())]}}), - ({"task_group_id": tg_uuid_str, "tasks": {func_uuid_str: [str(uuid.uuid4())]}}), - ) - - # Temporarily block the task submitter loop - can_continue = threading.Event() - - def _mock_max(*a, **k): - can_continue.wait() - return max(*a, **k) - - mocker.patch(f"{_MOCK_BASE}max", side_effect=_mock_max) + gce._tasks_to_send.put((None, None)) # stop internal thread + try_assert(_is_stopped(gce._task_submitter), "Verify test setup") uep_config_1 = {"heartbeat": 10} uep_config_2 = {"heartbeat": 20} @@ -1085,11 +1085,12 @@ def _mock_max(*a, **k): gce.submit(noop) gce.user_endpoint_config = uep_config_2 gce.submit(noop) + gce._tasks_to_send.put((None, None)) # allow function to stop when called assert not gcc.create_batch.called, "Verify test setup" - can_continue.set() - try_assert(lambda: gcc.batch_run.call_count == 2) + gce._task_submitter_impl() + assert gcc.batch_run.call_count == 2, "two different configs" for expected, (a, _k) in zip( (uep_config_1, uep_config_2), gcc.create_batch.call_args_list ): @@ -1098,17 +1099,6 @@ def _mock_max(*a, **k): def test_task_submitter_handles_stale_result_watcher_gracefully(gce: Executor): - gcc = gce.client - gcc.register_function.return_value = uuid.uuid4() - gce.endpoint_id = uuid.uuid4() - - fn_id = str(uuid.uuid4()) - gce._function_registry[gce._fn_cache_key(noop)] = fn_id - task_id = str(uuid.uuid4()) - gcc.batch_run.return_value = { - "tasks": {fn_id: [task_id]}, - "task_group_id": str(uuid.uuid4()), - } gce.submit(noop) try_assert(lambda: bool(_RESULT_WATCHERS), "Test prerequisite") @@ -1123,10 +1113,15 @@ def test_task_submitter_handles_stale_result_watcher_gracefully(gce: Executor): try_assert(lambda: watcher_1 is not _RESULT_WATCHERS.get(gce.task_group_id)) -def test_task_submitter_sets_future_task_ids(gce): +@pytest.mark.parametrize("num_tasks", (random.randint(3, 20),)) +@pytest.mark.parametrize("ignore_tasks", (False, True)) +@pytest.mark.parametrize("too_few", (False, True)) +def test_task_submitter_sets_future_metadata(gce, num_tasks, ignore_tasks, too_few): gcc = gce.client - num_tasks = random.randint(2, 20) + req_id = uuid.UUID(gcc.batch_run.return_value["request_id"]) + ep_id = uuid.UUID(gcc.batch_run.return_value["endpoint_id"]) + tg_id = uuid.UUID(gcc.batch_run.return_value["task_group_id"]) futs = [ComputeFuture() for _ in range(num_tasks)] tasks = [ mock.MagicMock(function_uuid="fn_id", args=[], kwargs={}) @@ -1134,15 +1129,22 @@ def test_task_submitter_sets_future_task_ids(gce): ] batch_ids = [uuid.uuid4() for _ in range(num_tasks)] - gcc.batch_run.return_value = { - "request_id": "rq_id", - "task_group_id": str(uuid.uuid4()), - "endpoint_id": "ep_id", - "tasks": {"fn_id": batch_ids}, - } - gce._submit_tasks("tg_id", "ep_id", None, None, futs, tasks) + submitted_tasks = {"fn_id": batch_ids} + if ignore_tasks: + submitted_tasks["other_fn_id"] = submitted_tasks.pop("fn_id") + if too_few: + batch_ids.pop() + + gcc.batch_run.return_value["tasks"] = submitted_tasks + gce._submit_tasks(tg_id, ep_id, None, None, futs, tasks) - assert all(f.task_id == task_id for f, task_id in zip(futs, batch_ids)) + for f_idx, f in enumerate(futs): + assert f._metadata["request_uuid"] == req_id, (f_idx, f._metadata, req_id) + assert f._metadata["endpoint_uuid"] == ep_id, (f_idx, f._metadata, ep_id) + assert f._metadata["task_group_uuid"] == tg_id, (f_idx, f._metadata, tg_id) + if not (too_few or ignore_tasks): + for f_idx, (f, task_id) in enumerate(zip(futs, batch_ids)): + assert f.task_id == task_id, f_idx @pytest.mark.parametrize("batch_response", [{"tasks": "foo"}, {"task_group_id": "foo"}]) @@ -1163,24 +1165,22 @@ def test_submit_tasks_stops_futures_on_bad_response(gce, batch_response): for fut in futs: assert fut.exception() is pyt_exc.value + gce._submitter_thread_exception_captured = False # yep; we got it + def test_one_resultwatcher_per_task_group(gce: Executor): gcc = gce.client - gce.endpoint_id = uuid.uuid4() + fn_id = gcc.register_function.return_value def runit(tg_id: uuid.UUID, num_watchers: int): - fn_id = uuid.uuid4() - batch_run_data = { - "task_group_id": str(tg_id), - "tasks": {str(fn_id): [str(uuid.uuid4())]}, - } - gcc.batch_run.return_value = batch_run_data - gcc.register_function.return_value = str(fn_id) - gce.task_group_id = tg_id - f = gce.submit(lambda: uuid.uuid4()) + gcc.batch_run.return_value["task_group_id"] = str(tg_id) + gcc.batch_run.return_value["tasks"] = {fn_id: [str(uuid.uuid4())]} + + f = gce.submit(noop) - try_assert(lambda: len(_RESULT_WATCHERS) == num_watchers) + try_for_timeout(lambda: len(_RESULT_WATCHERS) == num_watchers) + assert len(_RESULT_WATCHERS) == num_watchers, f.exception() rw = _RESULT_WATCHERS.get(gce.task_group_id) assert rw.task_group_id == gce.task_group_id try_assert(lambda: f.task_id in rw._open_futures)