Skip to content

Commit

Permalink
fix: Fixed argument name in UnmanagedContainerModel
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 493688203
  • Loading branch information
Ark-kun authored and copybara-github committed Dec 7, 2022
1 parent 52656ca commit d876b3a
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ class UnmanagedContainerModel(base_artifact.BaseArtifactSchema):
def __init__(
self,
*,
predict_schema_ta: utils.PredictSchemata,
predict_schemata: utils.PredictSchemata,
container_spec: utils.ContainerSpec,
artifact_id: Optional[str] = None,
uri: Optional[str] = None,
Expand All @@ -233,7 +233,7 @@ def __init__(
state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE,
):
"""Args:
predict_schema_ta (PredictSchemata):
predict_schemata (PredictSchemata):
An instance of PredictSchemata which holds instance, parameter and prediction schema uris.
container_spec (ContainerSpec):
An instance of ContainerSpec which holds the container configuration for the model.
Expand Down Expand Up @@ -262,7 +262,7 @@ def __init__(
check the validity of state transitions.
"""
extended_metadata = copy.deepcopy(metadata) if metadata else {}
extended_metadata["predictSchemata"] = predict_schema_ta.to_dict()
extended_metadata["predictSchemata"] = predict_schemata.to_dict()
extended_metadata["containerSpec"] = container_spec.to_dict()

super(UnmanagedContainerModel, self).__init__(
Expand Down
12 changes: 6 additions & 6 deletions tests/unit/aiplatform/test_metadata_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ def test_vertex_endpoint_constructor_parameters_are_set_correctly(self):
assert artifact.schema_version == _TEST_SCHEMA_VERSION

def test_unmanaged_container_model_title_is_set_correctly(self):
predict_schema_ta = utils.PredictSchemata(
predict_schemata = utils.PredictSchemata(
instance_schema_uri="instance_uri",
prediction_schema_uri="prediction_uri",
parameters_schema_uri="parameters_uri",
Expand All @@ -827,13 +827,13 @@ def test_unmanaged_container_model_title_is_set_correctly(self):
image_uri="gcr.io/test_container_image_uri"
)
artifact = google_artifact_schema.UnmanagedContainerModel(
predict_schema_ta=predict_schema_ta,
predict_schemata=predict_schemata,
container_spec=container_spec,
)
assert artifact.schema_title == "google.UnmanagedContainerModel"

def test_unmanaged_container_model_constructor_parameters_are_set_correctly(self):
predict_schema_ta = utils.PredictSchemata(
predict_schemata = utils.PredictSchemata(
instance_schema_uri="instance_uri",
prediction_schema_uri="prediction_uri",
parameters_schema_uri="parameters_uri",
Expand All @@ -844,7 +844,7 @@ def test_unmanaged_container_model_constructor_parameters_are_set_correctly(self
)

artifact = google_artifact_schema.UnmanagedContainerModel(
predict_schema_ta=predict_schema_ta,
predict_schemata=predict_schemata,
container_spec=container_spec,
artifact_id=_TEST_ARTIFACT_ID,
uri=_TEST_URI,
Expand Down Expand Up @@ -1253,7 +1253,7 @@ def teardown_method(self):
initializer.global_pool.shutdown(wait=True)

def test_predict_schemata_to_dict_method_returns_correct_schema(self):
predict_schema_ta = utils.PredictSchemata(
predict_schemata = utils.PredictSchemata(
instance_schema_uri="instance_uri",
prediction_schema_uri="prediction_uri",
parameters_schema_uri="parameters_uri",
Expand All @@ -1264,7 +1264,7 @@ def test_predict_schemata_to_dict_method_returns_correct_schema(self):
"predictionSchemaUri": "prediction_uri",
}

assert json.dumps(predict_schema_ta.to_dict()) == json.dumps(expected_results)
assert json.dumps(predict_schemata.to_dict()) == json.dumps(expected_results)

def test_create_uri_from_resource_name_for_valid_resouce_names(self):
valid_resouce_names = [
Expand Down

0 comments on commit d876b3a

Please sign in to comment.