Skip to content

Commit

Permalink
Replace WebClient with _ComputeWebClient
Browse files Browse the repository at this point in the history
  • Loading branch information
rjmello committed Dec 9, 2024
1 parent 0d75cbf commit b860cb4
Show file tree
Hide file tree
Showing 12 changed files with 177 additions and 114 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)
from globus_compute_endpoint.endpoint import endpoint
from globus_compute_endpoint.endpoint.config import UserEndpointConfig
from globus_compute_sdk.sdk.web_client import WebClient
from globus_compute_sdk.sdk.web_client import _ComputeWebClient

_MOCK_BASE = "globus_compute_endpoint.endpoint.endpoint."
_SVC_ADDY = "http://api.funcx.fqdn" # something clearly not correct
Expand All @@ -38,7 +38,7 @@ def patch_compute_client(mocker):
login_manager=mock.Mock(),
)
gcc.web_service_address = _SVC_ADDY
gcc.web_client = WebClient(base_url=_SVC_ADDY)
gcc._compute_web_client = _ComputeWebClient(base_url=_SVC_ADDY)

yield mocker.patch(f"{_MOCK_BASE}Client", return_value=gcc)

Expand Down
83 changes: 51 additions & 32 deletions compute_sdk/globus_compute_sdk/sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from .batch import Batch, UserRuntime
from .login_manager import LoginManagerProtocol, requires_login
from .utils import get_env_var_with_deprecation
from .web_client import WebClient
from .web_client import WebClient, _ComputeWebClient

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -120,10 +120,16 @@ def __init__(
self.web_client = self.login_manager.get_web_client(
base_url=self.web_service_address
)
self._compute_web_client = _ComputeWebClient(
base_url=self.web_service_address, authorizer=self.web_client.authorizer
)
else:
self.app = app if app else get_globus_app(environment=environment)
self.auth_client = ComputeAuthClient(app=self.app)
self.web_client = WebClient(base_url=self.web_service_address, app=self.app)
self._compute_web_client = _ComputeWebClient(
base_url=self.web_service_address, app=self.app
)

self.fx_serializer = ComputeSerializer(
strategy_code=code_serialization_strategy,
Expand All @@ -140,7 +146,7 @@ def version_check(self, endpoint_version: str | None = None) -> None:
Raises a VersionMismatch error on failure.
"""
data = self.web_client.get_version()
data = self._compute_web_client.v2.get_version(service="all")

min_ep_version = data["min_ep_version"]
min_sdk_version = data["min_sdk_version"]
Expand Down Expand Up @@ -247,7 +253,7 @@ def get_task(self, task_id):
if task.get("pending", True) is False:
return task

r = self.web_client.get_task(task_id)
r = self._compute_web_client.v2.get_task(task_id)
logger.debug(f"Response string : {r}")
return self._update_task_table(r.text, task_id)

Expand Down Expand Up @@ -294,7 +300,7 @@ def get_batch_result(self, task_id_list):
results = {}

if pending_task_ids:
r = self.web_client.get_batch_status(pending_task_ids)
r = self._compute_web_client.v2.get_task_batch(pending_task_ids)
logger.debug(f"Response string : {r}")

pending_task_ids = set(pending_task_ids)
Expand Down Expand Up @@ -415,7 +421,7 @@ def batch_run(
raise ValueError("No tasks specified for batch run")

# Send the data to Globus Compute
return self.web_client.submit(endpoint_id, batch.prepare()).data
return self._compute_web_client.v3.submit(endpoint_id, batch.prepare()).data

@requires_login
def register_endpoint(
Expand Down Expand Up @@ -462,22 +468,32 @@ def register_endpoint(
"""
self.version_check()

r = self.web_client.register_endpoint(
endpoint_name=name,
endpoint_id=endpoint_id,
metadata=metadata,
multi_user=multi_user,
display_name=display_name,
allowed_functions=allowed_functions,
auth_policy=auth_policy,
subscription_id=subscription_id,
public=public,
)
data: t.Dict[str, t.Any] = {"endpoint_name": name}
if display_name is not None:
data["display_name"] = display_name
if multi_user:
data["multi_user"] = multi_user
if metadata:
data["metadata"] = metadata
if allowed_functions:
data["allowed_functions"] = allowed_functions
if auth_policy:
data["authentication_policy"] = auth_policy
if subscription_id:
data["subscription_uuid"] = subscription_id
if public is not None:
data["public"] = public

if endpoint_id:
r = self._compute_web_client.v3.update_endpoint(endpoint_id, data)
else:
r = self._compute_web_client.v3.register_endpoint(data)

return r.data

@requires_login
def get_result_amqp_url(self) -> dict[str, str]:
r = self.web_client.get_result_amqp_url()
r = self._compute_web_client.v2.get_result_amqp_url()
return r.data

@requires_login
Expand All @@ -499,8 +515,7 @@ def get_containers(self, name, description=None):
The port to connect to and a list of containers
"""
data = {"endpoint_name": name, "description": description}

r = self.web_client.post("/v2/get_containers", data=data)
r = self._compute_web_client.v2.post("/v2/get_containers", data=data)
return r.data["endpoint_uuid"], r.data["endpoint_containers"]

@requires_login
Expand All @@ -521,7 +536,9 @@ def get_container(self, container_uuid, container_type):
"""
self.version_check()

r = self.web_client.get(f"/v2/containers/{container_uuid}/{container_type}")
r = self._compute_web_client.v2.get(
f"/v2/containers/{container_uuid}/{container_type}"
)
return r.data["container"]

@requires_login
Expand All @@ -538,7 +555,7 @@ def get_endpoint_status(self, endpoint_uuid):
dict
The details of the endpoint's stats
"""
r = self.web_client.get_endpoint_status(endpoint_uuid)
r = self._compute_web_client.v2.get_endpoint_status(endpoint_uuid)
return r.data

@requires_login
Expand All @@ -557,7 +574,7 @@ def get_endpoint_metadata(self, endpoint_uuid):
configuration values. If there were any issues deserializing this data, may
also include an "errors" key.
"""
r = self.web_client.get_endpoint_metadata(endpoint_uuid)
r = self._compute_web_client.v2.get_endpoint(endpoint_uuid)
return r.data

@requires_login
Expand All @@ -569,7 +586,7 @@ def get_endpoints(self):
list
A list of dictionaries which contain endpoint info
"""
r = self.web_client.get_endpoints()
r = self._compute_web_client.v2.get_endpoints()
return r.data

@requires_login
Expand Down Expand Up @@ -636,7 +653,7 @@ def register_function(
serializer=self.fx_serializer,
)
logger.info(f"Registering function : {data}")
r = self.web_client.register_function(data)
r = self._compute_web_client.v2.register_function(data.to_dict())
return r.data["function_uuid"]

@requires_login
Expand All @@ -654,7 +671,7 @@ def get_function(self, function_id: UUID_LIKE_T):
Information about the registered function, such as name, description,
serialized source code, python version, etc.
"""
r = self.web_client.get_function(function_id)
r = self._compute_web_client.v2.get_function(function_id)
return r.data

@requires_login
Expand Down Expand Up @@ -686,7 +703,7 @@ def register_container(self, location, container_type, name="", description=""):
"type": container_type,
}

r = self.web_client.post("/v2/containers", data=payload)
r = self._compute_web_client.v2.post("/v2/containers", data=payload)
return r.data["container_id"]

@requires_login
Expand Down Expand Up @@ -715,11 +732,13 @@ def build_container(self, container_spec):
ContainerBuildForbidden
User is not in the globus group that protects the build
"""
r = self.web_client.post("/v2/containers/build", data=container_spec.to_json())
r = self._compute_web_client.v2.post(
"/v2/containers/build", data=container_spec.to_json()
)
return r.data["container_id"]

def get_container_build_status(self, container_id):
r = self.web_client.get(f"/v2/containers/build/{container_id}")
r = self._compute_web_client.v2.get(f"/v2/containers/build/{container_id}")
if r.http_status == 200:
return r["status"]
elif r.http_status == 404:
Expand All @@ -744,7 +763,7 @@ def get_allowed_functions(self, endpoint_id: UUID_LIKE_T):
json
The response of the request
"""
return self.web_client.get_allowed_functions(endpoint_id).data
return self._compute_web_client.v3.get_endpoint_allowlist(endpoint_id).data

@requires_login
def stop_endpoint(self, endpoint_id: str):
Expand All @@ -760,7 +779,7 @@ def stop_endpoint(self, endpoint_id: str):
json
The response of the request
"""
return self.web_client.stop_endpoint(endpoint_id)
return self._compute_web_client.v2.lock_endpoint(endpoint_id)

@requires_login
def delete_endpoint(self, endpoint_id: str):
Expand All @@ -776,7 +795,7 @@ def delete_endpoint(self, endpoint_id: str):
json
The response of the request
"""
return self.web_client.delete_endpoint(endpoint_id)
return self._compute_web_client.v2.delete_endpoint(endpoint_id)

@requires_login
def delete_function(self, function_id: str):
Expand All @@ -792,7 +811,7 @@ def delete_function(self, function_id: str):
json
The response of the request
"""
return self.web_client.delete_function(function_id)
return self._compute_web_client.v2.delete_function(function_id)

@requires_login
def get_worker_hardware_details(self, endpoint_id: UUID_LIKE_T) -> str:
Expand Down
6 changes: 4 additions & 2 deletions compute_sdk/globus_compute_sdk/sdk/diagnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from globus_compute_sdk.sdk.utils.sample_function import (
sdk_tutorial_sample_simple_function,
)
from globus_compute_sdk.sdk.web_client import WebClient
from globus_compute_sdk.sdk.web_client import _ComputeWebClient
from globus_compute_sdk.serialize.concretes import DillCodeTextInspect
from rich import get_console
from rich.console import Console
Expand Down Expand Up @@ -122,7 +122,9 @@ def print_service_versions(base_url: str):
@display_name(f"get_service_versions({base_url})")
def kernel():
# Just adds a newline to the version info for better formatting
version_info = WebClient(base_url=base_url).get_version(service="all")
version_info = _ComputeWebClient(base_url=base_url).v2.get_version(
service="all"
)
print(f"{version_info}\n")

return kernel
Expand Down
4 changes: 2 additions & 2 deletions compute_sdk/globus_compute_sdk/sdk/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ def reload_tasks(
assert task_group_id is not None # mypy: we _just_ proved this

# step 1: from server, acquire list of related task ids and make futures
r = self.client.web_client.get_taskgroup_tasks(task_group_id)
r = self.client._compute_web_client.v2.get_task_group(task_group_id)
if r["taskgroup_id"] != str(task_group_id):
msg = (
"Server did not respond with requested TaskGroup Tasks. "
Expand Down Expand Up @@ -720,7 +720,7 @@ def reload_tasks(
len(id_chunk),
)

res = self.client.web_client.get_batch_status(id_chunk)
res = self.client._compute_web_client.v2.get_task_batch(id_chunk)
for task_id, task in res.data.get("results", {}).items():
if task_id in open_futures:
continue
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def get_web_client(
base_url=base_url,
app_name=app_name,
authorizer=self.authorizers[ComputeScopes.resource_server],
_deprecation_warning=False,
)

def ensure_logged_in(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,5 @@ def get_web_client(
base_url=base_url,
app_name=app_name,
authorizer=self._get_authorizer(ComputeScopes.resource_server),
_deprecation_warning=False,
)
11 changes: 11 additions & 0 deletions compute_sdk/globus_compute_sdk/sdk/web_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,14 @@ def delete_function(
self, function_id: UUID_LIKE_T
) -> globus_sdk.GlobusHTTPResponse:
return self.delete(f"/v2/functions/{function_id}")


class _ComputeWebClient:
def __init__(
self,
*args,
**kwargs,
):
kwargs["app_name"] = user_agent_substring(__version__)
self.v2 = globus_sdk.ComputeClientV2(*args, **kwargs)
self.v3 = globus_sdk.ComputeClientV3(*args, **kwargs)
3 changes: 2 additions & 1 deletion compute_sdk/tests/integration/test_executor_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
from globus_compute_sdk import Client
from globus_compute_sdk.sdk.executor import Executor, _ResultWatcher
from globus_compute_sdk.sdk.web_client import _ComputeWebClient
from tests.utils import try_assert


Expand All @@ -17,7 +18,7 @@ def test_resultwatcher_graceful_shutdown():
service_url = os.environ["COMPUTE_INTEGRATION_TEST_WEB_URL"]
gcc = Client()
gcc.web_service_address = service_url
gcc.web_client = gcc.login_manager.get_web_client(service_url)
gcc._compute_web_client = _ComputeWebClient(service_url, app=gcc.app)
gce = Executor(client=gcc)
rw = _ResultWatcher(gce)
rw._start_consuming = mock.Mock()
Expand Down
Loading

0 comments on commit b860cb4

Please sign in to comment.