Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add done method for pipeline, training, and batch prediction jobs #1062

Merged
merged 10 commits into from
Mar 9, 2022
55 changes: 55 additions & 0 deletions google/cloud/aiplatform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,3 +1220,58 @@ def get_annotation_class(annotation: type) -> type:
return annotation.__args__[0]
else:
return annotation


class DoneMixin(abc.ABC):
"""An abstract class for implementing a done method, indicating
whether a job has completed.

"""

@abc.abstractmethod
def done(self) -> bool:
"""Method indicating whether a job has completed."""
pass


class StatefulResource(DoneMixin):
"""Extends DoneMixin to check whether a job returning a stateful resource has compted."""

@property
@abc.abstractmethod
def state(self):
"""The current state of the job."""
pass

@property
@classmethod
@abc.abstractmethod
def _valid_done_states(cls):
"""A set() containing all job states associated with a completed job."""
pass

def done(self) -> bool:
"""Method indicating whether a job has completed.

Returns:
True if the job has completed.
"""
if self.state in self._valid_done_states:
return True
else:
return False


class VertexAiStatefulResource(VertexAiResourceNounWithFutureManager, StatefulResource):
"""Extends StatefulResource to include a check for self._gca_resource."""

def done(self) -> bool:
"""Method indicating whether a job has completed.

Returns:
True if the job has completed.
"""
if self._gca_resource and self._gca_resource.name:
return super().done()
else:
return False
5 changes: 4 additions & 1 deletion google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
)


