Skip to content

Commit

Permalink
feat: add Custom Job support to from_pretrained
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 565175389
  • Loading branch information
sararob authored and copybara-github committed Sep 13, 2023
1 parent 220cbe8 commit 8b0add1
Show file tree
Hide file tree
Showing 12 changed files with 457 additions and 58 deletions.
13 changes: 13 additions & 0 deletions google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,19 @@
gca_job_state_v1beta1.JobState.JOB_STATE_CANCELLED,
)

_JOB_PENDING_STATES = (
gca_job_state.JobState.JOB_STATE_QUEUED,
gca_job_state.JobState.JOB_STATE_PENDING,
gca_job_state.JobState.JOB_STATE_RUNNING,
gca_job_state.JobState.JOB_STATE_CANCELLING,
gca_job_state.JobState.JOB_STATE_UPDATING,
gca_job_state_v1beta1.JobState.JOB_STATE_QUEUED,
gca_job_state_v1beta1.JobState.JOB_STATE_PENDING,
gca_job_state_v1beta1.JobState.JOB_STATE_RUNNING,
gca_job_state_v1beta1.JobState.JOB_STATE_CANCELLING,
gca_job_state_v1beta1.JobState.JOB_STATE_UPDATING,
)

# _block_until_complete wait times
_JOB_WAIT_TIME = 5 # start at five seconds
_LOG_WAIT_TIME = 5
Expand Down
1 change: 1 addition & 0 deletions tests/unit/vertexai/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
output_uri_prefix=_TEST_BASE_OUTPUT_DIR
),
},
labels={"trained_by_vertex_ai": "true"},
)


Expand Down
202 changes: 202 additions & 0 deletions tests/unit/vertexai/test_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,19 @@
import vertexai
from vertexai.preview._workflow.serialization_engine import (
any_serializer,
serializers_base,
)
from google.cloud.aiplatform.compat.services import job_service_client
from google.cloud.aiplatform.compat.types import (
job_state as gca_job_state,
custom_job as gca_custom_job,
io as gca_io,
)
import pytest

import cloudpickle
import numpy as np
import sklearn
from sklearn.linear_model import _logistic
import tensorflow
import torch
Expand All @@ -45,6 +55,9 @@
_MODEL_RESOURCE_NAME = "projects/123/locations/us-central1/models/456"
_REWRAPPER = "rewrapper"

# customJob constants
_TEST_CUSTOM_JOB_RESOURCE_NAME = "projects/123/locations/us-central1/customJobs/456"


@pytest.fixture
def mock_serialize_model():
Expand Down Expand Up @@ -123,6 +136,126 @@ def mock_deserialize_model_exception():
yield mock_deserialize_model_exception


@pytest.fixture
def mock_any_serializer_serialize_sklearn():
with mock.patch.object(
any_serializer.AnySerializer,
"serialize",
side_effect=[
{
serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [
f"scikit-learn=={sklearn.__version__}"
]
},
{
serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [
f"numpy=={np.__version__}",
f"cloudpickle=={cloudpickle.__version__}",
]
},
{
serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [
f"numpy=={np.__version__}",
f"cloudpickle=={cloudpickle.__version__}",
]
},
{
serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [
f"numpy=={np.__version__}",
f"cloudpickle=={cloudpickle.__version__}",
]
},
],
) as mock_any_serializer_serialize:
yield mock_any_serializer_serialize


_TEST_PROJECT = "test-project"
_TEST_LOCATION = "us-central1"
_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
_TEST_DISPLAY_NAME = f"{_TEST_PARENT}/customJobs/12345"
_TEST_BUCKET_NAME = "gs://test_bucket"
_TEST_BASE_OUTPUT_DIR = f"{_TEST_BUCKET_NAME}/test_base_output_dir"

