Skip to content

Commit

Permalink
chore: Instantiate AnySerializer in any_serializer.py to fix edge c…
Browse files Browse the repository at this point in the history
…ase bugs

PiperOrigin-RevId: 562083494
  • Loading branch information
jaycee-li authored and copybara-github committed Sep 2, 2023
1 parent 5978d31 commit 50c1591
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 6 deletions.
32 changes: 30 additions & 2 deletions tests/system/vertexai/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
import vertexai
from tests.system.aiplatform import e2e_base
from vertexai.preview._workflow.executor import training
from vertexai.preview._workflow.serialization_engine import (
any_serializer,
)
from vertexai.preview._workflow.serialization_engine import (
serializers,
)
import pytest
from sklearn.datasets import load_iris
import torch
Expand All @@ -34,15 +40,15 @@
"google-cloud-aiplatform[preview] @ git+https://github.com/googleapis/"
f"python-aiplatform.git@{os.environ['KOKORO_GIT_COMMIT']}"
if os.environ.get("KOKORO_GIT_COMMIT")
else "google-cloud-aiplatform[preview] @ git+https://github.com/googleapis/python-aiplatform.git@copybara_557913723",
else "google-cloud-aiplatform[preview] @ git+https://github.com/googleapis/python-aiplatform.git@main",
)
@mock.patch.object(
training,
"VERTEX_AI_DEPENDENCY_PATH_AUTOLOGGING",
"google-cloud-aiplatform[preview,autologging] @ git+https://github.com/googleapis/"
f"python-aiplatform.git@{os.environ['KOKORO_GIT_COMMIT']}"
if os.environ.get("KOKORO_GIT_COMMIT")
else "google-cloud-aiplatform[preview,autologging] @ git+https://github.com/googleapis/python-aiplatform.git@copybara_557913723",
else "google-cloud-aiplatform[preview,autologging] @ git+https://github.com/googleapis/python-aiplatform.git@main",
)
@pytest.mark.usefixtures(
"prepare_staging_bucket", "delete_staging_bucket", "tear_down_resources"
Expand Down Expand Up @@ -136,6 +142,17 @@ def predict(self, X):
)
model.train(train_loader, num_epochs=100, lr=0.05)

# Assert the right serializer is being used
serializer = any_serializer.AnySerializer()
assert (
serializer._get_predefined_serializer(model.__class__.__mro__[2])
is serializers.TorchModelSerializer
)
assert (
serializer._get_predefined_serializer(train_loader.__class__)
is serializers.TorchDataLoaderSerializer
)

# Remote prediction on Torch custom model
model.predict.vertex.remote_config.display_name = self._make_display_name(
"pytorch-prediction"
Expand All @@ -156,3 +173,14 @@ def predict(self, X):
"pytorch-cpu-uptraining"
)
pulled_model.train(retrain_loader, num_epochs=100, lr=0.05)

# Assert the right serializer is being used
serializer = any_serializer.AnySerializer()
assert (
serializer._get_predefined_serializer(pulled_model.__class__.__mro__[2])
is serializers.TorchModelSerializer
)
assert (
serializer._get_predefined_serializer(retrain_loader.__class__)
is serializers.TorchDataLoaderSerializer
)
38 changes: 36 additions & 2 deletions tests/system/vertexai/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
import vertexai
from tests.system.aiplatform import e2e_base
from vertexai.preview._workflow.executor import training
from vertexai.preview._workflow.serialization_engine import (
any_serializer,
)
from vertexai.preview._workflow.serialization_engine import (
serializers,
)
import pandas as pd
import pytest
from sklearn.datasets import load_iris
Expand All @@ -40,15 +46,15 @@
"google-cloud-aiplatform[preview] @ git+https://github.com/googleapis/"
f"python-aiplatform.git@{os.environ['KOKORO_GIT_COMMIT']}"
if os.environ.get("KOKORO_GIT_COMMIT")
else "google-cloud-aiplatform[preview] @ git+https://github.com/googleapis/python-aiplatform.git@copybara_557913723",
else "google-cloud-aiplatform[preview] @ git+https://github.com/googleapis/python-aiplatform.git@main",
)
@mock.patch.object(
training,
"VERTEX_AI_DEPENDENCY_PATH_AUTOLOGGING",
"google-cloud-aiplatform[preview,autologging] @ git+https://github.com/googleapis/"
f"python-aiplatform.git@{os.environ['KOKORO_GIT_COMMIT']}"
if os.environ.get("KOKORO_GIT_COMMIT")
else "google-cloud-aiplatform[preview,autologging] @ git+https://github.com/googleapis/python-aiplatform.git@copybara_557913723",
else "google-cloud-aiplatform[preview,autologging] @ git+https://github.com/googleapis/python-aiplatform.git@main",
)
@pytest.mark.usefixtures(
"prepare_staging_bucket", "delete_staging_bucket", "tear_down_resources"
Expand Down Expand Up @@ -83,12 +89,26 @@ def test_remote_execution_sklearn(self, shared_state):
)
X_train = transformer.fit_transform(X_train)

