diff --git a/google/cloud/aiplatform/preview/vertex_ray/client_builder.py b/google/cloud/aiplatform/preview/vertex_ray/client_builder.py index 5cd3d1ee1d..734e8a5c03 100644 --- a/google/cloud/aiplatform/preview/vertex_ray/client_builder.py +++ b/google/cloud/aiplatform/preview/vertex_ray/client_builder.py @@ -111,6 +111,18 @@ def __init__(self, address: Optional[str]) -> None: " failed to start Head node properly because custom service account isn't supported.", ) logging.debug("[Ray on Vertex AI]: Resolved head node ip: %s", address) + cluster = _gapic_utils.persistent_resource_to_cluster( + persistent_resource=self.response + ) + if cluster is None: + raise ValueError( + "[Ray on Vertex AI]: Please delete and recreate the cluster (The cluster is not a Ray cluster or the cluster image is outdated)." + ) + local_ray_verion = _validation_utils.get_local_ray_version() + if cluster.ray_version != local_ray_verion: + raise ValueError( + f"[Ray on Vertex AI]: Local runtime has Ray version {local_ray_verion}, but the cluster runtime has {cluster.ray_version}. Please ensure that the Ray versions match." + ) super().__init__(address) def connect(self) -> _VertexRayClientContext: diff --git a/google/cloud/aiplatform/preview/vertex_ray/util/_validation_utils.py b/google/cloud/aiplatform/preview/vertex_ray/util/_validation_utils.py index f99efa8dfa..db4e4e6acd 100644 --- a/google/cloud/aiplatform/preview/vertex_ray/util/_validation_utils.py +++ b/google/cloud/aiplatform/preview/vertex_ray/util/_validation_utils.py @@ -18,6 +18,7 @@ import google.auth import google.auth.transport.requests import logging +import ray import re from google.cloud.aiplatform import initializer @@ -68,6 +69,13 @@ def maybe_reconstruct_resource_name(address) -> str: return address +def get_local_ray_version(): + ray_version = ray.__version__.split(".") + if len(ray_version) == 3: + ray_version = ray_version[:2] + return "_".join(ray_version) + + def get_image_uri(ray_version, python_version, enable_cuda): """Image uri for a given ray version and python version.""" if ray_version not in ["2_4"]: diff --git a/tests/unit/vertex_ray/test_constants.py b/tests/unit/vertex_ray/test_constants.py index 97dc92260d..8290aca780 100644 --- a/tests/unit/vertex_ray/test_constants.py +++ b/tests/unit/vertex_ray/test_constants.py @@ -39,6 +39,13 @@ ResourceRuntimeSpec, ) +import pytest +import sys + +rovminversion = pytest.mark.skipif( + sys.version_info > (3, 10), reason="Requires python3.10 or lower" +) + @dataclasses.dataclass(frozen=True) class ProjectConstants: diff --git a/tests/unit/vertex_ray/test_vertex_ray_client.py b/tests/unit/vertex_ray/test_vertex_ray_client.py index a421f84ab2..9d33bc89b8 100644 --- a/tests/unit/vertex_ray/test_vertex_ray_client.py +++ b/tests/unit/vertex_ray/test_vertex_ray_client.py @@ -84,6 +84,7 @@ def setup_method(self): def teardown_method(self): aiplatform.initializer.global_pool.shutdown(wait=True) + @tc.rovminversion @pytest.mark.usefixtures("get_persistent_resource_status_running_mock") def test_init_with_full_resource_name( self, @@ -94,6 +95,7 @@ def test_init_with_full_resource_name( tc.ClusterConstants._TEST_VERTEX_RAY_HEAD_NODE_IP, ) + @tc.rovminversion @pytest.mark.usefixtures( "get_persistent_resource_status_running_mock", "google_auth_mock" ) @@ -112,6 +114,7 @@ def test_init_with_cluster_name( tc.ClusterConstants._TEST_VERTEX_RAY_HEAD_NODE_IP, ) + @tc.rovminversion @pytest.mark.usefixtures("get_persistent_resource_status_running_mock") def test_connect_running(self, ray_client_connect_mock): connect_result = vertex_ray.ClientBuilder( @@ -124,6 +127,7 @@ def test_connect_running(self, ray_client_connect_mock): == tc.ClusterConstants._TEST_VERTEX_RAY_PR_ID ) + @tc.rovminversion @pytest.mark.usefixtures("get_persistent_resource_status_running_no_ray_mock") def test_connect_running_no_ray(self, ray_client_connect_mock): expected_message = ( @@ -139,6 +143,7 @@ def test_connect_running_no_ray(self, ray_client_connect_mock): ray_client_connect_mock.assert_called_once_with() assert str(exception.value) == expected_message + @tc.rovminversion @pytest.mark.parametrize( "address", [