Skip to content

Commit

Permalink
feat: add Service Account support to BatchPredictionJob
Browse files Browse the repository at this point in the history
COPYBARA_INTEGRATE_REVIEW=#1872 from cymarechal-devoteam:feature/batch-prediction/service-account 4f015f3
PiperOrigin-RevId: 501301075
  • Loading branch information
cymarechal-devoteam authored and copybara-github committed Jan 11, 2023
1 parent 369a0cc commit deba06b
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 6 deletions.
7 changes: 4 additions & 3 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -359,10 +359,11 @@ To create a batch prediction job:
batch_prediction_job = model.batch_predict(
job_display_name='my-batch-prediction-job',
instances_format='csv'
instances_format='csv',
machine_type='n1-standard-4',
gcs_source=['gs://path/to/my/file.csv']
gcs_destination_prefix='gs://path/to/by/batch_prediction/results/'
gcs_source=['gs://path/to/my/file.csv'],
gcs_destination_prefix='gs://path/to/my/batch_prediction/results/',
service_account='[email protected]'
)
You can also create a batch prediction job asynchronously by including the `sync=False` argument:
Expand Down
7 changes: 4 additions & 3 deletions docs/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -284,10 +284,11 @@ To create a batch prediction job:
batch_prediction_job = model.batch_predict(
job_display_name='my-batch-prediction-job',
instances_format='csv'
instances_format='csv',
machine_type='n1-standard-4',
gcs_source=['gs://path/to/my/file.csv']
gcs_destination_prefix='gs://path/to/by/batch_prediction/results/'
gcs_source=['gs://path/to/my/file.csv'],
gcs_destination_prefix='gs://path/to/my/batch_prediction/results/',
service_account='[email protected]'
)
You can also create a batch prediction job asynchronously by including the `sync=False` argument:
Expand Down
7 changes: 7 additions & 0 deletions google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ def create(
"aiplatform.model_monitoring.AlertConfig"
] = None,
analysis_instance_schema_uri: Optional[str] = None,
service_account: Optional[str] = None,
) -> "BatchPredictionJob":
"""Create a batch prediction job.
Expand Down Expand Up @@ -586,6 +587,9 @@ def create(
and TFDV instance, this field can be used to override the schema.
For models trained with Vertex AI, this field must be set as all the
fields in predict instance formatted as string.
service_account (str):
Optional. Specifies the service account for workload run-as account.
Users submitting jobs must have act-as permission on this run-as account.
Returns:
(jobs.BatchPredictionJob):
Instantiated representation of the created batch prediction job.
Expand Down Expand Up @@ -745,6 +749,9 @@ def create(
)
gapic_batch_prediction_job.explanation_spec = explanation_spec

if service_account:
gapic_batch_prediction_job.service_account = service_account

empty_batch_prediction_job = cls._empty_constructor(
project=project,
location=location,
Expand Down
5 changes: 5 additions & 0 deletions google/cloud/aiplatform/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3511,6 +3511,7 @@ def batch_predict(
sync: bool = True,
create_request_timeout: Optional[float] = None,
batch_size: Optional[int] = None,
service_account: Optional[str] = None,
) -> jobs.BatchPredictionJob:
"""Creates a batch prediction job using this Model and outputs
prediction results to the provided destination prefix in the specified
Expand Down Expand Up @@ -3673,6 +3674,9 @@ def batch_predict(
but too high value will result in a whole batch not fitting in a machine's memory,
and the whole operation will fail.
The default value is 64.
service_account (str):
Optional. Specifies the service account for workload run-as account.
Users submitting jobs must have act-as permission on this run-as account.
Returns:
job (jobs.BatchPredictionJob):
Expand Down Expand Up @@ -3705,6 +3709,7 @@ def batch_predict(
encryption_spec_key_name=encryption_spec_key_name,
sync=sync,
create_request_timeout=create_request_timeout,
service_account=service_account,
)

@classmethod
Expand Down
23 changes: 23 additions & 0 deletions tests/unit/aiplatform/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@
_TEST_BQ_JOB_ID = "123459876"
_TEST_BQ_MAX_RESULTS = 100
_TEST_GCS_BUCKET_NAME = "my-bucket"
_TEST_SERVICE_ACCOUNT = "[email protected]"


_TEST_BQ_PATH = f"bq://{_TEST_BQ_PROJECT_ID}.{_TEST_BQ_DATASET_ID}"
_TEST_GCS_BUCKET_PATH = f"gs://{_TEST_GCS_BUCKET_NAME}"
Expand Down Expand Up @@ -719,6 +721,7 @@ def test_batch_predict_gcs_source_and_dest(
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
sync=sync,
create_request_timeout=None,
service_account=_TEST_SERVICE_ACCOUNT,
)

batch_prediction_job.wait_for_resource_creation()
Expand All @@ -741,6 +744,7 @@ def test_batch_predict_gcs_source_and_dest(
),
predictions_format="jsonl",
),
service_account=_TEST_SERVICE_ACCOUNT,
)

