Skip to content

Commit

Permalink
Replace WebClient with _ComputeWebClient
Browse files Browse the repository at this point in the history
  • Loading branch information
rjmello committed Dec 9, 2024
1 parent 638878e commit 956fd05
Show file tree
Hide file tree
Showing 8 changed files with 186 additions and 114 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)
from globus_compute_endpoint.endpoint import endpoint
from globus_compute_endpoint.endpoint.config import UserEndpointConfig
from globus_compute_sdk.sdk.web_client import WebClient
from globus_compute_sdk.sdk.client import _ComputeWebClient

_MOCK_BASE = "globus_compute_endpoint.endpoint.endpoint."
_SVC_ADDY = "http://api.funcx.fqdn" # something clearly not correct
Expand All @@ -38,7 +38,7 @@ def patch_compute_client(mocker):
login_manager=mock.Mock(),
)
gcc.web_service_address = _SVC_ADDY
gcc.web_client = WebClient(base_url=_SVC_ADDY)
gcc._compute_web_client = _ComputeWebClient(base_url=_SVC_ADDY)

yield mocker.patch(f"{_MOCK_BASE}Client", return_value=gcc)

Expand Down
99 changes: 65 additions & 34 deletions compute_sdk/globus_compute_sdk/sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import typing as t
import warnings

import globus_sdk
from globus_compute_common.sdk_version_sharing import user_agent_substring
from globus_compute_sdk.errors import (
SerializationError,
TaskExecutionFailed,
Expand All @@ -21,7 +23,6 @@
)
from globus_compute_sdk.serialize import ComputeSerializer, SerializationStrategy
from globus_compute_sdk.version import __version__, compare_versions
from globus_sdk import GlobusApp
from globus_sdk.version import __version__ as __version_globus__

from .auth.auth_client import ComputeAuthClient
Expand All @@ -34,6 +35,17 @@
logger = logging.getLogger(__name__)


class _ComputeWebClient:
def __init__(
self,
*args,
**kwargs,
):
kwargs["app_name"] = user_agent_substring(__version__)
self.v2 = globus_sdk.ComputeClientV2(*args, **kwargs)
self.v3 = globus_sdk.ComputeClientV3(*args, **kwargs)


class Client:
"""Main class for interacting with the Globus Compute service
Expand All @@ -60,7 +72,7 @@ def __init__(
code_serialization_strategy: SerializationStrategy | None = None,
data_serialization_strategy: SerializationStrategy | None = None,
login_manager: LoginManagerProtocol | None = None,
app: GlobusApp | None = None,
app: globus_sdk.GlobusApp | None = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -109,7 +121,7 @@ def __init__(

self._task_status_table: dict[str, dict] = {}

self.app: GlobusApp | None = None
self.app: globus_sdk.GlobusApp | None = None
self.login_manager: LoginManagerProtocol | None = None

if app and login_manager:
Expand All @@ -120,10 +132,16 @@ def __init__(
self.web_client = self.login_manager.get_web_client(
base_url=self.web_service_address
)
self._compute_web_client = _ComputeWebClient(
base_url=self.web_service_address, authorizer=self.web_client.authorizer
)
else:
self.app = app if app else get_globus_app(environment=environment)
self.auth_client = ComputeAuthClient(app=self.app)
self.web_client = WebClient(base_url=self.web_service_address, app=self.app)
self._compute_web_client = _ComputeWebClient(
base_url=self.web_service_address, app=self.app
)

self.fx_serializer = ComputeSerializer(
strategy_code=code_serialization_strategy,
Expand All @@ -140,7 +158,7 @@ def version_check(self, endpoint_version: str | None = None) -> None:
Raises a VersionMismatch error on failure.
"""
data = self.web_client.get_version()
data = self._compute_web_client.v2.get_version(service="all")

min_ep_version = data["min_ep_version"]
min_sdk_version = data["min_sdk_version"]
Expand Down Expand Up @@ -247,7 +265,7 @@ def get_task(self, task_id):
if task.get("pending", True) is False:
return task

r = self.web_client.get_task(task_id)
r = self._compute_web_client.v2.get_task(task_id)
logger.debug(f"Response string : {r}")
return self._update_task_table(r.text, task_id)

Expand Down Expand Up @@ -294,7 +312,7 @@ def get_batch_result(self, task_id_list):
results = {}

if pending_task_ids:
r = self.web_client.get_batch_status(pending_task_ids)
r = self._compute_web_client.v2.get_task_batch(pending_task_ids)
logger.debug(f"Response string : {r}")