_TEST_INPUTS = [
"--arg_0=string_val_0",
"--arg_1=string_val_1",
"--arg_2=int_val_0",
"--arg_3=int_val_1",
]
_TEST_IMAGE_URI = "test_image_uri"
_TEST_MACHINE_TYPE = "test_machine_type"
_TEST_WORKER_POOL_SPEC = [
{
"machine_spec": {
"machine_type": _TEST_MACHINE_TYPE,
},
"replica_count": 1,
"container_spec": {
"image_uri": _TEST_IMAGE_URI,
"args": _TEST_INPUTS,
},
}
]
_TEST_CUSTOM_JOB_PROTO = gca_custom_job.CustomJob(
display_name=_TEST_DISPLAY_NAME,
job_spec={
"worker_pool_specs": _TEST_WORKER_POOL_SPEC,
"base_output_directory": gca_io.GcsDestination(
output_uri_prefix=_TEST_BASE_OUTPUT_DIR
),
},
labels={"trained_by_vertex_ai": "true"},
)


@pytest.fixture
def mock_get_custom_job_pending():
with mock.patch.object(
job_service_client.JobServiceClient, "get_custom_job"
) as mock_get_custom_job:

mock_get_custom_job.side_effect = [
gca_custom_job.CustomJob(
name=_TEST_CUSTOM_JOB_RESOURCE_NAME,
state=gca_job_state.JobState.JOB_STATE_RUNNING,
display_name=_TEST_DISPLAY_NAME,
job_spec={
"worker_pool_specs": _TEST_WORKER_POOL_SPEC,
"base_output_directory": gca_io.GcsDestination(
output_uri_prefix=_TEST_BASE_OUTPUT_DIR
),
},
labels={"trained_by_vertex_ai": "true"},
),
gca_custom_job.CustomJob(
name=_TEST_CUSTOM_JOB_RESOURCE_NAME,
state=gca_job_state.JobState.JOB_STATE_SUCCEEDED,
display_name=_TEST_DISPLAY_NAME,
job_spec={
"worker_pool_specs": _TEST_WORKER_POOL_SPEC,
"base_output_directory": gca_io.GcsDestination(
output_uri_prefix=_TEST_BASE_OUTPUT_DIR
),
},
labels={"trained_by_vertex_ai": "true"},
),
]
yield mock_get_custom_job


@pytest.fixture
def mock_get_custom_job_failed():
with mock.patch.object(
job_service_client.JobServiceClient, "get_custom_job"
) as mock_get_custom_job:
custom_job_proto = _TEST_CUSTOM_JOB_PROTO
custom_job_proto.name = _TEST_CUSTOM_JOB_RESOURCE_NAME
custom_job_proto.state = gca_job_state.JobState.JOB_STATE_FAILED
mock_get_custom_job.return_value = custom_job_proto
yield mock_get_custom_job


@pytest.mark.usefixtures("google_auth_mock")
class TestModelUtils:
def setup_method(self):
Expand Down Expand Up @@ -289,3 +422,72 @@ def test_local_model_from_pretrained_fail(self):

with pytest.raises(ValueError):
vertexai.preview.from_pretrained(model_name=_MODEL_RESOURCE_NAME)

@pytest.mark.usefixtures(
"mock_get_vertex_model",
"mock_get_custom_job_succeeded",
)
def test_custom_job_from_pretrained_succeed(self, mock_deserialize_model):
vertexai.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
staging_bucket=_TEST_BUCKET,
)

local_model = vertexai.preview.from_pretrained(
custom_job_name=_TEST_CUSTOM_JOB_RESOURCE_NAME
)
assert local_model == _SKLEARN_MODEL
assert 2 == mock_deserialize_model.call_count

mock_deserialize_model.assert_has_calls(
calls=[
mock.call(
f"{_TEST_BASE_OUTPUT_DIR}/output/output_estimator",
),
],
any_order=True,
)

@pytest.mark.usefixtures(
"mock_get_vertex_model",
"mock_get_custom_job_pending",
"mock_cloud_logging_list_entries",
)
def test_custom_job_from_pretrained_logs_and_blocks_until_complete_on_pending_job(
self, mock_deserialize_model
):
vertexai.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
staging_bucket=_TEST_BUCKET,
)

local_model = vertexai.preview.from_pretrained(
custom_job_name=_TEST_CUSTOM_JOB_RESOURCE_NAME
)
assert local_model == _SKLEARN_MODEL
assert 2 == mock_deserialize_model.call_count

