diff --git a/.kokoro/continuous/system.cfg b/.kokoro/continuous/system.cfg index f5ed200a12..203d61a329 100644 --- a/.kokoro/continuous/system.cfg +++ b/.kokoro/continuous/system.cfg @@ -2,7 +2,7 @@ env_vars: { key: "NOX_SESSION" - value: "system-3.8" + value: "system-3.11" } # Run system tests in parallel, splitting up by file diff --git a/.kokoro/presubmit/system.cfg b/.kokoro/presubmit/system.cfg index 568d482bf7..c0c769d999 100644 --- a/.kokoro/presubmit/system.cfg +++ b/.kokoro/presubmit/system.cfg @@ -3,7 +3,7 @@ # Run system tests when test files are modified env_vars: { key: "NOX_SESSION" - value: "system-3.8" + value: "system-3.11" } # Run system tests in parallel, splitting up by file diff --git a/noxfile.py b/noxfile.py index bc9dbabd5e..7871ab8b8a 100644 --- a/noxfile.py +++ b/noxfile.py @@ -48,7 +48,7 @@ ] UNIT_TEST_EXTRAS_BY_PYTHON = {} -SYSTEM_TEST_PYTHON_VERSIONS = ["3.8"] +SYSTEM_TEST_PYTHON_VERSIONS = ["3.11"] SYSTEM_TEST_STANDARD_DEPENDENCIES = [ "mock", "pytest", diff --git a/setup.py b/setup.py index 0b16e349a0..ec6f62d641 100644 --- a/setup.py +++ b/setup.py @@ -86,6 +86,12 @@ autologging_extra_require = ["mlflow>=1.27.0,<=2.1.1"] +preview_extra_require = [ + "cloudpickle < 3.0", + "google-cloud-logging < 4.0", + "importlib-metadata < 7.0; python_version<'3.8'", +] + full_extra_require = list( set( tensorboard_extra_require @@ -100,6 +106,7 @@ + prediction_extra_require + private_endpoints_extra_require + autologging_extra_require + + preview_extra_require ) ) testing_extra_require = ( @@ -107,12 +114,16 @@ + profiler_extra_require + [ "grpcio-testing", - "pytest-asyncio", - "pytest-xdist", "ipython", "kfp", - "xgboost", + "pyfakefs", + "pytest-asyncio", + "pytest-xdist", "scikit-learn", + "tensorflow >=2.3.0, < 2.13.0", + "torch >= 2.0.0; python_version>='3.8'", + "torch; python_version<'3.8'", + "xgboost", ] ) @@ -160,6 +171,7 @@ "datasets": datasets_extra_require, "private_endpoints": private_endpoints_extra_require, "autologging": autologging_extra_require, + "preview": preview_extra_require, }, python_requires=">=3.7", classifiers=[ diff --git a/testing/constraints-3.10.txt b/testing/constraints-3.10.txt index ed7f9aed25..6c3e6c5bbc 100644 --- a/testing/constraints-3.10.txt +++ b/testing/constraints-3.10.txt @@ -1,6 +1,10 @@ # -*- coding: utf-8 -*- # This constraints file is required for unit tests. # List all library dependencies and extras in this file. -google-api-core -proto-plus -protobuf +google-api-core==1.32.0 +proto-plus==1.22.0 +protobuf==3.19.5 +mock==4.0.2 +google-cloud-storage==2.0.0 +packaging==20.0 # Increased for compatibility with MLFlow +grpcio-testing==1.34.0 diff --git a/testing/constraints-3.7.txt b/testing/constraints-3.7.txt index 8f27a26d3d..9f1b48e7ae 100644 --- a/testing/constraints-3.7.txt +++ b/testing/constraints-3.7.txt @@ -9,6 +9,6 @@ google-api-core==1.32.0 proto-plus==1.22.0 protobuf==3.19.5 mock==4.0.2 -google-cloud-storage==1.32.0 +google-cloud-storage==2.0.0 packaging==20.0 # Increased for compatibility with MLFlow grpcio-testing==1.34.0 diff --git a/testing/constraints-3.8.txt b/testing/constraints-3.8.txt index ed7f9aed25..6c3e6c5bbc 100644 --- a/testing/constraints-3.8.txt +++ b/testing/constraints-3.8.txt @@ -1,6 +1,10 @@ # -*- coding: utf-8 -*- # This constraints file is required for unit tests. # List all library dependencies and extras in this file. -google-api-core -proto-plus -protobuf +google-api-core==1.32.0 +proto-plus==1.22.0 +protobuf==3.19.5 +mock==4.0.2 +google-cloud-storage==2.0.0 +packaging==20.0 # Increased for compatibility with MLFlow +grpcio-testing==1.34.0 diff --git a/testing/constraints-3.9.txt b/testing/constraints-3.9.txt index ed7f9aed25..6c3e6c5bbc 100644 --- a/testing/constraints-3.9.txt +++ b/testing/constraints-3.9.txt @@ -1,6 +1,10 @@ # -*- coding: utf-8 -*- # This constraints file is required for unit tests. # List all library dependencies and extras in this file. -google-api-core -proto-plus -protobuf +google-api-core==1.32.0 +proto-plus==1.22.0 +protobuf==3.19.5 +mock==4.0.2 +google-cloud-storage==2.0.0 +packaging==20.0 # Increased for compatibility with MLFlow +grpcio-testing==1.34.0 diff --git a/tests/system/aiplatform/e2e_base.py b/tests/system/aiplatform/e2e_base.py index 9ff3eb5ee8..c3a8846d66 100644 --- a/tests/system/aiplatform/e2e_base.py +++ b/tests/system/aiplatform/e2e_base.py @@ -26,6 +26,7 @@ from google.api_core import exceptions from google.cloud import aiplatform +import vertexai from google.cloud import bigquery from google.cloud import resourcemanager from google.cloud import storage @@ -62,6 +63,7 @@ def _make_display_name(cls, key: str) -> str: def setup_method(self): importlib.reload(initializer) importlib.reload(aiplatform) + importlib.reload(vertexai) @pytest.fixture(scope="class") def shared_state(self) -> Generator[Dict[str, Any], None, None]: diff --git a/tests/system/vertexai/test_pytorch.py b/tests/system/vertexai/test_pytorch.py new file mode 100644 index 0000000000..140a819260 --- /dev/null +++ b/tests/system/vertexai/test_pytorch.py @@ -0,0 +1,158 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +from unittest import mock + +import vertexai +from tests.system.aiplatform import e2e_base +from vertexai.preview._workflow.executor import training +import pytest +from sklearn.datasets import load_iris +import torch +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler + + +@mock.patch.object( + training, + "VERTEX_AI_DEPENDENCY_PATH", + "google-cloud-aiplatform[preview] @ git+https://github.com/googleapis/" + f"python-aiplatform.git@{os.environ['KOKORO_GIT_COMMIT']}" + if os.environ.get("KOKORO_GIT_COMMIT") + else "google-cloud-aiplatform[preview] @ git+https://github.com/googleapis/python-aiplatform.git@copybara_557913723", +) +@mock.patch.object( + training, + "VERTEX_AI_DEPENDENCY_PATH_AUTOLOGGING", + "google-cloud-aiplatform[preview,autologging] @ git+https://github.com/googleapis/" + f"python-aiplatform.git@{os.environ['KOKORO_GIT_COMMIT']}" + if os.environ.get("KOKORO_GIT_COMMIT") + else "google-cloud-aiplatform[preview,autologging] @ git+https://github.com/googleapis/python-aiplatform.git@copybara_557913723", +) +@pytest.mark.usefixtures( + "prepare_staging_bucket", "delete_staging_bucket", "tear_down_resources" +) +class TestRemoteExecutionPytorch(e2e_base.TestEndToEnd): + + _temp_prefix = "temp-vertexai-remote-execution" + + def test_remote_execution_pytorch(self, shared_state): + # Define the pytorch custom model + class TorchLogisticRegression(vertexai.preview.VertexModel, torch.nn.Module): + def __init__(self, input_size: int, output_size: int): + torch.nn.Module.__init__(self) + vertexai.preview.VertexModel.__init__(self) + self.linear = torch.nn.Linear(input_size, output_size) + self.softmax = torch.nn.Softmax(dim=1) + + def forward(self, x): + return self.softmax(self.linear(x)) + + @vertexai.preview.developer.mark.train() + def train(self, dataloader, num_epochs, lr): + criterion = torch.nn.CrossEntropyLoss() + optimizer = torch.optim.SGD(self.parameters(), lr=lr) + + for t in range(num_epochs): + for idx, batch in enumerate(dataloader): + # move data to the same device as model + device = next(self.parameters()).device + x, y = batch[0].to(device), batch[1].to(device) + + optimizer.zero_grad() + pred = self(x) + loss = criterion(pred, y) + loss.backward() + optimizer.step() + + @vertexai.preview.developer.mark.predict() + def predict(self, X): + X = torch.tensor(X).to(torch.float32) + with torch.no_grad(): + pred = torch.argmax(self(X), dim=1) + return pred + + # Initialize vertexai + vertexai.init( + project=e2e_base._PROJECT, + location=e2e_base._LOCATION, + staging_bucket=f"gs://{shared_state['staging_bucket_name']}", + ) + + # Prepare dataset + dataset = load_iris() + + X, X_retrain, y, y_retrain = train_test_split( + dataset.data, dataset.target, test_size=0.60, random_state=42 + ) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.20, random_state=42 + ) + + transformer = StandardScaler() + X_train = transformer.fit_transform(X_train) + X_test = transformer.transform(X_test) + X_retrain = transformer.transform(X_retrain) + + train_loader = torch.utils.data.DataLoader( + torch.utils.data.TensorDataset( + torch.tensor(X_train).to(torch.float32), + torch.tensor(y_train), + ), + batch_size=10, + shuffle=True, + ) + + retrain_loader = torch.utils.data.DataLoader( + torch.utils.data.TensorDataset( + torch.tensor(X_retrain).to(torch.float32), + torch.tensor(y_retrain), + ), + batch_size=10, + shuffle=True, + ) + + # Remote CPU training on Torch custom model + vertexai.preview.init(remote=True) + + model = TorchLogisticRegression(4, 3) + model.train.vertex.remote_config.display_name = self._make_display_name( + "pytorch-cpu-training" + ) + model.train(train_loader, num_epochs=100, lr=0.05) + + # Remote prediction on Torch custom model + model.predict.vertex.remote_config.display_name = self._make_display_name( + "pytorch-prediction" + ) + model.predict(X_test) + + # Register trained model + registered_model = vertexai.preview.register(model) + shared_state["resources"] = [registered_model] + + # Load the registered model + pulled_model = vertexai.preview.from_pretrained( + model_name=registered_model.resource_name + ) + + # Uptrain the pretrained model on CPU + pulled_model.train.vertex.remote_config.display_name = self._make_display_name( + "pytorch-cpu-uptraining" + ) + pulled_model.train(retrain_loader, num_epochs=100, lr=0.05) diff --git a/tests/system/vertexai/test_sklearn.py b/tests/system/vertexai/test_sklearn.py new file mode 100644 index 0000000000..fb7d37f4d5 --- /dev/null +++ b/tests/system/vertexai/test_sklearn.py @@ -0,0 +1,124 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +from unittest import mock + +import vertexai +from tests.system.aiplatform import e2e_base +from vertexai.preview._workflow.executor import training +import pandas as pd +import pytest +from sklearn.datasets import load_iris +from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler + + +# Wrap classes +StandardScaler = vertexai.preview.remote(StandardScaler) +LogisticRegression = vertexai.preview.remote(LogisticRegression) + + +@mock.patch.object( + training, + "VERTEX_AI_DEPENDENCY_PATH", + "google-cloud-aiplatform[preview] @ git+https://github.com/googleapis/" + f"python-aiplatform.git@{os.environ['KOKORO_GIT_COMMIT']}" + if os.environ.get("KOKORO_GIT_COMMIT") + else "google-cloud-aiplatform[preview] @ git+https://github.com/googleapis/python-aiplatform.git@copybara_557913723", +) +@mock.patch.object( + training, + "VERTEX_AI_DEPENDENCY_PATH_AUTOLOGGING", + "google-cloud-aiplatform[preview,autologging] @ git+https://github.com/googleapis/" + f"python-aiplatform.git@{os.environ['KOKORO_GIT_COMMIT']}" + if os.environ.get("KOKORO_GIT_COMMIT") + else "google-cloud-aiplatform[preview,autologging] @ git+https://github.com/googleapis/python-aiplatform.git@copybara_557913723", +) +@pytest.mark.usefixtures( + "prepare_staging_bucket", "delete_staging_bucket", "tear_down_resources" +) +class TestRemoteExecutionSklearn(e2e_base.TestEndToEnd): + + _temp_prefix = "temp-vertexai-remote-execution" + + def test_remote_execution_sklearn(self, shared_state): + # Initialize vertexai + vertexai.init( + project=e2e_base._PROJECT, + location=e2e_base._LOCATION, + staging_bucket=f"gs://{shared_state['staging_bucket_name']}", + ) + + # Prepare dataset + dataset = load_iris() + X, X_retrain, y, y_retrain = train_test_split( + dataset.data, dataset.target, test_size=0.60, random_state=42 + ) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.20, random_state=42 + ) + + # Remote fit_transform on train dataset + vertexai.preview.init(remote=True) + + transformer = StandardScaler() + transformer.fit_transform.vertex.set_config( + display_name=self._make_display_name("fit-transform"), + ) + X_train = transformer.fit_transform(X_train) + + # Remote transform on test dataset + transformer.transform.vertex.set_config( + display_name=self._make_display_name("transform"), + ) + X_test = transformer.transform(X_test) + + # Local transform on retrain data + vertexai.preview.init(remote=False) + X_retrain = transformer.transform(X_retrain) + # Transform retrain dataset to pandas dataframe + X_retrain_df = pd.DataFrame(X_retrain, columns=dataset.feature_names) + y_retrain_df = pd.DataFrame(y_retrain, columns=["class"]) + + # Remote training on sklearn + vertexai.preview.init(remote=True) + + model = LogisticRegression(warm_start=True) + model.fit.vertex.remote_config.display_name = self._make_display_name( + "sklearn-training" + ) + model.fit(X_train, y_train) + + # Remote prediction on sklearn + model.predict.vertex.remote_config.display_name = self._make_display_name( + "sklearn-prediction" + ) + model.predict(X_test) + + # Register trained model + registered_model = vertexai.preview.register(model) + shared_state["resources"] = [registered_model] + + # Load the registered model + pulled_model = vertexai.preview.from_pretrained( + model_name=registered_model.resource_name + ) + + # Retrain model with pandas df on Vertex + pulled_model.fit(X_retrain_df, y_retrain_df) diff --git a/tests/system/vertexai/test_tensorflow.py b/tests/system/vertexai/test_tensorflow.py new file mode 100644 index 0000000000..5d2beaac7a --- /dev/null +++ b/tests/system/vertexai/test_tensorflow.py @@ -0,0 +1,124 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +from unittest import mock + +import vertexai +from tests.system.aiplatform import e2e_base +from vertexai.preview._workflow.executor import training +import pytest +from sklearn.datasets import load_iris +import tensorflow as tf +from tensorflow import keras +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler + + +# Wrap classes +keras.Sequential = vertexai.preview.remote(keras.Sequential) + + +@mock.patch.object( + training, + "VERTEX_AI_DEPENDENCY_PATH", + "google-cloud-aiplatform[preview] @ git+https://github.com/googleapis/" + f"python-aiplatform.git@{os.environ['KOKORO_GIT_COMMIT']}" + if os.environ.get("KOKORO_GIT_COMMIT") + else "google-cloud-aiplatform[preview] @ git+https://github.com/googleapis/python-aiplatform.git@copybara_557913723", +) +@mock.patch.object( + training, + "VERTEX_AI_DEPENDENCY_PATH_AUTOLOGGING", + "google-cloud-aiplatform[preview,autologging] @ git+https://github.com/googleapis/" + f"python-aiplatform.git@{os.environ['KOKORO_GIT_COMMIT']}" + if os.environ.get("KOKORO_GIT_COMMIT") + else "google-cloud-aiplatform[preview,autologging] @ git+https://github.com/googleapis/python-aiplatform.git@copybara_557913723", +) +@pytest.mark.usefixtures( + "prepare_staging_bucket", "delete_staging_bucket", "tear_down_resources" +) +class TestRemoteExecutionTensorflow(e2e_base.TestEndToEnd): + + _temp_prefix = "temp-vertexai-remote-execution" + + def test_remote_execution_keras(self, shared_state): + # Initialize vertexai + vertexai.init( + project=e2e_base._PROJECT, + location=e2e_base._LOCATION, + staging_bucket=f"gs://{shared_state['staging_bucket_name']}", + ) + + # Prepare dataset + dataset = load_iris() + + X, X_retrain, y, y_retrain = train_test_split( + dataset.data, dataset.target, test_size=0.60, random_state=42 + ) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.20, random_state=42 + ) + + transformer = StandardScaler() + X_train = transformer.fit_transform(X_train) + X_test = transformer.transform(X_test) + X_retrain = transformer.transform(X_retrain) + + tf_train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)) + tf_train_dataset = tf_train_dataset.shuffle(buffer_size=64).batch(32) + + tf_retrain_dataset = tf.data.Dataset.from_tensor_slices((X_retrain, y_retrain)) + tf_retrain_dataset = tf_retrain_dataset.shuffle(buffer_size=64).batch(32) + + tf_test_dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test)) + tf_prediction_test_data = tf_test_dataset + tf_remote_prediction_test_data = tf_prediction_test_data.batch(32) + + # Remote GPU training on Keras + vertexai.preview.init(remote=True) + + model = keras.Sequential( + [keras.layers.Dense(5, input_shape=(4,)), keras.layers.Softmax()] + ) + model.compile(optimizer="adam", loss="mean_squared_error") + model.fit.vertex.set_config( + enable_cuda=True, display_name=self._make_display_name("keras-gpu-training") + ) + model.fit(tf_train_dataset, epochs=10) + + # Remote prediction on keras + model.predict.vertex.remote_config.display_name = self._make_display_name( + "keras-prediction" + ) + model.predict(tf_remote_prediction_test_data) + + # Register trained model + registered_model = vertexai.preview.register(model) + shared_state["resources"] = [registered_model] + + # Load the registered model + pulled_model = vertexai.preview.from_pretrained( + model_name=registered_model.resource_name + ) + + # Uptrain the pretrained model on CPU + pulled_model.fit.vertex.remote_config.enable_cuda = False + pulled_model.fit.vertex.remote_config.display_name = self._make_display_name( + "keras-cpu-uptraining" + ) + pulled_model.fit(tf_retrain_dataset, epochs=10) diff --git a/tests/unit/vertexai/conftest.py b/tests/unit/vertexai/conftest.py new file mode 100644 index 0000000000..2113267a49 --- /dev/null +++ b/tests/unit/vertexai/conftest.py @@ -0,0 +1,294 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import copy +import os +import shutil +import tempfile +from typing import Any +from unittest import mock +import uuid + +from google import auth +from google.api_core import operation as ga_operation +from google.auth import credentials as auth_credentials +from google.cloud.logging import Logger +from google.cloud import storage +from google.cloud.aiplatform.compat.services import job_service_client +from google.cloud.aiplatform.compat.types import ( + custom_job as gca_custom_job_compat, +) +from google.cloud.aiplatform.compat.types import io as gca_io_compat +from google.cloud.aiplatform.compat.types import ( + job_state as gca_job_state_compat, +) +from google.cloud.aiplatform_v1beta1.services.persistent_resource_service import ( + PersistentResourceServiceClient, +) +import constants as test_constants +from pyfakefs import fake_filesystem_unittest +import pytest +import tensorflow.saved_model as tf_saved_model + + +_TEST_PROJECT = "test-project" +_TEST_PROJECT_NUMBER = "12345678" +_TEST_LOCATION = "us-central1" +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" +_TEST_DISPLAY_NAME = f"{_TEST_PARENT}/customJobs/12345" +_TEST_BUCKET_NAME = "gs://test_bucket" +_TEST_BASE_OUTPUT_DIR = f"{_TEST_BUCKET_NAME}/test_base_output_dir" + +_TEST_INPUTS = [ + "--arg_0=string_val_0", + "--arg_1=string_val_1", + "--arg_2=int_val_0", + "--arg_3=int_val_1", +] +_TEST_IMAGE_URI = "test_image_uri" +_TEST_MACHINE_TYPE = "test_machine_type" +_TEST_WORKER_POOL_SPEC = [ + { + "machine_spec": { + "machine_type": _TEST_MACHINE_TYPE, + }, + "replica_count": 1, + "container_spec": { + "image_uri": _TEST_IMAGE_URI, + "args": _TEST_INPUTS, + }, + } +] +_TEST_CUSTOM_JOB_PROTO = gca_custom_job_compat.CustomJob( + display_name=_TEST_DISPLAY_NAME, + job_spec={ + "worker_pool_specs": _TEST_WORKER_POOL_SPEC, + "base_output_directory": gca_io_compat.GcsDestination( + output_uri_prefix=_TEST_BASE_OUTPUT_DIR + ), + }, +) + + +@pytest.fixture(scope="module") +def google_auth_mock(): + with mock.patch.object(auth, "default") as auth_mock: + auth_mock.return_value = ( + auth_credentials.AnonymousCredentials(), + "test-project", + ) + yield auth_mock + + +@pytest.fixture +def mock_filesystem(): + with fake_filesystem_unittest.Patcher() as patcher: + patcher.setUp() + yield patcher.fs + patcher.tearDown() + + +@pytest.fixture +def mock_storage_blob(mock_filesystem): + """Mocks the storage Blob API. + + Replaces the Blob factory method by a simpler method that records the + destination_file_uri and, instead of uploading the file to gcs, copying it + to the fake local file system. + """ + + class MockStorageBlob: + """Mocks storage.Blob.""" + + def __init__(self, destination_file_uri: str, client: Any): + del client + self.destination_file_uri = destination_file_uri + + @classmethod + def from_string(cls, destination_file_uri: str, client: Any): + if destination_file_uri.startswith("gs://"): + # Do not copy files to gs:// since it's not a valid path in the fake + # filesystem. + destination_file_uri = destination_file_uri.split("/")[-1] + return cls(destination_file_uri, client) + + def upload_from_filename(self, filename: str): + shutil.copy(filename, self.destination_file_uri) + + def download_to_filename(self, filename: str): + """To be replaced by an implementation of testing needs.""" + raise NotImplementedError + + with mock.patch.object(storage, "Blob", new=MockStorageBlob) as storage_blob: + yield storage_blob + + +@pytest.fixture +def mock_storage_blob_tmp_dir(tmp_path): + """Mocks the storage Blob API. + + Replaces the Blob factory method by a simpler method that records the + destination_file_uri and, instead of uploading the file to gcs, copying it + to a temporaray path in the local file system. + """ + + class MockStorageBlob: + """Mocks storage.Blob.""" + + def __init__(self, destination_file_uri: str, client: Any): + del client + self.destination_file_uri = destination_file_uri + + @classmethod + def from_string(cls, destination_file_uri: str, client: Any): + if destination_file_uri.startswith("gs://"): + # Do not copy files to gs:// since it's not a valid path in the fake + # filesystem. + destination_file_uri = os.fspath( + tmp_path / destination_file_uri.split("/")[-1] + ) + return cls(destination_file_uri, client) + + def upload_from_filename(self, filename: str): + shutil.copy(filename, self.destination_file_uri) + + def download_to_filename(self, filename: str): + """To be replaced by an implementation of testing needs.""" + raise NotImplementedError + + with mock.patch.object(storage, "Blob", new=MockStorageBlob) as storage_blob: + yield storage_blob + + +@pytest.fixture +def mock_gcs_upload(): + def fake_upload_to_gcs(local_filename: str, gcs_destination: str): + if gcs_destination.startswith("gs://") or gcs_destination.startswith("gcs/"): + raise ValueError("Please don't use the real gcs path with mock_gcs_upload.") + # instead of upload, just copy the file. + shutil.copyfile(local_filename, gcs_destination) + + with mock.patch( + "google.cloud.aiplatform.utils.gcs_utils.upload_to_gcs", + new=fake_upload_to_gcs, + ) as gcs_upload: + yield gcs_upload + + +@pytest.fixture +def mock_temp_dir(): + with mock.patch.object(tempfile, "TemporaryDirectory") as temp_dir_mock: + yield temp_dir_mock + + +@pytest.fixture +def mock_named_temp_file(): + with mock.patch.object(tempfile, "NamedTemporaryFile") as named_temp_file_mock: + yield named_temp_file_mock + + +@pytest.fixture +def mock_create_custom_job(): + with mock.patch.object( + job_service_client.JobServiceClient, "create_custom_job" + ) as create_custom_job_mock: + custom_job_proto = copy.deepcopy(_TEST_CUSTOM_JOB_PROTO) + custom_job_proto.name = _TEST_DISPLAY_NAME + custom_job_proto.state = gca_job_state_compat.JobState.JOB_STATE_PENDING + create_custom_job_mock.return_value = custom_job_proto + yield create_custom_job_mock + + +@pytest.fixture +def mock_get_custom_job_succeeded(): + with mock.patch.object( + job_service_client.JobServiceClient, "get_custom_job" + ) as get_custom_job_mock: + custom_job_proto = copy.deepcopy(_TEST_CUSTOM_JOB_PROTO) + custom_job_proto.name = _TEST_DISPLAY_NAME + custom_job_proto.state = gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED + get_custom_job_mock.return_value = custom_job_proto + yield get_custom_job_mock + + +@pytest.fixture +def mock_blob_upload_from_filename(): + with mock.patch.object(storage.Blob, "upload_from_filename") as upload_mock: + yield upload_mock + + +@pytest.fixture +def mock_blob_download_to_filename(): + with mock.patch.object(storage.Blob, "download_to_filename") as download_mock: + yield download_mock + + +@pytest.fixture +def mock_uuid(): + with mock.patch.object(uuid, "uuid4") as uuid_mock: + uuid_mock.return_value = 0 + yield uuid_mock + + +@pytest.fixture +def mock_tf_saved_model_load(): + with mock.patch.object(tf_saved_model, "load") as load_mock: + yield load_mock + + +@pytest.fixture +def mock_cloud_logging_list_entries(): + with mock.patch.object(Logger, "list_entries") as list_entries_mock: + list_entries_mock.return_value = [] + yield list_entries_mock + + +@pytest.fixture +def persistent_resource_running_mock(): + with mock.patch.object( + PersistentResourceServiceClient, + "get_persistent_resource", + ) as persistent_resource_running_mock: + persistent_resource_running_mock.return_value = ( + test_constants._TEST_PERSISTENT_RESOURCE_RUNNING + ) + yield persistent_resource_running_mock + + +@pytest.fixture +def persistent_resource_exception_mock(): + with mock.patch.object( + PersistentResourceServiceClient, + "get_persistent_resource", + ) as persistent_resource_exception_mock: + persistent_resource_exception_mock.side_effect = Exception + yield persistent_resource_exception_mock + + +@pytest.fixture +def create_persistent_resource_default_mock(): + with mock.patch.object( + PersistentResourceServiceClient, + "create_persistent_resource", + ) as create_persistent_resource_default_mock: + create_persistent_resource_lro_mock = mock.Mock(ga_operation.Operation) + create_persistent_resource_lro_mock.result.return_value = ( + test_constants._TEST_REQUEST_RUNNING_DEFAULT + ) + create_persistent_resource_default_mock.return_value = ( + create_persistent_resource_lro_mock + ) + yield create_persistent_resource_default_mock diff --git a/tests/unit/vertexai/constants.py b/tests/unit/vertexai/constants.py new file mode 100644 index 0000000000..deb469f3d7 --- /dev/null +++ b/tests/unit/vertexai/constants.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from vertexai.preview._workflow.shared import configs +from google.cloud.aiplatform_v1beta1.types.persistent_resource import ( + PersistentResource, + ResourcePool, +) + +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" +_TEST_CLUSTER_NAME = "test-cluster" +_TEST_CLUSTER_CONFIG = configs.PersistentResourceConfig(name=_TEST_CLUSTER_NAME) +_TEST_CLUSTER_RESOURCE_NAME = f"{_TEST_PARENT}/persistentResources/{_TEST_CLUSTER_NAME}" + + +_TEST_PERSISTENT_RESOURCE_ERROR = PersistentResource() +_TEST_PERSISTENT_RESOURCE_ERROR.state = "ERROR" + +# move to constants +_TEST_REQUEST_RUNNING_DEFAULT = PersistentResource() +resource_pool = ResourcePool() +resource_pool.machine_spec.machine_type = "n1-standard-4" +resource_pool.replica_count = 1 +resource_pool.disk_spec.boot_disk_type = "pd-ssd" +resource_pool.disk_spec.boot_disk_size_gb = 100 +_TEST_REQUEST_RUNNING_DEFAULT.resource_pools = [resource_pool] + + +_TEST_PERSISTENT_RESOURCE_RUNNING = PersistentResource() +_TEST_PERSISTENT_RESOURCE_RUNNING.state = "RUNNING" diff --git a/tests/unit/vertexai/test_any_serializer.py b/tests/unit/vertexai/test_any_serializer.py new file mode 100644 index 0000000000..01ed54bcd3 --- /dev/null +++ b/tests/unit/vertexai/test_any_serializer.py @@ -0,0 +1,1023 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import mock +import pytest + +import cloudpickle +import json +import os +from typing import Any + +from vertexai.preview import developer +from vertexai.preview._workflow.serialization_engine import ( + any_serializer, + serializers, + serializers_base, +) +from vertexai.preview._workflow.shared import constants + +import pandas as pd +from sklearn.linear_model import LogisticRegression +import tensorflow as tf +from tensorflow import keras +import torch + + +@pytest.fixture +def any_serializer_instance(): + return any_serializer.AnySerializer() + + +@pytest.fixture +def torch_dataloader_serializer(): + return serializers.TorchDataLoaderSerializer() + + +@pytest.fixture +def bigframe_serializer(): + return serializers.BigframeSerializer() + + +@pytest.fixture +def tf_dataset_serializer(): + return serializers.TFDatasetSerializer() + + +@pytest.fixture +def mock_keras_model_serialize(): + def stateful_serialize(self, to_serialize, gcs_path): + del self, to_serialize, gcs_path + serializers.KerasModelSerializer._metadata.dependencies = ["keras==1.0.0"] + + with mock.patch.object( + serializers.KerasModelSerializer, "serialize", new=stateful_serialize + ) as keras_model_serialize: + yield keras_model_serialize + serializers.KerasModelSerializer._metadata.dependencies = [] + + +@pytest.fixture +def mock_keras_model_deserialize(): + with mock.patch.object( + serializers.KerasModelSerializer, "deserialize", autospec=True + ) as keras_model_deserialize: + yield keras_model_deserialize + + +@pytest.fixture +def mock_sklearn_estimator_serialize(): + def stateful_serialize(self, to_serialize, gcs_path): + del self, to_serialize, gcs_path + serializers.SklearnEstimatorSerializer._metadata.dependencies = [ + "sklearn_dependency1==1.0.0" + ] + + with mock.patch.object( + serializers.SklearnEstimatorSerializer, + "serialize", + new=stateful_serialize, + ) as sklearn_estimator_serialize: + yield sklearn_estimator_serialize + serializers.SklearnEstimatorSerializer._metadata.dependencies = [] + + +@pytest.fixture +def mock_sklearn_estimator_deserialize(): + with mock.patch.object( + serializers.SklearnEstimatorSerializer, "deserialize", autospec=True + ) as sklearn_estimator_deserialize: + yield sklearn_estimator_deserialize + + +@pytest.fixture +def mock_torch_model_serialize(): + def stateful_serialize(self, to_serialize, gcs_path): + del self, to_serialize, gcs_path + serializers.TorchModelSerializer._metadata.dependencies = ["torch==1.0.0"] + + with mock.patch.object( + serializers.TorchModelSerializer, "serialize", new=stateful_serialize + ) as torch_model_serialize: + yield torch_model_serialize + serializers.TorchModelSerializer._metadata.dependencies = [] + + +@pytest.fixture +def mock_torch_model_deserialize(): + with mock.patch.object( + serializers.TorchModelSerializer, "deserialize", autospec=True + ) as torch_model_deserialize: + yield torch_model_deserialize + + +@pytest.fixture +def mock_torch_dataloader_serialize(tmp_path): + def stateful_serialize(self, to_serialize, gcs_path): + del self, to_serialize, gcs_path + serializers.TorchDataLoaderSerializer._metadata.dependencies = ["torch==1.0.0"] + + with mock.patch.object( + serializers.TorchDataLoaderSerializer, "serialize", new=stateful_serialize + ) as torch_dataloader_serialize: + yield torch_dataloader_serialize + serializers.TorchDataLoaderSerializer._metadata.dependencies = [] + + +@pytest.fixture +def mock_torch_dataloader_deserialize(): + with mock.patch.object( + serializers.TorchDataLoaderSerializer, "deserialize", autospec=True + ) as torch_dataloader_serializer: + yield torch_dataloader_serializer + + +@pytest.fixture +def mock_tf_dataset_serialize(tmp_path): + def stateful_serialize(self, to_serialize, gcs_path): + del gcs_path + serializers.TFDatasetSerializer._metadata.dependencies = ["tensorflow==1.0.0"] + try: + to_serialize.save(str(tmp_path / "tf_dataset")) + except AttributeError: + tf.data.experimental.save(to_serialize, str(tmp_path / "tf_dataset")) + + with mock.patch.object( + serializers.TFDatasetSerializer, "serialize", new=stateful_serialize + ) as tf_dataset_serialize: + yield tf_dataset_serialize + serializers.TFDatasetSerializer._metadata.dependencies = [] + + +@pytest.fixture +def mock_tf_dataset_deserialize(): + with mock.patch.object( + serializers.TFDatasetSerializer, "deserialize", autospec=True + ) as tf_dataset_serializer: + yield tf_dataset_serializer + + +@pytest.fixture +def mock_pandas_data_serialize(): + def stateful_serialize(self, to_serialize, gcs_path): + del self, to_serialize, gcs_path + serializers.PandasDataSerializer._metadata.dependencies = ["pandas==1.0.0"] + + with mock.patch.object( + serializers.PandasDataSerializer, "serialize", new=stateful_serialize + ) as data_serialize: + yield data_serialize + serializers.PandasDataSerializer._metadata.dependencies = [] + + +@pytest.fixture +def mock_pandas_data_deserialize(): + with mock.patch.object( + serializers.PandasDataSerializer, "deserialize", autospec=True + ) as pandas_data_deserialize: + yield pandas_data_deserialize + + +# TODO(b/295338623): Test correctness of Bigframes serialize/deserialize +@pytest.fixture +def mock_bigframe_deserialize_sklearn(): + with mock.patch.object( + serializers.BigframeSerializer, "_deserialize_sklearn", autospec=True + ) as bigframe_deserialize_sklearn: + yield bigframe_deserialize_sklearn + + +# TODO(b/295338623): Test correctness of Bigframes serialize/deserialize +@pytest.fixture +def mock_bigframe_deserialize_torch(): + with mock.patch.object( + serializers.BigframeSerializer, "_deserialize_torch", autospec=True + ) as bigframe_deserialize_torch: + yield bigframe_deserialize_torch + + +# TODO(b/295338623): Test correctness of Bigframes serialize/deserialize +@pytest.fixture +def mock_bigframe_deserialize_tensorflow(): + with mock.patch.object( + serializers.BigframeSerializer, "_deserialize_tensorflow", autospec=True + ) as bigframe_deserialize_tensorflow: + yield bigframe_deserialize_tensorflow + + +@pytest.fixture +def mock_cloudpickle_serialize(): + def stateful_serialize(self, to_serialize, gcs_path): + del self, to_serialize, gcs_path + serializers.CloudPickleSerializer._metadata.dependencies = [ + "cloudpickle==1.0.0" + ] + + with mock.patch.object( + serializers.CloudPickleSerializer, "serialize", new=stateful_serialize + ) as cloudpickle_serialize: + yield cloudpickle_serialize + serializers.CloudPickleSerializer._metadata.dependencies = [] + + +@pytest.fixture +def mock_cloudpickle_deserialize(): + with mock.patch.object( + serializers.CloudPickleSerializer, "deserialize", autospec=True + ) as cloudpickle_deserialize: + yield cloudpickle_deserialize + + +class TestTorchClass(torch.nn.Module): + def __init__(self, input_size=4): + super().__init__() + self.linear_relu_stack = torch.nn.Sequential( + torch.nn.Linear(input_size, 3), torch.nn.ReLU(), torch.nn.Linear(3, 2) + ) + + def forward(self, x): + logits = self.linear_relu_stack(x) + return logits + + +class TestAnySerializer: + """Tests that AnySerializer is acting as 'controller' and router.""" + + @mock.patch.object(serializers.CloudPickleSerializer, "serialize", autospec=True) + def test_any_serializer_serialize_custom_model_with_custom_serializer( + self, mock_cloudpickle_serializer_serialize, any_serializer_instance, tmp_path + ): + # Arrange + class CustomModel: + def __init__(self, weight: int = 0): + self.weight = weight + + @developer.mark.train() + def fit(self, X_train, y_train) -> "CustomModel": + self.weight += 1 + return self + + class CustomSerializer(developer.Serializer): + _metadata = developer.SerializationMetadata() + + def serialize( + self, to_serialize: CustomModel, gcs_path: str, extra_para: Any + ) -> str: + del extra_para + return gcs_path + + def deserialize(self, serialized_gcs_path: str) -> CustomModel: + # Pretend that the model is trained + return CustomModel(weight=1) + + CustomSerializer.register_requirements(["custom_dependency==1.0.0"]) + developer.register_serializer(CustomModel, CustomSerializer) + + fake_gcs_path = os.fspath(tmp_path / "job_id/input/input_estimator") + os.makedirs(tmp_path / "job_id/input") + custom_model = CustomModel() + + # Act + any_serializer_instance.serialize(custom_model, fake_gcs_path, extra_para=1) + + # Assert + metadata_path = ( + tmp_path + / f"job_id/input/{serializers_base.SERIALIZATION_METADATA_FILENAME}_input_estimator.json" + ) + + # Metadata should have the correct serializer information + with open(metadata_path, "rb") as f: + metadata = json.load(f) + # Metadata should record the dependency specifiers + assert ( + metadata[serializers_base.SERIALIZATION_METADATA_SERIALIZER_KEY] + == "CustomSerializer" + ) + assert metadata[serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY] == [ + "custom_dependency==1.0.0" + ] + + # During the serialization of the CustomModel object, we also serialize + # the serializer with CloudPicleSerializer. + custom_serializer_path = tmp_path / "job_id/input/CustomSerializer" + mock_cloudpickle_serializer_serialize.assert_called_once_with( + any_serializer_instance._instances[serializers.CloudPickleSerializer], + any_serializer_instance._instances[CustomSerializer], + str(custom_serializer_path), + ) + + @pytest.mark.usefixtures("mock_gcs_upload") + def test_any_serializer_serialize_sklearn_estimator( + self, any_serializer_instance, tmp_path, mock_sklearn_estimator_serialize + ): + # Arrange + fake_gcs_path = os.fspath(tmp_path / "job_id/input/input_estimator") + os.makedirs(tmp_path / "job_id/input") + sklearn_estimator = LogisticRegression() + + # Act + any_serializer_instance.serialize(sklearn_estimator, fake_gcs_path) + + # Assert + metadata_path = ( + tmp_path + / f"job_id/input/{serializers_base.SERIALIZATION_METADATA_FILENAME}_input_estimator.json" + ) + + # Metadata should have the correct serializer information + with open(metadata_path, "rb") as f: + metadata = json.load(f) + assert ( + metadata[serializers_base.SERIALIZATION_METADATA_SERIALIZER_KEY] + == "SklearnEstimatorSerializer" + ) + assert metadata[serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY] == [ + "sklearn_dependency1==1.0.0" + ] + + @pytest.mark.usefixtures("mock_gcs_upload") + def test_any_serializer_serialize_keras_model( + self, any_serializer_instance, tmp_path, mock_keras_model_serialize + ): + # Arrange + fake_gcs_path = os.fspath(tmp_path / "job_id/input/input_estimator") + os.makedirs(tmp_path / "job_id/input") + keras_model = keras.Sequential( + [keras.layers.Dense(5, input_shape=(4,)), keras.layers.Softmax()] + ) + + # Act + any_serializer_instance.serialize(keras_model, fake_gcs_path) + + # Assert + metadata_path = ( + tmp_path + / f"job_id/input/{serializers_base.SERIALIZATION_METADATA_FILENAME}_input_estimator.json" + ) + + # Metadata should have the correct serializer information + with open(metadata_path, "rb") as f: + metadata = json.load(f) + assert ( + metadata[serializers_base.SERIALIZATION_METADATA_SERIALIZER_KEY] + == "KerasModelSerializer" + ) + assert metadata[serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY] == [ + "keras==1.0.0" + ] + + @pytest.mark.usefixtures("mock_gcs_upload") + def test_any_serializer_serialize_torch_model( + self, any_serializer_instance, tmp_path, mock_torch_model_serialize + ): + # Arrange + fake_gcs_path = os.fspath(tmp_path / "job_id/input/input_estimator") + os.makedirs(tmp_path / "job_id/input") + torch_model = TestTorchClass() + + # Act + any_serializer_instance.serialize(torch_model, fake_gcs_path) + + # Assert + metadata_path = ( + tmp_path + / f"job_id/input/{serializers_base.SERIALIZATION_METADATA_FILENAME}_input_estimator.json" + ) + + # Metadata should have the correct serializer information + with open(metadata_path, "rb") as f: + metadata = json.load(f) + assert ( + metadata[serializers_base.SERIALIZATION_METADATA_SERIALIZER_KEY] + == "TorchModelSerializer" + ) + assert metadata[serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY] == [ + "torch==1.0.0" + ] + + @pytest.mark.usefixtures("mock_gcs_upload") + def test_any_serializer_serialize_dataframe( + self, any_serializer_instance, tmp_path, mock_pandas_data_serialize + ): + # Arrange + fake_gcs_path = os.fspath(tmp_path / "job_id/input/X") + os.makedirs(tmp_path / "job_id/input") + df = pd.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]}) + + # Act + any_serializer_instance.serialize(df, fake_gcs_path) + + # Assert + metadata_path = ( + tmp_path + / f"job_id/input/{serializers_base.SERIALIZATION_METADATA_FILENAME}_X.json" + ) + + # Metadata should have the correct serializer information + with open(metadata_path, "rb") as f: + metadata = json.load(f) + assert ( + metadata[serializers_base.SERIALIZATION_METADATA_SERIALIZER_KEY] + == "PandasDataSerializer" + ) + assert metadata[serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY] == [ + "pandas==1.0.0" + ] + + @pytest.mark.usefixtures("mock_gcs_upload") + def test_any_serializer_serialize_general_object( + self, any_serializer_instance, tmp_path, mock_cloudpickle_serialize + ): + # Arrange + fake_gcs_path = os.fspath(tmp_path / "job_id/input/general_object.cpkl") + os.makedirs(tmp_path / "job_id/input") + + class TestClass: + pass + + obj = TestClass() + + # Act + any_serializer_instance.serialize(obj, fake_gcs_path) + + # Assert + metadata_path = ( + tmp_path + / f"job_id/input/{serializers_base.SERIALIZATION_METADATA_FILENAME}_general_object.json" + ) + + # Metadata should have the correct serializer information + with open(metadata_path, "rb") as f: + metadata = json.load(f) + assert ( + metadata[serializers_base.SERIALIZATION_METADATA_SERIALIZER_KEY] + == "CloudPickleSerializer" + ) + assert metadata[serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY] == [ + "cloudpickle==1.0.0" + ] + + @pytest.mark.usefixtures("mock_gcs_upload") + def test_any_serializer_serialize_torch_dataloader( + self, any_serializer_instance, tmp_path, mock_torch_dataloader_serialize + ): + # Arrange + fake_gcs_path = os.fspath(tmp_path / "dataloader") + + dataloader = torch.utils.data.DataLoader( + torch.utils.data.TensorDataset( + torch.tensor([[1, 2, 3] for i in range(100)]), + torch.tensor([1] * 100), + ), + batch_size=10, + shuffle=True, + ) + + # Act + any_serializer_instance.serialize(dataloader, fake_gcs_path) + + # Assert + metadata_path = ( + tmp_path + / f"{serializers_base.SERIALIZATION_METADATA_FILENAME}_dataloader.json" + ) + with open(metadata_path, "rb") as f: + metadata = json.load(f) + + assert ( + metadata[serializers_base.SERIALIZATION_METADATA_SERIALIZER_KEY] + == "TorchDataLoaderSerializer" + ) + assert metadata[serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY] == [ + "torch==1.0.0" + ] + + @pytest.mark.usefixtures("mock_tf_dataset_serialize") + def test_any_serializer_serialize_tf_dataset( + self, any_serializer_instance, tmp_path, tf_dataset_serializer + ): + # Arrange + fake_gcs_path = os.fspath(tmp_path / "tf_dataset") + + tf_dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) + + # Act + any_serializer_instance.serialize(tf_dataset, fake_gcs_path) + + # Assert + metadata_path = ( + tmp_path + / f"{serializers_base.SERIALIZATION_METADATA_FILENAME}_tf_dataset.json" + ) + with open(metadata_path, "rb") as f: + metadata = json.load(f) + + assert ( + metadata[serializers_base.SERIALIZATION_METADATA_SERIALIZER_KEY] + == "TFDatasetSerializer" + ) + assert metadata[serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY] == [ + "tensorflow==1.0.0" + ] + + @pytest.mark.usefixtures("mock_gcs_upload") + def test_any_serializer_typed_serializer_failed_falling_back_to_cloudpickle( + self, any_serializer_instance, tmp_path, mock_cloudpickle_serialize + ): + # Arrange + fake_gcs_path = os.fspath(tmp_path / "job_id/input/input_estimator") + os.makedirs(tmp_path / "job_id/input") + keras_model = keras.Sequential( + [keras.layers.Dense(5, input_shape=(4,)), keras.layers.Softmax()] + ) + + with mock.patch.object( + serializers.KerasModelSerializer, "serialize", autospec=True + ) as mock_keras_model_serializer_serialize: + mock_keras_model_serializer_serialize.side_effect = Exception + # Act + any_serializer_instance.serialize(keras_model, fake_gcs_path) + + # Assert + metadata_path = ( + tmp_path + / f"job_id/input/{serializers_base.SERIALIZATION_METADATA_FILENAME}_input_estimator.json" + ) + + # Metadata should have the correct serializer information + with open(metadata_path, "rb") as f: + metadata = json.load(f) + assert ( + metadata[serializers_base.SERIALIZATION_METADATA_SERIALIZER_KEY] + == "CloudPickleSerializer" + ) + assert metadata[serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY] == [ + "cloudpickle==1.0.0" + ] + + def test_any_serializer_cloudpickle_serializer_failed_raise_serialization_error( + self, any_serializer_instance, tmp_path + ): + # Arrange + fake_gcs_path = os.fspath(tmp_path / "job_id/input/general_object.cpkl") + + class TestClass: + pass + + obj = TestClass() + + with mock.patch.object( + serializers.CloudPickleSerializer, "serialize", autospec=True + ) as mock_cloudpickle_serializer_serialize: + mock_cloudpickle_serializer_serialize.side_effect = Exception + # Act & Assert + with pytest.raises(serializers_base.SerializationError): + any_serializer_instance.serialize(obj, fake_gcs_path) + + @pytest.mark.usefixtures("mock_gcs_upload") + @mock.patch.object(any_serializer, "_check_dependency_versions", autospec=True) + def test_any_serializer_deserialize_custom_model_with_custom_serializer( + self, mocked_check_dependency_versions, any_serializer_instance, tmp_path + ): + # Arrange + class CustomModel: + def __init__(self, weight: int = 0): + self.weight = weight + + @developer.mark.train() + def fit(self, X_train, y_train): + self.weight += 1 + return self + + class CustomSerializer(developer.Serializer): + _metadata = developer.SerializationMetadata() + + def serialize(self, to_serialize: CustomModel, gcs_path: str) -> str: + return gcs_path + + def deserialize(self, serialized_gcs_path: str) -> CustomModel: + # Pretend that the model is trained + return CustomModel(weight=1) # noqa: F821 + + developer.register_serializer(CustomModel, CustomSerializer) + CustomSerializer.register_requirements(["custom_dependency==1.0.0"]) + + fake_gcs_path = os.fspath(tmp_path / "job_id/input/custom_model") + os.makedirs(tmp_path / "job_id/input") + metadata_path = ( + tmp_path + / f"job_id/input/{serializers_base.SERIALIZATION_METADATA_FILENAME}_custom_model.json" + ) + + with open(metadata_path, "wb") as f: + f.write( + json.dumps( + { + serializers_base.SERIALIZATION_METADATA_SERIALIZER_KEY: "CustomSerializer", + serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [ + "custom_dependency==1.0.0" + ], + } + ).encode("utf-8") + ) + + custom_serializer_path = tmp_path / "job_id/input/CustomSerializer" + + # Act + with mock.patch.object( + serializers.CloudPickleSerializer, + "deserialize", + autospec=True, + return_value=CustomSerializer(), + ) as mock_cloudpickle_deserialize: + deserialized_custom_model = any_serializer_instance.deserialize( + fake_gcs_path + ) + + # Assert + del CustomModel + deserialized_custom_model.weight = 1 + # CloudPickleSerializer.deserialize() is called to deserialize the + # CustomSerializer. + mock_cloudpickle_deserialize.assert_called_once_with( + any_serializer_instance._instances[serializers.CloudPickleSerializer], + serialized_gcs_path=str(custom_serializer_path), + ) + + @pytest.mark.usefixtures("mock_gcs_upload") + def test_any_serializer_deserialize_sklearn_estimator( + self, any_serializer_instance, tmp_path, mock_sklearn_estimator_deserialize + ): + # Arrange + fake_gcs_path = os.fspath(tmp_path / "job_id/input/input_estimator") + os.makedirs(tmp_path / "job_id/input") + metadata_path = ( + tmp_path + / f"job_id/input/{serializers_base.SERIALIZATION_METADATA_FILENAME}_input_estimator.json" + ) + with open(metadata_path, "wb") as f: + f.write( + json.dumps( + { + serializers_base.SERIALIZATION_METADATA_SERIALIZER_KEY: "SklearnEstimatorSerializer", + serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [], + } + ).encode("utf-8") + ) + + # Act + _ = any_serializer_instance.deserialize(fake_gcs_path) + + # Assert + mock_sklearn_estimator_deserialize.assert_called_once_with( + any_serializer_instance._instances[serializers.SklearnEstimatorSerializer], + serialized_gcs_path=fake_gcs_path, + ) + + @pytest.mark.usefixtures("mock_gcs_upload") + def test_any_serializer_deserialize_keras_model( + self, any_serializer_instance, tmp_path, mock_keras_model_deserialize + ): + # Arrange + fake_gcs_path = os.fspath(tmp_path / "job_id/input/input_estimator") + os.makedirs(tmp_path / "job_id/input") + metadata_path = ( + tmp_path + / f"job_id/input/{serializers_base.SERIALIZATION_METADATA_FILENAME}_input_estimator.json" + ) + with open(metadata_path, "wb") as f: + f.write( + json.dumps( + { + serializers_base.SERIALIZATION_METADATA_SERIALIZER_KEY: "KerasModelSerializer", + serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [], + } + ).encode("utf-8") + ) + + # Act + _ = any_serializer_instance.deserialize(fake_gcs_path) + + # Assert + mock_keras_model_deserialize.assert_called_once_with( + any_serializer_instance._instances[serializers.KerasModelSerializer], + serialized_gcs_path=fake_gcs_path, + ) + + @pytest.mark.usefixtures("mock_gcs_upload") + def test_any_serializer_deserialize_torch_model( + self, any_serializer_instance, tmp_path, mock_torch_model_deserialize + ): + # Arrange + fake_gcs_path = os.fspath(tmp_path / "job_id/input/input_estimator") + os.makedirs(tmp_path / "job_id/input") + metadata_path = ( + tmp_path + / f"job_id/input/{serializers_base.SERIALIZATION_METADATA_FILENAME}_input_estimator.json" + ) + with open(metadata_path, "wb") as f: + f.write( + json.dumps( + { + serializers_base.SERIALIZATION_METADATA_SERIALIZER_KEY: "TorchModelSerializer", + serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [], + } + ).encode("utf-8") + ) + + # Act + _ = any_serializer_instance.deserialize(fake_gcs_path) + + # Assert + mock_torch_model_deserialize.assert_called_once_with( + any_serializer_instance._instances[serializers.TorchModelSerializer], + serialized_gcs_path=fake_gcs_path, + ) + + @pytest.mark.usefixtures("mock_gcs_upload") + def test_any_serializer_deserialize_dataframe( + self, any_serializer_instance, tmp_path, mock_pandas_data_deserialize + ): + # Arrange + fake_gcs_path = os.fspath(tmp_path / "job_id/input/X") + os.makedirs(tmp_path / "job_id/input") + metadata_path = ( + tmp_path + / f"job_id/input/{serializers_base.SERIALIZATION_METADATA_FILENAME}_X.json" + ) + with open(metadata_path, "wb") as f: + f.write( + json.dumps( + { + serializers_base.SERIALIZATION_METADATA_SERIALIZER_KEY: "PandasDataSerializer", + serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [], + } + ).encode("utf-8") + ) + + # Act + _ = any_serializer_instance.deserialize(fake_gcs_path) + + # Assert + mock_pandas_data_deserialize.assert_called_once_with( + any_serializer_instance._instances[serializers.PandasDataSerializer], + serialized_gcs_path=fake_gcs_path, + ) + + @pytest.mark.usefixtures("mock_gcs_upload") + def test_any_serializer_deserialize_torch_dataloader( + self, any_serializer_instance, tmp_path, mock_torch_dataloader_deserialize + ): + # Arrange + fake_gcs_path = os.fspath(tmp_path / "job_id/input/dataloader") + os.makedirs(tmp_path / "job_id/input") + metadata_path = ( + tmp_path + / f"job_id/input/{serializers_base.SERIALIZATION_METADATA_FILENAME}_dataloader.json" + ) + with open(metadata_path, "wb") as f: + f.write( + json.dumps( + { + serializers_base.SERIALIZATION_METADATA_SERIALIZER_KEY: "TorchDataLoaderSerializer", + serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [], + } + ).encode("utf-8") + ) + + # Act + _ = any_serializer_instance.deserialize(fake_gcs_path) + + # Assert + mock_torch_dataloader_deserialize.assert_called_once_with( + any_serializer_instance._instances[serializers.TorchDataLoaderSerializer], + serialized_gcs_path=fake_gcs_path, + ) + + @pytest.mark.usefixtures("mock_gcs_upload") + def test_any_serializer_deserialize_bigframe_sklearn( + self, any_serializer_instance, tmp_path, mock_bigframe_deserialize_sklearn + ): + # Arrange + fake_gcs_path = os.fspath(tmp_path / "job_id/input/X") + os.makedirs(tmp_path / "job_id/input") + metadata_path = ( + tmp_path + / f"job_id/input/{serializers_base.SERIALIZATION_METADATA_FILENAME}_X.json" + ) + with open(metadata_path, "wb") as f: + f.write( + json.dumps( + { + serializers_base.SERIALIZATION_METADATA_SERIALIZER_KEY: "BigframeSerializer", + serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [], + serializers.SERIALIZATION_METADATA_FRAMEWORK_KEY: "sklearn", + } + ).encode("utf-8") + ) + + # Act (step 2) + _ = any_serializer_instance.deserialize(fake_gcs_path) + + # Assert + mock_bigframe_deserialize_sklearn.assert_called_once_with( + any_serializer_instance._instances[serializers.BigframeSerializer], + serialized_gcs_path=fake_gcs_path, + ) + + @pytest.mark.usefixtures("mock_gcs_upload") + def test_any_serializer_deserialize_bigframe_torch( + self, any_serializer_instance, tmp_path, mock_bigframe_deserialize_torch + ): + # Arrange + fake_gcs_path = os.fspath(tmp_path / "job_id/input/X") + os.makedirs(tmp_path / "job_id/input") + metadata_path = ( + tmp_path + / f"job_id/input/{serializers_base.SERIALIZATION_METADATA_FILENAME}_X.json" + ) + with open(metadata_path, "wb") as f: + f.write( + json.dumps( + { + serializers_base.SERIALIZATION_METADATA_SERIALIZER_KEY: "BigframeSerializer", + serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [], + serializers.SERIALIZATION_METADATA_FRAMEWORK_KEY: "torch", + } + ).encode("utf-8") + ) + + # Act (step 2) + _ = any_serializer_instance.deserialize(fake_gcs_path) + + # Assert + mock_bigframe_deserialize_torch.assert_called_once_with( + any_serializer_instance._instances[serializers.BigframeSerializer], + serialized_gcs_path=fake_gcs_path, + ) + + @pytest.mark.usefixtures("mock_gcs_upload") + def test_any_serializer_deserialize_bigframe_tensorflow( + self, any_serializer_instance, tmp_path, mock_bigframe_deserialize_tensorflow + ): + # Arrange + fake_gcs_path = os.fspath(tmp_path / "job_id/input/X") + os.makedirs(tmp_path / "job_id/input") + metadata_path = ( + tmp_path + / f"job_id/input/{serializers_base.SERIALIZATION_METADATA_FILENAME}_X.json" + ) + with open(metadata_path, "wb") as f: + f.write( + json.dumps( + { + serializers_base.SERIALIZATION_METADATA_SERIALIZER_KEY: "BigframeSerializer", + serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [], + serializers.SERIALIZATION_METADATA_FRAMEWORK_KEY: "tensorflow", + } + ).encode("utf-8") + ) + + # Act (step 2) + _ = any_serializer_instance.deserialize(fake_gcs_path) + + # Assert + mock_bigframe_deserialize_tensorflow.assert_called_once_with( + any_serializer_instance._instances[serializers.BigframeSerializer], + serialized_gcs_path=fake_gcs_path, + ) + + def test_any_serializer_deserialize_tf_dataset( + self, any_serializer_instance, tmp_path, mock_tf_dataset_deserialize + ): + # Arrange + fake_gcs_path = os.fspath(tmp_path / "job_id/input/X") + os.makedirs(tmp_path / "job_id/input") + metadata_path = ( + tmp_path + / f"job_id/input/{serializers_base.SERIALIZATION_METADATA_FILENAME}_X.json" + ) + with open(metadata_path, "wb") as f: + f.write( + json.dumps( + { + serializers_base.SERIALIZATION_METADATA_SERIALIZER_KEY: "TFDatasetSerializer", + serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [ + "tensorflow==1.0.0" + ], + } + ).encode("utf-8") + ) + + # Act + any_serializer_instance.deserialize(fake_gcs_path) + + # Assert + mock_tf_dataset_deserialize.assert_called_once_with( + any_serializer_instance._instances[serializers.TFDatasetSerializer], + serialized_gcs_path=fake_gcs_path, + ) + + @pytest.mark.usefixtures("mock_gcs_upload") + def test_any_serializer_deserialize_general_object( + self, any_serializer_instance, tmp_path, mock_cloudpickle_deserialize + ): + # Arrange + fake_gcs_path = os.fspath(tmp_path / "job_id/input/general_object.cpkl") + os.makedirs(tmp_path / "job_id/input") + metadata_path = ( + tmp_path + / f"job_id/input/{serializers_base.SERIALIZATION_METADATA_FILENAME}_general_object.json" + ) + with open(metadata_path, "wb") as f: + f.write( + json.dumps( + { + serializers_base.SERIALIZATION_METADATA_SERIALIZER_KEY: "CloudPickleSerializer", + serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [], + } + ).encode("utf-8") + ) + + # Act + _ = any_serializer_instance.deserialize(fake_gcs_path) + + # Assert + mock_cloudpickle_deserialize.assert_called_once_with( + any_serializer_instance._instances[serializers.CloudPickleSerializer], + serialized_gcs_path=fake_gcs_path, + ) + + def test_any_serializer_deserialize_raise_runtime_error_when_dependency_cannot_be_imported( + self, tmp_path, any_serializer_instance + ): + # Arrange + fake_gcs_path = os.fspath(tmp_path / "job_id/input/general_object.cpkl") + os.makedirs(tmp_path / "job_id/input") + metadata_path = ( + tmp_path + / f"job_id/input/{serializers_base.SERIALIZATION_METADATA_FILENAME}_general_object.json" + ) + with open(metadata_path, "wb") as f: + f.write( + json.dumps( + { + serializers_base.SERIALIZATION_METADATA_SERIALIZER_KEY: "CloudPickleSerializer", + serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [ + "nonexisting_module==1.0.0", + ], + } + ).encode("utf-8") + ) + + # Act & Assert + with pytest.raises(RuntimeError, match="nonexisting_module is not installed"): + _ = any_serializer_instance.deserialize(fake_gcs_path) + + @mock.patch.object(serializers, "_is_valid_gcs_path", return_value=True) + def test_any_serializer_deserialize_raises_warning_when_version_mismatched( + self, mock_gcs_path_validation, tmp_path, caplog, any_serializer_instance + ): + # Arrange + fake_gcs_path = os.fspath(tmp_path / "job_id/input/general_object.cpkl") + os.makedirs(tmp_path / "job_id/input") + metadata_path = ( + tmp_path + / f"job_id/input/{serializers_base.SERIALIZATION_METADATA_FILENAME}_general_object.json" + ) + with open(metadata_path, "wb") as f: + f.write( + json.dumps( + { + serializers_base.SERIALIZATION_METADATA_SERIALIZER_KEY: "CloudPickleSerializer", + serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [ + "sklearn==1.0.0", + ], + } + ).encode("utf-8") + ) + with open(fake_gcs_path, "wb") as f: + f.write(cloudpickle.dumps([1, 2, 3], protocol=constants.PICKLE_PROTOCOL)) + + # Act + _ = any_serializer_instance.deserialize(fake_gcs_path) + # Assert + # The current sklearn version in google3 will changing, but it's a later + # version than 1.0.0 + with caplog.at_level(level=20, logger="vertexai.serialization_engine"): + assert "sklearn's version is" in caplog.text + assert "while the required version is ==1.0.0" in caplog.text diff --git a/tests/unit/vertexai/test_developer_mark.py b/tests/unit/vertexai/test_developer_mark.py new file mode 100644 index 0000000000..0bcbbc5a29 --- /dev/null +++ b/tests/unit/vertexai/test_developer_mark.py @@ -0,0 +1,295 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import functools + +import vertexai +from vertexai.preview._workflow import driver +from vertexai.preview._workflow.driver import remote +from vertexai.preview._workflow.executor import ( + remote_container_training, +) +from vertexai.preview._workflow.shared import ( + configs, +) +from vertexai.preview.developer import remote_specs +import pytest + +# RemoteConfig constants +_TEST_DISPLAY_NAME = "test_display_name" +_TEST_STAGING_BUCKET = "gs://test-staging-bucket" +_TEST_CONTAINER_URI = "gcr.io/test-image" +_TEST_MACHINE_TYPE = "n1-standard-4" +_TEST_SERVICE_ACCOUNT = "test-service-account" +_TEST_WORKER_POOL_SPECS = remote_specs.WorkerPoolSpecs( + chief=remote_specs.WorkerPoolSpec( + machine_type=_TEST_MACHINE_TYPE, + ) +) + +_TEST_TRAINING_CONFIG = configs.RemoteConfig( + display_name=_TEST_DISPLAY_NAME, + staging_bucket=_TEST_STAGING_BUCKET, + container_uri=_TEST_CONTAINER_URI, + machine_type=_TEST_MACHINE_TYPE, + service_account=_TEST_SERVICE_ACCOUNT, +) + +_TEST_TRAINING_CONFIG_WORKER_POOL = configs.RemoteConfig( + display_name=_TEST_DISPLAY_NAME, + staging_bucket=_TEST_STAGING_BUCKET, + container_uri=_TEST_CONTAINER_URI, + worker_pool_specs=_TEST_WORKER_POOL_SPECS, + service_account=_TEST_SERVICE_ACCOUNT, +) + +# Remote training custom job constants +_TEST_IMAGE_URI = "test_image_uri" +_TEST_REPLICA_COUNT = 1 +_TEST_ACCELERATOR_COUNT = 8 +_TEST_ACCELERATOR_TYPE = "NVIDIA_TESLA_K80" +_TEST_BOOT_DISK_TYPE = "test_boot_disk_type" +_TEST_BOOT_DISK_SIZE_GB = 10 +_TEST_REMOTE_CONTAINER_TRAINING_CONFIG = configs.DistributedTrainingConfig( + display_name=_TEST_DISPLAY_NAME, + staging_bucket=_TEST_STAGING_BUCKET, + machine_type=_TEST_MACHINE_TYPE, + replica_count=_TEST_REPLICA_COUNT, + accelerator_count=_TEST_ACCELERATOR_COUNT, + accelerator_type=_TEST_ACCELERATOR_TYPE, + boot_disk_type=_TEST_BOOT_DISK_TYPE, + boot_disk_size_gb=_TEST_BOOT_DISK_SIZE_GB, +) +_TEST_REMOTE_CONTAINER_TRAINING_CONFIG_WORKER_POOL = configs.DistributedTrainingConfig( + display_name=_TEST_DISPLAY_NAME, + staging_bucket=_TEST_STAGING_BUCKET, + worker_pool_specs=_TEST_WORKER_POOL_SPECS, +) + +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" + + +class TestDeveloperMark: + def test_mark_train(self): + class TestClass(vertexai.preview.VertexModel): + @vertexai.preview.developer.mark.train() + def test_method(x, y): + return x + y + + assert isinstance(TestClass.test_method, driver.VertexRemoteFunctor) + assert TestClass.test_method.vertex == configs.VertexConfig + + test_class = TestClass() + + assert isinstance(test_class.test_method, driver.VertexRemoteFunctor) + assert isinstance(test_class.test_method.vertex, configs.VertexConfig) + + @pytest.mark.usefixtures("google_auth_mock") + def test_mark_train_with_all_args(self): + class TestClass(vertexai.preview.VertexModel): + @vertexai.preview.developer.mark.train(remote_config=_TEST_TRAINING_CONFIG) + def test_method(self, x, y): + return x + y + + test_class = TestClass() + + assert isinstance(test_class.test_method, driver.VertexRemoteFunctor) + assert ( + test_class.test_method.vertex.remote_config.display_name + == _TEST_DISPLAY_NAME + ) + assert ( + test_class.test_method.vertex.remote_config.staging_bucket + == _TEST_STAGING_BUCKET + ) + assert ( + test_class.test_method.vertex.remote_config.container_uri + == _TEST_CONTAINER_URI + ) + assert ( + test_class.test_method.vertex.remote_config.machine_type + == _TEST_MACHINE_TYPE + ) + assert ( + test_class.test_method.vertex.remote_config.service_account + == _TEST_SERVICE_ACCOUNT + ) + + @pytest.mark.usefixtures("google_auth_mock") + def test_mark_train_with_worker_pool_specs(self): + class TestClass(vertexai.preview.VertexModel): + @vertexai.preview.developer.mark.train( + remote_config=_TEST_TRAINING_CONFIG_WORKER_POOL + ) + def test_method(self, x, y): + return x + y + + test_class = TestClass() + + assert isinstance(test_class.test_method, driver.VertexRemoteFunctor) + assert ( + test_class.test_method.vertex.remote_config.display_name + == _TEST_DISPLAY_NAME + ) + assert ( + test_class.test_method.vertex.remote_config.staging_bucket + == _TEST_STAGING_BUCKET + ) + assert ( + test_class.test_method.vertex.remote_config.container_uri + == _TEST_CONTAINER_URI + ) + assert ( + test_class.test_method.vertex.remote_config.worker_pool_specs + == _TEST_WORKER_POOL_SPECS + ) + + # pylint: disable=missing-function-docstring,protected-access) + @pytest.mark.parametrize( + "remote_config,expected_config", + [ + ( + _TEST_REMOTE_CONTAINER_TRAINING_CONFIG, + _TEST_REMOTE_CONTAINER_TRAINING_CONFIG, + ), + (None, configs.DistributedTrainingConfig()), + ( + _TEST_REMOTE_CONTAINER_TRAINING_CONFIG_WORKER_POOL, + _TEST_REMOTE_CONTAINER_TRAINING_CONFIG_WORKER_POOL, + ), + ], + ) + def test_mark_remote_container_train(self, remote_config, expected_config): + test_additional_data = [remote_specs._InputParameterSpec("arg_0")] + + # pylint: disable=missing-class-docstring + class MockTrainer(remote.VertexModel): + + # pylint: disable=invalid-name,missing-function-docstring + @vertexai.preview.developer.mark._remote_container_train( + image_uri=_TEST_IMAGE_URI, + additional_data=test_additional_data, + remote_config=remote_config, + ) + def fit(self): + return + + assert isinstance(MockTrainer.fit, driver.VertexRemoteFunctor) + assert isinstance(MockTrainer.fit.vertex, functools.partial) + assert MockTrainer.fit.vertex.func == configs.VertexConfig + assert not MockTrainer.fit.vertex.args + assert MockTrainer.fit.vertex.keywords == { + "remote_config": expected_config, + "remote": True, + } + + test_trainer = MockTrainer() + assert isinstance(test_trainer.fit, driver.VertexRemoteFunctor) + assert test_trainer.fit.vertex.remote_config == expected_config + assert test_trainer.fit._remote_executor is remote_container_training.train + assert test_trainer.fit._remote_executor_kwargs == { + "additional_data": test_additional_data, + "image_uri": _TEST_IMAGE_URI, + } + assert test_trainer.fit.vertex.remote + + # pylint: disable=missing-function-docstring,protected-access + def test_mark_remote_container_train_override_remote_config(self): + # pylint: disable=missing-class-docstring + class MockTrainer(remote.VertexModel): + + # pylint: disable=invalid-name,missing-function-docstring + @vertexai.preview.developer.mark._remote_container_train( + image_uri=_TEST_IMAGE_URI, + additional_data=[], + remote_config=configs.DistributedTrainingConfig(), + ) + def fit(self): + return + + test_trainer = MockTrainer() + assert isinstance(test_trainer.fit, driver.VertexRemoteFunctor) + assert ( + test_trainer.fit.vertex.remote_config == configs.DistributedTrainingConfig() + ) + assert test_trainer.fit._remote_executor is remote_container_training.train + assert test_trainer.fit._remote_executor_kwargs == { + "additional_data": [], + "image_uri": _TEST_IMAGE_URI, + } + + # Overrides training config + test_remote_config = test_trainer.fit.vertex.remote_config + test_remote_config.display_name = _TEST_DISPLAY_NAME + test_remote_config.staging_bucket = _TEST_STAGING_BUCKET + test_remote_config.machine_type = _TEST_MACHINE_TYPE + test_remote_config.replica_count = _TEST_REPLICA_COUNT + test_remote_config.accelerator_type = _TEST_ACCELERATOR_TYPE + test_remote_config.accelerator_count = _TEST_ACCELERATOR_COUNT + test_remote_config.boot_disk_type = _TEST_BOOT_DISK_TYPE + test_remote_config.boot_disk_size_gb = _TEST_BOOT_DISK_SIZE_GB + + assert ( + test_trainer.fit.vertex.remote_config + == _TEST_REMOTE_CONTAINER_TRAINING_CONFIG + ) + + def test_mark_predict(self): + class TestClass(vertexai.preview.VertexModel): + @vertexai.preview.developer.mark.predict() + def test_method(x, y): + return x + y + + assert isinstance(TestClass.test_method, driver.VertexRemoteFunctor) + assert TestClass.test_method.vertex == configs.VertexConfig + + test_class = TestClass() + + assert isinstance(test_class.test_method, driver.VertexRemoteFunctor) + assert isinstance(test_class.test_method.vertex, configs.VertexConfig) + + def test_mark_predict_with_all_args(self): + class TestClass(vertexai.preview.VertexModel): + @vertexai.preview.developer.mark.predict( + remote_config=_TEST_TRAINING_CONFIG + ) + def test_method(self, x, y): + return x + y + + test_class = TestClass() + + assert isinstance(test_class.test_method, driver.VertexRemoteFunctor) + assert ( + test_class.test_method.vertex.remote_config.display_name + == _TEST_DISPLAY_NAME + ) + assert ( + test_class.test_method.vertex.remote_config.staging_bucket + == _TEST_STAGING_BUCKET + ) + assert ( + test_class.test_method.vertex.remote_config.container_uri + == _TEST_CONTAINER_URI + ) + assert ( + test_class.test_method.vertex.remote_config.machine_type + == _TEST_MACHINE_TYPE + ) + assert ( + test_class.test_method.vertex.remote_config.service_account + == _TEST_SERVICE_ACCOUNT + ) diff --git a/tests/unit/vertexai/test_model_utils.py b/tests/unit/vertexai/test_model_utils.py new file mode 100644 index 0000000000..7d415d97e6 --- /dev/null +++ b/tests/unit/vertexai/test_model_utils.py @@ -0,0 +1,290 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from importlib import reload +from unittest import mock + +from google.cloud import aiplatform +from google.cloud.aiplatform import utils +import vertexai +from vertexai.preview._workflow.serialization_engine import ( + any_serializer, +) +import pytest + +from sklearn.linear_model import _logistic +import tensorflow +import torch + + +# project constants +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" +_TEST_BUCKET = "gs://test-bucket" +_TEST_UNIQUE_NAME = "test-unique-name" + + +# framework-specific constants +_SKLEARN_MODEL = _logistic.LogisticRegression() +_TF_MODEL = tensorflow.keras.models.Model() +_PYTORCH_MODEL = torch.nn.Module() +_TEST_MODEL_GCS_URI = "gs://test_model_dir" +_MODEL_RESOURCE_NAME = "projects/123/locations/us-central1/models/456" +_REWRAPPER = "rewrapper" + + +@pytest.fixture +def mock_serialize_model(): + with mock.patch.object( + any_serializer.AnySerializer, "serialize" + ) as mock_serialize_model: + yield mock_serialize_model + + +@pytest.fixture +def mock_vertex_model(): + model = mock.MagicMock(aiplatform.Model) + model.uri = _TEST_MODEL_GCS_URI + model.container_spec.image_uri = "us-docker.xxx/sklearn-cpu.1-0:latest" + model.labels = {"registered_by_vertex_ai": "true"} + yield model + + +@pytest.fixture +def mock_vertex_model_invalid(): + model = mock.MagicMock(aiplatform.Model) + model.uri = _TEST_MODEL_GCS_URI + model.container_spec.image_uri = "us-docker.xxx/sklearn-cpu.1-0:latest" + yield model + + +@pytest.fixture +def mock_timestamped_unique_name(): + with mock.patch.object( + utils, "timestamped_unique_name" + ) as mock_timestamped_unique_name: + mock_timestamped_unique_name.return_value = _TEST_UNIQUE_NAME + yield mock_timestamped_unique_name + + +@pytest.fixture +def mock_model_upload(mock_vertex_model): + with mock.patch.object(aiplatform.Model, "upload") as mock_model_upload: + mock_model_upload.return_value = mock_vertex_model + yield mock_model_upload + + +@pytest.fixture +def mock_get_vertex_model(mock_vertex_model): + with mock.patch.object(aiplatform, "Model") as mock_get_vertex_model: + mock_get_vertex_model.return_value = mock_vertex_model + yield mock_get_vertex_model + + +@pytest.fixture +def mock_get_vertex_model_invalid(mock_vertex_model_invalid): + with mock.patch.object(aiplatform, "Model") as mock_get_vertex_model: + mock_get_vertex_model.return_value = mock_vertex_model_invalid + yield mock_get_vertex_model + + +@pytest.fixture +def mock_deserialize_model(): + with mock.patch.object( + any_serializer.AnySerializer, "deserialize" + ) as mock_deserialize_model: + + mock_deserialize_model.side_effect = [ + _SKLEARN_MODEL, + mock.Mock(return_value=None), + ] + yield mock_deserialize_model + + +@pytest.fixture +def mock_deserialize_model_exception(): + with mock.patch.object( + any_serializer.AnySerializer, "deserialize" + ) as mock_deserialize_model_exception: + mock_deserialize_model_exception.side_effect = Exception + yield mock_deserialize_model_exception + + +@pytest.mark.usefixtures("google_auth_mock") +class TestModelUtils: + def setup_method(self): + reload(aiplatform) + reload(vertexai) + + def teardown_method(self): + aiplatform.initializer.global_pool.shutdown(wait=True) + + @pytest.mark.usefixtures("mock_timestamped_unique_name") + def test_register_sklearn_model(self, mock_model_upload, mock_serialize_model): + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_BUCKET, + ) + vertex_model = vertexai.preview.register(_SKLEARN_MODEL) + + expected_display_name = ( + f"vertex-ai-registered-sklearn-model-{_TEST_UNIQUE_NAME}" + ) + expected_uri = f"{_TEST_BUCKET}/{expected_display_name}" + expected_container_uri = ( + aiplatform.helpers.get_prebuilt_prediction_container_uri( + framework="sklearn", + framework_version="1.0", + ) + ) + + assert vertex_model.uri == _TEST_MODEL_GCS_URI + mock_model_upload.assert_called_once_with( + display_name=expected_display_name, + artifact_uri=expected_uri, + serving_container_image_uri=expected_container_uri, + labels={"registered_by_vertex_ai": "true"}, + sync=True, + ) + assert 2 == mock_serialize_model.call_count + mock_serialize_model.assert_has_calls( + calls=[ + mock.call( + _SKLEARN_MODEL, + f"{expected_uri}/model.pkl", + ), + ], + any_order=True, + ) + + @pytest.mark.parametrize("use_gpu", [True, False]) + @pytest.mark.usefixtures("mock_timestamped_unique_name") + def test_register_tf_model(self, mock_model_upload, mock_serialize_model, use_gpu): + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_BUCKET, + ) + vertex_model = vertexai.preview.register(_TF_MODEL, use_gpu=use_gpu) + + expected_display_name = ( + f"vertex-ai-registered-tensorflow-model-{_TEST_UNIQUE_NAME}" + ) + expected_uri = f"{_TEST_BUCKET}/{expected_display_name}/saved_model" + expected_container_uri = ( + aiplatform.helpers.get_prebuilt_prediction_container_uri( + framework="tensorflow", + framework_version="2.11", + accelerator="gpu" if use_gpu else "cpu", + ) + ) + + assert vertex_model.uri == _TEST_MODEL_GCS_URI + mock_model_upload.assert_called_once_with( + display_name=expected_display_name, + artifact_uri=expected_uri, + serving_container_image_uri=expected_container_uri, + labels={"registered_by_vertex_ai": "true"}, + sync=True, + ) + assert 2 == mock_serialize_model.call_count + mock_serialize_model.assert_has_calls( + calls=[ + mock.call( + _TF_MODEL, + f"{expected_uri}", + ), + ], + any_order=True, + ) + + @pytest.mark.parametrize("use_gpu", [True, False]) + @pytest.mark.usefixtures("mock_timestamped_unique_name") + def test_register_pytorch_model( + self, mock_model_upload, mock_serialize_model, use_gpu + ): + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_BUCKET, + ) + vertex_model = vertexai.preview.register(_PYTORCH_MODEL, use_gpu=use_gpu) + + expected_display_name = ( + f"vertex-ai-registered-pytorch-model-{_TEST_UNIQUE_NAME}" + ) + expected_uri = f"{_TEST_BUCKET}/{expected_display_name}" + expected_container_uri = ( + aiplatform.helpers.get_prebuilt_prediction_container_uri( + framework="pytorch", + framework_version="1.12", + accelerator="gpu" if use_gpu else "cpu", + ) + ) + + assert vertex_model.uri == _TEST_MODEL_GCS_URI + mock_model_upload.assert_called_once_with( + display_name=expected_display_name, + artifact_uri=expected_uri, + serving_container_image_uri=expected_container_uri, + labels={"registered_by_vertex_ai": "true"}, + sync=True, + ) + + assert 2 == mock_serialize_model.call_count + mock_serialize_model.assert_has_calls( + calls=[ + mock.call( + _PYTORCH_MODEL, + f"{expected_uri}/model.mar", + ), + ], + any_order=True, + ) + + @pytest.mark.usefixtures("mock_get_vertex_model") + def test_local_model_from_pretrained_succeed(self, mock_deserialize_model): + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_BUCKET, + ) + + local_model = vertexai.preview.from_pretrained(model_name=_MODEL_RESOURCE_NAME) + assert local_model == _SKLEARN_MODEL + assert 2 == mock_deserialize_model.call_count + mock_deserialize_model.assert_has_calls( + calls=[ + mock.call( + f"{_TEST_MODEL_GCS_URI}/model.pkl", + ), + ], + any_order=True, + ) + + @pytest.mark.usefixtures( + "mock_get_vertex_model_invalid", + ) + def test_local_model_from_pretrained_fail(self): + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_BUCKET, + ) + + with pytest.raises(ValueError): + vertexai.preview.from_pretrained(model_name=_MODEL_RESOURCE_NAME) diff --git a/tests/unit/vertexai/test_persistent_resource_util.py b/tests/unit/vertexai/test_persistent_resource_util.py new file mode 100644 index 0000000000..109b675f35 --- /dev/null +++ b/tests/unit/vertexai/test_persistent_resource_util.py @@ -0,0 +1,190 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import importlib + +from google.api_core import operation as ga_operation +from google.cloud import aiplatform +import vertexai +from google.cloud.aiplatform_v1beta1.services.persistent_resource_service import ( + PersistentResourceServiceClient, +) +from google.cloud.aiplatform_v1beta1.types import persistent_resource_service +from google.cloud.aiplatform_v1beta1.types.persistent_resource import ( + PersistentResource, +) +from google.cloud.aiplatform_v1beta1.types.persistent_resource import ( + ResourcePool, +) +from vertexai.preview._workflow.executor import ( + persistent_resource_util, +) +from vertexai.preview._workflow.shared import configs +import mock +import pytest + + +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" +_TEST_CLUSTER_NAME = "test-cluster" +_TEST_CLUSTER_CONFIG = configs.PersistentResourceConfig(name=_TEST_CLUSTER_NAME) +_TEST_CLUSTER_RESOURCE_NAME = f"{_TEST_PARENT}/persistentResources/{_TEST_CLUSTER_NAME}" + + +_TEST_PERSISTENT_RESOURCE_ERROR = PersistentResource() +_TEST_PERSISTENT_RESOURCE_ERROR.state = "ERROR" + +_TEST_REQUEST_RUNNING_DEFAULT = PersistentResource() +resource_pool = ResourcePool() +resource_pool.machine_spec.machine_type = "n1-standard-4" +resource_pool.replica_count = 1 +resource_pool.disk_spec.boot_disk_type = "pd-ssd" +resource_pool.disk_spec.boot_disk_size_gb = 100 +_TEST_REQUEST_RUNNING_DEFAULT.resource_pools = [resource_pool] + + +_TEST_PERSISTENT_RESOURCE_RUNNING = PersistentResource() +_TEST_PERSISTENT_RESOURCE_RUNNING.state = "RUNNING" + + +@pytest.fixture +def persistent_resource_running_mock(): + with mock.patch.object( + PersistentResourceServiceClient, + "get_persistent_resource", + ) as persistent_resource_running_mock: + persistent_resource_running_mock.return_value = ( + _TEST_PERSISTENT_RESOURCE_RUNNING + ) + yield persistent_resource_running_mock + + +@pytest.fixture +def persistent_resource_exception_mock(): + with mock.patch.object( + PersistentResourceServiceClient, + "get_persistent_resource", + ) as persistent_resource_exception_mock: + persistent_resource_exception_mock.side_effect = Exception + yield persistent_resource_exception_mock + + +@pytest.fixture +def create_persistent_resource_default_mock(): + with mock.patch.object( + PersistentResourceServiceClient, + "create_persistent_resource", + ) as create_persistent_resource_default_mock: + create_persistent_resource_lro_mock = mock.Mock(ga_operation.Operation) + create_persistent_resource_lro_mock.result.return_value = ( + _TEST_REQUEST_RUNNING_DEFAULT + ) + create_persistent_resource_default_mock.return_value = ( + create_persistent_resource_lro_mock + ) + yield create_persistent_resource_default_mock + + +@pytest.fixture +def persistent_resource_error_mock(): + with mock.patch.object( + PersistentResourceServiceClient, + "get_persistent_resource", + ) as persistent_resource_error_mock: + persistent_resource_error_mock.return_value = _TEST_PERSISTENT_RESOURCE_ERROR + yield persistent_resource_error_mock + + +@pytest.fixture +def create_persistent_resource_exception_mock(): + with mock.patch.object( + PersistentResourceServiceClient, + "create_persistent_resource", + ) as create_persistent_resource_exception_mock: + create_persistent_resource_exception_mock.side_effect = Exception + yield create_persistent_resource_exception_mock + + +@pytest.mark.usefixtures("google_auth_mock") +class TestPersistentResourceUtils: + def setup_method(self): + importlib.reload(vertexai.preview.initializer) + importlib.reload(vertexai.preview) + + def teardown_method(self): + aiplatform.initializer.global_pool.shutdown(wait=True) + + def test_check_persistent_resource_true(self, persistent_resource_running_mock): + expected = persistent_resource_util.check_persistent_resource( + _TEST_CLUSTER_RESOURCE_NAME + ) + + assert expected + + request = persistent_resource_service.GetPersistentResourceRequest( + name=_TEST_CLUSTER_RESOURCE_NAME, + ) + persistent_resource_running_mock.assert_called_once_with(request) + + def test_check_persistent_resource_false(self, persistent_resource_exception_mock): + with pytest.raises(Exception): + expected = persistent_resource_util.check_persistent_resource( + _TEST_CLUSTER_RESOURCE_NAME + ) + + assert not expected + + request = persistent_resource_service.GetPersistentResourceRequest( + name=_TEST_CLUSTER_RESOURCE_NAME, + ) + persistent_resource_exception_mock.assert_called_once_with(request) + + @pytest.mark.usefixtures("persistent_resource_error_mock") + def test_check_persistent_resource_error(self): + with pytest.raises(ValueError) as e: + persistent_resource_util.check_persistent_resource( + _TEST_CLUSTER_RESOURCE_NAME + ) + + e.match( + regexp=r'(\'The existing cluster `\', \'projects/test-project/locations/us-central1/persistentResources/test-cluster\', "` isn\'t running, please specify a different cluster_name.")' + ) + + @pytest.mark.usefixtures("persistent_resource_running_mock") + def test_create_persistent_resource_default_success( + self, create_persistent_resource_default_mock + ): + persistent_resource_util.create_persistent_resource(_TEST_CLUSTER_RESOURCE_NAME) + + request = persistent_resource_service.CreatePersistentResourceRequest( + parent=_TEST_PARENT, + persistent_resource=_TEST_REQUEST_RUNNING_DEFAULT, + persistent_resource_id=_TEST_CLUSTER_NAME, + ) + + create_persistent_resource_default_mock.assert_called_with( + request, + ) + + @pytest.mark.usefixtures("create_persistent_resource_exception_mock") + def test_create_ray_cluster_state_error(self): + with pytest.raises(ValueError) as e: + persistent_resource_util.create_persistent_resource( + _TEST_CLUSTER_RESOURCE_NAME + ) + + e.match(regexp=r"Failed in cluster creation due to: ") diff --git a/tests/unit/vertexai/test_remote_container_training.py b/tests/unit/vertexai/test_remote_container_training.py new file mode 100644 index 0000000000..2b156b8ba6 --- /dev/null +++ b/tests/unit/vertexai/test_remote_container_training.py @@ -0,0 +1,586 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Tests for _workflow/executor/remote_container_training.py. +""" + +from importlib import reload +import inspect +import os +import re +import tempfile + +import cloudpickle +from google.cloud import aiplatform +import vertexai +from google.cloud.aiplatform.compat.types import ( + custom_job as gca_custom_job_compat, +) +from google.cloud.aiplatform.compat.types import io as gca_io_compat +from vertexai.preview._workflow.driver import remote +from vertexai.preview._workflow.executor import ( + remote_container_training, +) +from vertexai.preview._workflow.shared import configs +from vertexai.preview.developer import remote_specs +import pandas as pd +import pytest + + +# Custom job constants. +_TEST_INPUTS = [ + "--arg_0=string_val_0", + "--arg_1=string_val_1", + "--arg_2=int_val_0", + "--arg_3=int_val_1", +] +_TEST_IMAGE_URI = "test_image_uri" +_TEST_MACHINE_TYPE = "n1-standard-4" + +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" + +_TEST_BUCKET_NAME = "gs://test_bucket" +_TEST_BASE_OUTPUT_DIR = f"{_TEST_BUCKET_NAME}/test_base_output_dir" + +_TEST_DISPLAY_NAME = "test_display_name" +_TEST_STAGING_BUCKET = "gs://test-staging-bucket" +_TEST_CONTAINER_URI = "gcr.io/test-image" +_TEST_REPLICA_COUNT = 1 +_TEST_ACCELERATOR_COUNT = 8 +_TEST_ACCELERATOR_TYPE = "NVIDIA_TESLA_K80" +_TEST_BOOT_DISK_TYPE = "test_boot_disk_type" +_TEST_BOOT_DISK_SIZE_GB = 10 +_TEST_REMOTE_CONTAINER_TRAINING_CONFIG = configs.DistributedTrainingConfig( + display_name=_TEST_DISPLAY_NAME, + staging_bucket=_TEST_STAGING_BUCKET, + machine_type=_TEST_MACHINE_TYPE, + replica_count=_TEST_REPLICA_COUNT, + accelerator_count=_TEST_ACCELERATOR_COUNT, + accelerator_type=_TEST_ACCELERATOR_TYPE, + boot_disk_type=_TEST_BOOT_DISK_TYPE, + boot_disk_size_gb=_TEST_BOOT_DISK_SIZE_GB, +) + +_TEST_WORKER_POOL_SPECS = remote_specs.WorkerPoolSpecs( + chief=remote_specs.WorkerPoolSpec( + machine_type=_TEST_MACHINE_TYPE, + replica_count=_TEST_REPLICA_COUNT, + accelerator_count=_TEST_ACCELERATOR_COUNT, + accelerator_type=_TEST_ACCELERATOR_TYPE, + boot_disk_type=_TEST_BOOT_DISK_TYPE, + boot_disk_size_gb=_TEST_BOOT_DISK_SIZE_GB, + ) +) + +_TEST_REMOTE_CONTAINER_TRAINING_CONFIG_WORKER_POOL = configs.DistributedTrainingConfig( + display_name=_TEST_DISPLAY_NAME, + staging_bucket=_TEST_STAGING_BUCKET, + worker_pool_specs=_TEST_WORKER_POOL_SPECS, +) + +_TEST_REMOTE_CONTAINER_TRAINING_CONFIG_INVALID = configs.DistributedTrainingConfig( + display_name=_TEST_DISPLAY_NAME, + staging_bucket=_TEST_STAGING_BUCKET, + machine_type=_TEST_MACHINE_TYPE, + replica_count=_TEST_REPLICA_COUNT, + accelerator_count=_TEST_ACCELERATOR_COUNT, + accelerator_type=_TEST_ACCELERATOR_TYPE, + boot_disk_type=_TEST_BOOT_DISK_TYPE, + boot_disk_size_gb=_TEST_BOOT_DISK_SIZE_GB, + worker_pool_specs=_TEST_WORKER_POOL_SPECS, +) + + +# pylint: disable=protected-access,missing-function-docstring +class TestRemoteContainerTrain: + """Tests for remote_container_train and helper functions.""" + + def setup_method(self): + reload(aiplatform.initializer) + reload(aiplatform) + reload(vertexai.preview.initializer) + reload(vertexai) + + def test_generate_worker_pool_specs_single_machine(self): + expected_worker_pool_specs = [ + { + "replica_count": 1, + "machine_spec": { + "machine_type": _TEST_MACHINE_TYPE, + }, + "disk_spec": { + "boot_disk_type": _TEST_BOOT_DISK_TYPE, + "boot_disk_size_gb": _TEST_BOOT_DISK_SIZE_GB, + }, + "container_spec": { + "image_uri": _TEST_IMAGE_URI, + "args": _TEST_INPUTS, + }, + } + ] + + worker_pool_specs = remote_container_training._generate_worker_pool_specs( + image_uri=_TEST_IMAGE_URI, + inputs=_TEST_INPUTS, + machine_type=_TEST_MACHINE_TYPE, + replica_count=1, + boot_disk_type=_TEST_BOOT_DISK_TYPE, + boot_disk_size_gb=_TEST_BOOT_DISK_SIZE_GB, + ) + + assert worker_pool_specs == expected_worker_pool_specs + + def test_generate_worker_pool_specs_distributed(self): + expected_worker_pool_specs = [ + { + "replica_count": 1, + "machine_spec": { + "machine_type": _TEST_MACHINE_TYPE, + }, + "disk_spec": { + "boot_disk_type": "pd-ssd", + "boot_disk_size_gb": 100, + }, + "container_spec": { + "image_uri": _TEST_IMAGE_URI, + "args": _TEST_INPUTS, + }, + }, + { + "replica_count": 3, + "machine_spec": { + "machine_type": _TEST_MACHINE_TYPE, + }, + "disk_spec": { + "boot_disk_type": "pd-ssd", + "boot_disk_size_gb": 100, + }, + "container_spec": { + "image_uri": _TEST_IMAGE_URI, + "args": _TEST_INPUTS, + }, + }, + ] + + worker_pool_specs = remote_container_training._generate_worker_pool_specs( + image_uri=_TEST_IMAGE_URI, + inputs=_TEST_INPUTS, + replica_count=4, + machine_type=_TEST_MACHINE_TYPE, + ) + + assert worker_pool_specs == expected_worker_pool_specs + + def test_generate_worker_pool_specs_gpu(self): + test_accelerator_type = "NVIDIA_TESLA_K80" + test_accelerator_count = 8 + + expected_worker_pool_specs = [ + { + "replica_count": 1, + "machine_spec": { + "machine_type": _TEST_MACHINE_TYPE, + "accelerator_type": test_accelerator_type, + "accelerator_count": test_accelerator_count, + }, + "disk_spec": { + "boot_disk_type": "pd-ssd", + "boot_disk_size_gb": 100, + }, + "container_spec": { + "image_uri": _TEST_IMAGE_URI, + "args": _TEST_INPUTS, + }, + } + ] + + worker_pool_specs = remote_container_training._generate_worker_pool_specs( + image_uri=_TEST_IMAGE_URI, + inputs=_TEST_INPUTS, + machine_type=_TEST_MACHINE_TYPE, + accelerator_count=test_accelerator_count, + accelerator_type=test_accelerator_type, + ) + + assert worker_pool_specs == expected_worker_pool_specs + + def test_generate_worker_pool_specs_invalid(self): + with pytest.raises(ValueError) as e: + remote_container_training._generate_worker_pool_specs( + image_uri=_TEST_IMAGE_URI, + inputs=_TEST_INPUTS, + replica_count=0, + machine_type=_TEST_MACHINE_TYPE, + ) + expected_err_msg = "replica_count must be a positive number but is 0." + assert str(e.value) == expected_err_msg + + # pylint: disable=missing-function-docstring,protected-access + @pytest.mark.parametrize( + "remote_config", + [ + (_TEST_REMOTE_CONTAINER_TRAINING_CONFIG), + (_TEST_REMOTE_CONTAINER_TRAINING_CONFIG_WORKER_POOL), + ], + ) + @pytest.mark.usefixtures( + "google_auth_mock", "mock_uuid", "mock_get_custom_job_succeeded" + ) + def test_remote_container_train( + self, + mock_blob_upload_from_filename, + mock_create_custom_job, + mock_named_temp_file, + mock_blob_download_to_filename, + remote_config: configs.DistributedTrainingConfig, + ): + # pylint: disable=missing-class-docstring + class MockTrainer(remote.VertexModel): + def __init__(self, input_0, input_1): + super().__init__() + sig = inspect.signature(self.__init__) + self._binding = sig.bind(input_0, input_1).arguments + self.output_0 = None + self.output_1 = None + + # pylint: disable=invalid-name,unused-argument,missing-function-docstring + @vertexai.preview.developer.mark._remote_container_train( + image_uri=_TEST_IMAGE_URI, + additional_data=[ + remote_specs._InputParameterSpec("input_0"), + remote_specs._InputParameterSpec( + "input_1", serializer="cloudpickle" + ), + remote_specs._InputParameterSpec("X", serializer="parquet"), + remote_specs._OutputParameterSpec("output_0"), + remote_specs._OutputParameterSpec( + "output_1", deserializer="cloudpickle" + ), + ], + remote_config=remote_config, + ) + def fit(self, X): + self.output_0 = int(self.output_0) + + def test_input_1(x): + return x + + test_trainer = MockTrainer( + input_0="test_input_0", + input_1=test_input_1, + ) + test_data = pd.DataFrame(data={"col_0": [0, 1], "col_1": [2, 3]}) + test_output_0 = 10 + + def test_output_1(x): + return x + 1 + + assert test_trainer.fit._remote_executor is remote_container_training.train + + with tempfile.TemporaryDirectory() as tmp_dir: + # Sets up file mocks + test_input_1_path = os.path.join(tmp_dir, "input_1") + test_input_1_handler = open(test_input_1_path, "wb") + + test_serialized_path = os.path.join(tmp_dir, "serialized") + test_serialized_handler = open(test_serialized_path, "wb") + + test_metadata_path = os.path.join(tmp_dir, "metadata") + test_metadata_handler = open(test_metadata_path, "wb") + + test_output_0_path = os.path.join(tmp_dir, "output_0") + with open(test_output_0_path, "w") as f: + f.write(f"{test_output_0}") + test_output_0_handler = open(test_output_0_path, "r") + + test_output_1_path = os.path.join(tmp_dir, "output_1") + with open(test_output_1_path, "wb") as f: + f.write(cloudpickle.dumps(test_output_1)) + test_output_1_handler = open(test_output_1_path, "rb") + + (mock_named_temp_file.return_value.__enter__.side_effect) = [ + test_input_1_handler, + test_serialized_handler, + test_metadata_handler, + test_output_0_handler, + test_output_1_handler, + ] + + # Calls the decorated function + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_trainer.fit(test_data) + + # Checks the created custom job and outputs + expected_inputs = [ + "--input_0=test_input_0", + f"--input_1={_TEST_STAGING_BUCKET}/input/input_1", + f"--X={_TEST_STAGING_BUCKET}/input/X", + f"--output_0={_TEST_STAGING_BUCKET}/output/output_0", + f"--output_1={_TEST_STAGING_BUCKET}/output/output_1", + ] + + assert mock_blob_upload_from_filename.call_count == 3 + assert mock_blob_download_to_filename.call_count == 2 + + expected_worker_pool_specs = [ + { + "replica_count": 1, + "machine_spec": { + "machine_type": _TEST_MACHINE_TYPE, + "accelerator_type": _TEST_ACCELERATOR_TYPE, + "accelerator_count": _TEST_ACCELERATOR_COUNT, + }, + "disk_spec": { + "boot_disk_type": _TEST_BOOT_DISK_TYPE, + "boot_disk_size_gb": _TEST_BOOT_DISK_SIZE_GB, + }, + "container_spec": { + "image_uri": _TEST_IMAGE_URI, + "args": expected_inputs, + }, + } + ] + expected_custom_job = gca_custom_job_compat.CustomJob( + display_name=f"MockTrainer-{_TEST_DISPLAY_NAME}-0", + job_spec=gca_custom_job_compat.CustomJobSpec( + worker_pool_specs=expected_worker_pool_specs, + base_output_directory=gca_io_compat.GcsDestination( + output_uri_prefix=os.path.join(_TEST_STAGING_BUCKET, "custom_job"), + ), + ), + ) + mock_create_custom_job.assert_called_once_with( + parent=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}", + custom_job=expected_custom_job, + timeout=None, + ) + + assert test_trainer.output_0 == test_output_0 + # pylint: disable=not-callable + assert test_trainer.output_1(1) == test_output_1(1) + + # pylint: disable=missing-function-docstring,protected-access + def test_remote_container_train_invalid_additional_data(self): + # pylint: disable=missing-class-docstring + class MockTrainer(remote.VertexModel): + def __init__(self): + super().__init__() + self._binding = {} + + # pylint: disable=invalid-name,missing-function-docstring + @vertexai.preview.developer.mark._remote_container_train( + image_uri=_TEST_IMAGE_URI, + additional_data=["invalid"], + remote_config=configs.DistributedTrainingConfig( + staging_bucket=_TEST_STAGING_BUCKET + ), + ) + def fit(self): + return + + test_trainer = MockTrainer() + assert test_trainer.fit._remote_executor is remote_container_training.train + + with pytest.raises(ValueError, match="Invalid data type"): + test_trainer.fit() + + @pytest.mark.usefixtures( + "google_auth_mock", "mock_uuid", "mock_get_custom_job_succeeded" + ) + def test_remote_container_train_invalid_local(self): + # pylint: disable=missing-class-docstring + class MockTrainer(remote.VertexModel): + def __init__(self): + super().__init__() + self._binding = {} + + # pylint: disable=invalid-name,missing-function-docstring + @vertexai.preview.developer.mark._remote_container_train( + image_uri=_TEST_IMAGE_URI, + additional_data=[], + remote_config=configs.DistributedTrainingConfig( + staging_bucket=_TEST_STAGING_BUCKET + ), + ) + def fit(self): + return + + test_trainer = MockTrainer() + assert test_trainer.fit._remote_executor is remote_container_training.train + test_trainer.fit.vertex.remote = False + with pytest.raises( + ValueError, + match="Remote container train is only supported for remote mode.", + ): + test_trainer.fit() + + # pylint: disable=missing-function-docstring,protected-access + @pytest.mark.usefixtures( + "google_auth_mock", "mock_uuid", "mock_get_custom_job_succeeded" + ) + def test_remote_container_train_default_config(self, mock_create_custom_job): + class MockTrainer(remote.VertexModel): + def __init__(self): + super().__init__() + self._binding = {} + + # pylint: disable=invalid-name,missing-function-docstring + @vertexai.preview.developer.mark._remote_container_train( + image_uri=_TEST_IMAGE_URI, + additional_data=[], + ) + def fit(self): + return + + test_trainer = MockTrainer() + + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_STAGING_BUCKET, + ) + + test_trainer.fit() + + expected_display_name = "MockTrainer-remote-fit" + expected_worker_pool_specs = [ + { + "replica_count": 1, + "machine_spec": { + "machine_type": remote_container_training._DEFAULT_MACHINE_TYPE, + "accelerator_type": ( + remote_container_training._DEFAULT_ACCELERATOR_TYPE + ), + "accelerator_count": ( + remote_container_training._DEFAULT_ACCELERATOR_COUNT + ), + }, + "disk_spec": { + "boot_disk_type": remote_container_training._DEFAULT_BOOT_DISK_TYPE, + "boot_disk_size_gb": ( + remote_container_training._DEFAULT_BOOT_DISK_SIZE_GB + ), + }, + "container_spec": { + "image_uri": _TEST_IMAGE_URI, + "args": [], + }, + } + ] + expected_custom_job = gca_custom_job_compat.CustomJob( + display_name=f"{expected_display_name}-0", + job_spec=gca_custom_job_compat.CustomJobSpec( + worker_pool_specs=expected_worker_pool_specs, + base_output_directory=gca_io_compat.GcsDestination( + output_uri_prefix=os.path.join(_TEST_STAGING_BUCKET, "custom_job"), + ), + ), + ) + mock_create_custom_job.assert_called_once_with( + parent=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}", + custom_job=expected_custom_job, + timeout=None, + ) + + @pytest.mark.usefixtures( + "google_auth_mock", "mock_uuid", "mock_get_custom_job_succeeded" + ) + def test_remote_container_train_job_dir(self, mock_create_custom_job): + class MockTrainer(remote.VertexModel): + def __init__(self): + super().__init__() + self._binding = {"job_dir": ""} + + # pylint: disable=invalid-name,missing-function-docstring + @vertexai.preview.developer.mark._remote_container_train( + image_uri=_TEST_IMAGE_URI, + additional_data=[remote_specs._InputParameterSpec("job_dir")], + remote_config=configs.DistributedTrainingConfig( + staging_bucket=_TEST_STAGING_BUCKET + ), + ) + def fit(self): + return + + test_trainer = MockTrainer() + + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + test_trainer.fit() + + expected_display_name = "MockTrainer-remote-fit" + expected_job_dir = os.path.join(_TEST_STAGING_BUCKET, "custom_job") + expected_worker_pool_specs = [ + { + "replica_count": 1, + "machine_spec": { + "machine_type": remote_container_training._DEFAULT_MACHINE_TYPE, + "accelerator_type": remote_container_training._DEFAULT_ACCELERATOR_TYPE, + "accelerator_count": remote_container_training._DEFAULT_ACCELERATOR_COUNT, + }, + "disk_spec": { + "boot_disk_type": remote_container_training._DEFAULT_BOOT_DISK_TYPE, + "boot_disk_size_gb": remote_container_training._DEFAULT_BOOT_DISK_SIZE_GB, + }, + "container_spec": { + "image_uri": _TEST_IMAGE_URI, + "args": [f"--job_dir={expected_job_dir}"], + }, + } + ] + expected_custom_job = gca_custom_job_compat.CustomJob( + display_name=f"{expected_display_name}-0", + job_spec=gca_custom_job_compat.CustomJobSpec( + worker_pool_specs=expected_worker_pool_specs, + base_output_directory=gca_io_compat.GcsDestination( + output_uri_prefix=os.path.join(_TEST_STAGING_BUCKET, "custom_job"), + ), + ), + ) + mock_create_custom_job.assert_called_once_with( + parent=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}", + custom_job=expected_custom_job, + timeout=None, + ) + + @pytest.mark.usefixtures( + "google_auth_mock", "mock_uuid", "mock_get_custom_job_succeeded" + ) + def test_remote_container_train_invalid_remote_config(self): + # pylint: disable=missing-class-docstring + class MockTrainer(remote.VertexModel): + def __init__(self): + super().__init__() + self._binding = {} + + # pylint: disable=invalid-name,missing-function-docstring + @vertexai.preview.developer.mark._remote_container_train( + image_uri=_TEST_IMAGE_URI, + additional_data=[], + remote_config=_TEST_REMOTE_CONTAINER_TRAINING_CONFIG_INVALID, + ) + def fit(self): + return + + test_trainer = MockTrainer() + assert test_trainer.fit._remote_executor is remote_container_training.train + with pytest.raises( + ValueError, + match=re.escape( + "Cannot specify both 'worker_pool_specs' and ['machine_type', 'accelerator_type', 'accelerator_count', 'replica_count', 'boot_disk_type', 'boot_disk_size_gb']." + ), + ): + test_trainer.fit() diff --git a/tests/unit/vertexai/test_remote_prediction.py b/tests/unit/vertexai/test_remote_prediction.py new file mode 100644 index 0000000000..81b2816913 --- /dev/null +++ b/tests/unit/vertexai/test_remote_prediction.py @@ -0,0 +1,139 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from importlib import reload +import inspect +from unittest.mock import patch + +from google.cloud import aiplatform +import vertexai +from vertexai.preview._workflow.executor import prediction +from vertexai.preview._workflow.executor import training +from vertexai.preview._workflow.shared import configs + +import pytest +from sklearn.datasets import load_iris +from sklearn.linear_model import _logistic +from sklearn.model_selection import train_test_split + + +# vertexai constants +_TEST_PROJECT = "test-project" +_TEST_PROJECT_NUMBER = 123 +_TEST_LOCATION = "us-central1" +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" +_TEST_BUCKET_NAME = "gs://test-bucket" + +# dataset constants +dataset = load_iris() +_X_TRAIN, _X_TEST, _Y_TRAIN, _Y_TEST = train_test_split( + dataset.data, dataset.target, test_size=0.2, random_state=42 +) + +# config constants +_TEST_CONTAINER_URI = "gcr.io/custom-image" +_TEST_DISPLAY_NAME = "test-display-name" + + +@pytest.fixture +def mock_remote_training(): + with patch.object(training, "remote_training") as mock_remote_training: + mock_remote_training.return_value = _Y_TEST + yield mock_remote_training + + +@pytest.mark.usefixtures("google_auth_mock") +class TestRemotePrediction: + def setup_method(self): + reload(vertexai) + reload(vertexai.preview.initializer) + reload(_logistic) + + def teardown_method(self): + aiplatform.initializer.global_pool.shutdown(wait=True) + + def test_remote_prediction_sklearn(self, mock_remote_training): + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_BUCKET_NAME, + ) + vertexai.preview.init(remote=True) + + LogisticRegression = vertexai.preview.remote(_logistic.LogisticRegression) + model = LogisticRegression() + + model.predict.vertex.remote = True + model.predict.vertex.remote_config.staging_bucket = _TEST_BUCKET_NAME + + model.predict(_X_TEST) + + invokable = mock_remote_training.call_args[1]["invokable"] + assert invokable.method == model.predict._method + assert invokable.bound_arguments == ( + inspect.signature(model.predict._method).bind(_X_TEST) + ) + + assert invokable.vertex_config.remote is True + + assert invokable.vertex_config.remote_config.display_name is None + assert invokable.vertex_config.remote_config.staging_bucket == _TEST_BUCKET_NAME + assert invokable.vertex_config.remote_config.container_uri is None + assert invokable.vertex_config.remote_config.machine_type is None + assert invokable.vertex_config.remote_config.service_account is None + + assert invokable.remote_executor == prediction.remote_prediction + assert invokable.remote_executor_kwargs == {} + assert invokable.instance == model + + def test_remote_prediction_with_set_config(self, mock_remote_training): + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_BUCKET_NAME, + ) + vertexai.preview.init(remote=True) + + LogisticRegression = vertexai.preview.remote(_logistic.LogisticRegression) + model = LogisticRegression() + + model.predict.vertex.remote = True + + model.predict.vertex.set_config( + staging_bucket=_TEST_BUCKET_NAME, display_name=_TEST_DISPLAY_NAME + ) + + model.predict(_X_TEST) + + invokable = mock_remote_training.call_args[1]["invokable"] + + assert invokable.method == model.predict._method + assert invokable.bound_arguments == ( + inspect.signature(model.predict._method).bind(_X_TEST) + ) + + assert invokable.vertex_config.remote is True + assert isinstance(invokable.vertex_config.remote_config, configs.RemoteConfig) + + assert invokable.vertex_config.remote_config.display_name == _TEST_DISPLAY_NAME + assert invokable.vertex_config.remote_config.staging_bucket == _TEST_BUCKET_NAME + assert invokable.vertex_config.remote_config.container_uri is None + assert invokable.vertex_config.remote_config.machine_type is None + assert invokable.vertex_config.remote_config.service_account is None + + assert invokable.remote_executor == prediction.remote_prediction + assert invokable.remote_executor_kwargs == {} + assert invokable.instance == model diff --git a/tests/unit/vertexai/test_remote_specs.py b/tests/unit/vertexai/test_remote_specs.py new file mode 100644 index 0000000000..4a73985af7 --- /dev/null +++ b/tests/unit/vertexai/test_remote_specs.py @@ -0,0 +1,712 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Tests for developer/remote_specs.py. +""" + +import json +import os +import re +import tempfile +from typing import Any, Dict, List + +import cloudpickle +import vertexai +from vertexai.preview.developer import remote_specs +import mock +import pandas as pd +import pytest +import torch + + +_TEST_BINDING = { + "arg_0": 10, + "arg_1": lambda x: x + 1, + "arg_2": pd.DataFrame(data={"col_0": [0, 1], "col_1": [2, 3]}), +} + +_TEST_MACHINE_TYPE = "n1-standard-16" +_TEST_REPLICA_COUNT = 1 +_TEST_BOOT_DISK_TYPE_DEFAULT = "pd-ssd" +_TEST_BOOT_DISK_SIZE_GB_DEFAULT = 100 + +_TEST_WORKER_POOL_SPEC_OBJ_MACHINE_TYPE = remote_specs.WorkerPoolSpec( + machine_type=_TEST_MACHINE_TYPE, replica_count=_TEST_REPLICA_COUNT +) + +_TEST_WORKER_POOL_SPEC_MACHINE_TYPE = { + "machine_spec": {"machine_type": _TEST_MACHINE_TYPE}, + "replica_count": _TEST_REPLICA_COUNT, + "disk_spec": { + "boot_disk_type": _TEST_BOOT_DISK_TYPE_DEFAULT, + "boot_disk_size_gb": _TEST_BOOT_DISK_SIZE_GB_DEFAULT, + }, +} + +_TEST_WORKER_POOL_SPEC_MACHINE_TYPE_CONTAINER_SPEC = { + "machine_spec": {"machine_type": _TEST_MACHINE_TYPE}, + "replica_count": _TEST_REPLICA_COUNT, + "disk_spec": { + "boot_disk_type": _TEST_BOOT_DISK_TYPE_DEFAULT, + "boot_disk_size_gb": _TEST_BOOT_DISK_SIZE_GB_DEFAULT, + }, + "container_spec": { + "image_uri": "test-image", + "command": ["python3", "run.py"], + "args": [], + }, +} + +_TEST_CLUSTER_SPEC_CHIEF_STR = '{"cluster":{"workerpool0":["cmle-training-workerpool0-1d969d3ba6-0:2222"],"workerpool1":["cmle-training-workerpool1-1d969d3ba6-0:2222"]},"environment":"cloud","task":{"type":"workerpool0","index":0}}' +_TEST_CLUSTER_SPEC_WORKER_STR = '{"cluster":{"workerpool0":["cmle-training-workerpool0-1d969d3ba6-0:2222"],"workerpool1":["cmle-training-workerpool1-1d969d3ba6-0:2222"]},"environment":"cloud","task":{"type":"workerpool1","index":0}}' + +_TEST_OUTPUT_PATH = "gs://test-bucket/output" + + +def _get_vertex_cluster_spec(task_type: str = "workerpool0", task_index: int = 0): + # pylint: disable=protected-access,missing-function-docstring + return { + "cluster": { + remote_specs._CHIEF: ["cmle-training-workerpool0-id-0:2222"], + remote_specs._WORKER: [ + "cmle-training-workerpool1-id-0:2222", + "cmle-training-workerpool1-id-1:2222", + "cmle-training-workerpool1-id-2:2222", + ], + remote_specs._SERVER: [ + "cmle-training-workerpool2-id-0:2222", + "cmle-training-workerpool2-id-1:2222", + "cmle-training-workerpool2-id-2:2222", + ], + remote_specs._EVALUATOR: ["cmle-training-workerpool3-id-0:2222"], + }, + remote_specs._TASK: { + remote_specs._TYPE: task_type, + remote_specs._INDEX: task_index, + }, + } + + +class TestRemoteSpec: + """Tests for parameter spec classes and helper function(s).""" + + # pylint: disable=protected-access,missing-function-docstring + @pytest.mark.parametrize( + "name,expected_argument_name", + [ + ("self.a", "a"), + ("a.b.c", "c"), + ("_arg_0", "arg_0"), + ("__arg_0", "__arg_0"), + ("arg_0", "arg_0"), + ], + ) + def test_get_argument_name(self, name: str, expected_argument_name: str): + argument_name = remote_specs._get_argument_name(name) + assert argument_name == expected_argument_name + + # pylint: disable=missing-function-docstring,protected-access + @pytest.mark.parametrize( + "name", + [ + ("."), + (".."), + ("_"), + ], + ) + def test_get_argument_name_invalid(self, name: str): + err_msg = f"Failed to get argument name from name {name}." + with pytest.raises(ValueError) as e: + remote_specs._get_argument_name(name) + assert re.match(err_msg, str(e.value)) + + def test_input_parameter_spec_default(self): + param_spec = remote_specs._InputParameterSpec("arg_0") + assert param_spec.name == "arg_0" + assert param_spec.argument_name == "arg_0" + assert param_spec.serializer == "literal" + + def test_input_parameter_spec_argument_name(self): + param_spec = remote_specs._InputParameterSpec("arg_0", argument_name="input_0") + assert param_spec.name == "arg_0" + assert param_spec.argument_name == "input_0" + assert param_spec.serializer == "literal" + + def test_input_parameter_spec_argument_name_empty(self): + err_msg = "Input parameter name cannot be empty" + with pytest.raises(ValueError) as e: + remote_specs._InputParameterSpec("") + assert re.match(err_msg, str(e.value)) + + @pytest.mark.parametrize("serializer", ["literal", "parquet", "cloudpickle"]) + def test_input_parameter_spec_serializer_valid(self, serializer: str): + param_spec = remote_specs._InputParameterSpec("arg_0", serializer=serializer) + assert param_spec.name == "arg_0" + assert param_spec.argument_name == "arg_0" + assert param_spec.serializer == serializer + + def test_input_parameter_spec_serializer_invalid(self): + err_msg = "Invalid serializer" + with pytest.raises(ValueError) as e: + remote_specs._InputParameterSpec("arg_0", serializer="invalid") + assert re.match(err_msg, str(e.value)) + + def test_input_format_arg_literal(self): + test_spec = remote_specs._InputParameterSpec("arg_0", serializer="literal") + assert test_spec.format_arg("", _TEST_BINDING) == _TEST_BINDING["arg_0"] + + # pylint: disable=redefined-outer-name + @pytest.mark.usefixtures("google_auth_mock") + def test_input_format_arg_cloudpickle( + self, mock_named_temp_file, mock_blob_upload_from_filename + ): + + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = os.path.join(tmp_dir, "tmp") + tmp_handler = open(tmp_path, "wb") + (mock_named_temp_file.return_value.__enter__.return_value) = tmp_handler + + spec = remote_specs._InputParameterSpec("arg_1", serializer="cloudpickle") + assert ( + spec.format_arg("gs://bucket/path", _TEST_BINDING) + == "gs://bucket/path/arg_1" + ) + mock_blob_upload_from_filename.assert_called_once() + + with open(tmp_path, "rb") as f: + assert cloudpickle.loads(f.read())(1) == _TEST_BINDING["arg_1"](1) + + @pytest.mark.usefixtures("google_auth_mock") + def test_input_format_arg_parquet( + self, mock_named_temp_file, mock_blob_upload_from_filename + ): + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_serialized_path = os.path.join(tmp_dir, "serialized") + tmp_serialized_handler = open(tmp_serialized_path, "wb") + + tmp_metadata_path = os.path.join(tmp_dir, "metadata") + tmp_handler = open(tmp_metadata_path, "wb") + (mock_named_temp_file.return_value.__enter__.side_effect) = [ + tmp_serialized_handler, + tmp_handler, + ] + + spec = remote_specs._InputParameterSpec("arg_2", serializer="parquet") + assert ( + spec.format_arg("gs://bucket/path", _TEST_BINDING) + == "gs://bucket/path/arg_2" + ) + assert mock_blob_upload_from_filename.call_count == 2 + + upload_calls = mock_blob_upload_from_filename.call_args_list + + metadata_path = upload_calls[1][1]["filename"] + + assert metadata_path == tmp_metadata_path + expected_metadata = { + "col_0": { + "dtype": "int64", + "feature_type": "dense", + }, + "col_1": { + "dtype": "int64", + "feature_type": "dense", + }, + } + with open(tmp_metadata_path, "rb") as f: + assert cloudpickle.loads(f.read()) == expected_metadata + + @pytest.mark.parametrize( + "spec,binding,msg", + [ + ( + remote_specs._InputParameterSpec("arg_4"), + _TEST_BINDING, + "Input arg_4 not found in binding", + ), + ( + remote_specs._InputParameterSpec("arg", serializer="parquet"), + {"arg": 10}, + "Parquet serializer is only supported for", + ), + ( + remote_specs._InputParameterSpec("arg_0"), + _TEST_BINDING, + "Unsupported serializer:", + ), + ], + ) + @pytest.mark.usefixtures("google_auth_mock") + def test_input_format_arg_invalid(self, spec, binding, msg): + if msg == "Unsupported serializer:": + spec.serializer = "invalid" + with pytest.raises(ValueError, match=msg): + spec.format_arg("gs://bucket/path", binding) + + def test_output_parameter_spec_default(self): + param_spec = remote_specs._OutputParameterSpec("arg_0") + assert param_spec.name == "arg_0" + assert param_spec.argument_name == "arg_0" + assert param_spec.deserializer == "literal" + + def test_output_parameter_spec_argument_name(self): + param_spec = remote_specs._OutputParameterSpec("arg_0", argument_name="input_0") + assert param_spec.name == "arg_0" + assert param_spec.argument_name == "input_0" + assert param_spec.deserializer == "literal" + + def test_output_parameter_spec_argument_name_empty(self): + err_msg = "Output parameter name cannot be empty" + with pytest.raises(ValueError) as e: + remote_specs._OutputParameterSpec("") + assert re.match(err_msg, str(e.value)) + + @pytest.mark.parametrize("deserializer", ["literal", "cloudpickle"]) + def test_output_parameter_spec_serializer_valid(self, deserializer): + param_spec = remote_specs._OutputParameterSpec( + "arg_0", deserializer=deserializer + ) + assert param_spec.name == "arg_0" + assert param_spec.argument_name == "arg_0" + assert param_spec.deserializer == deserializer + + def test_output_parameter_spec_deserializer_invalid(self): + err_msg = "Invalid deserializer" + with pytest.raises(ValueError) as e: + remote_specs._OutputParameterSpec("arg_0", deserializer="invalid") + assert re.match(err_msg, str(e.value)) + + @pytest.mark.usefixtures("google_auth_mock") + def test_deserialize_output_literal( + self, mock_named_temp_file, mock_blob_download_to_filename + ): + spec = remote_specs._OutputParameterSpec( + "arg_0", deserializer=remote_specs._LITERAL + ) + test_path = "gs://bucket/path" + test_val = "output" + + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = os.path.join(tmp_dir, "tmp_path") + mock_temp_file = mock_named_temp_file.return_value.__enter__() + mock_temp_file.name = tmp_path + + # Writes to a file to be read from. + with open(tmp_path, "w") as f: + f.write(test_val) + + # Tests reading literal output from GCS. + assert spec.deserialize_output(test_path) == test_val + mock_blob_download_to_filename.assert_called_once_with(filename=tmp_path) + + @pytest.mark.usefixtures("google_auth_mock") + def test_deserialize_output_cloudpickle( + self, mock_named_temp_file, mock_blob_download_to_filename + ): + spec = remote_specs._OutputParameterSpec( + "arg_1", deserializer=remote_specs._CLOUDPICKLE + ) + test_path = "gs://bucket/path" + test_val = cloudpickle.dumps(_TEST_BINDING["arg_1"]) + + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = os.path.join(tmp_dir, "tmp_path") + mock_handler = mock_named_temp_file.return_value.__enter__() + mock_handler.name = tmp_path + + # Writes to a file to be read from. + with open(tmp_path, "wb") as f: + f.write(test_val) + + # Tests the deserialized output function works as expected. + with open(tmp_path, "rb") as f: + mock_handler.read = f.read + # Verifies that calling the functions return the same results. + assert spec.deserialize_output(test_path)(1) == _TEST_BINDING["arg_1"]( + 1 + ) + mock_blob_download_to_filename.assert_called_once_with( + filename=tmp_path + ) + + @pytest.mark.usefixtures("google_auth_mock") + def test_deserialize_output_invalid(self): + spec = remote_specs._OutputParameterSpec("arg_0") + spec.deserializer = "invalid" + with pytest.raises(ValueError, match="Unsupported deserializer:"): + spec.deserialize_output("gs://bucket/path") + + def test_gen_gcs_path(self): + base_dir = "gs://test_bucket" + name = "test_name" + expected_path = "gs://test_bucket/test_name" + assert remote_specs._gen_gcs_path(base_dir, name) == expected_path + + def test_gen_gcs_path_invalid(self): + base_dir = "test_bucket" + name = "test_name" + with pytest.raises(ValueError): + remote_specs._gen_gcs_path(base_dir, name) + + def test_gen_gcs_path_remove_suffix(self): + base_dir = "gs://test_bucket" + name = "test_name/" + expected_path = "gs://test_bucket/test_name" + assert remote_specs._gen_gcs_path(base_dir, name) == expected_path + + def test_generate_feature_metadata(self): + df = pd.DataFrame( + { + "col_int": [1, 2, 3], + "col_float": [0.1, 0.2, 0.3], + 0: [0, 1, 0], + "ignored_cat_0": ["0", "1", "0"], + "ignored_cat_1": ["a", "b", "c"], + "ignored_type": ["d", "e", "f"], + } + ) + df["col_int"] = df["col_int"].astype("int64") + df["col_float"] = df["col_float"].astype("float64") + df[0] = df[0].astype("category") + df["ignored_cat_0"] = df["ignored_cat_0"].astype("category") + df["ignored_cat_1"] = df["ignored_cat_1"].astype("category") + df["ignored_type"] = df["ignored_type"].astype("object") + + # ignored_cat and ignored_type do not have feature metadata + expected__metadata = { + "col_int": { + "dtype": "int64", + "feature_type": "dense", + }, + "col_float": { + "dtype": "float64", + "feature_type": "dense", + }, + "0": { + "dtype": "int64", + "feature_type": "dense", + "categories": [0, 1], + }, + } + + original_df = df.copy(deep=True) + assert remote_specs._generate_feature_metadata(df) == expected__metadata + + # Checks that the original dataframe is not modified + assert df.equals(original_df) + + def test_generate_feature_metadata_invalid(self): + with pytest.raises(ValueError, match="Generating feature metadata is"): + remote_specs._generate_feature_metadata([0, 1, 2]) + + +class TestClusterSpec: + """Tests for cluster spec classes and other distributed training helper functions.""" + + # pylint: disable=protected-access,missing-function-docstring + def test_invalid_cluster_info(self): + cluster = { + remote_specs._CHIEF: ["cmle-training-workerpool0-id-0:2222"], + "worker": ["cmle-training-workerpool1-id-0:2222"], + } + + err_msg = "Invalid task type: worker." + with pytest.raises(ValueError) as e: + remote_specs._Cluster(cluster) + assert re.match(err_msg, str(e.value)) + + def test_task_types(self): + cluster = remote_specs._Cluster(_get_vertex_cluster_spec()["cluster"]) + assert cluster.task_types == [ + remote_specs._CHIEF, + remote_specs._WORKER, + remote_specs._SERVER, + remote_specs._EVALUATOR, + ] + + @pytest.mark.parametrize( + "task_type,expected_num_tasks", + [ + (remote_specs._CHIEF, 1), + (remote_specs._WORKER, 3), + (remote_specs._SERVER, 3), + (remote_specs._EVALUATOR, 1), + ], + ) + def test_get_num_tasks(self, task_type, expected_num_tasks): + cluster = remote_specs._Cluster(_get_vertex_cluster_spec()["cluster"]) + assert cluster.get_num_tasks(task_type) == expected_num_tasks + + @pytest.mark.parametrize( + "task_type,expected_task_addresses", + [ + (remote_specs._CHIEF, ["cmle-training-workerpool0-id-0:2222"]), + ( + remote_specs._WORKER, + [ + "cmle-training-workerpool1-id-0:2222", + "cmle-training-workerpool1-id-1:2222", + "cmle-training-workerpool1-id-2:2222", + ], + ), + ( + remote_specs._SERVER, + [ + "cmle-training-workerpool2-id-0:2222", + "cmle-training-workerpool2-id-1:2222", + "cmle-training-workerpool2-id-2:2222", + ], + ), + (remote_specs._EVALUATOR, ["cmle-training-workerpool3-id-0:2222"]), + ], + ) + def test_get_task_addresses(self, task_type, expected_task_addresses): + cluster = remote_specs._Cluster(_get_vertex_cluster_spec()["cluster"]) + assert cluster.get_task_addresses(task_type) == expected_task_addresses + + @pytest.mark.parametrize( + "cluster_spec,expected_rank", + [ + ( + remote_specs._ClusterSpec( + _get_vertex_cluster_spec(remote_specs._CHIEF, 0) + ), + 0, + ), + ( + remote_specs._ClusterSpec( + _get_vertex_cluster_spec(remote_specs._WORKER, 2) + ), + 3, + ), + ( + remote_specs._ClusterSpec( + _get_vertex_cluster_spec(remote_specs._SERVER, 1) + ), + 5, + ), + ( + remote_specs._ClusterSpec( + _get_vertex_cluster_spec(remote_specs._EVALUATOR, 0) + ), + 7, + ), + ], + ) + def test_get_rank(self, cluster_spec, expected_rank): + assert cluster_spec.get_rank() == expected_rank + + def test_get_world_size(self): + cluster_spec = remote_specs._ClusterSpec(_get_vertex_cluster_spec()) + assert cluster_spec.get_world_size() == 8 + + def test_get_chief_address_port(self): + cluster_spec = remote_specs._ClusterSpec(_get_vertex_cluster_spec()) + assert cluster_spec.get_chief_address_port() == ( + "cmle-training-workerpool0-id-0", + 2222, + ) + + +# pylint: disable=protected-access +class TestWorkerPoolSpecs: + """Tests for worker pool spec classes and related functions.""" + + @pytest.mark.parametrize( + "worker_pool_specs,expected_spec", + [ + ( + remote_specs.WorkerPoolSpecs(_TEST_WORKER_POOL_SPEC_OBJ_MACHINE_TYPE), + [_TEST_WORKER_POOL_SPEC_MACHINE_TYPE_CONTAINER_SPEC], + ), + ( + remote_specs.WorkerPoolSpecs( + _TEST_WORKER_POOL_SPEC_OBJ_MACHINE_TYPE, + evaluator=_TEST_WORKER_POOL_SPEC_OBJ_MACHINE_TYPE, + ), + [ + _TEST_WORKER_POOL_SPEC_MACHINE_TYPE_CONTAINER_SPEC, + {}, + {}, + _TEST_WORKER_POOL_SPEC_MACHINE_TYPE_CONTAINER_SPEC, + ], + ), + ( + remote_specs.WorkerPoolSpecs( + _TEST_WORKER_POOL_SPEC_OBJ_MACHINE_TYPE, + server=_TEST_WORKER_POOL_SPEC_OBJ_MACHINE_TYPE, + ), + [ + _TEST_WORKER_POOL_SPEC_MACHINE_TYPE_CONTAINER_SPEC, + {}, + _TEST_WORKER_POOL_SPEC_MACHINE_TYPE_CONTAINER_SPEC, + ], + ), + ], + ) + def test_prepare_worker_pool_specs( + self, + worker_pool_specs: remote_specs.WorkerPoolSpecs, + expected_spec: List[Dict[str, Any]], + ): + assert ( + remote_specs._prepare_worker_pool_specs( + worker_pool_specs, "test-image", ["python3", "run.py"], [] + ) + == expected_spec + ) + + @pytest.mark.parametrize( + "cluster_spec_str,expected_output_path", + [ + ( + _TEST_CLUSTER_SPEC_CHIEF_STR, + os.path.join(_TEST_OUTPUT_PATH, "output_estimator"), + ), + ( + _TEST_CLUSTER_SPEC_WORKER_STR, + os.path.join(_TEST_OUTPUT_PATH, "temp/workerpool1_0"), + ), + ("", os.path.join(_TEST_OUTPUT_PATH, "output_estimator")), + ], + ) + def test_get_output_path_for_distributed_training( + self, cluster_spec_str, expected_output_path + ): + with mock.patch.dict( + os.environ, {remote_specs._CLUSTER_SPEC: cluster_spec_str}, clear=True + ): + with mock.patch("os.makedirs"): + output_path = remote_specs._get_output_path_for_distributed_training( + _TEST_OUTPUT_PATH, "output_estimator" + ) + assert output_path == expected_output_path + + # Temporarily remove these tests since they require tensorflow >= 2.12.0 + # but in our external test environment tf 2.12 is not available due to conflict + # TODO(jayceeli) Add these tests back once we fix the external environment issue. + + # def test_set_keras_distributed_strategy_enable_distributed_multi_worker(self): + # model = tf.keras.Sequential( + # [tf.keras.layers.Dense(5, input_shape=(4,)), tf.keras.layers.Softmax()] + # ) + # model.compile(optimizer="adam", loss="mean_squared_error") + # with mock.patch.dict( + # os.environ, + # {remote_specs._CLUSTER_SPEC: _TEST_CLUSTER_SPEC_CHIEF_STR}, + # clear=True, + # ): + # strategy = remote_specs._get_keras_distributed_strategy(True, None) + # updated_model = remote_specs._set_keras_distributed_strategy( + # model, strategy + # ) + + # assert updated_model.get_config() == model.get_config() + # assert updated_model.get_compile_config() == model.get_compile_config() + # assert "CollectiveAllReduceStrategy" in str( + # type(updated_model.distribute_strategy) + # ) + + # def test_set_keras_distributed_strategy_enable_distributed_multi_gpu(self): + # model = tf.keras.Sequential( + # [tf.keras.layers.Dense(5, input_shape=(4,)), tf.keras.layers.Softmax()] + # ) + # model.compile(optimizer="adam", loss="mean_squared_error") + # # no cluster_spec is set for single worker training + # strategy = remote_specs._get_keras_distributed_strategy(True, None) + # updated_model = remote_specs._set_keras_distributed_strategy(model, strategy) + + # assert updated_model.get_config() == model.get_config() + # assert updated_model.get_compile_config() == model.get_compile_config() + # assert "MirroredStrategy" in str(type(updated_model.distribute_strategy)) + + # def test_set_keras_distributed_strategy_multi_gpu(self): + # model = tf.keras.Sequential( + # [tf.keras.layers.Dense(5, input_shape=(4,)), tf.keras.layers.Softmax()] + # ) + # model.compile(optimizer="adam", loss="mean_squared_error") + # # no cluster_spec is set for single worker training + # strategy = remote_specs._get_keras_distributed_strategy(False, 3) + # updated_model = remote_specs._set_keras_distributed_strategy(model, strategy) + + # assert updated_model.get_config() == model.get_config() + # assert updated_model.get_compile_config() == model.get_compile_config() + # assert "MirroredStrategy" in str(type(updated_model.distribute_strategy)) + + @mock.patch.dict(os.environ, {}, clear=True) + @mock.patch.object(torch.distributed, "init_process_group") + @mock.patch("torch.nn.parallel.DistributedDataParallel") + def test_setup_pytorch_distributed_training( + self, + mock_distributed_data_parallel, + mock_init_process_group, + ): + class TestClass(vertexai.preview.VertexModel, torch.nn.Module): + def __init__(self): + torch.nn.Module.__init__(self) + vertexai.preview.VertexModel.__init__(self) + self.linear = torch.nn.Linear(4, 3) + self.softmax = torch.nn.Softmax(dim=1) + + def forward(self, x): + return self.softmax(self.linear(x)) + + @vertexai.preview.developer.mark.train() + def test_method(self): + return + + model = TestClass() + setattr( + model, + "cluster_spec", + remote_specs._ClusterSpec(json.loads(_TEST_CLUSTER_SPEC_CHIEF_STR)), + ) + setattr(model, "_enable_cuda", False) + + output = remote_specs.setup_pytorch_distributed_training(model) + + mock_init_process_group.assert_called_once_with( + backend="gloo", rank=0, world_size=2 + ) + mock_distributed_data_parallel.assert_called_once_with(model) + + assert ( + os.getenv(remote_specs._MASTER_ADDR) + == "cmle-training-workerpool0-1d969d3ba6-0" + ) + assert os.getenv(remote_specs._MASTER_PORT) == "2222" + assert next(output.parameters()).is_cpu + + @mock.patch.dict(os.environ, {}, clear=True) + def test_setup_pytorch_distributed_training_no_cluster_spec(self): + class TestClass(vertexai.preview.VertexModel, torch.nn.Module): + def __init__(self): + torch.nn.Module.__init__(self) + vertexai.preview.VertexModel.__init__(self) + self.linear = torch.nn.Linear(4, 3) + self.softmax = torch.nn.Softmax(dim=1) + + def forward(self, x): + return self.softmax(self.linear(x)) + + @vertexai.preview.developer.mark.train() + def test_method(self): + return + + model = TestClass() + + assert model == remote_specs.setup_pytorch_distributed_training(model) diff --git a/tests/unit/vertexai/test_remote_training.py b/tests/unit/vertexai/test_remote_training.py new file mode 100644 index 0000000000..6752e023bf --- /dev/null +++ b/tests/unit/vertexai/test_remote_training.py @@ -0,0 +1,1714 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import copy +from importlib import reload +import os +import re +from unittest import mock +from unittest.mock import patch + +import cloudpickle +from google.api_core import exceptions +from google.cloud import aiplatform +from google.cloud.aiplatform import utils +from google.cloud.aiplatform.compat.services import job_service_client +from google.cloud.aiplatform.compat.types import ( + custom_job as gca_custom_job_compat, +) +from google.cloud.aiplatform.compat.types import execution as gca_execution +from google.cloud.aiplatform.compat.types import io as gca_io_compat +from google.cloud.aiplatform.compat.types import ( + job_state as gca_job_state_compat, +) +from google.cloud.aiplatform.compat.types import ( + tensorboard as gca_tensorboard, +) +from google.cloud.aiplatform.metadata import constants as metadata_constants +from google.cloud.aiplatform_v1 import ( + Context as GapicContext, + MetadataServiceClient, + MetadataStore as GapicMetadataStore, + TensorboardServiceClient, +) +import vertexai +from vertexai.preview._workflow.executor import ( + training, +) +from vertexai.preview._workflow.serialization_engine import ( + any_serializer, + serializers_base, +) +from vertexai.preview._workflow.shared import configs +from vertexai.preview._workflow.shared import ( + supported_frameworks, +) +from vertexai.preview.developer import remote_specs +import numpy as np +import pytest +import sklearn +from sklearn.datasets import load_iris +from sklearn.linear_model import _logistic +from sklearn.model_selection import train_test_split +import tensorflow as tf + + +# Manually set tensorflow version for b/295580335 +tf.__version__ = "2.12.0" + + +# vertexai constants +_TEST_PROJECT = "test-project" +_TEST_PROJECT_NUMBER = 123 +_TEST_LOCATION = "us-central1" +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" +_TEST_BUCKET_NAME = "gs://test-bucket" +_TEST_UNIQUE_NAME = "test-unique-name" +_TEST_REMOTE_JOB_NAME = f"remote-job-{_TEST_UNIQUE_NAME}" +_TEST_REMOTE_JOB_BASE_PATH = os.path.join(_TEST_BUCKET_NAME, _TEST_REMOTE_JOB_NAME) +_TEST_EXPERIMENT = "test-experiment" +_TEST_EXPERIMENT_RUN = "test-experiment-run" + +# dataset constants +dataset = load_iris() +_X_TRAIN, _X_TEST, _Y_TRAIN, _Y_TEST = train_test_split( + dataset.data, dataset.target, test_size=0.2, random_state=42 +) + +# custom job constants +_TEST_CUSTOM_JOB_NAME = f"{_TEST_PARENT}/customJobs/12345" +_TEST_UPGRADE_PIP_COMMAND = ( + "export PIP_ROOT_USER_ACTION=ignore && " "pip install --upgrade pip && " +) +_TEST_BASE_DEPS = f"'{training.VERTEX_AI_DEPENDENCY_PATH}' 'absl-py==1.4.0' " +_TEST_CUSTOM_COMMAND = "apt-get update && " "apt-get install -y git && " +_TEST_DEPS = ( + f"'scikit-learn=={sklearn.__version__}' " + f"'numpy=={np.__version__}' " + f"'cloudpickle=={cloudpickle.__version__}' " +) +_TEST_USER_DEPS = ( + f"'torch_cv' " + f"'xgboost==1.6.0' " + f"'numpy' " + f"'scikit-learn=={sklearn.__version__}' " + f"'cloudpickle=={cloudpickle.__version__}' " +) +_TEST_TRAINING_COMMAND = ( + "python3 -m vertexai.preview._workflow.executor.training_script " + "--pass_through_int_args= " + "--pass_through_float_args= " + "--pass_through_str_args= " + "--pass_through_bool_args= " + f"--input_path={os.path.join(_TEST_REMOTE_JOB_BASE_PATH, 'input').replace('gs://', '/gcs/', 1)} " + f"--output_path={os.path.join(_TEST_REMOTE_JOB_BASE_PATH, 'output').replace('gs://', '/gcs/', 1)} " + "--method_name=fit " + f"--arg_names=X,y " + "--enable_cuda=False " + "--enable_distributed=False " + "--accelerator_count=0" +) + +_TEST_AUTOLOG_COMMAND = ( + _TEST_UPGRADE_PIP_COMMAND + + "pip install " + + _TEST_BASE_DEPS.replace( + training.VERTEX_AI_DEPENDENCY_PATH, + training.VERTEX_AI_DEPENDENCY_PATH_AUTOLOGGING, + ) + + _TEST_DEPS + + "&& " + + _TEST_TRAINING_COMMAND + + " --enable_autolog" +) + +_TEST_WORKER_POOL_SPEC = [ + { + "machine_spec": { + "machine_type": "n1-standard-4", + }, + "replica_count": 1, + "container_spec": { + "image_uri": f"python:{supported_frameworks._get_python_minor_version()}", + "command": ["sh", "-c"] + + [ + _TEST_UPGRADE_PIP_COMMAND + + "pip install " + + _TEST_BASE_DEPS + + _TEST_DEPS + + "&& " + + _TEST_TRAINING_COMMAND + ], + "args": [], + }, + } +] +_TEST_CUSTOM_JOB_PROTO = gca_custom_job_compat.CustomJob( + display_name=_TEST_REMOTE_JOB_NAME, + job_spec={ + "worker_pool_specs": _TEST_WORKER_POOL_SPEC, + "base_output_directory": gca_io_compat.GcsDestination( + output_uri_prefix=_TEST_REMOTE_JOB_BASE_PATH + ), + }, +) + +# RemoteConfig constants +_TEST_TRAINING_CONFIG_DISPLAY_NAME = "test-training-config-display-name" +_TEST_TRAINING_CONFIG_STAGING_BUCKET = "gs://test-training-config-staging-bucket" +_TEST_TRAINING_CONFIG_CONTAINER_URI = "gcr.io/custom-image" +_TEST_TRAINING_CONFIG_MACHINE_TYPE = "n1-highmem-4" +_TEST_TRAINING_CONFIG_ACCELERATOR_TYPE = "NVIDIA_TESLA_K80" +_TEST_TRAINING_CONFIG_ACCELERATOR_COUNT = 4 +_TEST_REQUIREMENTS = ["torch_cv", "xgboost==1.6.0", "numpy"] +_TEST_CUSTOM_COMMANDS = ["apt-get update", "apt-get install -y git"] + +_TEST_BOOT_DISK_TYPE = "test_boot_disk_type" +_TEST_BOOT_DISK_SIZE_GB = 10 +_TEST_TRAINING_CONFIG_WORKER_POOL_SPECS = remote_specs.WorkerPoolSpecs( + chief=remote_specs.WorkerPoolSpec( + machine_type=_TEST_TRAINING_CONFIG_MACHINE_TYPE, + replica_count=1, + boot_disk_type=_TEST_BOOT_DISK_TYPE, + boot_disk_size_gb=_TEST_BOOT_DISK_SIZE_GB, + ) +) +_TEST_TRAINING_CONFIG_WORKER_POOL_SPECS_GPU = remote_specs.WorkerPoolSpecs( + chief=remote_specs.WorkerPoolSpec( + machine_type=_TEST_TRAINING_CONFIG_MACHINE_TYPE, + accelerator_count=_TEST_TRAINING_CONFIG_ACCELERATOR_COUNT, + accelerator_type=_TEST_TRAINING_CONFIG_ACCELERATOR_TYPE, + replica_count=1, + boot_disk_type=_TEST_BOOT_DISK_TYPE, + boot_disk_size_gb=_TEST_BOOT_DISK_SIZE_GB, + ) +) + +_TEST_CONTEXT_ID = _TEST_EXPERIMENT +_TEST_CONTEXT_NAME = f"{_TEST_PARENT}/contexts/{_TEST_CONTEXT_ID}" +_TEST_EXPERIMENT_DESCRIPTION = "test-experiment-description" +_TEST_ID = "1028944691210842416" +_TEST_TENSORBOARD_NAME = f"{_TEST_PARENT}/tensorboards/{_TEST_ID}" +_TEST_EXECUTION_ID = f"{_TEST_EXPERIMENT}-{_TEST_EXPERIMENT_RUN}" +_TEST_EXPERIMENT_RUN_CONTEXT_NAME = f"{_TEST_PARENT}/contexts/{_TEST_EXECUTION_ID}" +_TEST_METADATASTORE = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/metadataStores/default" +) + +_EXPERIMENT_MOCK = GapicContext( + name=_TEST_CONTEXT_NAME, + display_name=_TEST_EXPERIMENT, + description=_TEST_EXPERIMENT_DESCRIPTION, + schema_title=metadata_constants.SYSTEM_EXPERIMENT, + schema_version=metadata_constants.SCHEMA_VERSIONS[ + metadata_constants.SYSTEM_EXPERIMENT + ], + metadata={**metadata_constants.EXPERIMENT_METADATA}, +) + +_EXPERIMENT_MOCK.metadata[ + metadata_constants._BACKING_TENSORBOARD_RESOURCE_KEY +] = _TEST_TENSORBOARD_NAME + +_EXPERIMENT_RUN_MOCK = GapicContext( + name=_TEST_EXPERIMENT_RUN_CONTEXT_NAME, + display_name=_TEST_EXPERIMENT_RUN, + schema_title=metadata_constants.SYSTEM_EXPERIMENT_RUN, + schema_version=metadata_constants.SCHEMA_VERSIONS[ + metadata_constants.SYSTEM_EXPERIMENT_RUN + ], + metadata={ + metadata_constants._PARAM_KEY: {}, + metadata_constants._METRIC_KEY: {}, + metadata_constants._STATE_KEY: gca_execution.Execution.State.RUNNING.name, + }, +) + + +_TEST_DEFAULT_TENSORBOARD_NAME = "test-tensorboard-default-name" + +_TEST_DEFAULT_TENSORBOARD_GCA = gca_tensorboard.Tensorboard( + name=_TEST_DEFAULT_TENSORBOARD_NAME, + is_default=True, +) + + +@pytest.fixture +def list_default_tensorboard_mock(): + with patch.object( + TensorboardServiceClient, "list_tensorboards" + ) as list_default_tensorboard_mock: + list_default_tensorboard_mock.side_effect = [ + [_TEST_DEFAULT_TENSORBOARD_GCA], + ] + yield list_default_tensorboard_mock + + +def _get_custom_job_proto( + display_name=None, + staging_bucket=None, + container_uri=None, + machine_type=None, + accelerator_type=None, + accelerator_count=None, + replica_count=None, + boot_disk_type=None, + boot_disk_size_gb=None, + service_account=None, + experiment=None, + experiment_run=None, + autolog_enabled=False, + cuda_enabled=False, + distributed_enabled=False, + model=None, + user_requirements=False, + custom_commands=False, +): + job = copy.deepcopy(_TEST_CUSTOM_JOB_PROTO) + if display_name: + job.display_name = display_name + if container_uri: + job.job_spec.worker_pool_specs[0].container_spec.image_uri = container_uri + job.job_spec.worker_pool_specs[0].container_spec.command[-1] = ( + _TEST_UPGRADE_PIP_COMMAND + + "pip install " + + _TEST_BASE_DEPS + + "&& " + + _TEST_TRAINING_COMMAND + ) + if user_requirements: + job.job_spec.worker_pool_specs[0].container_spec.command[-1] = ( + _TEST_UPGRADE_PIP_COMMAND + + "pip install " + + _TEST_BASE_DEPS + + _TEST_USER_DEPS + + "&& " + + _TEST_TRAINING_COMMAND + ) + if custom_commands: + job.job_spec.worker_pool_specs[0].container_spec.command[-1] = ( + _TEST_UPGRADE_PIP_COMMAND.replace("&& ", f"&& {_TEST_CUSTOM_COMMAND}", 1) + + "pip install " + + _TEST_BASE_DEPS + + _TEST_DEPS + + "&& " + + _TEST_TRAINING_COMMAND + ) + if autolog_enabled: + job.job_spec.worker_pool_specs[0].container_spec.command[ + -1 + ] = _TEST_AUTOLOG_COMMAND + if isinstance(model, tf.Module): + command = job.job_spec.worker_pool_specs[0].container_spec.command + for i, s in enumerate(command): + s = s.replace( + f"scikit-learn=={sklearn.__version__}", f"tensorflow=={tf.__version__}" + ) + s = s.replace("--arg_names=X,y", "--arg_names=x,y") + command[i] = s + job.job_spec.worker_pool_specs[0].container_spec.command = command + if cuda_enabled: + if not container_uri: + job.job_spec.worker_pool_specs[ + 0 + ].container_spec.image_uri = supported_frameworks._get_gpu_container_uri( + model + ) + job.job_spec.worker_pool_specs[0].machine_spec.machine_type = "n1-standard-16" + job.job_spec.worker_pool_specs[ + 0 + ].machine_spec.accelerator_type = "NVIDIA_TESLA_P100" + job.job_spec.worker_pool_specs[0].machine_spec.accelerator_count = 1 + command = job.job_spec.worker_pool_specs[0].container_spec.command + job.job_spec.worker_pool_specs[0].container_spec.command = [ + s.replace("--enable_cuda=False", "--enable_cuda=True") for s in command + ] + if distributed_enabled: + command = job.job_spec.worker_pool_specs[0].container_spec.command + job.job_spec.worker_pool_specs[0].container_spec.command = [ + s.replace("--enable_distributed=False", "--enable_distributed=True") + for s in command + ] + if machine_type: + job.job_spec.worker_pool_specs[0].machine_spec.machine_type = machine_type + if accelerator_type: + job.job_spec.worker_pool_specs[ + 0 + ].machine_spec.accelerator_type = accelerator_type + if accelerator_count: + job.job_spec.worker_pool_specs[ + 0 + ].machine_spec.accelerator_count = accelerator_count + if not distributed_enabled: + command = job.job_spec.worker_pool_specs[0].container_spec.command + job.job_spec.worker_pool_specs[0].container_spec.command = [ + s.replace( + "--accelerator_count=0", + f"--accelerator_count={accelerator_count}", + ) + for s in command + ] + if replica_count: + job.job_spec.worker_pool_specs[0].replica_count = replica_count + if boot_disk_type: + job.job_spec.worker_pool_specs[0].disk_spec.boot_disk_type = boot_disk_type + if boot_disk_size_gb: + job.job_spec.worker_pool_specs[ + 0 + ].disk_spec.boot_disk_size_gb = boot_disk_size_gb + if staging_bucket: + job.job_spec.base_output_directory = gca_io_compat.GcsDestination( + output_uri_prefix=os.path.join(staging_bucket, _TEST_REMOTE_JOB_NAME) + ) + command = job.job_spec.worker_pool_specs[0].container_spec.command + job.job_spec.worker_pool_specs[0].container_spec.command = [ + s.replace(_TEST_BUCKET_NAME[5:], staging_bucket[5:]) for s in command + ] + if service_account: + job.job_spec.service_account = service_account + if experiment: + env = job.job_spec.worker_pool_specs[0].container_spec.env + env.append({"name": metadata_constants.ENV_EXPERIMENT_KEY, "value": experiment}) + if experiment_run: + env = job.job_spec.worker_pool_specs[0].container_spec.env + env.append( + {"name": metadata_constants.ENV_EXPERIMENT_RUN_KEY, "value": experiment_run} + ) + return job + + +@pytest.fixture +def mock_timestamped_unique_name(): + with patch.object(utils, "timestamped_unique_name") as mock_timestamped_unique_name: + mock_timestamped_unique_name.return_value = _TEST_UNIQUE_NAME + yield mock_timestamped_unique_name + + +@pytest.fixture +def mock_autolog_enabled(): + with patch.object( + utils.autologging_utils, "_is_autologging_enabled" + ) as autolog_enabled: + autolog_enabled.return_value = True + yield autolog_enabled + + +@pytest.fixture +def mock_autolog_disabled(): + with patch.object( + utils.autologging_utils, "_is_autologging_enabled" + ) as autolog_disabled: + autolog_disabled.return_value = False + yield autolog_disabled + + +@pytest.fixture +def mock_get_project_number(): + with patch.object( + utils.resource_manager_utils, "get_project_number" + ) as mock_get_project_number: + mock_get_project_number.return_value = _TEST_PROJECT_NUMBER + yield mock_get_project_number + + +@pytest.fixture +def mock_get_experiment_run(): + with patch.object(MetadataServiceClient, "get_context") as mock_get_experiment_run: + mock_get_experiment_run.side_effect = [ + _EXPERIMENT_MOCK, + _EXPERIMENT_RUN_MOCK, + _EXPERIMENT_RUN_MOCK, + ] + + yield mock_get_experiment_run + + +@pytest.fixture +def mock_get_metadata_store(): + with patch.object( + MetadataServiceClient, "get_metadata_store" + ) as mock_get_metadata_store: + mock_get_metadata_store.return_value = GapicMetadataStore( + name=_TEST_METADATASTORE, + ) + yield mock_get_metadata_store + + +@pytest.fixture +def get_artifact_not_found_mock(): + with patch.object(MetadataServiceClient, "get_artifact") as get_artifact_mock: + get_artifact_mock.side_effect = exceptions.NotFound("") + yield get_artifact_mock + + +# we've tested AnySerializer in `test_serializers.py` +# so here we mock the SDK methods directly +@pytest.fixture +def mock_any_serializer_serialize_sklearn(): + with patch.object( + any_serializer.AnySerializer, + "serialize", + side_effect=[ + { + serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [ + f"scikit-learn=={sklearn.__version__}" + ] + }, + { + serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [ + f"numpy=={np.__version__}", + f"cloudpickle=={cloudpickle.__version__}", + ] + }, + { + serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [ + f"numpy=={np.__version__}", + f"cloudpickle=={cloudpickle.__version__}", + ] + }, + ], + ) as mock_any_serializer_serialize: + yield mock_any_serializer_serialize + + +@pytest.fixture +def mock_any_serializer_sklearn( + mock_any_serializer_serialize_sklearn, mock_any_serializer_deserialize_sklearn +): + with patch.object( + any_serializer, + "AnySerializer", + ) as mock_any_serializer_obj: + model = _logistic.LogisticRegression() + model.fit(_X_TRAIN, _Y_TRAIN) + mock_any_serializer_obj.return_value.deserialize = ( + mock_any_serializer_deserialize_sklearn + ) + mock_any_serializer_obj.return_value.serialize = ( + mock_any_serializer_serialize_sklearn + ) + yield mock_any_serializer_obj + + +@pytest.fixture +def mock_any_serializer_deserialize_sklearn(): + with patch.object( + any_serializer.AnySerializer, "deserialize" + ) as mock_any_serializer_deserialize_sklearn: + model = _logistic.LogisticRegression() + returned_model = model.fit(_X_TRAIN, _Y_TRAIN) + mock_any_serializer_deserialize_sklearn.side_effect = [model, returned_model] + yield mock_any_serializer_deserialize_sklearn + + +@pytest.fixture +def mock_any_serializer_keras( + mock_any_serializer_serialize_keras, mock_any_serializer_deserialize_keras +): + with patch.object( + any_serializer, + "AnySerializer", + ) as mock_any_serializer_obj: + model = _logistic.LogisticRegression() + model.fit(_X_TRAIN, _Y_TRAIN) + mock_any_serializer_obj.return_value.deserialize = ( + mock_any_serializer_deserialize_keras + ) + mock_any_serializer_obj.return_value.serialize = ( + mock_any_serializer_serialize_keras + ) + yield mock_any_serializer_obj + + +@pytest.fixture +def mock_any_serializer_serialize_keras(): + with patch.object( + any_serializer.AnySerializer, + "serialize", + side_effect=[ + { + serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [ + f"tensorflow=={tf.__version__}" + ] + }, + { + serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [ + f"numpy=={np.__version__}", + f"cloudpickle=={cloudpickle.__version__}", + ] + }, + { + serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [ + f"numpy=={np.__version__}", + f"cloudpickle=={cloudpickle.__version__}", + ] + }, + ], + ) as mock_any_serializer_serialize: + yield mock_any_serializer_serialize + + +@pytest.fixture +def mock_any_serializer_deserialize_keras(): + with patch.object( + any_serializer.AnySerializer, "deserialize" + ) as mock_any_serializer_deserialize_keras: + model = tf.keras.Sequential( + [tf.keras.layers.Dense(5, input_shape=(4,)), tf.keras.layers.Softmax()] + ) + model.compile(optimizer="adam", loss="mean_squared_error") + returned_history = model.fit(_X_TRAIN, _Y_TRAIN) + mock_any_serializer_deserialize_keras.side_effect = [model, returned_history] + yield mock_any_serializer_deserialize_keras + + +@pytest.fixture +def mock_create_custom_job(): + with mock.patch.object( + job_service_client.JobServiceClient, "create_custom_job" + ) as mock_create_custom_job: + custom_job_proto = _get_custom_job_proto() + custom_job_proto.name = _TEST_CUSTOM_JOB_NAME + custom_job_proto.state = gca_job_state_compat.JobState.JOB_STATE_PENDING + mock_create_custom_job.return_value = custom_job_proto + yield mock_create_custom_job + + +@pytest.fixture +def mock_get_custom_job(): + with patch.object( + job_service_client.JobServiceClient, "get_custom_job" + ) as mock_get_custom_job: + custom_job_proto = _get_custom_job_proto() + custom_job_proto.name = _TEST_CUSTOM_JOB_NAME + custom_job_proto.state = gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED + mock_get_custom_job.return_value = custom_job_proto + yield mock_get_custom_job + + +@pytest.fixture +def update_context_mock(): + with patch.object(MetadataServiceClient, "update_context") as update_context_mock: + update_context_mock.side_effect = [_EXPERIMENT_RUN_MOCK] * 4 + yield update_context_mock + + +@pytest.fixture +def aiplatform_autolog_mock(): + with patch.object(aiplatform, "autolog") as aiplatform_autolog_mock: + yield aiplatform_autolog_mock + + +# unittest `assert_any_call` method doesn't work when arguments contain `np.ndarray` +# https://stackoverflow.com/questions/56644729/mock-assert-mock-calls-with-a-numpy-array-as-argument-raises-valueerror-and-np +# tentatively runtime patch `assert_any_call` to solve this issue +def assert_any_call_for_numpy(self, **kwargs): + """Used by vertexai Serializer mock, only check kwargs.""" + found = False + for call in self.call_args_list: + equal = True + actual_kwargs = call[1] + for k, v in actual_kwargs.items(): + if k not in kwargs: + equal = False + break + try: + equal = v == kwargs[k] + except ValueError: + equal = False + equal = equal.all() if isinstance(equal, np.ndarray) else equal + if not equal: + break + + if equal and len(actual_kwargs) == len(kwargs): + found = True + break + + if not found: + raise AssertionError(f"{kwargs} not found.") + + +mock.Mock.assert_any_call = assert_any_call_for_numpy + +# TODO(zhenyiqi) fix external unit test failure caused by this method +training._add_indirect_dependency_versions = lambda x: x + + +@pytest.mark.usefixtures("google_auth_mock", "mock_cloud_logging_list_entries") +class TestRemoteTraining: + def setup_method(self): + reload(aiplatform.initializer) + reload(aiplatform) + reload(vertexai.preview.initializer) + reload(vertexai) + reload(_logistic) + reload(tf.keras) + + def teardown_method(self): + aiplatform.initializer.global_pool.shutdown(wait=True) + + @pytest.mark.usefixtures( + "mock_timestamped_unique_name", "mock_get_custom_job", "mock_autolog_disabled" + ) + def test_remote_training_sklearn( + self, + mock_any_serializer_sklearn, + mock_create_custom_job, + ): + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_BUCKET_NAME, + ) + vertexai.preview.init(remote=True) + + LogisticRegression = vertexai.preview.remote(_logistic.LogisticRegression) + model = LogisticRegression() + model.fit(_X_TRAIN, _Y_TRAIN) + + # check that model is serialized correctly + mock_any_serializer_sklearn.return_value.serialize.assert_any_call( + to_serialize=model, + gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/input_estimator"), + ) + + # check that args are serialized correctly + mock_any_serializer_sklearn.return_value.serialize.assert_any_call( + to_serialize=_X_TRAIN, + gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/X"), + ) + mock_any_serializer_sklearn.return_value.serialize.assert_any_call( + to_serialize=_Y_TRAIN, + gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/y"), + ) + + # ckeck that CustomJob is created correctly + expected_custom_job = _get_custom_job_proto() + mock_create_custom_job.assert_called_once_with( + parent=_TEST_PARENT, + custom_job=expected_custom_job, + timeout=None, + ) + + # check that trained model is deserialized correctly + mock_any_serializer_sklearn.return_value.deserialize.assert_has_calls( + [ + mock.call( + os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "output/output_estimator") + ), + mock.call( + os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "output/output_data") + ), + ] + ) + + # change to `vertexai.preview.init(remote=False)` to use local prediction + vertexai.preview.init(remote=False) + + # check that local model is updated in place + # `model.score` raises NotFittedError if the model is not updated + model.score(_X_TEST, _Y_TEST) + + @pytest.mark.usefixtures( + "mock_timestamped_unique_name", "mock_get_custom_job", "mock_autolog_disabled" + ) + def test_remote_training_sklearn_with_user_requirements( + self, + mock_any_serializer_sklearn, + mock_create_custom_job, + ): + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_BUCKET_NAME, + ) + vertexai.preview.init(remote=True) + + LogisticRegression = vertexai.preview.remote(_logistic.LogisticRegression) + model = LogisticRegression() + model.fit.vertex.remote_config.requirements = _TEST_REQUIREMENTS + + model.fit(_X_TRAIN, _Y_TRAIN) + + # check that model is serialized correctly + mock_any_serializer_sklearn.return_value.serialize.assert_any_call( + to_serialize=model, + gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/input_estimator"), + ) + + # check that args are serialized correctly + mock_any_serializer_sklearn.return_value.serialize.assert_any_call( + to_serialize=_X_TRAIN, + gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/X"), + ) + mock_any_serializer_sklearn.return_value.serialize.assert_any_call( + to_serialize=_Y_TRAIN, + gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/y"), + ) + + # ckeck that CustomJob is created correctly + expected_custom_job = _get_custom_job_proto(user_requirements=True) + mock_create_custom_job.assert_called_once_with( + parent=_TEST_PARENT, + custom_job=expected_custom_job, + timeout=None, + ) + + # check that trained model is deserialized correctly + mock_any_serializer_sklearn.return_value.deserialize.assert_has_calls( + [ + mock.call( + os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "output/output_estimator") + ), + mock.call( + os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "output/output_data") + ), + ] + ) + + # change to `vertexai.preview.init(remote=False)` to use local prediction + vertexai.preview.init(remote=False) + + # check that local model is updated in place + # `model.score` raises NotFittedError if the model is not updated + model.score(_X_TEST, _Y_TEST) + + @pytest.mark.usefixtures( + "mock_timestamped_unique_name", "mock_get_custom_job", "mock_autolog_disabled" + ) + def test_remote_training_sklearn_with_custom_commands( + self, + mock_any_serializer_sklearn, + mock_create_custom_job, + ): + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_BUCKET_NAME, + ) + vertexai.preview.init(remote=True) + + LogisticRegression = vertexai.preview.remote(_logistic.LogisticRegression) + model = LogisticRegression() + model.fit.vertex.remote_config.custom_commands = _TEST_CUSTOM_COMMANDS + + model.fit(_X_TRAIN, _Y_TRAIN) + + # check that model is serialized correctly + mock_any_serializer_sklearn.return_value.serialize.assert_any_call( + to_serialize=model, + gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/input_estimator"), + ) + + # check that args are serialized correctly + mock_any_serializer_sklearn.return_value.serialize.assert_any_call( + to_serialize=_X_TRAIN, + gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/X"), + ) + mock_any_serializer_sklearn.return_value.serialize.assert_any_call( + to_serialize=_Y_TRAIN, + gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/y"), + ) + + # ckeck that CustomJob is created correctly + expected_custom_job = _get_custom_job_proto(custom_commands=True) + mock_create_custom_job.assert_called_once_with( + parent=_TEST_PARENT, + custom_job=expected_custom_job, + timeout=None, + ) + + # check that trained model is deserialized correctly + mock_any_serializer_sklearn.return_value.deserialize.assert_has_calls( + [ + mock.call( + os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "output/output_estimator") + ), + mock.call( + os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "output/output_data") + ), + ] + ) + + # change to `vertexai.preview.init(remote=False)` to use local prediction + vertexai.preview.init(remote=False) + + # check that local model is updated in place + # `model.score` raises NotFittedError if the model is not updated + model.score(_X_TEST, _Y_TEST) + + @pytest.mark.usefixtures( + "mock_timestamped_unique_name", "mock_get_custom_job", "mock_autolog_disabled" + ) + def test_remote_training_sklearn_with_remote_configs( + self, + mock_any_serializer_sklearn, + mock_create_custom_job, + ): + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_BUCKET_NAME, + ) + vertexai.preview.init(remote=True) + + LogisticRegression = vertexai.preview.remote(_logistic.LogisticRegression) + model = LogisticRegression() + + # set all training configs + model.fit.vertex.remote_config.display_name = _TEST_TRAINING_CONFIG_DISPLAY_NAME + model.fit.vertex.remote_config.staging_bucket = ( + _TEST_TRAINING_CONFIG_STAGING_BUCKET + ) + model.fit.vertex.remote_config.container_uri = ( + _TEST_TRAINING_CONFIG_CONTAINER_URI + ) + model.fit.vertex.remote_config.machine_type = _TEST_TRAINING_CONFIG_MACHINE_TYPE + + model.fit(_X_TRAIN, _Y_TRAIN) + + remote_job_base_path = os.path.join( + _TEST_TRAINING_CONFIG_STAGING_BUCKET, _TEST_REMOTE_JOB_NAME + ) + + # check that model is serialized correctly + mock_any_serializer_sklearn.return_value.serialize.assert_any_call( + to_serialize=model, + gcs_path=os.path.join(remote_job_base_path, "input/input_estimator"), + ) + + # check that args are serialized correctly + mock_any_serializer_sklearn.return_value.serialize.assert_any_call( + to_serialize=_X_TRAIN, + gcs_path=os.path.join(remote_job_base_path, "input/X"), + ) + mock_any_serializer_sklearn.return_value.serialize.assert_any_call( + to_serialize=_Y_TRAIN, + gcs_path=os.path.join(remote_job_base_path, "input/y"), + ) + + # ckeck that CustomJob is created correctly + expected_custom_job = _get_custom_job_proto( + display_name=_TEST_TRAINING_CONFIG_DISPLAY_NAME, + staging_bucket=_TEST_TRAINING_CONFIG_STAGING_BUCKET, + container_uri=_TEST_TRAINING_CONFIG_CONTAINER_URI, + machine_type=_TEST_TRAINING_CONFIG_MACHINE_TYPE, + ) + mock_create_custom_job.assert_called_once_with( + parent=_TEST_PARENT, + custom_job=expected_custom_job, + timeout=None, + ) + + # check that trained model is deserialized correctly + mock_any_serializer_sklearn.return_value.deserialize.assert_has_calls( + [ + mock.call( + os.path.join(remote_job_base_path, "output/output_estimator") + ), + mock.call(os.path.join(remote_job_base_path, "output/output_data")), + ] + ) + + # change to `vertexai.preview.init(remote=False)` to use local prediction + vertexai.preview.init(remote=False) + + # check that local model is updated in place + # `model.score` raises NotFittedError if the model is not updated + model.score(_X_TEST, _Y_TEST) + + @pytest.mark.usefixtures( + "mock_timestamped_unique_name", "mock_get_custom_job", "mock_autolog_disabled" + ) + def test_remote_training_sklearn_with_worker_pool_specs( + self, + mock_any_serializer_sklearn, + mock_create_custom_job, + ): + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_BUCKET_NAME, + ) + vertexai.preview.init(remote=True) + + LogisticRegression = vertexai.preview.remote(_logistic.LogisticRegression) + model = LogisticRegression() + + # set all training configs + model.fit.vertex.remote_config.display_name = _TEST_TRAINING_CONFIG_DISPLAY_NAME + model.fit.vertex.remote_config.staging_bucket = ( + _TEST_TRAINING_CONFIG_STAGING_BUCKET + ) + model.fit.vertex.remote_config.container_uri = ( + _TEST_TRAINING_CONFIG_CONTAINER_URI + ) + model.fit.vertex.remote_config.worker_pool_specs = ( + _TEST_TRAINING_CONFIG_WORKER_POOL_SPECS + ) + + model.fit(_X_TRAIN, _Y_TRAIN) + + remote_job_base_path = os.path.join( + _TEST_TRAINING_CONFIG_STAGING_BUCKET, _TEST_REMOTE_JOB_NAME + ) + + # check that model is serialized correctly + mock_any_serializer_sklearn.return_value.serialize.assert_any_call( + to_serialize=model, + gcs_path=os.path.join(remote_job_base_path, "input/input_estimator"), + ) + + # check that args are serialized correctly + mock_any_serializer_sklearn.return_value.serialize.assert_any_call( + to_serialize=_X_TRAIN, + gcs_path=os.path.join(remote_job_base_path, "input/X"), + ) + mock_any_serializer_sklearn.return_value.serialize.assert_any_call( + to_serialize=_Y_TRAIN, + gcs_path=os.path.join(remote_job_base_path, "input/y"), + ) + + # ckeck that CustomJob is created correctly + expected_custom_job = _get_custom_job_proto( + display_name=_TEST_TRAINING_CONFIG_DISPLAY_NAME, + staging_bucket=_TEST_TRAINING_CONFIG_STAGING_BUCKET, + container_uri=_TEST_TRAINING_CONFIG_CONTAINER_URI, + machine_type="n1-standard-4", + ) + mock_create_custom_job.assert_called_once_with( + parent=_TEST_PARENT, + custom_job=expected_custom_job, + timeout=None, + ) + + # check that trained model is deserialized correctly + mock_any_serializer_sklearn.return_value.deserialize.assert_has_calls( + [ + mock.call( + os.path.join(remote_job_base_path, "output/output_estimator") + ), + mock.call(os.path.join(remote_job_base_path, "output/output_data")), + ] + ) + + # change to `vertexai.preview.init(remote=False)` to use local prediction + vertexai.preview.init(remote=False) + + # check that local model is updated in place + # `model.score` raises NotFittedError if the model is not updated + model.score(_X_TEST, _Y_TEST) + + @pytest.mark.usefixtures( + "mock_timestamped_unique_name", + "mock_get_custom_job", + "mock_any_serializer_deserialize_sklearn", + "mock_autolog_disabled", + ) + def test_remote_training_sklearn_with_set_config( + self, + mock_any_serializer_serialize_sklearn, + mock_create_custom_job, + ): + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_BUCKET_NAME, + ) + vertexai.preview.init(remote=True) + + LogisticRegression = vertexai.preview.remote(_logistic.LogisticRegression) + model = LogisticRegression() + + # set training config via dict + model.fit.vertex.set_config( + display_name=_TEST_TRAINING_CONFIG_DISPLAY_NAME, + staging_bucket=_TEST_TRAINING_CONFIG_STAGING_BUCKET, + container_uri=_TEST_TRAINING_CONFIG_CONTAINER_URI, + worker_pool_specs=_TEST_TRAINING_CONFIG_WORKER_POOL_SPECS, + ) + + model.fit(_X_TRAIN, _Y_TRAIN) + + remote_job_base_path = os.path.join( + _TEST_TRAINING_CONFIG_STAGING_BUCKET, _TEST_REMOTE_JOB_NAME + ) + + # check that model is serialized correctly + mock_any_serializer_serialize_sklearn.assert_any_call( + to_serialize=model, + gcs_path=os.path.join(remote_job_base_path, "input/input_estimator"), + ) + + # check that args are serialized correctly + mock_any_serializer_serialize_sklearn.assert_any_call( + to_serialize=_X_TRAIN, + gcs_path=os.path.join(remote_job_base_path, "input/X"), + ) + mock_any_serializer_serialize_sklearn.assert_any_call( + to_serialize=_Y_TRAIN, + gcs_path=os.path.join(remote_job_base_path, "input/y"), + ) + + # ckeck that CustomJob is created correctly + expected_custom_job = _get_custom_job_proto( + display_name=_TEST_TRAINING_CONFIG_DISPLAY_NAME, + staging_bucket=_TEST_TRAINING_CONFIG_STAGING_BUCKET, + container_uri=_TEST_TRAINING_CONFIG_CONTAINER_URI, + machine_type="n1-standard-4", + ) + mock_create_custom_job.assert_called_once_with( + parent=_TEST_PARENT, + custom_job=expected_custom_job, + timeout=None, + ) + + @pytest.mark.usefixtures( + "mock_timestamped_unique_name", + "mock_get_custom_job", + "mock_any_serializer_sklearn", + "mock_autolog_disabled", + ) + def test_set_config_raises_with_unsupported_arg( + self, + mock_create_custom_job, + ): + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_BUCKET_NAME, + ) + vertexai.preview.init(remote=True) + + LogisticRegression = vertexai.preview.remote(_logistic.LogisticRegression) + model = LogisticRegression() + + # RemoteConfig doesn't have `boot_disk_type`, only DistributedTrainingConfig + with pytest.raises(ValueError): + model.fit.vertex.set_config(boot_disk_type=_TEST_BOOT_DISK_TYPE) + + @pytest.mark.usefixtures( + "mock_timestamped_unique_name", "mock_get_custom_job", "mock_autolog_disabled" + ) + def test_remote_training_sklearn_with_invalid_remote_config( + self, + mock_any_serializer_sklearn, + mock_create_custom_job, + ): + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_BUCKET_NAME, + ) + vertexai.preview.init(remote=True) + + LogisticRegression = vertexai.preview.remote(_logistic.LogisticRegression) + model = LogisticRegression() + + # set all training configs + model.fit.vertex.remote_config.display_name = _TEST_TRAINING_CONFIG_DISPLAY_NAME + model.fit.vertex.remote_config.staging_bucket = ( + _TEST_TRAINING_CONFIG_STAGING_BUCKET + ) + model.fit.vertex.remote_config.container_uri = ( + _TEST_TRAINING_CONFIG_CONTAINER_URI + ) + model.fit.vertex.remote_config.worker_pool_specs = ( + _TEST_TRAINING_CONFIG_WORKER_POOL_SPECS + ) + model.fit.vertex.remote_config.machine_type = _TEST_TRAINING_CONFIG_MACHINE_TYPE + + with pytest.raises( + ValueError, + match=re.escape( + "Cannot specify both 'worker_pool_specs' and ['machine_type', 'accelerator_type', 'accelerator_count', 'replica_count', 'boot_disk_type', 'boot_disk_size_gb']." + ), + ): + model.fit(_X_TRAIN, _Y_TRAIN) + + @pytest.mark.usefixtures( + "mock_timestamped_unique_name", "mock_get_custom_job", "mock_autolog_disabled" + ) + def test_remote_gpu_training_keras( + self, + mock_any_serializer_keras, + mock_create_custom_job, + ): + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_BUCKET_NAME, + ) + vertexai.preview.init(remote=True) + + tf.keras.VertexSequential = vertexai.preview.remote(tf.keras.Sequential) + model = tf.keras.VertexSequential( + [tf.keras.layers.Dense(5, input_shape=(4,)), tf.keras.layers.Softmax()] + ) + model.compile(optimizer="adam", loss="mean_squared_error") + model.fit.vertex.remote_config.enable_cuda = True + model.fit(_X_TRAIN, _Y_TRAIN) + + # check that model is serialized correctly + mock_any_serializer_keras.return_value.serialize.assert_any_call( + to_serialize=model, + gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/input_estimator"), + ) + + # check that args are serialized correctly + mock_any_serializer_keras.return_value.serialize.assert_any_call( + to_serialize=_X_TRAIN, + gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/x"), + ) + mock_any_serializer_keras.return_value.serialize.assert_any_call( + to_serialize=_Y_TRAIN, + gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/y"), + ) + + # ckeck that CustomJob is created correctly + expected_custom_job = _get_custom_job_proto(cuda_enabled=True, model=model) + mock_create_custom_job.assert_called_once_with( + parent=_TEST_PARENT, + custom_job=expected_custom_job, + timeout=None, + ) + + # check that trained model is deserialized correctly + mock_any_serializer_keras.return_value.deserialize.assert_has_calls( + [ + mock.call( + os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "output/output_estimator") + ), + mock.call( + os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "output/output_data"), + model=model, + ), + ] + ) + + @pytest.mark.usefixtures( + "mock_timestamped_unique_name", "mock_get_custom_job", "mock_autolog_disabled" + ) + def test_remote_gpu_training_keras_with_remote_configs( + self, + mock_any_serializer_keras, + mock_create_custom_job, + ): + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_BUCKET_NAME, + ) + vertexai.preview.init(remote=True) + + tf.keras.VertexSequential = vertexai.preview.remote(tf.keras.Sequential) + model = tf.keras.VertexSequential( + [tf.keras.layers.Dense(5, input_shape=(4,)), tf.keras.layers.Softmax()] + ) + model.compile(optimizer="adam", loss="mean_squared_error") + + model.fit.vertex.remote_config.enable_cuda = True + model.fit.vertex.remote_config.container_uri = ( + _TEST_TRAINING_CONFIG_CONTAINER_URI + ) + model.fit.vertex.remote_config.machine_type = _TEST_TRAINING_CONFIG_MACHINE_TYPE + model.fit.vertex.remote_config.accelerator_type = ( + _TEST_TRAINING_CONFIG_ACCELERATOR_TYPE + ) + model.fit.vertex.remote_config.accelerator_count = ( + _TEST_TRAINING_CONFIG_ACCELERATOR_COUNT + ) + + model.fit(_X_TRAIN, _Y_TRAIN) + + # check that model is serialized correctly + mock_any_serializer_keras.return_value.serialize.assert_any_call( + to_serialize=model, + gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/input_estimator"), + ) + + # check that args are serialized correctly + mock_any_serializer_keras.return_value.serialize.assert_any_call( + to_serialize=_X_TRAIN, + gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/x"), + ) + mock_any_serializer_keras.return_value.serialize.assert_any_call( + to_serialize=_Y_TRAIN, + gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/y"), + ) + + # ckeck that CustomJob is created correctly + expected_custom_job = _get_custom_job_proto( + cuda_enabled=True, + model=model, + container_uri=_TEST_TRAINING_CONFIG_CONTAINER_URI, + machine_type=_TEST_TRAINING_CONFIG_MACHINE_TYPE, + accelerator_type=_TEST_TRAINING_CONFIG_ACCELERATOR_TYPE, + accelerator_count=_TEST_TRAINING_CONFIG_ACCELERATOR_COUNT, + ) + mock_create_custom_job.assert_called_once_with( + parent=_TEST_PARENT, + custom_job=expected_custom_job, + timeout=None, + ) + + # check that trained model is deserialized correctly + mock_any_serializer_keras.return_value.deserialize.assert_has_calls( + [ + mock.call( + os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "output/output_estimator") + ), + mock.call( + os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "output/output_data"), + model=model, + ), + ] + ) + + @pytest.mark.usefixtures( + "mock_timestamped_unique_name", "mock_get_custom_job", "mock_autolog_disabled" + ) + def test_remote_training_keras_with_worker_pool_specs( + self, + mock_any_serializer_keras, + mock_create_custom_job, + ): + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_BUCKET_NAME, + ) + vertexai.preview.init(remote=True) + + tf.keras.VertexSequential = vertexai.preview.remote(tf.keras.Sequential) + model = tf.keras.VertexSequential( + [tf.keras.layers.Dense(5, input_shape=(4,)), tf.keras.layers.Softmax()] + ) + model.compile(optimizer="adam", loss="mean_squared_error") + + model.fit.vertex.remote_config.enable_distributed = True + model.fit.vertex.remote_config.enable_cuda = True + model.fit.vertex.remote_config.container_uri = ( + _TEST_TRAINING_CONFIG_CONTAINER_URI + ) + model.fit.vertex.remote_config.worker_pool_specs = ( + _TEST_TRAINING_CONFIG_WORKER_POOL_SPECS_GPU + ) + + model.fit(_X_TRAIN, _Y_TRAIN) + + # check that model is serialized correctly + mock_any_serializer_keras.return_value.serialize.assert_any_call( + to_serialize=model, + gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/input_estimator"), + ) + + # check that args are serialized correctly + mock_any_serializer_keras.return_value.serialize.assert_any_call( + to_serialize=_X_TRAIN, + gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/x"), + ) + mock_any_serializer_keras.return_value.serialize.assert_any_call( + to_serialize=_Y_TRAIN, + gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/y"), + ) + + # ckeck that CustomJob is created correctly + expected_custom_job = _get_custom_job_proto( + cuda_enabled=True, + model=model, + container_uri=_TEST_TRAINING_CONFIG_CONTAINER_URI, + machine_type=_TEST_TRAINING_CONFIG_MACHINE_TYPE, + accelerator_type=_TEST_TRAINING_CONFIG_ACCELERATOR_TYPE, + accelerator_count=_TEST_TRAINING_CONFIG_ACCELERATOR_COUNT, + boot_disk_type=_TEST_BOOT_DISK_TYPE, + boot_disk_size_gb=_TEST_BOOT_DISK_SIZE_GB, + distributed_enabled=True, + ) + mock_create_custom_job.assert_called_once_with( + parent=_TEST_PARENT, + custom_job=expected_custom_job, + timeout=None, + ) + + # check that trained model is deserialized correctly + mock_any_serializer_keras.return_value.deserialize.assert_has_calls( + [ + mock.call( + os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "output/output_estimator") + ), + mock.call( + os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "output/output_data"), + model=model, + ), + ] + ) + + @pytest.mark.usefixtures( + "mock_timestamped_unique_name", "mock_get_custom_job", "mock_autolog_disabled" + ) + def test_remote_training_keras_distributed_cuda_no_worker_pool_specs( + self, + mock_any_serializer_keras, + mock_create_custom_job, + ): + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_BUCKET_NAME, + ) + vertexai.preview.init(remote=True) + + tf.keras.VertexSequential = vertexai.preview.remote(tf.keras.Sequential) + model = tf.keras.VertexSequential( + [tf.keras.layers.Dense(5, input_shape=(4,)), tf.keras.layers.Softmax()] + ) + model.compile(optimizer="adam", loss="mean_squared_error") + + model.fit.vertex.remote_config.enable_distributed = True + model.fit.vertex.remote_config.enable_cuda = True + model.fit.vertex.remote_config.container_uri = ( + _TEST_TRAINING_CONFIG_CONTAINER_URI + ) + + model.fit(_X_TRAIN, _Y_TRAIN) + + # check that model is serialized correctly + mock_any_serializer_keras.return_value.serialize.assert_any_call( + to_serialize=model, + gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/input_estimator"), + ) + + # check that args are serialized correctly + mock_any_serializer_keras.return_value.serialize.assert_any_call( + to_serialize=_X_TRAIN, + gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/x"), + ) + mock_any_serializer_keras.return_value.serialize.assert_any_call( + to_serialize=_Y_TRAIN, + gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/y"), + ) + + # ckeck that CustomJob is created correctly + expected_custom_job = _get_custom_job_proto( + cuda_enabled=True, + model=model, + container_uri=_TEST_TRAINING_CONFIG_CONTAINER_URI, + machine_type="n1-standard-16", + accelerator_type="NVIDIA_TESLA_P100", + accelerator_count=1, + boot_disk_type="pd-ssd", + boot_disk_size_gb=100, + distributed_enabled=True, + ) + + expected_custom_job.job_spec.worker_pool_specs = [ + expected_custom_job.job_spec.worker_pool_specs[0], + expected_custom_job.job_spec.worker_pool_specs[0], + ] + + mock_create_custom_job.assert_called_once_with( + parent=_TEST_PARENT, + custom_job=expected_custom_job, + timeout=None, + ) + + # check that trained model is deserialized correctly + mock_any_serializer_keras.return_value.deserialize.assert_has_calls( + [ + mock.call( + os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "output/output_estimator") + ), + mock.call( + os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "output/output_data"), + model=model, + ), + ] + ) + + @pytest.mark.usefixtures( + "mock_timestamped_unique_name", "mock_get_custom_job", "mock_autolog_disabled" + ) + def test_remote_training_keras_distributed_no_cuda_no_worker_pool_specs( + self, + mock_any_serializer_keras, + mock_create_custom_job, + ): + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_BUCKET_NAME, + ) + vertexai.preview.init(remote=True) + + tf.keras.VertexSequential = vertexai.preview.remote(tf.keras.Sequential) + model = tf.keras.VertexSequential( + [tf.keras.layers.Dense(5, input_shape=(4,)), tf.keras.layers.Softmax()] + ) + model.compile(optimizer="adam", loss="mean_squared_error") + + model.fit.vertex.remote_config.enable_distributed = True + model.fit.vertex.remote_config.enable_cuda = False + model.fit.vertex.remote_config.container_uri = ( + _TEST_TRAINING_CONFIG_CONTAINER_URI + ) + + model.fit(_X_TRAIN, _Y_TRAIN) + + # check that model is serialized correctly + mock_any_serializer_keras.return_value.serialize.assert_any_call( + to_serialize=model, + gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/input_estimator"), + ) + + # check that args are serialized correctly + mock_any_serializer_keras.return_value.serialize.assert_any_call( + to_serialize=_X_TRAIN, + gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/x"), + ) + mock_any_serializer_keras.return_value.serialize.assert_any_call( + to_serialize=_Y_TRAIN, + gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/y"), + ) + + # check that CustomJob is created correctly + expected_custom_job = _get_custom_job_proto( + model=model, + container_uri=_TEST_TRAINING_CONFIG_CONTAINER_URI, + machine_type="n1-standard-4", + boot_disk_type="pd-ssd", + boot_disk_size_gb=100, + distributed_enabled=True, + ) + expected_custom_job.job_spec.worker_pool_specs = [ + expected_custom_job.job_spec.worker_pool_specs[0], + expected_custom_job.job_spec.worker_pool_specs[0], + ] + + mock_create_custom_job.assert_called_once_with( + parent=_TEST_PARENT, + custom_job=expected_custom_job, + timeout=None, + ) + + # check that trained model is deserialized correctly + mock_any_serializer_keras.return_value.deserialize.assert_has_calls( + [ + mock.call( + os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "output/output_estimator") + ), + mock.call( + os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "output/output_data"), + model=model, + ), + ] + ) + + @pytest.mark.usefixtures( + "list_default_tensorboard_mock", + "mock_timestamped_unique_name", + "mock_get_custom_job", + "mock_get_project_number", + "mock_get_experiment_run", + "mock_get_metadata_store", + "get_artifact_not_found_mock", + "update_context_mock", + "mock_autolog_disabled", + ) + def test_remote_training_sklearn_with_experiment( + self, + mock_any_serializer_sklearn, + mock_create_custom_job, + ): + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_BUCKET_NAME, + experiment=_TEST_EXPERIMENT, + ) + vertexai.preview.init(remote=True) + + vertexai.preview.start_run(_TEST_EXPERIMENT_RUN, resume=True) + + LogisticRegression = vertexai.preview.remote(_logistic.LogisticRegression) + model = LogisticRegression() + + model.fit.vertex.remote_config.service_account = "GCE" + + model.fit(_X_TRAIN, _Y_TRAIN) + + # check that model is serialized correctly + mock_any_serializer_sklearn.return_value.serialize.assert_any_call( + to_serialize=model, + gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/input_estimator"), + ) + + # check that args are serialized correctly + mock_any_serializer_sklearn.return_value.serialize.assert_any_call( + to_serialize=_X_TRAIN, + gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/X"), + ) + mock_any_serializer_sklearn.return_value.serialize.assert_any_call( + to_serialize=_Y_TRAIN, + gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/y"), + ) + + # ckeck that CustomJob is created correctly + expected_custom_job = _get_custom_job_proto( + service_account=f"{_TEST_PROJECT_NUMBER}-compute@developer.gserviceaccount.com", + ) + mock_create_custom_job.assert_called_once_with( + parent=_TEST_PARENT, + custom_job=expected_custom_job, + timeout=None, + ) + + # check that trained model is deserialized correctly + mock_any_serializer_sklearn.return_value.deserialize.assert_has_calls( + [ + mock.call( + os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "output/output_estimator") + ), + mock.call( + os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "output/output_data") + ), + ] + ) + + # change to `vertexai.preview.init(remote=False)` to use local prediction + vertexai.preview.init(remote=False) + + # check that local model is updated in place + # `model.score` raises NotFittedError if the model is not updated + model.score(_X_TEST, _Y_TEST) + + @pytest.mark.usefixtures( + "list_default_tensorboard_mock", + "mock_timestamped_unique_name", + "mock_get_custom_job", + "mock_get_experiment_run", + "mock_get_metadata_store", + "get_artifact_not_found_mock", + "update_context_mock", + "aiplatform_autolog_mock", + "mock_autolog_enabled", + ) + def test_remote_training_sklearn_with_experiment_autolog_enabled( + self, + mock_any_serializer_sklearn, + mock_create_custom_job, + ): + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_BUCKET_NAME, + experiment=_TEST_EXPERIMENT, + ) + vertexai.preview.init(remote=True, autolog=True) + + vertexai.preview.start_run(_TEST_EXPERIMENT_RUN, resume=True) + + LogisticRegression = vertexai.preview.remote(_logistic.LogisticRegression) + model = LogisticRegression() + + model.fit.vertex.remote_config.service_account = "custom-sa" + + model.fit(_X_TRAIN, _Y_TRAIN) + + # check that model is serialized correctly + mock_any_serializer_sklearn.return_value.serialize.assert_any_call( + to_serialize=model, + gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/input_estimator"), + ) + + # check that args are serialized correctly + mock_any_serializer_sklearn.return_value.serialize.assert_any_call( + to_serialize=_X_TRAIN, + gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/X"), + ) + mock_any_serializer_sklearn.return_value.serialize.assert_any_call( + to_serialize=_Y_TRAIN, + gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/y"), + ) + + # ckeck that CustomJob is created correctly + expected_custom_job = _get_custom_job_proto( + service_account="custom-sa", + experiment=_TEST_EXPERIMENT, + experiment_run=_TEST_EXPERIMENT_RUN, + autolog_enabled=True, + ) + mock_create_custom_job.assert_called_once_with( + parent=_TEST_PARENT, + custom_job=expected_custom_job, + timeout=None, + ) + + # check that trained model is deserialized correctly + mock_any_serializer_sklearn.return_value.deserialize.assert_has_calls( + [ + mock.call( + os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "output/output_estimator") + ), + mock.call( + os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "output/output_data") + ), + ] + ) + + # change to `vertexai.preview.init(remote=False)` to use local prediction + vertexai.preview.init(remote=False) + + # check that local model is updated in place + # `model.score` raises NotFittedError if the model is not updated + model.score(_X_TEST, _Y_TEST) + + def test_get_service_account_custom_service_account(self): + config = configs.RemoteConfig() + config.service_account = "custom-sa" + + service_account = training._get_service_account(config, autolog=True) + + assert service_account == "custom-sa" + + @pytest.mark.usefixtures( + "mock_get_project_number", + ) + def test_get_service_account_gce_service_account(self): + config = configs.RemoteConfig() + config.service_account = "GCE" + + service_account = training._get_service_account(config, autolog=True) + + assert ( + service_account + == f"{_TEST_PROJECT_NUMBER}-compute@developer.gserviceaccount.com" + ) + + def test_get_service_account_empty_sa_autolog_enabled(self): + config = configs.RemoteConfig() + # config.service_account is empty + + with pytest.raises(ValueError): + training._get_service_account(config, autolog=True) + + def test_get_service_account_empty_sa_autolog_disabled(self): + config = configs.RemoteConfig() + # config.service_account is empty + + service_account = training._get_service_account(config, autolog=False) + + assert service_account is None diff --git a/tests/unit/vertexai/test_serializers.py b/tests/unit/vertexai/test_serializers.py new file mode 100644 index 0000000000..3bc161d79e --- /dev/null +++ b/tests/unit/vertexai/test_serializers.py @@ -0,0 +1,1306 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from importlib import reload +import json +import os +import pickle +import types +from unittest.mock import ANY + +import cloudpickle +from google.cloud import aiplatform +import vertexai +from google.cloud.aiplatform.utils import gcs_utils +from vertexai.preview._workflow.serialization_engine import ( + any_serializer as any_serializer_lib, +) +from vertexai.preview._workflow.serialization_engine import ( + serializers, +) +from vertexai.preview._workflow.shared import constants +from vertexai.preview._workflow.shared import ( + supported_frameworks, +) + +import mock +import numpy as np +import pandas as pd +from pyfakefs import fake_filesystem_unittest +import pytest +from sklearn.linear_model import _logistic +import tensorflow as tf +from tensorflow import keras +import torch + + +@pytest.fixture +def mock_isvalid_gcs_path(): + """Allow using a local path in test.""" + with mock.patch.object( + serializers, + "_is_valid_gcs_path", + autospec=True, + return_value=True, + ) as always_return_true_mock_path_check: + yield always_return_true_mock_path_check + + +@pytest.fixture +def cloudpickle_serializer(): + return serializers.CloudPickleSerializer() + + +@pytest.fixture +def any_serializer(): + return any_serializer_lib.AnySerializer() + + +@pytest.fixture +def sklearn_estimator_serializer(): + return serializers.SklearnEstimatorSerializer() + + +@pytest.fixture +def keras_model_serializer(): + return serializers.KerasModelSerializer() + + +@pytest.fixture +def keras_history_callback_serializer(): + return serializers.KerasHistoryCallbackSerializer() + + +@pytest.fixture +def torch_model_serializer(): + return serializers.TorchModelSerializer() + + +@pytest.fixture +def pandas_data_serializer(): + return serializers.PandasDataSerializer() + + +@pytest.fixture +def torch_dataloader_serializer(): + return serializers.TorchDataLoaderSerializer() + + +@pytest.fixture +def bigframe_serializer(): + return serializers.BigframeSerializer() + + +@pytest.fixture +def tf_dataset_serializer(): + return serializers.TFDatasetSerializer() + + +@pytest.fixture +def mock_keras_model_deserialize(): + with mock.patch.object( + serializers.KerasModelSerializer, "deserialize", autospec=True + ) as keras_model_deserialize: + yield keras_model_deserialize + + +@pytest.fixture +def mock_sklearn_estimator_serialize(): + def stateful_serialize(self, to_serialize, gcs_path): + del self, to_serialize, gcs_path + serializers.SklearnEstimatorSerializer._metadata.dependencies = [ + "sklearn_dependency1==1.0.0" + ] + + with mock.patch.object( + serializers.SklearnEstimatorSerializer, + "serialize", + new=stateful_serialize, + ) as sklearn_estimator_serialize: + yield sklearn_estimator_serialize + serializers.SklearnEstimatorSerializer._metadata.dependencies = [] + + +@pytest.fixture +def mock_sklearn_estimator_deserialize(): + with mock.patch.object( + serializers.SklearnEstimatorSerializer, "deserialize", autospec=True + ) as sklearn_estimator_deserialize: + yield sklearn_estimator_deserialize + + +@pytest.fixture +def mock_torch_model_serialize(): + def stateful_serialize(self, to_serialize, gcs_path): + del self, to_serialize, gcs_path + serializers.TorchModelSerializer._metadata.dependencies = ["torch==1.0.0"] + + with mock.patch.object( + serializers.TorchModelSerializer, "serialize", new=stateful_serialize + ) as torch_model_serialize: + yield torch_model_serialize + serializers.TorchModelSerializer._metadata.dependencies = [] + + +@pytest.fixture +def mock_torch_model_deserialize(): + with mock.patch.object( + serializers.TorchModelSerializer, "deserialize", autospec=True + ) as torch_model_deserialize: + yield torch_model_deserialize + + +@pytest.fixture +def mock_torch_dataloader_serialize(tmp_path): + def stateful_serialize(self, to_serialize, gcs_path): + del self, to_serialize, gcs_path + serializers.TorchDataLoaderSerializer._metadata.dependencies = ["torch==1.0.0"] + + with mock.patch.object( + serializers.TorchDataLoaderSerializer, "serialize", new=stateful_serialize + ) as torch_dataloader_serialize: + yield torch_dataloader_serialize + serializers.TorchDataLoaderSerializer._metadata.dependencies = [] + + +@pytest.fixture +def mock_torch_dataloader_deserialize(): + with mock.patch.object( + serializers.TorchDataLoaderSerializer, "deserialize", autospec=True + ) as torch_dataloader_serializer: + yield torch_dataloader_serializer + + +@pytest.fixture +def mock_download_from_gcs(tmp_path, torch_dataloader_serializer): + def fake_download_from_gcs(serialized_gcs_path, temp_dir): + dataloader = torch.utils.data.DataLoader( + torch.utils.data.TensorDataset( + torch.tensor([[1, 2, 3] for i in range(100)]), + torch.tensor([1] * 100), + ), + batch_size=10, + shuffle=True, + ) + torch_dataloader_serializer._serialize_to_local( + dataloader, os.fspath(tmp_path / temp_dir) + ) + + with mock.patch.object( + gcs_utils, "download_from_gcs", new=fake_download_from_gcs + ) as download_from_gcs: + yield download_from_gcs + + +@pytest.fixture +def mock_tf_dataset_serialize(tmp_path): + def stateful_serialize(self, to_serialize, gcs_path): + del gcs_path + serializers.TFDatasetSerializer._metadata.dependencies = ["tensorflow==1.0.0"] + try: + to_serialize.save(str(tmp_path / "tf_dataset")) + except AttributeError: + tf.data.experimental.save(to_serialize, str(tmp_path / "tf_dataset")) + + with mock.patch.object( + serializers.TFDatasetSerializer, "serialize", new=stateful_serialize + ) as tf_dataset_serialize: + yield tf_dataset_serialize + serializers.TFDatasetSerializer._metadata.dependencies = [] + + +@pytest.fixture +def mock_tf_dataset_deserialize(): + with mock.patch.object( + serializers.TFDatasetSerializer, "deserialize", autospec=True + ) as tf_dataset_serializer: + yield tf_dataset_serializer + + +@pytest.fixture +def mock_pandas_data_serialize(): + def stateful_serialize(self, to_serialize, gcs_path): + del self, to_serialize, gcs_path + serializers.PandasDataSerializer._metadata.dependencies = ["pandas==1.0.0"] + + with mock.patch.object( + serializers.PandasDataSerializer, "serialize", new=stateful_serialize + ) as data_serialize: + yield data_serialize + serializers.PandasDataSerializer._metadata.dependencies = [] + + +@pytest.fixture +def mock_pandas_data_deserialize(): + with mock.patch.object( + serializers.PandasDataSerializer, "deserialize", autospec=True + ) as pandas_data_deserialize: + yield pandas_data_deserialize + + +@pytest.fixture +def mock_bigframe_deserialize_sklearn(): + with mock.patch.object( + serializers.BigframeSerializer, "_deserialize_sklearn", autospec=True + ) as bigframe_deserialize_sklearn: + yield bigframe_deserialize_sklearn + + +@pytest.fixture +def mock_keras_save_model(): + with mock.patch.object(keras.models.Sequential, "save") as keras_save_model: + yield keras_save_model + + +@pytest.fixture +def mock_keras_load_model(): + with mock.patch.object(keras.models, "load_model") as keras_load_model: + yield keras_load_model + + +@pytest.fixture +def mock_torch_save_model(): + with mock.patch.object(torch, "save", autospec=True) as torch_save_model: + yield torch_save_model + + +@pytest.fixture +def mock_torch_load_model(): + with mock.patch.object(torch, "load", autospec=True) as torch_load_model: + yield torch_load_model + + +@pytest.fixture +def mock_upload_to_gcs(): + with mock.patch.object(gcs_utils, "upload_to_gcs", autospec=True) as upload_to_gcs: + yield upload_to_gcs + + +@pytest.fixture +def mock_json_dump(): + with mock.patch.object(json, "dump", autospec=True) as json_dump: + yield json_dump + + +@pytest.fixture +def mock_cloudpickle_dump(): + with mock.patch.object(cloudpickle, "dump", autospec=True) as cloudpickle_dump: + yield cloudpickle_dump + + +class TestTorchClass(torch.nn.Module): + def __init__(self, input_size=4): + super().__init__() + self.linear_relu_stack = torch.nn.Sequential( + torch.nn.Linear(input_size, 3), torch.nn.ReLU(), torch.nn.Linear(3, 2) + ) + + def forward(self, x): + logits = self.linear_relu_stack(x) + return logits + + +class TestSklearnEstimatorSerializer: + def setup_method(self): + reload(vertexai) + reload(vertexai.preview.initializer) + reload(_logistic) + + def teardown_method(self): + aiplatform.initializer.global_pool.shutdown(wait=True) + + @pytest.mark.usefixtures("mock_storage_blob", "google_auth_mock") + def test_serialize_path_start_with_gs(self, sklearn_estimator_serializer): + # Arrange + fake_gcs_uri = "gs://staging-bucket/fake_gcs_uri" + + train_x = np.array([[1, 1], [1, 2], [2, 2], [2, 3]]) + train_y = np.dot(train_x, np.array([1, 2])) + 3 + sklearn_estimator = _logistic.LogisticRegression() + sklearn_estimator.fit(train_x, train_y) + + # Act + sklearn_estimator_serializer.serialize(sklearn_estimator, fake_gcs_uri) + + # Assert + # The serialized file is written to a local path "fake_gcs_uri" via + # mock_upload_to_gcs for hermicity. + with open(fake_gcs_uri.split("/")[-1], "rb") as f: + restored_estimator = pickle.load(f) + + assert isinstance(restored_estimator, _logistic.LogisticRegression) + assert sklearn_estimator.get_params() == restored_estimator.get_params() + assert (sklearn_estimator.coef_ == restored_estimator.coef_).all() + + def test_serialize_path_start_with_gcs(self, sklearn_estimator_serializer): + # Arrange + fake_gcs_uri = "/gcs/staging-bucket/fake_gcs_uri" + + train_x = np.array([[1, 1], [1, 2], [2, 2], [2, 3]]) + train_y = np.dot(train_x, np.array([1, 2])) + 3 + sklearn_estimator = _logistic.LogisticRegression() + sklearn_estimator.fit(train_x, train_y) + + # Act + with fake_filesystem_unittest.Patcher() as filesystem: + filesystem.fs.create_file(fake_gcs_uri) + sklearn_estimator_serializer.serialize(sklearn_estimator, fake_gcs_uri) + + # Assert + # The serialized file is written to a local path "fake_gcs_uri" via + # mock_upload_to_gcs for hermicity. + with open(fake_gcs_uri, "rb") as f: + restored_estimator = pickle.load(f) + + assert isinstance(restored_estimator, _logistic.LogisticRegression) + assert sklearn_estimator.get_params() == restored_estimator.get_params() + assert (sklearn_estimator.coef_ == restored_estimator.coef_).all() + + def test_serialize_invalid_gcs_path(self, sklearn_estimator_serializer): + # Arrange + fake_gcs_uri = "fake_gcs_uri" + + train_x = np.array([[1, 1], [1, 2], [2, 2], [2, 3]]) + train_y = np.dot(train_x, np.array([1, 2])) + 3 + sklearn_estimator = _logistic.LogisticRegression() + sklearn_estimator.fit(train_x, train_y) + + # Act + with pytest.raises(ValueError, match=f"Invalid gcs path: {fake_gcs_uri}"): + sklearn_estimator_serializer.serialize(sklearn_estimator, fake_gcs_uri) + + @pytest.mark.usefixtures("mock_storage_blob", "google_auth_mock") + def test_deserialize_path_start_with_gs( + self, sklearn_estimator_serializer, mock_storage_blob + ): + # Arrange + fake_gcs_uri = "gs://staging-bucket/fake_gcs_uri" + + train_x = np.array([[1, 1], [1, 2], [2, 2], [2, 3]]) + train_y = np.dot(train_x, np.array([1, 2])) + 3 + sklearn_estimator = _logistic.LogisticRegression() + sklearn_estimator.fit(train_x, train_y) + + def fake_download_file_from_gcs(self, filename): + with open(filename, "wb") as f: + pickle.dump(sklearn_estimator, f) + + mock_storage_blob.download_to_filename = types.MethodType( + fake_download_file_from_gcs, mock_storage_blob + ) + + # Act + restored_estimator = sklearn_estimator_serializer.deserialize(fake_gcs_uri) + + # Assert + assert isinstance(restored_estimator, _logistic.LogisticRegression) + assert sklearn_estimator.get_params() == restored_estimator.get_params() + assert (sklearn_estimator.coef_ == restored_estimator.coef_).all() + + def test_deserialize_path_start_with_gcs(self, sklearn_estimator_serializer): + # Arrange + fake_gcs_uri = "/gcs/staging-bucket/fake_gcs_uri" + + train_x = np.array([[1, 1], [1, 2], [2, 2], [2, 3]]) + train_y = np.dot(train_x, np.array([1, 2])) + 3 + sklearn_estimator = _logistic.LogisticRegression() + sklearn_estimator.fit(train_x, train_y) + + with fake_filesystem_unittest.Patcher() as filesystem: + filesystem.fs.create_file(fake_gcs_uri) + with open(fake_gcs_uri, "wb") as f: + pickle.dump(sklearn_estimator, f) + # Act + restored_estimator = sklearn_estimator_serializer.deserialize(fake_gcs_uri) + + # Assert + assert isinstance(restored_estimator, _logistic.LogisticRegression) + assert sklearn_estimator.get_params() == restored_estimator.get_params() + assert (sklearn_estimator.coef_ == restored_estimator.coef_).all() + + def test_deserialize_invalid_gcs_path(self, sklearn_estimator_serializer): + # Arrange + fake_gcs_uri = "fake_gcs_uri" + + # Act + with pytest.raises(ValueError, match=f"Invalid gcs path: {fake_gcs_uri}"): + sklearn_estimator_serializer.deserialize(fake_gcs_uri) + + +class TestKerasModelSerializer: + def test_serialize_gcs_path(self, keras_model_serializer, mock_keras_save_model): + # Arrange + fake_gcs_uri = "gs://staging-bucket/fake_gcs_uri" + + keras_model = keras.Sequential( + [keras.layers.Dense(8, input_shape=(2,)), keras.layers.Dense(4)] + ) + + # Act + keras_model_serializer.serialize(keras_model, fake_gcs_uri) + + # Assert + mock_keras_save_model.assert_called_once_with(fake_gcs_uri) + + def test_serialize_invalid_gcs_path(self, keras_model_serializer): + # Arrange + fake_gcs_uri = "fake_gcs_uri" + + keras_model = keras.Sequential( + [keras.layers.Dense(8, input_shape=(2,)), keras.layers.Dense(4)] + ) + + # Act + with pytest.raises(ValueError, match=f"Invalid gcs path: {fake_gcs_uri}"): + keras_model_serializer.serialize(keras_model, fake_gcs_uri) + + @pytest.mark.usefixtures("google_auth_mock") + def test_deserialize_gcs_path(self, keras_model_serializer, mock_keras_load_model): + # Arrange + fake_gcs_uri = "gs://staging-bucket/fake_gcs_uri" + + # Act + _ = keras_model_serializer.deserialize(fake_gcs_uri) + + # Assert + mock_keras_load_model.assert_called_once_with(fake_gcs_uri) + + def test_deserialize_invalid_gcs_path(self, keras_model_serializer): + # Arrange + fake_gcs_uri = "fake_gcs_uri" + + # Act + with pytest.raises(ValueError, match=f"Invalid gcs path: {fake_gcs_uri}"): + keras_model_serializer.deserialize(fake_gcs_uri) + + +class TestKerasHistoryCallbackSerializer: + @pytest.mark.usefixtures("mock_isvalid_gcs_path") + def test_serialize_gcs_path(self, keras_history_callback_serializer, tmp_path): + # Arrange + fake_gcs_uri = tmp_path / "fake_gcs_uri" + + keras_model = keras.Sequential( + [keras.layers.Dense(8, input_shape=(2,)), keras.layers.Dense(4)] + ) + history = keras.callbacks.History() + history.history = {"loss": [1.0, 0.5, 0.2]} + history.params = {"verbose": 1, "epochs": 3, "steps": 1} + history.epoch = [0, 1, 2] + history.model = keras_model + + # Act + keras_history_callback_serializer.serialize(history, str(fake_gcs_uri)) + + with open(tmp_path / "fake_gcs_uri", "rb") as f: + deserialized = cloudpickle.load(f) + + assert "model" not in deserialized + assert deserialized["history"]["loss"] == history.history["loss"] + assert deserialized["params"] == history.params + assert deserialized["epoch"] == history.epoch + + def test_serialize_invalid_gcs_path(self, keras_history_callback_serializer): + # Arrange + fake_gcs_uri = "fake_gcs_uri" + + history = keras.callbacks.History() + + # Act + with pytest.raises(ValueError, match=f"Invalid gcs path: {fake_gcs_uri}"): + keras_history_callback_serializer.serialize(history, fake_gcs_uri) + + @pytest.mark.usefixtures("google_auth_mock") + def test_deserialize_gcs_path( + self, + keras_history_callback_serializer, + mock_storage_blob_tmp_dir, + tmp_path, + ): + # Arrange + fake_gcs_uri = "gs://staging-bucket/fake_gcs_uri" + + _ = keras.Sequential( + [keras.layers.Dense(8, input_shape=(2,)), keras.layers.Dense(4)] + ) + history = keras.callbacks.History() + history.history = {"loss": [1.0, 0.5, 0.2]} + history.params = {"verbose": 1, "epochs": 3, "steps": 1} + history.epoch = [0, 1, 2] + + def fake_download_file_from_gcs(self, filename): + with open(tmp_path / filename, "wb") as f: + cloudpickle.dump( + history.__dict__, + f, + ) + + mock_storage_blob_tmp_dir.download_to_filename = types.MethodType( + fake_download_file_from_gcs, mock_storage_blob_tmp_dir + ) + + # Act + restored_history = keras_history_callback_serializer.deserialize(fake_gcs_uri) + + # Assert + assert isinstance(restored_history, keras.callbacks.History) + assert restored_history.model is None + + @pytest.mark.usefixtures("google_auth_mock") + def test_deserialize_gcs_path_with_model( + self, + keras_history_callback_serializer, + mock_storage_blob_tmp_dir, + tmp_path, + ): + # Arrange + fake_gcs_uri = "gs://staging-bucket/fake_gcs_uri" + + keras_model = keras.Sequential( + [keras.layers.Dense(8, input_shape=(2,)), keras.layers.Dense(4)] + ) + history = keras.callbacks.History() + history.history = {"loss": [1.0, 0.5, 0.2]} + history.params = {"verbose": 1, "epochs": 3, "steps": 1} + history.epoch = [0, 1, 2] + + def fake_download_file_from_gcs(self, filename): + with open(tmp_path / filename, "wb") as f: + cloudpickle.dump( + history.__dict__, + f, + ) + + mock_storage_blob_tmp_dir.download_to_filename = types.MethodType( + fake_download_file_from_gcs, mock_storage_blob_tmp_dir + ) + + # Act + restored_history = keras_history_callback_serializer.deserialize( + fake_gcs_uri, model=keras_model + ) + + # Assert + assert isinstance(restored_history, keras.callbacks.History) + assert restored_history.model == keras_model + + def test_deserialize_invalid_gcs_path(self, keras_history_callback_serializer): + # Arrange + fake_gcs_uri = "fake_gcs_uri" + + # Act + with pytest.raises(ValueError, match=f"Invalid gcs path: {fake_gcs_uri}"): + keras_history_callback_serializer.deserialize(fake_gcs_uri) + + +class TestTorchModelSerializer: + def test_serialize_path_start_with_gs( + self, torch_model_serializer, mock_torch_save_model, mock_upload_to_gcs + ): + # Arrange + fake_gcs_uri = "gs://staging-bucket/fake_gcs_uri" + + torch_model = TestTorchClass() + + # Act + torch_model_serializer.serialize(torch_model, fake_gcs_uri) + + # Assert + mock_torch_save_model.assert_called_once_with( + torch_model, + ANY, + pickle_module=cloudpickle, + pickle_protocol=constants.PICKLE_PROTOCOL, + ) + + mock_upload_to_gcs.assert_called_once_with(ANY, fake_gcs_uri) + + def test_serialize_path_start_with_gcs( + self, torch_model_serializer, mock_torch_save_model + ): + # Arrange + fake_gcs_uri = "/gcs/staging-bucket/fake_gcs_uri" + + torch_model = TestTorchClass() + + # Act + + torch_model_serializer.serialize(torch_model, fake_gcs_uri) + + # Assert + mock_torch_save_model.assert_called_once_with( + torch_model, + fake_gcs_uri, + pickle_module=cloudpickle, + pickle_protocol=constants.PICKLE_PROTOCOL, + ) + + def test_serialize_invalid_gcs_path(self, torch_model_serializer): + # Arrange + fake_gcs_uri = "fake_gcs_uri" + + torch_model = TestTorchClass() + + # Act + with pytest.raises(ValueError, match=f"Invalid gcs path: {fake_gcs_uri}"): + torch_model_serializer.serialize(torch_model, fake_gcs_uri) + + @pytest.mark.usefixtures("google_auth_mock") + def test_deserialize_path_start_with_gs( + self, torch_model_serializer, mock_storage_blob_tmp_dir, tmp_path + ): + # TorchModelSerializer only supports torch>=2.0, which supports python>=3.8 + # Skip this test for python 3.7 + if supported_frameworks._get_python_minor_version() == "3.7": + return + + # Arrange + fake_gcs_uri = "gs://staging-bucket/fake_gcs_uri" + + torch_model = TestTorchClass() + + def fake_download_file_from_gcs(self, filename): + torch.save( + torch_model, + os.fspath(tmp_path / filename), + pickle_module=cloudpickle, + pickle_protocol=constants.PICKLE_PROTOCOL, + ) + + mock_storage_blob_tmp_dir.download_to_filename = types.MethodType( + fake_download_file_from_gcs, mock_storage_blob_tmp_dir + ) + + # Act + restored_model = torch_model_serializer.deserialize(fake_gcs_uri) + + # Assert + assert isinstance(restored_model, TestTorchClass) + assert str(torch_model.state_dict()) == str(restored_model.state_dict()) + + def test_deserialize_path_start_with_gcs( + self, torch_model_serializer, mock_torch_load_model + ): + # TorchModelSerializer only supports torch>=2.0, which supports python>=3.8 + # Skip this test for python 3.7 + if supported_frameworks._get_python_minor_version() == "3.7": + return + + # Arrange + fake_gcs_uri = "/gcs/staging-bucket/fake_gcs_uri" + + # Act + _ = torch_model_serializer.deserialize(fake_gcs_uri) + + # Assert + mock_torch_load_model.assert_called_once_with( + fake_gcs_uri, + map_location=None, + ) + + def test_deserialize_invalid_gcs_path(self, torch_model_serializer): + # Arrange + fake_gcs_uri = "fake_gcs_uri" + + # Act + with pytest.raises(ValueError, match=f"Invalid gcs path: {fake_gcs_uri}"): + torch_model_serializer.deserialize(fake_gcs_uri) + + +class TestTorchDataLoaderSerializer: + def test_serialize_dataloader( + self, + torch_dataloader_serializer, + mock_json_dump, + mock_cloudpickle_dump, + mock_upload_to_gcs, + ): + # Arrange + fake_gcs_uri = "gs://staging-bucket/fake_gcs_uri" + + dataloader = torch.utils.data.DataLoader( + torch.utils.data.TensorDataset( + torch.tensor([[1, 2, 3] for i in range(100)]), + torch.tensor([1] * 100), + ), + batch_size=10, + shuffle=True, + ) + + # Act + torch_dataloader_serializer.serialize(dataloader, fake_gcs_uri) + + # Assert + mock_json_dump.assert_called_once_with( + { + "batch_size": dataloader.batch_size, + "num_workers": dataloader.num_workers, + "pin_memory": dataloader.pin_memory, + "drop_last": dataloader.drop_last, + "timeout": dataloader.timeout, + "prefetch_factor": dataloader.prefetch_factor, + "persistent_workers": dataloader.persistent_workers, + "pin_memory_device": dataloader.pin_memory_device, + "generator_device": None, + }, + ANY, + ) + + assert mock_cloudpickle_dump.call_count == 4 + + mock_upload_to_gcs.assert_called_once_with(ANY, fake_gcs_uri) + + def test_serialize_invalid_gcs_path(self, torch_dataloader_serializer): + # Arrange + fake_gcs_uri = "fake_gcs_uri" + + dataloader = torch.utils.data.DataLoader( + torch.utils.data.TensorDataset( + torch.tensor([[1, 2, 3] for i in range(100)]), + torch.tensor([1] * 100), + ), + batch_size=10, + shuffle=True, + ) + + # Act + with pytest.raises(ValueError, match=f"Invalid gcs path: {fake_gcs_uri}"): + torch_dataloader_serializer.serialize(dataloader, fake_gcs_uri) + + @pytest.mark.usefixtures("google_auth_mock") + def test_deserialize_dataloader( + self, torch_dataloader_serializer, mock_download_from_gcs + ): + # Arrange + fake_gcs_uri = "gs://staging-bucket/fake_gcs_uri" + + expected_dataloader = torch.utils.data.DataLoader( + torch.utils.data.TensorDataset( + torch.tensor([[1, 2, 3] for i in range(100)]), + torch.tensor([1] * 100), + ), + batch_size=10, + shuffle=True, + ) + + # Act + dataloader = torch_dataloader_serializer.deserialize(fake_gcs_uri) + + # Assert + assert dataloader.batch_size == expected_dataloader.batch_size + assert dataloader.num_workers == expected_dataloader.num_workers + assert dataloader.pin_memory == expected_dataloader.pin_memory + assert dataloader.drop_last == expected_dataloader.drop_last + assert dataloader.timeout == expected_dataloader.timeout + assert dataloader.prefetch_factor == expected_dataloader.prefetch_factor + assert dataloader.persistent_workers == expected_dataloader.persistent_workers + + def test_deserialize_invalid_gcs_path(self, torch_dataloader_serializer): + # Arrange + fake_gcs_uri = "fake_gcs_uri" + + # Act + with pytest.raises(ValueError, match=f"Invalid gcs path: {fake_gcs_uri}"): + torch_dataloader_serializer.deserialize(fake_gcs_uri) + + +class TestCloudPickleSerializer: + @pytest.mark.usefixtures("mock_storage_blob", "google_auth_mock") + def test_serialize_func(self, cloudpickle_serializer): + # Arrange + fake_gcs_uri = "gs://staging-bucket/fake_gcs_uri.cpkl" + + def function_to_be_serialized(): + return "return_str" + + # Act + cloudpickle_serializer.serialize(function_to_be_serialized, fake_gcs_uri) + + # Assert + del function_to_be_serialized + # The serialized file is written to a local path "fake_gcs_uri.cpkl" via + # mock_upload_to_gcs for hermicity. + with open(fake_gcs_uri.split("/")[-1], "rb") as f: + restored_fn = cloudpickle.load(f) + assert restored_fn() == "return_str" + + def test_serialize_func_path_start_with_gcs(self, cloudpickle_serializer): + # Arrange + fake_gcs_uri = "/gcs/staging-bucket/fake_gcs_uri.cpkl" + + def function_to_be_serialized(): + return "return_str" + + # Act + with fake_filesystem_unittest.Patcher() as filesystem: + filesystem.fs.create_file(fake_gcs_uri) + cloudpickle_serializer.serialize(function_to_be_serialized, fake_gcs_uri) + + # Assert + del function_to_be_serialized + # The serialized file is written to a local path "fake_gcs_uri.cpkl" via + # mock_upload_to_gcs for hermicity. + with open(fake_gcs_uri, "rb") as f: + restored_fn = cloudpickle.load(f) + assert restored_fn() == "return_str" + + def test_serialize_invalid_gcs_path(self, cloudpickle_serializer): + # Arrange + fake_gcs_uri = "fake_gcs_uri.cpkl" + + def function_to_be_serialized(): + return "return_str" + + # Act + with pytest.raises(ValueError, match=f"Invalid gcs path: {fake_gcs_uri}"): + cloudpickle_serializer.serialize(function_to_be_serialized, fake_gcs_uri) + + @pytest.mark.usefixtures("mock_storage_blob", "google_auth_mock") + def test_serialize_object(self, cloudpickle_serializer): + # Arrange + fake_gcs_uri = "gs://staging-bucket/fake_gcs_uri.cpkl" + + class TestClass: + def test_method(self): + return "return_str" + + test_object = TestClass() + # Act + cloudpickle_serializer.serialize(test_object, fake_gcs_uri) + + # Assert + del test_object + # The serialized file is written to a local path "fake_gcs_uri.cpkl" via + # mock_upload_to_gcs for hermicity. + with open(fake_gcs_uri.split("/")[-1], "rb") as f: + restored_object = cloudpickle.load(f) + assert restored_object.test_method() == "return_str" + + @pytest.mark.usefixtures("mock_storage_blob", "google_auth_mock") + def test_serialize_class(self, cloudpickle_serializer): + # Arrange + fake_gcs_uri = "gs://staging-bucket/fake_gcs_uri.cpkl" + + class TestClass: + def test_method(self): + return "return_str" + + # Act + cloudpickle_serializer.serialize(TestClass, fake_gcs_uri) + + # Assert + del TestClass + # The serialized file is written to a local path "fake_gcs_uri.cpkl" via + # mock_upload_to_gcs for hermicity. + with open(fake_gcs_uri.split("/")[-1], "rb") as f: + restored_class = cloudpickle.load(f) + assert restored_class().test_method() == "return_str" + + @pytest.mark.usefixtures("mock_storage_blob", "google_auth_mock") + def test_deserialize_func(self, cloudpickle_serializer, mock_storage_blob): + # Arrange + fake_gcs_uri = "gs://staging-bucket/fake_gcs_uri.cpkl" + + def test_function(): + return "return_str" + + def fake_download_file_from_gcs(self, filename): + with open(filename, "wb") as f: + cloudpickle.dump(test_function, f) + + mock_storage_blob.download_to_filename = types.MethodType( + fake_download_file_from_gcs, mock_storage_blob + ) + + # Act + restored_fn = cloudpickle_serializer.deserialize(fake_gcs_uri) + + # Assert + assert restored_fn() == "return_str" + + def test_deserialize_func_path_start_with_gcs(self, cloudpickle_serializer): + # Arrange + fake_gcs_uri = "/gcs/staging-bucket/fake_gcs_uri.cpkl" + + def test_function(): + return "return_str" + + with fake_filesystem_unittest.Patcher() as filesystem: + filesystem.fs.create_file(fake_gcs_uri) + with open(fake_gcs_uri, "wb") as f: + cloudpickle.dump(test_function, f) + # Act + restored_fn = cloudpickle_serializer.deserialize(fake_gcs_uri) + + # Assert + assert restored_fn() == "return_str" + + def test_deserialize_func_invalid_gcs_path(self, cloudpickle_serializer): + # Arrange + fake_gcs_uri = "fake_gcs_uri.cpkl" + + def test_function(): + return "return_str" + + # Act + with pytest.raises(ValueError, match=f"Invalid gcs path: {fake_gcs_uri}"): + cloudpickle_serializer.serialize(test_function, fake_gcs_uri) + + @pytest.mark.usefixtures("google_auth_mock") + def test_deserialize_object(self, cloudpickle_serializer, mock_storage_blob): + # Arrange + fake_gcs_uri = "gs://staging-bucket/fake_gcs_uri.cpkl" + + class TestClass: + def test_method(self): + return "return_str" + + def fake_download_file_from_gcs(self, filename: str): + with open(filename, "wb") as f: + cloudpickle.dump(TestClass(), f) + + mock_storage_blob.download_to_filename = types.MethodType( + fake_download_file_from_gcs, mock_storage_blob + ) + + # Act + restored_object = cloudpickle_serializer.deserialize(fake_gcs_uri) + + # Assert + assert restored_object.test_method() == "return_str" + + @pytest.mark.usefixtures("google_auth_mock") + def test_deserialize_class(self, cloudpickle_serializer, mock_storage_blob): + # Arrange + fake_gcs_uri = "gs://staging-bucket/fake_gcs_uri.cpkl" + + class TestClass: + def test_method(self): + return "return_str" + + def fake_download_file_from_gcs(self, filename): + with open(filename, "wb") as f: + cloudpickle.dump(TestClass, f) + + mock_storage_blob.download_to_filename = types.MethodType( + fake_download_file_from_gcs, mock_storage_blob + ) + + # Act + restored_class = cloudpickle_serializer.deserialize(fake_gcs_uri) + + # Assert + assert restored_class().test_method() == "return_str" + + +class TestTFDatasetSerializer: + @pytest.mark.usefixtures("mock_tf_dataset_serialize") + def test_serialize_tf_dataset(self, tf_dataset_serializer, tmp_path): + # Arrange + fake_gcs_uri = "gs://staging-bucket/tf_dataset" + tf_dataset = tf.data.Dataset.from_tensor_slices(np.array([1, 2, 3])) + + # Act + tf_dataset_serializer.serialize(tf_dataset, fake_gcs_uri) + + # Assert + try: + loaded_dataset = tf.data.Dataset.load(str(tmp_path / "tf_dataset")) + except AttributeError: + loaded_dataset = tf.data.experimental.load(str(tmp_path / "tf_dataset")) + for original_ele, loaded_ele in zip(tf_dataset, loaded_dataset): + assert original_ele == loaded_ele + + def test_deserialize_tf_dataset(self, tf_dataset_serializer, tmp_path): + # Arrange + tf_dataset = tf.data.Dataset.from_tensor_slices(np.array([1, 2, 3])) + try: + tf_dataset.save(str(tmp_path / "tf_dataset")) + except AttributeError: + tf.data.experimental.save(tf_dataset, str(tmp_path / "tf_dataset")) + + # Act + loaded_dataset = tf_dataset_serializer.deserialize(str(tmp_path / "tf_dataset")) + + # Assert + for original_ele, loaded_ele in zip(tf_dataset, loaded_dataset): + assert original_ele == loaded_ele + + +class TestPandasDataSerializer: + @pytest.mark.usefixtures("mock_storage_blob_tmp_dir", "google_auth_mock") + def test_serialize_float_only_default_index_dataframe( + self, pandas_data_serializer, tmp_path + ): + # Arrange + fake_gcs_uri = "gs://staging-bucket/fake_gcs_uri.parquet" + + df = pd.DataFrame(np.zeros(shape=[3, 3]), columns=["col1", "col2", "col3"]) + + # Act + pandas_data_serializer.serialize(df, fake_gcs_uri) + + # Assert + # For hermicity, The serialized file is written to a local path + # "tmp_path/fake_gcs_uri.parquet" via mock_storage_blob_tmp_dir. + parquet_file_path = os.fspath(tmp_path / fake_gcs_uri.split("/")[-1]) + restored_df = pd.read_parquet(parquet_file_path) + + pd.testing.assert_frame_equal(df, restored_df) + + @pytest.mark.usefixtures("mock_storage_blob_tmp_dir", "google_auth_mock") + def test_serialize_float_only_str_index(self, pandas_data_serializer, tmp_path): + # Arrange + fake_gcs_uri = "gs://staging-bucket/fake_gcs_uri.parquet" + + df = pd.DataFrame( + np.zeros(shape=[3, 3]), + columns=["col1", "col2", "col3"], + index=["row1", "row2", "row3"], + ) + + # Act + pandas_data_serializer.serialize(df, fake_gcs_uri) + + # Assert + # For hermicity, The serialized file is written to a local path + # "tmp_path/fake_gcs_uri.parquet" via mock_storage_blob_tmp_dir. + parquet_file_path = os.fspath(tmp_path / fake_gcs_uri.split("/")[-1]) + restored_df = pd.read_parquet(parquet_file_path) + + pd.testing.assert_frame_equal(df, restored_df) + + @pytest.mark.usefixtures("mock_storage_blob_tmp_dir", "google_auth_mock") + def test_serialize_common_typed_columns_with_nan( + self, pandas_data_serializer, tmp_path + ): + # Arrange + fake_gcs_uri = "gs://staging-bucket/fake_gcs_uri.parquet" + + df = pd.DataFrame( + np.zeros(shape=[3, 4]), + columns=["str_col", "float_col", "bool_col", "timestamp_col"], + ) + + # object type + df["str_col"] = ["a", np.nan, "b"] + # float type + df["float_clo"] = [1.0, np.nan, np.nan] + # object type + df["bool_col"] = [True, False, np.nan] + # object type + df["timestamp_col"] = [ + pd.Timestamp("20110101"), + np.nan, + pd.Timestamp("20110101"), + ] + + # Act + pandas_data_serializer.serialize(df, fake_gcs_uri) + + # Assert + # For hermicity, The serialized file is written to a local path + # "tmp_path/fake_gcs_uri.parquet" via mock_storage_blob_tmp_dir. + parquet_file_path = os.fspath(tmp_path / fake_gcs_uri.split("/")[-1]) + restored_df = pd.read_parquet(parquet_file_path) + + pd.testing.assert_frame_equal(df, restored_df) + + @pytest.mark.usefixtures("mock_storage_blob_tmp_dir", "google_auth_mock") + def test_serialize_common_typed_columns_with_none( + self, pandas_data_serializer, tmp_path + ): + # Arrange + fake_gcs_uri = "gs://staging-bucket/fake_gcs_uri.parquet" + + df = pd.DataFrame( + np.zeros(shape=[3, 8]), + columns=[ + "str_to_object_col", + "str_col", + "float_col", + "int_to_float_col", + "int_col", + "bool_to_object_col", + "bool_col", + "timestamp_col", + ], + ) + + df["str_to_object_col"] = ["a", None, "b"] + df["str_col"] = ["a", "b", "c"] + + df["float_col"] = [1.0, None, None] # None -> NaN + + df["int_to_float_col"] = [1, 2, None] # None -> NaN + df["int_col"] = [1, 2, 3] + + df["bool_to_object_col"] = [True, False, None] + df["bool_col"] = [True, False, True] + + df["timestamp_col"] = [ + pd.Timestamp("20110101"), + None, + pd.Timestamp("20110101"), + ] # None -> NaT + + # Act + pandas_data_serializer.serialize(df, fake_gcs_uri) + + # Assert + # For hermicity, The serialized file is written to a local path + # "tmp_path/fake_gcs_uri.parquet" via mock_storage_blob_tmp_dir. + parquet_file_path = os.fspath(tmp_path / fake_gcs_uri.split("/")[-1]) + restored_df = pd.read_parquet(parquet_file_path) + + pd.testing.assert_frame_equal(df, restored_df) + + @pytest.mark.usefixtures("mock_storage_blob_tmp_dir", "google_auth_mock") + def test_deserialize_all_floats_cols( + self, pandas_data_serializer, mock_storage_blob_tmp_dir + ): + # Arrange + fake_gcs_uri = "gs://staging-bucket/fake_gcs_uri.parquet" + + df = pd.DataFrame(np.zeros(shape=[3, 3]), columns=["col1", "col2", "col3"]) + + def fake_download_file_from_gcs(self, filename): + df.to_parquet(filename) + + mock_storage_blob_tmp_dir.download_to_filename = types.MethodType( + fake_download_file_from_gcs, mock_storage_blob_tmp_dir + ) + + # Act + restored_df = pandas_data_serializer.deserialize(fake_gcs_uri) + + # Assert + pd.testing.assert_frame_equal(df, restored_df) + + @pytest.mark.usefixtures("mock_storage_blob_tmp_dir", "google_auth_mock") + def test_deserialize_all_floats_cols_str_index( + self, pandas_data_serializer, mock_storage_blob_tmp_dir + ): + # Arrange + fake_gcs_uri = "gs://staging-bucket/fake_gcs_uri.parquet" + + df = pd.DataFrame( + np.zeros(shape=[3, 3]), + columns=["col1", "col2", "col3"], + index=["row1", "row2", "row3"], + ) + + def fake_download_file_from_gcs(self, filename): + df.to_parquet(filename) + + mock_storage_blob_tmp_dir.download_to_filename = types.MethodType( + fake_download_file_from_gcs, mock_storage_blob_tmp_dir + ) + + # Act + restored_df = pandas_data_serializer.deserialize(fake_gcs_uri) + + # Assert + pd.testing.assert_frame_equal(df, restored_df) + + @pytest.mark.usefixtures("mock_storage_blob_tmp_dir", "google_auth_mock") + def test_deserialize_common_types_with_none( + self, pandas_data_serializer, mock_storage_blob_tmp_dir + ): + # Arrange + fake_gcs_uri = "gs://staging-bucket/fake_gcs_uri.parquet" + + df = pd.DataFrame( + np.zeros(shape=[3, 8]), + columns=[ + "str_to_object_col", + "str_col", + "float_col", + "int_to_float_col", + "int_col", + "bool_to_object_col", + "bool_col", + "timestamp_col", + ], + ) + + df["str_to_object_col"] = ["a", None, "b"] + df["str_col"] = ["a", "b", "c"] + + df["float_col"] = [1.0, None, None] # None -> NaN + + df["int_to_float_col"] = [1, 2, None] # None -> NaN + df["int_col"] = [1, 2, 3] + + df["bool_to_object_col"] = [True, False, None] + df["bool_col"] = [True, False, True] + + df["timestamp_col"] = [ + pd.Timestamp("20110101"), + None, + pd.Timestamp("20110101"), + ] # None -> NaT + + def fake_download_file_from_gcs(self, filename): + df.to_parquet(filename) + + mock_storage_blob_tmp_dir.download_to_filename = types.MethodType( + fake_download_file_from_gcs, mock_storage_blob_tmp_dir + ) + + # Act + restored_df = pandas_data_serializer.deserialize(fake_gcs_uri) + + # Assert + pd.testing.assert_frame_equal(df, restored_df) + + @pytest.mark.usefixtures("mock_storage_blob_tmp_dir", "google_auth_mock") + def test_deserialize_common_types_with_nan( + self, pandas_data_serializer, mock_storage_blob_tmp_dir + ): + # Arrange + fake_gcs_uri = "gs://staging-bucket/fake_gcs_uri.parquet" + + df = pd.DataFrame( + np.zeros(shape=[3, 4]), + columns=["str_col", "float_col", "bool_col", "timestamp_col"], + ) + + # object type + df["str_col"] = ["a", np.nan, "b"] + # float type + df["float_clo"] = [1.0, np.nan, np.nan] + # object type + df["bool_col"] = [True, False, np.nan] + # object type + df["timestamp_col"] = [ + pd.Timestamp("20110101"), + np.nan, + pd.Timestamp("20110101"), + ] + + def fake_download_file_from_gcs(self, filename): + df.to_parquet(filename) + + mock_storage_blob_tmp_dir.download_to_filename = types.MethodType( + fake_download_file_from_gcs, mock_storage_blob_tmp_dir + ) + + # Act + restored_df = pandas_data_serializer.deserialize(fake_gcs_uri) + + # Assert + pd.testing.assert_frame_equal(df, restored_df) diff --git a/tests/unit/vertexai/test_tabnet_trainer.py b/tests/unit/vertexai/test_tabnet_trainer.py new file mode 100644 index 0000000000..d6efce0d3e --- /dev/null +++ b/tests/unit/vertexai/test_tabnet_trainer.py @@ -0,0 +1,812 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from importlib import reload +import os +import re +from unittest.mock import Mock + +from google.cloud import aiplatform +from google.cloud.aiplatform.compat.types import ( + custom_job as gca_custom_job_compat, +) +from google.cloud.aiplatform.compat.types import io as gca_io_compat +from vertexai.preview._workflow.executor import ( + remote_container_training, +) +from vertexai.preview.tabular_models import ( + tabnet_trainer, +) +from vertexai.preview._workflow.shared import configs +import pandas as pd +import pytest +import tensorflow as tf + +_TEST_STAGING_BUCKET = "gs://test_staging_bucket" +_TEST_JOB_DIR = "gs://test_job_dir" +_TEST_TARGET_COLUMN = "target" +_TEST_MODEL_TYPE_CLASSIFICATION = "classification" +_TEST_MODEL_TYPE_REGRESSION = "regression" +_TEST_LEARNING_RATE = 0.01 +_TEST_DATA = pd.DataFrame(data={"col_0": [0, 1], "col_1": [2, 3]}) +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" +_TEST_DISPLAY_NAME = "test" +_TEST_MACHINE_TYPE = "n1-highmem-8" +_TEST_ACCELERATOR_COUNT = 8 +_TEST_ACCELERATOR_TYPE = "NVIDIA_TESLA_K80" +_TEST_BOOT_DISK_TYPE = "test_boot_disk_type" +_TEST_BOOT_DISK_SIZE_GB = 10 + + +class TestTabNetTrainer: + def setup_method(self): + reload(aiplatform.initializer) + reload(aiplatform) + + @pytest.mark.usefixtures( + "google_auth_mock", + "mock_uuid", + "mock_get_custom_job_succeeded", + "mock_blob_upload_from_filename", + ) + def test_tabnet_trainer_default( + self, + mock_create_custom_job, + mock_blob_download_to_filename, + mock_tf_saved_model_load, + ): + test_tabnet_trainer = tabnet_trainer.TabNetTrainer( + model_type=_TEST_MODEL_TYPE_CLASSIFICATION, + target_column=_TEST_TARGET_COLUMN, + learning_rate=_TEST_LEARNING_RATE, + ) + test_tabnet_trainer.fit.vertex.remote_config.staging_bucket = ( + _TEST_STAGING_BUCKET + ) + expected_binding = { + "model_type": _TEST_MODEL_TYPE_CLASSIFICATION, + "target_column": _TEST_TARGET_COLUMN, + "learning_rate": _TEST_LEARNING_RATE, + "enable_profiler": False, + "job_dir": "", + "cache_data": "auto", + "seed": 1, + "large_category_dim": 1, + "large_category_thresh": 300, + "yeo_johnson_transform": False, + "weight_column": "", + "max_steps": -1, + "max_train_secs": -1, + "measurement_selection_type": "BEST_MEASUREMENT", + "optimization_metric": "", + "eval_steps": 0, + "batch_size": 100, + "eval_frequency_secs": 600, + "feature_dim": 64, + "feature_dim_ratio": 0.5, + "num_decision_steps": 6, + "relaxation_factor": 1.5, + "decay_every": 100.0, + "decay_rate": 0.95, + "gradient_thresh": 2000.0, + "sparsity_loss_weight": 0.00001, + "batch_momentum": 0.95, + "batch_size_ratio": 0.25, + "num_transformer_layers": 4, + "num_transformer_layers_ratio": 0.25, + "class_weight": 1.0, + "loss_function_type": "default", + "alpha_focal_loss": 0.25, + "gamma_focal_loss": 2.0, + "is_remote_trainer": True, + } + assert test_tabnet_trainer._binding == expected_binding + + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + test_tabnet_trainer.fit(training_data=_TEST_DATA, validation_data=_TEST_DATA) + + expected_display_name = "TabNetTrainer-fit" + expected_job_dir = os.path.join(_TEST_STAGING_BUCKET, "custom_job") + expected_args = [ + f"--model_type={_TEST_MODEL_TYPE_CLASSIFICATION}", + f"--target_column={_TEST_TARGET_COLUMN}", + f"--learning_rate={_TEST_LEARNING_RATE}", + f"--job_dir={expected_job_dir}", + "--enable_profiler=False", + "--cache_data=auto", + "--seed=1", + "--large_category_dim=1", + "--large_category_thresh=300", + "--yeo_johnson_transform=False", + "--weight_column=", + "--max_steps=-1", + "--max_train_secs=-1", + "--measurement_selection_type=BEST_MEASUREMENT", + "--optimization_metric=", + "--eval_steps=0", + "--batch_size=100", + "--eval_frequency_secs=600", + "--feature_dim=64", + "--feature_dim_ratio=0.5", + "--num_decision_steps=6", + "--relaxation_factor=1.5", + "--decay_every=100.0", + "--decay_rate=0.95", + "--gradient_thresh=2000.0", + "--sparsity_loss_weight=1e-05", + "--batch_momentum=0.95", + "--batch_size_ratio=0.25", + "--num_transformer_layers=4", + "--num_transformer_layers_ratio=0.25", + "--class_weight=1.0", + "--loss_function_type=default", + "--alpha_focal_loss=0.25", + "--gamma_focal_loss=2.0", + "--is_remote_trainer=True", + f"--training_data_path={_TEST_STAGING_BUCKET}/input/training_data_path", + f"--validation_data_path={_TEST_STAGING_BUCKET}/input/validation_data_path", + f"--output_model_path={_TEST_STAGING_BUCKET}/output/output_model_path", + ] + expected_worker_pool_specs = [ + { + "replica_count": 1, + "machine_spec": { + "machine_type": "c2-standard-16", + "accelerator_type": remote_container_training._DEFAULT_ACCELERATOR_TYPE, + "accelerator_count": remote_container_training._DEFAULT_ACCELERATOR_COUNT, + }, + "disk_spec": { + "boot_disk_type": "pd-ssd", + "boot_disk_size_gb": 100, + }, + "container_spec": { + "image_uri": tabnet_trainer._TABNET_TRAINING_IMAGE, + "args": [], + }, + } + ] + expected_custom_job = gca_custom_job_compat.CustomJob( + display_name=f"{expected_display_name}-0", + job_spec=gca_custom_job_compat.CustomJobSpec( + worker_pool_specs=expected_worker_pool_specs, + base_output_directory=gca_io_compat.GcsDestination( + output_uri_prefix=os.path.join(_TEST_STAGING_BUCKET, "custom_job"), + ), + ), + ) + mock_create_custom_job.assert_called_once() + + assert ( + mock_create_custom_job.call_args[1]["parent"] + == f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" + ) + assert not mock_create_custom_job.call_args[1]["timeout"] + test_custom_job = mock_create_custom_job.call_args[1]["custom_job"] + + test_args = test_custom_job.job_spec.worker_pool_specs[0].container_spec.args + assert set(test_args) == set(expected_args) + + test_custom_job.job_spec.worker_pool_specs[0].container_spec.args = [] + assert test_custom_job == expected_custom_job + + mock_blob_download_to_filename.assert_called_once() + mock_tf_saved_model_load.assert_called_once() + + @pytest.mark.usefixtures( + "google_auth_mock", + "mock_uuid", + "mock_get_custom_job_succeeded", + "mock_blob_upload_from_filename", + ) + def test_tabnet_trainer_all_args( + self, + mock_create_custom_job, + mock_blob_download_to_filename, + mock_tf_saved_model_load, + ): + test_tabnet_trainer = tabnet_trainer.TabNetTrainer( + model_type=_TEST_MODEL_TYPE_CLASSIFICATION, + target_column=_TEST_TARGET_COLUMN, + learning_rate=_TEST_LEARNING_RATE, + job_dir=_TEST_JOB_DIR, + enable_profiler=True, + cache_data="test", + seed=2, + large_category_dim=2, + large_category_thresh=10, + yeo_johnson_transform=True, + weight_column="weight", + max_steps=5, + max_train_secs=600, + measurement_selection_type="LAST_MEASUREMENT", + optimization_metric="rmse", + eval_steps=1, + batch_size=10, + eval_frequency_secs=60, + feature_dim=8, + feature_dim_ratio=0.1, + num_decision_steps=3, + relaxation_factor=1.2, + decay_every=10.0, + decay_rate=0.9, + gradient_thresh=200.0, + sparsity_loss_weight=0.01, + batch_momentum=0.9, + batch_size_ratio=0.2, + num_transformer_layers=2, + num_transformer_layers_ratio=0.2, + class_weight=1.2, + loss_function_type="rmse", + alpha_focal_loss=0.2, + gamma_focal_loss=2.5, + ) + expected_binding = { + "model_type": _TEST_MODEL_TYPE_CLASSIFICATION, + "target_column": _TEST_TARGET_COLUMN, + "learning_rate": _TEST_LEARNING_RATE, + "job_dir": _TEST_JOB_DIR, + "enable_profiler": True, + "cache_data": "test", + "seed": 2, + "large_category_dim": 2, + "large_category_thresh": 10, + "yeo_johnson_transform": True, + "weight_column": "weight", + "max_steps": 5, + "max_train_secs": 600, + "measurement_selection_type": "LAST_MEASUREMENT", + "optimization_metric": "rmse", + "eval_steps": 1, + "batch_size": 10, + "eval_frequency_secs": 60, + "feature_dim": 8, + "feature_dim_ratio": 0.1, + "num_decision_steps": 3, + "relaxation_factor": 1.2, + "decay_every": 10.0, + "decay_rate": 0.9, + "gradient_thresh": 200.0, + "sparsity_loss_weight": 0.01, + "batch_momentum": 0.9, + "batch_size_ratio": 0.2, + "num_transformer_layers": 2, + "num_transformer_layers_ratio": 0.2, + "class_weight": 1.2, + "loss_function_type": "rmse", + "alpha_focal_loss": 0.2, + "gamma_focal_loss": 2.5, + "is_remote_trainer": True, + } + assert test_tabnet_trainer._binding == expected_binding + + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + test_tabnet_trainer.fit.vertex.remote_config.staging_bucket = ( + _TEST_STAGING_BUCKET + ) + test_tabnet_trainer.fit.vertex.remote_config.machine_type = _TEST_MACHINE_TYPE + test_tabnet_trainer.fit.vertex.remote_config.display_name = _TEST_DISPLAY_NAME + ( + test_tabnet_trainer.fit.vertex.remote_config.boot_disk_type + ) = _TEST_BOOT_DISK_TYPE + ( + test_tabnet_trainer.fit.vertex.remote_config.boot_disk_size_gb + ) = _TEST_BOOT_DISK_SIZE_GB + ( + test_tabnet_trainer.fit.vertex.remote_config.accelerator_type + ) = _TEST_ACCELERATOR_TYPE + ( + test_tabnet_trainer.fit.vertex.remote_config.accelerator_count + ) = _TEST_ACCELERATOR_COUNT + test_tabnet_trainer.fit(training_data=_TEST_DATA, validation_data=_TEST_DATA) + + expected_display_name = "TabNetTrainer-test" + expected_args = [ + f"--model_type={_TEST_MODEL_TYPE_CLASSIFICATION}", + f"--target_column={_TEST_TARGET_COLUMN}", + f"--learning_rate={_TEST_LEARNING_RATE}", + f"--job_dir={_TEST_JOB_DIR}", + "--enable_profiler=True", + "--cache_data=test", + "--seed=2", + "--large_category_dim=2", + "--large_category_thresh=10", + "--yeo_johnson_transform=True", + "--weight_column=weight", + "--max_steps=5", + "--max_train_secs=600", + "--measurement_selection_type=LAST_MEASUREMENT", + "--optimization_metric=rmse", + "--eval_steps=1", + "--batch_size=10", + "--eval_frequency_secs=60", + "--feature_dim=8", + "--feature_dim_ratio=0.1", + "--num_decision_steps=3", + "--relaxation_factor=1.2", + "--decay_every=10.0", + "--decay_rate=0.9", + "--gradient_thresh=200.0", + "--sparsity_loss_weight=0.01", + "--batch_momentum=0.9", + "--batch_size_ratio=0.2", + "--num_transformer_layers=2", + "--num_transformer_layers_ratio=0.2", + "--class_weight=1.2", + "--loss_function_type=rmse", + "--alpha_focal_loss=0.2", + "--gamma_focal_loss=2.5", + "--is_remote_trainer=True", + f"--training_data_path={_TEST_STAGING_BUCKET}/input/training_data_path", + f"--validation_data_path={_TEST_STAGING_BUCKET}/input/validation_data_path", + f"--output_model_path={_TEST_STAGING_BUCKET}/output/output_model_path", + ] + expected_worker_pool_specs = [ + { + "replica_count": 1, + "machine_spec": { + "machine_type": _TEST_MACHINE_TYPE, + "accelerator_type": _TEST_ACCELERATOR_TYPE, + "accelerator_count": _TEST_ACCELERATOR_COUNT, + }, + "disk_spec": { + "boot_disk_type": _TEST_BOOT_DISK_TYPE, + "boot_disk_size_gb": _TEST_BOOT_DISK_SIZE_GB, + }, + "container_spec": { + "image_uri": tabnet_trainer._TABNET_TRAINING_IMAGE, + "args": [], + }, + } + ] + expected_custom_job = gca_custom_job_compat.CustomJob( + display_name=f"{expected_display_name}-0", + job_spec=gca_custom_job_compat.CustomJobSpec( + worker_pool_specs=expected_worker_pool_specs, + base_output_directory=gca_io_compat.GcsDestination( + output_uri_prefix=os.path.join(_TEST_STAGING_BUCKET, "custom_job"), + ), + ), + ) + mock_create_custom_job.assert_called_once() + + assert ( + mock_create_custom_job.call_args[1]["parent"] + == f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" + ) + assert not mock_create_custom_job.call_args[1]["timeout"] + test_custom_job = mock_create_custom_job.call_args[1]["custom_job"] + + test_args = test_custom_job.job_spec.worker_pool_specs[0].container_spec.args + assert set(test_args) == set(expected_args) + + test_custom_job.job_spec.worker_pool_specs[0].container_spec.args = [] + assert test_custom_job == expected_custom_job + + mock_blob_download_to_filename.assert_called_once() + mock_tf_saved_model_load.assert_called_once() + + @pytest.mark.usefixtures( + "google_auth_mock", + "mock_uuid", + "mock_get_custom_job_succeeded", + "mock_blob_upload_from_filename", + ) + def test_tabnet_trainer_all_args_with_set_config_method( + self, + mock_create_custom_job, + mock_blob_download_to_filename, + mock_tf_saved_model_load, + ): + test_tabnet_trainer = tabnet_trainer.TabNetTrainer( + model_type=_TEST_MODEL_TYPE_CLASSIFICATION, + target_column=_TEST_TARGET_COLUMN, + learning_rate=_TEST_LEARNING_RATE, + job_dir=_TEST_JOB_DIR, + enable_profiler=True, + cache_data="test", + seed=2, + large_category_dim=2, + large_category_thresh=10, + yeo_johnson_transform=True, + weight_column="weight", + max_steps=5, + max_train_secs=600, + measurement_selection_type="LAST_MEASUREMENT", + optimization_metric="rmse", + eval_steps=1, + batch_size=10, + eval_frequency_secs=60, + feature_dim=8, + feature_dim_ratio=0.1, + num_decision_steps=3, + relaxation_factor=1.2, + decay_every=10.0, + decay_rate=0.9, + gradient_thresh=200.0, + sparsity_loss_weight=0.01, + batch_momentum=0.9, + batch_size_ratio=0.2, + num_transformer_layers=2, + num_transformer_layers_ratio=0.2, + class_weight=1.2, + loss_function_type="rmse", + alpha_focal_loss=0.2, + gamma_focal_loss=2.5, + ) + expected_binding = { + "model_type": _TEST_MODEL_TYPE_CLASSIFICATION, + "target_column": _TEST_TARGET_COLUMN, + "learning_rate": _TEST_LEARNING_RATE, + "job_dir": _TEST_JOB_DIR, + "enable_profiler": True, + "cache_data": "test", + "seed": 2, + "large_category_dim": 2, + "large_category_thresh": 10, + "yeo_johnson_transform": True, + "weight_column": "weight", + "max_steps": 5, + "max_train_secs": 600, + "measurement_selection_type": "LAST_MEASUREMENT", + "optimization_metric": "rmse", + "eval_steps": 1, + "batch_size": 10, + "eval_frequency_secs": 60, + "feature_dim": 8, + "feature_dim_ratio": 0.1, + "num_decision_steps": 3, + "relaxation_factor": 1.2, + "decay_every": 10.0, + "decay_rate": 0.9, + "gradient_thresh": 200.0, + "sparsity_loss_weight": 0.01, + "batch_momentum": 0.9, + "batch_size_ratio": 0.2, + "num_transformer_layers": 2, + "num_transformer_layers_ratio": 0.2, + "class_weight": 1.2, + "loss_function_type": "rmse", + "alpha_focal_loss": 0.2, + "gamma_focal_loss": 2.5, + "is_remote_trainer": True, + } + assert test_tabnet_trainer._binding == expected_binding + + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + + test_tabnet_trainer.fit.vertex.set_config( + staging_bucket=_TEST_STAGING_BUCKET, + machine_type=_TEST_MACHINE_TYPE, + display_name=_TEST_DISPLAY_NAME, + boot_disk_type=_TEST_BOOT_DISK_TYPE, + boot_disk_size_gb=_TEST_BOOT_DISK_SIZE_GB, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + ) + + assert isinstance( + test_tabnet_trainer.fit.vertex.remote_config, + configs.DistributedTrainingConfig, + ) + + test_tabnet_trainer.fit(training_data=_TEST_DATA, validation_data=_TEST_DATA) + + expected_display_name = "TabNetTrainer-test" + expected_args = [ + f"--model_type={_TEST_MODEL_TYPE_CLASSIFICATION}", + f"--target_column={_TEST_TARGET_COLUMN}", + f"--learning_rate={_TEST_LEARNING_RATE}", + f"--job_dir={_TEST_JOB_DIR}", + "--enable_profiler=True", + "--cache_data=test", + "--seed=2", + "--large_category_dim=2", + "--large_category_thresh=10", + "--yeo_johnson_transform=True", + "--weight_column=weight", + "--max_steps=5", + "--max_train_secs=600", + "--measurement_selection_type=LAST_MEASUREMENT", + "--optimization_metric=rmse", + "--eval_steps=1", + "--batch_size=10", + "--eval_frequency_secs=60", + "--feature_dim=8", + "--feature_dim_ratio=0.1", + "--num_decision_steps=3", + "--relaxation_factor=1.2", + "--decay_every=10.0", + "--decay_rate=0.9", + "--gradient_thresh=200.0", + "--sparsity_loss_weight=0.01", + "--batch_momentum=0.9", + "--batch_size_ratio=0.2", + "--num_transformer_layers=2", + "--num_transformer_layers_ratio=0.2", + "--class_weight=1.2", + "--loss_function_type=rmse", + "--alpha_focal_loss=0.2", + "--gamma_focal_loss=2.5", + "--is_remote_trainer=True", + f"--training_data_path={_TEST_STAGING_BUCKET}/input/training_data_path", + f"--validation_data_path={_TEST_STAGING_BUCKET}/input/validation_data_path", + f"--output_model_path={_TEST_STAGING_BUCKET}/output/output_model_path", + ] + expected_worker_pool_specs = [ + { + "replica_count": 1, + "machine_spec": { + "machine_type": _TEST_MACHINE_TYPE, + "accelerator_type": _TEST_ACCELERATOR_TYPE, + "accelerator_count": _TEST_ACCELERATOR_COUNT, + }, + "disk_spec": { + "boot_disk_type": _TEST_BOOT_DISK_TYPE, + "boot_disk_size_gb": _TEST_BOOT_DISK_SIZE_GB, + }, + "container_spec": { + "image_uri": tabnet_trainer._TABNET_TRAINING_IMAGE, + "args": [], + }, + } + ] + expected_custom_job = gca_custom_job_compat.CustomJob( + display_name=f"{expected_display_name}-0", + job_spec=gca_custom_job_compat.CustomJobSpec( + worker_pool_specs=expected_worker_pool_specs, + base_output_directory=gca_io_compat.GcsDestination( + output_uri_prefix=os.path.join(_TEST_STAGING_BUCKET, "custom_job"), + ), + ), + ) + mock_create_custom_job.assert_called_once() + + assert ( + mock_create_custom_job.call_args[1]["parent"] + == f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" + ) + assert not mock_create_custom_job.call_args[1]["timeout"] + test_custom_job = mock_create_custom_job.call_args[1]["custom_job"] + + test_args = test_custom_job.job_spec.worker_pool_specs[0].container_spec.args + assert set(test_args) == set(expected_args) + + test_custom_job.job_spec.worker_pool_specs[0].container_spec.args = [] + assert test_custom_job == expected_custom_job + + mock_blob_download_to_filename.assert_called_once() + mock_tf_saved_model_load.assert_called_once() + + def test_tabnet_trainer_predict_classification(self): + test_col_0 = [1.0, 3.0, 5.0] + test_col_1 = [2, 4, 6] + test_col_cat = [0, 1, 0] + + test_tabnet_trainer = tabnet_trainer.TabNetTrainer( + model_type=_TEST_MODEL_TYPE_CLASSIFICATION, + target_column=_TEST_TARGET_COLUMN, + learning_rate=_TEST_LEARNING_RATE, + ) + test_tabnet_trainer.model = Mock() + mock_serving_default = Mock() + mock_serving_default.return_value = { + "scores": tf.constant([[0.1, 0.9], [0.8, 0.2], [0.4, 0.6]]), + "classes": tf.constant([[0, 1], [0, 1], [0, 1]]), + } + expected_predict_results = pd.DataFrame({_TEST_TARGET_COLUMN: [1, 0, 1]}) + test_tabnet_trainer.model.signatures = {"serving_default": mock_serving_default} + test_data = pd.DataFrame( + {"col_0": test_col_0, "col_1": test_col_1, "col_cat": test_col_cat} + ) + test_data["col_cat"] = test_data["col_cat"].astype("category") + test_predict_results = test_tabnet_trainer.predict(test_data) + assert test_predict_results.equals(expected_predict_results) + + mock_serving_default.assert_called_once() + + assert not mock_serving_default.call_args[0] + assert list(mock_serving_default.call_args[1].keys()) == [ + "col_0", + "col_1", + "col_cat", + ] + + expected_input_col_0 = tf.constant(test_col_0, dtype=tf.float64) + assert ( + tf.equal(mock_serving_default.call_args[1]["col_0"], expected_input_col_0) + .numpy() + .all() + ) + expected_input_col_1 = tf.constant(test_col_1, dtype=tf.int64) + assert ( + tf.equal(mock_serving_default.call_args[1]["col_1"], expected_input_col_1) + .numpy() + .all() + ) + expected_input_col_cat = tf.constant(test_col_cat, dtype=tf.int64) + assert ( + tf.equal( + mock_serving_default.call_args[1]["col_cat"], expected_input_col_cat + ) + .numpy() + .all() + ) + + def test_tabnet_trainer_predict_regression(self): + test_col_0 = [1.0, 3.0, 5.0] + test_col_1 = [2, 4, 6] + test_col_cat = [0, 1, 0] + + test_tabnet_trainer = tabnet_trainer.TabNetTrainer( + model_type=_TEST_MODEL_TYPE_REGRESSION, + target_column=_TEST_TARGET_COLUMN, + learning_rate=_TEST_LEARNING_RATE, + ) + test_tabnet_trainer.model = Mock() + mock_serving_default = Mock() + mock_serving_default.return_value = { + "value": tf.constant([[0.1], [0.2], [0.3]], dtype=tf.float64) + } + expected_predict_results = pd.DataFrame({_TEST_TARGET_COLUMN: [0.1, 0.2, 0.3]}) + test_tabnet_trainer.model.signatures = {"serving_default": mock_serving_default} + test_data = pd.DataFrame( + {"col_0": test_col_0, "col_1": test_col_1, "col_cat": test_col_cat} + ) + test_data["col_cat"] = test_data["col_cat"].astype("category") + test_predict_results = test_tabnet_trainer.predict(test_data) + assert test_predict_results.equals(expected_predict_results) + + mock_serving_default.assert_called_once() + + assert not mock_serving_default.call_args[0] + assert list(mock_serving_default.call_args[1].keys()) == [ + "col_0", + "col_1", + "col_cat", + ] + + expected_input_col_0 = tf.constant(test_col_0, dtype=tf.float64) + assert ( + tf.equal(mock_serving_default.call_args[1]["col_0"], expected_input_col_0) + .numpy() + .all() + ) + expected_input_col_1 = tf.constant(test_col_1, dtype=tf.int64) + assert ( + tf.equal(mock_serving_default.call_args[1]["col_1"], expected_input_col_1) + .numpy() + .all() + ) + expected_input_col_cat = tf.constant(test_col_cat, dtype=tf.int64) + assert ( + tf.equal( + mock_serving_default.call_args[1]["col_cat"], expected_input_col_cat + ) + .numpy() + .all() + ) + + def test_tabnet_trainer_predict_load_model(self, mock_tf_saved_model_load): + test_col_0 = [1.0, 3.0, 5.0] + test_col_1 = [2, 4, 6] + test_col_cat = [0, 1, 0] + + test_tabnet_trainer = tabnet_trainer.TabNetTrainer( + model_type=_TEST_MODEL_TYPE_CLASSIFICATION, + target_column=_TEST_TARGET_COLUMN, + learning_rate=_TEST_LEARNING_RATE, + ) + test_tabnet_trainer.output_model_path = ( + f"{_TEST_STAGING_BUCKET}/output/output_model_path" + ) + mock_serving_default = Mock() + mock_serving_default.return_value = { + "scores": tf.constant([[0.1, 0.9], [0.8, 0.2], [0.4, 0.6]]), + "classes": tf.constant([[0, 1], [0, 1], [0, 1]]), + } + expected_predict_results = pd.DataFrame({_TEST_TARGET_COLUMN: [1, 0, 1]}) + mock_tf_saved_model_load.return_value.signatures = { + "serving_default": mock_serving_default + } + test_data = pd.DataFrame( + {"col_0": test_col_0, "col_1": test_col_1, "col_cat": test_col_cat} + ) + test_data["col_cat"] = test_data["col_cat"].astype("category") + test_predict_results = test_tabnet_trainer.predict(test_data) + assert test_predict_results.equals(expected_predict_results) + + mock_tf_saved_model_load.assert_called_once_with( + test_tabnet_trainer.output_model_path + ) + mock_serving_default.assert_called_once() + + assert not mock_serving_default.call_args[0] + assert list(mock_serving_default.call_args[1].keys()) == [ + "col_0", + "col_1", + "col_cat", + ] + + expected_input_col_0 = tf.constant(test_col_0, dtype=tf.float64) + assert ( + tf.equal(mock_serving_default.call_args[1]["col_0"], expected_input_col_0) + .numpy() + .all() + ) + expected_input_col_1 = tf.constant(test_col_1, dtype=tf.int64) + assert ( + tf.equal(mock_serving_default.call_args[1]["col_1"], expected_input_col_1) + .numpy() + .all() + ) + expected_input_col_cat = tf.constant(test_col_cat, dtype=tf.int64) + assert ( + tf.equal( + mock_serving_default.call_args[1]["col_cat"], expected_input_col_cat + ) + .numpy() + .all() + ) + + def test_tabnet_trainer_predict_no_trained_model(self): + test_tabnet_trainer = tabnet_trainer.TabNetTrainer( + model_type=_TEST_MODEL_TYPE_CLASSIFICATION, + target_column=_TEST_TARGET_COLUMN, + learning_rate=_TEST_LEARNING_RATE, + ) + err_msg = re.escape("No trained model. Please call .fit first.") + with pytest.raises(ValueError, match=err_msg): + test_tabnet_trainer.predict(pd.DataFrame()) + + self.output_model_path = None + with pytest.raises(ValueError, match=err_msg): + test_tabnet_trainer.predict(pd.DataFrame()) + + def test_tabnet_trainer_predict_invalid_model_type(self): + test_invalid_model_type = "invalid" + test_tabnet_trainer = tabnet_trainer.TabNetTrainer( + model_type=test_invalid_model_type, + target_column=_TEST_TARGET_COLUMN, + learning_rate=_TEST_LEARNING_RATE, + ) + test_tabnet_trainer.model = Mock() + test_tabnet_trainer.model.signatures = {"serving_default": Mock()} + err_msg = f"Unsupported model type: {test_invalid_model_type}" + with pytest.raises(ValueError, match=err_msg): + test_tabnet_trainer.predict(pd.DataFrame()) + + def test_tabnet_trainer_invalid_gcs_path(self): + test_invalid_path = "invalid_gcs_path" + err_msg = re.escape( + f"Invalid GCS path {test_invalid_path}. Please provide a valid GCS path starting with 'gs://'" + ) + with pytest.raises(ValueError, match=err_msg): + tabnet_trainer.TabNetTrainer( + model_type=_TEST_MODEL_TYPE_CLASSIFICATION, + target_column=_TEST_TARGET_COLUMN, + learning_rate=_TEST_LEARNING_RATE, + job_dir=test_invalid_path, + ) diff --git a/tests/unit/vertexai/test_vizier_hyperparameter_tuner.py b/tests/unit/vertexai/test_vizier_hyperparameter_tuner.py new file mode 100644 index 0000000000..1fe8b7d03f --- /dev/null +++ b/tests/unit/vertexai/test_vizier_hyperparameter_tuner.py @@ -0,0 +1,1850 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Tests for hyperparameter_tuning/vizier_hyperparameter_tuner.py. +""" + +import concurrent +from importlib import reload +from unittest import mock + +from google.cloud import aiplatform +import vertexai +from google.cloud.aiplatform_v1.services.vizier_service import ( + VizierServiceClient, +) +from google.cloud.aiplatform_v1.types.study import Measurement +from google.cloud.aiplatform_v1.types.study import Trial +from google.cloud.aiplatform_v1.types.vizier_service import ( + SuggestTrialsResponse, +) +from vertexai.preview._workflow.driver import remote +from vertexai.preview._workflow.driver import ( + VertexRemoteFunctor, +) +from vertexai.preview._workflow.executor import training +from vertexai.preview._workflow.shared import configs +from vertexai.preview.developer import remote_specs +from vertexai.preview.hyperparameter_tuning import ( + VizierHyperparameterTuner, +) +import numpy as np +import pandas as pd +import pytest +from sklearn.linear_model import _logistic +import sklearn.metrics +import tensorflow as tf + + +_TEST_PARAMETER_SPEC = { + "parameter_id": "x", + "double_value_spec": {"min_value": -10.0, "max_value": 10.0}, +} +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" +_TEST_STUDY_NAME_PREFIX = "test_study" + +_TRAIN_COL_0 = np.array([0.1] * 100) +_TRAIN_COL_1 = np.array([0.2] * 100) +_TEST_COL_0 = np.array([0.3] * 100) +_TEST_COL_1 = np.array([0.4] * 100) +_TRAIN_TARGET = np.array([1] * 100) +_TEST_TARGET = np.array([1] * 100) +_TEST_X_TRAIN = pd.DataFrame({"col_0": _TRAIN_COL_0, "col_1": _TRAIN_COL_1}) +_TEST_Y_TRAIN = pd.DataFrame( + { + "target": _TRAIN_TARGET, + } +) +_TEST_TRAINING_DATA = pd.DataFrame( + {"col_0": _TRAIN_COL_0, "col_1": _TRAIN_COL_0, "target": _TRAIN_TARGET} +) +_TEST_X_TEST = pd.DataFrame({"col_0": _TEST_COL_0, "col_1": _TEST_COL_1}) +_TEST_Y_TEST_CLASSIFICATION_BINARY = pd.DataFrame( + { + "target": np.array([0] * 50 + [1] * 50), + } +) +_TEST_Y_PRED_CLASSIFICATION_BINARY = pd.DataFrame( + { + "target": np.array([0] * 30 + [1] * 70), + } +) +_TEST_Y_TEST_CLASSIFICATION_MULTI_CLASS = pd.DataFrame( + { + "target": np.array([1] * 25 + [2] * 25 + [3] * 25 + [4] * 25), + } +) +_TEST_Y_PRED_CLASSIFICATION_MULTI_CLASS = pd.DataFrame( + { + "target": np.array([1] * 25 + [2] * 25 + [4] * 25 + [8] * 25), + } +) +_TEST_Y_TEST_CLASSIFICATION_BINARY_KERAS = pd.DataFrame( + {"target": np.array([0, 1, 0, 1, 0])} +) +_TEST_Y_PRED_CLASSIFICATION_BINARY_KERAS = pd.DataFrame( + {"target": np.array([0.01, 0.56, 0.03, 0.65, 0.74])} +) +_TEST_Y_PRED_CLASSIFICATION_BINARY_KERAS_TRANSFORMED = pd.DataFrame( + {"target": np.array([0, 1, 0, 1, 1])} +) +_TEST_Y_TEST_CLASSIFICATION_MULTI_CLASS_KERAS = pd.DataFrame( + {"target": np.array([0, 1, 2, 1, 2])} +) +_TEST_Y_PRED_CLASSIFICATION_MULTI_CLASS_KERAS = pd.DataFrame( + { + "target_0": [0.98, 0.02, 0.01, 0.02, 0.02], + "target_1": [0.01, 0.97, 0.34, 0.96, 0.95], + "target_2": [0.01, 0.01, 0.65, 0.02, 0.03], + } +) +_TEST_Y_PRED_CLASSIFICATION_MULTI_CLASS_KERAS_TRANSFORMED = pd.DataFrame( + {"target": np.array([0, 1, 2, 1, 1])} +) +_TEST_Y_TEST_REGRESSION = pd.DataFrame( + { + "target": np.array([0.6] * 100), + } +) +_TEST_Y_PRED_REGRESSION = pd.DataFrame( + { + "target": np.array([0.8] * 100), + } +) +_TEST_CUSTOM_METRIC_VALUE = 0.5 +_TEST_VALIDATION_DATA = pd.DataFrame( + { + "col_0": _TEST_COL_0, + "col_1": _TEST_COL_1, + "target": _TEST_Y_TEST_CLASSIFICATION_BINARY["target"], + } +) + +_TEST_DISPLAY_NAME = "test_display_name" +_TEST_STAGING_BUCKET = "gs://test-staging-bucket" +_TEST_CONTAINER_URI = "gcr.io/test-image" +_TEST_CONTAINER_URI = "gcr.io/test-image" +_TEST_MACHINE_TYPE = "n1-standard-4" +_TEST_SERVICE_ACCOUNT = "test-service-account" +_TEST_TRAINING_CONFIG = configs.RemoteConfig( + display_name=_TEST_DISPLAY_NAME, + staging_bucket=_TEST_STAGING_BUCKET, + container_uri=_TEST_CONTAINER_URI, + machine_type=_TEST_MACHINE_TYPE, + service_account=_TEST_SERVICE_ACCOUNT, +) +_TEST_REMOTE_CONTAINER_TRAINING_CONFIG = configs.DistributedTrainingConfig( + display_name=_TEST_DISPLAY_NAME, + staging_bucket=_TEST_STAGING_BUCKET, +) +_TEST_TRIAL_NAME = "projects/123/locations/us/central1/studies/123/trials/1" +_TEST_TRIAL_STAGING_BUCKET = ( + _TEST_STAGING_BUCKET + "/projects-123-locations-us-central1-studies-123-trials/1" +) + + +@pytest.fixture +def mock_create_study(): + with mock.patch.object(VizierServiceClient, "create_study") as create_study_mock: + create_study_mock.return_value.name = "test_study" + yield create_study_mock + + +@pytest.fixture +def mock_suggest_trials(): + with mock.patch.object( + VizierServiceClient, "suggest_trials" + ) as suggest_trials_mock: + yield suggest_trials_mock + + +@pytest.fixture +def mock_list_trials(): + with mock.patch.object(VizierServiceClient, "list_trials") as list_trials_mock: + list_trials_mock.return_value.trials = [ + Trial( + name="trial_0", + final_measurement=Measurement( + metrics=[Measurement.Metric(metric_id="accuracy", value=0.5)] + ), + state=Trial.State.SUCCEEDED, + ), + Trial( + name="trial_1", + final_measurement=Measurement( + metrics=[Measurement.Metric(metric_id="accuracy", value=0.34)] + ), + state=Trial.State.SUCCEEDED, + ), + Trial( + name="trial_2", + final_measurement=Measurement( + metrics=[Measurement.Metric(metric_id="accuracy", value=0.99)] + ), + state=Trial.State.SUCCEEDED, + ), + Trial( + name="trial_3", + final_measurement=Measurement( + metrics=[Measurement.Metric(metric_id="accuracy", value=1.0)] + ), + state=Trial.State.STOPPING, + ), + ] + yield list_trials_mock + + +@pytest.fixture +def mock_complete_trial(): + with mock.patch.object( + VizierServiceClient, "complete_trial" + ) as complete_trial_mock: + yield complete_trial_mock + + +@pytest.fixture +def mock_binary_classifier(): + model = mock.Mock() + model.predict.return_value = _TEST_Y_PRED_CLASSIFICATION_BINARY + yield model + + +@pytest.fixture +def mock_multi_class_classifier(): + model = mock.Mock() + model.predict.return_value = _TEST_Y_PRED_CLASSIFICATION_MULTI_CLASS + yield model + + +@pytest.fixture +def mock_regressor(): + model = mock.Mock() + model.predict.return_value = _TEST_Y_PRED_REGRESSION + return model + + +@pytest.fixture +def mock_model_custom_metric(): + model = mock.Mock() + model.score.return_value = _TEST_CUSTOM_METRIC_VALUE + yield model + + +@pytest.fixture +def mock_executor_map(): + with mock.patch.object( + concurrent.futures.ThreadPoolExecutor, "map" + ) as executor_map_mock: + yield executor_map_mock + + +@pytest.fixture +def mock_keras_classifier(): + with mock.patch("tensorflow.keras.Sequential", autospec=True) as keras_mock: + yield keras_mock + + +class TestTrainerA(remote.VertexModel): + def predict(self, x_test): + return + + @vertexai.preview.developer.mark.train( + remote_config=_TEST_TRAINING_CONFIG, + ) + def train(self, x, y): + return + + +def get_test_trainer_a(): + model = TestTrainerA() + model.predict = mock.Mock() + model.predict.return_value = _TEST_Y_PRED_CLASSIFICATION_BINARY + return model + + +class TestTrainerB(remote.VertexModel): + def predict(self, x_test): + return + + @vertexai.preview.developer.mark.train( + remote_config=_TEST_TRAINING_CONFIG, + ) + def train(self, x_train, y_train, x_test, y_test): + return + + +def get_test_trainer_b(): + model = TestTrainerB() + model.predict = mock.Mock() + model.predict.return_value = _TEST_Y_PRED_CLASSIFICATION_BINARY + return model + + +class TestRemoteContainerTrainer(remote.VertexModel): + def __init__(self): + super().__init__() + self._binding = {} + + def predict(self, x_test): + return + + # pylint: disable=invalid-name,missing-function-docstring + @vertexai.preview.developer.mark._remote_container_train( + image_uri=_TEST_CONTAINER_URI, + additional_data=[ + remote_specs._InputParameterSpec( + "training_data", + argument_name="training_data_path", + serializer="parquet", + ), + remote_specs._InputParameterSpec( + "validation_data", + argument_name="validation_data_path", + serializer="parquet", + ), + ], + remote_config=_TEST_REMOTE_CONTAINER_TRAINING_CONFIG, + ) + def fit(self, training_data, validation_data): + return + + +def get_test_remote_container_trainer(): + model = TestRemoteContainerTrainer() + model.predict = mock.Mock() + model.predict.return_value = _TEST_Y_PRED_CLASSIFICATION_BINARY + return model + + +class TestVizierHyperparameterTuner: + def setup_method(self): + reload(aiplatform.initializer) + reload(aiplatform) + reload(vertexai.preview.initializer) + reload(vertexai) + + @pytest.mark.usefixtures("google_auth_mock", "mock_uuid") + def test_vizier_hyper_parameter_tuner(self, mock_create_study): + vertexai.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + test_model_name = "test_model" + test_max_trial_count = 16 + test_parallel_trial_count = 4 + test_hparam_space = [_TEST_PARAMETER_SPEC] + test_metric_id = "rmse" + test_metric_goal = "MINIMIZE" + test_max_failed_trial_count = 12 + test_search_algorithm = "RANDOM_SEARCH" + test_project = "custom-project" + test_location = "custom-location" + + def get_model_func(): + model = mock.Mock() + model.name = test_model_name + return model + + test_tuner = VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=test_max_trial_count, + parallel_trial_count=test_parallel_trial_count, + hparam_space=test_hparam_space, + metric_id=test_metric_id, + metric_goal=test_metric_goal, + max_failed_trial_count=test_max_failed_trial_count, + search_algorithm=test_search_algorithm, + project=test_project, + location=test_location, + study_display_name_prefix=_TEST_STUDY_NAME_PREFIX, + ) + assert test_tuner.get_model_func().name == test_model_name + assert test_tuner.max_trial_count == test_max_trial_count + assert test_tuner.parallel_trial_count == test_parallel_trial_count + assert test_tuner.hparam_space == test_hparam_space + assert test_tuner.metric_id == test_metric_id + assert test_tuner.metric_goal == test_metric_goal + assert test_tuner.max_failed_trial_count == test_max_failed_trial_count + assert test_tuner.search_algorithm == test_search_algorithm + assert test_tuner.vertex == configs.VertexConfig() + + expected_study_name = f"{_TEST_STUDY_NAME_PREFIX}_0" + expected_study_config = { + "display_name": expected_study_name, + "study_spec": { + "algorithm": test_search_algorithm, + "parameters": test_hparam_space, + "metrics": [{"metric_id": test_metric_id, "goal": test_metric_goal}], + }, + } + expected_parent = f"projects/{test_project}/locations/{test_location}" + mock_create_study.assert_called_once_with( + parent=expected_parent, study=expected_study_config + ) + assert isinstance(test_tuner.vizier_client, VizierServiceClient) + assert test_tuner.study == mock_create_study.return_value + + @pytest.mark.usefixtures("google_auth_mock", "mock_uuid") + def test_vizier_hyper_parameter_tuner_default(self, mock_create_study): + test_model_name = "test_model" + test_max_trial_count = 16 + test_parallel_trial_count = 4 + test_hparam_space = [_TEST_PARAMETER_SPEC] + + def get_model_func(): + model = mock.Mock() + model.name = test_model_name + return model + + test_tuner = VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=test_max_trial_count, + parallel_trial_count=test_parallel_trial_count, + hparam_space=test_hparam_space, + ) + assert test_tuner.get_model_func().name == test_model_name + assert test_tuner.max_trial_count == test_max_trial_count + assert test_tuner.parallel_trial_count == test_parallel_trial_count + assert test_tuner.hparam_space == test_hparam_space + assert test_tuner.metric_id == "accuracy" + assert test_tuner.metric_goal == "MAXIMIZE" + assert test_tuner.max_failed_trial_count == 0 + assert test_tuner.search_algorithm == "ALGORITHM_UNSPECIFIED" + assert test_tuner.vertex == configs.VertexConfig() + + expected_study_name = "vizier_hyperparameter_tuner_study_0" + expected_study_config = { + "display_name": expected_study_name, + "study_spec": { + "algorithm": "ALGORITHM_UNSPECIFIED", + "parameters": test_hparam_space, + "metrics": [{"metric_id": "accuracy", "goal": "MAXIMIZE"}], + }, + } + expected_parent = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" + mock_create_study.assert_called_once_with( + parent=expected_parent, study=expected_study_config + ) + assert isinstance(test_tuner.vizier_client, VizierServiceClient) + assert test_tuner.study == mock_create_study.return_value + + def test_vizier_hyper_parameter_tuner_error(self): + def get_model_func(): + return + + test_invalid_metric_id = "invalid_metric_id" + with pytest.raises(ValueError, match="Unsupported metric_id"): + VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=16, + parallel_trial_count=4, + hparam_space=[_TEST_PARAMETER_SPEC], + metric_id=test_invalid_metric_id, + ) + + @pytest.mark.usefixtures("google_auth_mock", "mock_create_study") + def test_get_best_models(self, mock_list_trials): + vertexai.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + test_max_trial_count = 16 + test_parallel_trial_count = 4 + test_hparam_space = [_TEST_PARAMETER_SPEC] + + def get_model_func(): + return mock.Mock() + + test_tuner = VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=test_max_trial_count, + parallel_trial_count=test_parallel_trial_count, + hparam_space=test_hparam_space, + ) + test_tuner.models["trial_0"] = get_model_func() + test_tuner.models["trial_1"] = get_model_func() + test_tuner.models["trial_2"] = get_model_func() + test_tuner.models["trial_3"] = get_model_func() + assert test_tuner.get_best_models(2) == [ + test_tuner.models["trial_2"], + test_tuner.models["trial_0"], + ] + mock_list_trials.assert_called_once_with({"parent": "test_study"}) + + @pytest.mark.usefixtures("google_auth_mock", "mock_create_study") + def test_create_train_and_test_split_x_and_y(self): + x = pd.DataFrame( + { + "col_0": np.array([0.1] * 100), + "col_1": np.array([0.2] * 100), + } + ) + y = pd.DataFrame( + { + "target": np.array([0.3] * 100), + } + ) + + def get_model_func(): + return mock.Mock() + + test_tuner = VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=1, + parallel_trial_count=1, + hparam_space=[], + ) + x_train, x_test, y_train, y_test = test_tuner._create_train_and_test_splits( + x, y + ) + assert x_train.shape == (75, 2) + assert list(x_train.columns) == ["col_0", "col_1"] + assert x_test.shape == (25, 2) + assert list(x_test.columns) == ["col_0", "col_1"] + assert y_train.shape == (75, 1) + assert list(y_train.columns) == ["target"] + assert y_test.shape == (25, 1) + assert list(y_test.columns) == ["target"] + + @pytest.mark.usefixtures("google_auth_mock", "mock_create_study") + def test_create_train_and_test_split_only_x(self): + x = pd.DataFrame( + { + "col_0": np.array([0.1] * 100), + "col_1": np.array([0.2] * 100), + "target": np.array([0.3] * 100), + } + ) + + def get_model_func(): + return mock.Mock() + + test_tuner = VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=1, + parallel_trial_count=1, + hparam_space=[], + ) + x_train, x_test, y_train, y_test = test_tuner._create_train_and_test_splits( + x, "target", test_fraction=0.2 + ) + assert x_train.shape == (80, 3) + assert list(x_train.columns) == ["col_0", "col_1", "target"] + assert x_test.shape == (20, 2) + assert list(x_test.columns) == ["col_0", "col_1"] + assert not y_train + assert y_test.shape == (20, 1) + assert list(y_test.columns) == ["target"] + + @pytest.mark.parametrize( + "test_fraction", + [-0.2, 0, 1, 1.2], + ) + @pytest.mark.usefixtures("google_auth_mock", "mock_create_study") + def test_create_train_and_test_split_invalid_test_fraction(self, test_fraction): + def get_model_func(): + return mock.Mock() + + test_tuner = VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=1, + parallel_trial_count=1, + hparam_space=[], + ) + + err_msg = f"test_fraction must be greater than 0 and less than 1 but was {test_fraction}" + with pytest.raises(ValueError, match=err_msg): + test_tuner._create_train_and_test_splits( + pd.DataFrame(), pd.DataFrame(), test_fraction=test_fraction + ) + + @pytest.mark.parametrize( + "metric_id,expected_value", + [ + ( + "roc_auc", + sklearn.metrics.roc_auc_score( + _TEST_Y_TEST_CLASSIFICATION_BINARY, + _TEST_Y_PRED_CLASSIFICATION_BINARY, + ), + ), + ( + "f1", + sklearn.metrics.f1_score( + _TEST_Y_TEST_CLASSIFICATION_BINARY, + _TEST_Y_PRED_CLASSIFICATION_BINARY, + ), + ), + ( + "precision", + sklearn.metrics.precision_score( + _TEST_Y_TEST_CLASSIFICATION_BINARY, + _TEST_Y_PRED_CLASSIFICATION_BINARY, + ), + ), + ( + "recall", + sklearn.metrics.recall_score( + _TEST_Y_TEST_CLASSIFICATION_BINARY, + _TEST_Y_PRED_CLASSIFICATION_BINARY, + ), + ), + ( + "accuracy", + sklearn.metrics.accuracy_score( + _TEST_Y_TEST_CLASSIFICATION_BINARY, + _TEST_Y_PRED_CLASSIFICATION_BINARY, + ), + ), + ], + ) + @pytest.mark.usefixtures("google_auth_mock", "mock_create_study") + def test_evaluate_model_binary_classification( + self, + metric_id, + expected_value, + mock_binary_classifier, + ): + def get_model_func(): + return mock.Mock() + + test_tuner = VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=1, + parallel_trial_count=1, + hparam_space=[], + metric_id=metric_id, + metric_goal="MAXIMIZE", + ) + test_model, test_value = test_tuner._evaluate_model( + mock_binary_classifier, + _TEST_X_TEST, + _TEST_Y_TEST_CLASSIFICATION_BINARY, + ) + assert test_value == expected_value + assert test_model == mock_binary_classifier + + @pytest.mark.parametrize( + "metric_id,expected_value", + [ + ( + "accuracy", + sklearn.metrics.accuracy_score( + _TEST_Y_TEST_CLASSIFICATION_MULTI_CLASS, + _TEST_Y_PRED_CLASSIFICATION_MULTI_CLASS, + ), + ), + ], + ) + @pytest.mark.usefixtures("google_auth_mock", "mock_create_study") + def test_evaluate_model_multi_class_classification( + self, + metric_id, + expected_value, + mock_multi_class_classifier, + ): + def get_model_func(): + return mock.Mock() + + test_tuner = VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=1, + parallel_trial_count=1, + hparam_space=[], + metric_id=metric_id, + metric_goal="MAXIMIZE", + ) + test_model, test_value = test_tuner._evaluate_model( + mock_multi_class_classifier, + _TEST_X_TEST, + _TEST_Y_TEST_CLASSIFICATION_MULTI_CLASS, + ) + assert test_value == expected_value + assert test_model == mock_multi_class_classifier + + @pytest.mark.parametrize( + "metric_id,metric_goal,expected_value", + [ + ( + "mae", + "MINIMIZE", + sklearn.metrics.mean_absolute_error( + _TEST_Y_TEST_REGRESSION, _TEST_Y_PRED_REGRESSION + ), + ), + ( + "mape", + "MINIMIZE", + sklearn.metrics.mean_absolute_percentage_error( + _TEST_Y_TEST_REGRESSION, _TEST_Y_PRED_REGRESSION + ), + ), + ( + "r2", + "MAXIMIZE", + sklearn.metrics.r2_score( + _TEST_Y_TEST_REGRESSION, _TEST_Y_PRED_REGRESSION + ), + ), + ( + "rmse", + "MINIMIZE", + sklearn.metrics.mean_squared_error( + _TEST_Y_TEST_REGRESSION, _TEST_Y_PRED_REGRESSION, squared=False + ), + ), + ( + "rmsle", + "MINIMIZE", + sklearn.metrics.mean_squared_log_error( + _TEST_Y_TEST_REGRESSION, _TEST_Y_PRED_REGRESSION, squared=False + ), + ), + ( + "mse", + "MINIMIZE", + sklearn.metrics.mean_squared_error( + _TEST_Y_TEST_REGRESSION, _TEST_Y_PRED_REGRESSION + ), + ), + ], + ) + @pytest.mark.usefixtures("google_auth_mock", "mock_create_study") + def test_evaluate_model_regression( + self, + metric_id, + metric_goal, + expected_value, + mock_regressor, + ): + def get_model_func(): + return mock.Mock() + + test_tuner = VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=1, + parallel_trial_count=1, + hparam_space=[], + metric_id=metric_id, + metric_goal=metric_goal, + ) + test_model, test_value = test_tuner._evaluate_model( + mock_regressor, _TEST_X_TEST, _TEST_Y_TEST_REGRESSION + ) + assert test_value == expected_value + assert test_model == mock_regressor + + @pytest.mark.usefixtures("google_auth_mock", "mock_create_study") + def test_evaluate_model_custom_metric( + self, + mock_model_custom_metric, + ): + def get_model_func(): + return mock.Mock() + + test_tuner = VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=1, + parallel_trial_count=1, + hparam_space=[], + metric_id="custom", + ) + test_model, test_value = test_tuner._evaluate_model( + mock_model_custom_metric, + _TEST_X_TEST, + _TEST_Y_TEST_REGRESSION, + ) + assert test_value == _TEST_CUSTOM_METRIC_VALUE + assert test_model == mock_model_custom_metric + + @pytest.mark.usefixtures("google_auth_mock", "mock_create_study") + def test_evaluate_model_invalid(self): + def get_model_func(): + return mock.Mock() + + test_tuner = VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=1, + parallel_trial_count=1, + hparam_space=[], + ) + test_tuner.metric_id = "invalid_metric_id" + with pytest.raises(ValueError, match="Unsupported metric_id"): + test_tuner._evaluate_model( + "model", + pd.DataFrame(), + pd.DataFrame(), + ) + + @pytest.mark.parametrize( + "metric_id,expected_value", + [ + ( + "roc_auc", + sklearn.metrics.roc_auc_score( + _TEST_Y_TEST_CLASSIFICATION_BINARY_KERAS, + _TEST_Y_PRED_CLASSIFICATION_BINARY_KERAS_TRANSFORMED, + ), + ), + ( + "f1", + sklearn.metrics.f1_score( + _TEST_Y_TEST_CLASSIFICATION_BINARY_KERAS, + _TEST_Y_PRED_CLASSIFICATION_BINARY_KERAS_TRANSFORMED, + ), + ), + ( + "precision", + sklearn.metrics.precision_score( + _TEST_Y_TEST_CLASSIFICATION_BINARY_KERAS, + _TEST_Y_PRED_CLASSIFICATION_BINARY_KERAS_TRANSFORMED, + ), + ), + ( + "recall", + sklearn.metrics.recall_score( + _TEST_Y_TEST_CLASSIFICATION_BINARY_KERAS, + _TEST_Y_PRED_CLASSIFICATION_BINARY_KERAS_TRANSFORMED, + ), + ), + ( + "accuracy", + sklearn.metrics.accuracy_score( + _TEST_Y_TEST_CLASSIFICATION_BINARY_KERAS, + _TEST_Y_PRED_CLASSIFICATION_BINARY_KERAS_TRANSFORMED, + ), + ), + ], + ) + @pytest.mark.usefixtures("google_auth_mock", "mock_create_study") + def test_evaluate_keras_model_binary_classification( + self, metric_id, expected_value, mock_keras_classifier + ): + def get_model_func(): + return mock.Mock + + test_tuner = VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=1, + parallel_trial_count=1, + hparam_space=[], + metric_id=metric_id, + metric_goal="MAXIMIZE", + ) + mock_keras_classifier.predict.return_value = ( + _TEST_Y_PRED_CLASSIFICATION_BINARY_KERAS + ) + + test_model, test_value = test_tuner._evaluate_model( + mock_keras_classifier, + _TEST_X_TEST, + _TEST_Y_TEST_CLASSIFICATION_BINARY_KERAS, + ) + assert test_value == expected_value + assert test_model == mock_keras_classifier + + @pytest.mark.parametrize( + "metric_id,expected_value", + [ + ( + "accuracy", + sklearn.metrics.accuracy_score( + _TEST_Y_TEST_CLASSIFICATION_MULTI_CLASS_KERAS, + _TEST_Y_PRED_CLASSIFICATION_MULTI_CLASS_KERAS_TRANSFORMED, + ), + ), + ], + ) + @pytest.mark.usefixtures("google_auth_mock", "mock_create_study") + def test_evaluate_keras_model_multi_class_classification( + self, + metric_id, + expected_value, + mock_keras_classifier, + ): + def get_model_func(): + return mock.Mock() + + test_tuner = VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=1, + parallel_trial_count=1, + hparam_space=[], + metric_id=metric_id, + metric_goal="MAXIMIZE", + ) + mock_keras_classifier.predict.return_value = ( + _TEST_Y_PRED_CLASSIFICATION_MULTI_CLASS_KERAS + ) + test_model, test_value = test_tuner._evaluate_model( + mock_keras_classifier, + _TEST_X_TEST, + _TEST_Y_TEST_CLASSIFICATION_MULTI_CLASS_KERAS, + ) + assert test_value == expected_value + assert test_model == mock_keras_classifier + + @pytest.mark.usefixtures("google_auth_mock", "mock_create_study") + def test_add_model_and_report_trial_metrics_feasible( + self, mock_binary_classifier, mock_complete_trial + ): + def get_model_func(): + return mock.Mock() + + test_tuner = VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=1, + parallel_trial_count=1, + hparam_space=[], + ) + test_trial_name = "trial_0" + test_model = mock_binary_classifier + test_metric_value = 1.0 + test_tuner._add_model_and_report_trial_metrics( + test_trial_name, + (test_model, test_metric_value), + ) + mock_complete_trial.assert_called_once_with( + { + "name": test_trial_name, + "final_measurement": { + "metrics": [{"metric_id": "accuracy", "value": test_metric_value}] + }, + } + ) + assert test_tuner.models == {test_trial_name: mock_binary_classifier} + + @pytest.mark.usefixtures("google_auth_mock", "mock_create_study") + def test_add_model_and_report_trial_metrics_infeasible(self, mock_complete_trial): + def get_model_func(): + return mock.Mock() + + test_tuner = VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=1, + parallel_trial_count=1, + hparam_space=[], + ) + test_trial_name = "trial_0" + test_tuner._add_model_and_report_trial_metrics( + test_trial_name, + None, + ) + mock_complete_trial.assert_called_once_with( + {"name": test_trial_name, "trial_infeasible": True} + ) + assert test_tuner.models == {} + + @pytest.mark.usefixtures("google_auth_mock", "mock_create_study") + def test_suggest_trials(self, mock_suggest_trials): + test_parallel_trial_count = 4 + + def get_model_func(): + return + + mock_suggest_trials.return_value.result.return_value.trials = [ + Trial(name="trial_0"), + Trial(name="trial_1"), + Trial(name="trial_2"), + Trial(name="trial_3"), + ] + + test_tuner = VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=16, + parallel_trial_count=test_parallel_trial_count, + hparam_space=[_TEST_PARAMETER_SPEC], + ) + test_suggested_trials = test_tuner._suggest_trials(test_parallel_trial_count) + + expected_suggest_trials_request = { + "parent": "test_study", + "suggestion_count": test_parallel_trial_count, + "client_id": "client", + } + mock_suggest_trials.assert_called_once_with(expected_suggest_trials_request) + assert test_suggested_trials == mock_suggest_trials().result().trials + + @pytest.mark.usefixtures("google_auth_mock", "mock_create_study") + def test_set_model_parameters(self): + def get_model_func(penalty: str, C: float, dual=True): + return _logistic.LogisticRegression(penalty=penalty, C=C, dual=dual) + + hparam_space = [ + { + "parameter_id": "penalty", + "categorical_value_spec": {"values": ["l1", "l2"]}, + }, + { + "parameter_id": "C", + "discrete_value_spec": {"values": [0.002, 0.01, 0.03]}, + }, + {"parameter_id": "extra_1", "discrete_value_spec": {"values": [1, 2, 3]}}, + ] + test_tuner = VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=16, + parallel_trial_count=4, + hparam_space=hparam_space, + ) + trial = Trial( + name="trial_1", + parameters=[ + Trial.Parameter(parameter_id="penalty", value="elasticnet"), + Trial.Parameter(parameter_id="C", value=0.05), + Trial.Parameter(parameter_id="extra_1", value=1.0), + ], + ) + model, model_runtime_parameters = test_tuner._set_model_parameters( + trial, fixed_runtime_params={"extra_2": 5} + ) + + assert model.C == 0.05 + assert model.dual + assert model.penalty == "elasticnet" + assert model_runtime_parameters == {"extra_1": 1, "extra_2": 5} + + @pytest.mark.usefixtures("google_auth_mock", "mock_create_study") + def test_set_model_parameters_no_runtime_params(self): + def get_model_func(penalty: str, C: float, dual=True): + return _logistic.LogisticRegression(penalty=penalty, C=C, dual=dual) + + hparam_space = [ + { + "parameter_id": "penalty", + "categorical_value_spec": {"values": ["l1", "l2"]}, + }, + { + "parameter_id": "C", + "discrete_value_spec": {"values": [0.002, 0.01, 0.03]}, + }, + {"parameter_id": "extra_1", "discrete_value_spec": {"values": [1, 2, 3]}}, + ] + + test_tuner = VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=16, + parallel_trial_count=4, + hparam_space=hparam_space, + ) + trial = Trial( + name="trial_1", + parameters=[ + Trial.Parameter(parameter_id="penalty", value="elasticnet"), + Trial.Parameter(parameter_id="C", value=0.05), + Trial.Parameter(parameter_id="dual", value=False), + ], + ) + model, model_runtime_parameters = test_tuner._set_model_parameters(trial) + + assert model.C == 0.05 + assert not model.dual + assert model.penalty == "elasticnet" + assert not model_runtime_parameters + + @pytest.mark.usefixtures("google_auth_mock", "mock_create_study") + def test_get_vertex_model_train_method_and_params(self): + vertexai.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + vertexai.preview.init(remote=False) + + class TestVertexModel(remote.VertexModel): + @vertexai.preview.developer.mark.train( + remote_config=_TEST_TRAINING_CONFIG, + ) + def train( + self, + x, + y, + x_train, + y_train, + x_test, + y_test, + training_data, + validation_data, + X, + X_train, + X_test, + ): + return + + def get_model_func(): + return TestVertexModel() + + test_tuner = VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=16, + parallel_trial_count=4, + hparam_space=[], + ) + test_model = get_model_func() + ( + test_train_method, + test_data_params, + ) = test_tuner._get_vertex_model_train_method_and_params( + test_model, + _TEST_X_TRAIN, + _TEST_Y_TRAIN, + _TEST_X_TEST, + _TEST_Y_TEST_CLASSIFICATION_BINARY, + _TEST_TRIAL_NAME, + ) + assert test_train_method == test_model.train + assert set(test_data_params.keys()) == set( + [ + "x", + "y", + "x_train", + "y_train", + "x_test", + "y_test", + "training_data", + "validation_data", + "X", + "X_train", + "X_test", + ] + ) + assert test_data_params["x"].equals(_TEST_X_TRAIN) + assert test_data_params["y"].equals(_TEST_Y_TRAIN) + assert test_data_params["x_train"].equals(_TEST_X_TRAIN) + assert test_data_params["y_train"].equals(_TEST_Y_TRAIN) + assert test_data_params["x_test"].equals(_TEST_X_TEST) + assert test_data_params["y_test"].equals(_TEST_Y_TEST_CLASSIFICATION_BINARY) + assert test_data_params["training_data"].equals(_TEST_X_TRAIN) + assert test_data_params["validation_data"].equals(_TEST_VALIDATION_DATA) + assert test_data_params["X"].equals(_TEST_X_TRAIN) + assert test_data_params["X_train"].equals(_TEST_X_TRAIN) + assert test_data_params["X_test"].equals(_TEST_X_TEST) + + # staging_bucket is not overriden in local mode. + assert ( + test_train_method.vertex.remote_config.staging_bucket + == _TEST_STAGING_BUCKET + ) + + @pytest.mark.usefixtures("google_auth_mock", "mock_create_study") + def test_get_vertex_model_train_method_and_params_no_y_train(self): + vertexai.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + class TestVertexModel(remote.VertexModel): + @vertexai.preview.developer.mark.train( + remote_config=_TEST_TRAINING_CONFIG, + ) + def train(self, training_data, validation_data): + return + + def get_model_func(): + return TestVertexModel() + + test_tuner = VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=16, + parallel_trial_count=4, + hparam_space=[], + ) + test_model = get_model_func() + ( + test_train_method, + test_data_params, + ) = test_tuner._get_vertex_model_train_method_and_params( + test_model, + _TEST_TRAINING_DATA, + None, + _TEST_X_TEST, + _TEST_Y_TEST_CLASSIFICATION_BINARY, + _TEST_TRIAL_NAME, + ) + assert test_train_method == test_model.train + assert set(test_data_params.keys()) == set(["training_data", "validation_data"]) + assert test_data_params["training_data"].equals(_TEST_TRAINING_DATA) + assert test_data_params["validation_data"].equals(_TEST_VALIDATION_DATA) + + @pytest.mark.parametrize( + "get_model_func", [get_test_trainer_a, get_test_remote_container_trainer] + ) + @pytest.mark.usefixtures("google_auth_mock", "mock_create_study") + def test_get_vertex_model_train_method_and_params_remote_staging_bucket( + self, get_model_func + ): + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_STAGING_BUCKET, + ) + vertexai.preview.init(remote=True) + test_tuner = VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=16, + parallel_trial_count=4, + hparam_space=[], + ) + test_model = get_model_func() + test_train_method, _ = test_tuner._get_vertex_model_train_method_and_params( + test_model, + _TEST_X_TRAIN, + _TEST_Y_TRAIN, + _TEST_X_TEST, + _TEST_Y_TEST_CLASSIFICATION_BINARY, + _TEST_TRIAL_NAME, + ) + assert ( + test_train_method.vertex.remote_config.staging_bucket + == _TEST_TRIAL_STAGING_BUCKET + ) + + @pytest.mark.usefixtures("google_auth_mock", "mock_create_study") + def test_get_vertex_model_train_method_and_params_no_remote_executable(self): + class TestVertexModel(remote.VertexModel): + def train(self, x, y): + return + + @vertexai.preview.developer.mark.predict() + def predict(self, x): + return + + def get_model_func(): + return TestVertexModel() + + vertexai.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_tuner = VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=16, + parallel_trial_count=4, + hparam_space=[], + ) + test_model = get_model_func() + with pytest.raises(ValueError, match="No remote executable train method"): + test_tuner._get_vertex_model_train_method_and_params( + test_model, + _TEST_X_TRAIN, + _TEST_Y_TRAIN, + _TEST_X_TEST, + _TEST_Y_TEST_CLASSIFICATION_BINARY, + _TEST_TRIAL_NAME, + ) + + @pytest.mark.parametrize( + "get_model_func,x_train,y_train,x_test,y_test", + [ + ( + get_test_trainer_a, + _TEST_X_TRAIN, + None, + _TEST_Y_TRAIN, + _TEST_Y_TEST_CLASSIFICATION_BINARY, + ), + ( + get_test_trainer_b, + _TEST_X_TRAIN, + _TEST_X_TEST, + _TEST_Y_TRAIN, + _TEST_Y_TEST_CLASSIFICATION_BINARY, + ), + ], + ) + @pytest.mark.usefixtures("google_auth_mock", "mock_create_study") + def test_run_trial_vertex_model_train( + self, + get_model_func, + x_train, + y_train, + x_test, + y_test, + ): + # For unit tests only test local mode. + vertexai.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + vertexai.preview.init(remote=False) + + test_tuner = VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=16, + parallel_trial_count=4, + hparam_space=[], + ) + test_trial = Trial(name="trial_0", parameters=[]) + test_model, test_metric_value = test_tuner._run_trial( + x_train=x_train, + y_train=y_train, + x_test=x_test, + y_test=y_test, + trial=test_trial, + ) + assert isinstance(test_model, type(get_model_func())) + test_model.predict.assert_called_once_with(x_test) + assert test_metric_value == sklearn.metrics.accuracy_score( + _TEST_Y_TEST_CLASSIFICATION_BINARY, + _TEST_Y_PRED_CLASSIFICATION_BINARY, + ) + + @pytest.mark.usefixtures( + "google_auth_mock", + "mock_create_study", + "mock_blob_upload_from_filename", + "mock_create_custom_job", + "mock_get_custom_job_succeeded", + ) + def test_run_trial_vertex_model_remote_container_train(self): + vertexai.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + vertexai.preview.init(remote=True) + + test_tuner = VizierHyperparameterTuner( + get_model_func=get_test_remote_container_trainer, + max_trial_count=16, + parallel_trial_count=4, + hparam_space=[], + ) + test_trial = Trial(name="trial_0", parameters=[]) + test_model, test_metric_value = test_tuner._run_trial( + x_train=_TEST_TRAINING_DATA, + y_train=None, + x_test=_TEST_X_TEST, + y_test=_TEST_Y_TEST_CLASSIFICATION_BINARY, + trial=test_trial, + ) + assert isinstance(test_model, TestRemoteContainerTrainer) + test_model.predict.assert_called_once_with(_TEST_X_TEST) + assert test_metric_value == sklearn.metrics.accuracy_score( + _TEST_Y_TEST_CLASSIFICATION_BINARY, + _TEST_Y_PRED_CLASSIFICATION_BINARY, + ) + + @pytest.mark.usefixtures("google_auth_mock", "mock_create_study") + def test_run_trial_infeasible(self): + vertexai.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + vertexai.preview.init(remote=True) + + class TestTrainer(remote.VertexModel): + @vertexai.preview.developer.mark.train( + remote_config=_TEST_TRAINING_CONFIG, + ) + def train(self, x_train, y_train, x_test, y_test): + raise RuntimeError() + + def get_model_func(): + return TestTrainer() + + test_tuner = VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=16, + parallel_trial_count=4, + hparam_space=[], + ) + test_trial = Trial(name="trial_0", parameters=[]) + trial_output = test_tuner._run_trial( + x_train=_TEST_X_TRAIN, + y_train=_TEST_Y_TRAIN, + x_test=_TEST_X_TEST, + y_test=_TEST_Y_TEST_REGRESSION, + trial=test_trial, + ) + assert not trial_output + + @pytest.mark.usefixtures("google_auth_mock", "mock_create_study") + def test_run_trial_unsupported_model_type(self): + def get_model_func(): + return mock.Mock() + + test_tuner = VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=16, + parallel_trial_count=4, + hparam_space=[], + ) + test_trial = Trial(name="trial_0", parameters=[]) + with pytest.raises(ValueError, match="Unsupported model type"): + test_tuner._run_trial( + x_train=_TEST_X_TRAIN, + y_train=_TEST_Y_TRAIN, + x_test=_TEST_X_TEST, + y_test=_TEST_Y_TEST_REGRESSION, + trial=test_trial, + ) + + @pytest.mark.usefixtures("google_auth_mock", "mock_uuid", "mock_create_study") + def test_fit( + self, + mock_executor_map, + mock_suggest_trials, + mock_complete_trial, + ): + def get_model_func(): + return + + mock_suggest_trials.return_value.result.side_effect = [ + SuggestTrialsResponse( + trials=[Trial(name="trial_1"), Trial(name="trial_2")] + ), + SuggestTrialsResponse( + trials=[ + Trial(name="trial_3"), + Trial(name="trial_4"), + ] + ), + ] + model_1, model_2, model_3, model_4 = (mock.Mock() for _ in range(4)) + mock_executor_map.side_effect = [ + [(model_1, 0.01), (model_2, 0.03)], + [(model_3, 0.02), (model_4, 0.05)], + ] + test_tuner = VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=4, + parallel_trial_count=2, + hparam_space=[], + ) + test_tuner.fit(x=_TEST_X_TEST, y=_TEST_Y_TEST_CLASSIFICATION_BINARY) + + assert mock_suggest_trials.call_count == 2 + assert mock_executor_map.call_count == 2 + # check fixed_runtime_params in first executor.map call is empty + assert not mock_executor_map.call_args_list[0][0][1][0][6] + assert mock_complete_trial.call_count == 4 + assert test_tuner.models == { + "trial_1": model_1, + "trial_2": model_2, + "trial_3": model_3, + "trial_4": model_4, + } + + @pytest.mark.usefixtures("google_auth_mock", "mock_uuid", "mock_create_study") + def test_fit_varying_parallel_trial_count_and_fixed_runtime_params( + self, + mock_executor_map, + mock_suggest_trials, + mock_complete_trial, + ): + def get_model_func(): + return + + mock_suggest_trials.return_value.result.side_effect = [ + SuggestTrialsResponse( + trials=[Trial(name="trial_1"), Trial(name="trial_2")] + ), + SuggestTrialsResponse( + trials=[ + Trial(name="trial_3"), + Trial(name="trial_4"), + ] + ), + SuggestTrialsResponse( + trials=[ + Trial(name="trial_5"), + ] + ), + ] + model_1, model_2, model_3, model_4, model_5 = (mock.Mock() for _ in range(5)) + mock_executor_map.side_effect = [ + [(model_1, 0.01), (model_2, 0.03)], + [(model_3, 0.02), (model_4, 0.05)], + [(model_5, 0.06)], + ] + test_tuner = VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=5, + parallel_trial_count=2, + hparam_space=[], + ) + test_tuner.fit( + x=_TEST_X_TEST, + y=_TEST_Y_TEST_CLASSIFICATION_BINARY, + x_test=_TEST_X_TEST, + y_test=_TEST_Y_TEST_CLASSIFICATION_BINARY, + num_epochs=5, + ) + + assert mock_suggest_trials.call_count == 3 + assert mock_executor_map.call_count == 3 + # check fixed_runtime_params in first executor.map call is non-empty + assert mock_executor_map.call_args_list[0][0][1][0][6] == {"num_epochs": 5} + assert mock_complete_trial.call_count == 5 + assert test_tuner.models == { + "trial_1": model_1, + "trial_2": model_2, + "trial_3": model_3, + "trial_4": model_4, + "trial_5": model_5, + } + + @pytest.mark.usefixtures("google_auth_mock", "mock_uuid", "mock_create_study") + def test_fit_max_failed_trial_count( + self, + mock_executor_map, + mock_suggest_trials, + mock_complete_trial, + ): + def get_model_func(): + return + + mock_suggest_trials.return_value.result.return_value = SuggestTrialsResponse( + trials=[Trial(name="trial_1")] + ) + + mock_executor_map.return_value = [None] + + test_tuner = VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=2, + parallel_trial_count=1, + hparam_space=[], + max_failed_trial_count=1, + ) + + with pytest.raises( + ValueError, match="Maximum number of failed trials reached." + ): + test_tuner.fit( + x=_TEST_X_TEST, + y=_TEST_Y_TEST_CLASSIFICATION_BINARY, + num_epochs=5, + ) + + assert mock_suggest_trials.call_count == 1 + assert mock_executor_map.call_count == 1 + # check fixed_runtime_params in first executor.map call is non-empty + assert mock_executor_map.call_args_list[0][0][1][0][6] == {"num_epochs": 5} + assert mock_complete_trial.call_count == 1 + assert not test_tuner.models + + @pytest.mark.usefixtures("google_auth_mock", "mock_uuid", "mock_create_study") + def test_fit_all_trials_failed( + self, + mock_executor_map, + mock_suggest_trials, + mock_complete_trial, + ): + def get_model_func(): + return + + mock_suggest_trials.return_value.result.side_effect = [ + SuggestTrialsResponse(trials=[Trial(name="trial_1")]), + SuggestTrialsResponse(trials=[Trial(name="trial_2")]), + ] + + mock_executor_map.return_value = [None] + + test_tuner = VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=2, + parallel_trial_count=1, + hparam_space=[], + max_failed_trial_count=0, + ) + + with pytest.raises(ValueError, match="All trials failed."): + test_tuner.fit( + x=_TEST_X_TEST, + y=_TEST_Y_TEST_CLASSIFICATION_BINARY, + ) + + assert mock_suggest_trials.call_count == 2 + assert mock_executor_map.call_count == 2 + assert mock_complete_trial.call_count == 2 + assert not test_tuner.models + + @pytest.mark.usefixtures("google_auth_mock", "mock_uuid", "mock_create_study") + def test_get_model_param_type_mapping(self): + hparam_space = [ + { + "parameter_id": "penalty", + "categorical_value_spec": {"values": ["l1", "l2"]}, + }, + { + "parameter_id": "C", + "discrete_value_spec": {"values": [0.002, 0.01, 0.03]}, + }, + { + "parameter_id": "epochs", + "integer_value_spec": {"min_value": 1, "max_value": 5.0}, + }, + { + "parameter_id": "learning_rate", + "double_value_spec": {"min_value": 1, "max_value": 5}, + }, + ] + test_tuner = VizierHyperparameterTuner( + get_model_func=None, + max_trial_count=16, + parallel_trial_count=4, + hparam_space=hparam_space, + ) + expected_mapping = { + "penalty": str, + "C": float, + "epochs": int, + "learning_rate": float, + } + + assert expected_mapping == test_tuner._get_model_param_type_mapping() + + @pytest.mark.parametrize( + "test_get_model_func,expected_fixed_init_params", + [ + (lambda x, y: None, {"x": _TEST_X_TRAIN, "y": _TEST_Y_TRAIN}), + (lambda X, y: None, {"X": _TEST_X_TRAIN, "y": _TEST_Y_TRAIN}), + ( + lambda x_train, y_train: None, + {"x_train": _TEST_X_TRAIN, "y_train": _TEST_Y_TRAIN}, + ), + ( + lambda X_train, y_train: None, + {"X_train": _TEST_X_TRAIN, "y_train": _TEST_Y_TRAIN}, + ), + ], + ) + @pytest.mark.usefixtures("google_auth_mock", "mock_uuid", "mock_create_study") + def test_fit_get_model_func_params( + self, + test_get_model_func, + expected_fixed_init_params, + mock_executor_map, + mock_suggest_trials, + mock_complete_trial, + ): + mock_suggest_trials.return_value.result.side_effect = [ + SuggestTrialsResponse(trials=[Trial(name="trial_1")]), + SuggestTrialsResponse(trials=[Trial(name="trial_2")]), + SuggestTrialsResponse(trials=[Trial(name="trial_3")]), + SuggestTrialsResponse(trials=[Trial(name="trial_4")]), + ] + model_1, model_2, model_3, model_4 = (mock.Mock() for _ in range(4)) + mock_executor_map.side_effect = [ + [(model_1, 0.01)], + [(model_2, 0.03)], + [(model_3, 0.02)], + [(model_4, 0.05)], + ] + test_tuner = VizierHyperparameterTuner( + get_model_func=test_get_model_func, + max_trial_count=4, + parallel_trial_count=1, + hparam_space=[], + ) + test_tuner.fit( + x=_TEST_X_TRAIN, + y=_TEST_Y_TRAIN, + x_test=_TEST_X_TEST, + y_test=_TEST_Y_TEST_CLASSIFICATION_BINARY, + ) + + assert mock_suggest_trials.call_count == 4 + assert mock_executor_map.call_count == 4 + # check fixed_runtime_params in first executor.map call is empty + assert not mock_executor_map.call_args_list[0][0][1][0][6] + assert mock_complete_trial.call_count == 4 + assert test_tuner.models == { + "trial_1": model_1, + "trial_2": model_2, + "trial_3": model_3, + "trial_4": model_4, + } + + test_map_args = [call_args[0] for call_args in mock_executor_map.call_args_list] + test_fixed_init_params = [] + for map_args in test_map_args: + test_fixed_init_params.append( + [trial_inputs[5] for trial_inputs in map_args[1]] + ) + assert test_fixed_init_params == [ + [expected_fixed_init_params], + [expected_fixed_init_params], + [expected_fixed_init_params], + [expected_fixed_init_params], + ] + + @pytest.mark.usefixtures("google_auth_mock", "mock_create_study") + def test_get_lightning_train_method_and_params_local(self): + vertexai.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + def get_model_func(): + return + + test_tuner = VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=4, + parallel_trial_count=2, + hparam_space=[], + ) + test_model = { + "model": mock.Mock(), + "trainer": mock.Mock(), + "train_dataloaders": mock.Mock(), + } + ( + test_train_method, + test_params, + ) = test_tuner._get_lightning_train_method_and_params(test_model, "") + assert test_train_method == test_model["trainer"].fit + assert test_params == { + "model": test_model["model"], + "train_dataloaders": test_model["train_dataloaders"], + } + + @pytest.mark.usefixtures("google_auth_mock", "mock_create_study") + def test_get_lightning_train_method_and_params_remote(self): + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_STAGING_BUCKET, + ) + vertexai.preview.init(remote=True) + + def get_model_func(): + return + + test_tuner = VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=4, + parallel_trial_count=2, + hparam_space=[], + ) + + class TestTrainer: + def fit(self, model, train_dataloaders): + pass + + test_model = { + "model": mock.Mock(), + "trainer": mock.Mock(), + "train_dataloaders": mock.Mock(), + } + + test_model["trainer"].fit = VertexRemoteFunctor( + TestTrainer().fit, remote_executor=training.remote_training + ) + ( + test_train_method, + test_params, + ) = test_tuner._get_lightning_train_method_and_params( + test_model, _TEST_TRIAL_NAME + ) + assert test_params == { + "model": test_model["model"], + "train_dataloaders": test_model["train_dataloaders"], + } + assert test_train_method == test_model["trainer"].fit + assert ( + test_train_method.vertex.remote_config.staging_bucket + == _TEST_TRIAL_STAGING_BUCKET + ) + + @pytest.mark.usefixtures("google_auth_mock", "mock_create_study") + def test_run_trial_lightning( + self, + ): + # For unit tests only test local mode. + vertexai.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + test_lightning_model = { + "model": mock.Mock(), + "trainer": mock.Mock(), + "train_dataloaders": mock.Mock(), + } + test_lightning_model[ + "model" + ].predict.return_value = _TEST_Y_PRED_CLASSIFICATION_BINARY + + def get_model_func(): + return test_lightning_model + + test_tuner = VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=16, + parallel_trial_count=4, + hparam_space=[], + ) + test_trial = Trial(name="trial_0", parameters=[]) + test_trained_model, test_metric_value = test_tuner._run_trial( + x_train=_TEST_X_TRAIN, + y_train=_TEST_Y_TRAIN, + x_test=_TEST_X_TEST, + y_test=_TEST_Y_TEST_CLASSIFICATION_BINARY, + trial=test_trial, + fixed_runtime_params={"ckpt_path": "test_ckpt_path"}, + ) + assert test_trained_model == test_lightning_model + test_lightning_model["trainer"].fit.assert_called_once_with( + model=test_lightning_model["model"], + train_dataloaders=test_lightning_model["train_dataloaders"], + ckpt_path="test_ckpt_path", + ) + test_lightning_model["model"].predict.assert_called_once_with(_TEST_X_TEST) + assert test_metric_value == sklearn.metrics.accuracy_score( + _TEST_Y_TEST_CLASSIFICATION_BINARY, + _TEST_Y_PRED_CLASSIFICATION_BINARY, + ) + + @pytest.mark.usefixtures("google_auth_mock", "mock_create_study") + def test_get_keras_train_method_and_params(self): + vertexai.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + vertexai.preview.init(remote=True) + + def get_model_func(): + tf.keras.Sequential = vertexai.preview.remote(tf.keras.Sequential) + model = tf.keras.Sequential( + [tf.keras.layers.Dense(5, input_shape=(4,)), tf.keras.layers.Softmax()] + ) + model.compile(optimizer="adam", loss="mean_squared_error") + model.fit.vertex.remote_config.staging_bucket = _TEST_STAGING_BUCKET + return model + + test_tuner = VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=16, + parallel_trial_count=4, + hparam_space=[], + ) + test_model = get_model_func() + test_train_method, data_params = test_tuner._get_train_method_and_params( + test_model, + _TEST_X_TRAIN, + _TEST_Y_TRAIN, + _TEST_TRIAL_NAME, + params=["x", "y"], + ) + assert test_train_method._remote_executor == training.remote_training + assert ( + test_train_method.vertex.remote_config.staging_bucket + == _TEST_TRIAL_STAGING_BUCKET + ) + assert data_params == {"x": _TEST_X_TRAIN, "y": _TEST_Y_TRAIN} + + @pytest.mark.usefixtures("google_auth_mock", "mock_create_study") + def test_get_sklearn_train_method_and_params(self): + vertexai.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + vertexai.preview.init(remote=True) + + def get_model_func(): + LogisticRegression = vertexai.preview.remote(_logistic.LogisticRegression) + model = LogisticRegression(penalty="l1") + model.fit.vertex.remote_config.staging_bucket = _TEST_STAGING_BUCKET + return model + + test_tuner = VizierHyperparameterTuner( + get_model_func=get_model_func, + max_trial_count=16, + parallel_trial_count=4, + hparam_space=[], + ) + test_model = get_model_func() + (test_train_method, data_params,) = test_tuner._get_train_method_and_params( + test_model, + _TEST_X_TRAIN, + _TEST_Y_TRAIN, + _TEST_TRIAL_NAME, + params=["X", "y"], + ) + assert test_train_method._remote_executor == training.remote_training + assert ( + test_train_method.vertex.remote_config.staging_bucket + == _TEST_TRIAL_STAGING_BUCKET + ) + assert data_params == {"X": _TEST_X_TRAIN, "y": _TEST_Y_TRAIN} diff --git a/vertexai/__init__.py b/vertexai/__init__.py index 5eff2b4391..59c5887403 100644 --- a/vertexai/__init__.py +++ b/vertexai/__init__.py @@ -15,7 +15,9 @@ """The vertexai module.""" from google.cloud.aiplatform import init +from vertexai import preview __all__ = [ "init", + "preview", ] diff --git a/vertexai/preview/__init__.py b/vertexai/preview/__init__.py new file mode 100644 index 0000000000..882597841d --- /dev/null +++ b/vertexai/preview/__init__.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.cloud.aiplatform import metadata +from vertexai.preview import developer +from vertexai.preview import hyperparameter_tuning +from vertexai.preview import initializer +from vertexai.preview import tabular_models +from vertexai.preview._workflow.driver import ( + remote as remote_decorator, +) +from vertexai.preview._workflow.shared import ( + model_utils, +) + + +global_config = initializer.global_config +init = global_config.init +remote = remote_decorator.remote +VertexModel = remote_decorator.VertexModel +register = model_utils.register +from_pretrained = model_utils.from_pretrained + +# For Vertex AI Experiment. + +# ExperimentRun manipulation. +start_run = metadata.metadata._experiment_tracker.start_run +end_run = metadata.metadata._experiment_tracker.end_run +get_experiment_df = metadata.metadata._experiment_tracker.get_experiment_df + +# Experiment logging. +log_params = metadata.metadata._experiment_tracker.log_params +log_metrics = metadata.metadata._experiment_tracker.log_metrics +log_time_series_metrics = metadata.metadata._experiment_tracker.log_time_series_metrics +log_classification_metrics = ( + metadata.metadata._experiment_tracker.log_classification_metrics +) + + +__all__ = ( + "init", + "remote", + "VertexModel", + "register", + "from_pretrained", + "start_run", + "end_run", + "get_experiment_df", + "log_params", + "log_metrics", + "log_time_series_metrics", + "log_classification_metrics", + "developer", + "hyperparameter_tuning", + "tabular_models", +) diff --git a/vertexai/preview/_workflow/__init__.py b/vertexai/preview/_workflow/__init__.py new file mode 100644 index 0000000000..875d5556f2 --- /dev/null +++ b/vertexai/preview/_workflow/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""The vertexai _workflow module.""" diff --git a/vertexai/preview/_workflow/driver/__init__.py b/vertexai/preview/_workflow/driver/__init__.py new file mode 100644 index 0000000000..18babaf74c --- /dev/null +++ b/vertexai/preview/_workflow/driver/__init__.py @@ -0,0 +1,268 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import functools +import inspect +from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Type, TypeVar + +from google.cloud.aiplatform import jobs +import vertexai +from vertexai.preview._workflow import launcher +from vertexai.preview._workflow import shared +from vertexai.preview._workflow.executor import ( + training, + prediction, +) +from vertexai.preview._workflow.executor import ( + remote_container_training, +) + +ModelBase = TypeVar("ModelBase") +ModelVertexSubclass = TypeVar("ModelVertexSubclass", bound=ModelBase) + +_WRAPPED_CLASS_PREFIX = "_Vertex" + + +class VertexRemoteFunctor: + """Functor to be used to wrap methods for remote execution.""" + + def __init__( + self, + method: Callable[..., Any], + remote_executor: Callable[..., Any], + remote_executor_kwargs: Optional[Dict[str, Any]] = None, + ): + """Wraps a method into VertexRemoteFunctor so that the method is remotely executable. + + Example Usage: + ``` + functor = VertexRemoteFunctor(LogisticRegression.fit, training.remote_training) + setattr(LogisticRegression, "fit", functor) + + model = LogisticRegression() + model.fit.vertex.remote_config.staging_bucket = REMOTE_JOB_BUCKET + model.fit.vertex.remote=True + model.fit(X_train, y_train) + ``` + + Args: + method (Callable[..., Any]): + Required. The method to be wrapped. + remote_executor (Callable[..., Any]): + Required. The remote executor for the method. + remote_executor_kwargs (Dict[str, Any]): + Optional. kwargs used in remote executor. + """ + self._method = method + # TODO(b/278074360) Consider multiple levels of configurations. + if inspect.ismethod(method): + # For instance method, instantiate vertex config directly. + self.vertex = shared.configs.VertexConfig() + else: + # For function, instantiate vertex config later, when the method is + # bounded to an instance. + self.vertex = shared.configs.VertexConfig + self._remote_executor = remote_executor + self._remote_executor_kwargs = remote_executor_kwargs or {} + functools.update_wrapper(self, method) + + def __get__(self, instance, owner) -> Any: + # For class and instance method that already instantiate a new functor, + # return self directly + if (instance is None) or isinstance(self.vertex, shared.configs.VertexConfig): + return self + + # Instantiate a new functor for the instance method + functor_with_instance_bound_method = self.__class__( + self._method.__get__(instance, owner), + self._remote_executor, + self._remote_executor_kwargs, + ) + functor_with_instance_bound_method.vertex = self.vertex() + setattr(instance, self._method.__name__, functor_with_instance_bound_method) + return functor_with_instance_bound_method + + def __call__(self, *args, **kwargs) -> Any: + bound_args = inspect.signature(self._method).bind(*args, **kwargs) + + # NOTE: may also need to handle the case of + # bound_args.arguments.get("self"), + + invokable = shared._Invokable( + instance=getattr(self._method, "__self__"), + method=self._method, + bound_arguments=bound_args, + remote_executor=self._remote_executor, + remote_executor_kwargs=self._remote_executor_kwargs, + vertex_config=self.vertex, + ) + + return _workflow_driver.invoke(invokable) + + +def _supported_member_iter(instance: Any) -> Iterator[Tuple[str, Callable[..., Any]]]: + """Iterates through known method names and returns matching methods.""" + for attr_name in shared.supported_frameworks.REMOTE_TRAINING_OVERRIDE_LIST: + attr_value = getattr(instance, attr_name, None) + if attr_value: + yield attr_name, attr_value, training.remote_training, None + + for attr_name in shared.supported_frameworks.REMOTE_PREDICTION_OVERRIDE_LIST: + attr_value = getattr(instance, attr_name, None) + if attr_value: + yield attr_name, attr_value, prediction.remote_prediction, None + + +def _patch_class(cls: Type[ModelBase]) -> Type[ModelVertexSubclass]: + """Creates a new class that inherited from original class and add Vertex remote execution support.""" + + if hasattr(cls, "_wrapped_by_vertex"): + return cls + + new_cls = type( + f"{_WRAPPED_CLASS_PREFIX}{cls.__name__}", (cls,), {"_wrapped_by_vertex": True} + ) + for ( + attr_name, + attr_value, + remote_executor, + remote_executor_kwargs, + ) in _supported_member_iter(cls): + setattr( + new_cls, + attr_name, + VertexRemoteFunctor(attr_value, remote_executor, remote_executor_kwargs), + ) + + return new_cls + + +def _rewrapper( + instance: Any, + wrapped_class: Any, + config_map: Dict[str, shared.configs.VertexConfig], +): + """Rewraps in place instances after remote execution has completed. + + Args: + instance (Any): + Required. Instance to rewrap. + wrapped_class (Any): + Required. The class type that the instance will be wrapped into. + config_map (Dict[str, shared.configs.VertexConfig]): + Required. Instance of config before unwrapping. Maintains + the config after wrapping. + """ + instance.__class__ = wrapped_class + for attr_name, ( + vertex_config, + remote_executor, + remote_executor_kwargs, + ) in config_map.items(): + method = getattr(instance, attr_name) + if isinstance(method, VertexRemoteFunctor): + method.vertex = vertex_config + setattr(instance, attr_name, method) + else: + functor = VertexRemoteFunctor( + method, remote_executor, remote_executor_kwargs + ) + functor.vertex = vertex_config + setattr(instance, attr_name, functor) + + +def _unwrapper(instance: Any) -> Callable[..., Any]: + """Unwraps all Vertex functor method. + + This should be done before locally executing or remotely executing. + """ + current_class = instance.__class__ + super_class = current_class.__mro__[1] + wrapped_in_place = ( + current_class.__name__ != f"{_WRAPPED_CLASS_PREFIX}{super_class.__name__}" + ) + + config_map = dict() + + for ( + attr_name, + attr_value, + remote_executor, + remote_executor_kwargs, + ) in _supported_member_iter(instance): + # NOTE: This additional check may be unnessecary in the current + # implementation but will be more robust to future changes. + # ie: framework specific method name patching + if isinstance(attr_value, VertexRemoteFunctor): + config_map[attr_name] = ( + attr_value.vertex, + remote_executor, + remote_executor_kwargs, + ) + if wrapped_in_place: + setattr(instance, attr_name, attr_value._method) + + if not wrapped_in_place: + instance.__class__ = super_class + + return functools.partial( + _rewrapper, wrapped_class=current_class, config_map=config_map + ) + + +class _WorkFlowDriver: + def __init__(self): + self._launcher = launcher._WorkflowLauncher() + + def invoke(self, invokable: shared._Invokable) -> Any: + """ + Wrapper should forward implementation to this method. + + NOTE: Not threadsafe w.r.t the instance. + """ + + rewrapper = None + # unwrap + if ( + invokable.instance is not None + and invokable.remote_executor is not remote_container_training.train + ): + rewrapper = _unwrapper(invokable.instance) + + result = self._launch(invokable) + + # rewrap the original instance + if rewrapper and invokable.instance is not None: + rewrapper(invokable.instance) + # also rewrap the result if the result is an estimator not a dataset + if rewrapper and isinstance(result, type(invokable.instance)): + rewrapper(result) + + if hasattr(result, "state") and result.state in jobs._JOB_ERROR_STATES: + raise RuntimeError("Remote job failed with:\n%s" % result.error) + + return result + + def _launch(self, invokable: shared._Invokable) -> Any: + """ + Launches an invokable. + """ + return self._launcher.launch( + invokable=invokable, global_remote=vertexai.preview.global_config.remote + ) + + +_workflow_driver = _WorkFlowDriver() diff --git a/vertexai/preview/_workflow/driver/remote.py b/vertexai/preview/_workflow/driver/remote.py new file mode 100644 index 0000000000..303e8ef135 --- /dev/null +++ b/vertexai/preview/_workflow/driver/remote.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc +import inspect +from typing import Any, Callable, Dict, Optional, Type + +from vertexai.preview._workflow import driver +from vertexai.preview._workflow.executor import ( + training, +) +from vertexai.preview._workflow.shared import ( + supported_frameworks, +) +from vertexai.preview.developer import remote_specs + + +def remote_method_decorator( + method: Callable[..., Any], + remote_executor: Callable[..., Any], + remote_executor_kwargs: Optional[Dict[str, Any]] = None, +) -> Callable[..., Any]: + """Wraps methods as Functor object to support configuration on method.""" + return driver.VertexRemoteFunctor(method, remote_executor, remote_executor_kwargs) + + +def remote_class_decorator(cls: Type) -> Type: + """Add Vertex attributes to a class object.""" + + if not supported_frameworks._is_oss(cls): + raise ValueError( + f"Class {cls.__name__} not supported. " + "Currently support remote execution on " + f"{supported_frameworks.REMOTE_FRAMEWORKS} classes." + ) + + return driver._patch_class(cls) + + +def remote(cls_or_method: Any) -> Any: + """Takes a class or method and add Vertex remote execution support. + + ex: + ``` + + LogisticRegression = vertexai.preview.remote(LogisticRegression) + model = LogisticRegression() + model.fit.vertex.remote_config.staging_bucket = REMOTE_JOB_BUCKET + model.fit.vertex.remote=True + model.fit(X_train, y_train) + ``` + + Args: + cls_or_method (Any): + Required. A class or method that will be added Vertex remote + execution support. + + Returns: + A class or method that can be executed remotely. + """ + if inspect.isclass(cls_or_method): + return remote_class_decorator(cls_or_method) + else: + return remote_method_decorator(cls_or_method, training.remote_training) + + +class VertexModel(metaclass=abc.ABCMeta): + """mixin class that can be used to add Vertex AI remote execution to a custom model.""" + + def __init__(self): + vertex_wrapper = False + for _, attr_value in inspect.getmembers(self): + if isinstance(attr_value, driver.VertexRemoteFunctor): + vertex_wrapper = True + break + # TODO(b/279631878) Remove this check once we support more decorators. + if not vertex_wrapper: + raise ValueError( + "No method is enabled for Vertex remote training. Please decorator " + "your training methods with `@vertexai.preview.developer.mark.train`." + ) + self._cluster_spec = None + + @property + def cluster_spec(self): + return self._cluster_spec + + @cluster_spec.setter + def cluster_spec(self, cluster_spec: remote_specs._ClusterSpec): + self._cluster_spec = cluster_spec diff --git a/vertexai/preview/_workflow/executor/__init__.py b/vertexai/preview/_workflow/executor/__init__.py new file mode 100644 index 0000000000..32001a602c --- /dev/null +++ b/vertexai/preview/_workflow/executor/__init__.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any + +from vertexai.preview._workflow import shared +from vertexai.preview._workflow.executor import ( + remote_container_training, + training, + prediction, +) + + +class _WorkflowExecutor: + """Executes an invokable either locally or remotely.""" + + def local_execute(self, invokable: shared._Invokable) -> Any: + if invokable.remote_executor is remote_container_training.train: + raise ValueError( + "Remote container train is only supported for remote mode." + ) + return invokable.method( + *invokable.bound_arguments.args, **invokable.bound_arguments.kwargs + ) + + def remote_execute(self, invokable: shared._Invokable) -> Any: + if invokable.remote_executor not in ( + remote_container_training.train, + training.remote_training, + prediction.remote_prediction, + ): + raise ValueError(f"{invokable.remote_executor} is not supported.") + + return invokable.remote_executor(invokable) + + +_workflow_executor = _WorkflowExecutor() diff --git a/vertexai/preview/_workflow/executor/persistent_resource_util.py b/vertexai/preview/_workflow/executor/persistent_resource_util.py new file mode 100644 index 0000000000..de464a7776 --- /dev/null +++ b/vertexai/preview/_workflow/executor/persistent_resource_util.py @@ -0,0 +1,208 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import datetime +import time +from typing import Optional + +from google.api_core import exceptions +from google.api_core import gapic_v1 +from google.api_core.client_options import ClientOptions +from google.cloud import aiplatform +from google.cloud.aiplatform import base +from google.cloud.aiplatform_v1beta1.services.persistent_resource_service import ( + PersistentResourceServiceClient, +) +from google.cloud.aiplatform_v1beta1.types import persistent_resource_service +from google.cloud.aiplatform_v1beta1.types.persistent_resource import ( + PersistentResource, +) +from google.cloud.aiplatform_v1beta1.types.persistent_resource import ( + ResourcePool, +) +from google.cloud.aiplatform_v1beta1.types.persistent_resource_service import ( + GetPersistentResourceRequest, +) + + +GAPIC_VERSION = aiplatform.__version__ +_LOGGER = base.Logger(__name__) + +_DEFAULT_REPLICA_COUNT = 1 +_DEFAULT_MACHINE_TYPE = "n1-standard-4" +_DEFAULT_DISK_TYPE = "pd-ssd" +_DEFAULT_DISK_SIZE_GB = 100 + + +def _create_persistent_resource_client(location: Optional[str] = "us-central1"): + + client_info = gapic_v1.client_info.ClientInfo( + gapic_version=GAPIC_VERSION, + ) + + api_endpoint = f"{location}-aiplatform.googleapis.com" + + return PersistentResourceServiceClient( + client_options=ClientOptions(api_endpoint=api_endpoint), + client_info=client_info, + ) + + +def check_persistent_resource(cluster_resource_name: str) -> bool: + """Helper method to check if a persistent resource exists or not. + + Args: + cluster_resource_name: Persistent Resource name. Has the form: + ``projects/my-project/locations/my-region/persistentResource/cluster-name``. + + Returns: + True if a Persistent Resource exists. + + Raises: + ValueError: if existing cluster is not RUNNING. + """ + # Parse resource name to get the location. + locataion = cluster_resource_name.split("/")[3] + client = _create_persistent_resource_client(locataion) + request = GetPersistentResourceRequest( + name=cluster_resource_name, + ) + try: + response = client.get_persistent_resource(request) + except exceptions.NotFound: + return False + + if response.state != PersistentResource.State.RUNNING: + raise ValueError( + "The existing cluster `", + cluster_resource_name, + "` isn't running, please specify a different cluster_name.", + ) + return True + + +def _default_persistent_resource() -> PersistentResource: + """Default persistent resource.""" + # Currently the service accepts only one resource_pool config and image_uri. + resource_pools = [] + resource_pool = ResourcePool() + resource_pool.replica_count = _DEFAULT_REPLICA_COUNT + resource_pool.machine_spec.machine_type = _DEFAULT_MACHINE_TYPE + resource_pool.disk_spec.boot_disk_type = _DEFAULT_DISK_TYPE + resource_pool.disk_spec.boot_disk_size_gb = _DEFAULT_DISK_SIZE_GB + resource_pools.append(resource_pool) + + return PersistentResource(resource_pools=resource_pools) + + +# TODO(b/294600649) +def _polling_delay(num_attempts: int, time_scale: float) -> datetime.timedelta: + """Computes a delay to the next attempt to poll the Vertex service. + + This does bounded exponential backoff, starting with $time_scale. + If $time_scale == 0, it starts with a small time interval, less than + 1 second. + + Args: + num_attempts: The number of times have we polled and found that the + desired result was not yet available. + time_scale: The shortest polling interval, in seconds, or zero. Zero is + treated as a small interval, less than 1 second. + + Returns: + A recommended delay interval, in seconds. + """ + # The polling schedule is slow initially , and then gets faster until 4 + # attempts (after that the sleeping time remains the same). + small_interval = 30.0 # Seconds + interval = max(time_scale, small_interval) * 0.76 ** min(num_attempts, 4) + return datetime.timedelta(seconds=interval) + + +def _get_persistent_resource(cluster_resource_name: str): + """Get persistent resource. + + Args: + cluster_resource_name: + "projects//locations//persistentResources/". + + Returns: + aiplatform_v1beta1.PersistentResource if state is RUNNING. + + Raises: + ValueError: Invalid cluster resource name. + RuntimeError: Service returns error. + RuntimeError: Cluster resource state is STOPPING. + RuntimeError: Cluster resource state is ERROR. + """ + + # Parse resource name to get the location. + locataion = cluster_resource_name.split("/")[3] + client = _create_persistent_resource_client(locataion) + request = GetPersistentResourceRequest( + name=cluster_resource_name, + ) + + num_attempts = 0 + while True: + try: + response = client.get_persistent_resource(request) + except exceptions.NotFound as e: + raise ValueError("Invalid cluster_resource_name (404 not found).") from e + if response.error.message: + raise RuntimeError("Cluster returned an error.", response.error.message) + + print("Cluster State =", response.state) + if response.state == PersistentResource.State.RUNNING: + return response + elif response.state == PersistentResource.State.STOPPING: + raise RuntimeError("The cluster is stopping.") + elif response.state == PersistentResource.State.ERROR: + raise RuntimeError("The cluster encountered an error.") + # Polling decay + sleep_time = _polling_delay(num_attempts=num_attempts, time_scale=90.0) + num_attempts += 1 + print( + "Waiting for cluster provisioning; attempt {}; sleeping for {} seconds".format( + num_attempts, sleep_time + ) + ) + time.sleep(sleep_time.total_seconds()) + + +def create_persistent_resource(cluster_resource_name: str): + """Create a default persistent resource.""" + locataion = cluster_resource_name.split("/")[3] + parent = "/".join(cluster_resource_name.split("/")[:4]) + cluster_name = cluster_resource_name.split("/")[-1] + + client = _create_persistent_resource_client(locataion) + + persistent_resource = _default_persistent_resource() + + request = persistent_resource_service.CreatePersistentResourceRequest( + parent=parent, + persistent_resource=persistent_resource, + persistent_resource_id=cluster_name, + ) + + try: + _ = client.create_persistent_resource(request) + except Exception as e: + raise ValueError("Failed in cluster creation due to: ", e) from e + + # Check cluster creation progress + response = _get_persistent_resource(cluster_resource_name) + _LOGGER.info(response) diff --git a/vertexai/preview/_workflow/executor/prediction.py b/vertexai/preview/_workflow/executor/prediction.py new file mode 100644 index 0000000000..0f011bda26 --- /dev/null +++ b/vertexai/preview/_workflow/executor/prediction.py @@ -0,0 +1,36 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from vertexai.preview._workflow import ( + shared, +) +from vertexai.preview._workflow.executor import ( + training, +) + + +def remote_prediction(invokable: shared._Invokable): + """Wrapper function that makes a method executable by Vertex CustomJob.""" + predictions = training.remote_training(invokable=invokable) + return predictions + + +def _online_prediction(invokable: shared._Invokable): + # TODO(b/283292903) Implement online prediction method + raise ValueError("Online prediction is not currently supported.") + + +def _batch_prediction(invokable: shared._Invokable): + # TODO(b/283289019) Implement batch prediction method + raise ValueError("Batch prediction is not currently supported.") diff --git a/vertexai/preview/_workflow/executor/remote_container_training.py b/vertexai/preview/_workflow/executor/remote_container_training.py new file mode 100644 index 0000000000..cb5e3f0983 --- /dev/null +++ b/vertexai/preview/_workflow/executor/remote_container_training.py @@ -0,0 +1,218 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Remote container training and helper functions. +""" +from typing import Any, Dict, List +import uuid + +from google.cloud import aiplatform +from google.cloud.aiplatform.utils import worker_spec_utils +import vertexai +from vertexai.preview._workflow import shared +from vertexai.preview.developer import remote_specs + + +_CUSTOM_JOB_DIR = "custom_job" +_INPUT_DIR = "input" +_OUTPUT_DIR = "output" + +# job_dir container argument name +_JOB_DIR = "job_dir" + +# Worker pool specs default value constants +_DEFAULT_REPLICA_COUNT: int = 1 +_DEFAULT_MACHINE_TYPE: str = "n1-standard-4" +_DEFAULT_ACCELERATOR_COUNT: int = 0 +_DEFAULT_ACCELERATOR_TYPE: str = "ACCELERATOR_TYPE_UNSPECIFIED" +_DEFAULT_BOOT_DISK_TYPE: str = "pd-ssd" +_DEFAULT_BOOT_DISK_SIZE_GB: int = 100 + +# Custom job default name +_DEFAULT_DISPLAY_NAME = "remote-fit" + + +def _generate_worker_pool_specs( + image_uri: str, + inputs: List[str], + replica_count: int = _DEFAULT_REPLICA_COUNT, + machine_type: str = _DEFAULT_MACHINE_TYPE, + accelerator_count: int = _DEFAULT_ACCELERATOR_COUNT, + accelerator_type: str = _DEFAULT_ACCELERATOR_TYPE, + boot_disk_type: str = _DEFAULT_BOOT_DISK_TYPE, + boot_disk_size_gb: int = _DEFAULT_BOOT_DISK_SIZE_GB, +) -> List[Dict[str, Any]]: + """Helper function to generate worker pool specs for CustomJob. + + TODO(b/278786170): Use customized worker_pool_specs to specify + replica_count, machine types, number/type of worker pools, etc. for + distributed training. + + Args: + image_uri (str): + Required. The docker image uri for CustomJob. + inputs (List[str]): + Required. A list of inputs for CustomJob. Each item would look like + "--arg_0=value_for_arg_0". + replica_count (int): + Optional. The number of worker replicas. Assigns 1 chief replica and + replica_count - 1 worker replicas. + machine_type (str): + Optional. The type of machine to use for training. + accelerator_count (int): + Optional. The number of accelerators to attach to a worker replica. + accelerator_type (str): + Optional. Hardware accelerator type. One of + ACCELERATOR_TYPE_UNSPECIFIED, NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, + NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, NVIDIA_TESLA_T4 + boot_disk_type (str): + Optional. Type of the boot disk (default is `pd-ssd`). + Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or + `pd-standard` (Persistent Disk Hard Disk Drive). + boot_disk_size_gb (int): + Optional. Size in GB of the boot disk (default is 100GB). + boot disk size must be within the range of [100, 64000]. + + Returns: + A list of worker pool specs in the form of dictionaries. For + replica = 1, there is one worker pool spec. For replica > 1, there are + two worker pool specs. + + Raises: + ValueError if replica_count is less than 1. + """ + if replica_count < 1: + raise ValueError( + "replica_count must be a positive number but is " f"{replica_count}." + ) + + # pylint: disable=protected-access + worker_pool_specs = worker_spec_utils._DistributedTrainingSpec.chief_worker_pool( + replica_count=replica_count, + machine_type=machine_type, + accelerator_count=accelerator_count, + accelerator_type=accelerator_type, + boot_disk_type=boot_disk_type, + boot_disk_size_gb=boot_disk_size_gb, + ).pool_specs + + # Attach a container_spec to each worker pool spec. + for spec in worker_pool_specs: + spec["container_spec"] = { + "image_uri": image_uri, + "args": inputs, + } + + return worker_pool_specs + + +# pylint: disable=protected-access +def train(invokable: shared._Invokable): + """Wrapper function that runs remote container training.""" + training_config = invokable.vertex_config.remote_config + + # user can specify either worker_pool_specs OR machine_type, replica_count etc. + remote_specs._verify_specified_remote_config_values( + training_config.worker_pool_specs, + training_config.machine_type, + training_config.replica_count, + training_config.accelerator_type, + training_config.accelerator_count, + training_config.boot_disk_type, + training_config.boot_disk_size_gb, + ) + + staging_bucket = ( + training_config.staging_bucket or vertexai.preview.global_config.staging_bucket + ) + if not staging_bucket: + raise ValueError( + "No default staging bucket set. " + "Please call `vertexai.init(staging_bucket='gs://my-bucket')." + ) + input_dir = remote_specs._gen_gcs_path(staging_bucket, _INPUT_DIR) + output_dir = remote_specs._gen_gcs_path(staging_bucket, _OUTPUT_DIR) + + # Creates a complete set of binding. + instance_binding = invokable.instance._binding + binding = invokable.bound_arguments.arguments + for arg in instance_binding: + binding[arg] = instance_binding[arg] + + # If a container accepts a job_dir argument and the user does not specify + # it, set job_dir based on the staging bucket. + if _JOB_DIR in binding and not binding[_JOB_DIR]: + binding[_JOB_DIR] = remote_specs._gen_gcs_path(staging_bucket, _CUSTOM_JOB_DIR) + + # Formats arguments. + formatted_args = {} + output_specs = [] + for data in invokable.remote_executor_kwargs["additional_data"]: + if isinstance(data, remote_specs._InputParameterSpec): + formatted_args[data.argument_name] = data.format_arg(input_dir, binding) + elif isinstance(data, remote_specs._OutputParameterSpec): + formatted_args[data.argument_name] = remote_specs._gen_gcs_path( + output_dir, data.argument_name + ) + output_specs.append(data) + else: + raise ValueError(f"Invalid data type {type(data)}.") + inputs = [f"--{key}={val}" for key, val in formatted_args.items()] + + # Launches a custom job. + display_name = training_config.display_name or _DEFAULT_DISPLAY_NAME + if training_config.worker_pool_specs: + worker_pool_specs = remote_specs._prepare_worker_pool_specs( + worker_pool_specs=training_config.worker_pool_specs, + image_uri=invokable.remote_executor_kwargs["image_uri"], + args=inputs, + ) + else: + worker_pool_specs = _generate_worker_pool_specs( + image_uri=invokable.remote_executor_kwargs["image_uri"], + inputs=inputs, + replica_count=(training_config.replica_count or _DEFAULT_REPLICA_COUNT), + machine_type=(training_config.machine_type or _DEFAULT_MACHINE_TYPE), + accelerator_count=( + training_config.accelerator_count or _DEFAULT_ACCELERATOR_COUNT + ), + accelerator_type=( + training_config.accelerator_type or _DEFAULT_ACCELERATOR_TYPE + ), + boot_disk_type=(training_config.boot_disk_type or _DEFAULT_BOOT_DISK_TYPE), + boot_disk_size_gb=( + training_config.boot_disk_size_gb or _DEFAULT_BOOT_DISK_SIZE_GB + ), + ) + + job = aiplatform.CustomJob( + display_name=f"{invokable.instance.__class__.__name__}-{display_name}" + f"-{uuid.uuid4()}", + worker_pool_specs=worker_pool_specs, + base_output_dir=remote_specs._gen_gcs_path(staging_bucket, _CUSTOM_JOB_DIR), + staging_bucket=remote_specs._gen_gcs_path(staging_bucket, _CUSTOM_JOB_DIR), + ) + job.run() + + # Sets output values from the custom job. + for data in output_specs: + deserialized_output = data.deserialize_output( + formatted_args[data.argument_name] + ) + invokable.instance.__setattr__(data.name, deserialized_output) + + # Calls the decorated function for post-processing. + return invokable.method( + *invokable.bound_arguments.args, **invokable.bound_arguments.kwargs + ) diff --git a/vertexai/preview/_workflow/executor/training.py b/vertexai/preview/_workflow/executor/training.py new file mode 100644 index 0000000000..10f5ad2575 --- /dev/null +++ b/vertexai/preview/_workflow/executor/training.py @@ -0,0 +1,741 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import collections +import datetime +import inspect +import logging +import os +import re +import time +from typing import Any, Dict, List, Optional, Set, Tuple, Union +import warnings + +from google.api_core import exceptions as api_exceptions +from google.cloud import aiplatform +import vertexai +from google.cloud.aiplatform import base +from google.cloud.aiplatform import jobs +from google.cloud.aiplatform import utils +from google.cloud.aiplatform.metadata import metadata +from google.cloud.aiplatform.utils import resource_manager_utils +from vertexai.preview._workflow import shared +from vertexai.preview._workflow.serialization_engine import ( + any_serializer, +) +from vertexai.preview._workflow.serialization_engine import ( + serializers_base, +) +from vertexai.preview._workflow.shared import constants +from vertexai.preview._workflow.shared import ( + supported_frameworks, +) +from vertexai.preview.developer import remote_specs +from packaging import version + + +try: + from importlib import metadata as importlib_metadata +except ImportError: + import importlib_metadata + +try: + import bigframes as bf + + BigframesData = bf.dataframe.DataFrame +except ImportError: + bf = None + BigframesData = Any + + +try: + from google.cloud import logging as cloud_logging +except ImportError: + cloud_logging = None + + +_LOGGER = base.Logger("vertexai.remote_execution") +_LOG_POLL_INTERVAL = 5 + + +# TODO(b/271855597) Serialize all input args +PASS_THROUGH_ARG_TYPES = [str, int, float, bool] + +VERTEX_AI_DEPENDENCY_PATH = ( + f"google-cloud-aiplatform[preview]=={aiplatform.__version__}" +) +VERTEX_AI_DEPENDENCY_PATH_AUTOLOGGING = ( + f"google-cloud-aiplatform[preview,autologging]=={aiplatform.__version__}" +) + +_DEFAULT_GPU_WORKER_POOL_SPECS = remote_specs.WorkerPoolSpecs( + remote_specs.WorkerPoolSpec(1, "n1-standard-16", 1, "NVIDIA_TESLA_P100"), + remote_specs.WorkerPoolSpec(1, "n1-standard-16", 1, "NVIDIA_TESLA_P100"), +) +_DEFAULT_CPU_WORKER_POOL_SPECS = remote_specs.WorkerPoolSpecs( + remote_specs.WorkerPoolSpec(1, "n1-standard-4"), + remote_specs.WorkerPoolSpec(1, "n1-standard-4"), +) + + +def _get_package_name(requirement: str) -> str: + """Given a requirement specification, returns the package name.""" + return re.match("[a-zA-Z-_]+", requirement).group() + + +def _get_package_extras(requirement: str) -> Set: + """Given a requirement specification, returns the extra component in it.""" + # searching for patterns like [extra1,extra2,...] + extras = re.search(r"\[.*\]", requirement) + if extras: + return set([extra.strip() for extra in extras.group()[1:-1].split(",")]) + return set() + + +def _add_indirect_dependency_versions(direct_requirements: List[str]) -> List[str]: + """Helper method to get versions of libraries in the dep tree.""" + versions = {} + dependencies_and_extras = collections.deque([]) + direct_deps_packages = set() + for direct_requirement in direct_requirements: + package_name = _get_package_name(direct_requirement) + extras = _get_package_extras(direct_requirement) + direct_deps_packages.add(package_name) + try: + versions[package_name] = importlib_metadata.version(package_name) + dependencies_and_extras.append((package_name, extras)) + except importlib_metadata.PackageNotFoundError: + pass + + while dependencies_and_extras: + dependency, extras = dependencies_and_extras.popleft() + child_requirements = importlib_metadata.requires(dependency) + if not child_requirements: + continue + for child_requirement in child_requirements: + child_dependency = _get_package_name(child_requirement) + child_dependency_extras = _get_package_extras(child_requirement) + if child_dependency not in versions: + if "extra" in child_requirement: + # Matching patter "extra == 'extra_component'" in a requirement + # specification like + # "dependency_name (>=1.0.0) ; extra == 'extra_component'" + extra_component = ( + re.search(r"extra == .*", child_requirement) + .group()[len("extra == ") :] + .strip("'") + ) + # If the corresponding extra_component is not in the needed + # extras set of the parent dependency, skip this package + if extra_component not in extras: + continue + try: + versions[child_dependency] = importlib_metadata.version( + child_dependency + ) + dependencies_and_extras.append( + (child_dependency, child_dependency_extras) + ) + except importlib_metadata.PackageNotFoundError: + pass + + return [ + "==".join([package_name, package_version]) if package_version else package_name + for package_name, package_version in versions.items() + if package_name not in direct_deps_packages + ] + direct_requirements + + +def _create_worker_pool_specs( + machine_type: str, + command: str, + image_uri: str, + replica_count: int = 1, + accelerator_type: Optional[str] = None, + accelerator_count: Optional[int] = None, +) -> List[Dict[str, Any]]: + """Helper method to create worker pool specs for CustomJob.""" + worker_pool_specs = [ + { + "machine_spec": { + "machine_type": machine_type, + "accelerator_type": accelerator_type, + "accelerator_count": accelerator_count, + }, + "replica_count": replica_count, + "container_spec": { + "image_uri": image_uri, + "command": command, + "args": [], + }, + } + ] + return worker_pool_specs + + +def _get_worker_pool_specs( + config: shared.configs.RemoteConfig, image_uri: str, command: List[str] +) -> List[Dict[str, Any]]: + """Helper method to return worker_pool_specs based on user specification in training config.""" + if config.enable_distributed: + if config.worker_pool_specs: + # validate user-specified worker_pool_specs support distributed training. + # must be single worker, multi-GPU OR multi-worker, single/multi-GPU + if ( + config.worker_pool_specs.chief.accelerator_count < 2 + and not config.worker_pool_specs.worker + ): + raise ValueError( + "`enable_distributed=True` in Vertex config, but `worker_pool_specs` do not support distributed training." + ) + return remote_specs._prepare_worker_pool_specs( + config.worker_pool_specs, image_uri, command, args=[] + ) + else: + default_worker_pool_specs = ( + _DEFAULT_GPU_WORKER_POOL_SPECS + if config.enable_cuda + else _DEFAULT_CPU_WORKER_POOL_SPECS + ) + return remote_specs._prepare_worker_pool_specs( + default_worker_pool_specs, image_uri, command, args=[] + ) + + if config.worker_pool_specs: + warnings.warn( + "config.worker_pool_specs will not take effect since `enable_distributed=False`." + ) + + if config.enable_cuda: + default_machine_type = "n1-standard-16" + default_accelerator_type = "NVIDIA_TESLA_P100" + default_accelerator_count = 1 + else: + default_machine_type = "n1-standard-4" + default_accelerator_type = None + default_accelerator_count = None + + machine_type = config.machine_type or default_machine_type + accelerator_type = config.accelerator_type or default_accelerator_type + accelerator_count = config.accelerator_count or default_accelerator_count + + return _create_worker_pool_specs( + machine_type=machine_type, + command=command, + image_uri=image_uri, + accelerator_type=accelerator_type, + accelerator_count=accelerator_count, + ) + + +def _common_update_model_inplace(old_estimator, new_estimator): + for attr_name, attr_value in new_estimator.__dict__.items(): + if not attr_name.startswith("__") and not inspect.ismethod( + getattr(old_estimator, attr_name, None) + ): + setattr(old_estimator, attr_name, attr_value) + + +def _update_sklearn_model_inplace(old_estimator, new_estimator): + _common_update_model_inplace(old_estimator, new_estimator) + + +def _update_torch_model_inplace(old_estimator, new_estimator): + # make sure estimators are on the same device + device = next(old_estimator.parameters()).device + new_estimator.to(device) + _common_update_model_inplace(old_estimator, new_estimator) + + +def _update_lightning_trainer_inplace(old_estimator, new_estimator): + _common_update_model_inplace(old_estimator, new_estimator) + + +def _update_keras_model_inplace(old_estimator, new_estimator): + import tensorflow as tf + + @tf.__internal__.tracking.no_automatic_dependency_tracking + def _no_tracking_setattr(instance, name, value): + setattr(instance, name, value) + + for attr_name, attr_value in new_estimator.__dict__.items(): + if not attr_name.startswith("__") and not inspect.ismethod( + getattr(old_estimator, attr_name, None) + ): + # for Keras model, we update self's attributes with a decorated + # setattr. See b/277939758 for the details. + _no_tracking_setattr(old_estimator, attr_name, attr_value) + + +def _get_service_account( + config: shared.configs.RemoteConfig, + autolog: bool, +) -> Optional[str]: + """Helper method to get service account from RemoteConfig.""" + service_account = ( + config.service_account or vertexai.preview.global_config.service_account + ) + if service_account: + if service_account.lower() == "gce": + project = vertexai.preview.global_config.project + project_number = resource_manager_utils.get_project_number(project) + return f"{project_number}-compute@developer.gserviceaccount.com" + else: + return service_account + else: + if autolog: + raise ValueError( + "Service account has to be provided for autologging. You can " + "either use your own service account by setting " + "`model..vertex.remote_config.service_account = `, " + "or use the GCE service account by setting " + "`model..vertex.remote_config.service_account = 'GCE'`." + ) + else: + return None + + +def _dedupe_requirements(requirements: List[str]) -> List[str]: + """Helper method to deduplicate requirements by the package name. + + Args: + requirements (List[str]): + Required. A list of python packages. Can be either "my-package" or + "my-package==1.0.0". + + Returns: + A list of unique python packages. if duplicate in the original list, will + keep the first one. + """ + res = [] + req_names = set() + for req in requirements: + req_name = req.split("==")[0] + if req_name not in req_names: + req_names.add(req_name) + res.append(req) + + return res + + +def _get_remote_logs( + job_id: str, + logger: "google.cloud.logging.Logger", # noqa: F821 + log_time: datetime.datetime, + log_level: str = "INFO", + is_training_log: bool = False, +) -> Tuple[datetime.datetime, bool]: + """Helper method to get CustomJob logs from Cloud Logging. + + Args: + job_id (str): + Required. The resource id of the CustomJob. + logger (cloud_logging.Logger): + Required. A google-cloud-logging Logger object corresponding to the + CustomJob. + log_time (datetime.datetime): + Required. Logs generated after this time will get pulled. + log_level (str): + Optional. Logs greater than or equal to this level will get pulled. + Default is `INFO` level. + is_training_log (bool): + Optional. Indicates if logs after the `log_time` are training logs. + + Returns: + A tuple indicates the end time of logs and whether the training log has + started. + """ + filter_msg = [ + f"resource.labels.job_id={job_id}", + f"severity>={log_level}", + f'timestamp>"{log_time.isoformat()}"', + ] + filter_msg = " AND ".join(filter_msg) + try: + entries = logger.list_entries( + filter_=filter_msg, order_by=cloud_logging.ASCENDING + ) + except api_exceptions.PermissionDenied as e: + _LOGGER.warning( + f"Failed to get logs due to: {e}. " + "Remote execution logging is disabled. " + "Please add 'Logging Admin' role to your principal." + ) + return None, None + + for entry in entries: + log_time = entry.timestamp + message = entry.payload["message"] + if constants._START_EXECUTION_MSG in message: + is_training_log = True + if is_training_log: + _LOGGER.log(getattr(logging, entry.severity), message) + if constants._END_EXECUTION_MSG in message: + is_training_log = False + + return log_time, is_training_log + + +def _get_remote_logs_until_complete( + job: Union[str, aiplatform.CustomJob], + start_time: Optional[datetime.datetime] = None, + system_logs: bool = False, +): + """Helper method to get CustomJob logs in real time until the job is complete. + + Args: + job (Union[str, aiplatform.CustomJob]): + Required. A CustomJob ID or `aiplatform.CustomJob` object. + start_time (datetime.datetime): + Optional. Get logs generated after this start time. Default is the + start time of the CustomJob or the current time. + system_logs (bool): + Optional. If set to True, all the logs from remote job will be logged + locally. Otherwise, only training logs will be shown. + + """ + if isinstance(job, str): + job = aiplatform.CustomJob.get(job) + + if not cloud_logging: + _LOGGER.warning( + "google-cloud-logging is not installed, remote execution logging is disabled. " + "To enable logs, call `pip install google-cloud-aiplatform[preview]`." + ) + while job.state not in jobs._JOB_COMPLETE_STATES: + time.sleep(_LOG_POLL_INTERVAL) + + return + + logging_client = cloud_logging.Client(project=job.project) + # TODO(b/295375379): support remote distributed training logs + logger = logging_client.logger("workerpool0-0") + + previous_time = ( + start_time or job.start_time or datetime.datetime.now(tz=datetime.timezone.utc) + ) + is_training_log = system_logs + + while job.state not in jobs._JOB_COMPLETE_STATES: + if previous_time: + previous_time, is_training_log = _get_remote_logs( + job_id=job.name, + logger=logger, + log_time=previous_time, + log_level="INFO", + is_training_log=is_training_log, + ) + time.sleep(_LOG_POLL_INTERVAL) + + if previous_time: + _get_remote_logs( + job_id=job.name, + logger=logger, + log_time=previous_time, + log_level="INFO", + is_training_log=is_training_log, + ) + + +def remote_training(invokable: shared._Invokable): + """Wrapper function that makes a method executable by Vertex CustomJob.""" + + self = invokable.instance + method = invokable.method + method_name = method.__name__ + bound_args = invokable.bound_arguments + config = invokable.vertex_config.remote_config + + autolog = vertexai.preview.global_config.autolog + service_account = _get_service_account(config, autolog=autolog) + if autolog: + vertex_requirements = [ + VERTEX_AI_DEPENDENCY_PATH_AUTOLOGGING, + "absl-py==1.4.0", + ] + else: + vertex_requirements = [ + VERTEX_AI_DEPENDENCY_PATH, + "absl-py==1.4.0", + ] + if bf: + vertex_requirements.append("bigframes==0.1.1") + + requirements = [] + + enable_cuda = config.enable_cuda + + # TODO(b/274979556): consider other approaches to pass around the primitives + pass_through_int_args = {} + pass_through_float_args = {} + pass_through_str_args = {} + pass_through_bool_args = {} + serialized_args = {} + + for arg_name, arg_value in bound_args.arguments.items(): + if arg_name == "self": + pass + elif isinstance(arg_value, int): + pass_through_int_args[arg_name] = arg_value + elif isinstance(arg_value, float): + pass_through_float_args[arg_name] = arg_value + elif isinstance(arg_value, str): + pass_through_str_args[arg_name] = arg_value + elif isinstance(arg_value, bool): + pass_through_bool_args[arg_name] = arg_value + else: + serialized_args[arg_name] = arg_value + + # set base gcs path for the remote job + staging_bucket = ( + config.staging_bucket or vertexai.preview.global_config.staging_bucket + ) + if not staging_bucket: + raise ValueError( + "No default staging bucket set. " + "Please call `vertexai.init(staging_bucket='gs://my-bucket')." + ) + remote_job = f"remote-job-{utils.timestamped_unique_name()}" + remote_job_base_path = os.path.join(staging_bucket, remote_job) + remote_job_input_path = os.path.join(remote_job_base_path, "input") + remote_job_output_path = os.path.join(remote_job_base_path, "output") + + detected_framework = None + if supported_frameworks._is_sklearn(self): + detected_framework = "sklearn" + elif supported_frameworks._is_keras(self): + detected_framework = "tensorflow" + # TODO(b/295580335): Investigate Tensorflow 2.13 GPU Hanging + import tensorflow as tf + + accelerator_count = config.accelerator_count if config.accelerator_count else 0 + if ( + version.Version(tf.__version__).base_version >= "2.13.0" + and accelerator_count > 1 + ): + raise ValueError( + f"Currently Tensorflow {tf.__version__} doesn't support multi-gpu training." + ) + elif supported_frameworks._is_torch(self): + detected_framework = "torch" + # TODO(b/296944997): Support remote training on torch<2 + import torch + + if version.Version(torch.__version__).base_version < "2.0.0": + raise ValueError( + f"Currently Vertex remote training doesn't support torch {torch.__version__}. " + "Please use torch>=2.0.0" + ) + + # serialize the estimator + serializer = any_serializer.AnySerializer() + serialization_metadata = serializer.serialize( + to_serialize=self, + gcs_path=os.path.join(remote_job_input_path, "input_estimator"), + ) + requirements += serialization_metadata[ + serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY + ] + # serialize args + for arg_name, arg_value in serialized_args.items(): + if supported_frameworks._is_bigframe(arg_value): + serialization_metadata = serializer.serialize( + to_serialize=arg_value, + gcs_path=os.path.join(remote_job_input_path, f"{arg_name}"), + framework=detected_framework, + ) + else: + serialization_metadata = serializer.serialize( + to_serialize=arg_value, + gcs_path=os.path.join(remote_job_input_path, f"{arg_name}"), + ) + # serializer.get_dependencies() must be run after serializer.serialize() + requirements += serialization_metadata[ + serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY + ] + + # execute the method in CustomJob + # set training configuration + display_name = config.display_name or remote_job + + # get or generate worker_pool_specs + # user can specify either worker_pool_specs OR machine_type etc. + remote_specs._verify_specified_remote_config_values( + config.worker_pool_specs, + config.machine_type, + config.accelerator_type, + config.accelerator_count, + ) + + if not config.container_uri: + container_uri = ( + supported_frameworks._get_cpu_container_uri() + if not enable_cuda + else supported_frameworks._get_gpu_container_uri(self) + ) + requirements = _dedupe_requirements( + vertex_requirements + config.requirements + requirements + ) + else: + container_uri = config.container_uri + requirements = _dedupe_requirements(vertex_requirements + config.requirements) + + requirements = _add_indirect_dependency_versions(requirements) + command = ["export PIP_ROOT_USER_ACTION=ignore &&"] + if config.custom_commands: + custom_commands = [f"{command} &&" for command in config.custom_commands] + command.extend(custom_commands) + if requirements: + command.append("pip install --upgrade pip &&") + requirements = [f"'{requirement}'" for requirement in requirements] + command.append(f"pip install {' '.join(requirements)} &&") + + pass_through_bool_args_flag_value = ",".join( + f"{key}={value}" for key, value in pass_through_bool_args.items() + ) + pass_through_int_args_flag_value = ",".join( + f"{key}={value}" for key, value in pass_through_int_args.items() + ) + pass_through_float_args_flag_value = ",".join( + f"{key}={value}" for key, value in pass_through_float_args.items() + ) + pass_through_str_args_flag_value = ",".join( + f"{key}={value}" for key, value in pass_through_str_args.items() + ) + + autolog_command = " --enable_autolog" if autolog else "" + + training_command = ( + "python3 -m " + "vertexai.preview._workflow.executor.training_script " + f"--pass_through_int_args={pass_through_int_args_flag_value} " + f"--pass_through_float_args={pass_through_float_args_flag_value} " + f"--pass_through_str_args={pass_through_str_args_flag_value} " + f"--pass_through_bool_args={pass_through_bool_args_flag_value} " + f"--input_path={remote_job_input_path.replace('gs://', '/gcs/', 1)} " + f"--output_path={remote_job_output_path.replace('gs://', '/gcs/', 1)} " + f"--method_name={method_name} " + + f"--arg_names={','.join(list(serialized_args.keys()))} " + + f"--enable_cuda={enable_cuda} " + + f"--enable_distributed={config.enable_distributed} " + # For distributed training. Use this to infer tf.distribute strategy for Keras training. + # Keras single worker, multi-gpu needs to be compiled with tf.distribute.MirroredStrategy. + # Keras multi-worker needs to be compiled with tf.distribute.MultiWorkerMirroredStrategy. + + f"--accelerator_count={0 if not config.accelerator_count else config.accelerator_count}" + + autolog_command + ) + command.append(training_command) + # Temporary fix for git not installed in pytorch cuda image + # Remove it once SDK 2.0 is release and don't need to be installed from git + if container_uri == "pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime": + command = ["apt-get update && apt-get install -y git &&"] + command + + command = ["sh", "-c", " ".join(command)] + + # create & run the CustomJob + + # disable CustomJob logs + logging.getLogger("google.cloud.aiplatform.jobs").disabled = True + try: + job = aiplatform.CustomJob( + display_name=display_name, + project=vertexai.preview.global_config.project, + location=vertexai.preview.global_config.location, + worker_pool_specs=_get_worker_pool_specs(config, container_uri, command), + base_output_dir=remote_job_base_path, + staging_bucket=remote_job_base_path, + ) + + job.submit( + service_account=service_account, + # TODO(jayceeli) Remove this check when manual logging is supported. + experiment=metadata._experiment_tracker.experiment if autolog else None, + experiment_run=metadata._experiment_tracker.experiment_run + if autolog + else None, + ) + job.wait_for_resource_creation() + + _LOGGER.info(f"Remote job created. View the job: {job._dashboard_uri()}") + + _get_remote_logs_until_complete( + job=job, + system_logs=config.enable_full_logs, + ) + except Exception as e: + raise e + finally: + # enable CustomJob logs after remote training job is done + logging.getLogger("google.cloud.aiplatform.jobs").disabled = False + + if job.state in jobs._JOB_ERROR_STATES: + return job + + add_model_to_history_obj = False + # retrieve the result from gcs to local + if method_name in supported_frameworks.REMOTE_TRAINING_STATEFUL_OVERRIDE_LIST: + estimator = serializer.deserialize( + os.path.join(remote_job_output_path, "output_estimator"), + ) + + if supported_frameworks._is_sklearn(self): + _update_sklearn_model_inplace(self, estimator) + + elif supported_frameworks._is_keras(self): + add_model_to_history_obj = True + _update_keras_model_inplace(self, estimator) + + elif supported_frameworks._is_torch(self): + _update_torch_model_inplace(self, estimator) + + elif supported_frameworks._is_lightning(self): + _update_lightning_trainer_inplace(self, estimator) + # deserialize and update the trained model as well + trained_model = serializer.deserialize( + os.path.join(remote_job_output_path, "output_estimator", "model") + ) + _update_torch_model_inplace(serialized_args["model"], trained_model) + else: + # if it's a custom model, update the model object by iterating its + # attributes. A custom model is any class that has a method + # decorated by @vertexai.preview.developer.mark.train (and optionally + # another method decorated by @vertexai.preview.developer.mark.predict). + _common_update_model_inplace(self, estimator) + + if method_name in supported_frameworks.REMOTE_PREDICTION_OVERRIDE_LIST: + predictions = serializer.deserialize( + os.path.join(remote_job_output_path, "output_predictions") + ) + return predictions + + # Note: "output_data" refers to general output from the executed method, not + # just a transformed data. + try: + # TODO b/296584472: figure out a general mechanism to populate + # inter-object references. + if add_model_to_history_obj: + output_data = serializer.deserialize( + os.path.join(remote_job_output_path, "output_data"), model=self + ) + else: + output_data = serializer.deserialize( + os.path.join(remote_job_output_path, "output_data") + ) + return output_data + except Exception as e: + _LOGGER.warning( + f"Fail to deserialize the output due to error {e}, " "returning None." + ) + return None diff --git a/vertexai/preview/_workflow/executor/training_script.py b/vertexai/preview/_workflow/executor/training_script.py new file mode 100644 index 0000000000..981e5d7eb7 --- /dev/null +++ b/vertexai/preview/_workflow/executor/training_script.py @@ -0,0 +1,234 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Training script to be run in Vertex CustomJob. +""" + +# import modules +import os + +from absl import app +from absl import flags +import vertexai +from vertexai.preview._workflow.serialization_engine import ( + any_serializer, + serializers_base, +) +from vertexai.preview._workflow.shared import ( + constants, + supported_frameworks, +) +from vertexai.preview.developer import remote_specs + + +try: + # This line ensures a tensorflow model to be loaded by cloudpickle correctly + # We put it in a try clause since not all models are tensorflow and if it is + # a tensorflow model, the dependency should've been installed and therefore + # import should work. + import tensorflow as tf # noqa: F401 +except ImportError: + pass + + +os.environ["_IS_VERTEX_REMOTE_TRAINING"] = "True" + +print(constants._START_EXECUTION_MSG) + +_ARGS = flags.DEFINE_list( + "arg_names", [], "Argument names of those to be deserialized." +) +# TODO(b/274979556): consider other approaches to pass around the primitives +_PASS_THROUGH_INT_ARGS = flags.DEFINE_list( + "pass_through_int_args", [], "Pass-through integer arguments." +) +_PASS_THROUGH_FLOAT_ARGS = flags.DEFINE_list( + "pass_through_float_args", [], "Pass-through float arguments." +) +_PASS_THROUGH_BOOL_ARGS = flags.DEFINE_list( + "pass_through_bool_args", [], "Pass-through bool arguments." +) +_PASS_THROUGH_STR_ARGS = flags.DEFINE_list( + "pass_through_str_args", [], "Pass-through string arguments." +) +_METHOD_NAME = flags.DEFINE_string("method_name", None, "Method being called") + +_INPUT_PATH = flags.DEFINE_string("input_path", None, "input path.") +_OUTPUT_PATH = flags.DEFINE_string("output_path", None, "output path.") +_ENABLE_AUTOLOG = flags.DEFINE_bool("enable_autolog", False, "enable autolog.") +_ENABLE_CUDA = flags.DEFINE_bool("enable_cuda", False, "enable cuda.") +_ENABLE_DISTRIBUTED = flags.DEFINE_bool( + "enable_distributed", False, "enable distributed training." +) +_ACCELERATOR_COUNT = flags.DEFINE_integer( + "accelerator_count", + 0, + "accelerator count for single worker, multi-gpu training.", +) + + +# pylint: disable=protected-access +def main(argv): + del argv + + # set cuda for tensorflow & pytorch + try: + import tensorflow + + if not _ENABLE_CUDA.value: + tensorflow.config.set_visible_devices([], "GPU") + except ImportError: + pass + + try: + import torch + + torch.set_default_device("cuda" if _ENABLE_CUDA.value else "cpu") + except ImportError: + torch = None + + strategy = None + try: + from tensorflow import keras # noqa: F401 + + # distribute strategy must be initialized at the beginning of the program + # to avoid RuntimeError: "Collective ops must be configured at program startup" + strategy = remote_specs._get_keras_distributed_strategy( + _ENABLE_DISTRIBUTED.value, _ACCELERATOR_COUNT.value + ) + + except ImportError: + pass + + if _ENABLE_AUTOLOG.value: + vertexai.preview.init(autolog=True) + + # retrieve the estimator + serializer = any_serializer.AnySerializer() + estimator = serializer.deserialize( + os.path.join(_INPUT_PATH.value, "input_estimator") + ) + + if strategy and supported_frameworks._is_keras(estimator): + # Single worker, multi-gpu will be compiled with tf.distribute.MirroredStrategy. + # Multi-worker will be compiled with tf.distribute.MultiWorkerMirroredStrategy. + # Single worker CPU/GPU will be returned as is. + estimator = remote_specs._set_keras_distributed_strategy(estimator, strategy) + + if supported_frameworks._is_lightning(estimator): + from lightning.pytorch.trainer.connectors.accelerator_connector import ( + _AcceleratorConnector, + ) + + # Re-instantiate accelerator connecotor in remote environment. Most of configs + # like strategy, devices will be automatically handled by + # the _AcceleratorConnector class. + # accelerator and num_nodes need to be manually set. + accelerator = "gpu" if _ENABLE_CUDA.value else "cpu" + num_nodes = ( + remote_specs._get_cluster_spec().get_world_size() + if _ENABLE_DISTRIBUTED.value + else 1 + ) + estimator._accelerator_connector = _AcceleratorConnector( + accelerator=accelerator, + num_nodes=num_nodes, + ) + + # retrieve seriliazed_args + kwargs = {} + for arg_name in _ARGS.value: + arg_value = serializer.deserialize(os.path.join(_INPUT_PATH.value, arg_name)) + + if supported_frameworks._is_torch_dataloader(arg_value): + # update gpu setting in dataloader for pytorch model gpu training + # lightning will automatically handle the data so no need to update + if supported_frameworks._is_torch(estimator) and _ENABLE_CUDA.value: + arg_value.pin_memory = True + arg_value.pin_memory_device = "cuda" + arg_value.generator = torch.Generator("cuda") + if hasattr(arg_value.sampler, "generator"): + setattr(arg_value.sampler, "generator", arg_value.generator) + # make sure the torch default device is the same as + # dataloader generator's device + torch.set_default_device( + arg_value.generator.device.type if arg_value.generator else "cpu" + ) + + kwargs[arg_name] = arg_value + + for arg_name_and_arg_value in _PASS_THROUGH_INT_ARGS.value: + arg_name, arg_value = arg_name_and_arg_value.split("=") + kwargs[arg_name] = int(arg_value) + for arg_name_and_arg_value in _PASS_THROUGH_FLOAT_ARGS.value: + arg_name, arg_value = arg_name_and_arg_value.split("=") + kwargs[arg_name] = float(arg_value) + for arg_name_and_arg_value in _PASS_THROUGH_BOOL_ARGS.value: + arg_name, arg_value = arg_name_and_arg_value.split("=") + kwargs[arg_name] = bool(arg_value) + for arg_name_and_arg_value in _PASS_THROUGH_STR_ARGS.value: + arg_name, arg_value = arg_name_and_arg_value.split("=") + kwargs[arg_name] = arg_value + + # for all custom trainers, set cluster_spec if available + if ( + isinstance(estimator, vertexai.preview.VertexModel) + and _ENABLE_DISTRIBUTED.value + ): + setattr(estimator, "cluster_spec", remote_specs._get_cluster_spec()) + if supported_frameworks._is_torch(estimator): + # need to know if GPU training is enabled for the + # optional remote_specs.setup_pytorch_distributed_training() + # function that a user can call in train() + setattr(estimator, "_enable_cuda", _ENABLE_CUDA.value) + + output = getattr(estimator, _METHOD_NAME.value)(**kwargs) + + # serialize the output + os.makedirs(_OUTPUT_PATH.value, exist_ok=True) + + if ( + _METHOD_NAME.value + in supported_frameworks.REMOTE_TRAINING_STATEFUL_OVERRIDE_LIST + ): + # for distributed training, chief saves output to specified output + # directory while non-chief workers save output to temp directory. + output_path = remote_specs._get_output_path_for_distributed_training( + _OUTPUT_PATH.value, "output_estimator" + ) + serializer.serialize(estimator, output_path) + + # for pytorch lightning trainer, we want to serialize the trained model as well + if "model" in _ARGS.value: + serializer.serialize(kwargs["model"], os.path.join(output_path, "model")) + + # for remote prediction + if _METHOD_NAME.value in supported_frameworks.REMOTE_PREDICTION_OVERRIDE_LIST: + serializer.serialize( + output, os.path.join(_OUTPUT_PATH.value, "output_predictions") + ) + + output_path = remote_specs._get_output_path_for_distributed_training( + _OUTPUT_PATH.value, "output_data" + ) + try: + serializer.serialize(output, output_path) + except serializers_base.SerializationError as e: + print(f"failed to serialize the output due to {e}") + + print(constants._END_EXECUTION_MSG) + + +if __name__ == "__main__": + app.run(main) diff --git a/vertexai/preview/_workflow/launcher/__init__.py b/vertexai/preview/_workflow/launcher/__init__.py new file mode 100644 index 0000000000..709dbd017e --- /dev/null +++ b/vertexai/preview/_workflow/launcher/__init__.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any + +from vertexai.preview._workflow import executor +from vertexai.preview._workflow import shared + + +class _WorkflowLauncher: + """Launches workflows either locally or remotely.""" + + def launch(self, invokable: shared._Invokable, global_remote: bool): + + local_remote = invokable.vertex_config.remote + + if local_remote or (local_remote is None and global_remote): + result = self._remote_launch(invokable) + else: + for _, arg in invokable.bound_arguments.arguments.items(): + if "bigframes" in repr(type(arg)): + raise ValueError( + "Bigframes not supported if vertexai.preview.init(remote=False)" + ) + result = self._local_launch(invokable) + return result + + def _remote_launch(self, invokable: shared._Invokable) -> Any: + result = executor._workflow_executor.remote_execute(invokable) + # TODO(b/277343861) workflow tracking goes here + # E.g., initializer.global_config.workflow.add_remote_step(invokable, result) + + return result + + def _local_launch(self, invokable: shared._Invokable) -> Any: + result = executor._workflow_executor.local_execute(invokable) + # TODO(b/277343861) workflow tracking goes here + # E.g., initializer.global_config.workflow.add_local_step(invokable, result) + + return result diff --git a/vertexai/preview/_workflow/serialization_engine/__init__.py b/vertexai/preview/_workflow/serialization_engine/__init__.py new file mode 100644 index 0000000000..b24e67a831 --- /dev/null +++ b/vertexai/preview/_workflow/serialization_engine/__init__.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/vertexai/preview/_workflow/serialization_engine/any_serializer.py b/vertexai/preview/_workflow/serialization_engine/any_serializer.py new file mode 100644 index 0000000000..aac34d698c --- /dev/null +++ b/vertexai/preview/_workflow/serialization_engine/any_serializer.py @@ -0,0 +1,364 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=line-too-long, bad-continuation,protected-access +"""Defines the Serializer classes.""" +import json +import os +import pathlib +import tempfile +from typing import Any, Dict, Union, List, TypeVar, Type + +from google.cloud.aiplatform import base +from google.cloud.aiplatform.utils import gcs_utils +from vertexai.preview._workflow.serialization_engine import ( + serializers, + serializers_base, +) +from vertexai.preview._workflow.shared import ( + supported_frameworks, +) + +from packaging import requirements + +# TODO(b/272263750): use the centralized module and usage pattern to guard these +# imports +# pylint: disable=g-import-not-at-top +try: + import pandas as pd + import bigframes as bf + + PandasData = pd.DataFrame + BigframesData = bf.dataframe.DataFrame +except ImportError: + pd = None + bf = None + PandasData = Any + BigframesData = Any + +try: + import pandas as pd + + PandasData = pd.DataFrame +except ImportError: + pd = None + PandasData = Any + +try: + import sklearn + + SklearnEstimator = sklearn.base.BaseEstimator +except ImportError: + sklearn = None + SklearnEstimator = Any + +try: + from tensorflow import keras + import tensorflow as tf + + KerasModel = keras.models.Model + TFDataset = tf.data.Dataset +except ImportError: + keras = None + tf = None + KerasModel = Any + TFDataset = Any + +try: + import torch + + TorchModel = torch.nn.Module + TorchDataLoader = torch.utils.data.DataLoader +except ImportError: + torch = None + TorchModel = Any + TorchDataLoader = Any + +try: + import lightning.pytorch as pl + + LightningTrainer = pl.Trainer +except ImportError: + pl = None + LightningTrainer = Any + + +T = TypeVar("T") + +Types = Union[ + PandasData, + BigframesData, + SklearnEstimator, + KerasModel, + TorchModel, + LightningTrainer, +] + +_LOGGER = base.Logger("vertexai.serialization_engine") + +SERIALIZATION_METADATA_FILENAME = "serialization_metadata" +SERIALIZATION_METADATA_SERIALIZER_KEY = "serializer" +SERIALIZATION_METADATA_DEPENDENCIES_KEY = "dependencies" +SERIALIZATION_METADATA_FRAMEWORK_KEY = "framework" + +_LIGHTNING_ROOT_DIR = "/vertex_lightning_root_dir/" + + +def _get_metadata_path_from_file_gcs_uri(gcs_uri: str) -> str: + gcs_pathlibpath = pathlib.Path(gcs_uri) + prefix = _get_uri_prefix(gcs_uri=gcs_uri) + return os.path.join( + prefix, + f"{SERIALIZATION_METADATA_FILENAME}_{gcs_pathlibpath.stem}.json", + ) + + +def _get_uri_prefix(gcs_uri: str) -> str: + """Gets the directory of the gcs_uri. + + Example: + 1) file uri: + _get_uri_prefix("gs:///directory/file.extension") == "gs:// + /directory/" + 2) folder uri: + _get_uri_prefix("gs:///parent_dir/dir") == "gs:/// + parent_dir/" + Args: + gcs_uri: A string starting with "gs://" that represent a gcs uri. + Returns: + The parent gcs directory in string format. + """ + # For tensorflow, the uri may be "gs://my-bucket/saved_model/" + if gcs_uri.endswith("/"): + gcs_uri = gcs_uri[:-1] + gcs_pathlibpath = pathlib.Path(gcs_uri) + file_name = gcs_pathlibpath.name + return gcs_uri[: -len(file_name)] + + +def _check_dependency_versions(required_packages: List[str]): + for package in required_packages: + requirement = requirements.Requirement(package) + package_name = requirement.name + current_version = supported_frameworks._get_version_for_package(package_name) + if not requirement.specifier.contains(current_version): + _LOGGER.warning( + "%s's version is %s, while the required version is %s", + package_name, + current_version, + requirement.specifier, + ) + + +def _get_custom_serializer_path_from_file_gcs_uri( + gcs_uri: str, serializer_name: str +) -> str: + prefix = _get_uri_prefix(gcs_uri=gcs_uri) + return os.path.join(prefix, f"{serializer_name}") + + +class AnySerializer(serializers_base.Serializer): + """A serializer that can routes any object to their own serializer.""" + + _metadata: serializers_base.SerializationMetadata = ( + serializers_base.SerializationMetadata(serializer="AnySerializer") + ) + + def __init__(self): + super().__init__() + # Register with default serializers + AnySerializer._register(object, serializers.CloudPickleSerializer) + if sklearn: + AnySerializer._register( + sklearn.base.BaseEstimator, serializers.SklearnEstimatorSerializer + ) + if keras: + AnySerializer._register( + keras.models.Model, serializers.KerasModelSerializer + ) + AnySerializer._register( + keras.callbacks.History, serializers.KerasHistoryCallbackSerializer + ) + if tf: + AnySerializer._register(tf.data.Dataset, serializers.TFDatasetSerializer) + if torch: + AnySerializer._register(torch.nn.Module, serializers.TorchModelSerializer) + AnySerializer._register( + torch.utils.data.DataLoader, serializers.TorchDataLoaderSerializer + ) + if pl: + AnySerializer._register(pl.Trainer, serializers.LightningTrainerSerializer) + if bf: + AnySerializer._register( + bf.dataframe.DataFrame, serializers.BigframeSerializer + ) + if pd: + AnySerializer._register(pd.DataFrame, serializers.PandasDataSerializer) + + @classmethod + def _get_custom_serializer(cls, type_cls): + return cls._custom_serialization_scheme.get(type_cls) + + @classmethod + def _get_predefined_serializer(cls, type_cls): + return cls._serialization_scheme.get(type_cls) + + def serialize(self, to_serialize: T, gcs_path: str, **kwargs) -> Dict[str, Any]: + """Simplified version of serialize().""" + metadata_path = _get_metadata_path_from_file_gcs_uri(gcs_path) + # TODO(b/277906396): consider implementing object-level serialization. + + for i, step_type in enumerate( + to_serialize.__class__.__mro__ + to_serialize.__class__.__mro__ + ): + # Iterate through the custom serialization scheme first. + if ( + i < len(to_serialize.__class__.__mro__) + and step_type not in AnySerializer._custom_serialization_scheme + ) or ( + i >= len(to_serialize.__class__.__mro__) + and step_type not in AnySerializer._serialization_scheme + ): + continue + elif i < len(to_serialize.__class__.__mro__): + serializer = AnySerializer._get_custom_serializer( + step_type + ).get_instance() # pytype: disable=attribute-error + # If the Serializer is a custom Serializer, serialize the + # Custom Serializer first. + serializer_path = _get_custom_serializer_path_from_file_gcs_uri( + gcs_path, serializer.__class__.__name__ + ) + serializers.CloudPickleSerializer().serialize( + serializer, serializer_path + ) + else: + serializer = AnySerializer._get_predefined_serializer( + step_type + ).get_instance() + + try: + serializer.serialize( + to_serialize=to_serialize, gcs_path=gcs_path, **kwargs + ) + except Exception as e: # pylint: disable=broad-exception-caught + if serializer.__class__.__name__ != "CloudPickleSerializer": + _LOGGER.warning( + "Failed to serialize %s with %s due to error %s", + to_serialize.__class__.__name__, + serializer.__class__.__name__, + e, + ) + # Falling back to Serializers of super classes + continue + else: + raise serializers_base.SerializationError from e + + metadata = serializer._metadata.to_dict() + serializers_base.write_and_upload_data( + json.dumps(metadata).encode(), metadata_path + ) + + return metadata + + def deserialize(self, serialized_gcs_path: str, **kwargs) -> T: + """Routes the corresponding Serializer based on the metadata.""" + metadata_path = _get_metadata_path_from_file_gcs_uri(serialized_gcs_path) + + if metadata_path.startswith("gs://"): + with tempfile.NamedTemporaryFile() as temp_file: + gcs_utils.download_file_from_gcs(metadata_path, temp_file.name) + with open(temp_file.name, mode="rb") as f: + metadata = json.load(f) + else: + with open(metadata_path, mode="rb") as f: + metadata = json.load(f) + + _LOGGER.debug( + "deserializing from %s, metadata is %s", serialized_gcs_path, metadata + ) + + serializer_cls_name = metadata[SERIALIZATION_METADATA_SERIALIZER_KEY] + packages = metadata[SERIALIZATION_METADATA_DEPENDENCIES_KEY] + _check_dependency_versions(packages) + serializer_class = getattr( + serializers, serializer_cls_name, None + ) or globals().get(serializer_cls_name) + if not serializer_class: + # Serializer is an unregistered custom Serializer. + # Deserialize serializer. + serializer_path = _get_custom_serializer_path_from_file_gcs_uri( + serialized_gcs_path, serializer_cls_name + ) + serializer = serializers.CloudPickleSerializer().deserialize( + serialized_gcs_path=serializer_path + ) + else: + serializer = serializer_class.get_instance() + + # TODO(b/277906396): implement object-level serialization. + if SERIALIZATION_METADATA_FRAMEWORK_KEY in metadata: + serializer.__class__._metadata = serializers.BigframeSerializationMetadata( + **metadata + ) + else: + serializer.__class__._metadata = serializers_base.SerializationMetadata( + **metadata + ) + + obj = serializer.deserialize(serialized_gcs_path=serialized_gcs_path, **kwargs) + if not serializer_class: + # Register the serializer + AnySerializer.register_custom(obj.__class__, serializer.__class__) + AnySerializer._instances[serializer.__class__] = serializer + return obj + + +def register_serializer( + to_serialize_type: Type[Any], serializer_cls: Type[serializers_base.Serializer] +): + """Registers a Serializer for a specific type. + + Example Usage: + + ``` + import vertexai + + # define a custom Serializer + class KerasCustomSerializer( + vertexai.preview.developer.Serializer): + _metadata = vertexai.preview.developer.SerializationMetadata() + + def serialize(self, to_serialize, gcs_path): + ... + def deserialize(self, gcs_path): + ... + + KerasCustomSerializer.register_requirements( + ['library1==1.0.0', 'library2<2.0']) + vertexai.preview.developer.register_serializer( + keras.models.Model, KerasCustomSerializer) + ``` + + Args: + to_serialize_type: The class that is supposed to be serialized with + the to-be-registered custom Serializer. + serializer_cls: The custom Serializer to be registered. + """ + any_serializer = AnySerializer() + any_serializer.register_custom( + to_serialize_type=to_serialize_type, serializer_cls=serializer_cls + ) diff --git a/vertexai/preview/_workflow/serialization_engine/serializers.py b/vertexai/preview/_workflow/serialization_engine/serializers.py new file mode 100644 index 0000000000..b9ce12a825 --- /dev/null +++ b/vertexai/preview/_workflow/serialization_engine/serializers.py @@ -0,0 +1,1207 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=line-too-long, bad-continuation,protected-access +"""Defines the Serializer classes.""" + +import dataclasses +import functools +import json +import os +import pickle +import shutil +import tempfile +from typing import Any, Optional, Union +import uuid + +from google.cloud.aiplatform.utils import gcs_utils +from vertexai.preview._workflow.shared import constants +from vertexai.preview._workflow.shared import ( + data_serializer_utils, + supported_frameworks, +) +from vertexai.preview._workflow.serialization_engine import ( + serializers_base, +) + +try: + import cloudpickle +except ImportError: + cloudpickle = None + +SERIALIZATION_METADATA_FRAMEWORK_KEY = "framework" + +# TODO(b/272263750): use the centralized module and usage pattern to guard these +# imports +# pylint: disable=g-import-not-at-top +try: + import pandas as pd + import bigframes as bf + + PandasData = pd.DataFrame + BigframesData = bf.dataframe.DataFrame +except ImportError: + pd = None + bf = None + PandasData = Any + BigframesData = Any + +try: + import pandas as pd + import pyarrow as pa + import pyarrow.parquet as pq + + PandasData = pd.DataFrame +except ImportError: + pd = None + pa = None + pq = None + PandasData = Any + +try: + import sklearn + + SklearnEstimator = sklearn.base.BaseEstimator +except ImportError: + sklearn = None + SklearnEstimator = Any + +try: + from tensorflow import keras + import tensorflow as tf + + KerasModel = keras.models.Model + TFDataset = tf.data.Dataset +except ImportError: + keras = None + tf = None + KerasModel = Any + TFDataset = Any + +try: + import torch + + TorchModel = torch.nn.Module + TorchDataLoader = torch.utils.data.DataLoader + TorchTensor = torch.tensor +except ImportError: + torch = None + TorchModel = Any + TorchDataLoader = Any + TorchTensor = Any + +try: + import lightning.pytorch as pl + + LightningTrainer = pl.Trainer +except ImportError: + pl = None + LightningTrainer = Any + + +Types = Union[ + PandasData, + BigframesData, + SklearnEstimator, + KerasModel, + TorchModel, + LightningTrainer, +] + +_LIGHTNING_ROOT_DIR = "/vertex_lightning_root_dir/" + + +def _is_valid_gcs_path(path: str) -> bool: + """checks if a path is a valid gcs path. + + Args: + path (str): + Required. A file path. + + Returns: + A boolean that indicates whether the path is a valid gcs path. + """ + return path.startswith(("gs://", "/gcs/", "gcs/")) + + +def _load_torch_model(path: str, map_location: "torch.device") -> TorchModel: + try: + return torch.load(path, map_location=map_location) + except Exception: + return torch.load(path, map_location=torch.device("cpu")) + + +class KerasModelSerializer(serializers_base.Serializer): + """A serializer for tensorflow.keras.models.Model objects.""" + + _metadata: serializers_base.SerializationMetadata = ( + serializers_base.SerializationMetadata(serializer="KerasModelSerializer") + ) + + def serialize( + self, to_serialize: KerasModel, gcs_path: str, **kwargs + ) -> str: # pytype: disable=invalid-annotation + """Serializes a tensorflow.keras.models.Model to a gcs path. + + Args: + to_serialize (keras.models.Model): + Required. A Keras Model object. + gcs_path (str): + Required. A GCS uri that the model will be saved to. + + Returns: + The GCS uri. + + Raises: + ValueError: if `gcs_path` is not a valid GCS uri. + """ + del kwargs + if not _is_valid_gcs_path(gcs_path): + raise ValueError(f"Invalid gcs path: {gcs_path}") + + KerasModelSerializer._metadata.dependencies = ( + supported_frameworks._get_deps_if_tensorflow_model(to_serialize) + ) + to_serialize.save(gcs_path) + + return gcs_path + + def deserialize( + self, serialized_gcs_path: str, **kwargs + ) -> KerasModel: # pytype: disable=invalid-annotation + """Deserialize a tensorflow.keras.models.Model given the gcs file name. + + Args: + serialized_gcs_path (str): + Required. A GCS path to the serialized file. + + Returns: + A Keras Model. + + Raises: + ValueError: if `serialized_gcs_path` is not a valid GCS uri. + ImportError: if tensorflow is not installed. + """ + del kwargs + if not _is_valid_gcs_path(serialized_gcs_path): + raise ValueError(f"Invalid gcs path: {serialized_gcs_path}") + + try: + from tensorflow import keras + + return keras.models.load_model(serialized_gcs_path) + except ImportError as e: + raise ImportError("tensorflow is not installed.") from e + + +class KerasHistoryCallbackSerializer(serializers_base.Serializer): + """A serializer for tensorflow.keras.callbacks.History objects.""" + + _metadata: serializers_base.SerializationMetadata = ( + serializers_base.SerializationMetadata( + serializer="KerasHistoryCallbackSerializer" + ) + ) + + def serialize(self, to_serialize, gcs_path: str, **kwargs): + """Serializes a keras History callback to a gcs path. + + Args: + to_serialize (keras.callbacks.History): + Required. A History object. + gcs_path (str): + Required. A GCS uri that History object will be saved to. + + Returns: + The GCS uri. + + Raises: + ValueError: if `gcs_path` is not a valid GCS uri. + """ + del kwargs + if not _is_valid_gcs_path(gcs_path): + raise ValueError(f"Invalid gcs path: {gcs_path}") + + KerasHistoryCallbackSerializer._metadata.dependencies = ["cloudpickle"] + + to_serialize_dict = to_serialize.__dict__ + del to_serialize_dict["model"] + with open(gcs_path, "wb") as f: + cloudpickle.dump(to_serialize_dict, f) + + return gcs_path + + def deserialize(self, serialized_gcs_path: str, **kwargs): + """Deserialize a keras.callbacks.History given the gcs file name. + + Args: + serialized_gcs_path (str): + Required. A GCS path to the serialized file. + + Returns: + A keras.callbacks.History object. + + Raises: + ValueError: if `serialized_gcs_path` is not a valid GCS uri. + """ + + if not _is_valid_gcs_path(serialized_gcs_path): + raise ValueError(f"Invalid gcs path: {serialized_gcs_path}") + model = kwargs.get("model", None) + # Only "model" is needed. + del kwargs + + history_dict = {} + if serialized_gcs_path.startswith("gs://"): + with tempfile.NamedTemporaryFile() as temp_file: + gcs_utils.download_file_from_gcs(serialized_gcs_path, temp_file.name) + with open(temp_file.name, mode="rb") as f: + history_dict = cloudpickle.load(f) + else: + with open(serialized_gcs_path, mode="rb") as f: + history_dict = cloudpickle.load(f) + + history_obj = keras.callbacks.History() + + for attr_name, attr_value in history_dict.items(): + setattr(history_obj, attr_name, attr_value) + + if model: + history_obj.set_model(model) + + return history_obj + + +class SklearnEstimatorSerializer(serializers_base.Serializer): + """A serializer that uses pickle to save/load sklearn estimators.""" + + _metadata: serializers_base.SerializationMetadata = ( + serializers_base.SerializationMetadata(serializer="SklearnEstimatorSerializer") + ) + + def serialize(self, to_serialize: SklearnEstimator, gcs_path: str, **kwargs) -> str: + """Serializes a sklearn estimator to a gcs path. + + Args: + to_serialize (sklearn.base.BaseEstimator): + Required. A sklearn estimator. + gcs_path (str): + Required. A GCS uri that the estimator will be saved to. + + Returns: + The GCS uri. + + Raises: + ValueError: if `gcs_path` is not a valid GCS uri. + """ + del kwargs + if not _is_valid_gcs_path(gcs_path): + raise ValueError(f"Invalid gcs path: {gcs_path}") + + SklearnEstimatorSerializer._metadata.dependencies = ( + supported_frameworks._get_deps_if_sklearn_model(to_serialize) + ) + serialized = pickle.dumps(to_serialize, protocol=constants.PICKLE_PROTOCOL) + serializers_base.write_and_upload_data(data=serialized, gcs_filename=gcs_path) + + return gcs_path + + def deserialize(self, serialized_gcs_path: str, **kwargs) -> SklearnEstimator: + """Deserialize a sklearn estimator given the gcs file name. + + Args: + serialized_gcs_path (str): + Required. A GCS path to the serialized file. + + Returns: + A sklearn estimator. + + Raises: + ValueError: if `serialized_gcs_path` is not a valid GCS uri. + """ + del kwargs + if not _is_valid_gcs_path(serialized_gcs_path): + raise ValueError(f"Invalid gcs path: {serialized_gcs_path}") + + if serialized_gcs_path.startswith("gs://"): + with tempfile.NamedTemporaryFile() as temp_file: + gcs_utils.download_file_from_gcs(serialized_gcs_path, temp_file.name) + with open(temp_file.name, mode="rb") as f: + obj = pickle.load(f) + else: + with open(serialized_gcs_path, mode="rb") as f: + obj = pickle.load(f) + + return obj + + +class TorchModelSerializer(serializers_base.Serializer): + """A serializer for torch.nn.Module objects.""" + + _metadata: serializers_base.SerializationMetadata = ( + serializers_base.SerializationMetadata(serializer="TorchModelSerializer") + ) + + def serialize(self, to_serialize: TorchModel, gcs_path: str, **kwargs) -> str: + """Serializes a torch.nn.Module to a gcs path. + + Args: + to_serialize (torch.nn.Module): + Required. A PyTorch model object. + gcs_path (str): + Required. A GCS uri that the model will be saved to. + + Returns: + The GCS uri. + + Raises: + ValueError: if `gcs_path` is not a valid GCS uri. + """ + del kwargs + if not _is_valid_gcs_path(gcs_path): + raise ValueError(f"Invalid gcs path: {gcs_path}") + + TorchModelSerializer._metadata.dependencies = ( + supported_frameworks._get_deps_if_torch_model(to_serialize) + ) + + if gcs_path.startswith("gs://"): + with tempfile.NamedTemporaryFile() as temp_file: + torch.save( + to_serialize, + temp_file.name, + pickle_module=cloudpickle, + pickle_protocol=constants.PICKLE_PROTOCOL, + ) + gcs_utils.upload_to_gcs(temp_file.name, gcs_path) + else: + torch.save( + to_serialize, + gcs_path, + pickle_module=cloudpickle, + pickle_protocol=constants.PICKLE_PROTOCOL, + ) + + return gcs_path + + def deserialize(self, serialized_gcs_path: str, **kwargs) -> TorchModel: + """Deserialize a torch.nn.Module given the gcs file name. + + Args: + serialized_gcs_path (str): + Required. A GCS path to the serialized file. + + Returns: + A torch.nn.Module model. + + Raises: + ValueError: if `serialized_gcs_path` is not a valid GCS uri. + ImportError: if torch is not installed. + """ + del kwargs + if not _is_valid_gcs_path(serialized_gcs_path): + raise ValueError(f"Invalid gcs path: {serialized_gcs_path}") + + try: + import torch + except ImportError as e: + raise ImportError("torch is not installed.") from e + + map_location = ( + torch._GLOBAL_DEVICE_CONTEXT.device + if torch._GLOBAL_DEVICE_CONTEXT + else None + ) + + if serialized_gcs_path.startswith("gs://"): + with tempfile.NamedTemporaryFile() as temp_file: + gcs_utils.download_file_from_gcs(serialized_gcs_path, temp_file.name) + model = _load_torch_model(temp_file.name, map_location=map_location) + else: + model = _load_torch_model(serialized_gcs_path, map_location=map_location) + + return model + + +# TODO(b/289386023) Add unit tests for LightningTrainerSerialzier +class LightningTrainerSerializer(serializers_base.Serializer): + """A serializer for lightning.pytorch.Trainer objects.""" + + _metadata: serializers_base.SerializationMetadata = ( + serializers_base.SerializationMetadata(serializer="LightningTrainerSerializer") + ) + + def _serialize_to_local(self, to_serialize: LightningTrainer, path: str): + """Serializes a lightning.pytorch.Trainer to a local path. + + Args: + to_serialize (lightning.pytorch.Trainer): + Required. A lightning trainer object. + path (str): + Required. A local_path that the trainer will be saved to. + """ + # In remote environment, we store local accelerator connector and default root + # dir as attributes when we deserialize the trainer. And we need to serialize + # them in order to retrieve in local environment. + if getattr(to_serialize, "_vertex_local_accelerator_connector", None): + with open(f"{path}/local_accelerator_connector", "wb") as f: + cloudpickle.dump( + to_serialize._vertex_local_accelerator_connector, + f, + protocol=constants.PICKLE_PROTOCOL, + ) + delattr(to_serialize, "_vertex_local_accelerator_connector") + else: + with open(f"{path}/local_accelerator_connector", "wb") as f: + cloudpickle.dump( + to_serialize._accelerator_connector, + f, + protocol=constants.PICKLE_PROTOCOL, + ) + + if getattr(to_serialize, "_vertex_local_default_root_dir", None): + with open(f"{path}/local_default_root_dir", "wb") as f: + cloudpickle.dump( + to_serialize._vertex_local_default_root_dir, + f, + protocol=constants.PICKLE_PROTOCOL, + ) + delattr(to_serialize, "_vertex_local_default_root_dir") + else: + with open(f"{path}/local_default_root_dir", "wb") as f: + cloudpickle.dump( + to_serialize._default_root_dir, + f, + protocol=constants.PICKLE_PROTOCOL, + ) + + with open(f"{path}/trainer", "wb") as f: + cloudpickle.dump(to_serialize, f, protocol=constants.PICKLE_PROTOCOL) + + if os.path.exists(to_serialize.logger.root_dir): + shutil.copytree( + to_serialize.logger.root_dir, + f"{path}/{to_serialize.logger.name}", + dirs_exist_ok=True, + ) + + def serialize(self, to_serialize: LightningTrainer, gcs_path: str, **kwargs) -> str: + """Serializes a lightning.pytorch.Trainer to a gcs path. + + Args: + to_serialize (lightning.pytorch.Trainer): + Required. A lightning trainer object. + gcs_path (str): + Required. A GCS uri that the trainer will be saved to. + + Returns: + The GCS uri. + + Raises: + ValueError: if `gcs_path` is not a valid GCS uri. + """ + del kwargs + if not _is_valid_gcs_path(gcs_path): + raise ValueError(f"Invalid gcs path: {gcs_path}") + + LightningTrainerSerializer._metadata.dependencies = ( + supported_frameworks._get_deps_if_lightning_model(to_serialize) + + supported_frameworks._get_cloudpickle_deps() + ) + + if gcs_path.startswith("gs://"): + with tempfile.TemporaryDirectory() as temp_dir: + self._serialize_to_local(to_serialize, temp_dir) + gcs_utils.upload_to_gcs(temp_dir, gcs_path) + else: + os.makedirs(gcs_path) + self._serialize_to_local(to_serialize, gcs_path) + + return gcs_path + + def _deserialize_from_local(self, path: str) -> LightningTrainer: + """Deserialize a lightning.pytorch.Trainer given a local path. + + Args: + path (str): + Required. A local path to the serialized trainer. + + Returns: + A lightning.pytorch.Trainer object. + """ + with open(f"{path}/trainer", "rb") as f: + trainer = cloudpickle.load(f) + + if os.getenv("_IS_VERTEX_REMOTE_TRAINING") == "True": + # Store the logs in the cwd of remote environment. + trainer._default_root_dir = _LIGHTNING_ROOT_DIR + for logger in trainer.loggers: + # for TensorBoardLogger + if getattr(logger, "_root_dir", None): + logger._root_dir = trainer.default_root_dir + # for CSVLogger + if getattr(logger, "_save_dir", None): + logger._save_dir = trainer.default_root_dir + + # Store local accelerator connector and root dir as attributes, so that + # we can retrieve them in local environment. + with open(f"{path}/local_accelerator_connector", "rb") as f: + trainer._vertex_local_accelerator_connector = cloudpickle.load(f) + + with open(f"{path}/local_default_root_dir", "rb") as f: + trainer._vertex_local_default_root_dir = cloudpickle.load(f) + else: + with open(f"{path}/local_accelerator_connector", "rb") as f: + trainer._accelerator_connector = cloudpickle.load(f) + + with open(f"{path}/local_default_root_dir", "rb") as f: + trainer._default_root_dir = cloudpickle.load(f) + + for logger in trainer.loggers: + if getattr(logger, "_root_dir", None): + logger._root_dir = trainer.default_root_dir + if getattr(logger, "_save_dir", None): + logger._save_dir = trainer.default_root_dir + + for callback in trainer.checkpoint_callbacks: + callback.dirpath = os.path.join( + trainer.default_root_dir, + callback.dirpath.replace(_LIGHTNING_ROOT_DIR, ""), + ) + if callback.best_model_path: + callback.best_model_path = os.path.join( + trainer.default_root_dir, + callback.best_model_path.replace(_LIGHTNING_ROOT_DIR, ""), + ) + if callback.kth_best_model_path: + callback.kth_best_model_path = os.path.join( + trainer.default_root_dir, + callback.kth_best_model_path.replace(_LIGHTNING_ROOT_DIR, ""), + ) + if callback.last_model_path: + callback.last_model_path = os.path.join( + trainer.default_root_dir, + callback.last_model_path.replace(_LIGHTNING_ROOT_DIR, ""), + ) + + if os.path.exists(f"{path}/{trainer.logger.name}"): + shutil.copytree( + f"{path}/{trainer.logger.name}", + trainer.logger.root_dir, + dirs_exist_ok=True, + ) + + return trainer + + def deserialize(self, serialized_gcs_path: str, **kwargs) -> LightningTrainer: + """Deserialize a lightning.pytorch.Trainer given the gcs path. + + Args: + serialized_gcs_path (str): + Required. A GCS path to the serialized file. + + Returns: + A lightning.pytorch.Trainer object. + + Raises: + ValueError: if `serialized_gcs_path` is not a valid GCS uri. + """ + del kwargs + if not _is_valid_gcs_path(serialized_gcs_path): + raise ValueError(f"Invalid gcs path: {serialized_gcs_path}") + + if serialized_gcs_path.startswith("gs://"): + with tempfile.TemporaryDirectory() as temp_dir: + gcs_utils.download_from_gcs(serialized_gcs_path, temp_dir) + trainer = self._deserialize_from_local(temp_dir) + else: + trainer = self._deserialize_from_local(serialized_gcs_path) + + return trainer + + +class TorchDataLoaderSerializer(serializers_base.Serializer): + """A serializer for torch.utils.data.DataLoader objects.""" + + _metadata: serializers_base.SerializationMetadata = ( + serializers_base.SerializationMetadata(serializer="TorchDataLoaderSerializer") + ) + + def _serialize_to_local(self, to_serialize: TorchDataLoader, path: str): + """Serializes a torch.utils.data.DataLoader to a local path. + + Args: + to_serialize (torch.utils.data.DataLoader): + Required. A pytorch dataloader object. + path (str): + Required. A local_path that the dataloader will be saved to. + """ + # save objects by cloudpickle + with open(f"{path}/dataset.cpkl", "wb") as f: + cloudpickle.dump( + to_serialize.dataset, f, protocol=constants.PICKLE_PROTOCOL + ) + + with open(f"{path}/collate_fn.cpkl", "wb") as f: + cloudpickle.dump( + to_serialize.collate_fn, f, protocol=constants.PICKLE_PROTOCOL + ) + + with open(f"{path}/worker_init_fn.cpkl", "wb") as f: + cloudpickle.dump( + to_serialize.worker_init_fn, f, protocol=constants.PICKLE_PROTOCOL + ) + + # save (str, int, float, bool) values into a json file + pass_through_args = { + "num_workers": to_serialize.num_workers, + "pin_memory": to_serialize.pin_memory, + "timeout": to_serialize.timeout, + "prefetch_factor": to_serialize.prefetch_factor, + "persistent_workers": to_serialize.persistent_workers, + "pin_memory_device": to_serialize.pin_memory_device, + } + + # dataloader.generator is a torch.Generator object that defined in c++ + # it cannot be serialized by cloudpickle, so we store its device information + # and re-instaintiate a new Generator object with this device when deserializing + pass_through_args["generator_device"] = ( + to_serialize.generator.device.type if to_serialize.generator else None + ) + + # batch_sampler option is mutually exclusive with batch_size, shuffle, + # sampler, and drop_last. + # for default batch sampler we store batch_size, drop_last, and sampler object + # but not batch sampler object. + if isinstance(to_serialize.batch_sampler, torch.utils.data.BatchSampler): + pass_through_args["batch_size"] = to_serialize.batch_size + pass_through_args["drop_last"] = to_serialize.drop_last + + with open(f"{path}/sampler.cpkl", "wb") as f: + cloudpickle.dump( + to_serialize.sampler, f, protocol=constants.PICKLE_PROTOCOL + ) + # otherwise we only serialize batch sampler and skip batch_size, drop_last, + # and sampler object. + else: + with open(f"{path}/batch_sampler.cpkl", "wb") as f: + cloudpickle.dump( + to_serialize.batch_sampler, f, protocol=constants.PICKLE_PROTOCOL + ) + + with open(f"{path}/pass_through_args.json", "w") as f: + json.dump(pass_through_args, f) + + def serialize(self, to_serialize: TorchDataLoader, gcs_path: str, **kwargs) -> str: + """Serializes a torch.utils.data.DataLoader to a gcs path. + + Args: + to_serialize (torch.utils.data.DataLoader): + Required. A pytorch dataloader object. + gcs_path (str): + Required. A GCS uri that the dataloader will be saved to. + + Returns: + The GCS uri. + + Raises: + ValueError: if `gcs_path` is not a valid GCS uri. + """ + del kwargs + if not _is_valid_gcs_path(gcs_path): + raise ValueError(f"Invalid gcs path: {gcs_path}") + + TorchDataLoaderSerializer._metadata.dependencies = ( + supported_frameworks._get_deps_if_torch_dataloader(to_serialize) + ) + + if gcs_path.startswith("gs://"): + with tempfile.TemporaryDirectory() as temp_dir: + self._serialize_to_local(to_serialize, temp_dir) + gcs_utils.upload_to_gcs(temp_dir, gcs_path) + else: + os.makedirs(gcs_path) + self._serialize_to_local(to_serialize, gcs_path) + + return gcs_path + + def _deserialize_from_local(self, path: str) -> TorchDataLoader: + """Deserialize a torch.utils.data.DataLoader given a local path. + + Args: + path (str): + Required. A local path to the serialized dataloader. + + Returns: + A torch.utils.data.DataLoader object. + + Raises: + ImportError: if torch is not installed. + """ + try: + import torch + except ImportError as e: + raise ImportError( + f"torch is not installed and required to deserialize the file from {path}." + ) from e + + with open(f"{path}/pass_through_args.json", "r") as f: + kwargs = json.load(f) + + # re-instantiate Generator + if kwargs["generator_device"] is not None: + kwargs["generator"] = torch.Generator( + kwargs["generator_device"] if torch.cuda.is_available() else "cpu" + ) + kwargs.pop("generator_device") + + with open(f"{path}/dataset.cpkl", "rb") as f: + kwargs["dataset"] = cloudpickle.load(f) + + with open(f"{path}/collate_fn.cpkl", "rb") as f: + kwargs["collate_fn"] = cloudpickle.load(f) + + with open(f"{path}/worker_init_fn.cpkl", "rb") as f: + kwargs["worker_init_fn"] = cloudpickle.load(f) + + try: + with open(f"{path}/sampler.cpkl", "rb") as f: + kwargs["sampler"] = cloudpickle.load(f) + except FileNotFoundError: + pass + + try: + with open(f"{path}/batch_sampler.cpkl", "rb") as f: + kwargs["batch_sampler"] = cloudpickle.load(f) + except FileNotFoundError: + pass + + return torch.utils.data.DataLoader(**kwargs) + + def deserialize(self, serialized_gcs_path: str, **kwargs) -> TorchDataLoader: + """Deserialize a torch.utils.data.DataLoader given the gcs path. + + Args: + serialized_gcs_path (str): + Required. A GCS path to the serialized file. + + Returns: + A torch.utils.data.DataLoader object. + + Raises: + ValueError: if `serialized_gcs_path` is not a valid GCS uri. + ImportError: if torch is not installed. + """ + del kwargs + if not _is_valid_gcs_path(serialized_gcs_path): + raise ValueError(f"Invalid gcs path: {serialized_gcs_path}") + + if serialized_gcs_path.startswith("gs://"): + with tempfile.TemporaryDirectory() as temp_dir: + gcs_utils.download_from_gcs(serialized_gcs_path, temp_dir) + dataloader = self._deserialize_from_local(temp_dir) + else: + dataloader = self._deserialize_from_local(serialized_gcs_path) + + return dataloader + + +class TFDatasetSerializer(serializers_base.Serializer): + """Serializer responsible for serializing/deserializing a tf.data.Dataset.""" + + _metadata: serializers_base.SerializationMetadata = ( + serializers_base.SerializationMetadata(serializer="TFDatasetSerializer") + ) + + def serialize(self, to_serialize: TFDataset, gcs_path: str, **kwargs) -> str: + del kwargs + if not _is_valid_gcs_path(gcs_path): + raise ValueError(f"Invalid gcs path: {gcs_path}") + TFDatasetSerializer._metadata.dependencies = ( + supported_frameworks._get_deps_if_tensorflow_model(to_serialize) + ) + + try: + to_serialize.save(gcs_path) + except AttributeError: + tf.data.experimental.save(to_serialize, gcs_path) + return gcs_path + + def deserialize(self, serialized_gcs_path: str, **kwargs) -> TFDataset: + del kwargs + try: + deserialized = tf.data.Dataset.load(serialized_gcs_path) + except AttributeError: + deserialized = tf.data.experimental.load(serialized_gcs_path) + return deserialized + + +class PandasDataSerializer(serializers_base.Serializer): + """Serializer for pandas DataFrames.""" + + _metadata: serializers_base.SerializationMetadata = ( + serializers_base.SerializationMetadata(serializer="PandasDataSerializer") + ) + + def serialize(self, to_serialize: PandasData, gcs_path: str, **kwargs) -> str: + del kwargs + if not _is_valid_gcs_path(gcs_path): + raise ValueError(f"Invalid gcs path: {gcs_path}") + + PandasDataSerializer._metadata.dependencies = ( + supported_frameworks._get_deps_if_pandas_dataframe(to_serialize) + ) + + if gcs_path.startswith("gs://"): + with tempfile.NamedTemporaryFile() as temp_file: + to_serialize.to_parquet(temp_file.name) + temp_file.flush() + temp_file.seek(0) + + gcs_utils.upload_to_gcs(temp_file.name, gcs_path) + else: + to_serialize.to_parquet(gcs_path) + + def deserialize(self, serialized_gcs_path: str, **kwargs) -> PandasData: + del kwargs + try: + import pandas as pd + except ImportError as e: + raise ImportError( + f"pandas is not installed and required to deserialize the file from {serialized_gcs_path}." + ) from e + + if serialized_gcs_path.startswith("gs://"): + with tempfile.NamedTemporaryFile() as temp_file: + gcs_utils.download_file_from_gcs(serialized_gcs_path, temp_file.name) + return pd.read_parquet(temp_file.name) + else: + return pd.read_parquet(serialized_gcs_path) + + +class PandasDataSerializerDev(serializers_base.Serializer): + """Serializer responsible for serializing/deserializing a pandas DataFrame.""" + + _metadata: serializers_base.SerializationMetadata = ( + serializers_base.SerializationMetadata(serializer="PandasDataSerializerDev") + ) + + def __init__(self): + super().__init__() + self.helper = data_serializer_utils._Helper() + + def serialize(self, to_serialize: PandasData, gcs_path: str, **kwargs) -> str: + del kwargs + PandasDataSerializerDev._metadata.dependencies = ( + supported_frameworks._get_deps_if_pandas_dataframe(to_serialize) + ) + try: + if not ( + isinstance(to_serialize.index, pd.MultiIndex) + or isinstance(to_serialize.columns, pd.MultiIndex) + ): + self.helper.create_placeholder_col_names(to_serialize) + self.helper.cast_int_to_str( + to_serialize, action=data_serializer_utils.ActionType.CAST_COL_NAME + ) + self.helper.cast_int_to_str( + to_serialize, action=data_serializer_utils.ActionType.CAST_ROW_INDEX + ) + self.helper.cast_int_to_str( + to_serialize, + action=data_serializer_utils.ActionType.CAST_CATEGORICAL, + ) + table = pa.Table.from_pandas(to_serialize) + custom_metadata = { + data_serializer_utils.df_restore_func_metadata_key.encode(): json.dumps( + self.helper.restore_df_actions + ).encode(), + data_serializer_utils.df_restore_func_args_metadata_key.encode(): json.dumps( + self.helper.restore_df_actions_args + ).encode(), + **table.schema.metadata, + } + table = table.replace_schema_metadata(custom_metadata) + + with tempfile.TemporaryDirectory() as temp_dir: + fp = os.path.join(temp_dir, f"{uuid.uuid4()}.parquet") + pq.write_table(table, fp, compression="GZIP") + gcs_utils.upload_to_gcs(fp, gcs_path) + finally: + # undo ad-hoc mutations on the dataframe + self.helper.restore_df_actions.reverse() + self.helper.restore_df_actions_args.reverse() + for func_str, args in zip( + self.helper.restore_df_actions, self.helper.restore_df_actions_args + ): + func = getattr(self.helper, func_str) + func(to_serialize, *args) if len(args) > 0 else func(to_serialize) + return gcs_path + + def deserialize(self, serialized_gcs_path: str, **kwargs) -> PandasData: + del kwargs + restored_table = pq.read_table(serialized_gcs_path) + restored_df = restored_table.to_pandas() + + # get custom metadata + restore_func_array_json = restored_table.schema.metadata[ + data_serializer_utils.df_restore_func_metadata_key.encode() + ] + restore_func_array = json.loads(restore_func_array_json) + restore_func_array_args_json = restored_table.schema.metadata[ + data_serializer_utils.df_restore_func_args_metadata_key.encode() + ] + restore_func_array_args = json.loads(restore_func_array_args_json) + restore_func_array.reverse() + restore_func_array_args.reverse() + + for func_str, args in zip(restore_func_array, restore_func_array_args): + func = getattr(self.helper, func_str) + func(restored_df, *args) if len(args) > 0 else func(restored_df) + return restored_df + + +@dataclasses.dataclass +class BigframeSerializationMetadata(serializers_base.SerializationMetadata): + """Metadata of BigframeSerializer class. + + Stores extra framework attribute + """ + + framework: Optional[str] = None + + def to_dict(self): + dct = super().to_dict() + dct.update({SERIALIZATION_METADATA_FRAMEWORK_KEY: self.framework}) + return dct + + +class BigframeSerializer(serializers_base.Serializer): + """Serializer responsible for serializing/deserializing a BigFrames DataFrame. + + Serialization: All frameworks serialize bigframes.dataframe.DataFrame -> parquet (GCS) + Deserialization: Framework specific deserialize methods are called + """ + + _metadata: serializers_base.SerializationMetadata = BigframeSerializationMetadata( + serializer="BigframeSerializer", framework=None + ) + + def serialize( + self, + to_serialize: Union[BigframesData, PandasData], + gcs_path: str, + **kwargs, + ) -> str: + # All bigframe serializers will be identical (bigframes.dataframe.DataFrame --> parquet) + # Record the framework in metadata for deserialization + BigframeSerializer._metadata.framework = kwargs.get("framework") + if not _is_valid_gcs_path(gcs_path): + raise ValueError(f"Invalid gcs path: {gcs_path}") + + BigframeSerializer._metadata.dependencies = ( + supported_frameworks._get_deps_if_bigframe(to_serialize) + ) + + # Check if index.name is default and set index.name if not + if to_serialize.index.name and to_serialize.index.name != "index": + raise ValueError("Index name must be 'index'") + if to_serialize.index.name is None: + to_serialize.index.name = "index" + + # Convert BigframesData to Parquet (GCS) + parquet_gcs_path = gcs_path + "/*" # path is required to contain '*' + to_serialize.to_parquet(parquet_gcs_path, index=True) + + def deserialize( + self, serialized_gcs_path: str, **kwargs + ) -> Union[PandasData, BigframesData]: + del kwargs + + detected_framework = BigframeSerializer._metadata.framework + if detected_framework == "sklearn": + return self._deserialize_sklearn(serialized_gcs_path) + elif detected_framework == "torch": + return self._deserialize_torch(serialized_gcs_path) + elif detected_framework == "tensorflow": + return self._deserialize_tensorflow(serialized_gcs_path) + else: + raise ValueError(f"Unsupported framework: {detected_framework}") + + def _deserialize_sklearn(self, serialized_gcs_path: str) -> PandasData: + """Sklearn deserializes parquet (GCS) --> pandas.DataFrame + + By default, sklearn returns a numpy array which uses CloudPickleSerializer. + If a bigframes.dataframe.DataFrame is desired for the return type, + b/291147206 (cl/548228568) is required + """ + # Deserialization at remote environment + try: + import pandas as pd + except ImportError as e: + raise ImportError( + f"pandas is not installed and required to deserialize the file from {serialized_gcs_path}." + ) from e + + # Deserialization always happens at remote, so gcs filesystem is mounted to /gcs/ + # pd.read_parquet auto-merges a directory of parquet files + pd_dataframe = pd.read_parquet(serialized_gcs_path) + + # Drop index now that ordering is guaranteed + if "index" in pd_dataframe.columns: + pd_dataframe.drop(columns=["index"], inplace=True) + + return pd_dataframe + + def _deserialize_torch(self, serialized_gcs_path: str) -> TorchTensor: + """Torch deserializes parquet (GCS) --> torch.tensor + + Assumes one parquet file is created. + """ + # Deserialization at remote environment + try: + from torchdata.datapipes.iter import FileLister + except ImportError as e: + raise ImportError( + f"torchdata is not installed and required to deserialize the file from {serialized_gcs_path}." + ) from e + + # Deserialization always happens at remote, so gcs filesystem is mounted to /gcs/ + # TODO(b/295335262): Implement torch lazy read + source_dp = FileLister(serialized_gcs_path, masks="") + parquet_df_dp = source_dp.load_parquet_as_df() + + def preprocess(torch_df): + torch_df = torch_df.drop("index") + df_tensor = torch_df.to_tensor() + + # Convert from TorchStruct to Tensor + cols = [] + for i in range(len(df_tensor)): + col = df_tensor[i].values + col = col[:, None] + cols.append(col) + deserialized_tensor = torch.cat(cols, 1) + return deserialized_tensor + + parquet_df_dp = parquet_df_dp.map(preprocess) + + def reduce_tensors(a, b): + return torch.concat((a, b), axis=0) + + return functools.reduce(reduce_tensors, list(parquet_df_dp)) + + def _deserialize_tensorflow(self, serialized_gcs_path: str) -> TFDataset: + """Tensorflow deserializes parquet (GCS) --> tf.data.Dataset + + Assumes one parquet file is created. + """ + # Deserialization at remote environment + try: + import tensorflow_io as tfio + except ImportError as e: + raise ImportError( + f"tensorflow_io is not installed and required to deserialize the file from {serialized_gcs_path}." + ) from e + + # Deserialization always happens at remote, so gcs filesystem is mounted to /gcs/ + # TODO(b/296475384): Handle multiple parquet shards + if len(os.listdir(serialized_gcs_path + "/")) > 1: + raise RuntimeError( + "Large datasets which are serialized into sharded parquet are not yet supported (b/296475384)" + ) + + single_parquet_gcs_path = serialized_gcs_path + "/" + "000000000000" + ds = tfio.IODataset.from_parquet(single_parquet_gcs_path) + + # TODO(b/296474656) Parquet must have "target" column for y + def map_fn(row): + target = row[b"target"] + row = { + k: tf.expand_dims(v, -1) + for k, v in row.items() + if k != b"target" and k != b"index" + } + + def reduce_fn(a, b): + return tf.concat((a, b), axis=0) + + return functools.reduce(reduce_fn, row.values()), target + + # TODO(b/295535730): Remove hardcoded batch_size of 32 + return ds.map(map_fn).batch(32) + + +class CloudPickleSerializer(serializers_base.Serializer): + """Serializer that uses cloudpickle to serialize the object.""" + + _metadata: serializers_base.SerializationMetadata = ( + serializers_base.SerializationMetadata(serializer="CloudPickleSerializer") + ) + + def serialize(self, to_serialize: Any, gcs_path: str, **kwargs) -> str: + """Use cloudpickle to serialize a python object to a gcs file path. + + Args: + to_serialize (Any): + Required. A python object. + gcs_path (str): + Required. A GCS uri that the estimator will be saved to. + + Returns: + The GCS uri. + + Raises: + ValueError: if `gcs_path` is not a valid GCS uri. + """ + del kwargs + if not _is_valid_gcs_path(gcs_path): + raise ValueError(f"Invalid gcs path: {gcs_path}") + + CloudPickleSerializer._metadata.dependencies = ( + supported_frameworks._get_estimator_requirement(to_serialize) + ) + serialized = cloudpickle.dumps(to_serialize, protocol=constants.PICKLE_PROTOCOL) + serializers_base.write_and_upload_data(data=serialized, gcs_filename=gcs_path) + return gcs_path + + def deserialize(self, serialized_gcs_path: str, **kwargs) -> Any: + """Use cloudpickle to deserialize a python object given the object's gcs file path. + + Args: + serialized_gcs_path (str): + Required. A GCS path to the serialized file. + + Returns: + A python object. + + Raises: + ValueError: if `serialized_gcs_path` is not a valid GCS uri. + """ + del kwargs + if not _is_valid_gcs_path(serialized_gcs_path): + raise ValueError(f"Invalid gcs path: {serialized_gcs_path}") + + if serialized_gcs_path.startswith("gs://"): + with tempfile.NamedTemporaryFile() as temp_file: + gcs_utils.download_file_from_gcs(serialized_gcs_path, temp_file.name) + with open(temp_file.name, mode="rb") as f: + obj = cloudpickle.load(f) + else: + with open(serialized_gcs_path, mode="rb") as f: + obj = cloudpickle.load(f) + + return obj diff --git a/vertexai/preview/_workflow/serialization_engine/serializers_base.py b/vertexai/preview/_workflow/serialization_engine/serializers_base.py new file mode 100644 index 0000000000..a1add75355 --- /dev/null +++ b/vertexai/preview/_workflow/serialization_engine/serializers_base.py @@ -0,0 +1,336 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=line-too-long, bad-continuation,protected-access +"""Defines the Serializer classes.""" + +import abc +import dataclasses +import os +import pathlib +import tempfile +from typing import Any, Dict, List, Optional, Type, TypeVar, Union + +from google.cloud.aiplatform.utils import gcs_utils + + +# TODO(b/272263750): use the centralized module and usage pattern to guard these +# imports +# pylint: disable=g-import-not-at-top +try: + import pandas as pd + import bigframes as bf + + PandasData = pd.DataFrame + BigframesData = bf.dataframe.DataFrame +except ImportError: + pd = None + bf = None + PandasData = Any + BigframesData = Any + +try: + import pandas as pd + import pyarrow as pa + import pyarrow.parquet as pq + + PandasData = pd.DataFrame +except ImportError: + pd = None + pa = None + pq = None + PandasData = Any + +try: + import sklearn + + SklearnEstimator = sklearn.base.BaseEstimator +except ImportError: + sklearn = None + SklearnEstimator = Any + +try: + from tensorflow import keras + import tensorflow as tf + + KerasModel = keras.models.Model + TFDataset = tf.data.Dataset +except ImportError: + keras = None + tf = None + KerasModel = Any + TFDataset = Any + +try: + import torch + + TorchModel = torch.nn.Module + TorchDataLoader = torch.utils.data.DataLoader +except ImportError: + torch = None + TorchModel = Any + TorchDataLoader = Any + +try: + import lightning.pytorch as pl + + LightningTrainer = pl.Trainer +except ImportError: + pl = None + LightningTrainer = Any + + +Types = Union[ + PandasData, + BigframesData, + SklearnEstimator, + KerasModel, + TorchModel, + LightningTrainer, +] +T = TypeVar("T") +SERIALIZATION_METADATA_FILENAME = "serialization_metadata" +SERIALIZATION_METADATA_SERIALIZER_KEY = "serializer" +SERIALIZATION_METADATA_DEPENDENCIES_KEY = "dependencies" + + +@dataclasses.dataclass +class SerializationMetadata: + """Metadata of Serializer classes. + + This is supposed to be a class attribute named `_metadata` of the Serializer + class. + + Example Usage: + ``` + import vertexai + + # define a custom Serializer + class KerasCustomSerializer( + vertexai.preview.developer.Serializer): + # make a metadata + _metadata = vertexai.preview.developer.SerializationMetadata() + + def serialize(self, to_serialize, gcs_path): + ... + def deserialize(self, gcs_path): + ... + ``` + """ + + serializer: Optional[str] = None + dependencies: List[str] = dataclasses.field(default_factory=list) + + def to_dict(self): + return { + SERIALIZATION_METADATA_SERIALIZER_KEY: self.serializer, + SERIALIZATION_METADATA_DEPENDENCIES_KEY: self.dependencies, + } + + +class SerializationError(Exception): + """Raised when the object fails to be serialized.""" + + pass + + +def write_and_upload_data(data: bytes, gcs_filename: str): + """Writes data to a local temp file and uploads the file to gcs. + + Args: + data (bytes): + Required. Bytes data to write. + gcs_filename (str): + Required. A gcs file path. + """ + if gcs_filename.startswith("gs://"): + with tempfile.NamedTemporaryFile() as temp_file: + temp_file.write(data) + temp_file.flush() + temp_file.seek(0) + + gcs_utils.upload_to_gcs(temp_file.name, gcs_filename) + else: + with open(gcs_filename, mode="wb") as f: + f.write(data) + + +def _get_uri_prefix(gcs_uri: str) -> str: + """Gets the directory of the gcs_uri. + + Example: + 1) file uri: + _get_uri_prefix("gs:///directory/file.extension") == "gs:// + /directory/" + 2) folder uri: + _get_uri_prefix("gs:///parent_dir/dir") == "gs:/// + parent_dir/" + Args: + gcs_uri: A string starting with "gs://" that represent a gcs uri. + Returns: + The parent gcs directory in string format. + """ + # For tensorflow, the uri may be "gs://my-bucket/saved_model/" + if gcs_uri.endswith("/"): + gcs_uri = gcs_uri[:-1] + gcs_pathlibpath = pathlib.Path(gcs_uri) + file_name = gcs_pathlibpath.name + return gcs_uri[: -len(file_name)] + + +def _get_metadata_path_from_file_gcs_uri(gcs_uri: str) -> str: + gcs_pathlibpath = pathlib.Path(gcs_uri) + prefix = _get_uri_prefix(gcs_uri=gcs_uri) + return os.path.join( + prefix, + f"{SERIALIZATION_METADATA_FILENAME}_{gcs_pathlibpath.stem}.json", + ) + + +def _get_custom_serializer_path_from_file_gcs_uri( + gcs_uri: str, serializer_name: str +) -> str: + prefix = _get_uri_prefix(gcs_uri=gcs_uri) + return os.path.join(prefix, f"{serializer_name}") + + +def _load_torch_model(path: str, map_location: "torch.device") -> TorchModel: + try: + return torch.load(path, map_location=map_location) + except Exception: + return torch.load(path, map_location=torch.device("cpu")) + + +class Serializer(metaclass=abc.ABCMeta): + """Abstract class of serializers. + + custom Serializers should be subclasses of this class. + Example Usage: + + ``` + import vertexai + + # define a custom Serializer + class KerasCustomSerializer( + vertexai.preview.developer.Serializer): + _metadata = vertexai.preview.developer.SerializationMetadata() + + def serialize(self, to_serialize, gcs_path): + ... + def deserialize(self, gcs_path): + ... + + KerasCustomSerializer.register_requirements( + ['library1==1.0.0', 'library2<2.0']) + vertexai.preview.developer.register_serializer( + keras.models.Model, KerasCustomSerializer) + ``` + """ + + _serialization_scheme: Dict[Type[Any], Optional[Type["Serializer"]]] = {} + _custom_serialization_scheme: Dict[Type[Any], Optional[Type["Serializer"]]] = {} + # _instances holds the instance of each Serializer for each type. + _instances: Dict[Type["Serializer"], "Serializer"] = {} + _metadata: SerializationMetadata = SerializationMetadata() + + def __new__(cls): + try: + import cloudpickle # noqa:F401 + except ImportError as e: + raise ImportError( + "cloudpickle is not installed. Please call `pip install google-cloud-aiplatform[preview]`." + ) from e + + if cls not in Serializer._instances: + Serializer._instances[cls] = super().__new__(cls) + if cls._metadata.serializer is None: + cls._metadata.serializer = cls.__name__ + return Serializer._instances[cls] + + @abc.abstractmethod + def serialize( + self, + to_serialize: T, + gcs_path: str, + **kwargs, + ) -> Union[Dict[str, Any], str]: # pytype: disable=invalid-annotation + raise NotImplementedError + + @abc.abstractmethod + def deserialize( + self, + serialized_gcs_path: str, + **kwargs, + ) -> T: # pytype: disable=invalid-annotation + raise NotImplementedError + + @classmethod + def _register( + cls, to_serialize_type: Type[Any], serializer_cls: Type["Serializer"] + ): + cls._serialization_scheme[to_serialize_type] = serializer_cls + + @classmethod + def register_custom( + cls, to_serialize_type: Type[Any], serializer_cls: Type["Serializer"] + ): + """Registers a custom serializer for a specific type. + + Example Usage: + ``` + # define a custom Serializer + class KerasCustomSerializer(serialization_engine.Serializer): + _metadata = serialization_engine.SerializationMetadata() + def serialize(self, to_serialize, gcs_path): + ... + def deserialize(self, gcs_path): + ... + + any_serializer = serialization_engine.AnySerializer() + any_serializer.register_custom(keras.models.Model, KerasCustomSerializer) + ``` + Args: + to_serialize_type: The class that is supposed to be serialized with + the to-be-registered custom Serializer. + serializer_cls: The custom Serializer to be registered. + """ + cls._custom_serialization_scheme[to_serialize_type] = serializer_cls + + @classmethod + def get_instance(cls) -> "Serializer": + if cls not in Serializer._instances: + Serializer._instances[cls] = cls() + return Serializer._instances[cls] + + @classmethod + def _dedupe_deps(cls): + # TODO(b/282719450): Consider letting the later specifier to overwrite + # earlier specifier for the same package, and automatically detecting + # the version if version is not specified. + cls._metadata.dependencies = list(dict.fromkeys(cls._metadata.dependencies)) + + @classmethod + def register_requirement(cls, required_package: str): + # TODO(b/280648121) Consider allowing the user to register the + # installation command so that we support installing packages not + # covered by PyPI in the remote machine. + cls._metadata.dependencies.append(required_package) + cls._dedupe_deps() + + @classmethod + def register_requirements(cls, requirements: List[str]): + cls._metadata.dependencies.extend(requirements) + cls._dedupe_deps() diff --git a/vertexai/preview/_workflow/shared/__init__.py b/vertexai/preview/_workflow/shared/__init__.py new file mode 100644 index 0000000000..7d0a81ae8f --- /dev/null +++ b/vertexai/preview/_workflow/shared/__init__.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import dataclasses +import inspect +from typing import Any, Callable, Dict, Optional + +from vertexai.preview._workflow.shared import configs + + +@dataclasses.dataclass(frozen=True) +class _Invokable: + """Represents a single invokable method. + + method: The method to invoke. + bound_arguments: The arguments to use to invoke the method. + vertex_config: User-specified configs for Vertex services. + remote_executor: The executor that execute the method remotely. + remote_executor_kwargs: kwargs used in the remote executor. + instance: The instance the method is bound. + """ + + method: Callable[..., Any] + bound_arguments: inspect.BoundArguments + vertex_config: configs.VertexConfig + remote_executor: Callable[..., Any] + remote_executor_kwargs: Optional[Dict[str, Any]] = None + instance: Optional[Any] = None diff --git a/vertexai/preview/_workflow/shared/configs.py b/vertexai/preview/_workflow/shared/configs.py new file mode 100644 index 0000000000..07153508f4 --- /dev/null +++ b/vertexai/preview/_workflow/shared/configs.py @@ -0,0 +1,301 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import dataclasses +from typing import List, Optional + + +@dataclasses.dataclass +class _BaseConfig: + """A class that holds configuration that can be shared across different remote services. + + Attributes: + display_name (str): + The display name of the remote job. + staging_bucket (str): + Base GCS directory of the remote job. All the input and + output artifacts will be saved here. If not provided a timestamped + directory in the default staging bucket will be used. + container_uri (str): + Uri of the training container image to use for remote job. + Support images in Artifact Registry, Container Registry, or Docker Hub. + machine_type (str): + The type of machine to use for remote training. + accelerator_type (str): + Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, + NVIDIA_TESLA_A100, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, + NVIDIA_TESLA_K80, NVIDIA_TESLA_T4, NVIDIA_TESLA_P4 + accelerator_count (int): + The number of accelerators to attach to a worker replica. + worker_pool_specs (vertexai.preview.developer.remote_specs.WorkerPoolSpecs): + The worker pool specs configuration for a remote job. + """ + + display_name: Optional[str] = None + staging_bucket: Optional[str] = None + container_uri: Optional[str] = None + machine_type: Optional[str] = None + accelerator_type: Optional[str] = None + accelerator_count: Optional[int] = None + worker_pool_specs: Optional[ + "vertexai.preview.developer.remote_specs.WorkerPoolSpecs" # noqa: F821 + ] = None + + +@dataclasses.dataclass +class RemoteConfig(_BaseConfig): + """A class that holds the configuration for Vertex remote training. + + Attributes: + enable_cuda (bool): + When set to True, Vertex will automatically choose a GPU image and + accelerators for the remote job and train the model on cuda devices. + You can also specify the image and accelerators by yourself through + `container_uri`, `accelerator_type`, `accelerator_count`. + Supported frameworks: keras, torch.nn.Module + Default configs: + container_uri="pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime" or "tensorflow/tensorflow:2.12.0-gpu" + machine_type="n1-standard-16" + accelerator_type="NVIDIA_TESLA_P100" + accelerator_count=1 + enable_distributed (bool): + When set to True, Vertex will automatically choose a GPU or CPU + distributed training configuration depending on the value of `enable_cuda`. + You can also specify a custom configuration by yourself through `worker_pool_specs`. + Supported frameworks: keras (requires TensorFlow >= 2.12.0), torch.nn.Module + Default configs: + If `enable_cuda` = True, for both the `chief` and `worker` specs: + machine_type="n1-standard-16" + accelerator_type="NVIDIA_TESLA_P100" + accelerator_count=1 + If `enable_cuda` = False, for both the `chief` and `worker` specs: + machine_type="n1-standard-4" + replica_count=1 + enable_full_logs (bool): + When set to True, all the logs from the remote job will be shown locally. + Otherwise, only training related logs will be shown. + service_account (str): + Specifies the service account for running the remote job. To use + autologging feature, you need to set it to "gce", which refers + to the GCE service account, or set it to another service account. + Please make sure your own service account has the Storage Admin role + and Vertex AI User role. + requirements (List[str]): + List of python packages dependencies that will be installed in the remote + job environment. In most cases Vertex will handle the installation of + dependencies that are required for running the remote job. You can use + this field to specify extra packages to install in the remote environment. + custom_commands (List[str]): + List of custom commands to be run in the remote job environment. + These commands will be run before the requirements are installed. + """ + + enable_cuda: bool = False + enable_distributed: bool = False + enable_full_logs: bool = False + service_account: Optional[str] = None + requirements: List[str] = dataclasses.field(default_factory=list) + custom_commands: List[str] = dataclasses.field(default_factory=list) + + +@dataclasses.dataclass +class DistributedTrainingConfig(_BaseConfig): + """A class that holds the configs for a distributed training remote job. + + Attributes: + replica_count (int): + The number of worker replicas. Assigns 1 chief replica and + replica_count - 1 worker replicas. + boot_disk_type (str): + Type of the boot disk (default is `pd-ssd`). + Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or + `pd-standard` (Persistent Disk Hard Disk Drive). + boot_disk_size_gb (int): + Size in GB of the boot disk (default is 100GB). + boot disk size must be within the range of [100, 64000]. + """ + + replica_count: Optional[int] = None + boot_disk_type: Optional[str] = None + boot_disk_size_gb: Optional[int] = None + + +@dataclasses.dataclass +class VertexConfig: + """A class that holds the configuration for the method wrapped by Vertex. + + Attributes: + remote (bool): + Whether or not this method will be executed remotely on Vertex. If not + set, Vertex will check the remote setting in `vertexai.preview.init(...)` + remote_config (RemoteConfig): + A class that holds the configuration for the remote job. + """ + + remote: Optional[bool] = None + remote_config: RemoteConfig = dataclasses.field(default_factory=RemoteConfig) + + def set_config( + self, + display_name: Optional[str] = None, + staging_bucket: Optional[str] = None, + container_uri: Optional[str] = None, + machine_type: Optional[str] = None, + accelerator_type: Optional[str] = None, + accelerator_count: Optional[int] = None, + worker_pool_specs: Optional[ + "vertexai.preview.developer.remote_specs.WorkerPoolSpecs" # noqa: F821 + ] = None, + enable_cuda: bool = False, + enable_distributed: bool = False, + enable_full_logs: bool = False, + service_account: Optional[str] = None, + requirements: List[str] = [], + custom_commands: List[str] = [], + replica_count: Optional[int] = None, + boot_disk_type: Optional[str] = None, + boot_disk_size_gb: Optional[int] = None, + ): + """Sets configuration attributes for a remote job. + + Calling this will overwrite any previously set job configuration attributes. + + Example usage: + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_BUCKET_NAME, + ) + vertexai.preview.init(remote=True) + + LogisticRegression = vertexai.preview.remote(_logistic.LogisticRegression) + model = LogisticRegression() + + model.fit.vertex.set_config( + display_name="my-display-name", + staging_bucket="gs://my-bucket", + container_uri="gcr.io/custom-image, + ) + + Args: + display_name (str): + The display name of the remote job. + staging_bucket (str): + Base GCS directory of the remote job. All the input and + output artifacts will be saved here. If not provided a timestamped + directory in the default staging bucket will be used. + container_uri (str): + Uri of the training container image to use for remote job. + Support images in Artifact Registry, Container Registry, or Docker Hub. + machine_type (str): + The type of machine to use for remote training. + accelerator_type (str): + Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, + NVIDIA_TESLA_A100, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, + NVIDIA_TESLA_K80, NVIDIA_TESLA_T4, NVIDIA_TESLA_P4 + accelerator_count (int): + The number of accelerators to attach to a worker replica. + worker_pool_specs (vertexai.preview.developer.remote_specs.WorkerPoolSpecs): + The worker pool specs configuration for a remote job. + enable_cuda (bool): + When set to True, Vertex will automatically choose a GPU image and + accelerators for the remote job and train the model on cuda devices. + This parameter is specifically for TrainingConfig. + You can also specify the image and accelerators by yourself through + `container_uri`, `accelerator_type`, `accelerator_count`. + Supported frameworks: keras, torch.nn.Module + Default configs: + container_uri="pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime" or "tensorflow/tensorflow:2.12.0-gpu" + machine_type="n1-standard-16" + accelerator_type="NVIDIA_TESLA_P100" + accelerator_count=1 + enable_distributed (bool): + When set to True, Vertex will automatically choose a GPU or CPU + distributed training configuration depending on the value of `enable_cuda`. + You can also specify a custom configuration by yourself through `worker_pool_specs`. + This parameter is specifically for TrainingConfig. + Supported frameworks: keras (requires TensorFlow >= 2.12.0), torch.nn.Module + Default configs: + If `enable_cuda` = True, for both the `chief` and `worker` specs: + machine_type="n1-standard-16" + accelerator_type="NVIDIA_TESLA_P100" + accelerator_count=1 + If `enable_cuda` = False, for both the `chief` and `worker` specs: + machine_type="n1-standard-4" + replica_count=1 + enable_full_logs (bool): + When set to True, all the logs from the remote job will be shown locally. + Otherwise, only training related logs will be shown. + service_account (str): + Specifies the service account for running the remote job. To use + autologging feature, you need to set it to "gce", which refers + to the GCE service account, or set it to another service account. + Please make sure your own service account has the Storage Admin role + and Vertex AI User role. This parameter is specifically for TrainingConfig. + requirements (List[str]): + List of python packages dependencies that will be installed in the remote + job environment. In most cases Vertex will handle the installation of + dependencies that are required for running the remote job. You can use + this field to specify extra packages to install in the remote environment. + This parameter is specifically for TrainingConfig. + custom_commands (List[str]): + List of custom commands to be run in the remote job environment. + These commands will be run before the requirements are installed. + replica_count (int): + The number of worker replicas. Assigns 1 chief replica and + replica_count - 1 worker replicas. This is specifically for + DistributedTrainingConfig. + boot_disk_type (str): + Type of the boot disk (default is `pd-ssd`). + Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or + `pd-standard` (Persistent Disk Hard Disk Drive). This is specifically for + DistributedTrainingConfig. + boot_disk_size_gb (int): + Size in GB of the boot disk (default is 100GB). + boot disk size must be within the range of [100, 64000]. This is specifically for + DistributedTrainingConfig. + """ + + # locals() contains a 'self' key in addition to function args + kwargs = locals() + + config = self.remote_config.__class__() + + for config_arg in kwargs: + if hasattr(config, config_arg): + setattr(config, config_arg, kwargs[config_arg]) + + # raise if a value was passed for an unsupported config attribute (i.e. boot_disk_type on TrainingConfig) + elif config_arg != "self" and kwargs[config_arg]: + raise ValueError( + f"{type(self.remote_config)} has no attribute {config_arg}." + ) + + self.remote_config = config + + +@dataclasses.dataclass +class PersistentResourceConfig: + """A class that holds persistent resource configuration during initialization. + + Attributes: + name (str): + The cluster name of the remote job. This value may be up to 63 + characters, and valid characters are `[a-z0-9_-]`. The first character + cannot be a number or hyphen. + """ + + name: Optional[str] = None diff --git a/vertexai/preview/_workflow/shared/constants.py b/vertexai/preview/_workflow/shared/constants.py new file mode 100644 index 0000000000..8cf36a2cca --- /dev/null +++ b/vertexai/preview/_workflow/shared/constants.py @@ -0,0 +1,21 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Constants used by vertexai.""" + +PICKLE_PROTOCOL = 4 + +_START_EXECUTION_MSG = "Start remote execution on Vertex..." +_END_EXECUTION_MSG = "Remote execution is completed." diff --git a/vertexai/preview/_workflow/shared/data_serializer_utils.py b/vertexai/preview/_workflow/shared/data_serializer_utils.py new file mode 100644 index 0000000000..43836e5919 --- /dev/null +++ b/vertexai/preview/_workflow/shared/data_serializer_utils.py @@ -0,0 +1,186 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import List, Any, Union +from enum import Enum + +try: + import pandas as pd + + PandasData = pd.DataFrame +except ImportError: + pd = None + PandasData = Any + +df_restore_func_metadata_key = "restore_df_actions" +df_restore_func_args_metadata_key = "restore_df_actions_args" + + +class ActionType(str, Enum): + CAST_COL_NAME = "CAST_COL_NAME" + CAST_ROW_INDEX = "CAST_ROW_INDEX" + CAST_CATEGORICAL = "CAST_CATEGORICAL" + + +class _Helper: + def __init__(self): + if not pd: + raise ImportError( + "pandas is not installed and required for Pandas Serializer." + ) + self.restore_df_actions = [] + self.restore_df_actions_args = [] + self.restore_func_metadata_key = "restore_df_actions" + self.restore_func_args_metadata_key = "restore_df_actions_args" + + def create_placeholder_col_names(self, df: PandasData): + """Creates placeholder column names for dataframes without column names. + + Args: + df (pd.DataFrame): + Required. This is the dataframe to serialize. + """ + if isinstance(df.columns, pd.RangeIndex): + df.columns = [str(x) for x in df.columns] + self.restore_df_actions.append("remove_placeholder_col_names") + self.restore_df_actions_args.append([]) + + def remove_placeholder_col_names(self, df: PandasData): + df.columns = pd.RangeIndex(start=0, stop=len(df.columns), step=1) + + def _append_to_temp_indices( + self, temp_indices: List[str], name: Any, action: ActionType + ): + """ + This function is a helper for the cast_int_to_str function. + + Args: + temp_indices (List[str]): a temporary array of indices that keeps track + of the original values of the column or row indices. + + name (Any): the name of the column or row. Note that this could be any type, + but Vertex only handles integer-to-string casting. Users who attempt to + serialize Pandas dataframes with non-string or non-integer column/row indices + will encounter a runtime error from pyarrow. + + action (ActionType): the enum that tells the deserialization function + at runtime whether a row or a column index is being cast back. + """ + if isinstance(name, int): + temp_indices.append(str(name)) + self.restore_df_actions.append("cast_str_to_int") + self.restore_df_actions_args.append([action, str(name)]) + else: + temp_indices.append(name) + + def cast_int_to_str(self, df: PandasData, action: ActionType): + """ + This function casts integers to strings depending on the action type. + + In the cases of casting integer-indexed columns or rows, the function + will modify the dataframe and append to restore_df_actions that will cast + the column and row indices back to their original data types. + + In the case of handling categorical columns, the function will keep track + of the column names with integers being the primitive data type, preserve + their orders if the column is ordered, and add relevant metadata to the + restore_df_actions and restore_df_actions_args arrays. + + Args: + df (pd.DataFrame): + Required. This is the dataframe to serialize. + action (enum.Enum): + Required. One of [CAST_COL_NAME, CAST_ROW_NAME, CAST_CATEGORICAL] + """ + temp_indices = [] + if action == ActionType.CAST_COL_NAME: + for i in range(len(df.columns)): + self._append_to_temp_indices(temp_indices, df.columns[i], action) + df.columns = temp_indices + elif action == ActionType.CAST_ROW_INDEX: + for i in range(len(df.index)): + self._append_to_temp_indices(temp_indices, df.index[i], action) + df.index = temp_indices + elif action == ActionType.CAST_CATEGORICAL: + columns_to_cast = [] + column_orders = [] + columns_to_reorder = [] + for col_name in df.select_dtypes(include=["category"]): + if df[col_name].cat.ordered: + column_orders.append(df[col_name].cat.categories.values.tolist()) + columns_to_reorder.append(col_name) + # cast the columns with integers as categories + try: + int(df.at[df[col_name].first_valid_index(), col_name]) + columns_to_cast.append(col_name) + # pass on the columns that are non-integers + except ValueError: + pass + self.restore_df_actions.append("restore_category_order") + self.restore_df_actions_args.append([columns_to_reorder, column_orders]) + + self.restore_df_actions.append("cast_str_to_int") + self.restore_df_actions_args.append([action, columns_to_cast]) + + @staticmethod + def cast_str_to_int( + df: PandasData, + action: ActionType, + index_name_or_columns: Union[List[str], str] = None, + ): + """ + This function is used by the deserialization function to undo any temp + workarounds applied to the dataframe during serialization. + + Args: + df (pd.DataFrame): + Required. This is the dataframe to deserialize. + action (enum.Enum): + Required. One of [CAST_COL_NAME, CAST_ROW_NAME, CAST_CATEGORICAL] + index_name_or_columns (Union[List[str], str]): + Required. This is the list of index names to cast back to int + in the case of restoring row or column indices. In the case of + categorical columns, this is the list of column names to restore. + """ + restored_indices = [] + if action == ActionType.CAST_COL_NAME: + for i in range(len(df.columns)): + if df.columns[i] == index_name_or_columns: + restored_indices.append(int(index_name_or_columns)) + else: + restored_indices.append(df.columns[i]) + df.columns = restored_indices + elif action == ActionType.CAST_ROW_INDEX: + for i in range(len(df.index)): + if df.index[i] == index_name_or_columns: + restored_indices.append(int(index_name_or_columns)) + else: + restored_indices.append(df.index[i]) + df.index = restored_indices + elif action == ActionType.CAST_CATEGORICAL: + for column in index_name_or_columns: + df[column] = df[column].astype("int", errors="ignore") + df[column] = df[column].astype("category") + + @staticmethod + def restore_category_order( + df: PandasData, columns: List[str], categories: List[Any] + ): + for (column, category) in zip(columns, categories): + df[column] = df[column].cat.set_categories( + new_categories=category, ordered=True + ) diff --git a/vertexai/preview/_workflow/shared/model_utils.py b/vertexai/preview/_workflow/shared/model_utils.py new file mode 100644 index 0000000000..36261e4acd --- /dev/null +++ b/vertexai/preview/_workflow/shared/model_utils.py @@ -0,0 +1,252 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Model utils. + +Push trained model from local to Model Registry, and pull Model Registry model +to local for uptraining. +""" + +import os +from typing import Any, Union + +from google.cloud import aiplatform +from google.cloud.aiplatform import utils +import vertexai +from vertexai.preview._workflow import driver +from vertexai.preview._workflow.serialization_engine import ( + any_serializer, + serializers_base, +) + +_SKLEARN_FILE_NAME = "model.pkl" +_TF_DIR_NAME = "saved_model" +_PYTORCH_FILE_NAME = "model.mar" +_REWRAPPER_NAME = "rewrapper" + + +def _register_sklearn_model( + model: "sklearn.base.BaseEstimator", # noqa: F821 + serializer: serializers_base.Serializer, + staging_bucket: str, + rewrapper: Any, +) -> aiplatform.Model: + """Register sklearn model.""" + unique_model_name = ( + f"vertex-ai-registered-sklearn-model-{utils.timestamped_unique_name()}" + ) + gcs_dir = os.path.join(staging_bucket, unique_model_name) + # serialize rewrapper + file_path = os.path.join(gcs_dir, _REWRAPPER_NAME) + serializer.serialize(rewrapper, file_path) + # serialize model + file_path = os.path.join(gcs_dir, _SKLEARN_FILE_NAME) + serializer.serialize(model, file_path) + + container_image_uri = aiplatform.helpers.get_prebuilt_prediction_container_uri( + framework="sklearn", + framework_version="1.0", + ) + + vertex_model = aiplatform.Model.upload( + display_name=unique_model_name, + artifact_uri=gcs_dir, + serving_container_image_uri=container_image_uri, + labels={"registered_by_vertex_ai": "true"}, + sync=True, + ) + + return vertex_model + + +def _register_tf_model( + model: "tensorflow.Module", # noqa: F821 + serializer: serializers_base.Serializer, + staging_bucket: str, + rewrapper: Any, + use_gpu: bool = False, +) -> aiplatform.Model: + """Register TensorFlow model.""" + unique_model_name = ( + f"vertex-ai-registered-tensorflow-model-{utils.timestamped_unique_name()}" + ) + gcs_dir = os.path.join(staging_bucket, unique_model_name) + # serialize rewrapper + file_path = os.path.join(gcs_dir, _TF_DIR_NAME + "/" + _REWRAPPER_NAME) + serializer.serialize(rewrapper, file_path) + # serialize model + file_path = os.path.join(gcs_dir, _TF_DIR_NAME) + serializer.serialize(model, file_path) + + container_image_uri = aiplatform.helpers.get_prebuilt_prediction_container_uri( + framework="tensorflow", + framework_version="2.11", + accelerator="gpu" if use_gpu else "cpu", + ) + + vertex_model = aiplatform.Model.upload( + display_name=unique_model_name, + artifact_uri=file_path, + serving_container_image_uri=container_image_uri, + labels={"registered_by_vertex_ai": "true"}, + sync=True, + ) + + return vertex_model + + +def _register_pytorch_model( + model: "torch.nn.Module", # noqa: F821 + serializer: serializers_base.Serializer, + staging_bucket: str, + rewrapper: Any, + use_gpu: bool = False, +) -> aiplatform.Model: + """Register PyTorch model.""" + unique_model_name = ( + f"vertex-ai-registered-pytorch-model-{utils.timestamped_unique_name()}" + ) + gcs_dir = os.path.join(staging_bucket, unique_model_name) + + # serialize rewrapper + file_path = os.path.join(gcs_dir, _REWRAPPER_NAME) + serializer.serialize(rewrapper, file_path) + + # This archive model is required for using prediction pre-built container + archive_file_path = os.path.join(gcs_dir, _PYTORCH_FILE_NAME) + serializer.serialize(model, archive_file_path) + + container_image_uri = aiplatform.helpers.get_prebuilt_prediction_container_uri( + framework="pytorch", + framework_version="1.12", + accelerator="gpu" if use_gpu else "cpu", + ) + + vertex_model = aiplatform.Model.upload( + display_name=unique_model_name, + artifact_uri=gcs_dir, + serving_container_image_uri=container_image_uri, + labels={"registered_by_vertex_ai": "true"}, + sync=True, + ) + + return vertex_model + + +def register( + model: Union[ + "sklearn.base.BaseEstimator", "tf.Module", "torch.nn.Module" # noqa: F821 + ], + use_gpu: bool = False, +) -> aiplatform.Model: + """Registers a model and returns a Model representing the registered Model resource. + + Args: + model (Union["sklearn.base.BaseEstimator", "tensorflow.Module", "torch.nn.Module"]): + Required. An OSS model. Supported frameworks: sklearn, tensorflow, pytorch. + use_gpu (bool): + Optional. Whether to use GPU for model serving. Default to False. + + Returns: + vertex_model (aiplatform.Model): + Instantiated representation of the registered model resource. + + Raises: + ValueError: if default staging bucket is not set + or if the framework is not supported. + """ + staging_bucket = vertexai.preview.global_config.staging_bucket + if not staging_bucket: + raise ValueError( + "A default staging bucket is required to upload the model file. " + "Please call `vertexai.init(staging_bucket='gs://my-bucket')." + ) + + # Unwrap VertexRemoteFunctor before upload to Model Registry. + rewrapper = driver._unwrapper(model) + + serializer = any_serializer.AnySerializer() + # TODO(b/279812300) + if model.__module__.startswith("sklearn"): + return _register_sklearn_model(model, serializer, staging_bucket, rewrapper) + + elif model.__module__.startswith("keras") or ( + hasattr(model, "_tracking_metadata") + ): # pylint: disable=protected-access + return _register_tf_model(model, serializer, staging_bucket, rewrapper, use_gpu) + + elif "torch" in model.__module__ or (hasattr(model, "state_dict")): + return _register_pytorch_model( + model, serializer, staging_bucket, rewrapper, use_gpu + ) + + else: + raise ValueError("Support uploading PyTorch, scikit-learn and TensorFlow only.") + + +def from_pretrained( + *, + model_name: str, +) -> Union["sklearn.base.BaseEstimator", "tf.Module", "torch.nn.Module"]: # noqa: F821 + """Pulls a model from Model Registry for retraining. + + Args: + model_name (str): + Required. The resource ID or fully qualified resource name of a registered model. + Format: "12345678910" or + "projects/123/locations/us-central1/models/12345678910@1". + + Returns: + model: local model for uptraining. + + Raises: + ValueError: If registered model is not registered through `vertexai.preview.register` + """ + + project = vertexai.preview.global_config.project + location = vertexai.preview.global_config.location + credentials = vertexai.preview.global_config.credentials + + vertex_model = aiplatform.Model( + model_name, project=project, location=location, credentials=credentials + ) + if vertex_model.labels.get("registered_by_vertex_ai") != "true": + raise ValueError( + f"The model {model_name} is not registered through `vertexai.preview.register`." + ) + artifact_uri = vertex_model.uri + + # sklearn, TF, PyTorch model extensions for retraining. + # PyTorch serv will need model.mar + if "tf" in vertex_model.container_spec.image_uri: + model_file = "" + elif "sklearn" in vertex_model.container_spec.image_uri: + model_file = _SKLEARN_FILE_NAME + elif "pytorch" in vertex_model.container_spec.image_uri: + # Assume the pretrained model will be pulled for uptraining. + model_file = _PYTORCH_FILE_NAME + else: + raise ValueError("Support loading PyTorch, scikit-learn and TensorFlow only.") + + serializer = any_serializer.AnySerializer() + model = serializer.deserialize(os.path.join(artifact_uri, model_file)) + + rewrapper = serializer.deserialize(os.path.join(artifact_uri, _REWRAPPER_NAME)) + + # Rewrap model (in-place) for following remote training. + rewrapper(model) + return model diff --git a/vertexai/preview/_workflow/shared/supported_frameworks.py b/vertexai/preview/_workflow/shared/supported_frameworks.py new file mode 100644 index 0000000000..0794a43901 --- /dev/null +++ b/vertexai/preview/_workflow/shared/supported_frameworks.py @@ -0,0 +1,374 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import importlib + +try: + from importlib import metadata as importlib_metadata +except ImportError: + import importlib_metadata +import inspect +import sys +from typing import Any, List, Tuple +import warnings + +from google.cloud.aiplatform import base +from packaging import version + + +_LOGGER = base.Logger(__name__) + +# This most likely needs to be map +REMOTE_FRAMEWORKS = frozenset(["sklearn", "keras", "lightning"]) + +REMOTE_TRAINING_MODEL_UPDATE_ONLY_OVERRIDE_LIST = frozenset(["fit", "train"]) + +# Methods that change the state of the object during a training workflow +REMOTE_TRAINING_STATEFUL_OVERRIDE_LIST = frozenset(["fit", "train", "fit_transform"]) + +# Methods that don't change the state of the object during a training workflow +REMOTE_TRAINING_FUNCTIONAL_OVERRIDE_LIST = frozenset(["transform"]) + +# Methods involved in training process +REMOTE_TRAINING_OVERRIDE_LIST = ( + REMOTE_TRAINING_STATEFUL_OVERRIDE_LIST | REMOTE_TRAINING_FUNCTIONAL_OVERRIDE_LIST +) + +REMOTE_PREDICTION_OVERRIDE_LIST = frozenset(["predict"]) + +REMOTE_OVERRIDE_LIST = REMOTE_TRAINING_OVERRIDE_LIST.union( + REMOTE_PREDICTION_OVERRIDE_LIST +) + + +LIBRARY_TO_MODULE_MAP = {"scikit-learn": "sklearn", "tf-models-official": "official"} + + +def _get_version_for_package(package_name: str) -> str: + try: + # Note: this doesn't work in the internal environment since + # importlib.metadata relies on the directory site-packages to collect + # the metadata of Python packages. + return importlib_metadata.version(package_name) + except importlib_metadata.PackageNotFoundError: + _LOGGER.info( + "Didn't find package %s via importlib.metadata. Trying to import it.", + package_name, + ) + try: + if package_name in LIBRARY_TO_MODULE_MAP: + module_name = LIBRARY_TO_MODULE_MAP[package_name] + else: + # Note: this assumes the top-level module name is the same as the + # package name after replacing "-" in the package name by "_". + # This is not always true. + module_name = package_name.replace("-", "_") + + module = importlib.import_module(module_name) + # This assumes the top-level module has __version__ attribute, but this + # is not always true. + return module.__version__ + except Exception as exc: + raise RuntimeError(f"{package_name} is not installed.") from exc + + +def _get_mro(cls_or_ins: Any) -> Tuple[Any, ...]: + if inspect.isclass(cls_or_ins): + return cls_or_ins.__mro__ + else: + return cls_or_ins.__class__.__mro__ + + +# pylint: disable=g-import-not-at-top +def _is_keras(cls_or_ins: Any) -> bool: + try: + global keras + from tensorflow import keras + + return keras.layers.Layer in _get_mro(cls_or_ins) + except ImportError: + return False + + +def _is_sklearn(cls_or_ins: Any) -> bool: + try: + global sklearn + import sklearn + + return sklearn.base.BaseEstimator in _get_mro(cls_or_ins) + except ImportError: + return False + + +def _is_lightning(cls_or_ins: Any) -> bool: + try: + global torch + global lightning + import torch + import lightning + + return lightning.pytorch.trainer.trainer.Trainer in _get_mro(cls_or_ins) + except ImportError: + return False + + +def _is_torch(cls_or_ins: Any) -> bool: + try: + global torch + import torch + + return torch.nn.modules.module.Module in _get_mro(cls_or_ins) + except ImportError: + return False + + +def _is_torch_dataloader(cls_or_ins: Any) -> bool: + try: + global torch + import torch + + return torch.utils.data.DataLoader in _get_mro(cls_or_ins) + except ImportError: + return False + + +def _is_tensorflow(cls_or_ins: Any) -> bool: + try: + global tf + import tensorflow as tf + + return tf.Module in _get_mro(cls_or_ins) + except ImportError: + return False + + +def _is_pandas_dataframe(possible_dataframe: Any) -> bool: + try: + global pd + import pandas as pd + + return pd.DataFrame in _get_mro(possible_dataframe) + except ImportError: + return False + + +def _is_bigframe(possible_dataframe: Any) -> bool: + try: + global bf + import bigframes as bf + + return bf.dataframe.DataFrame in _get_mro(possible_dataframe) + except ImportError: + return False + + +# pylint: enable=g-import-not-at-top +def _is_oss(cls_or_ins: Any) -> bool: + return any( + [_is_sklearn(cls_or_ins), _is_keras(cls_or_ins), _is_lightning(cls_or_ins)] + ) + + +# pylint: disable=undefined-variable +def _get_deps_if_sklearn_model(model: Any) -> List[str]: + deps = [] + if _is_sklearn(model): + dep_version = version.Version(sklearn.__version__).base_version + deps.append(f"scikit-learn=={dep_version}") + return deps + + +def _get_deps_if_tensorflow_model(model: Any) -> List[str]: + deps = [] + if _is_tensorflow(model): + dep_version = version.Version(tf.__version__).base_version + deps.append(f"tensorflow=={dep_version}") + return deps + + +def _get_deps_if_torch_model(model: Any) -> List[str]: + deps = [] + if _is_torch(model): + dep_version = version.Version(torch.__version__).base_version + deps.append(f"torch=={dep_version}") + return deps + + +def _get_deps_if_lightning_model(model: Any) -> List[str]: + deps = [] + if _is_lightning(model): + lightning_version = version.Version(lightning.__version__).base_version + torch_version = version.Version(torch.__version__).base_version + deps.append(f"lightning=={lightning_version}") + deps.append(f"torch=={torch_version}") + try: + global tensorboard + import tensorboard + + tensorboard_version = version.Version(tensorboard.__version__).base_version + deps.append(f"tensorboard=={tensorboard_version}") + except ImportError: + pass + try: + global tensorboardX + import tensorboardX + + tensorboardX_version = version.Version( + tensorboardX.__version__ + ).base_version + deps.append(f"tensorboardX=={tensorboardX_version}") + except ImportError: + pass + + return deps + + +def _get_deps_if_torch_dataloader(obj: Any) -> List[str]: + deps = [] + if _is_torch_dataloader(obj): + dep_version = version.Version(torch.__version__).base_version + deps.append(f"torch=={dep_version}") + deps.extend(_get_cloudpickle_deps()) + return deps + + +def _get_cloudpickle_deps() -> List[str]: + deps = [] + try: + global cloudpickle + import cloudpickle + + dep_version = version.Version(cloudpickle.__version__).base_version + deps.append(f"cloudpickle=={dep_version}") + except ImportError as e: + raise ImportError( + "Cloudpickle is not installed. Please call `pip install google-cloud-aiplatform[preview]`." + ) from e + + return deps + + +def _get_deps_if_pandas_dataframe(possible_dataframe: Any) -> List[str]: + deps = [] + if _is_pandas_dataframe(possible_dataframe): + dep_version = version.Version(pd.__version__).base_version + deps.append(f"pandas=={dep_version}") + try: + import pyarrow as pa + + pyarrow_version = version.Version(pa.__version__).base_version + deps.append(f"pyarrow=={pyarrow_version}") + except ImportError: + deps.append("pyarrow") + # Note: it's likely that a DataFrame can be changed to other format, and + # therefore needs to be serialized by CloudPickleSerializer. An example + # is sklearn's Transformer.fit_transform() method, whose output is always + # a ndarray. + deps += _get_cloudpickle_deps() + return deps + + +def _get_deps_if_bigframe(possible_dataframe: Any) -> List[str]: + deps = [] + if _is_bigframe(possible_dataframe): + dep_version = version.Version(bf.__version__).base_version + deps.append(f"bigframes=={dep_version}") + + # Note: it's likely that a DataFrame can be changed to other format, and + # therefore needs to be serialized by CloudPickleSerializer. An example + # is sklearn's Transformer.fit_transform() method, whose output is always + # a ndarray. + deps += _get_cloudpickle_deps() + deps += _get_pandas_deps() + return deps + + +def _get_numpy_deps() -> List[str]: + deps = [] + try: + global numpy + import numpy + + dep_version = version.Version(numpy.__version__).base_version + deps.append(f"numpy=={dep_version}") + except ImportError: + deps.append("numpy") + return deps + + +def _get_pandas_deps() -> List[str]: + deps = [] + try: + global pd + import pandas as pd + + dep_version = version.Version(pd.__version__).base_version + deps.append(f"pandas=={dep_version}") + except ImportError: + deps.append("pandas") + return deps + + +# pylint: enable=undefined-variable + + +def _get_estimator_requirement(estimator: Any) -> List[str]: + """Returns a list of requirements given an estimator.""" + deps = [] + deps.extend(_get_numpy_deps()) + deps.extend(_get_pandas_deps()) + deps.extend(_get_cloudpickle_deps()) + deps.extend(_get_deps_if_sklearn_model(estimator)) + deps.extend(_get_deps_if_tensorflow_model(estimator)) + deps.extend(_get_deps_if_torch_model(estimator)) + deps.extend(_get_deps_if_lightning_model(estimator)) + # dedupe the dependencies by casting it to a dict first (dict perserves the + # order while set doesn't) + return list(dict.fromkeys(deps)) + + +def _get_python_minor_version() -> str: + # this will generally be the container with least or no security vulnerabilities + return ".".join(sys.version.split()[0].split(".")[0:2]) + + +def _get_cpu_container_uri() -> str: + """Returns the container uri used for cpu training.""" + return f"python:{_get_python_minor_version()}" + + +def _get_gpu_container_uri(estimator: Any) -> str: + """Returns the container uri used for gpu training given an estimator.""" + local_python_version = _get_python_minor_version() + if _is_tensorflow(estimator): + if local_python_version != "3.10": + warnings.warn( + f"Your local runtime has python{local_python_version}, but your " + "remote GPU training will be executed in python3.10" + ) + return "us-docker.pkg.dev/vertex-ai/training/tf-gpu.2-11.py310:latest" + + elif _is_torch(estimator) or _is_lightning(estimator): + if local_python_version != "3.10": + warnings.warn( + f"Your local runtime has python{local_python_version}, but your " + "remote GPU training will be executed in python3.10" + ) + return "pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime" + + else: + raise ValueError(f"{estimator} is not supported for GPU training.") diff --git a/vertexai/preview/developer/__init__.py b/vertexai/preview/developer/__init__.py new file mode 100644 index 0000000000..06d668e12a --- /dev/null +++ b/vertexai/preview/developer/__init__.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from vertexai.preview._workflow.serialization_engine import ( + any_serializer, +) +from vertexai.preview._workflow.serialization_engine import ( + serializers_base, +) +from vertexai.preview._workflow.shared import configs +from vertexai.preview.developer import mark +from vertexai.preview.developer import remote_specs + + +PersistentResourceConfig = configs.PersistentResourceConfig +Serializer = serializers_base.Serializer +SerializationMetadata = serializers_base.SerializationMetadata +RemoteConfig = configs.RemoteConfig +WorkerPoolSpec = remote_specs.WorkerPoolSpec +WorkerPoolSepcs = remote_specs.WorkerPoolSpecs + +register_serializer = any_serializer.register_serializer + + +__all__ = ( + "mark", + "PersistentResourceConfig", + "register_serializer", + "Serializer", + "SerializationMetadata", + "RemoteConfig", + "WorkerPoolSpec", + "WorkerPoolSepcs", +) diff --git a/vertexai/preview/developer/base_classes.py b/vertexai/preview/developer/base_classes.py new file mode 100644 index 0000000000..255fe5185c --- /dev/null +++ b/vertexai/preview/developer/base_classes.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Placeholder for base Model and FeatureTransformer classes. +""" + + +class Model: + pass + + +class FeatureTransformer: + pass diff --git a/vertexai/preview/developer/mark.py b/vertexai/preview/developer/mark.py new file mode 100644 index 0000000000..c2fa4f40bc --- /dev/null +++ b/vertexai/preview/developer/mark.py @@ -0,0 +1,197 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import functools +import inspect +from typing import Any, Callable, List, Optional, Union + +from vertexai.preview._workflow.driver import remote +from vertexai.preview._workflow.executor import ( + remote_container_training, +) +from vertexai.preview._workflow.executor import ( + training, + prediction, +) +from vertexai.preview._workflow.shared import configs +from vertexai.preview.developer import remote_specs + + +def train( + remote_config: Optional[configs.RemoteConfig] = None, +) -> Callable[..., Any]: + """Decorator to enable Vertex remote training on a method. + + Example Usage: + ``` + vertexai.init( + project="my-project", + location="my-location", + staging_bucket="gs://my-bucket", + ) + vertexai.preview.init(remote=True) + + class MyModel(vertexai.preview.VertexModel): + ... + + @vertexai.preview.developer.mark.train() + def my_train_method(...): + ... + + model = MyModel(...) + + # This train method will be executed remotely + model.my_train_method(...) + ``` + + Args: + remote_config (config.RemoteConfig): + Optional. A class that holds the configuration for the remote job. + + Returns: + A wrapped method with its original signature. + """ + + def remote_training_wrapper(method: Callable[..., Any]) -> Callable[..., Any]: + functor = remote.remote_method_decorator(method, training.remote_training) + if remote_config is not None: + if inspect.ismethod(method): + functor.vertex.remote_config = remote_config + else: + functor.vertex = functools.partial( + configs.VertexConfig, remote_config=remote_config + ) + + return functor + + return remote_training_wrapper + + +# pylint: disable=protected-access +def _remote_container_train( + image_uri: str, + additional_data: List[ + Union[remote_specs._InputParameterSpec, remote_specs._OutputParameterSpec] + ], + remote_config: Optional[configs.DistributedTrainingConfig] = None, +) -> Callable[..., Any]: + """Decorator to enable remote training with a container image. + + This decorator takes the parameters from the __init__ function (requires + setting up binding outside of the decorator) and the function that it + decorates, preprocesses the arguments, and launches a custom job for + training. + + As the custom job is running, the inputs are read and parsed according to + the container code, and the outputs are written to the GCS paths specified + for each output field. + + If the custom job succeeds, the decorator deserializes the outputs from the + custom job and sets them as instance attributes. Each output will be either + a string or bytes, and the function this decorator decorates may + additionally post-process the outputs to their corresponding types. + + Args: + image_uri (str): + Required. The pre-built docker image uri for CustomJob. + additional_data (List): + Required. A list of input and output parameter specs. + remote_config (config.DistributedTrainingConfig): + Optional. A class that holds the configuration for the distributed + training remote job. + + Returns: + An inner decorator that returns the decorated remote container training + function. + + Raises: + ValueError if the decorated function has a duplicate argument name as + the parameters in existing binding, or if an additional data is neither + an input parameter spec or an output parameter spec. + """ + + def remote_training_wrapper(method: Callable[..., Any]) -> Callable[..., Any]: + functor = remote.remote_method_decorator( + method, + remote_container_training.train, + remote_executor_kwargs={ + "image_uri": image_uri, + "additional_data": additional_data, + }, + ) + config = remote_config or configs.DistributedTrainingConfig() + if inspect.ismethod(method): + functor.vertex.remote_config = config + functor.vertex.remote = True + else: + functor.vertex = functools.partial( + configs.VertexConfig, remote=True, remote_config=config + ) + + return functor + + return remote_training_wrapper + + +def predict( + remote_config: Optional[configs.RemoteConfig] = None, +) -> Callable[..., Any]: + """Decorator to enable Vertex remote prediction on a method. + + Example Usage: + ``` + vertexai.init( + project="my-project", + location="my-location", + staging_bucket="gs://my-bucket", + ) + vertexai.preview.init(remote=True) + + class MyModel(vertexai.preview.VertexModel): + ... + + @vertexai.preview.developer.mark.predict() + def my_predict_method(...): + ... + + model = MyModel(...) + + # This train method will be executed remotely + model.my_predict_method(...) + ``` + + Args: + remote_config (config.RemoteConfig): + Optional. A class that holds the configuration for the remote job. + + Returns: + A wrapped method with its original signature. + """ + + def remote_prediction_wrapper(method: Callable[..., Any]) -> Callable[..., Any]: + functor = remote.remote_method_decorator(method, prediction.remote_prediction) + if remote_config is not None: + if inspect.ismethod(method): + functor.vertex.remote_config = remote_config + else: + functor.vertex = functools.partial( + configs.VertexConfig, remote_config=remote_config + ) + + return functor + + return remote_prediction_wrapper diff --git a/vertexai/preview/developer/remote_specs.py b/vertexai/preview/developer/remote_specs.py new file mode 100644 index 0000000000..eb27e38bb4 --- /dev/null +++ b/vertexai/preview/developer/remote_specs.py @@ -0,0 +1,862 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Remote workload specs and helper functions for developer.mark.fit. + +""" + +import dataclasses +import json +import os +import tempfile + +from typing import Any, Dict, List, Optional + +from google.cloud.aiplatform import base +from google.cloud.aiplatform.utils import gcs_utils +from google.cloud.aiplatform.utils import worker_spec_utils +from vertexai.preview._workflow.serialization_engine import ( + serializers, +) + +try: + import tensorflow as tf +except ImportError: + pass +try: + import torch +except ImportError: + pass + +_LOGGER = base.Logger(__name__) + +_LITERAL: str = "literal" +_PARQUET: str = "parquet" +_CLOUDPICKLE: str = "cloudpickle" + +# Constants for serializer. +_SERIALIZER = frozenset([_LITERAL, _PARQUET, _CLOUDPICKLE]) + +_METADATA_FILE_NAME = "metadata" +_DATA_FILE_NAME = "data.parquet" +_FLOAT16 = "float16" +_FLOAT32 = "float32" +_FLOAT64 = "float64" +_INT8 = "int8" +_INT16 = "int16" +_INT32 = "int32" +_INT64 = "int64" +_UINT8 = "uint8" +_UINT16 = "uint16" +_UINT32 = "uint32" +_UINT64 = "uint64" +_SUPPORTED_NUMERICAL_DTYPES = [ + _FLOAT16, + _FLOAT32, + _FLOAT64, + _INT8, + _INT16, + _INT32, + _INT64, + _UINT8, + _UINT16, + _UINT32, + _UINT64, +] +_DENSE = "dense" + +# Constants for deserializer. +_DESERIALIZER = frozenset([_LITERAL, _CLOUDPICKLE]) + +# Constants for Cluster and ClusterSpec +_CHIEF = "workerpool0" +_WORKER = "workerpool1" +_SERVER = "workerpool2" +_EVALUATOR = "workerpool3" +_WORKER_POOLS = frozenset([_CHIEF, _WORKER, _SERVER, _EVALUATOR]) +_CLUSTER = "cluster" +_TASK = "task" +_TYPE = "type" +_INDEX = "index" +_TRIAL = "trial" + +_CLUSTER_SPEC = "CLUSTER_SPEC" +_MASTER_ADDR = "MASTER_ADDR" +_MASTER_PORT = "MASTER_PORT" + + +def _gen_gcs_path(base_dir: str, name: str) -> str: + """Generates a GCS path for a file or a directory. + + The created path can be used for either a file or a directory. If it is a + file, we can directly write to it. If it is a directory, file paths could + be generated by joining the derectly path and the file names. + + Example usages: + 1. When passing in parameters to a custom job, we will be able to write + serialized parameter value content to a GCS path in 'cloudpickle' mode. + 2. We will also provide GCS paths so that the custom job can write the + output parameter values to the dedicated paths. + + Args: + base_dir (str): + Required. The base GCS directory. Must be a valid GCS path that + starts with 'gs://'. + name (str): + Required. The name of a file or directory. If name ends with '/', + removes '/' for consistency since we do not need the suffix to + identify the path as a directory. + Returns: + The generated GCS path. + Raises: + ValueError if the input base_dir is not a valid GCS path. + """ + if not base_dir.startswith("gs://"): + raise ValueError(f"base_dir {base_dir} is not a valid GCS path.") + name = name[:-1] if name.endswith("/") else name + return os.path.join(base_dir, name) + + +def _get_argument_name(name: str) -> str: + """Gets an argument name for the inputs and outputs of a container. + + 1. If the name contains dots such as a.b.arg_name or self.arg_name, use the + string following the right-most dot (arg_name) as the argument name. + 2. If the name has a single leading underscore, such as _arg_name, remove + the leading underscore in the argument name (arg_name). If the name has a + double leading underscore such as __arg_name, use the argument name + __arg_name directly. + + Args: + name (str): + Required. The name of the parameter in the InputParameterSpec. + + Returns: + The name of the argument in the container. + """ + argument_name = name.split(".")[-1] + if argument_name.startswith("_") and not argument_name.startswith("__"): + argument_name = argument_name[1:] + if not argument_name: + raise ValueError(f"Failed to get argument name from name {name}.") + return argument_name + + +@dataclasses.dataclass +class _FeatureMetadata: + dtype: str + feature_type: str = _DENSE + + +@dataclasses.dataclass +class _CategoricalFeatureMetadata: + dtype: str + categories: List[Any] + feature_type: str = _DENSE + + +@dataclasses.dataclass +class _TaskInfo: + """Describes the task of the particular node on which code is running. + + Args: + task_type (str): + Required. The type of worker pool this task is running in. One of 'workerpool0' for chief, 'workerpool1' for worker, 'workerpool2' for server or 'workerpool3' for evaluator. + task_index (int): + Required. The zero-based index of the task. If a training job has two workers, this value is set to 0 on one and 1 on the other. + task_trial (int): + Optional. The identifier of the hyperparameter tuning trial currently running. + """ + + task_type: str + task_index: int + task_trial: int = None + + +class _InputParameterSpec: + """Input parameter spec for remote trainers.""" + + def __init__( + self, + name: str, + argument_name: Optional[str] = None, + serializer: str = _LITERAL, + ) -> None: + """Initializes an _InputParameterSpec instance. + + When creating CustomJob spec, each _InputParameterSpec will be + transformed into a custom job input. + + Args: + name (str): + Required. The parameter name that stores the input value. + argument_name (str): + Optional. The argument name for the custom job input. If not + specified, an argument_name will be derived from name. + serializer (str): + Optional. The serializer for the input. Must be one of + 'literal', 'parquet', and 'cloudpickle'. + + Raises: + ValueError: If name or serializer is invalid. + """ + if not name: + raise ValueError("Input parameter name cannot be empty.") + self.name = name + self.argument_name = argument_name or _get_argument_name(name) + if serializer not in _SERIALIZER: + raise ValueError( + f"Invalid serializer {serializer} for {name}. Please" + f"choose one of {list(_SERIALIZER)}." + ) + self.serializer = serializer + + def format_arg(self, input_dir: str, binding: Dict[str, Any]) -> Any: + """Formats an argument based on the spec. + + Args: + input_dir (str): + Required. The GCS input directory to save the serialized input + value when necessary. + binding (Dict[str, Any]): + Required. A dictionary that contains maps an input name to its + value. + + Returns: + The formatted argument. + + Raises: + ValueError if the input is not found in binding, tries to serialize + a non-pandas.DataFrame to parquet, or the serialization format is + not supported. + """ + try: + # pylint: disable=g-import-not-at-top + import pandas as pd + except ImportError: + raise ImportError( + "pandas is not installed and is required for remote training." + ) from None + if self.name not in binding: + raise ValueError(f"Input {self.name} not found in binding: " f"{binding}.") + + value = binding[self.name] + if self.serializer == _LITERAL: + return value + + gcs_path = _gen_gcs_path(input_dir, self.argument_name) + if self.serializer == _PARQUET: + if not isinstance(value, pd.DataFrame): + raise ValueError( + "Parquet serializer is only supported for " + f"pandas.DataFrame, but {self.name} has type " + f"{type(value)}." + ) + # Serializes data + data_serializer = serializers.PandasDataSerializer() + data_path = _gen_gcs_path(gcs_path, _DATA_FILE_NAME) + data_serializer.serialize( + to_serialize=value, + gcs_path=data_path, + ) + + # Serializes feature metadata + metadata_serializer = serializers.CloudPickleSerializer() + metadata_path = _gen_gcs_path(gcs_path, _METADATA_FILE_NAME) + feature_metadata = _generate_feature_metadata(value) + metadata_serializer.serialize( + to_serialize=feature_metadata, gcs_path=metadata_path + ) + + elif self.serializer == _CLOUDPICKLE: + serializer = serializers.CloudPickleSerializer() + serializer.serialize( + to_serialize=value, + gcs_path=gcs_path, + ) + + else: + raise ValueError( + f"Unsupported serializer: {self.serializer}." + "The input serializer must be one of " + f"{_SERIALIZER}." + ) + return gcs_path + + +class _OutputParameterSpec: + """Output parameter spec for remote trainers.""" + + def __init__( + self, + name: str, + argument_name: Optional[str] = None, + deserializer: Optional[str] = _LITERAL, + ) -> None: + """Initializes an OutputParameterSpec instance. + + When creating CustomJob spec, each OutputParameterSpec will be + transformed into a custom job argument that will store the output value. + + Args: + name (str): + Required. The parameter name that will store the output value. + argument_name (str): + Optional. The argument name for the custom job argument. If not + specified, an argument_name will be derived from name. + deserializer (str): + Optional. The deserializer for the output. Must be one of + 'literal', and 'cloudpickle'. + + Raises: + ValueError: If name or deserializer is invalid. + """ + if not name: + raise ValueError("Output parameter name cannot be empty.") + self.name = name + self.argument_name = argument_name or _get_argument_name(name) + if deserializer not in _DESERIALIZER: + raise ValueError( + f"Invalid deserializer {deserializer} for {name}. Please" + f"choose one of {list(_DESERIALIZER)}." + ) + self.deserializer = deserializer + + def deserialize_output(self, gcs_path: str) -> Any: + """Deserializes an output based on the spec. + + Args: + gcs_path (str): + Required. The gcs path containing the output. + + Returns: + The deserialized output. + + Raises: + ValueError if the deserialization format is unsupported. + """ + if self.deserializer == _LITERAL: + with tempfile.NamedTemporaryFile() as temp_file: + gcs_utils.download_file_from_gcs(gcs_path, temp_file.name) + with open(temp_file.name, "r") as f: + return f.read() + elif self.deserializer == _CLOUDPICKLE: + serializer = serializers.CloudPickleSerializer() + return serializer.deserialize(serialized_gcs_path=gcs_path) + else: + raise ValueError(f"Unsupported deserializer: {self.deserializer}.") + + +def _generate_feature_metadata(df: Any) -> Dict[str, Any]: + """Helper function to generate feature metadata from a pandas DataFrame. + + When column types are not supported, the corresponding columns are excluded + from feature metadata. + + Args: + df (pandas.DataFrame): + Required. A DataFrame to generate feature metadata from. + + Returns: + A dictionary that maps column names to metadata. + + Raises: + ValueError if df is not a valid/ supported DataFrame. + """ + try: + # pylint: disable=g-import-not-at-top + import pandas as pd + except ImportError: + raise ImportError( + "pandas is not installed and is required for remote training." + ) from None + + if not isinstance(df, pd.DataFrame): + raise ValueError( + "Generating feature metadata is only supported for " + f"pandas.DataFrame, but {df} has type {type(df)}." + ) + + feature_metadata = {} + for col in df.columns: + if df[col].dtypes in _SUPPORTED_NUMERICAL_DTYPES: + feature_metadata[str(col)] = dataclasses.asdict( + _FeatureMetadata(str(df[col].dtypes)) + ) + # Ignores categorical columns that are not integers. + elif df[col].dtypes == "category" and df[col].cat.categories.dtype == _INT64: + categories = df[col].cat.categories.tolist() + feature_metadata[str(col)] = dataclasses.asdict( + _CategoricalFeatureMetadata(_INT64, categories) + ) + else: + # Ignores unsupported column type. + pass + return feature_metadata + + +class _Cluster: + """Represents a Cluster as a set of "tasks". + + Task type or worker pool can be one of chief, worker, server or evaluator. + + To create a cluster with two task types and three tasks, specify the + mapping from worker pool to list of network addresses. + + ```python + cluster = Cluster({"workerpool0": ["cmle-training-workerpool0-ab-0:2222"], + "workerpool1": ["cmle-training-workerpool1-ab-0:2222", + "cmle-training-workerpool1-ab-1:2222"]}) + ``` + """ + + def __init__(self, cluster_info: Dict[str, Any]): + """Initializes a Cluster instance. + + The cluster description contains a list of tasks for each + task type or worker pool specified in a CustomJob. + + Args: + cluster_info (Dict[str, Any]): Required. The cluster description + containing the list of tasks for each task type. + + Raises: + ValueError: If cluster description contains invalid task types. + """ + for task_type in cluster_info: + if task_type not in _WORKER_POOLS: + raise ValueError( + f"Invalid task type: {task_type}. Must be one of {_WORKER_POOLS}." + ) + self.cluster_info = cluster_info + + # Different worker pool types + @property + def chief_task_type(self) -> str: + return _CHIEF + + @property + def worker_task_type(self) -> str: + return _WORKER + + @property + def server_task_type(self) -> str: + return _SERVER + + @property + def evaluator_task_type(self) -> str: + return _EVALUATOR + + @property + def task_types(self) -> List[str]: + """Returns a list of task types in this cluster. + + Returns: + A list of task types in this cluster. + """ + return list(self.cluster_info.keys()) + + def get_num_tasks(self, task_type): + """Returns the number of tasks of a given task type. + + Args: + task_type (str): The task type. + + Returns: + The number of tasks of the given task type. + """ + if task_type not in self.cluster_info: + return 0 + return len(self.cluster_info[task_type]) + + def get_task_addresses(self, task_type): + """Returns list of task address for the task type. + + Args: + task_type (str): The task type. + + Returns: + A list of task address for the given task type. + + Raises: + ValueError: If the task type passed does not exist in the cluster. + """ + if task_type not in self.cluster_info: + raise ValueError(f"No such task type in cluster: {task_type}") + return self.cluster_info[task_type] + + +class _ClusterSpec: + """ClusterSpec for a distributed training job.""" + + def __init__(self, cluster_spec: Dict[str, Any]): + """Initializes a ClusterSpec instance. + + Vertex AI populates an environment variable, CLUSTER_SPEC, on every + replica to describe how the overall cluster is set up. For + distributed + training, this environment variable will be used to create a + ClusterSpec. + + A sample CLUSTER_SPEC: + ``` + { + "cluster": { + "workerpool0": [ + "cmle-training-workerpool0-ab-0:2222" + ], + "workerpool1": [ + "cmle-training-workerpool1-ab-0:2222", + "cmle-training-workerpool1-ab-1:2222" + ], + "workerpool2": [ + "cmle-training-workerpool2-ab-0:2222" + ], + "workerpool3": [ + "cmle-training-workerpool3-ab-0:2222" + ] + }, + "environment":"cloud", + "task":{ + "type": "workerpool0", + "index": 0 + } + } + ``` + Args: + cluster_spec (Dict[str, Any]): Required. The cluster spec + containing the cluster and current task specification. + + Raises: + ValueError: If `cluster_spec` is missing required keys. + """ + if _CLUSTER not in cluster_spec or _TASK not in cluster_spec: + raise ValueError(f"`cluster_spec` must contain {_CLUSTER} and {_TASK}") + self.cluster = _Cluster(cluster_spec[_CLUSTER]) + self.task = _TaskInfo( + task_type=cluster_spec[_TASK][_TYPE], + task_index=cluster_spec[_TASK][_INDEX], + task_trial=cluster_spec[_TASK].get(_TRIAL, None), + ) + + def get_rank(self): + """Returns the world rank of the current task. + + Returns: + The world rank of the current task. + """ + task_type = self.task.task_type + task_index = self.task.task_index + + if task_type == self.cluster.chief_task_type: + return 0 + if task_type == self.cluster.worker_task_type: + return task_index + 1 + + num_workers = self.cluster.get_num_tasks(self.cluster.worker_task_type) + if task_type == self.cluster.server_task_type: + return num_workers + task_index + 1 + + num_ps = self.cluster.get_num_tasks(self.cluster.server_task_type) + if task_type == self.cluster.evaluator_task_type: + return num_ps + num_workers + task_index + 1 + + def get_world_size(self): + """Returns the world size (total number of workers) for the current run. + + Returns: + The world size for the current run. + """ + num_chief = self.cluster.get_num_tasks(self.cluster.chief_task_type) + num_workers = self.cluster.get_num_tasks(self.cluster.worker_task_type) + num_ps = self.cluster.get_num_tasks(self.cluster.server_task_type) + num_evaluators = self.cluster.get_num_tasks(self.cluster.evaluator_task_type) + + return num_chief + num_workers + num_ps + num_evaluators + + def get_chief_address_port(self): + """Returns address and port for chief task. + + Returns: + A tuple of task address and port. + + Raises: + ValueError: If the chief task type does not exist in the cluster + """ + if self.cluster.chief_task_type not in self.cluster.task_types: + raise ValueError("Cluster must have a chief task.") + chief_task = self.cluster.get_task_addresses(self.cluster.chief_task_type)[0] + address, port = chief_task.split(":") + return address, int(port) + + +# pylint: disable=protected-access +class WorkerPoolSpec(worker_spec_utils._WorkerPoolSpec): + """Wraps class that holds a worker pool spec configuration. + + Attributes: + replica_count (int): + The number of worker replicas. + machine_type (str): + The type of machine to use for remote training. + accelerator_count (int): + The number of accelerators to attach to a worker replica. + accelerator_type (str): + Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, + NVIDIA_TESLA_A100, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, + NVIDIA_TESLA_K80, NVIDIA_TESLA_T4, NVIDIA_TESLA_P4 + boot_disk_type (str): + Type of the boot disk (default is `pd-ssd`). + Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or + `pd-standard` (Persistent Disk Hard Disk Drive). + boot_disk_size_gb (int): + Size in GB of the boot disk (default is 100GB). + boot disk size must be within the range of [100, 64000]. + """ + + +@dataclasses.dataclass +class WorkerPoolSpecs: + """A class that holds the worker pool specs configuration for a remote job. + + Attributes: + chief (WorkerPoolSpec): + The `cheif` or `workerpool0` worker pool spec configuration. + worker (WorkerPoolSpec): + The `worker` or `workerpool1` worker pool spec configuration. + server (WorkerPoolSpec): + The `server` or `workerpool2` worker pool spec configuration. + evaluator (WorkerPoolSpec): + The `evaluator` or `workerpool3` worker pool spec configuration. + """ + + chief: WorkerPoolSpec + worker: Optional[WorkerPoolSpec] = None + server: Optional[WorkerPoolSpec] = None + evaluator: Optional[WorkerPoolSpec] = None + + +def _prepare_worker_pool_specs( + worker_pool_specs: WorkerPoolSpecs, + image_uri: str, + command: Optional[List[Any]] = [], + args: Optional[List[Any]] = [], +): + """Return each worker pools spec in order for Vertex AI Training as a list of dicts. + + Args: + worker_pool_specs (WorkerPoolSpecs): Required. Worker pool specs configuration for a remote job. + image_uri (str): Required. Image uri for training. + command (str): Command for training. + args (str): Args for training. + + Returns: + Ordered list of worker pool specs for Vertex AI Training. + + Raises: + ValueError: If replica_count for cheif worker pool spec is greater than 1. + """ + + if worker_pool_specs.chief.replica_count > 1: + raise ValueError( + "Chief worker pool spec replica_count cannot be greater than 1." + ) + spec_order = [ + worker_pool_specs.chief, + worker_pool_specs.worker, + worker_pool_specs.server, + worker_pool_specs.evaluator, + ] + formatted_specs = [{} if not spec else spec.spec_dict for spec in spec_order] + + # Remove empty trailing worker pool specs + for i in reversed(range(len(spec_order))): + if spec_order[i]: + break + formatted_specs.pop() + + # Add container spec to each non-empty worker pool spec + for spec in formatted_specs: + if spec: + spec["container_spec"] = { + "image_uri": image_uri, + "command": command, + "args": args, + } + + return formatted_specs + + +def _verify_specified_remote_config_values( + worker_pool_specs: WorkerPoolSpecs, + machine_type: str, + accelerator_type: str, + accelerator_count: int, + replica_count: Optional[int] = None, + boot_disk_type: Optional[str] = None, + boot_disk_size_gb: Optional[int] = None, +): + """Helper to validate if remote_config.worker_pool_specs is set, other remote job config values are not.""" + if worker_pool_specs and ( + machine_type + or accelerator_type + or accelerator_count + or replica_count + or boot_disk_type + or boot_disk_size_gb + ): + raise ValueError( + "Cannot specify both 'worker_pool_specs' and ['machine_type', 'accelerator_type', 'accelerator_count', 'replica_count', 'boot_disk_type', 'boot_disk_size_gb']." + ) + + +def _get_cluster_spec() -> _ClusterSpec: + """Helper to check for CLUSTER_SPEC environment variable and return object if it exists.""" + cluster_spec_str = os.getenv(_CLUSTER_SPEC, "") + if cluster_spec_str: + return _ClusterSpec(json.loads(cluster_spec_str)) + return None + + +def _get_output_path_for_distributed_training(base_dir, name) -> str: + """Helper to get output path for distributed training.""" + cluster_spec = _get_cluster_spec() + if cluster_spec: + task_type = cluster_spec.task.task_type + task_id = cluster_spec.task.task_index + + if task_type != cluster_spec.cluster.chief_task_type: + temp_path = os.path.join(base_dir, "temp") + os.makedirs(temp_path, exist_ok=True) + temp_path = os.path.join(temp_path, f"{task_type}_{task_id}") + return temp_path + + return os.path.join(base_dir, name) + + +def _get_keras_distributed_strategy(enable_distributed: bool, accelerator_count: int): + """Returns distribute strategy for Keras distributed training. + + For multi-worker training, use tf.distribute.MultiWorkerMirroredStrategy(). + For single worker, multi-GPU training, use tf.distribute.MirroredStrategy(). + For non-distributed training, return None. Requires TensorFlow >= 2.12.0. + + Args: + enable_distributed (boolean): Required. Whether distributed training is enabled. + accelerator_count (int): Accelerator count specified for single worker training. + + Returns: + A tf.distribute.Strategy. + """ + if enable_distributed: + cluster_spec = _get_cluster_spec() + # Multiple workers, use tf.distribute.MultiWorkerMirroredStrategy(). + if cluster_spec and len(cluster_spec.cluster.task_types) >= 2: + return tf.distribute.MultiWorkerMirroredStrategy() + # Single worker, use tf.distribute.MirroredStrategy(). We validate accelerator_count > 1 before + # creating CustomJob. + else: + return tf.distribute.MirroredStrategy() + # Multi-GPU training, but enable_distributed is false, use tf.distribute.MirroredStrategy(). + elif accelerator_count and accelerator_count > 1: + return tf.distribute.MirroredStrategy() + # Not distributed, return None. + else: + return None + + +def _set_keras_distributed_strategy(model: Any, strategy: Any): + """Returns a model compiled within the scope of the specified distribute strategy. + + Requires TensorFlow >= 2.12.0. + + Args: + model (Any): Required. An instance of a Keras model. + strategy (tf.distribute.Strategy): The distribute strategy. + + Returns: + A tf.distribute.Strategy. + """ + # Clone and compile model within scope of chosen strategy. + with strategy.scope(): + cloned_model = tf.keras.models.clone_model(model) + cloned_model.compile_from_config(model.get_compile_config()) + + return cloned_model + + +def setup_pytorch_distributed_training(model: Any) -> Any: + """Sets up environment for PyTorch distributed training. + + The number of nodes or processes (`world_size`) is the number of + workers being used for the training run. This helper can be called + within the Vertex remote training-enabled function of a custom model + built on top of `torch.nn.Module`. + + Example Usage: + ``` + vertexai.init( + project="my-project", + location="my-location", + staging_bucket="gs://my-bucket", + ) + vertexai.preview.init(remote=True) + + class MyModel(vertexai.preview.VertexModel, torch.nn.Module): + ... + + @vertexai.preview.developer.mark.train() + def my_train_method(self, ...): + self = setup_pytorch_distributed_training(self) + ... + + model = MyModel(...) + + # This will execute distributed, remote training + model.my_train_method(...) + ``` + Args: + model (Any): Required. An instance of a custom PyTorch model. + + Returns: + A custom model built on top of `torch.nn.Module` wrapped in DistributedDataParallel. + """ + if not model.cluster_spec: # cluster_spec is populated for multi-worker training + return model + + device = "cuda" if model._enable_cuda else "cpu" + rank = model.cluster_spec.get_rank() + world_size = model.cluster_spec.get_world_size() + address, port = model.cluster_spec.get_chief_address_port() + + os.environ[_MASTER_ADDR] = address + os.environ[_MASTER_PORT] = str(port) + + torch.distributed.init_process_group( + backend="nccl" if device == "cuda" else "gloo", + rank=rank, + world_size=world_size, + ) + + if device == "cuda": + model.to(device) + model = torch.nn.parallel.DistributedDataParallel(model) + + _LOGGER.info( + f"Initialized process rank: {rank}, world_size: {world_size}, device: {device}", + ) + return model diff --git a/vertexai/preview/hyperparameter_tuning/__init__.py b/vertexai/preview/hyperparameter_tuning/__init__.py new file mode 100644 index 0000000000..16a54be359 --- /dev/null +++ b/vertexai/preview/hyperparameter_tuning/__init__.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from vertexai.preview.hyperparameter_tuning import ( + vizier_hyperparameter_tuner, +) + + +VizierHyperparameterTuner = vizier_hyperparameter_tuner.VizierHyperparameterTuner + + +__all__ = ("VizierHyperparameterTuner",) diff --git a/vertexai/preview/hyperparameter_tuning/vizier_hyperparameter_tuner.py b/vertexai/preview/hyperparameter_tuning/vizier_hyperparameter_tuner.py new file mode 100644 index 0000000000..9214d78932 --- /dev/null +++ b/vertexai/preview/hyperparameter_tuning/vizier_hyperparameter_tuner.py @@ -0,0 +1,961 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import concurrent +import functools +import inspect +import logging +import os +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import uuid + +from google.cloud.aiplatform import base +from google.cloud.aiplatform_v1.services.vizier_service import ( + VizierServiceClient, +) +from google.cloud.aiplatform_v1.types import study as gca_study +import vertexai +from vertexai.preview._workflow.driver import remote +from vertexai.preview._workflow.driver import ( + VertexRemoteFunctor, +) +from vertexai.preview._workflow.executor import ( + remote_container_training, +) +from vertexai.preview._workflow.executor import ( + training, +) +from vertexai.preview._workflow.shared import configs +from vertexai.preview._workflow.shared import ( + supported_frameworks, +) + + +try: + import pandas as pd + + PandasData = pd.DataFrame + +except ImportError: + PandasData = Any + +_LOGGER = base.Logger(__name__) + +# Metric id constants +_CUSTOM_METRIC_ID = "custom" +_ROC_AUC_METRIC_ID = "roc_auc" +_F1_METRIC_ID = "f1" +_PRECISION_METRIC_ID = "precision" +_RECALL_METRIC_ID = "recall" +_ACCURACY_METRIC_ID = "accuracy" +_MAE_METRIC_ID = "mae" +_MAPE_METRIC_ID = "mape" +_R2_METRIC_ID = "r2" +_RMSE_METRIC_ID = "rmse" +_RMSLE_METRIC_ID = "rmsle" +_MSE_METRIC_ID = "mse" + +try: # Only used by local tuning loop + import sklearn.metrics + from sklearn.model_selection import train_test_split + + _SUPPORTED_METRIC_FUNCTIONS = { + _ROC_AUC_METRIC_ID: sklearn.metrics.roc_auc_score, + _F1_METRIC_ID: sklearn.metrics.f1_score, + _PRECISION_METRIC_ID: sklearn.metrics.precision_score, + _RECALL_METRIC_ID: sklearn.metrics.recall_score, + _ACCURACY_METRIC_ID: sklearn.metrics.accuracy_score, + _MAE_METRIC_ID: sklearn.metrics.mean_absolute_error, + _MAPE_METRIC_ID: sklearn.metrics.mean_absolute_percentage_error, + _R2_METRIC_ID: sklearn.metrics.r2_score, + _RMSE_METRIC_ID: functools.partial( + sklearn.metrics.mean_squared_error, squared=False + ), + _RMSLE_METRIC_ID: functools.partial( + sklearn.metrics.mean_squared_log_error, squared=False + ), + _MSE_METRIC_ID: sklearn.metrics.mean_squared_error, + } + _SUPPORTED_METRIC_IDS = frozenset(_SUPPORTED_METRIC_FUNCTIONS.keys()).union( + frozenset([_CUSTOM_METRIC_ID]) + ) + _SUPPORTED_CLASSIFICATION_METRIC_IDS = frozenset( + [ + _ROC_AUC_METRIC_ID, + _F1_METRIC_ID, + _PRECISION_METRIC_ID, + _RECALL_METRIC_ID, + _ACCURACY_METRIC_ID, + ] + ) + +except ImportError: + pass + +# Vizier client constnats +_STUDY_NAME_PREFIX = "vizier_hyperparameter_tuner_study" +_CLIENT_ID = "client" + +# Train and test split constants +_DEFAULT_TEST_FRACTION = 0.25 + +# Parameter constants +_TRAINING_X_PARAMS = ["X", "x", "X_train", "x_train"] +_TRAINING_DATA_PARAMS = ["X", "x", "X_train", "x_train", "training_data"] +_OSS_TRAINING_DATA_PARAMS = ["X", "x"] +_TRAINING_TARGET_VALUE_PARAMS = ["y", "y_train"] +_Y_DATA_PARAM = "y" +_X_TEST_PARAMS = ["X_test", "x_test"] +_Y_TEST = "y_test" +_VALIDATION_DATA = "validation_data" + + +class VizierHyperparameterTuner: + """The Vizier hyperparameter tuner for local and remote tuning.""" + + def __init__( + self, + get_model_func: Callable[..., Any], + max_trial_count: int, + parallel_trial_count: int, + hparam_space: List[Dict[str, Any]], + metric_id: str = _ACCURACY_METRIC_ID, + metric_goal: str = "MAXIMIZE", + max_failed_trial_count: int = 0, + search_algorithm: str = "ALGORITHM_UNSPECIFIED", + project: Optional[str] = None, + location: Optional[str] = None, + study_display_name_prefix: str = _STUDY_NAME_PREFIX, + ): + """Initializes a VizierHyperparameterTuner instance. + + VizierHyperparameterTuner provides support for local and remote Vizier + hyperparameter tuning. For information on Vertex AI Vizier, refer to + https://cloud.google.com/vertex-ai/docs/vizier/overview. + + Args: + get_model_func (Callable[..., Any]): + Required. A function that returns a model to be tuned. Non-tunable + parameters should be preset by get_model_func, and tunable + parameters will be set byVizierHyperparameterTuner. + + Example: + # parameter_a and parameter_b are tunable. + def get_model_func(parameter_a, parameter_b): + # parameter_c is non-tunable + parameter_c = 10 + return ExampleModel(parameter_a, parameter_b, parameter_c) + + For lightning models, get_model_func should return a dictionary + containing the following keys: 'model', 'trainer', + 'train_dataloaders'; each representing the lightning model, the + trainer and the training dataloader(s) respectively. + + max_trial_count (int): + Required. The desired total number of trials. + parallel_trial_count (int): + Required. The desired number of trials to run in parallel. For + pytorch lightning, currently we only support parallel_trial_count=1. + hparam_space (List[Dict[str, Any]]): + Required. A list of parameter specs each representing a single + tunable parameter. For parameter specs, refer to + https://cloud.google.com/vertex-ai/docs/reference/rest/v1/StudySpec#parameterspec + metric_id (str): + Optional. The ID of the metric. Must be one of 'roc_auc', 'f1', + 'precision', 'recall', 'accuracy', 'mae', 'mape', 'r2', 'rmse', + 'rmsle', 'mse' or 'custom'. Only 'accuracy' supports multi-class + classification. Set to 'custom' to use a custom score function. + Default is 'accuracy'. + metric_goal (str): + Optional. The optimization goal of the metric. Must be one of + 'GOAL_TYPE_UNSPECIFIED', 'MAXIMIZE' and 'MINIMIZE'. + 'GOAL_TYPE_UNSPECIFIED' defaults to maximize. Default is + 'MAXIMIZE'. Refer to + https://cloud.google.com/vertex-ai/docs/reference/rest/v1/StudySpec#goaltype + for details on goal types. + max_failed_trial_count (int): + Optional. The number of failed trials that need to be seen before + failing the tuning process. If 0, the tuning process only fails + when all trials have failed. Default is 0. + search_algorithm (str): + Optional. The search algorithm specified for the study. Must be + one of 'ALGORITHM_UNSPECIFIED', 'GRID_SEARCH' and 'RANDOM_SEARCH'. + Default is 'ALGORITHM_UNSPECIFIED'. Refer to + https://cloud.google.com/vertex-ai/docs/reference/rest/v1/StudySpec#algorithm + for details on the study algorithms. + project (str): + Optional. Project for the study. If not set, project set in + vertexai.init will be used. + location (str): + Optional. Location for the study. If not set, location set in + vertexai.init will be used. + study_display_name_prefix (str): + Optional. Prefix of the study display name. Default is + 'vizier-hyperparameter-tuner-study'. + """ + self.get_model_func = get_model_func + self.max_trial_count = max_trial_count + self.parallel_trial_count = parallel_trial_count + self.hparam_space = hparam_space + + if metric_id not in _SUPPORTED_METRIC_IDS: + raise ValueError( + f"Unsupported metric_id {metric_id}. Supported metric_ids: {_SUPPORTED_METRIC_IDS}" + ) + self.metric_id = metric_id + + self.metric_goal = metric_goal + self.max_failed_trial_count = max_failed_trial_count + self.search_algorithm = search_algorithm + + # Initializes Vertex config + self.vertex = configs.VertexConfig() + + # Creates Vizier client, study and trials + project = project or vertexai.preview.global_config.project + location = location or vertexai.preview.global_config.location + self.vizier_client, self.study = self._create_study( + project, location, study_display_name_prefix + ) + + # self.models should be a mapping from trial names to trained models. + self.models = {} + + def _create_study( + self, + project: str, + location: str, + study_display_name_prefix: str = _STUDY_NAME_PREFIX, + ) -> Tuple[VizierServiceClient, gca_study.Study]: + """Creates a Vizier study config. + + Args: + project (str): + Project for the study. + location (str): + Location for the study. + study_display_name_prefix (str): + Prefix for the study display name. Default is + 'vizier-hyperparameter-tuner-study'. + Returns: + A Vizier client and the created study. + """ + vizier_client = VizierServiceClient( + client_options=dict(api_endpoint=f"{location}-aiplatform.googleapis.com") + ) + study_config = { + "display_name": f"{study_display_name_prefix}_{uuid.uuid4()}".replace( + "-", "_" + ), + "study_spec": { + "algorithm": self.search_algorithm, + "parameters": self.hparam_space, + "metrics": [{"metric_id": self.metric_id, "goal": self.metric_goal}], + }, + } + parent = f"projects/{project}/locations/{location}" + study = vizier_client.create_study(parent=parent, study=study_config) + return vizier_client, study + + def _suggest_trials(self, num_trials: int) -> List[gca_study.Trial]: + """Suggests trials using the Vizier client. + + During each round of tuning, num_trials number of trials will + be suggested. For each trial, training will be performed locally or + remotely. After training finishes, we use the trained model to measure + the metrics and report the metrics to the trial before marking it as + completed. At the next round of tuning, another parallel_trial_count + of trials will be suggested based on previous measurements. + + Args: + num_trials (int): Required. Number of trials to suggest. + Returns: + A list of suggested trials. + """ + return ( + self.vizier_client.suggest_trials( + { + "parent": self.study.name, + "suggestion_count": num_trials, + "client_id": _CLIENT_ID, + } + ) + .result() + .trials + ) + + def get_best_models(self, num_models: int = 1) -> List[Any]: + """Gets the best models from completed trials. + + Args: + num_models (int): + Optional. The number of best models to return. Default is 1. + + Returns: + A list of best models. + """ + trials = [] + for trial in self.vizier_client.list_trials({"parent": self.study.name}).trials: + if ( + trial.state == gca_study.Trial.State.SUCCEEDED + and trial.name in self.models + ): + trials.append((trial.name, trial.final_measurement.metrics[0].value)) + + maximize = True if self.metric_goal == "MAXIMIZE" else False + trials.sort(reverse=maximize, key=lambda x: x[1]) + + return [self.models[trial[0]] for trial in trials[:num_models]] + + def _create_train_and_test_splits( + self, + x: PandasData, + y: Union[PandasData, str], + test_fraction: float = _DEFAULT_TEST_FRACTION, + ) -> Tuple[PandasData, PandasData, Optional[PandasData], PandasData]: + """Creates train and test splits if no manual test splits provided. + + Depending on the model to be tuned, the training step may take in either + one or two DataFrames for training data and target values. + + 1. Two pandas DataFrames: + - One contains training data and the other contains target values. + - Four DataFrames will be returned, ie. X_train, X_test, y_train, + y_test. + 2. One pandas DataFrame: + - Contains both training data and target values. + - Only three DataFrames will be returned, ie. X_train, X_test, + y_test. X_train contains both training data and target values. The + testing splits need to be separated into data and values to make + predictions. + + Args: + x (pandas.DataFrame): + Required. A pandas DataFrame for the dataset. If it contains the + target column, y must be a string specifying the target column + name. + y (Union[pandas.DataFrame, str]): + Required. A pandas DataFrame containing target values for the + dataset or a string specifying the target column name. + test_fraction (float): + Optional. The proportion of the dataset to include in the test + split. eg. test_fraction=0.25 for a pandas Dataframe with 100 + rows would result in 75 rows for training and 25 rows for + testing. Default is 0.25. + Returns: + A tuple containing training data, testing data, training target + values, testing target values. Training target values may be None if + training data contrains training target. + """ + if test_fraction <= 0 or test_fraction >= 1: + raise ValueError( + "test_fraction must be greater than 0 and less than 1 but was " + f"{test_fraction}." + ) + + if isinstance(y, str): + try: + import pandas as pd + except ImportError: + raise ImportError( + "pandas must be installed to create train and test splits " + "with a target column name." + ) from None + x_train, x_test = train_test_split(x, test_size=test_fraction) + y_test = pd.DataFrame(x_test.pop(y)) + return x_train, x_test, None, y_test + else: + return train_test_split(x, y, test_size=test_fraction) + + def _evaluate_model( + self, model: Any, x_test: PandasData, y_test: PandasData + ) -> Tuple[Any, float]: + """Evaluates a model. + + Metrics are calculated based on the metric_id set by the user. After + reporting the metrics, mark the trial as complete. Only completed trials + can be listed as optimal trials. + + Supported metric_id: 'roc_auc', 'f1', 'precision', 'recall', 'accuracy', + 'mae', 'mape', 'r2', 'rmse', 'rmsle', 'mse' or 'custom'. Only 'accuracy' + supports multi-class classification. + + When metric_id is 'custom', the model must provide a score() function to + provide a metric value. Otherwise, the model must provide a predict() + function that returns array-like prediction results. + + e.g. + class ExampleModel: + def score(x_test, y_test): + # Code to make predictions and calculate metrics + return custom_metric(y_true=y_test, y_pred=self.predict(x_test)) + + Args: + model (Any): + Required. The model trained during the trial. + x_test (pandas.DataFrame): + Required. The testing data. + y_test (pandas.DataFrame): + Required. The testing values. + Returns: + A tuple containing the model and the corresponding metric value. + """ + if self.metric_id == _CUSTOM_METRIC_ID: + metric_value = model.score(x_test, y_test) + else: + if self.metric_id in _SUPPORTED_METRIC_IDS: + predictions = model.predict(x_test) + # Keras outputs probabilities. Must convert to output label. + if ( + supported_frameworks._is_keras(model) + and self.metric_id in _SUPPORTED_CLASSIFICATION_METRIC_IDS + ): + if isinstance(predictions, pd.DataFrame): + predictions = predictions.to_numpy() + predictions = ( + predictions.argmax(axis=-1) + if predictions.shape[-1] > 1 + else (predictions > 0.5).astype("int32") + ) + metric_value = _SUPPORTED_METRIC_FUNCTIONS[self.metric_id]( + y_test, predictions + ) + else: + raise ValueError( + f"Unsupported metric_id {self.metric_id}. Supported metric_ids: {_SUPPORTED_METRIC_IDS}" + ) + return (model, metric_value) + + def _add_model_and_report_trial_metrics( + self, trial_name: str, trial_output: Optional[Tuple[Any, float]] + ) -> None: + """Adds a model to the dictionary of trained models and report metrics. + + If trial_output is None, it means that the trial has failed and should + be marked as infeasible. + + Args: + trial_name (str): + Required. The trial name. + trial_output (Optional[Tuple[Any, float]]): + Required. A tuple containing the model and the metric value, or + None if the trial has failed. + """ + if trial_output is not None: + model, metric_value = trial_output + self.vizier_client.complete_trial( + { + "name": trial_name, + "final_measurement": { + "metrics": [ + {"metric_id": self.metric_id, "value": metric_value} + ] + }, + } + ) + self.models[trial_name] = model + else: + self.vizier_client.complete_trial( + {"name": trial_name, "trial_infeasible": True} + ) + + def _get_model_param_type_mapping(self): + """Gets a mapping from parameter_id to its type. + + Returns: + A mapping from parameter id to its type. + """ + model_param_type_mapping = {} + for param in self.hparam_space: + param_id = param["parameter_id"] + if "double_value_spec" in param: + param_type = float + elif "integer_value_spec" in param: + param_type = int + elif "categorical_value_spec" in param: + param_type = str + elif "discrete_value_spec" in param: + param_type = type(param["discrete_value_spec"]["values"][0]) + else: + raise ValueError( + f"Invalid hparam_space configuration for parameter {param_id}" + ) + model_param_type_mapping[param_id] = param_type + + return model_param_type_mapping + + def _set_model_parameters( + self, + trial: gca_study.Trial, + fixed_init_params: Optional[Dict[Any, Any]] = None, + fixed_runtime_params: Optional[Dict[Any, Any]] = None, + ) -> Tuple[Any, Dict[Any, Any]]: + """Returns a model intialized with trial parameters and a dictionary of runtime parameters. + + Initialization parameters are passed to the get_model_func. Runtime parameters + will be passed to the model's fit() or @developer.mark.train()-decorated + method outside of this function. + + Args: + trial (gca_study.Trial): Required. A trial suggested by Vizier. + fixed_init_params (Dict[Any, Any]): Optional. A dictionary of fixed + parameters to be passed to get_model_func. + fixed_runtime_params (Dict[Any, Any]): Optional. A dictionary of fixed + runtime parameters. + + Returns: + A model initialized using parameters from the specified trial and + a dictionary of runtime parameters. + """ + model_init_params = {} + model_runtime_params = {} + get_model_func_binding = inspect.signature(self.get_model_func).parameters + + model_param_type_mapping = self._get_model_param_type_mapping() + + for param in trial.parameters: + param_id = param.parameter_id + param_value = ( + model_param_type_mapping[param_id](param.value) + if param_id in model_param_type_mapping + else param.value + ) + if param_id in get_model_func_binding: + model_init_params[param_id] = param_value + else: + model_runtime_params[param_id] = param_value + + if fixed_init_params: + model_init_params.update(fixed_init_params) + if fixed_runtime_params: + model_runtime_params.update(fixed_runtime_params) + + return self.get_model_func(**model_init_params), model_runtime_params + + def _is_remote(self, train_method: VertexRemoteFunctor) -> bool: + """Checks if a train method will be executed locally or remotely. + + The train method will be executed remotely if: + - The train method's vertex config sets remote to True (eg. + train.vertex.remote=True) + - Or, .vertex.remote is not set but the global config defaults + remote to True. (eg. vertexai.preview.init(remote=True, ...)) + + Otherwise, the train method will be executed locally. + + Args: + train_method (VertexRemoteFunctor): + Required. The train method. + Returns: + Whether the train method will be executed locally or remotely. + """ + return train_method.vertex.remote or ( + train_method.vertex.remote is None and vertexai.preview.global_config.remote + ) + + def _override_staging_bucket( + self, train_method: VertexRemoteFunctor, trial_name: str + ) -> None: + """Overrides the staging bucket for a train method. + + A staging bucket must be specified by: + - The train method's training config. + eg. train.vertex.remote_config.staging_bucket = ... + - Or, .vertex.remote_config.staging_bucket is not set, but a + default staging bucket is specified in the global config. + eg. vertexai.init(staging_bucket=...) + + The staging bucket for each trial is overriden so that each trial uses + its own directory. + + Args: + train_method (VertexRemoteFunctor): + Required. The train method. + trial_name (str): Required. The trial name. + Raises: + ValueError if no staging bucket specified and no default staging + bucket set. + """ + staging_bucket = ( + train_method.vertex.remote_config.staging_bucket + or vertexai.preview.global_config.staging_bucket + ) + if not staging_bucket: + raise ValueError( + "No default staging bucket set. " + "Please call `vertexai.init(staging_bucket='gs://my-bucket')." + ) + train_method.vertex.remote_config.staging_bucket = os.path.join( + staging_bucket, + "-".join(trial_name.split("/")[:-1]), + trial_name.split("/")[-1], + ) + + def _get_vertex_model_train_method_and_params( + self, + model: remote.VertexModel, + x_train: PandasData, + y_train: Optional[PandasData], + x_test: PandasData, + y_test: PandasData, + trial_name: str, + ) -> Tuple[VertexRemoteFunctor, Dict[str, Any]]: + """Gets the train method for a VertexModel model and data parameters. + + Supported parameter names: + - Training data: ['X', 'X_train', 'x', 'x_train', 'training_data']. + - Training target values: ['y', 'y_train']. If not provided, training + data should contain target values. + - Testing data: ['X_test', 'x_test', 'validation_data']. + - Testing target values: ['y_test']. If not provided, testing data + should contain target values. + + If remote mode is turned on, overrides the training staging bucket for + each trial. + + Args: + model (remote.VertexModel): + Required. An instance of VertexModel. + x_train (pandas.DataFrame): + Required. Training data. + y_train (Optional[pandas.DataFrame]): + Required. Training target values. If None, x_train should + include training target values. + x_test (pandas.DataFrame): + Required. Testing data. + y_test (pandas.DataFrame): + Required. Testing target values. + trial_name (str): + Required. The trial name. + Returns: + The train method for the Vertex model and data params. + Raises: + ValueError if there is no remote executable train method. + """ + data_params = {} + for _, attr_value in inspect.getmembers(model): + if isinstance(attr_value, VertexRemoteFunctor) and ( + attr_value._remote_executor == training.remote_training + or attr_value._remote_executor == remote_container_training.train + ): + params = inspect.signature(attr_value).parameters + for param in params: + if param in _TRAINING_DATA_PARAMS: + data_params[param] = x_train + elif param in _TRAINING_TARGET_VALUE_PARAMS: + data_params[param] = y_train + elif param in _X_TEST_PARAMS: + data_params[param] = x_test + elif param == _Y_TEST: + data_params[_Y_TEST] = y_test + elif param == _VALIDATION_DATA: + data_params[_VALIDATION_DATA] = pd.concat( + [x_test, y_test], axis=1 + ) + if self._is_remote(attr_value): + self._override_staging_bucket(attr_value, trial_name) + return (attr_value, data_params) + raise ValueError("No remote executable train method.") + + def _get_lightning_train_method_and_params( + self, + model: Dict[str, Any], + trial_name: str, + ): + """Gets the train method and parameters for a Lightning model. + + Given the lightning model, the trainer and the training dataloader(s), + returns trainer.fit and the parameters containing the model and the + training dataloader(s). If the trainer is enabled to run remotely and + remote mode is turned on, overrides the training staging bucket for + each trial. + + Training data and target values have already been passed into the + training dataloader(s), so no additional runtime parameters need to be + set. + + Args: + model (Dict[str, Any]): + Required. A dictionary containing the following keys: 'model', + 'trainer', 'train_dataloaders'; each representing the lightning + model, the trainer and the training dataloader(s) respectively. + trial_name (str): + Required. The trial name. + Returns: + The train method and its parameters for the lightning model. + """ + trainer = model["trainer"] + if isinstance(trainer.fit, VertexRemoteFunctor) and self._is_remote( + trainer.fit + ): + self._override_staging_bucket(trainer.fit, trial_name) + return trainer.fit, { + "model": model["model"], + "train_dataloaders": model["train_dataloaders"], + } + + def _run_trial( + self, + x_train: PandasData, + y_train: Optional[PandasData], + x_test: PandasData, + y_test: PandasData, + trial: gca_study.Trial, + fixed_init_params: Optional[Dict[Any, Any]] = None, + fixed_runtime_params: Optional[Dict[Any, Any]] = None, + ) -> Optional[Tuple[Any, float]]: + """Runs a trial. + + This function sets model parameters and train method parameters, + launches either local or remote training, and evaluates the model. With + parallel tuning, this function can be the target function that would be + executed in parallel. + + Args: + x_train (pandas.DataFrame): + Required. Training data. + y_train (Optional[pandas.DataFrame]): + Required. Training target values. If None, x_train should + include training target values. + x_test (pandas.DataFrame): + Required. Testing data. + y_test (pandas.DataFrame): + Required. Testing target values. + trial (gca_study.Trial): Required. A trial suggested by Vizier. + fixed_init_params (Dict[Any, Any]): Optional. A dictionary of fixed + parameters to be passed to get_model_func. + fixed_runtime_params (Dict[Any, Any]): Optional. A dictionary of + fixed runtime parameters. + Returns: + If the trial is feasible, returns a tuple of the trained model and + its corresponding metric value. If the trial is infeasible, returns + None. + """ + model, model_runtime_params = self._set_model_parameters( + trial, fixed_init_params, fixed_runtime_params + ) + + if isinstance(model, remote.VertexModel): + train_method, params = self._get_vertex_model_train_method_and_params( + model, + x_train, + y_train, + x_test, + y_test, + trial.name, + ) + elif isinstance(model, dict): + train_method, params = self._get_lightning_train_method_and_params( + model, + trial.name, + ) + elif supported_frameworks._is_keras(model): + train_method, params = self._get_train_method_and_params( + model, x_train, y_train, trial.name, params=["x", "y"] + ) + elif supported_frameworks._is_sklearn(model): + train_method, params = self._get_train_method_and_params( + model, x_train, y_train, trial.name, params=["X", "y"] + ) + else: + raise ValueError(f"Unsupported model type {type(model)}") + + model_runtime_params.update(params) + + try: + train_method(**model_runtime_params) + except Exception as e: + _LOGGER.warning(f"Trial {trial.name} failed: {e}.") + return None + + if isinstance(model, dict): + # For lightning, evaluate the model and keep track of the dictionary + # containing the model, the trainer, and the training dataloader(s). + _, metric_value = self._evaluate_model(model["model"], x_test, y_test) + return model, metric_value + + return self._evaluate_model(model, x_test, y_test) + + def _get_train_method_and_params( + self, + model: Any, + x_train: PandasData, + y_train: Optional[PandasData], + trial_name: str, + params: List[str], + ) -> Tuple[VertexRemoteFunctor, Dict[str, Any]]: + """Gets the train method for an Sklearn or Keras model and data parameters. + + Args: + model (Any): + Required. An instance of an Sklearn or Keras model. + x_train (pandas.DataFrame): + Required. Training data. + y_train (Optional[pandas.DataFrame]): + Required. Training target values. + trial_name (str): + Required. The trial name. + params (str): + Required. The list of data parameters. + Returns: + The train method for the model and data params. + Raises: + ValueError if there is no remote executable train method. + """ + data_params = {} + if isinstance(model.fit, VertexRemoteFunctor) and self._is_remote(model.fit): + self._override_staging_bucket(model.fit, trial_name) + attr_params = inspect.signature(model.fit).parameters + for param in params: + if param not in attr_params: + raise ValueError(f"Invalid data parameter {param}.") + if param in _OSS_TRAINING_DATA_PARAMS: + data_params[param] = x_train + elif param == _Y_DATA_PARAM: + data_params[param] = y_train + return (model.fit, data_params) + + def fit( + self, + x: PandasData, + y: Union[PandasData, str], + x_test: Optional[PandasData] = None, + y_test: Optional[PandasData] = None, + test_fraction: Optional[float] = _DEFAULT_TEST_FRACTION, + **kwargs, + ): + """Runs Vizier-backed hyperparameter tuning for a model. + + Extra runtime arguments will be forwarded to a model's fit() or + @vertexai.preview.developer.mark.train()-decorated method. + + Example Usage: + ``` + def get_model_func(parameter_a, parameter_b): + # parameter_c is non-tunable + parameter_c = 10 + return ExampleModel(parameter_a, parameter_b, parameter_c) + + x, y = pd.DataFrame(...), pd.DataFrame(...) + tuner = VizierHyperparameterTuner(get_model_func, ...) + # num_epochs will be passed to ExampleModel.fit() + # (ex: ExampleModel.fit(x, y, num_epochs=5)) + tuner.fit(x, y, num_epochs=5) + ``` + + Args: + x (pandas.DataFrame): + Required. A pandas DataFrame for the dataset. If it contains the + target column, y must be a string specifying the target column + name. + y (Union[pandas.DataFrame, str]): + Required. A pandas DataFrame containing target values for the + dataset or a string specifying the target column name. + x_test (pandas.DataFrame): + Optional. A pandas DataFrame for the test dataset. If not provided, + X will be split into X_train and X_test based on test_fraction. + y_test (pandas.DataFrame): + Optional. A pandas DataFrame containing target values for the test + dataset. If not provided, y will be split into y_train and t_test + based on test_fraction. + test_fraction (float): + Optional. The proportion of the dataset to include in the test + split. eg. test_fraction=0.25 for a pandas Dataframe with 100 + rows would result in 75 rows for training and 25 rows for + testing. Default is 0.25. + **kwargs (Any): + Optional. Keyword arguments to pass to the model's fit(), + or @vertexai.preview.developer.mark.train()-decorated method. + + Returns: + A model initialized using parameters from the specified trial and + a dictionary of runtime parameters. + """ + if x_test is None or y_test is None or x_test.empty or y_test.empty: + x, x_test, y, y_test = self._create_train_and_test_splits( + x, y, test_fraction + ) + + # Fixed params that are passed to get_model_func. + # Lightning, for example, requires X and y to be passed to get_model_func. + fixed_init_params = {} + get_model_func_binding = inspect.signature(self.get_model_func).parameters + for x_param_name in _TRAINING_X_PARAMS: + if x_param_name in get_model_func_binding: + # Temporary solution for b/295191253 + # TODO(b/295191253) + if self.parallel_trial_count > 1: + raise ValueError( + "Currently pytorch lightning only supports `parallel_trial_count = 1`. " + f"In {self} it was set to {self.parallel_trial_count}." + ) + fixed_init_params[x_param_name] = x + break + for y_param_name in _TRAINING_TARGET_VALUE_PARAMS: + if y_param_name in get_model_func_binding: + fixed_init_params[y_param_name] = y + break + + # Disable remote job logs when running trials. + logging.getLogger("vertexai.remote_execution").disabled = True + try: + num_completed_trials = 0 + num_failed_trials = 0 + while num_completed_trials < self.max_trial_count: + num_new_trials = min( + (self.max_trial_count - num_completed_trials), + self.parallel_trial_count, + ) + suggested_trials = self._suggest_trials(num_new_trials) + inputs = [ + (x, y, x_test, y_test, trial, fixed_init_params, kwargs) + for trial in suggested_trials + ] + _LOGGER.info( + f"Number of completed trials: {num_completed_trials}, " + f"Number of new trials: {num_new_trials}." + ) + + with concurrent.futures.ThreadPoolExecutor( + max_workers=num_new_trials + ) as executor: + trial_outputs = list( + executor.map(lambda t: self._run_trial(*t), inputs) + ) + + for i in range(num_new_trials): + trial_output = trial_outputs[i] + self._add_model_and_report_trial_metrics( + suggested_trials[i].name, trial_output + ) + if not trial_output: + num_failed_trials += 1 + if num_failed_trials == self.max_failed_trial_count: + raise ValueError("Maximum number of failed trials reached.") + num_completed_trials += num_new_trials + except Exception as e: + raise e + finally: + # Enable remote job logs after trials are complete. + logging.getLogger("vertexai.remote_execution").disabled = False + + if num_failed_trials == num_completed_trials: + raise ValueError("All trials failed.") + + _LOGGER.info( + f"Number of completed trials: {num_completed_trials}. Tuning complete." + ) diff --git a/vertexai/preview/initializer.py b/vertexai/preview/initializer.py new file mode 100644 index 0000000000..88bb1aba21 --- /dev/null +++ b/vertexai/preview/initializer.py @@ -0,0 +1,92 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional + +from google.cloud import aiplatform +from google.cloud.aiplatform import base +from vertexai.preview._workflow.executor import ( + persistent_resource_util, +) + + +_LOGGER = base.Logger(__name__) + + +class _Config: + """Store common configurations and current workflow for remote execution.""" + + def __init__(self): + self._remote = False + self._cluster_name = None + + def init( + self, + *, + remote: Optional[bool] = None, + autolog: Optional[bool] = None, + ): + """Updates preview global parameters for Vertex remote execution. + + Args: + remote (bool): + Optional. A global flag to indicate whether or not a method will + be executed remotely. Default is Flase. The method level remote + flag has higher priority than this global flag. + autolog (bool): + Optional. Whether or not to turn on autologging feature for remote + execution. To learn more about the autologging feature, see + https://cloud.google.com/vertex-ai/docs/experiments/autolog-data. + """ + if remote is not None: + self._remote = remote + + if autolog is True: + aiplatform.autolog() + elif autolog is False: + aiplatform.autolog(disable=True) + + cluster = None + if cluster is not None: + self._cluster_name = cluster.name + cluster_resource_name = f"projects/{self.project}/locations/{self.location}/persistentResources/{self._cluster_name}" + cluster_exists = persistent_resource_util.check_persistent_resource( + cluster_resource_name=cluster_resource_name + ) + if cluster_exists: + _LOGGER.info(f"Using existing cluster: {cluster_resource_name}") + return + # create a default one + persistent_resource_util.create_persistent_resource( + cluster_resource_name=cluster_resource_name + ) + + @property + def remote(self): + return self._remote + + @property + def autolog(self): + return aiplatform.utils.autologging_utils._is_autologging_enabled() + + @property + def cluster_name(self): + return self._cluster_name + + def __getattr__(self, name): + return getattr(aiplatform.initializer.global_config, name) + + +global_config = _Config() diff --git a/vertexai/preview/tabular_models/__init__.py b/vertexai/preview/tabular_models/__init__.py new file mode 100644 index 0000000000..d96f82480a --- /dev/null +++ b/vertexai/preview/tabular_models/__init__.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from vertexai.preview.tabular_models import tabnet_trainer + + +TabNetTrainer = tabnet_trainer.TabNetTrainer + + +__all__ = ("TabNetTrainer",) diff --git a/vertexai/preview/tabular_models/tabnet_trainer.py b/vertexai/preview/tabular_models/tabnet_trainer.py new file mode 100644 index 0000000000..93218fd4bf --- /dev/null +++ b/vertexai/preview/tabular_models/tabnet_trainer.py @@ -0,0 +1,412 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import inspect +from typing import Any + +from google.cloud.aiplatform import base +from google.cloud.aiplatform.utils import gcs_utils +from vertexai.preview import developer +from vertexai.preview._workflow.driver import remote +from vertexai.preview._workflow.shared import configs +from vertexai.preview.developer import remote_specs + + +try: + import pandas as pd + + PandasData = pd.DataFrame + +except ImportError: + PandasData = Any + + +_LOGGER = base.Logger(__name__) + +# Constants for TabNetTrainer +_TABNET_TRAINING_IMAGE = "us-docker.pkg.dev/vertex-ai-restricted/automl-tabular/tabnet-training:20230605_1325" + +_TABNET_FIT_DISPLAY_NAME = "fit" +_TABNET_MACHINE_TYPE = "c2-standard-16" +_TABNET_BOOT_DISK_TYPE = "pd-ssd" +_TABNET_BOOT_DISK_SIZE_GB = 100 + +_CLASSIFICATION = "classification" +_REGRESSION = "regression" + + +class TabNetTrainer(remote.VertexModel): + """The TabNet trainer for remote training and prediction.""" + + def __init__( + self, + model_type: str, + target_column: str, + learning_rate: float, + job_dir: str = "", + enable_profiler: bool = False, + cache_data: str = "auto", + seed: int = 1, + large_category_dim: int = 1, + large_category_thresh: int = 300, + yeo_johnson_transform: bool = False, + weight_column: str = "", + max_steps: int = -1, + max_train_secs: int = -1, + measurement_selection_type: str = "BEST_MEASUREMENT", + optimization_metric: str = "", + eval_steps: int = 0, + batch_size: int = 100, + eval_frequency_secs: int = 600, + feature_dim: int = 64, + feature_dim_ratio: float = 0.5, + num_decision_steps: int = 6, + relaxation_factor: float = 1.5, + decay_every: float = 100.0, + decay_rate: float = 0.95, + gradient_thresh: float = 2000.0, + sparsity_loss_weight: float = 0.00001, + batch_momentum: float = 0.95, + batch_size_ratio: float = 0.25, + num_transformer_layers: int = 4, + num_transformer_layers_ratio: float = 0.25, + class_weight: float = 1.0, + loss_function_type: str = "default", + alpha_focal_loss: float = 0.25, + gamma_focal_loss: float = 2.0, + ): + """Initializes a TabNetTrainer instance. + + is_remote_trainer is always set to True because TabNetTrainer only + supports remote training. + + Args: + model_type (str): + Required. The type of prediction the model is to produce. + 'classification' or 'regression'. + target_column (str): + Required. The target column name. + learning_rate (float): + Required. The learning rate used by the linear optimizer. + job_dir (str): + Optional. The GCS directory for reading and writing inside the + the custom job. If provided, must start with 'gs://'. Default is + ''. + enable_profiler (bool): + Optional. Enables profiling and saves a trace during evaluation. + Default is False. + cache_data (str): + Optional. Whether to cache data or not. If set to 'auto', + caching is determined based on the dataset size. Default is + 'auto'. + seed (int): + Optional. Seed to be used for this run. Default is 1. + large_category_dim (int): + Optional. Embedding dimension for categorical feature with + large number of categories. Default is 1. + large_category_thresh (int): + Optional. Threshold for number of categories to apply + large_category_dim embedding dimension to. Default is 300. + yeo_johnson_transform (bool): + Optional. Enables trainable Yeo-Johnson power transform. Default + is False. + weight_column (str): + Optional. The weight column name. ''(empty string) for no + weight column. Default is ''(empty string). + max_steps (int): + Optional. Number of steps to run the trainer for. -1 for no + maximum steps. Default is -1. + max_train_seconds (int): + Optional. Amount of time in seconds to run the trainer for. -1 + for no maximum train seconds. Default is -1. + measurement_selection_type (str): + Optional. Which measurement to use if/when the service + automatically selects the final measurement from previously + reported intermediate measurements. One of 'BEST_MEASUREMENT' + or 'LAST_MEASUREMENT'. Default is 'BEST_MEASUREMENT'. + optimization_metric (str): + Optional. Optimization metric used for + `measurement_selection_type`. ''(empty string) for using the + default value: 'rmse' for regression and 'auc' for + classification. Default is ''(empty string). + eval_steps (int): + Optional. Number of steps to run evaluation for. If not + specified or negative, it means run evaluation on the whole + validation dataset. If set to 0, it means run evaluation for a + fixed number of samples. Default is 0. + batch_size (int): + Optional. Batch size for training. Default is 100. + eval_frequency_secs (int): + Optional. Frequency at which evaluation and checkpointing will + take place. Default is 600. + feature_dim (int): + Optional. Dimensionality of the hidden representation in feature + transformation block. Default is 64. + feature_dim_ratio (float): + Optional. The ratio of output dimension (dimensionality of the + outputs of each decision step) to feature dimension. Default is + 0.5. + num_decision_steps (int): + Optional. Number of sequential decision steps. Default is 6. + relaxation_factor (float): + Optional. Relaxation factor that promotes the reuse of each + feature at different decision steps. When it is 1, a feature is + enforced to be used only at one decision step and as it + increases, more flexibility is provided to use a feature at + multiple decision steps. Default is 1.5. + decay_every (float): + Optional. Number of iterations for periodically applying + learning rate decaying. Default is 100.0. + decay_rate (float): + Optional. Learning rate decaying. Default is 0.95. + gradient_thresh (float): + Optional. Threshold for the norm of gradients for clipping. + Default is 2000.0. + sparsity_loss_weight (float): + Optional. Weight of the loss for sparsity regularization + (increasing it will yield more sparse feature selection). + Default is 0.00001. + batch_momentum (float): + Optional. Momentum in ghost batch normalization. Default is + 0.95. + batch_size_ratio (float): + Optional. The ratio of virtual batch size (size of the ghost + batch normalization) to batch size. Default is 0.25. + num_transformer_layers (int): + Optional. The number of transformer layers for each decision + step. used only at one decision step and as it increases, more + flexibility is provided to use a feature at multiple decision + steps. Default is 4. + num_transformer_layers_ratio (float): + Optional. The ratio of shared transformer layer to transformer + layers. Default is 0.25. + class_weight (float): + Optional. The class weight is used to compute a weighted cross + entropy which is helpful in classifying imbalanced dataset. Only + used for classification. Default is 1.0. + loss_function_type (str): + Optional. Loss function type. Loss function in classification + [cross_entropy, weighted_cross_entropy, focal_loss], default is + cross_entropy. Loss function in regression: [rmse, mae, mse], + default is mse. "default" for default values. Default is + "default". + alpha_focal_loss (float): + Optional. Alpha value (balancing factor) in focal_loss function. + Only used for classification. Default is 0.25. + gamma_focal_loss (float): + Optional. Gamma value (modulating factor) for focal loss for + focal loss. Only used for classification. Default is 2.0. + Raises: + ValueError if job_dir is set to an invalid GCS path. + """ + super().__init__() + if job_dir: + gcs_utils.validate_gcs_path(job_dir) + sig = inspect.signature(self.__init__) + self._binding = sig.bind( + model_type, + target_column, + learning_rate, + job_dir, + enable_profiler, + cache_data, + seed, + large_category_dim, + large_category_thresh, + yeo_johnson_transform, + weight_column, + max_steps, + max_train_secs, + measurement_selection_type, + optimization_metric, + eval_steps, + batch_size, + eval_frequency_secs, + feature_dim, + feature_dim_ratio, + num_decision_steps, + relaxation_factor, + decay_every, + decay_rate, + gradient_thresh, + sparsity_loss_weight, + batch_momentum, + batch_size_ratio, + num_transformer_layers, + num_transformer_layers_ratio, + class_weight, + loss_function_type, + alpha_focal_loss, + gamma_focal_loss, + ).arguments + self._binding["is_remote_trainer"] = True + self.model = None + + @developer.mark._remote_container_train( + image_uri=_TABNET_TRAINING_IMAGE, + additional_data=[ + remote_specs._InputParameterSpec( + "training_data", + argument_name="training_data_path", + serializer="parquet", + ), + remote_specs._InputParameterSpec( + "validation_data", + argument_name="validation_data_path", + serializer="parquet", + ), + remote_specs._InputParameterSpec("model_type"), + remote_specs._InputParameterSpec("target_column"), + remote_specs._InputParameterSpec("learning_rate"), + remote_specs._InputParameterSpec("job_dir"), + remote_specs._InputParameterSpec("enable_profiler"), + remote_specs._InputParameterSpec("cache_data"), + remote_specs._InputParameterSpec("seed"), + remote_specs._InputParameterSpec("large_category_dim"), + remote_specs._InputParameterSpec("large_category_thresh"), + remote_specs._InputParameterSpec("yeo_johnson_transform"), + remote_specs._InputParameterSpec("weight_column"), + remote_specs._InputParameterSpec("max_steps"), + remote_specs._InputParameterSpec("max_train_secs"), + remote_specs._InputParameterSpec("measurement_selection_type"), + remote_specs._InputParameterSpec("optimization_metric"), + remote_specs._InputParameterSpec("eval_steps"), + remote_specs._InputParameterSpec("batch_size"), + remote_specs._InputParameterSpec("eval_frequency_secs"), + remote_specs._InputParameterSpec("feature_dim"), + remote_specs._InputParameterSpec("feature_dim_ratio"), + remote_specs._InputParameterSpec("num_decision_steps"), + remote_specs._InputParameterSpec("relaxation_factor"), + remote_specs._InputParameterSpec("decay_every"), + remote_specs._InputParameterSpec("decay_rate"), + remote_specs._InputParameterSpec("gradient_thresh"), + remote_specs._InputParameterSpec("sparsity_loss_weight"), + remote_specs._InputParameterSpec("batch_momentum"), + remote_specs._InputParameterSpec("batch_size_ratio"), + remote_specs._InputParameterSpec("num_transformer_layers"), + remote_specs._InputParameterSpec("num_transformer_layers_ratio"), + remote_specs._InputParameterSpec("class_weight"), + remote_specs._InputParameterSpec("loss_function_type"), + remote_specs._InputParameterSpec("alpha_focal_loss"), + remote_specs._InputParameterSpec("gamma_focal_loss"), + remote_specs._InputParameterSpec("is_remote_trainer"), + remote_specs._OutputParameterSpec("output_model_path"), + ], + remote_config=configs.DistributedTrainingConfig( + display_name=_TABNET_FIT_DISPLAY_NAME, + machine_type=_TABNET_MACHINE_TYPE, + boot_disk_type=_TABNET_BOOT_DISK_TYPE, + boot_disk_size_gb=_TABNET_BOOT_DISK_SIZE_GB, + ), + ) + def fit(self, training_data: PandasData, validation_data: PandasData) -> None: + """Trains a tabnet model in a custom job. + + After the custom job successfully finishes, load the model and set it to + self.model to enable prediction. If TensorFlow is not installed, the + model will not be loaded. + + Training config can be overriden by setting the training config. + + Example Usage: + ` + tabnet_trainer = TabNetTrainer(...) + tabnet_trainer.fit.vertex.remote_config.staging_bucket = 'gs://...' + tabnet_trainer.fit.vertex.remote_config.display_name = 'example' + tabnet_trainer.fit(...) + ` + + PandasData refers to a pandas DataFrame. Each data frame should meet the + following requirements: + 1. All entries should be numerical (no string, array or object). + 2. For categorical columns, the entries should be integers. In + addition, the column type should be set to 'category'. Otherwise, it + will be treated as numerical columns. + 3. The column names should be string. + + Args: + training_data (pandas.DataFrame): + Required. A pandas DataFrame for training. + validation_data (pandas.DataFrame): + Required. A pandas DataFrame for validation. + """ + try: + import tensorflow.saved_model as tf_saved_model + + self.model = tf_saved_model.load(self.output_model_path) + except ImportError: + _LOGGER.warning( + "TensorFlow must be installed to load the trained model. The model is stored at %s", + self.output_model_path, + ) + + def predict(self, input_data: PandasData) -> PandasData: + """Makes prediction on input data through a trained model. + + Unlike in training and validation data, the categorical columns in + prediction input data can have dtypes either 'category' or 'int', with + 'int' being numpy.int64 in pandas DataFrame. + + + Args: + input_data (pandas.DataFrame): + Required. An input Pandas DataFrame containing data for + prediction. It will be preprocessed into a dictionary as the + input for to the trained model. + Returns: + Prediction results in the format of pandas DataFrame. + """ + try: + import tensorflow as tf + except ImportError: + raise ImportError( + "TensorFlow must be installed to make predictions." + ) from None + + if self.model is None: + if not hasattr(self, "output_model_path") or self.output_model_path is None: + raise ValueError("No trained model. Please call .fit first.") + self.model = tf.saved_model.load(self.output_model_path) + + prediction_inputs = {} + for col in input_data.columns: + if input_data[col].dtypes == "category": + dtype = tf.int64 + else: + dtype = tf.dtypes.as_dtype(input_data[col].dtypes) + prediction_inputs[col] = tf.constant(input_data[col].to_list(), dtype=dtype) + prediction_outputs = self.model.signatures["serving_default"]( + **prediction_inputs + ) + if self._binding["model_type"] == _CLASSIFICATION: + predicted_labels = [] + for score, labels in zip( + prediction_outputs["scores"].numpy(), + prediction_outputs["classes"].numpy().astype(int), + ): + predicted_labels.append(labels[score.argmax()]) + return pd.DataFrame({self._binding["target_column"]: predicted_labels}) + elif self._binding["model_type"] == _REGRESSION: + return pd.DataFrame( + { + self._binding["target_column"]: prediction_outputs["value"] + .numpy() + .reshape(-1) + } + ) + else: + raise ValueError(f"Unsupported model type: {self._binding['model_type']}.")