create_batch_prediction_job_mock.assert_called_once_with(
Expand All @@ -766,6 +770,7 @@ def test_batch_predict_gcs_source_and_dest_with_timeout(
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
sync=sync,
create_request_timeout=180.0,
service_account=_TEST_SERVICE_ACCOUNT,
)

batch_prediction_job.wait_for_resource_creation()
Expand All @@ -788,6 +793,7 @@ def test_batch_predict_gcs_source_and_dest_with_timeout(
),
predictions_format="jsonl",
),
service_account=_TEST_SERVICE_ACCOUNT,
)

create_batch_prediction_job_mock.assert_called_once_with(
Expand All @@ -812,6 +818,7 @@ def test_batch_predict_gcs_source_and_dest_with_timeout_not_explicitly_set(
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
sync=sync,
service_account=_TEST_SERVICE_ACCOUNT,
)

batch_prediction_job.wait_for_resource_creation()
Expand All @@ -834,6 +841,7 @@ def test_batch_predict_gcs_source_and_dest_with_timeout_not_explicitly_set(
),
predictions_format="jsonl",
),
service_account=_TEST_SERVICE_ACCOUNT,
)

create_batch_prediction_job_mock.assert_called_once_with(
Expand All @@ -855,6 +863,7 @@ def test_batch_predict_job_done_create(self, create_batch_prediction_job_mock):
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
sync=False,
service_account=_TEST_SERVICE_ACCOUNT,
)

batch_prediction_job.wait_for_resource_creation()
Expand All @@ -881,6 +890,7 @@ def test_batch_predict_gcs_source_bq_dest(
bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX,
sync=sync,
create_request_timeout=None,
service_account=_TEST_SERVICE_ACCOUNT,
)

batch_prediction_job.wait_for_resource_creation()
Expand Down Expand Up @@ -908,6 +918,7 @@ def test_batch_predict_gcs_source_bq_dest(
),
predictions_format="bigquery",
),
service_account=_TEST_SERVICE_ACCOUNT,
)

create_batch_prediction_job_mock.assert_called_once_with(
Expand Down Expand Up @@ -946,6 +957,7 @@ def test_batch_predict_with_all_args(
sync=sync,
create_request_timeout=None,
batch_size=_TEST_BATCH_SIZE,
service_account=_TEST_SERVICE_ACCOUNT,
)

batch_prediction_job.wait_for_resource_creation()
Expand Down Expand Up @@ -986,6 +998,7 @@ def test_batch_predict_with_all_args(
parameters=_TEST_EXPLANATION_PARAMETERS,
),
labels=_TEST_LABEL,
service_account=_TEST_SERVICE_ACCOUNT,
)

create_batch_prediction_job_with_explanations_mock.assert_called_once_with(
Expand Down Expand Up @@ -1047,6 +1060,7 @@ def test_batch_predict_with_all_args_and_model_monitoring(
model_monitoring_objective_config=mm_obj_cfg,
model_monitoring_alert_config=mm_alert_cfg,
analysis_instance_schema_uri="",
service_account=_TEST_SERVICE_ACCOUNT,
)

batch_prediction_job.wait_for_resource_creation()
Expand Down Expand Up @@ -1086,6 +1100,7 @@ def test_batch_predict_with_all_args_and_model_monitoring(
generate_explanation=True,
model_monitoring_config=_TEST_MODEL_MONITORING_CFG,
labels=_TEST_LABEL,
service_account=_TEST_SERVICE_ACCOUNT,
)
create_batch_prediction_job_v1beta1_mock.assert_called_once_with(
parent=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}",
Expand All @@ -1103,6 +1118,7 @@ def test_batch_predict_create_fails(self):
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX,
sync=False,
service_account=_TEST_SERVICE_ACCOUNT,
)

with pytest.raises(RuntimeError) as e:
Expand Down Expand Up @@ -1143,6 +1159,7 @@ def test_batch_predict_no_source(self, create_batch_prediction_job_mock):
model_name=_TEST_MODEL_NAME,
job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME,
bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX,
service_account=_TEST_SERVICE_ACCOUNT,
)

