Skip to content

Commit

Permalink
Feat: add batch_size kwarg for batch prediction jobs (#1194)
Browse files Browse the repository at this point in the history
* Add batch_size kwarg for batch prediction jobs

* Fix errors

Update the copyright year. Change the order of the argument. Fix the syntax error.

* fix: change description layout
  • Loading branch information
jaycee-li authored May 10, 2022
1 parent 7c70484 commit 50bdb01
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 36 deletions.
20 changes: 18 additions & 2 deletions google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

# Copyright 2020 Google LLC
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -40,6 +40,7 @@
job_state as gca_job_state,
hyperparameter_tuning_job as gca_hyperparameter_tuning_job_compat,
machine_resources as gca_machine_resources_compat,
manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters_compat,
study as gca_study_compat,
)
from google.cloud.aiplatform.constants import base as constants
Expand Down Expand Up @@ -376,6 +377,7 @@ def create(
encryption_spec_key_name: Optional[str] = None,
sync: bool = True,
create_request_timeout: Optional[float] = None,
batch_size: Optional[int] = None,
) -> "BatchPredictionJob":
"""Create a batch prediction job.
Expand Down Expand Up @@ -534,6 +536,13 @@ def create(
be immediately returned and synced when the Future has completed.
create_request_timeout (float):
Optional. The timeout for the create request in seconds.
batch_size (int):
Optional. The number of the records (e.g. instances) of the operation given in each batch
to a machine replica. Machine type, and size of a single record should be considered
when setting this parameter, higher value speeds up the batch operation's execution,
but too high value will result in a whole batch not fitting in a machine's memory,
and the whole operation will fail.
The default value is 64.
Returns:
(jobs.BatchPredictionJob):
Instantiated representation of the created batch prediction job.
Expand Down Expand Up @@ -647,7 +656,14 @@ def create(

gapic_batch_prediction_job.dedicated_resources = dedicated_resources

gapic_batch_prediction_job.manual_batch_tuning_parameters = None
manual_batch_tuning_parameters = (
gca_manual_batch_tuning_parameters_compat.ManualBatchTuningParameters()
)
manual_batch_tuning_parameters.batch_size = batch_size

gapic_batch_prediction_job.manual_batch_tuning_parameters = (
manual_batch_tuning_parameters
)

# User Labels
gapic_batch_prediction_job.labels = labels
Expand Down
11 changes: 10 additions & 1 deletion google/cloud/aiplatform/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

# Copyright 2020 Google LLC
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -2284,6 +2284,7 @@ def batch_predict(
encryption_spec_key_name: Optional[str] = None,
sync: bool = True,
create_request_timeout: Optional[float] = None,
batch_size: Optional[int] = None,
) -> jobs.BatchPredictionJob:
"""Creates a batch prediction job using this Model and outputs
prediction results to the provided destination prefix in the specified
Expand Down Expand Up @@ -2442,6 +2443,13 @@ def batch_predict(
Overrides encryption_spec_key_name set in aiplatform.init.
create_request_timeout (float):
Optional. The timeout for the create request in seconds.
batch_size (int):
Optional. The number of the records (e.g. instances) of the operation given in each batch
to a machine replica. Machine type, and size of a single record should be considered
when setting this parameter, higher value speeds up the batch operation's execution,
but too high value will result in a whole batch not fitting in a machine's memory,
and the whole operation will fail.
The default value is 64.
Returns:
(jobs.BatchPredictionJob):
Instantiated representation of the created batch prediction job.
Expand All @@ -2462,6 +2470,7 @@ def batch_predict(
accelerator_count=accelerator_count,
starting_replica_count=starting_replica_count,
max_replica_count=max_replica_count,
batch_size=batch_size,
generate_explanation=generate_explanation,
explanation_metadata=explanation_metadata,
explanation_parameters=explanation_parameters,
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/aiplatform/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
io as gca_io_compat,
job_state as gca_job_state_compat,
machine_resources as gca_machine_resources_compat,
manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters_compat,
)

from google.cloud.aiplatform_v1.services.job_service import client as job_service_client
Expand Down Expand Up @@ -132,6 +133,7 @@
_TEST_ACCELERATOR_COUNT = 2
_TEST_STARTING_REPLICA_COUNT = 2
_TEST_MAX_REPLICA_COUNT = 12
_TEST_BATCH_SIZE = 16

_TEST_LABEL = {"team": "experimentation", "trial_id": "x435"}

Expand Down Expand Up @@ -725,6 +727,7 @@ def test_batch_predict_with_all_args(
credentials=creds,
sync=sync,
create_request_timeout=None,
batch_size=_TEST_BATCH_SIZE,
)

batch_prediction_job.wait_for_resource_creation()
Expand Down Expand Up @@ -756,6 +759,9 @@ def test_batch_predict_with_all_args(
starting_replica_count=_TEST_STARTING_REPLICA_COUNT,
max_replica_count=_TEST_MAX_REPLICA_COUNT,
),
manual_batch_tuning_parameters=gca_manual_batch_tuning_parameters_compat.ManualBatchTuningParameters(
batch_size=_TEST_BATCH_SIZE
),
generate_explanation=True,
explanation_spec=gca_explanation_compat.ExplanationSpec(
metadata=_TEST_EXPLANATION_METADATA,
Expand Down
69 changes: 36 additions & 33 deletions tests/unit/aiplatform/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
env_var as gca_env_var,
explanation as gca_explanation,
machine_resources as gca_machine_resources,
manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters_compat,
model_service as gca_model_service,
model_evaluation as gca_model_evaluation,
endpoint_service as gca_endpoint_service,
Expand Down Expand Up @@ -86,6 +87,8 @@
_TEST_STARTING_REPLICA_COUNT = 2
_TEST_MAX_REPLICA_COUNT = 12

_TEST_BATCH_SIZE = 16

_TEST_PIPELINE_RESOURCE_NAME = (
"projects/my-project/locations/us-central1/trainingPipeline/12345"
)
Expand Down Expand Up @@ -1402,47 +1405,47 @@ def test_batch_predict_with_all_args(self, create_batch_prediction_job_mock, syn
encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME,
sync=sync,
create_request_timeout=None,
batch_size=_TEST_BATCH_SIZE,
)

if not sync:
batch_prediction_job.wait()

# Construct expected request
expected_gapic_batch_prediction_job = (
gca_batch_prediction_job.BatchPredictionJob(
display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME,
model=model_service_client.ModelServiceClient.model_path(
_TEST_PROJECT, _TEST_LOCATION, _TEST_ID
),
input_config=gca_batch_prediction_job.BatchPredictionJob.InputConfig(
instances_format="jsonl",
gcs_source=gca_io.GcsSource(
uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]
),
),
output_config=gca_batch_prediction_job.BatchPredictionJob.OutputConfig(
gcs_destination=gca_io.GcsDestination(
output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX
),
predictions_format="csv",
),
dedicated_resources=gca_machine_resources.BatchDedicatedResources(
machine_spec=gca_machine_resources.MachineSpec(
machine_type=_TEST_MACHINE_TYPE,
accelerator_type=_TEST_ACCELERATOR_TYPE,
accelerator_count=_TEST_ACCELERATOR_COUNT,
),
starting_replica_count=_TEST_STARTING_REPLICA_COUNT,
max_replica_count=_TEST_MAX_REPLICA_COUNT,
expected_gapic_batch_prediction_job = gca_batch_prediction_job.BatchPredictionJob(
display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME,
model=model_service_client.ModelServiceClient.model_path(
_TEST_PROJECT, _TEST_LOCATION, _TEST_ID
),
input_config=gca_batch_prediction_job.BatchPredictionJob.InputConfig(
instances_format="jsonl",
gcs_source=gca_io.GcsSource(uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]),
),
output_config=gca_batch_prediction_job.BatchPredictionJob.OutputConfig(
gcs_destination=gca_io.GcsDestination(
output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX
),
generate_explanation=True,
explanation_spec=gca_explanation.ExplanationSpec(
metadata=_TEST_EXPLANATION_METADATA,
parameters=_TEST_EXPLANATION_PARAMETERS,
predictions_format="csv",
),
dedicated_resources=gca_machine_resources.BatchDedicatedResources(
machine_spec=gca_machine_resources.MachineSpec(
machine_type=_TEST_MACHINE_TYPE,
accelerator_type=_TEST_ACCELERATOR_TYPE,
accelerator_count=_TEST_ACCELERATOR_COUNT,
),
labels=_TEST_LABEL,
encryption_spec=_TEST_ENCRYPTION_SPEC,
)
starting_replica_count=_TEST_STARTING_REPLICA_COUNT,
max_replica_count=_TEST_MAX_REPLICA_COUNT,
),
manual_batch_tuning_parameters=gca_manual_batch_tuning_parameters_compat.ManualBatchTuningParameters(
batch_size=_TEST_BATCH_SIZE
),
generate_explanation=True,
explanation_spec=gca_explanation.ExplanationSpec(
metadata=_TEST_EXPLANATION_METADATA,
parameters=_TEST_EXPLANATION_PARAMETERS,
),
labels=_TEST_LABEL,
encryption_spec=_TEST_ENCRYPTION_SPEC,
)

create_batch_prediction_job_mock.assert_called_once_with(
Expand Down

0 comments on commit 50bdb01

Please sign in to comment.