Skip to content

Commit

Permalink
feat: add get method for Experiment and ExperimentRun
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 518042292
  • Loading branch information
jaycee-li authored and copybara-github committed Mar 20, 2023
1 parent 9fa3c68 commit 41cd943
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 4 deletions.
42 changes: 40 additions & 2 deletions google/cloud/aiplatform/metadata/experiment_resources.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

# Copyright 2022 Google LLC
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -18,8 +18,9 @@
import abc
from dataclasses import dataclass
import logging
from typing import Dict, List, NamedTuple, Optional, Union, Tuple, Type
from typing import Dict, List, NamedTuple, Optional, Tuple, Type, Union

from google.api_core import exceptions
from google.auth import credentials as auth_credentials

from google.cloud.aiplatform import base
Expand Down Expand Up @@ -211,6 +212,43 @@ def create(

return self

@classmethod
def get(
cls,
experiment_name: str,
*,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> Optional["Experiment"]:
"""Gets experiment if one exists with this experiment_name in Vertex AI Experiments.
Args:
experiment_name (str):
Required. The name of this experiment.
project (str):
Optional. Project used to retrieve this resource.
Overrides project set in aiplatform.init.
location (str):
Optional. Location used to retrieve this resource.
Overrides location set in aiplatform.init.
credentials (auth_credentials.Credentials):
Optional. Custom credentials used to retrieve this resource.
Overrides credentials set in aiplatform.init.
Returns:
Vertex AI experiment or None if no resource was found.
"""
try:
return cls(
experiment_name=experiment_name,
project=project,
location=location,
credentials=credentials,
)
except exceptions.NotFound:
return None

@classmethod
def get_or_create(
cls,
Expand Down
52 changes: 51 additions & 1 deletion google/cloud/aiplatform/metadata/experiment_run_resource.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

# Copyright 2022 Google LLC
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -386,6 +386,56 @@ def _lookup_tensorboard_run_artifact(
metadata=tensorboard_run_artifact,
)

@classmethod
def get(
cls,
run_name: str,
*,
experiment: Optional[Union[experiment_resources.Experiment, str]] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> Optional["ExperimentRun"]:
"""Gets experiment run if one exists with this run_name.
Args:
run_name (str):
Required. The name of this run.
experiment (Union[experiment_resources.Experiment, str]):
Optional. The name or instance of this experiment.
If not set, use the default experiment in `aiplatform.init`
project (str):
Optional. Project where this experiment run is located.
Overrides project set in aiplatform.init.
location (str):
Optional. Location where this experiment run is located.
Overrides location set in aiplatform.init.
credentials (auth_credentials.Credentials):
Optional. Custom credentials used to retrieve this experiment run.
Overrides credentials set in aiplatform.init.
Returns:
Vertex AI experimentRun or None if no resource was found.
"""
experiment = experiment or metadata._experiment_tracker.experiment

if not experiment:
raise ValueError(
"experiment must be provided or "
"experiment should be set using aiplatform.init"
)

try:
return cls(
run_name=run_name,
experiment=experiment,
project=project,
location=location,
credentials=credentials,
)
except exceptions.NotFound:
return None

@classmethod
def list(
cls,
Expand Down
91 changes: 90 additions & 1 deletion tests/unit/aiplatform/test_metadata.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

# Copyright 2022 Google LLC
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -280,6 +280,17 @@ def get_execution_mock():
yield get_execution_mock


@pytest.fixture
def get_execution_not_found_mock():
with patch.object(
MetadataServiceClient, "get_execution"
) as get_execution_not_found_mock:
get_execution_not_found_mock.side_effect = exceptions.NotFound(
"test: not found"
)
yield get_execution_not_found_mock


@pytest.fixture
def get_execution_wrong_schema_mock():
with patch.object(
Expand Down Expand Up @@ -681,6 +692,13 @@ def get_experiment_mock():
yield get_context_mock


@pytest.fixture
def get_experiment_not_found_mock():
with patch.object(MetadataServiceClient, "get_context") as get_context_mock:
get_context_mock.side_effect = exceptions.NotFound("test: not found")
yield get_context_mock


@pytest.fixture
def get_experiment_run_run_mock():
with patch.object(MetadataServiceClient, "get_context") as get_context_mock:
Expand All @@ -704,6 +722,17 @@ def get_experiment_run_mock():
yield get_context_mock


@pytest.fixture
def get_experiment_run_not_found_mock():
with patch.object(MetadataServiceClient, "get_context") as get_context_mock:
get_context_mock.side_effect = [
_EXPERIMENT_MOCK,
exceptions.NotFound("test: not found"),
]

yield get_context_mock


@pytest.fixture
def create_experiment_context_mock():
with patch.object(MetadataServiceClient, "create_context") as create_context_mock:
Expand Down Expand Up @@ -1125,6 +1154,66 @@ def test_init_experiment_wrong_schema(self):
experiment=_TEST_EXPERIMENT,
)

def test_get_experiment(self, get_experiment_mock):
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)

exp = aiplatform.Experiment.get(_TEST_EXPERIMENT)

assert exp.name == _TEST_EXPERIMENT
get_experiment_mock.assert_called_with(
name=_TEST_CONTEXT_NAME, retry=base._DEFAULT_RETRY
)

def test_get_experiment_not_found(self, get_experiment_not_found_mock):
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)

exp = aiplatform.Experiment.get(_TEST_EXPERIMENT)

assert exp is None
get_experiment_not_found_mock.assert_called_with(
name=_TEST_CONTEXT_NAME, retry=base._DEFAULT_RETRY
)

@pytest.mark.usefixtures(
"get_metadata_store_mock", "get_tensorboard_run_artifact_not_found_mock"
)
def test_get_experiment_run(self, get_experiment_run_mock):
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)

run = aiplatform.ExperimentRun.get(_TEST_RUN, experiment=_TEST_EXPERIMENT)

assert run.name == _TEST_RUN
get_experiment_run_mock.assert_called_with(
name=f"{_TEST_CONTEXT_NAME}-{_TEST_RUN}", retry=base._DEFAULT_RETRY
)

@pytest.mark.usefixtures(
"get_metadata_store_mock",
"get_tensorboard_run_artifact_not_found_mock",
"get_execution_not_found_mock",
)
def test_get_experiment_run_not_found(self, get_experiment_run_not_found_mock):
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)

run = aiplatform.ExperimentRun.get(_TEST_RUN, experiment=_TEST_EXPERIMENT)

assert run is None
get_experiment_run_not_found_mock.assert_called_with(
name=f"{_TEST_CONTEXT_NAME}-{_TEST_RUN}", retry=base._DEFAULT_RETRY
)

@pytest.mark.usefixtures("get_metadata_store_mock")
def test_start_run(
self,
Expand Down

0 comments on commit 41cd943

Please sign in to comment.