Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: update Model.list_model_evaluations and get_model_evaluation to use the provided version #1616

Merged
merged 5 commits into from
Aug 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions google/cloud/aiplatform/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4571,10 +4571,14 @@ def list_model_evaluations(
self,
) -> List["model_evaluation.ModelEvaluation"]:
"""List all Model Evaluation resources associated with this model.
If this Model resource was instantiated with a version, the Model
Evaluation resources for that version will be returned. If no version
was provided when the Model resource was instantiated, Model Evaluation
resources will be returned for the default version.

Example Usage:
my_model = Model(
model_name="projects/123/locations/us-central1/models/456"
model_name="projects/123/locations/us-central1/models/456@1"
)

my_evaluations = my_model.list_model_evaluations()
Expand All @@ -4584,10 +4588,8 @@ def list_model_evaluations(
List of ModelEvaluation resources for the model.
"""

self.wait()

return model_evaluation.ModelEvaluation._list(
parent=self.resource_name,
parent=self.versioned_resource_name,
credentials=self.credentials,
)

Expand All @@ -4597,7 +4599,10 @@ def get_model_evaluation(
) -> Optional[model_evaluation.ModelEvaluation]:
"""Returns a ModelEvaluation resource and instantiates its representation.
If no evaluation_id is passed, it will return the first evaluation associated
with this model.
with this model. If the aiplatform.Model resource was instantiated with a
version, this will return a Model Evaluation from that version. If no version
was specified when instantiating the Model resource, this will return an
Evaluation from the default version.

Example usage:
my_model = Model(
Expand Down
39 changes: 38 additions & 1 deletion tests/unit/aiplatform/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2357,7 +2357,7 @@ def test_update(self, update_model_mock, get_model_mock):
model=current_model_proto, update_mask=update_mask
)

def test_get_model_evaluation_with_id(
def test_get_model_evaluation_with_evaluation_id(
self,
mock_model_eval_get,
get_model_mock,
Expand All @@ -2371,6 +2371,26 @@ def test_get_model_evaluation_with_id(
name=_TEST_MODEL_EVAL_RESOURCE_NAME, retry=base._DEFAULT_RETRY
)

def test_get_model_evaluation_with_evaluation_and_instantiated_version(
self,
mock_model_eval_get,
get_model_mock,
list_model_evaluations_mock,
):
test_model = models.Model(
model_name=f"{_TEST_MODEL_RESOURCE_NAME}@{_TEST_VERSION_ID}"
)

test_model.get_model_evaluation(evaluation_id=_TEST_ID)

mock_model_eval_get.assert_called_once_with(
name=_TEST_MODEL_EVAL_RESOURCE_NAME, retry=base._DEFAULT_RETRY
)

list_model_evaluations_mock.assert_called_once_with(
request={"parent": test_model.versioned_resource_name}
)

def test_get_model_evaluation_without_id(
self,
mock_model_eval_get,
Expand Down Expand Up @@ -2402,6 +2422,23 @@ def test_list_model_evaluations(

assert len(eval_list) == len(_TEST_MODEL_EVAL_LIST)

def test_list_model_evaluations_with_version(
self,
get_model_mock,
mock_model_eval_get,
list_model_evaluations_mock,
):

test_model = models.Model(
model_name=f"{_TEST_MODEL_RESOURCE_NAME}@{_TEST_VERSION_ID}"
)

test_model.list_model_evaluations()

list_model_evaluations_mock.assert_called_once_with(
request={"parent": test_model.versioned_resource_name}
)

def test_init_with_version_in_resource_name(self, get_model_with_version):
model = models.Model(
model_name=models.ModelRegistry._get_versioned_name(
Expand Down