From 722066daad5d01b8374cb1c7bac7ba9ea9002c5e Mon Sep 17 00:00:00 2001 From: Reid Mello <30907815+rjmello@users.noreply.github.com> Date: Mon, 9 Dec 2024 14:17:22 -0500 Subject: [PATCH] Replace `WebClient` with `_ComputeWebClient` --- .../endpoint/endpoint/test_endpoint.py | 4 +- compute_sdk/globus_compute_sdk/sdk/client.py | 99 ++++++++++------ .../globus_compute_sdk/sdk/diagnostic.py | 6 +- .../globus_compute_sdk/sdk/executor.py | 4 +- .../tests/integration/test_executor_int.py | 3 +- compute_sdk/tests/unit/test_client.py | 111 +++++++++++------- compute_sdk/tests/unit/test_executor.py | 69 ++++++----- smoke_tests/tests/test_version.py | 2 +- 8 files changed, 184 insertions(+), 114 deletions(-) diff --git a/compute_endpoint/tests/integration/endpoint/endpoint/test_endpoint.py b/compute_endpoint/tests/integration/endpoint/endpoint/test_endpoint.py index b773e690d..8a4429464 100644 --- a/compute_endpoint/tests/integration/endpoint/endpoint/test_endpoint.py +++ b/compute_endpoint/tests/integration/endpoint/endpoint/test_endpoint.py @@ -18,7 +18,7 @@ ) from globus_compute_endpoint.endpoint import endpoint from globus_compute_endpoint.endpoint.config import UserEndpointConfig -from globus_compute_sdk.sdk.web_client import WebClient +from globus_compute_sdk.sdk.client import _ComputeWebClient _MOCK_BASE = "globus_compute_endpoint.endpoint.endpoint." _SVC_ADDY = "http://api.funcx.fqdn" # something clearly not correct @@ -38,7 +38,7 @@ def patch_compute_client(mocker): login_manager=mock.Mock(), ) gcc.web_service_address = _SVC_ADDY - gcc.web_client = WebClient(base_url=_SVC_ADDY) + gcc._compute_web_client = _ComputeWebClient(base_url=_SVC_ADDY) yield mocker.patch(f"{_MOCK_BASE}Client", return_value=gcc) diff --git a/compute_sdk/globus_compute_sdk/sdk/client.py b/compute_sdk/globus_compute_sdk/sdk/client.py index b707533a1..8c757ceb6 100644 --- a/compute_sdk/globus_compute_sdk/sdk/client.py +++ b/compute_sdk/globus_compute_sdk/sdk/client.py @@ -6,6 +6,8 @@ import typing as t import warnings +import globus_sdk +from globus_compute_common.sdk_version_sharing import user_agent_substring from globus_compute_sdk.errors import ( SerializationError, TaskExecutionFailed, @@ -21,7 +23,6 @@ ) from globus_compute_sdk.serialize import ComputeSerializer, SerializationStrategy from globus_compute_sdk.version import __version__, compare_versions -from globus_sdk import GlobusApp from globus_sdk.version import __version__ as __version_globus__ from .auth.auth_client import ComputeAuthClient @@ -34,6 +35,17 @@ logger = logging.getLogger(__name__) +class _ComputeWebClient: + def __init__( + self, + *args, + **kwargs, + ): + kwargs["app_name"] = user_agent_substring(__version__) + self.v2 = globus_sdk.ComputeClientV2(*args, **kwargs) + self.v3 = globus_sdk.ComputeClientV3(*args, **kwargs) + + class Client: """Main class for interacting with the Globus Compute service @@ -60,7 +72,7 @@ def __init__( code_serialization_strategy: SerializationStrategy | None = None, data_serialization_strategy: SerializationStrategy | None = None, login_manager: LoginManagerProtocol | None = None, - app: GlobusApp | None = None, + app: globus_sdk.GlobusApp | None = None, **kwargs, ): """ @@ -109,7 +121,7 @@ def __init__( self._task_status_table: dict[str, dict] = {} - self.app: GlobusApp | None = None + self.app: globus_sdk.GlobusApp | None = None self.login_manager: LoginManagerProtocol | None = None if app and login_manager: @@ -120,10 +132,16 @@ def __init__( self.web_client = self.login_manager.get_web_client( base_url=self.web_service_address ) + self._compute_web_client = _ComputeWebClient( + base_url=self.web_service_address, authorizer=self.web_client.authorizer + ) else: self.app = app if app else get_globus_app(environment=environment) self.auth_client = ComputeAuthClient(app=self.app) self.web_client = WebClient(base_url=self.web_service_address, app=self.app) + self._compute_web_client = _ComputeWebClient( + base_url=self.web_service_address, app=self.app + ) self.fx_serializer = ComputeSerializer( strategy_code=code_serialization_strategy, @@ -140,7 +158,7 @@ def version_check(self, endpoint_version: str | None = None) -> None: Raises a VersionMismatch error on failure. """ - data = self.web_client.get_version() + data = self._compute_web_client.v2.get_version(service="all") min_ep_version = data["min_ep_version"] min_sdk_version = data["min_sdk_version"] @@ -247,7 +265,7 @@ def get_task(self, task_id): if task.get("pending", True) is False: return task - r = self.web_client.get_task(task_id) + r = self._compute_web_client.v2.get_task(task_id) logger.debug(f"Response string : {r}") return self._update_task_table(r.text, task_id) @@ -294,7 +312,7 @@ def get_batch_result(self, task_id_list): results = {} if pending_task_ids: - r = self.web_client.get_batch_status(pending_task_ids) + r = self._compute_web_client.v2.get_task_batch(pending_task_ids) logger.debug(f"Response string : {r}") pending_task_ids = set(pending_task_ids) @@ -415,7 +433,7 @@ def batch_run( raise ValueError("No tasks specified for batch run") # Send the data to Globus Compute - return self.web_client.submit(endpoint_id, batch.prepare()).data + return self._compute_web_client.v3.submit(endpoint_id, batch.prepare()).data @requires_login def register_endpoint( @@ -462,22 +480,32 @@ def register_endpoint( """ self.version_check() - r = self.web_client.register_endpoint( - endpoint_name=name, - endpoint_id=endpoint_id, - metadata=metadata, - multi_user=multi_user, - display_name=display_name, - allowed_functions=allowed_functions, - auth_policy=auth_policy, - subscription_id=subscription_id, - public=public, - ) + data: t.Dict[str, t.Any] = {"endpoint_name": name} + if display_name is not None: + data["display_name"] = display_name + if multi_user: + data["multi_user"] = multi_user + if metadata: + data["metadata"] = metadata + if allowed_functions: + data["allowed_functions"] = allowed_functions + if auth_policy: + data["authentication_policy"] = auth_policy + if subscription_id: + data["subscription_uuid"] = subscription_id + if public is not None: + data["public"] = public + + if endpoint_id: + r = self._compute_web_client.v3.update_endpoint(endpoint_id, data) + else: + r = self._compute_web_client.v3.register_endpoint(data) + return r.data @requires_login def get_result_amqp_url(self) -> dict[str, str]: - r = self.web_client.get_result_amqp_url() + r = self._compute_web_client.v2.get_result_amqp_url() return r.data @requires_login @@ -499,8 +527,7 @@ def get_containers(self, name, description=None): The port to connect to and a list of containers """ data = {"endpoint_name": name, "description": description} - - r = self.web_client.post("/v2/get_containers", data=data) + r = self._compute_web_client.v2.post("/v2/get_containers", data=data) return r.data["endpoint_uuid"], r.data["endpoint_containers"] @requires_login @@ -521,7 +548,9 @@ def get_container(self, container_uuid, container_type): """ self.version_check() - r = self.web_client.get(f"/v2/containers/{container_uuid}/{container_type}") + r = self._compute_web_client.v2.get( + f"/v2/containers/{container_uuid}/{container_type}" + ) return r.data["container"] @requires_login @@ -538,7 +567,7 @@ def get_endpoint_status(self, endpoint_uuid): dict The details of the endpoint's stats """ - r = self.web_client.get_endpoint_status(endpoint_uuid) + r = self._compute_web_client.v2.get_endpoint_status(endpoint_uuid) return r.data @requires_login @@ -557,7 +586,7 @@ def get_endpoint_metadata(self, endpoint_uuid): configuration values. If there were any issues deserializing this data, may also include an "errors" key. """ - r = self.web_client.get_endpoint_metadata(endpoint_uuid) + r = self._compute_web_client.v2.get_endpoint(endpoint_uuid) return r.data @requires_login @@ -569,7 +598,7 @@ def get_endpoints(self): list A list of dictionaries which contain endpoint info """ - r = self.web_client.get_endpoints() + r = self._compute_web_client.v2.get_endpoints() return r.data @requires_login @@ -636,7 +665,7 @@ def register_function( serializer=self.fx_serializer, ) logger.info(f"Registering function : {data}") - r = self.web_client.register_function(data) + r = self._compute_web_client.v2.register_function(data.to_dict()) return r.data["function_uuid"] @requires_login @@ -654,7 +683,7 @@ def get_function(self, function_id: UUID_LIKE_T): Information about the registered function, such as name, description, serialized source code, python version, etc. """ - r = self.web_client.get_function(function_id) + r = self._compute_web_client.v2.get_function(function_id) return r.data @requires_login @@ -686,7 +715,7 @@ def register_container(self, location, container_type, name="", description=""): "type": container_type, } - r = self.web_client.post("/v2/containers", data=payload) + r = self._compute_web_client.v2.post("/v2/containers", data=payload) return r.data["container_id"] @requires_login @@ -715,11 +744,13 @@ def build_container(self, container_spec): ContainerBuildForbidden User is not in the globus group that protects the build """ - r = self.web_client.post("/v2/containers/build", data=container_spec.to_json()) + r = self._compute_web_client.v2.post( + "/v2/containers/build", data=container_spec.to_json() + ) return r.data["container_id"] def get_container_build_status(self, container_id): - r = self.web_client.get(f"/v2/containers/build/{container_id}") + r = self._compute_web_client.v2.get(f"/v2/containers/build/{container_id}") if r.http_status == 200: return r["status"] elif r.http_status == 404: @@ -744,7 +775,7 @@ def get_allowed_functions(self, endpoint_id: UUID_LIKE_T): json The response of the request """ - return self.web_client.get_allowed_functions(endpoint_id).data + return self._compute_web_client.v3.get_endpoint_allowlist(endpoint_id).data @requires_login def stop_endpoint(self, endpoint_id: str): @@ -760,7 +791,7 @@ def stop_endpoint(self, endpoint_id: str): json The response of the request """ - return self.web_client.stop_endpoint(endpoint_id) + return self._compute_web_client.v2.lock_endpoint(endpoint_id) @requires_login def delete_endpoint(self, endpoint_id: str): @@ -776,7 +807,7 @@ def delete_endpoint(self, endpoint_id: str): json The response of the request """ - return self.web_client.delete_endpoint(endpoint_id) + return self._compute_web_client.v2.delete_endpoint(endpoint_id) @requires_login def delete_function(self, function_id: str): @@ -792,7 +823,7 @@ def delete_function(self, function_id: str): json The response of the request """ - return self.web_client.delete_function(function_id) + return self._compute_web_client.v2.delete_function(function_id) @requires_login def get_worker_hardware_details(self, endpoint_id: UUID_LIKE_T) -> str: diff --git a/compute_sdk/globus_compute_sdk/sdk/diagnostic.py b/compute_sdk/globus_compute_sdk/sdk/diagnostic.py index 0eb3a2f09..98d6c65d3 100644 --- a/compute_sdk/globus_compute_sdk/sdk/diagnostic.py +++ b/compute_sdk/globus_compute_sdk/sdk/diagnostic.py @@ -24,6 +24,7 @@ get_amqp_service_host, get_web_service_url, ) +from globus_compute_sdk.sdk.client import _ComputeWebClient from globus_compute_sdk.sdk.compute_dir import ensure_compute_dir from globus_compute_sdk.sdk.executor import _RESULT_WATCHERS from globus_compute_sdk.sdk.hardware_report import hardware_commands_list @@ -31,7 +32,6 @@ from globus_compute_sdk.sdk.utils.sample_function import ( sdk_tutorial_sample_simple_function, ) -from globus_compute_sdk.sdk.web_client import WebClient from globus_compute_sdk.serialize.concretes import DillCodeTextInspect from rich import get_console from rich.console import Console @@ -122,7 +122,9 @@ def print_service_versions(base_url: str): @display_name(f"get_service_versions({base_url})") def kernel(): # Just adds a newline to the version info for better formatting - version_info = WebClient(base_url=base_url).get_version(service="all") + version_info = _ComputeWebClient(base_url=base_url).v2.get_version( + service="all" + ) print(f"{version_info}\n") return kernel diff --git a/compute_sdk/globus_compute_sdk/sdk/executor.py b/compute_sdk/globus_compute_sdk/sdk/executor.py index c3acf0f7c..a76b731b6 100644 --- a/compute_sdk/globus_compute_sdk/sdk/executor.py +++ b/compute_sdk/globus_compute_sdk/sdk/executor.py @@ -684,7 +684,7 @@ def reload_tasks( assert task_group_id is not None # mypy: we _just_ proved this # step 1: from server, acquire list of related task ids and make futures - r = self.client.web_client.get_taskgroup_tasks(task_group_id) + r = self.client._compute_web_client.v2.get_task_group(task_group_id) if r["taskgroup_id"] != str(task_group_id): msg = ( "Server did not respond with requested TaskGroup Tasks. " @@ -720,7 +720,7 @@ def reload_tasks( len(id_chunk), ) - res = self.client.web_client.get_batch_status(id_chunk) + res = self.client._compute_web_client.v2.get_task_batch(id_chunk) for task_id, task in res.data.get("results", {}).items(): if task_id in open_futures: continue diff --git a/compute_sdk/tests/integration/test_executor_int.py b/compute_sdk/tests/integration/test_executor_int.py index 980026cde..7274b7894 100644 --- a/compute_sdk/tests/integration/test_executor_int.py +++ b/compute_sdk/tests/integration/test_executor_int.py @@ -6,6 +6,7 @@ import pytest from globus_compute_sdk import Client +from globus_compute_sdk.sdk.client import _ComputeWebClient from globus_compute_sdk.sdk.executor import Executor, _ResultWatcher from tests.utils import try_assert @@ -17,7 +18,7 @@ def test_resultwatcher_graceful_shutdown(): service_url = os.environ["COMPUTE_INTEGRATION_TEST_WEB_URL"] gcc = Client() gcc.web_service_address = service_url - gcc.web_client = gcc.login_manager.get_web_client(service_url) + gcc._compute_web_client = _ComputeWebClient(service_url, app=gcc.app) gce = Executor(client=gcc) rw = _ResultWatcher(gce) rw._start_consuming = mock.Mock() diff --git a/compute_sdk/tests/unit/test_client.py b/compute_sdk/tests/unit/test_client.py index 10d759036..358185a4f 100644 --- a/compute_sdk/tests/unit/test_client.py +++ b/compute_sdk/tests/unit/test_client.py @@ -8,17 +8,17 @@ from globus_compute_sdk import ContainerSpec, __version__ from globus_compute_sdk.errors import TaskExecutionFailed from globus_compute_sdk.sdk.auth.auth_client import ComputeAuthClient +from globus_compute_sdk.sdk.client import _ComputeWebClient from globus_compute_sdk.sdk.login_manager import LoginManager from globus_compute_sdk.sdk.utils import get_env_details from globus_compute_sdk.sdk.web_client import ( FunctionRegistrationData, - FunctionRegistrationMetadata, WebClient, _get_packed_code, ) from globus_compute_sdk.serialize import ComputeSerializer from globus_compute_sdk.serialize.concretes import SELECTABLE_STRATEGIES -from globus_sdk import UserApp +from globus_sdk import ComputeClientV2, ComputeClientV3, UserApp from globus_sdk import __version__ as __version_globus__ from pytest_mock import MockerFixture @@ -36,6 +36,9 @@ def gcc(): do_version_check=False, login_manager=mock.Mock(spec=LoginManager), ) + _gcc._compute_web_client = mock.Mock(spec=_ComputeWebClient) + _gcc._compute_web_client.v2 = mock.Mock(spec=ComputeClientV2) + _gcc._compute_web_client.v3 = mock.Mock(spec=ComputeClientV3) yield _gcc @@ -55,6 +58,7 @@ def funk(): def test_client_warns_on_unknown_kwargs(kwargs, mocker: MockerFixture): mocker.patch(f"{_MOCK_BASE}ComputeAuthClient") mocker.patch(f"{_MOCK_BASE}WebClient") + mocker.patch(f"{_MOCK_BASE}_ComputeWebClient") known_kwargs = [ "funcx_home", @@ -83,6 +87,7 @@ def test_client_init_sets_addresses_by_env( ): mocker.patch(f"{_MOCK_BASE}ComputeAuthClient") mocker.patch(f"{_MOCK_BASE}WebClient") + mocker.patch(f"{_MOCK_BASE}_ComputeWebClient") if env in (None, "production"): web_uri = "https://compute.api.globus.org" @@ -113,6 +118,13 @@ def test_client_init_sets_addresses_by_env( assert client.web_service_address == web_uri +def test_compute_web_client(): + gcc = gc.Client(do_version_check=False) + assert isinstance(gcc._compute_web_client, _ComputeWebClient) + assert isinstance(gcc._compute_web_client.v2, ComputeClientV2) + assert isinstance(gcc._compute_web_client.v3, ComputeClientV3) + + @pytest.mark.parametrize( "api_data", [ @@ -158,7 +170,7 @@ def test_pending_tasks_always_fetched(gcc): should_fetch_02 = str(uuid.uuid4()) no_fetch = str(uuid.uuid4()) - gcc.web_client = mock.MagicMock() + gcc._compute_web_client = mock.MagicMock() gcc._task_status_table.update( { should_fetch_01: {"pending": True, "task_id": should_fetch_01}, @@ -170,7 +182,7 @@ def test_pending_tasks_always_fetched(gcc): # bulk avenue gcc.get_batch_result(task_id_list) - args, _ = gcc.web_client.get_batch_status.call_args + args, _ = gcc._compute_web_client.v2.get_task_batch.call_args assert should_fetch_01 in args[0] assert should_fetch_02 in args[0] assert no_fetch not in args[0] @@ -181,11 +193,11 @@ def test_pending_tasks_always_fetched(gcc): (True, should_fetch_02), (False, no_fetch), ): - gcc.web_client.get_task.reset_mock() + gcc._compute_web_client.v2.get_task.reset_mock() gcc.get_task(sf) - assert should_fetch is gcc.web_client.get_task.called + assert should_fetch is gcc._compute_web_client.v2.get_task.called if should_fetch: - args, _ = gcc.web_client.get_task.call_args + args, _ = gcc._compute_web_client.v2.get_task.call_args assert sf == args[0] @@ -204,8 +216,8 @@ def test_batch_created_websocket_queue(gcc, create_result_queue): gcc.batch_run(eid, batch) - assert gcc.web_client.submit.called - *_, submit_data = gcc.web_client.submit.call_args[0] + assert gcc._compute_web_client.v3.submit.called + *_, submit_data = gcc._compute_web_client.v3.submit.call_args[0] assert "create_queue" in submit_data assert submit_data["create_queue"] is bool(create_result_queue) @@ -244,7 +256,7 @@ def test_batch_includes_user_runtime_info(gcc): def test_build_container(mocker, gcc, randomstring): expected_container_id = randomstring() mock_data = mocker.Mock(data={"container_id": expected_container_id}) - gcc.web_client.post.return_value = mock_data + gcc._compute_web_client.v2.post.return_value = mock_data spec = ContainerSpec( name="MyContainer", pip=["matplotlib==3.5.1", "numpy==1.18.5"], @@ -255,8 +267,8 @@ def test_build_container(mocker, gcc, randomstring): container_id = gcc.build_container(spec) assert container_id == expected_container_id - assert gcc.web_client.post.called - a, k = gcc.web_client.post.call_args + assert gcc._compute_web_client.v2.post.called + a, k = gcc._compute_web_client.v2.post.call_args assert a[0] == "/v2/containers/build" assert k == {"data": spec.to_json()} @@ -270,7 +282,7 @@ def __init__(self): self["status"] = expected_status self.http_status = 200 - gcc.web_client.get.return_value = MockData() + gcc._compute_web_client.v2.get.return_value = MockData() status = gcc.get_container_build_status("123-434") assert status == expected_status @@ -281,7 +293,7 @@ def __init__(self): super().__init__() self.http_status = 404 - gcc.web_client.get.return_value = MockData() + gcc._compute_web_client.v2.get.return_value = MockData() look_for = randomstring() with pytest.raises(ValueError) as excinfo: @@ -297,7 +309,7 @@ def __init__(self): self.http_status = 500 self.http_reason = "This is a reason" - gcc.web_client.get.return_value = MockData() + gcc._compute_web_client.v2.get.return_value = MockData() with pytest.raises(SystemError) as excinfo: gcc.get_container_build_status("123-434") @@ -306,23 +318,21 @@ def __init__(self): def test_register_function(gcc): - gcc.web_client = mock.MagicMock() + gcc._compute_web_client = mock.MagicMock() metadata = {"python_version": "3.11.3", "sdk_version": "2.3.3"} gcc.register_function(funk, metadata=metadata) - a, _ = gcc.web_client.register_function.call_args + a, _ = gcc._compute_web_client.v2.register_function.call_args func_data = a[0] - assert isinstance(func_data, FunctionRegistrationData) - assert func_data.function_code is not None - assert isinstance(func_data.metadata, FunctionRegistrationMetadata) - assert func_data.metadata.python_version == metadata["python_version"] - assert func_data.metadata.sdk_version == metadata["sdk_version"] + assert func_data["function_code"] is not None + assert func_data["metadata"]["python_version"] == metadata["python_version"] + assert func_data["metadata"]["sdk_version"] == metadata["sdk_version"] @pytest.mark.parametrize("dep_arg", ["searchable", "function_name"]) def test_register_function_deprecated_args(gcc, dep_arg): - gcc.web_client = mock.MagicMock() + gcc._compute_web_client = mock.MagicMock() with pytest.deprecated_call() as pyt_wrn: gcc.register_function(funk, **{dep_arg: "foo"}) @@ -393,29 +403,28 @@ def _docstring_test_case_real_world(): ], ) def test_register_function_docstring(gcc, func): - gcc.web_client = mock.MagicMock() + gcc._compute_web_client = mock.MagicMock() gcc.register_function(func) expected = inspect.getdoc(func) - a, _ = gcc.web_client.register_function.call_args + a, _ = gcc._compute_web_client.v2.register_function.call_args func_data = a[0] - assert func_data.description == expected + assert func_data["description"] == expected def test_register_function_no_metadata(gcc): - gcc.web_client = mock.MagicMock() + gcc._compute_web_client = mock.MagicMock() gcc.register_function(funk) - a, _ = gcc.web_client.register_function.call_args + a, _ = gcc._compute_web_client.v2.register_function.call_args func_data = a[0] - assert isinstance(func_data, FunctionRegistrationData) - assert func_data.metadata is None + assert func_data["metadata"] is None def test_register_function_no_function(gcc): - gcc.web_client = mock.MagicMock() + gcc._compute_web_client = mock.MagicMock() with pytest.raises(ValueError) as pyt_exc: gcc.register_function(None) @@ -473,29 +482,29 @@ def test_function_registration_data_cant_have_both_function_and_name_code(random def test_get_function(gcc): func_uuid_str = str(uuid.uuid4()) - gcc.web_client = mock.MagicMock() + gcc._compute_web_client = mock.MagicMock() gcc.get_function(func_uuid_str) - gcc.web_client.get_function.assert_called_with(func_uuid_str) + gcc._compute_web_client.v2.get_function.assert_called_with(func_uuid_str) def test_get_allowed_functions(gcc): ep_uuid_str = str(uuid.uuid4()) - gcc.web_client = mock.MagicMock() + gcc._compute_web_client = mock.MagicMock() gcc.get_allowed_functions(ep_uuid_str) - gcc.web_client.get_allowed_functions.assert_called_with(ep_uuid_str) + gcc._compute_web_client.v3.get_endpoint_allowlist.assert_called_with(ep_uuid_str) def test_delete_function(gcc): func_uuid_str = str(uuid.uuid4()) - gcc.web_client = mock.MagicMock() + gcc._compute_web_client = mock.MagicMock() gcc.delete_function(func_uuid_str) - gcc.web_client.delete_function.assert_called_with(func_uuid_str) + gcc._compute_web_client.v2.delete_function.assert_called_with(func_uuid_str) def test_missing_task_info(gcc): @@ -516,7 +525,7 @@ def test_missing_task_info(gcc): }, }, } - gcc.web_client.get_batch_status.return_value = mock_resp + gcc._compute_web_client.v2.get_task_batch.return_value = mock_resp res = gcc.get_batch_result([tid1, tid2]) # dev note: gcc.gbr, not gcc.web_client @@ -575,7 +584,7 @@ def test_version_mismatch_from_details( } mock_warn = mocker.patch("globus_compute_sdk.sdk.client.warnings") - gcc.web_client.get_task.return_value = mock_response(200, returned_task) + gcc._compute_web_client.v2.get_task.return_value = mock_response(200, returned_task) assert gcc.get_result(tid) == result @@ -614,7 +623,7 @@ def test_version_mismatch_warns_on_failure_and_success( returned_task["result"] = gcc.fx_serializer.serialize(result) mock_warn = mocker.patch("globus_compute_sdk.sdk.client.warnings") - gcc.web_client.get_task.return_value = mock_response(200, returned_task) + gcc._compute_web_client.v2.get_task.return_value = mock_response(200, returned_task) if should_fail: with pytest.raises(TaskExecutionFailed) as exc_info: @@ -672,7 +681,9 @@ def test_version_mismatch_only_warns_once_per_ep(mocker, gcc, mock_response, ep_ tid = str(uuid.uuid4()) returned_task["task_id"] = tid returned_task["details"]["endpoint_id"] = ep_id - gcc.web_client.get_task.return_value = mock_response(200, returned_task) + gcc._compute_web_client.v2.get_task.return_value = mock_response( + 200, returned_task + ) assert gcc.get_result(tid) == result @@ -691,10 +702,15 @@ def test_client_globus_app_and_login_manager_mutually_exclusive(): def test_client_handles_globus_app( custom_app: bool, mocker: MockerFixture, randomstring ): - mock_auth_client = mocker.patch(f"{_MOCK_BASE}ComputeAuthClient") - mock_auth_client.return_value = mock.Mock(spec=ComputeAuthClient) - mock_web_client = mocker.patch(f"{_MOCK_BASE}WebClient") - mock_web_client.return_value = mock.Mock(spec=WebClient) + mock_auth_client = mocker.patch( + f"{_MOCK_BASE}ComputeAuthClient", return_value=mock.Mock(spec=ComputeAuthClient) + ) + mock_web_client = mocker.patch( + f"{_MOCK_BASE}WebClient", return_value=mock.Mock(spec=WebClient) + ) + mock_compute_web_client = mocker.patch( + f"{_MOCK_BASE}_ComputeWebClient", return_value=mock.Mock(spec=_ComputeWebClient) + ) mock_get_globus_app = mocker.patch(f"{_MOCK_BASE}get_globus_app") mock_app = mock.Mock(spec=UserApp) @@ -715,9 +731,13 @@ def test_client_handles_globus_app( assert client.app is mock_app assert client.web_client is mock_web_client.return_value + assert client._compute_web_client is mock_compute_web_client.return_value mock_web_client.assert_called_once_with( base_url=client.web_service_address, app=mock_app ) + mock_compute_web_client.assert_called_once_with( + base_url=client.web_service_address, app=mock_app + ) assert client.auth_client is mock_auth_client.return_value mock_auth_client.assert_called_once_with(app=mock_app) @@ -734,6 +754,7 @@ def test_client_handles_login_manager(): def test_client_logout_with_app(mocker): mocker.patch(f"{_MOCK_BASE}ComputeAuthClient") mocker.patch(f"{_MOCK_BASE}WebClient") + mocker.patch(f"{_MOCK_BASE}_ComputeWebClient") mock_app = mock.Mock(spec=UserApp) client = gc.Client(do_version_check=False, app=mock_app) client.logout() diff --git a/compute_sdk/tests/unit/test_executor.py b/compute_sdk/tests/unit/test_executor.py index b3ddb3ff2..3aa8e95b0 100644 --- a/compute_sdk/tests/unit/test_executor.py +++ b/compute_sdk/tests/unit/test_executor.py @@ -15,6 +15,7 @@ from globus_compute_sdk import Client, Executor, __version__ from globus_compute_sdk.errors import TaskExecutionFailed from globus_compute_sdk.sdk.asynchronous.compute_future import ComputeFuture +from globus_compute_sdk.sdk.client import _ComputeWebClient from globus_compute_sdk.sdk.executor import ( _RESULT_WATCHERS, _ResultWatcher, @@ -24,6 +25,7 @@ from globus_compute_sdk.sdk.utils.uuid_like import as_optional_uuid, as_uuid from globus_compute_sdk.sdk.web_client import WebClient from globus_compute_sdk.serialize.facade import ComputeSerializer +from globus_sdk import ComputeClientV2, ComputeClientV3 from pytest_mock import MockerFixture from tests.utils import try_assert, try_for_timeout @@ -49,6 +51,11 @@ def __init__(self, *args, **kwargs): mock.Mock( spec=Client, web_client=mock.Mock(spec=WebClient), + _compute_web_client=mock.Mock( + spec=_ComputeWebClient, + v2=mock.Mock(ComputeClientV2), + v3=mock.Mock(ComputeClientV3), + ), fx_serializer=mock.Mock(spec=ComputeSerializer), ), ) @@ -627,7 +634,9 @@ def test_reload_tasks_sets_passed_task_group_id(gce): gcc = gce.client # for less mocking: - gcc.web_client.get_taskgroup_tasks.side_effect = RuntimeError("bailing out early") + gcc._compute_web_client.v2.get_task_group.side_effect = RuntimeError( + "bailing out early" + ) tg_id = uuid.uuid4() with pytest.raises(RuntimeError) as e: @@ -649,8 +658,8 @@ def test_reload_tasks_none_completed(gce, mock_log, num_tasks): mock_batch_result = {t["id"]: t for t in mock_data["tasks"]} mock_batch_result = mock.MagicMock(data={"results": mock_batch_result}) - gcc.web_client.get_taskgroup_tasks.return_value = mock_data - gcc.web_client.get_batch_status.return_value = mock_batch_result + gcc._compute_web_client.v2.get_task_group.return_value = mock_data + gcc._compute_web_client.v2.get_task_batch.return_value = mock_batch_result client_futures = list(gce.reload_tasks()) if num_tasks == 0: @@ -687,8 +696,8 @@ def test_reload_tasks_some_completed(gce, mock_log, num_tasks): mock_batch_result[t_id]["result"] = serialize("abc") mock_batch_result = mock.MagicMock(data={"results": mock_batch_result}) - gcc.web_client.get_taskgroup_tasks.return_value = mock_data - gcc.web_client.get_batch_status.return_value = mock_batch_result + gcc._compute_web_client.v2.get_task_group.return_value = mock_data + gcc._compute_web_client.v2.get_task_batch.return_value = mock_batch_result client_futures = list(gce.reload_tasks()) if num_tasks == 0: @@ -725,8 +734,8 @@ def test_reload_tasks_all_completed(gce: Executor): mock_batch_result = {t["id"]: t for t in mock_data["tasks"]} mock_batch_result = mock.MagicMock(data={"results": mock_batch_result}) - gcc.web_client.get_taskgroup_tasks.return_value = mock_data - gcc.web_client.get_batch_status.return_value = mock_batch_result + gcc._compute_web_client.v2.get_task_group.return_value = mock_data + gcc._compute_web_client.v2.get_task_batch.return_value = mock_batch_result client_futures = list(gce.reload_tasks()) @@ -777,8 +786,8 @@ def test_reload_chunks_tasks_requested(mock_log, gce, num_tasks): mock_batch_result = mock.Mock(data={"results": {}}) - gbs: mock.Mock = gcc.web_client.get_batch_status # convenience - gcc.web_client.get_taskgroup_tasks.return_value = mock_data + gbs: mock.Mock = gcc._compute_web_client.v2.get_task_batch # convenience + gcc._compute_web_client.v2.get_task_group.return_value = mock_data gbs.return_value = mock_batch_result gce.reload_tasks() @@ -813,8 +822,8 @@ def test_reload_does_not_start_new_watcher(gce: Executor): mock_batch_result = {t["id"]: t for t in mock_data["tasks"]} mock_batch_result = mock.MagicMock(data={"results": mock_batch_result}) - gcc.web_client.get_taskgroup_tasks.return_value = mock_data - gcc.web_client.get_batch_status.return_value = mock_batch_result + gcc._compute_web_client.v2.get_task_group.return_value = mock_data + gcc._compute_web_client.v2.get_task_batch.return_value = mock_batch_result client_futures = list(gce.reload_tasks()) assert len(client_futures) == num_tasks @@ -843,8 +852,8 @@ def update_mock_data(task_ids: t.List[str]): "tasks": [{"id": task_id} for task_id in task_ids], } mock_batch_status = {_t["id"]: {} for _t in mock_data["tasks"]} - gcc.web_client.get_taskgroup_tasks.return_value = mock_data - gcc.web_client.get_batch_status.return_value = mock.Mock( + gcc._compute_web_client.v2.get_task_group.return_value = mock_data + gcc._compute_web_client.v2.get_task_batch.return_value = mock.Mock( data={"results": mock_batch_status} ) @@ -867,7 +876,7 @@ def update_mock_data(task_ids: t.List[str]): assert all(fut.task_id in task_ids_2 for fut in futures_2) -def test_reload_client_taskgroup_tasks_fails_gracefully(gce): +def test_reload_client_taskgroup_tasks_fails_gracefully(gce: Executor): gcc = gce.client gce.task_group_id = uuid.uuid4() @@ -878,7 +887,7 @@ def test_reload_client_taskgroup_tasks_fails_gracefully(gce): ) for expected_exc_class, md in mock_datum: - gcc.web_client.get_taskgroup_tasks.return_value = md + gcc._compute_web_client.v2.get_task_group.return_value = md if expected_exc_class: with pytest.raises(expected_exc_class): gce.reload_tasks() @@ -886,7 +895,7 @@ def test_reload_client_taskgroup_tasks_fails_gracefully(gce): gce.reload_tasks() -def test_reload_sets_failed_tasks(gce): +def test_reload_sets_failed_tasks(gce: Executor): gcc = gce.client gce.task_group_id = uuid.uuid4() @@ -901,15 +910,15 @@ def test_reload_sets_failed_tasks(gce): mock_batch_result = {t["id"]: t for t in mock_data["tasks"]} mock_batch_result = mock.MagicMock(data={"results": mock_batch_result}) - gcc.web_client.get_taskgroup_tasks.return_value = mock_data - gcc.web_client.get_batch_status.return_value = mock_batch_result + gcc._compute_web_client.v2.get_task_group.return_value = mock_data + gcc._compute_web_client.v2.get_task_batch.return_value = mock_batch_result futs = list(gce.reload_tasks()) assert all("doh!" in str(fut.exception()) for fut in futs) -def test_reload_handles_deseralization_error_gracefully(gce): +def test_reload_handles_deseralization_error_gracefully(gce: Executor): gcc = gce.client gcc.fx_serializer = ComputeSerializer() @@ -925,8 +934,8 @@ def test_reload_handles_deseralization_error_gracefully(gce): mock_batch_result = {t["id"]: t for t in mock_data["tasks"]} mock_batch_result = mock.MagicMock(data={"results": mock_batch_result}) - gcc.web_client.get_taskgroup_tasks.return_value = mock_data - gcc.web_client.get_batch_status.return_value = mock_batch_result + gcc._compute_web_client.v2.get_task_group.return_value = mock_data + gcc._compute_web_client.v2.get_task_batch.return_value = mock_batch_result futs = list(gce.reload_tasks()) @@ -934,7 +943,7 @@ def test_reload_handles_deseralization_error_gracefully(gce): @pytest.mark.parametrize("batch_size", tuple(range(1, 11))) -def test_task_submitter_respects_batch_size(gce, batch_size: int): +def test_task_submitter_respects_batch_size(gce: Executor, batch_size: int): gcc = gce.client # make a new MagicMock every time create_batch is called @@ -961,14 +970,16 @@ def test_task_submitter_respects_batch_size(gce, batch_size: int): assert 0 < batch.add.call_count <= batch_size -def test_task_submitter_stops_executor_on_exception(gce): +def test_task_submitter_stops_executor_on_exception(gce: Executor): gce._tasks_to_send.put(("too", "much", "destructuring", "!!")) try_assert(lambda: gce._stopped) try_assert(lambda: isinstance(gce._test_task_submitter_exception, ValueError)) -def test_task_submitter_stops_executor_on_upstream_error_response(gce, randomstring): +def test_task_submitter_stops_executor_on_upstream_error_response( + gce: Executor, randomstring +): upstream_error = Exception(f"Upstream error {randomstring}!!") gce.client.batch_run.side_effect = upstream_error gce.task_group_id = uuid.uuid4() @@ -991,7 +1002,7 @@ def test_task_submitter_stops_executor_on_upstream_error_response(gce, randomstr assert gce._test_task_submitter_exception is None, "handled by future" -def test_sc25897_task_submit_correctly_handles_multiple_tg_ids(mocker, gce): +def test_sc25897_task_submit_correctly_handles_multiple_tg_ids(mocker, gce: Executor): gcc = gce.client gce.endpoint_id = uuid.uuid4() gcc.register_function.return_value = uuid.uuid4() @@ -1027,7 +1038,9 @@ def _mock_max(*a, **k): @pytest.mark.parametrize("burst_limit", (2, 3, 4)) @pytest.mark.parametrize("burst_window", (2, 3, 4)) -def test_task_submitter_api_rate_limit(gce, mock_log, burst_limit, burst_window): +def test_task_submitter_api_rate_limit( + gce: Executor, mock_log, burst_limit, burst_window +): gce.endpoint_id = uuid.uuid4() gce._submit_tasks = mock.Mock() @@ -1058,7 +1071,9 @@ def test_task_submitter_api_rate_limit(gce, mock_log, burst_limit, burst_window) assert exp_perc_text == a[-1], "Expect to share batch utilization %" -def test_task_submit_handles_multiple_user_endpoint_configs(mocker: MockerFixture, gce): +def test_task_submit_handles_multiple_user_endpoint_configs( + mocker: MockerFixture, gce: Executor +): gcc = gce.client gce.endpoint_id = uuid.uuid4() diff --git a/smoke_tests/tests/test_version.py b/smoke_tests/tests/test_version.py index c41d1c83e..c8a5aa236 100644 --- a/smoke_tests/tests/test_version.py +++ b/smoke_tests/tests/test_version.py @@ -3,7 +3,7 @@ def test_web_service(compute_client, endpoint, compute_test_config): """This test checks 1) web-service is online, 2) version of the web-service""" - response = compute_client.web_client.get_version() + response = compute_client._compute_web_client.v2.get_version() assert response.http_status == 200, ( "Request to version expected status_code=200, "