# Assert the right serializer is being used
serializer = any_serializer.AnySerializer()
assert (
serializer._get_predefined_serializer(transformer.__class__.__mro__[-2])
is serializers.SklearnEstimatorSerializer
)

# Remote transform on test dataset
transformer.transform.vertex.set_config(
display_name=self._make_display_name("transform"),
)
X_test = transformer.transform(X_test)

# Assert the right serializer is being used
serializer = any_serializer.AnySerializer()
assert (
serializer._get_predefined_serializer(transformer.__class__.__mro__[-2])
is serializers.SklearnEstimatorSerializer
)

# Local transform on retrain data
vertexai.preview.init(remote=False)
X_retrain = transformer.transform(X_retrain)
Expand All @@ -105,6 +125,13 @@ def test_remote_execution_sklearn(self, shared_state):
)
model.fit(X_train, y_train)

# Assert the right serializer is being used
serializer = any_serializer.AnySerializer()
assert (
serializer._get_predefined_serializer(model.__class__.__mro__[-2])
is serializers.SklearnEstimatorSerializer
)

# Remote prediction on sklearn
model.predict.vertex.remote_config.display_name = self._make_display_name(
"sklearn-prediction"
Expand All @@ -122,3 +149,10 @@ def test_remote_execution_sklearn(self, shared_state):

# Retrain model with pandas df on Vertex
pulled_model.fit(X_retrain_df, y_retrain_df)

# Assert the right serializer is being used
serializer = any_serializer.AnySerializer()
assert (
serializer._get_predefined_serializer(pulled_model.__class__.__mro__[-2])
is serializers.SklearnEstimatorSerializer
)
34 changes: 32 additions & 2 deletions tests/system/vertexai/test_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
import vertexai
from tests.system.aiplatform import e2e_base
from vertexai.preview._workflow.executor import training
from vertexai.preview._workflow.serialization_engine import (
any_serializer,
)
from vertexai.preview._workflow.serialization_engine import (
serializers,
)
import pytest
from sklearn.datasets import load_iris
import tensorflow as tf
Expand All @@ -39,15 +45,15 @@
"google-cloud-aiplatform[preview] @ git+https://github.com/googleapis/"
f"python-aiplatform.git@{os.environ['KOKORO_GIT_COMMIT']}"
if os.environ.get("KOKORO_GIT_COMMIT")
else "google-cloud-aiplatform[preview] @ git+https://github.com/googleapis/python-aiplatform.git@copybara_557913723",
else "google-cloud-aiplatform[preview] @ git+https://github.com/googleapis/python-aiplatform.git@main",
)
@mock.patch.object(
training,
"VERTEX_AI_DEPENDENCY_PATH_AUTOLOGGING",
"google-cloud-aiplatform[preview,autologging] @ git+https://github.com/googleapis/"
f"python-aiplatform.git@{os.environ['KOKORO_GIT_COMMIT']}"
if os.environ.get("KOKORO_GIT_COMMIT")
else "google-cloud-aiplatform[preview,autologging] @ git+https://github.com/googleapis/python-aiplatform.git@copybara_557913723",
else "google-cloud-aiplatform[preview,autologging] @ git+https://github.com/googleapis/python-aiplatform.git@main",
)
@pytest.mark.usefixtures(
"prepare_staging_bucket", "delete_staging_bucket", "tear_down_resources"
Expand Down Expand Up @@ -101,6 +107,17 @@ def test_remote_execution_keras(self, shared_state):
)
model.fit(tf_train_dataset, epochs=10)

# Assert the right serializer is being used
serializer = any_serializer.AnySerializer()
assert (
serializer._get_predefined_serializer(model.__class__.__mro__[3])
is serializers.KerasModelSerializer
)
assert (
serializer._get_predefined_serializer(tf_train_dataset.__class__.__mro__[2])
is serializers.TFDatasetSerializer
)

# Remote prediction on keras
model.predict.vertex.remote_config.display_name = self._make_display_name(
"keras-prediction"
Expand All @@ -122,3 +139,16 @@ def test_remote_execution_keras(self, shared_state):
"keras-cpu-uptraining"
)
pulled_model.fit(tf_retrain_dataset, epochs=10)

# Assert the right serializer is being used
serializer = any_serializer.AnySerializer()
assert (
serializer._get_predefined_serializer(pulled_model.__class__.__mro__[3])
is serializers.KerasModelSerializer
)
assert (
serializer._get_predefined_serializer(
tf_retrain_dataset.__class__.__mro__[2]
)
is serializers.TFDatasetSerializer
)
Original file line number Diff line number Diff line change
Expand Up @@ -362,3 +362,11 @@ def deserialize(self, gcs_path):
any_serializer.register_custom(
to_serialize_type=to_serialize_type, serializer_cls=serializer_cls
)


try:
_any_serializer = AnySerializer()
except ImportError:
_LOGGER.warning(
"cloudpickle is not installed. Please call `pip install google-cloud-aiplatform[preview]`."
)

0 comments on commit 50c1591

Please sign in to comment.