Skip to content

Commit

Permalink
feat: Verify client and cluster Ray versions match
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 588901140
  • Loading branch information
matthew29tang authored and copybara-github committed Dec 7, 2023
1 parent 7c64672 commit 10c6ad2
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 0 deletions.
12 changes: 12 additions & 0 deletions google/cloud/aiplatform/preview/vertex_ray/client_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,18 @@ def __init__(self, address: Optional[str]) -> None:
" failed to start Head node properly because custom service account isn't supported.",
)
logging.debug("[Ray on Vertex AI]: Resolved head node ip: %s", address)
cluster = _gapic_utils.persistent_resource_to_cluster(
persistent_resource=self.response
)
if cluster is None:
raise ValueError(
"[Ray on Vertex AI]: Please delete and recreate the cluster (The cluster is not a Ray cluster or the cluster image is outdated)."
)
local_ray_verion = _validation_utils.get_local_ray_version()
if cluster.ray_version != local_ray_verion:
raise ValueError(
f"[Ray on Vertex AI]: Local runtime has Ray version {local_ray_verion}, but the cluster runtime has {cluster.ray_version}. Please ensure that the Ray versions match."
)
super().__init__(address)

def connect(self) -> _VertexRayClientContext:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import google.auth
import google.auth.transport.requests
import logging
import ray
import re

from google.cloud.aiplatform import initializer
Expand Down Expand Up @@ -68,6 +69,13 @@ def maybe_reconstruct_resource_name(address) -> str:
return address


def get_local_ray_version():
ray_version = ray.__version__.split(".")
if len(ray_version) == 3:
ray_version = ray_version[:2]
return "_".join(ray_version)


def get_image_uri(ray_version, python_version, enable_cuda):
"""Image uri for a given ray version and python version."""
if ray_version not in ["2_4"]:
Expand Down
7 changes: 7 additions & 0 deletions tests/unit/vertex_ray/test_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@
ResourceRuntimeSpec,
)

import pytest
import sys

rovminversion = pytest.mark.skipif(
sys.version_info > (3, 10), reason="Requires python3.10 or lower"
)


@dataclasses.dataclass(frozen=True)
class ProjectConstants:
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/vertex_ray/test_vertex_ray_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def setup_method(self):
def teardown_method(self):
aiplatform.initializer.global_pool.shutdown(wait=True)

@tc.rovminversion
@pytest.mark.usefixtures("get_persistent_resource_status_running_mock")
def test_init_with_full_resource_name(
self,
Expand All @@ -94,6 +95,7 @@ def test_init_with_full_resource_name(
tc.ClusterConstants._TEST_VERTEX_RAY_HEAD_NODE_IP,
)

@tc.rovminversion
@pytest.mark.usefixtures(
"get_persistent_resource_status_running_mock", "google_auth_mock"
)
Expand All @@ -112,6 +114,7 @@ def test_init_with_cluster_name(
tc.ClusterConstants._TEST_VERTEX_RAY_HEAD_NODE_IP,
)

@tc.rovminversion
@pytest.mark.usefixtures("get_persistent_resource_status_running_mock")
def test_connect_running(self, ray_client_connect_mock):
connect_result = vertex_ray.ClientBuilder(
Expand All @@ -124,6 +127,7 @@ def test_connect_running(self, ray_client_connect_mock):
== tc.ClusterConstants._TEST_VERTEX_RAY_PR_ID
)

@tc.rovminversion
@pytest.mark.usefixtures("get_persistent_resource_status_running_no_ray_mock")
def test_connect_running_no_ray(self, ray_client_connect_mock):
expected_message = (
Expand All @@ -139,6 +143,7 @@ def test_connect_running_no_ray(self, ray_client_connect_mock):
ray_client_connect_mock.assert_called_once_with()
assert str(exception.value) == expected_message

@tc.rovminversion
@pytest.mark.parametrize(
"address",
[
Expand Down

0 comments on commit 10c6ad2

Please sign in to comment.