Skip to content

Commit

Permalink
feat: support dataset update (#1416)
Browse files Browse the repository at this point in the history
* feat: add update() method and system test

* fix: fix and add unit test

* remove superfluous line in system test
  • Loading branch information
jaycee-li authored Jun 9, 2022
1 parent b91db66 commit e3eb82f
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 4 deletions.
68 changes: 65 additions & 3 deletions google/cloud/aiplatform/datasets/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

# Copyright 2020 Google LLC
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -31,6 +31,7 @@
io as gca_io,
)
from google.cloud.aiplatform.datasets import _datasources
from google.protobuf import field_mask_pb2

_LOGGER = base.Logger(__name__)

Expand Down Expand Up @@ -597,8 +598,69 @@ def export_data(self, output_dir: str) -> Sequence[str]:

return export_data_response.exported_files

def update(self):
raise NotImplementedError("Update dataset has not been implemented yet")
def update(
self,
*,
display_name: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
description: Optional[str] = None,
update_request_timeout: Optional[float] = None,
) -> "_Dataset":
"""Update the dataset.
Updatable fields:
- ``display_name``
- ``description``
- ``labels``
Args:
display_name (str):
Optional. The user-defined name of the Dataset.
The name can be up to 128 characters long and can be consist
of any UTF-8 characters.
labels (Dict[str, str]):
Optional. Labels with user-defined metadata to organize your Tensorboards.
Label keys and values can be no longer than 64 characters
(Unicode codepoints), can only contain lowercase letters, numeric
characters, underscores and dashes. International characters are allowed.
No more than 64 user labels can be associated with one Tensorboard
(System labels are excluded).
See https://goo.gl/xmQnxf for more information and examples of labels.
System reserved label keys are prefixed with "aiplatform.googleapis.com/"
and are immutable.
description (str):
Optional. The description of the Dataset.
update_request_timeout (float):
Optional. The timeout for the update request in seconds.
Returns:
dataset (Dataset):
Updated dataset.
"""

update_mask = field_mask_pb2.FieldMask()
if display_name:
update_mask.paths.append("display_name")

if labels:
update_mask.paths.append("labels")

if description:
update_mask.paths.append("description")

update_dataset = gca_dataset.Dataset(
name=self.resource_name,
display_name=display_name,
description=description,
labels=labels,
)

self._gca_resource = self.api_client.update_dataset(
dataset=update_dataset,
update_mask=update_mask,
timeout=update_request_timeout,
)

return self

@classmethod
def list(
Expand Down
27 changes: 26 additions & 1 deletion tests/system/aiplatform/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

# Copyright 2020 Google LLC
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -50,6 +50,8 @@
"6203215905493614592" # permanent_text_entity_extraction_dataset
)
_TEST_DATASET_DISPLAY_NAME = "permanent_50_flowers_dataset"
_TEST_DATASET_LABELS = {"test": "labels"}
_TEST_DATASET_DESCRIPTION = "test description"
_TEST_TABULAR_CLASSIFICATION_GCS_SOURCE = "gs://ucaip-sample-resources/iris_1000.csv"
_TEST_FORECASTING_BQ_SOURCE = (
"bq://ucaip-sample-tests:ucaip_test_us_central1.2020_sales_train"
Expand Down Expand Up @@ -350,3 +352,26 @@ def test_export_data(self, storage_client, staging_bucket):
blob = bucket.get_blob(prefix)

assert blob # Verify the returned GCS export path exists

def test_update_dataset(self):
"""Create a new dataset and use update() method to change its display_name, labels, and description.
Then confirm these fields of the dataset was successfully modifed."""

try:
dataset = aiplatform.ImageDataset.create()
labels = dataset.labels

dataset = dataset.update(
display_name=_TEST_DATASET_DISPLAY_NAME,
labels=_TEST_DATASET_LABELS,
description=_TEST_DATASET_DESCRIPTION,
update_request_timeout=None,
)
labels.update(_TEST_DATASET_LABELS)

assert dataset.display_name == _TEST_DATASET_DISPLAY_NAME
assert dataset.labels == labels
assert dataset.gca_resource.description == _TEST_DATASET_DESCRIPTION

finally:
dataset.delete()
46 changes: 46 additions & 0 deletions tests/unit/aiplatform/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from google.cloud.aiplatform import schema
from google.cloud import bigquery
from google.cloud import storage
from google.protobuf import field_mask_pb2

from google.cloud.aiplatform.compat.services import dataset_service_client

Expand All @@ -59,6 +60,7 @@
_TEST_ID = "1028944691210842416"
_TEST_DISPLAY_NAME = "my_dataset_1234"
_TEST_DATA_LABEL_ITEMS = None
_TEST_DESCRIPTION = "test description"

_TEST_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/datasets/{_TEST_ID}"
_TEST_ALT_NAME = (
Expand Down Expand Up @@ -425,6 +427,20 @@ def export_data_mock():
yield export_data_mock


@pytest.fixture
def update_dataset_mock():
with patch.object(
dataset_service_client.DatasetServiceClient, "update_dataset"
) as update_dataset_mock:
update_dataset_mock.return_value = gca_dataset.Dataset(
name=_TEST_NAME,
display_name=f"update_{_TEST_DISPLAY_NAME}",
labels=_TEST_LABELS,
description=_TEST_DESCRIPTION,
)
yield update_dataset_mock


@pytest.fixture
def list_datasets_mock():
with patch.object(
Expand Down Expand Up @@ -996,6 +1012,36 @@ def test_delete_dataset(self, delete_dataset_mock, sync):

delete_dataset_mock.assert_called_once_with(name=my_dataset.resource_name)

@pytest.mark.usefixtures("get_dataset_mock")
def test_update_dataset(self, update_dataset_mock):
aiplatform.init(project=_TEST_PROJECT)

my_dataset = datasets._Dataset(dataset_name=_TEST_NAME)

my_dataset = my_dataset.update(
display_name=f"update_{_TEST_DISPLAY_NAME}",
labels=_TEST_LABELS,
description=_TEST_DESCRIPTION,
update_request_timeout=None,
)

expected_dataset = gca_dataset.Dataset(
name=_TEST_NAME,
display_name=f"update_{_TEST_DISPLAY_NAME}",
labels=_TEST_LABELS,
description=_TEST_DESCRIPTION,
)

expected_mask = field_mask_pb2.FieldMask(
paths=["display_name", "labels", "description"]
)

update_dataset_mock.assert_called_once_with(
dataset=expected_dataset,
update_mask=expected_mask,
timeout=None,
)


@pytest.mark.usefixtures("google_auth_mock")
class TestImageDataset:
Expand Down

0 comments on commit e3eb82f

Please sign in to comment.