From aed8c7604f5d89f52c53a599330fd502d02f7877 Mon Sep 17 00:00:00 2001 From: Sara Robinson Date: Fri, 14 Jul 2023 12:13:31 -0700 Subject: [PATCH] fix: require model name in ModelEvaluation.list() PiperOrigin-RevId: 548190073 --- google/cloud/aiplatform/base.py | 4 ++ .../model_evaluation/model_evaluation.py | 62 ++++++++++++++++++- .../unit/aiplatform/test_model_evaluation.py | 60 ++++++++++++++++++ 3 files changed, 125 insertions(+), 1 deletion(-) diff --git a/google/cloud/aiplatform/base.py b/google/cloud/aiplatform/base.py index 9544b843c4..c78f61285e 100644 --- a/google/cloud/aiplatform/base.py +++ b/google/cloud/aiplatform/base.py @@ -1143,6 +1143,7 @@ def _list_with_local_order( project: Optional[str] = None, location: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, + parent: Optional[str] = None, ) -> List[VertexAiResourceNoun]: """Private method to list all instances of this Vertex AI Resource, takes a `cls_filter` arg to filter to a particular SDK resource @@ -1179,6 +1180,8 @@ def _list_with_local_order( credentials (auth_credentials.Credentials): Optional. Custom credentials to use to retrieve list. Overrides credentials set in aiplatform.init. + parent (str): + Optional. The parent resource name if any to retrieve resource list from. Returns: List[VertexAiResourceNoun] - A list of SDK resource objects @@ -1192,6 +1195,7 @@ def _list_with_local_order( project=project, location=location, credentials=credentials, + parent=parent, ) if order_by: diff --git a/google/cloud/aiplatform/model_evaluation/model_evaluation.py b/google/cloud/aiplatform/model_evaluation/model_evaluation.py index f8553b7644..2c90e830ab 100644 --- a/google/cloud/aiplatform/model_evaluation/model_evaluation.py +++ b/google/cloud/aiplatform/model_evaluation/model_evaluation.py @@ -22,7 +22,7 @@ from google.cloud.aiplatform import models from google.protobuf import struct_pb2 -from typing import Optional +from typing import List, Optional class ModelEvaluation(base.VertexAiResourceNounWithFutureManager): @@ -91,3 +91,63 @@ def delete(self): raise NotImplementedError( "Deleting a model evaluation has not been implemented yet." ) + + @classmethod + def list( + cls, + model: str, + filter: Optional[str] = None, + order_by: Optional[str] = None, + enable_simple_view: bool = False, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List["ModelEvaluation"]: + """List all ModelEvaluation resources on the provided model. + + Example Usage: + + aiplatform.ModelEvaluation.list( + model="projects/123/locations/us-central1/models/456", + ) + + aiplatform.Model.list( + model="projects/123/locations/us-central1/models/456", + order_by="create_time desc, display_name" + ) + + Args: + model (str): + Required. The resource name of the model to list evaluations for. + For example: "projects/123/locations/us-central1/models/456". + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + parent (str): + Optional. The parent resource name if any to retrieve list from. + + Returns: + List[VertexAiResourceNoun] - A list of SDK resource objects + """ + + return super()._list_with_local_order( + filter=filter, + order_by=order_by, + project=project, + location=location, + credentials=credentials, + parent=model, + ) diff --git a/tests/unit/aiplatform/test_model_evaluation.py b/tests/unit/aiplatform/test_model_evaluation.py index 73e1bf72b1..0a06012e42 100644 --- a/tests/unit/aiplatform/test_model_evaluation.py +++ b/tests/unit/aiplatform/test_model_evaluation.py @@ -15,10 +15,13 @@ # limitations under the License. # +import datetime import pytest from unittest import mock +from google.api_core import datetime_helpers + from google.cloud import aiplatform from google.cloud.aiplatform import base from google.cloud.aiplatform import models @@ -96,6 +99,37 @@ def mock_model_eval_get(): yield mock_get_model_eval +_TEST_MODEL_EVAL_LIST = [ + gca_model_evaluation.ModelEvaluation( + name=_TEST_MODEL_EVAL_RESOURCE_NAME, + create_time=datetime_helpers.DatetimeWithNanoseconds( + 2023, 5, 14, 16, 24, 3, 299558, tzinfo=datetime.timezone.utc + ), + ), + gca_model_evaluation.ModelEvaluation( + name=_TEST_MODEL_EVAL_RESOURCE_NAME, + create_time=datetime_helpers.DatetimeWithNanoseconds( + 2023, 6, 14, 16, 24, 3, 299558, tzinfo=datetime.timezone.utc + ), + ), + gca_model_evaluation.ModelEvaluation( + name=_TEST_MODEL_EVAL_RESOURCE_NAME, + create_time=datetime_helpers.DatetimeWithNanoseconds( + 2023, 7, 14, 16, 24, 3, 299558, tzinfo=datetime.timezone.utc + ), + ), +] + + +@pytest.fixture +def list_model_evaluations_mock(): + with mock.patch.object( + model_service_client.ModelServiceClient, "list_model_evaluations" + ) as list_model_evaluations_mock: + list_model_evaluations_mock.return_value = _TEST_MODEL_EVAL_LIST + yield list_model_evaluations_mock + + @pytest.mark.usefixtures("google_auth_mock") class TestModelEvaluation: def test_init_model_evaluation_with_only_resource_name(self, mock_model_eval_get): @@ -156,3 +190,29 @@ def test_no_delete_model_evaluation_method(self, mock_model_eval_get): with pytest.raises(NotImplementedError): my_eval.delete() + + def test_list_model_evaluations( + self, + mock_model_eval_get, + get_model_mock, + list_model_evaluations_mock, + ): + aiplatform.init(project=_TEST_PROJECT) + + metrics_list = aiplatform.ModelEvaluation.list(model=_TEST_MODEL_RESOURCE_NAME) + + assert isinstance(metrics_list[0], aiplatform.ModelEvaluation) + + def test_list_model_evaluations_with_order_by( + self, + mock_model_eval_get, + get_model_mock, + list_model_evaluations_mock, + ): + aiplatform.init(project=_TEST_PROJECT) + + metrics_list = aiplatform.ModelEvaluation.list( + model=_TEST_MODEL_RESOURCE_NAME, order_by="create_time desc" + ) + + assert metrics_list[0].create_time > metrics_list[1].create_time