mock_deserialize_model.assert_has_calls(
calls=[
mock.call(
f"{_TEST_BASE_OUTPUT_DIR}/output/output_estimator",
),
],
any_order=True,
)

@pytest.mark.usefixtures("mock_get_vertex_model", "mock_get_custom_job_failed")
def test_custom_job_from_pretrained_fails_on_errored_job(self):
vertexai.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
staging_bucket=_TEST_BUCKET,
)

with pytest.raises(ValueError) as err_msg:
vertexai.preview.from_pretrained(
custom_job_name=_TEST_CUSTOM_JOB_RESOURCE_NAME
)
assert "did not complete" in err_msg
13 changes: 13 additions & 0 deletions tests/unit/vertexai/test_remote_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ def _get_custom_job_proto(
env.append(
{"name": metadata_constants.ENV_EXPERIMENT_RUN_KEY, "value": experiment_run}
)
job.labels = ({"trained_by_vertex_ai": "true"},)
return job


Expand Down Expand Up @@ -480,6 +481,12 @@ def mock_any_serializer_serialize_sklearn():
f"cloudpickle=={cloudpickle.__version__}",
]
},
{
serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [
f"numpy=={np.__version__}",
f"cloudpickle=={cloudpickle.__version__}",
]
},
],
) as mock_any_serializer_serialize:
yield mock_any_serializer_serialize
Expand Down Expand Up @@ -557,6 +564,12 @@ def mock_any_serializer_serialize_keras():
f"cloudpickle=={cloudpickle.__version__}",
]
},
{
serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [
f"numpy=={np.__version__}",
f"cloudpickle=={cloudpickle.__version__}",
]
},
],
) as mock_any_serializer_serialize:
yield mock_any_serializer_serialize
Expand Down
8 changes: 5 additions & 3 deletions vertexai/preview/_workflow/driver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def invoke(self, invokable: shared._Invokable) -> Any:
):
rewrapper = _unwrapper(invokable.instance)

result = self._launch(invokable)
result = self._launch(invokable, rewrapper)

# rewrap the original instance
if rewrapper and invokable.instance is not None:
Expand All @@ -255,12 +255,14 @@ def invoke(self, invokable: shared._Invokable) -> Any:

return result

def _launch(self, invokable: shared._Invokable) -> Any:
def _launch(self, invokable: shared._Invokable, rewrapper: Any) -> Any:
"""
Launches an invokable.
"""
return self._launcher.launch(
invokable=invokable, global_remote=vertexai.preview.global_config.remote
invokable=invokable,
global_remote=vertexai.preview.global_config.remote,
rewrapper=rewrapper,
)


Expand Down
7 changes: 5 additions & 2 deletions vertexai/preview/_workflow/executor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,18 @@ def local_execute(self, invokable: shared._Invokable) -> Any:
*invokable.bound_arguments.args, **invokable.bound_arguments.kwargs
)

def remote_execute(self, invokable: shared._Invokable) -> Any:
def remote_execute(self, invokable: shared._Invokable, rewrapper: Any) -> Any:
if invokable.remote_executor not in (
remote_container_training.train,
training.remote_training,
prediction.remote_prediction,
):
raise ValueError(f"{invokable.remote_executor} is not supported.")

return invokable.remote_executor(invokable)
if invokable.remote_executor == remote_container_training.train:
invokable.remote_executor(invokable)
else:
return invokable.remote_executor(invokable, rewrapper=rewrapper)


_workflow_executor = _WorkflowExecutor()
6 changes: 4 additions & 2 deletions vertexai/preview/_workflow/executor/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any

from vertexai.preview._workflow import (
shared,
)
Expand All @@ -20,9 +22,9 @@
)


def remote_prediction(invokable: shared._Invokable):
def remote_prediction(invokable: shared._Invokable, rewrapper: Any):
"""Wrapper function that makes a method executable by Vertex CustomJob."""
predictions = training.remote_training(invokable=invokable)
predictions = training.remote_training(invokable=invokable, rewrapper=rewrapper)
return predictions


Expand Down
Loading

0 comments on commit 8b0add1

Please sign in to comment.