pending_task_ids = set(pending_task_ids)
Expand Down Expand Up @@ -415,7 +433,7 @@ def batch_run(
raise ValueError("No tasks specified for batch run")

# Send the data to Globus Compute
return self.web_client.submit(endpoint_id, batch.prepare()).data
return self._compute_web_client.v3.submit(endpoint_id, batch.prepare()).data

@requires_login
def register_endpoint(
Expand Down Expand Up @@ -462,22 +480,32 @@ def register_endpoint(
"""
self.version_check()

r = self.web_client.register_endpoint(
endpoint_name=name,
endpoint_id=endpoint_id,
metadata=metadata,
multi_user=multi_user,
display_name=display_name,
allowed_functions=allowed_functions,
auth_policy=auth_policy,
subscription_id=subscription_id,
public=public,
)
data: t.Dict[str, t.Any] = {"endpoint_name": name}
if display_name is not None:
data["display_name"] = display_name
if multi_user:
data["multi_user"] = multi_user
if metadata:
data["metadata"] = metadata
if allowed_functions:
data["allowed_functions"] = allowed_functions
if auth_policy:
data["authentication_policy"] = auth_policy
if subscription_id:
data["subscription_uuid"] = subscription_id
if public is not None:
data["public"] = public

if endpoint_id:
r = self._compute_web_client.v3.update_endpoint(endpoint_id, data)
else:
r = self._compute_web_client.v3.register_endpoint(data)

return r.data

@requires_login
def get_result_amqp_url(self) -> dict[str, str]:
r = self.web_client.get_result_amqp_url()
r = self._compute_web_client.v2.get_result_amqp_url()
return r.data

@requires_login
Expand All @@ -499,8 +527,7 @@ def get_containers(self, name, description=None):
The port to connect to and a list of containers
"""
data = {"endpoint_name": name, "description": description}

r = self.web_client.post("/v2/get_containers", data=data)
r = self._compute_web_client.v2.post("/v2/get_containers", data=data)
return r.data["endpoint_uuid"], r.data["endpoint_containers"]

@requires_login
Expand All @@ -521,7 +548,9 @@ def get_container(self, container_uuid, container_type):
"""
self.version_check()

r = self.web_client.get(f"/v2/containers/{container_uuid}/{container_type}")
r = self._compute_web_client.v2.get(
f"/v2/containers/{container_uuid}/{container_type}"
)
return r.data["container"]

@requires_login
Expand All @@ -538,7 +567,7 @@ def get_endpoint_status(self, endpoint_uuid):
dict
The details of the endpoint's stats
"""
r = self.web_client.get_endpoint_status(endpoint_uuid)
r = self._compute_web_client.v2.get_endpoint_status(endpoint_uuid)
return r.data

@requires_login
Expand All @@ -557,7 +586,7 @@ def get_endpoint_metadata(self, endpoint_uuid):
configuration values. If there were any issues deserializing this data, may
also include an "errors" key.
"""
r = self.web_client.get_endpoint_metadata(endpoint_uuid)
r = self._compute_web_client.v2.get_endpoint(endpoint_uuid)
return r.data

@requires_login
Expand All @@ -569,7 +598,7 @@ def get_endpoints(self):
list
A list of dictionaries which contain endpoint info
"""
r = self.web_client.get_endpoints()
r = self._compute_web_client.v2.get_endpoints()
return r.data

@requires_login
Expand Down Expand Up @@ -636,7 +665,7 @@ def register_function(
serializer=self.fx_serializer,
)
logger.info(f"Registering function : {data}")
r = self.web_client.register_function(data)
r = self._compute_web_client.v2.register_function(data.to_dict())
return r.data["function_uuid"]

@requires_login
Expand All @@ -654,7 +683,7 @@ def get_function(self, function_id: UUID_LIKE_T):
Information about the registered function, such as name, description,
serialized source code, python version, etc.
"""
r = self.web_client.get_function(function_id)
r = self._compute_web_client.v2.get_function(function_id)
return r.data

@requires_login
Expand Down Expand Up @@ -686,7 +715,7 @@ def register_container(self, location, container_type, name="", description=""):
"type": container_type,
}

r = self.web_client.post("/v2/containers", data=payload)
r = self._compute_web_client.v2.post("/v2/containers", data=payload)
return r.data["container_id"]

@requires_login
Expand Down Expand Up @@ -715,11 +744,13 @@ def build_container(self, container_spec):
ContainerBuildForbidden
User is not in the globus group that protects the build
"""
r = self.web_client.post("/v2/containers/build", data=container_spec.to_json())
r = self._compute_web_client.v2.post(
"/v2/containers/build", data=container_spec.to_json()
)
return r.data["container_id"]

def get_container_build_status(self, container_id):
r = self.web_client.get(f"/v2/containers/build/{container_id}")
r = self._compute_web_client.v2.get(f"/v2/containers/build/{container_id}")
if r.http_status == 200:
return r["status"]
elif r.http_status == 404:
Expand All @@ -744,7 +775,7 @@ def get_allowed_functions(self, endpoint_id: UUID_LIKE_T):
json
The response of the request
"""
return self.web_client.get_allowed_functions(endpoint_id).data
return self._compute_web_client.v3.get_endpoint_allowlist(endpoint_id).data

@requires_login
def stop_endpoint(self, endpoint_id: str):
Expand All @@ -760,7 +791,7 @@ def stop_endpoint(self, endpoint_id: str):
json
The response of the request
"""
return self.web_client.stop_endpoint(endpoint_id)
return self._compute_web_client.v2.lock_endpoint(endpoint_id)

@requires_login
def delete_endpoint(self, endpoint_id: str):
Expand All @@ -776,7 +807,7 @@ def delete_endpoint(self, endpoint_id: str):
json
The response of the request
"""
return self.web_client.delete_endpoint(endpoint_id)
return self._compute_web_client.v2.delete_endpoint(endpoint_id)

@requires_login
def delete_function(self, function_id: str):
Expand All @@ -792,7 +823,7 @@ def delete_function(self, function_id: str):
json
The response of the request
"""
return self.web_client.delete_function(function_id)
return self._compute_web_client.v2.delete_function(function_id)

@requires_login
def get_worker_hardware_details(self, endpoint_id: UUID_LIKE_T) -> str:
Expand Down
6 changes: 4 additions & 2 deletions compute_sdk/globus_compute_sdk/sdk/diagnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
get_amqp_service_host,
get_web_service_url,
)
from globus_compute_sdk.sdk.client import _ComputeWebClient
from globus_compute_sdk.sdk.compute_dir import ensure_compute_dir
from globus_compute_sdk.sdk.executor import _RESULT_WATCHERS
from globus_compute_sdk.sdk.hardware_report import hardware_commands_list
from globus_compute_sdk.sdk.utils import display_name
from globus_compute_sdk.sdk.utils.sample_function import (
sdk_tutorial_sample_simple_function,
)
from globus_compute_sdk.sdk.web_client import WebClient
from globus_compute_sdk.serialize.concretes import DillCodeTextInspect
from rich import get_console
from rich.console import Console
Expand Down Expand Up @@ -122,7 +122,9 @@ def print_service_versions(base_url: str):
@display_name(f"get_service_versions({base_url})")
def kernel():
# Just adds a newline to the version info for better formatting
version_info = WebClient(base_url=base_url).get_version(service="all")
version_info = _ComputeWebClient(base_url=base_url).v2.get_version(
service="all"
)
print(f"{version_info}\n")

return kernel
Expand Down
4 changes: 2 additions & 2 deletions compute_sdk/globus_compute_sdk/sdk/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ def reload_tasks(
assert task_group_id is not None # mypy: we _just_ proved this

# step 1: from server, acquire list of related task ids and make futures
r = self.client.web_client.get_taskgroup_tasks(task_group_id)
r = self.client._compute_web_client.v2.get_task_group(task_group_id)
if r["taskgroup_id"] != str(task_group_id):
msg = (
"Server did not respond with requested TaskGroup Tasks. "
Expand Down Expand Up @@ -720,7 +720,7 @@ def reload_tasks(
len(id_chunk),
)

res = self.client.web_client.get_batch_status(id_chunk)
res = self.client._compute_web_client.v2.get_task_batch(id_chunk)
for task_id, task in res.data.get("results", {}).items():
if task_id in open_futures:
continue
Expand Down
3 changes: 2 additions & 1 deletion compute_sdk/tests/integration/test_executor_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import pytest
from globus_compute_sdk import Client
from globus_compute_sdk.sdk.client import _ComputeWebClient
from globus_compute_sdk.sdk.executor import Executor, _ResultWatcher
from tests.utils import try_assert

Expand All @@ -17,7 +18,7 @@ def test_resultwatcher_graceful_shutdown():
service_url = os.environ["COMPUTE_INTEGRATION_TEST_WEB_URL"]
gcc = Client()
gcc.web_service_address = service_url
gcc.web_client = gcc.login_manager.get_web_client(service_url)
gcc._compute_web_client = _ComputeWebClient(service_url, app=gcc.app)
gce = Executor(client=gcc)
rw = _ResultWatcher(gce)
rw._start_consuming = mock.Mock()
Expand Down
Loading

0 comments on commit 956fd05

Please sign in to comment.