assert e.match(regexp=r"source")
Expand All @@ -1159,6 +1176,7 @@ def test_batch_predict_two_sources(self, create_batch_prediction_job_mock):
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
bigquery_source=_TEST_BATCH_PREDICTION_BQ_PREFIX,
bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX,
service_account=_TEST_SERVICE_ACCOUNT,
)

assert e.match(regexp=r"source")
Expand All @@ -1173,6 +1191,7 @@ def test_batch_predict_no_destination(self):
model_name=_TEST_MODEL_NAME,
job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME,
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
service_account=_TEST_SERVICE_ACCOUNT,
)

assert e.match(regexp=r"destination")
Expand All @@ -1189,6 +1208,7 @@ def test_batch_predict_wrong_instance_format(self):
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
instances_format="wrong",
bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX,
service_account=_TEST_SERVICE_ACCOUNT,
)

assert e.match(regexp=r"accepted instances format")
Expand All @@ -1205,6 +1225,7 @@ def test_batch_predict_wrong_prediction_format(self):
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
predictions_format="wrong",
bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX,
service_account=_TEST_SERVICE_ACCOUNT,
)

assert e.match(regexp=r"accepted prediction format")
Expand All @@ -1222,6 +1243,7 @@ def test_batch_predict_job_with_versioned_model(
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
sync=True,
service_account=_TEST_SERVICE_ACCOUNT,
)
assert (
create_batch_prediction_job_mock.call_args_list[0][1][
Expand All @@ -1237,6 +1259,7 @@ def test_batch_predict_job_with_versioned_model(
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
sync=True,
service_account=_TEST_SERVICE_ACCOUNT,
)
assert (
create_batch_prediction_job_mock.call_args_list[0][1][
Expand Down
9 changes: 9 additions & 0 deletions tests/unit/aiplatform/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1644,6 +1644,7 @@ def test_init_aiplatform_with_encryption_key_name_and_batch_predict_gcs_source_a
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
sync=sync,
create_request_timeout=None,
service_account=_TEST_SERVICE_ACCOUNT,
)

if not sync:
Expand All @@ -1669,6 +1670,7 @@ def test_init_aiplatform_with_encryption_key_name_and_batch_predict_gcs_source_a
predictions_format="jsonl",
),
encryption_spec=_TEST_ENCRYPTION_SPEC,
service_account=_TEST_SERVICE_ACCOUNT,
)
)

Expand All @@ -1693,6 +1695,7 @@ def test_batch_predict_gcs_source_and_dest(
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
sync=sync,
create_request_timeout=None,
service_account=_TEST_SERVICE_ACCOUNT,
)

if not sync:
Expand All @@ -1711,6 +1714,7 @@ def test_batch_predict_with_version(self, sync, create_batch_prediction_job_mock
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
sync=sync,
create_request_timeout=None,
service_account=_TEST_SERVICE_ACCOUNT,
)

if not sync:
Expand All @@ -1733,6 +1737,7 @@ def test_batch_predict_with_version(self, sync, create_batch_prediction_job_mock
),
predictions_format="jsonl",
),
service_account=_TEST_SERVICE_ACCOUNT,
)
)

Expand All @@ -1757,6 +1762,7 @@ def test_batch_predict_gcs_source_bq_dest(
bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX,
sync=sync,
create_request_timeout=None,
service_account=_TEST_SERVICE_ACCOUNT,
)

if not sync:
Expand All @@ -1781,6 +1787,7 @@ def test_batch_predict_gcs_source_bq_dest(
),
predictions_format="bigquery",
),
service_account=_TEST_SERVICE_ACCOUNT,
)
)

Expand Down Expand Up @@ -1817,6 +1824,7 @@ def test_batch_predict_with_all_args(self, create_batch_prediction_job_mock, syn
sync=sync,
create_request_timeout=None,
batch_size=_TEST_BATCH_SIZE,
service_account=_TEST_SERVICE_ACCOUNT,
)

if not sync:
Expand Down Expand Up @@ -1857,6 +1865,7 @@ def test_batch_predict_with_all_args(self, create_batch_prediction_job_mock, syn
),
labels=_TEST_LABEL,
encryption_spec=_TEST_ENCRYPTION_SPEC,
service_account=_TEST_SERVICE_ACCOUNT,
)

create_batch_prediction_job_mock.assert_called_once_with(
Expand Down

0 comments on commit deba06b

Please sign in to comment.