Skip to content

Commit

Permalink
feat: enable passing experiment_tensorboard to init without experiment
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 501298160
  • Loading branch information
sararob authored and copybara-github committed Jan 11, 2023
1 parent 2e509d0 commit 369a0cc
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 5 deletions.
16 changes: 12 additions & 4 deletions google/cloud/aiplatform/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ def init(
Example tensorboard resource name format:
"projects/123/locations/us-central1/tensorboards/456"
If `experiment_tensorboard` is provided and `experiment` is not,
the provided `experiment_tensorboard` will be set as the global Tensorboard.
Any subsequent calls to aiplatform.init() with `experiment` and without
`experiment_tensorboard` will automatically assign the global Tensorboard
to the `experiment`.
staging_bucket (str): The default staging bucket to use to stage artifacts
when making API calls. In the form gs://...
credentials (google.auth.credentials.Credentials): The default custom
Expand All @@ -106,17 +112,19 @@ def init(
Raises:
ValueError:
If experiment_description is provided but experiment is not.
If experiment_tensorboard is provided but experiment is not.
"""

if experiment_description and experiment is None:
raise ValueError(
"Experiment needs to be set in `init` in order to add experiment descriptions."
)

if experiment_tensorboard and experiment is None:
raise ValueError(
"Experiment needs to be set in `init` in order to add experiment_tensorboard."
if experiment_tensorboard:
metadata._experiment_tracker.set_tensorboard(
tensorboard=experiment_tensorboard,
project=project,
location=location,
credentials=credentials,
)

# reset metadata_service config if project or location is updated.
Expand Down
7 changes: 7 additions & 0 deletions google/cloud/aiplatform/metadata/experiment_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,13 @@ def resource_name(self) -> str:
"""The Metadata context resource name of this experiment."""
return self._metadata_context.resource_name

@property
def backing_tensorboard_resource_name(self) -> Optional[str]:
"""The Tensorboard resource associated with this Experiment if there is one."""
return self._metadata_context.metadata.get(
constants._BACKING_TENSORBOARD_RESOURCE_KEY
)

def delete(self, *, delete_backing_tensorboard_runs: bool = False):
"""Deletes this experiment all the experiment runs under this experiment
Expand Down
39 changes: 38 additions & 1 deletion google/cloud/aiplatform/metadata/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ class _ExperimentTracker:
def __init__(self):
self._experiment: Optional[experiment_resources.Experiment] = None
self._experiment_run: Optional[experiment_run_resource.ExperimentRun] = None
self._global_tensorboard: Optional[tensorboard_resource.Tensorboard] = None

def reset(self):
"""Resets this experiment tracker, clearing the current experiment and run."""
Expand Down Expand Up @@ -235,11 +236,47 @@ def set_experiment(
experiment_name=experiment, description=description
)

if backing_tensorboard:
backing_tb = backing_tensorboard or self._global_tensorboard

current_backing_tb = experiment.backing_tensorboard_resource_name

if not current_backing_tb and backing_tb:
experiment.assign_backing_tensorboard(tensorboard=backing_tensorboard)

self._experiment = experiment

def set_tensorboard(
self,
tensorboard: Union[
tensorboard_resource.Tensorboard,
str,
],
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
):
"""Sets the global Tensorboard resource for this session.
Args:
tensorboard (Union[str, aiplatform.Tensorboard]):
Required. The Tensorboard resource to set as the global Tensorboard.
project (str):
Optional. Project associated with this Tensorboard resource.
location (str):
Optional. Location associated with this Tensorboard resource.
credentials (auth_credentials.Credentials):
Optional. Custom credentials used to set this Tensorboard resource.
"""
if isinstance(tensorboard, str):
tensorboard = tensorboard_resource.Tensorboard(
tensorboard,
project=project,
location=location,
credentials=credentials,
)

self._global_tensorboard = tensorboard

def start_run(
self,
run: str,
Expand Down
46 changes: 46 additions & 0 deletions tests/system/aiplatform/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,3 +416,49 @@ def test_delete_experiment(self):

with pytest.raises(exceptions.NotFound):
aiplatform.Experiment(experiment_name=self._experiment_name)

def test_init_associates_global_tensorboard_to_experiment(self, shared_state):

tensorboard = aiplatform.Tensorboard.create(
project=e2e_base._PROJECT,
location=e2e_base._LOCATION,
display_name=self._make_display_name("")[:64],
)

shared_state["resources"] = [tensorboard]

aiplatform.init(
project=e2e_base._PROJECT,
location=e2e_base._LOCATION,
experiment_tensorboard=tensorboard,
)

assert (
aiplatform.metadata.metadata._experiment_tracker._global_tensorboard
== tensorboard
)

new_experiment_name = self._make_display_name("")[:64]
new_experiment_resource = aiplatform.Experiment.create(
experiment_name=new_experiment_name
)

shared_state["resources"].append(new_experiment_resource)

aiplatform.init(
project=e2e_base._PROJECT,
location=e2e_base._LOCATION,
experiment=new_experiment_name,
)

assert (
new_experiment_resource._lookup_backing_tensorboard().resource_name
== tensorboard.resource_name
)

assert (
new_experiment_resource._metadata_context.metadata.get(
aiplatform.metadata.constants._BACKING_TENSORBOARD_RESOURCE_KEY
)
== tensorboard.resource_name
)
57 changes: 57 additions & 0 deletions tests/unit/aiplatform/test_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@
_TEST_STAGING_BUCKET = "test-bucket"
_TEST_NETWORK = "projects/12345/global/networks/myVPC"

# tensorboard
_TEST_TENSORBOARD_ID = "1028944691210842416"
_TEST_TENSORBOARD_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/tensorboards/{_TEST_TENSORBOARD_ID}"


@pytest.mark.usefixtures("google_auth_mock")
class TestInit:
Expand Down Expand Up @@ -115,6 +119,59 @@ def test_init_experiment_sets_experiment_with_description(
backing_tensorboard=None,
)

@patch.object(_experiment_tracker, "set_tensorboard")
def test_init_with_experiment_tensorboard_id_sets_global_tensorboard(
self, set_tensorboard_mock
):
creds = credentials.AnonymousCredentials()
initializer.global_config.init(
experiment_tensorboard=_TEST_TENSORBOARD_ID,
project=_TEST_PROJECT,
location=_TEST_LOCATION,
credentials=creds,
)

set_tensorboard_mock.assert_called_once_with(
tensorboard=_TEST_TENSORBOARD_ID,
project=_TEST_PROJECT,
location=_TEST_LOCATION,
credentials=creds,
)

@patch.object(_experiment_tracker, "set_tensorboard")
def test_init_with_experiment_tensorboard_resource_sets_global_tensorboard(
self, set_tensorboard_mock
):
initializer.global_config.init(experiment_tensorboard=_TEST_TENSORBOARD_NAME)

set_tensorboard_mock.assert_called_once_with(
tensorboard=_TEST_TENSORBOARD_NAME,
project=None,
location=None,
credentials=None,
)

@patch.object(_experiment_tracker, "set_tensorboard")
@patch.object(_experiment_tracker, "set_experiment")
def test_init_experiment_without_tensorboard_uses_global_tensorboard(
self,
set_tensorboard_mock,
set_experiment_mock,
):

initializer.global_config.init(experiment_tensorboard=_TEST_TENSORBOARD_NAME)

initializer.global_config.init(
experiment=_TEST_EXPERIMENT,
)

set_experiment_mock.assert_called_once_with(
tensorboard=_TEST_TENSORBOARD_NAME,
project=None,
location=None,
credentials=None,
)

def test_init_experiment_description_fail_without_experiment(self):
with pytest.raises(ValueError):
initializer.global_config.init(experiment_description=_TEST_DESCRIPTION)
Expand Down

0 comments on commit 369a0cc

Please sign in to comment.