class _Job(base.VertexAiResourceNounWithFutureManager):
class _Job(base.VertexAiStatefulResource):
"""Class that represents a general Job resource in Vertex AI.
Cannot be directly instantiated.

Expand All @@ -83,6 +83,9 @@ class _Job(base.VertexAiResourceNounWithFutureManager):

client_class = utils.JobClientWithOverride

# Required by the done() method
_valid_done_states = _JOB_COMPLETE_STATES

def __init__(
self,
job_name: str,
Expand Down
5 changes: 4 additions & 1 deletion google/cloud/aiplatform/pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _set_enable_caching_value(
task["cachingOptions"] = {"enableCache": enable_caching}


class PipelineJob(base.VertexAiResourceNounWithFutureManager):
class PipelineJob(base.VertexAiStatefulResource):

client_class = utils.PipelineJobClientWithOverride
_resource_noun = "pipelineJobs"
Expand All @@ -87,6 +87,9 @@ class PipelineJob(base.VertexAiResourceNounWithFutureManager):
_parse_resource_name_method = "parse_pipeline_job_path"
_format_resource_name_method = "pipeline_job_path"

# Required by the done() method
_valid_done_states = _PIPELINE_COMPLETE_STATES

def __init__(
self,
display_name: str,
Expand Down
5 changes: 4 additions & 1 deletion google/cloud/aiplatform/training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
)


class _TrainingJob(base.VertexAiResourceNounWithFutureManager):
class _TrainingJob(base.VertexAiStatefulResource):

client_class = utils.PipelineClientWithOverride
_resource_noun = "trainingPipelines"
Expand All @@ -76,6 +76,9 @@ class _TrainingJob(base.VertexAiResourceNounWithFutureManager):
_parse_resource_name_method = "parse_training_pipeline_path"
_format_resource_name_method = "training_pipeline_path"

# Required by the done() method
_valid_done_states = _PIPELINE_COMPLETE_STATES

def __init__(
self,
display_name: str,
Expand Down
8 changes: 8 additions & 0 deletions tests/system/aiplatform/test_e2e_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,9 @@ def test_end_to_end_tabular(self, shared_state):

shared_state["resources"].append(custom_batch_prediction_job)

in_progress_done_check = custom_job.done()
custom_job.wait_for_resource_creation()

automl_job.wait_for_resource_creation()
custom_batch_prediction_job.wait_for_resource_creation()

Expand Down Expand Up @@ -174,6 +176,8 @@ def test_end_to_end_tabular(self, shared_state):
# Test lazy loading of Endpoint, check getter was never called after predict()
custom_endpoint = aiplatform.Endpoint(custom_endpoint.resource_name)
custom_endpoint.predict([_INSTANCE])

completion_done_check = custom_job.done()
assert custom_endpoint._skipped_getter_call()

assert (
Expand Down Expand Up @@ -201,3 +205,7 @@ def test_end_to_end_tabular(self, shared_state):
assert 200000 > custom_result > 50000
except KeyError as e:
raise RuntimeError("Unexpected prediction response structure:", e)

# Check done() method works correctly
assert in_progress_done_check is False
assert completion_done_check is True
29 changes: 29 additions & 0 deletions tests/unit/aiplatform/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,14 @@ def test_batch_prediction_job_status(self, get_batch_prediction_job_mock):
name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=base._DEFAULT_RETRY
)

def test_batch_prediction_job_done_get(self, get_batch_prediction_job_mock):
bp = jobs.BatchPredictionJob(
batch_prediction_job_name=_TEST_BATCH_PREDICTION_JOB_NAME
)

assert bp.done() is False
assert get_batch_prediction_job_mock.call_count == 2

@pytest.mark.usefixtures("get_batch_prediction_job_gcs_output_mock")
def test_batch_prediction_iter_dirs_gcs(self, storage_list_blobs_mock):
bp = jobs.BatchPredictionJob(
Expand Down Expand Up @@ -507,6 +515,27 @@ def test_batch_predict_gcs_source_and_dest(
batch_prediction_job=expected_gapic_batch_prediction_job,
)

@pytest.mark.usefixtures("get_batch_prediction_job_mock")
def test_batch_predict_job_done_create(self, create_batch_prediction_job_mock):
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)

# Make SDK batch_predict method call
batch_prediction_job = jobs.BatchPredictionJob.create(
model_name=_TEST_MODEL_NAME,
job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME,
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
sync=False,
)

batch_prediction_job.wait_for_resource_creation()

assert batch_prediction_job.done() is False

batch_prediction_job.wait()

assert batch_prediction_job.done() is True

@pytest.mark.parametrize("sync", [True, False])
@pytest.mark.usefixtures("get_batch_prediction_job_mock")
def test_batch_predict_gcs_source_bq_dest(
Expand Down
33 changes: 33 additions & 0 deletions tests/unit/aiplatform/test_pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,39 @@ def test_submit_call_pipeline_service_pipeline_job_create(
gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_SUCCEEDED
)

@pytest.mark.parametrize(
"job_spec_json", [_TEST_PIPELINE_SPEC, _TEST_PIPELINE_JOB],
)
def test_done_method_pipeline_service(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
job_spec_json,
mock_load_json,
):
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_GCS_BUCKET_NAME,
location=_TEST_LOCATION,
credentials=_TEST_CREDENTIALS,
)

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
template_path=_TEST_TEMPLATE_PATH,
job_id=_TEST_PIPELINE_JOB_ID,
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
enable_caching=True,
)

job.submit(service_account=_TEST_SERVICE_ACCOUNT, network=_TEST_NETWORK)

assert job.done() is False

job.wait()

assert job.done() is True

@pytest.mark.parametrize(
"job_spec_json", [_TEST_PIPELINE_SPEC_LEGACY, _TEST_PIPELINE_JOB_LEGACY],
)
Expand Down
112 changes: 112 additions & 0 deletions tests/unit/aiplatform/test_training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,65 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(

assert job._has_logged_custom_job

def test_custom_training_tabular_done(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
mock_python_package_to_gcs,
mock_tabular_dataset,
mock_model_service_get,
):
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_BUCKET_NAME,
credentials=_TEST_CREDENTIALS,
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
)

job = training_jobs.CustomTrainingJob(
display_name=_TEST_DISPLAY_NAME,
labels=_TEST_LABELS,
script_path=_TEST_LOCAL_SCRIPT_FILE_NAME,
container_uri=_TEST_TRAINING_CONTAINER_IMAGE,
model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE,
model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE,
model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE,
model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI,
model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI,
model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI,
model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND,
model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS,
model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES,
model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS,
model_description=_TEST_MODEL_DESCRIPTION,
)

job.run(
dataset=mock_tabular_dataset,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
args=_TEST_RUN_ARGS,
environment_variables=_TEST_ENVIRONMENT_VARIABLES,
machine_type=_TEST_MACHINE_TYPE,
accelerator_type=_TEST_ACCELERATOR_TYPE,
accelerator_count=_TEST_ACCELERATOR_COUNT,
model_display_name=_TEST_MODEL_DISPLAY_NAME,
model_labels=_TEST_MODEL_LABELS,
training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT,
validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT,
test_fraction_split=_TEST_TEST_FRACTION_SPLIT,
timestamp_split_column_name=_TEST_TIMESTAMP_SPLIT_COLUMN_NAME,
tensorboard=_TEST_TENSORBOARD_RESOURCE_NAME,
sync=False,
)

assert job.done() is False

job.wait()

assert job.done() is True

@pytest.mark.parametrize("sync", [True, False])
def test_run_call_pipeline_service_create_with_bigquery_destination(
self,
Expand Down Expand Up @@ -2323,6 +2382,59 @@ def setup_method(self):
def teardown_method(self):
initializer.global_pool.shutdown(wait=True)

def test_custom_container_training_tabular_done(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
mock_tabular_dataset,
mock_model_service_get,
):
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_BUCKET_NAME,
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
)

job = training_jobs.CustomContainerTrainingJob(
display_name=_TEST_DISPLAY_NAME,
labels=_TEST_LABELS,
container_uri=_TEST_TRAINING_CONTAINER_IMAGE,
command=_TEST_TRAINING_CONTAINER_CMD,
model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE,
model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE,
model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE,
model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI,
model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI,
model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI,
model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND,
model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS,
model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES,
model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS,
model_description=_TEST_MODEL_DESCRIPTION,
)

job.run(
dataset=mock_tabular_dataset,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
args=_TEST_RUN_ARGS,
environment_variables=_TEST_ENVIRONMENT_VARIABLES,
machine_type=_TEST_MACHINE_TYPE,
accelerator_type=_TEST_ACCELERATOR_TYPE,
accelerator_count=_TEST_ACCELERATOR_COUNT,
model_display_name=_TEST_MODEL_DISPLAY_NAME,
model_labels=_TEST_MODEL_LABELS,
predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME,
service_account=_TEST_SERVICE_ACCOUNT,
tensorboard=_TEST_TENSORBOARD_RESOURCE_NAME,
sync=False,
)

assert job.done() is False

job.wait()

assert job.done() is True

@pytest.mark.parametrize("sync", [True, False])
def test_run_call_pipeline_service_create_with_tabular_dataset(
self,
Expand Down