From d876b3ad8d0129dc98de9f86567d5bf17791058b Mon Sep 17 00:00:00 2001 From: Alexey Volkov Date: Wed, 7 Dec 2022 13:11:18 -0800 Subject: [PATCH] fix: Fixed argument name in UnmanagedContainerModel PiperOrigin-RevId: 493688203 --- .../metadata/schema/google/artifact_schema.py | 6 +++--- tests/unit/aiplatform/test_metadata_schema.py | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/google/cloud/aiplatform/metadata/schema/google/artifact_schema.py b/google/cloud/aiplatform/metadata/schema/google/artifact_schema.py index 4941e42480..264eff9168 100644 --- a/google/cloud/aiplatform/metadata/schema/google/artifact_schema.py +++ b/google/cloud/aiplatform/metadata/schema/google/artifact_schema.py @@ -222,7 +222,7 @@ class UnmanagedContainerModel(base_artifact.BaseArtifactSchema): def __init__( self, *, - predict_schema_ta: utils.PredictSchemata, + predict_schemata: utils.PredictSchemata, container_spec: utils.ContainerSpec, artifact_id: Optional[str] = None, uri: Optional[str] = None, @@ -233,7 +233,7 @@ def __init__( state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE, ): """Args: - predict_schema_ta (PredictSchemata): + predict_schemata (PredictSchemata): An instance of PredictSchemata which holds instance, parameter and prediction schema uris. container_spec (ContainerSpec): An instance of ContainerSpec which holds the container configuration for the model. @@ -262,7 +262,7 @@ def __init__( check the validity of state transitions. """ extended_metadata = copy.deepcopy(metadata) if metadata else {} - extended_metadata["predictSchemata"] = predict_schema_ta.to_dict() + extended_metadata["predictSchemata"] = predict_schemata.to_dict() extended_metadata["containerSpec"] = container_spec.to_dict() super(UnmanagedContainerModel, self).__init__( diff --git a/tests/unit/aiplatform/test_metadata_schema.py b/tests/unit/aiplatform/test_metadata_schema.py index e12af35b23..826a99b942 100644 --- a/tests/unit/aiplatform/test_metadata_schema.py +++ b/tests/unit/aiplatform/test_metadata_schema.py @@ -817,7 +817,7 @@ def test_vertex_endpoint_constructor_parameters_are_set_correctly(self): assert artifact.schema_version == _TEST_SCHEMA_VERSION def test_unmanaged_container_model_title_is_set_correctly(self): - predict_schema_ta = utils.PredictSchemata( + predict_schemata = utils.PredictSchemata( instance_schema_uri="instance_uri", prediction_schema_uri="prediction_uri", parameters_schema_uri="parameters_uri", @@ -827,13 +827,13 @@ def test_unmanaged_container_model_title_is_set_correctly(self): image_uri="gcr.io/test_container_image_uri" ) artifact = google_artifact_schema.UnmanagedContainerModel( - predict_schema_ta=predict_schema_ta, + predict_schemata=predict_schemata, container_spec=container_spec, ) assert artifact.schema_title == "google.UnmanagedContainerModel" def test_unmanaged_container_model_constructor_parameters_are_set_correctly(self): - predict_schema_ta = utils.PredictSchemata( + predict_schemata = utils.PredictSchemata( instance_schema_uri="instance_uri", prediction_schema_uri="prediction_uri", parameters_schema_uri="parameters_uri", @@ -844,7 +844,7 @@ def test_unmanaged_container_model_constructor_parameters_are_set_correctly(self ) artifact = google_artifact_schema.UnmanagedContainerModel( - predict_schema_ta=predict_schema_ta, + predict_schemata=predict_schemata, container_spec=container_spec, artifact_id=_TEST_ARTIFACT_ID, uri=_TEST_URI, @@ -1253,7 +1253,7 @@ def teardown_method(self): initializer.global_pool.shutdown(wait=True) def test_predict_schemata_to_dict_method_returns_correct_schema(self): - predict_schema_ta = utils.PredictSchemata( + predict_schemata = utils.PredictSchemata( instance_schema_uri="instance_uri", prediction_schema_uri="prediction_uri", parameters_schema_uri="parameters_uri", @@ -1264,7 +1264,7 @@ def test_predict_schemata_to_dict_method_returns_correct_schema(self): "predictionSchemaUri": "prediction_uri", } - assert json.dumps(predict_schema_ta.to_dict()) == json.dumps(expected_results) + assert json.dumps(predict_schemata.to_dict()) == json.dumps(expected_results) def test_create_uri_from_resource_name_for_valid_resouce_names(self): valid_resouce_names = [