diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index fc4f829882..00d6f11780 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2020 Google LLC +# Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -40,6 +40,7 @@ job_state as gca_job_state, hyperparameter_tuning_job as gca_hyperparameter_tuning_job_compat, machine_resources as gca_machine_resources_compat, + manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters_compat, study as gca_study_compat, ) from google.cloud.aiplatform.constants import base as constants @@ -376,6 +377,7 @@ def create( encryption_spec_key_name: Optional[str] = None, sync: bool = True, create_request_timeout: Optional[float] = None, + batch_size: Optional[int] = None, ) -> "BatchPredictionJob": """Create a batch prediction job. @@ -534,6 +536,13 @@ def create( be immediately returned and synced when the Future has completed. create_request_timeout (float): Optional. The timeout for the create request in seconds. + batch_size (int): + Optional. The number of the records (e.g. instances) of the operation given in each batch + to a machine replica. Machine type, and size of a single record should be considered + when setting this parameter, higher value speeds up the batch operation's execution, + but too high value will result in a whole batch not fitting in a machine's memory, + and the whole operation will fail. + The default value is 64. Returns: (jobs.BatchPredictionJob): Instantiated representation of the created batch prediction job. @@ -647,7 +656,14 @@ def create( gapic_batch_prediction_job.dedicated_resources = dedicated_resources - gapic_batch_prediction_job.manual_batch_tuning_parameters = None + manual_batch_tuning_parameters = ( + gca_manual_batch_tuning_parameters_compat.ManualBatchTuningParameters() + ) + manual_batch_tuning_parameters.batch_size = batch_size + + gapic_batch_prediction_job.manual_batch_tuning_parameters = ( + manual_batch_tuning_parameters + ) # User Labels gapic_batch_prediction_job.labels = labels diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index b15ed791bf..95f3044cbe 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2020 Google LLC +# Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -2284,6 +2284,7 @@ def batch_predict( encryption_spec_key_name: Optional[str] = None, sync: bool = True, create_request_timeout: Optional[float] = None, + batch_size: Optional[int] = None, ) -> jobs.BatchPredictionJob: """Creates a batch prediction job using this Model and outputs prediction results to the provided destination prefix in the specified @@ -2442,6 +2443,13 @@ def batch_predict( Overrides encryption_spec_key_name set in aiplatform.init. create_request_timeout (float): Optional. The timeout for the create request in seconds. + batch_size (int): + Optional. The number of the records (e.g. instances) of the operation given in each batch + to a machine replica. Machine type, and size of a single record should be considered + when setting this parameter, higher value speeds up the batch operation's execution, + but too high value will result in a whole batch not fitting in a machine's memory, + and the whole operation will fail. + The default value is 64. Returns: (jobs.BatchPredictionJob): Instantiated representation of the created batch prediction job. @@ -2462,6 +2470,7 @@ def batch_predict( accelerator_count=accelerator_count, starting_replica_count=starting_replica_count, max_replica_count=max_replica_count, + batch_size=batch_size, generate_explanation=generate_explanation, explanation_metadata=explanation_metadata, explanation_parameters=explanation_parameters, diff --git a/tests/unit/aiplatform/test_jobs.py b/tests/unit/aiplatform/test_jobs.py index 6b8d908dd2..73a4f8da0c 100644 --- a/tests/unit/aiplatform/test_jobs.py +++ b/tests/unit/aiplatform/test_jobs.py @@ -37,6 +37,7 @@ io as gca_io_compat, job_state as gca_job_state_compat, machine_resources as gca_machine_resources_compat, + manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters_compat, ) from google.cloud.aiplatform_v1.services.job_service import client as job_service_client @@ -132,6 +133,7 @@ _TEST_ACCELERATOR_COUNT = 2 _TEST_STARTING_REPLICA_COUNT = 2 _TEST_MAX_REPLICA_COUNT = 12 +_TEST_BATCH_SIZE = 16 _TEST_LABEL = {"team": "experimentation", "trial_id": "x435"} @@ -725,6 +727,7 @@ def test_batch_predict_with_all_args( credentials=creds, sync=sync, create_request_timeout=None, + batch_size=_TEST_BATCH_SIZE, ) batch_prediction_job.wait_for_resource_creation() @@ -756,6 +759,9 @@ def test_batch_predict_with_all_args( starting_replica_count=_TEST_STARTING_REPLICA_COUNT, max_replica_count=_TEST_MAX_REPLICA_COUNT, ), + manual_batch_tuning_parameters=gca_manual_batch_tuning_parameters_compat.ManualBatchTuningParameters( + batch_size=_TEST_BATCH_SIZE + ), generate_explanation=True, explanation_spec=gca_explanation_compat.ExplanationSpec( metadata=_TEST_EXPLANATION_METADATA, diff --git a/tests/unit/aiplatform/test_models.py b/tests/unit/aiplatform/test_models.py index f6561cffaa..eaf63d9fdd 100644 --- a/tests/unit/aiplatform/test_models.py +++ b/tests/unit/aiplatform/test_models.py @@ -49,6 +49,7 @@ env_var as gca_env_var, explanation as gca_explanation, machine_resources as gca_machine_resources, + manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters_compat, model_service as gca_model_service, model_evaluation as gca_model_evaluation, endpoint_service as gca_endpoint_service, @@ -86,6 +87,8 @@ _TEST_STARTING_REPLICA_COUNT = 2 _TEST_MAX_REPLICA_COUNT = 12 +_TEST_BATCH_SIZE = 16 + _TEST_PIPELINE_RESOURCE_NAME = ( "projects/my-project/locations/us-central1/trainingPipeline/12345" ) @@ -1402,47 +1405,47 @@ def test_batch_predict_with_all_args(self, create_batch_prediction_job_mock, syn encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, sync=sync, create_request_timeout=None, + batch_size=_TEST_BATCH_SIZE, ) if not sync: batch_prediction_job.wait() # Construct expected request - expected_gapic_batch_prediction_job = ( - gca_batch_prediction_job.BatchPredictionJob( - display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, - model=model_service_client.ModelServiceClient.model_path( - _TEST_PROJECT, _TEST_LOCATION, _TEST_ID - ), - input_config=gca_batch_prediction_job.BatchPredictionJob.InputConfig( - instances_format="jsonl", - gcs_source=gca_io.GcsSource( - uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE] - ), - ), - output_config=gca_batch_prediction_job.BatchPredictionJob.OutputConfig( - gcs_destination=gca_io.GcsDestination( - output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX - ), - predictions_format="csv", - ), - dedicated_resources=gca_machine_resources.BatchDedicatedResources( - machine_spec=gca_machine_resources.MachineSpec( - machine_type=_TEST_MACHINE_TYPE, - accelerator_type=_TEST_ACCELERATOR_TYPE, - accelerator_count=_TEST_ACCELERATOR_COUNT, - ), - starting_replica_count=_TEST_STARTING_REPLICA_COUNT, - max_replica_count=_TEST_MAX_REPLICA_COUNT, + expected_gapic_batch_prediction_job = gca_batch_prediction_job.BatchPredictionJob( + display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, + model=model_service_client.ModelServiceClient.model_path( + _TEST_PROJECT, _TEST_LOCATION, _TEST_ID + ), + input_config=gca_batch_prediction_job.BatchPredictionJob.InputConfig( + instances_format="jsonl", + gcs_source=gca_io.GcsSource(uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]), + ), + output_config=gca_batch_prediction_job.BatchPredictionJob.OutputConfig( + gcs_destination=gca_io.GcsDestination( + output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX ), - generate_explanation=True, - explanation_spec=gca_explanation.ExplanationSpec( - metadata=_TEST_EXPLANATION_METADATA, - parameters=_TEST_EXPLANATION_PARAMETERS, + predictions_format="csv", + ), + dedicated_resources=gca_machine_resources.BatchDedicatedResources( + machine_spec=gca_machine_resources.MachineSpec( + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, ), - labels=_TEST_LABEL, - encryption_spec=_TEST_ENCRYPTION_SPEC, - ) + starting_replica_count=_TEST_STARTING_REPLICA_COUNT, + max_replica_count=_TEST_MAX_REPLICA_COUNT, + ), + manual_batch_tuning_parameters=gca_manual_batch_tuning_parameters_compat.ManualBatchTuningParameters( + batch_size=_TEST_BATCH_SIZE + ), + generate_explanation=True, + explanation_spec=gca_explanation.ExplanationSpec( + metadata=_TEST_EXPLANATION_METADATA, + parameters=_TEST_EXPLANATION_PARAMETERS, + ), + labels=_TEST_LABEL, + encryption_spec=_TEST_ENCRYPTION_SPEC, ) create_batch_prediction_job_mock.assert_called_once_with(