From d06b22d1ac6197c460092739e8572b9beb08bd63 Mon Sep 17 00:00:00 2001 From: Makoto Uchida Date: Fri, 13 Jan 2023 10:06:03 -0800 Subject: [PATCH] fix: address broken unit tests in certain environments PiperOrigin-RevId: 501875885 --- .../vizier/pyvizier/study_config.py | 29 +++++++++++++++---- .../aiplatform/test_metadata_resources.py | 2 ++ tests/unit/aiplatform/test_metadata_store.py | 1 + tests/unit/aiplatform/test_utils.py | 2 ++ 4 files changed, 28 insertions(+), 6 deletions(-) diff --git a/google/cloud/aiplatform/vizier/pyvizier/study_config.py b/google/cloud/aiplatform/vizier/pyvizier/study_config.py index 0314e1442f..75b3015186 100644 --- a/google/cloud/aiplatform/vizier/pyvizier/study_config.py +++ b/google/cloud/aiplatform/vizier/pyvizier/study_config.py @@ -117,19 +117,36 @@ class SearchSpace(SearchSpace): @classmethod def from_proto(cls, proto: study_pb2.StudySpec) -> "SearchSpace": """Extracts a SearchSpace object from a StudyConfig proto.""" - parameter_configs = [] + + # For google-vizier <= 0.0.15 + if hasattr(cls, "_factory"): + parameter_configs = [] + for pc in proto.parameters: + parameter_configs.append( + proto_converters.ParameterConfigConverter.from_proto(pc) + ) + return cls._factory(parameter_configs=parameter_configs) + + result = cls() for pc in proto.parameters: - parameter_configs.append( - proto_converters.ParameterConfigConverter.from_proto(pc) - ) - return cls._factory(parameter_configs=parameter_configs) + result.add(proto_converters.ParameterConfigConverter.from_proto(pc)) + + return result @property def parameter_protos(self) -> List[study_pb2.StudySpec.ParameterSpec]: """Returns the search space as a List of ParameterConfig protos.""" + + # For google-vizier <= 0.0.15 + if isinstance(self._parameter_configs, list): + return [ + proto_converters.ParameterConfigConverter.to_proto(pc) + for pc in self._parameter_configs + ] + return [ proto_converters.ParameterConfigConverter.to_proto(pc) - for pc in self._parameter_configs + for _, pc in self._parameter_configs.items() ] diff --git a/tests/unit/aiplatform/test_metadata_resources.py b/tests/unit/aiplatform/test_metadata_resources.py index 76db36dfbe..72b1ce7d34 100644 --- a/tests/unit/aiplatform/test_metadata_resources.py +++ b/tests/unit/aiplatform/test_metadata_resources.py @@ -614,6 +614,7 @@ def list_artifact_empty_mock(): yield list_artifacts_mock +@pytest.mark.usefixtures("google_auth_mock") class TestExecution: def setup_method(self): reload(initializer) @@ -893,6 +894,7 @@ def test_query_input_and_output_artifacts( assert artifact_list[0]._gca_resource == expected_artifact +@pytest.mark.usefixtures("google_auth_mock") class TestArtifact: def setup_method(self): reload(initializer) diff --git a/tests/unit/aiplatform/test_metadata_store.py b/tests/unit/aiplatform/test_metadata_store.py index b6dfe8e032..78d1618f4a 100644 --- a/tests/unit/aiplatform/test_metadata_store.py +++ b/tests/unit/aiplatform/test_metadata_store.py @@ -134,6 +134,7 @@ def delete_metadata_store_mock(): yield delete_metadata_store_mock +@pytest.mark.usefixtures("google_auth_mock") class TestMetadataStore: def setup_method(self): reload(initializer) diff --git a/tests/unit/aiplatform/test_utils.py b/tests/unit/aiplatform/test_utils.py index 4dfc5951c7..164dec7d4f 100644 --- a/tests/unit/aiplatform/test_utils.py +++ b/tests/unit/aiplatform/test_utils.py @@ -390,6 +390,7 @@ def test_wrapped_client(): ) +@pytest.mark.usefixtures("google_auth_mock") def test_client_w_override_default_version(): test_client_info = gapic_v1.client_info.ClientInfo() @@ -407,6 +408,7 @@ def test_client_w_override_default_version(): ) +@pytest.mark.usefixtures("google_auth_mock") def test_client_w_override_select_version(): test_client_info = gapic_v1.client_info.ClientInfo()