From 9716ecd470c81682f6ff55d9ddbcd8db29c46962 Mon Sep 17 00:00:00 2001 From: Vinny Senthil Date: Mon, 23 Nov 2020 21:25:30 -0800 Subject: [PATCH] feat: Job and BatchPredictionJob classes (#79) * Create and move all constants to constants.py * Fix tests after constants.py, drop unused vars * Init Job and BatchPredictionJob class, unit tests * Address all reviewer comments * Update docstring to bigquery.table.RowIterator * Get GCS/BQ clients to use same creds as uCAIP --- google/cloud/aiplatform/__init__.py | 10 +- google/cloud/aiplatform/constants.py | 31 +++ google/cloud/aiplatform/initializer.py | 5 +- google/cloud/aiplatform/jobs.py | 215 +++++++++++++++++- google/cloud/aiplatform/utils.py | 13 +- setup.py | 1 + tests/unit/aiplatform/test_initializer.py | 8 +- tests/unit/aiplatform/test_jobs.py | 265 ++++++++++++++++++++++ 8 files changed, 533 insertions(+), 15 deletions(-) create mode 100644 google/cloud/aiplatform/constants.py create mode 100644 tests/unit/aiplatform/test_jobs.py diff --git a/google/cloud/aiplatform/__init__.py b/google/cloud/aiplatform/__init__.py index c2de451b97..18b020e95d 100644 --- a/google/cloud/aiplatform/__init__.py +++ b/google/cloud/aiplatform/__init__.py @@ -22,6 +22,7 @@ from google.cloud.aiplatform.models import Endpoint from google.cloud.aiplatform.models import Model from google.cloud.aiplatform.training_jobs import CustomTrainingJob +from google.cloud.aiplatform.jobs import BatchPredictionJob """ Usage: @@ -31,4 +32,11 @@ """ init = initializer.global_config.init -__all__ = ("gapic", "CustomTrainingJob", "Model", "Dataset", "Endpoint") +__all__ = ( + "gapic", + "BatchPredictionJob", + "CustomTrainingJob", + "Model", + "Dataset", + "Endpoint", +) diff --git a/google/cloud/aiplatform/constants.py b/google/cloud/aiplatform/constants.py new file mode 100644 index 0000000000..e601c80f9c --- /dev/null +++ b/google/cloud/aiplatform/constants.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +DEFAULT_REGION = "us-central1" +SUPPORTED_REGIONS = ("us-central1", "europe-west4", "asia-east1") +API_BASE_PATH = "aiplatform.googleapis.com" + +# Batch Prediction +BATCH_PREDICTION_INPUT_STORAGE_FORMATS = ( + "jsonl", + "csv", + "tf-record", + "tf-record-gzip", + "bigquery", + "file-list", +) +BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS = ("jsonl", "csv", "bigquery") diff --git a/google/cloud/aiplatform/initializer.py b/google/cloud/aiplatform/initializer.py index 7d1ff588b3..767462aa8d 100644 --- a/google/cloud/aiplatform/initializer.py +++ b/google/cloud/aiplatform/initializer.py @@ -24,6 +24,7 @@ from google.auth import credentials as auth_credentials from google.auth.exceptions import GoogleAuthError from google.cloud.aiplatform import utils +from google.cloud.aiplatform import constants class _Config: @@ -97,7 +98,7 @@ def project(self) -> str: @property def location(self) -> str: """Default location.""" - return self._location or utils.DEFAULT_REGION + return self._location or constants.DEFAULT_REGION @property def experiment(self) -> Optional[str]: @@ -147,7 +148,7 @@ def get_client_options( utils.validate_region(region) return client_options.ClientOptions( - api_endpoint=f"{region}-{prediction}{utils.PROD_API_ENDPOINT}" + api_endpoint=f"{region}-{prediction}{constants.API_BASE_PATH}" ) def common_location_path( diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index 1134d1fe0d..e4dc47dcd3 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -15,10 +15,221 @@ # limitations under the License. # +import abc +from typing import Iterable, Optional, Union -class Job: +from google.cloud import storage +from google.cloud import bigquery + +from google.auth import credentials as auth_credentials + +from google.cloud.aiplatform import base +from google.cloud.aiplatform import utils +from google.cloud.aiplatform_v1beta1.types import job_state +from google.cloud.aiplatform_v1beta1.services import job_service + + +class _Job(base.AiPlatformResourceNoun): + """ + Class that represents a general Job resource in AI Platform (Unified). + Cannot be directly instantiated. + + Serves as base class to specific Job types, i.e. BatchPredictionJob or + DataLabelingJob to re-use shared functionality. + + Subclasses requires one class attribute: + + _getter_method (str): The name of JobServiceClient getter method for specific + Job type, i.e. 'get_custom_job' for CustomJob + """ + + client_class = job_service.JobServiceClient + _is_client_prediction_client = False + + @property + @abc.abstractclassmethod + def _getter_method(cls) -> str: + """Name of getter method of Job subclass, i.e. 'get_custom_job' for CustomJob""" + pass + + def _get_job(self, job_name: str): + """Returns GAPIC service representation of Job subclass resource""" + return getattr(self.api_client, self._getter_method)(name=job_name) + + def __init__( + self, + valid_job_name: str, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """ + Retrives Job subclass resource by calling a subclass-specific getter method. + + Args: + valid_job_name (str): + A validated, fully-qualified Job resource name. For example: + 'projects/.../locations/.../batchPredictionJobs/456' or + 'projects/.../locations/.../customJobs/789' + project: Optional[str] = None, + Optional project to retrieve Job subclass from. If not set, + project set in aiplatform.init will be used. + location: Optional[str] = None, + Optional location to retrieve Job subclass from. If not set, + location set in aiplatform.init will be used. + credentials: Optional[auth_credentials.Credentials] = None, + Custom credentials to use. If not set, credentials set in + aiplatform.init will be used. + """ + super().__init__(project=project, location=location, credentials=credentials) + self._gca_resource = self._get_job(job_name=valid_job_name) + + def status(self) -> job_state.JobState: + """Fetch Job again and return the current JobState. + + Returns: + state (job_state.JobState): + Enum that describes the state of a AI Platform job. + """ + + # Fetch the Job again for most up-to-date job state + self._gca_resource = self._get_job(job_name=self._gca_resource.name) + + return self._gca_resource.state + + +class BatchPredictionJob(_Job): + + _getter_method = "get_batch_prediction_job" + + def __init__( + self, + batch_prediction_job_name: str, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """ + Retrieves a BatchPredictionJob resource and instantiates its representation. + + Args: + batch_prediction_job_name (str): + Required. A fully-qualified BatchPredictionJob resource name or ID. + Example: "projects/.../locations/.../batchPredictionJobs/456" or + "456" when project and location are initialized or passed. + project: Optional[str] = None, + Optional project to retrieve BatchPredictionJob from. If not set, + project set in aiplatform.init will be used. + location: Optional[str] = None, + Optional location to retrieve BatchPredictionJob from. If not set, + location set in aiplatform.init will be used. + credentials: Optional[auth_credentials.Credentials] = None, + Custom credentials to use. If not set, credentials set in + aiplatform.init will be used. + """ + valid_batch_prediction_job_name = utils.full_resource_name( + resource_name=batch_prediction_job_name, + resource_noun="batchPredictionJobs", + project=project, + location=location, + ) + + super().__init__( + valid_job_name=valid_batch_prediction_job_name, + project=project, + location=location, + credentials=credentials, + ) + + def iter_outputs( + self, bq_max_results: Optional[int] = 100 + ) -> Union[Iterable[storage.Blob], Iterable[bigquery.table.RowIterator]]: + """Returns an Iterable object to traverse the output files, either a list + of GCS Blobs or a BigQuery RowIterator depending on the output config set + when the BatchPredictionJob was created. + + Args: + bq_max_results: Optional[int] = 100 + Limit on rows to retrieve from prediction table in BigQuery dataset. + Only used when retrieving predictions from a bigquery_destination_prefix. + Default is 100. + + Returns: + Union[Iterable[storage.Blob], Iterable[bigquery.table.RowIterator]]: + Either a list of GCS Blob objects within the prediction output + directory or an iterable BigQuery RowIterator with predictions. + + Raises: + RuntimeError: + If BatchPredictionJob is in a JobState other than SUCCEEDED, + since outputs cannot be retrieved until the Job has finished. + NotImplementedError: + If BatchPredictionJob succeeded and output_info does not have a + GCS or BQ output provided. + """ + + job_status = self.status() + + if job_status != job_state.JobState.JOB_STATE_SUCCEEDED: + raise RuntimeError( + f"Cannot read outputs until BatchPredictionJob has succeeded, " + f"current status: {job_status}" + ) + + output_info = self._gca_resource.output_info + + # GCS Destination, return Blobs + if output_info.gcs_output_directory: + + # Build a Storage Client using the same credentials as JobServiceClient + storage_client = storage.Client( + credentials=self.api_client._transport._credentials + ) + + blobs = storage_client.list_blobs(output_info.gcs_output_directory) + return blobs + + # BigQuery Destination, return RowIterator + elif output_info.bigquery_output_dataset: + + # Build a BigQuery Client using the same credentials as JobServiceClient + bq_client = bigquery.Client( + credentials=self.api_client._transport._credentials + ) + + # Format from service is `bq://projectId.bqDatasetId` + bq_dataset = output_info.bigquery_output_dataset + + if bq_dataset.startswith("bq://"): + bq_dataset = bq_dataset[5:] + + # # Split project ID and BQ dataset ID + _, bq_dataset_id = bq_dataset.split(".", 1) + + row_iterator = bq_client.list_rows( + table=f"{bq_dataset_id}.predictions", max_results=bq_max_results + ) + + return row_iterator + + # Unknown Destination type + else: + raise NotImplementedError( + f"Unsupported batch prediction output location, here are details" + f"on your prediction output:\n{output_info}" + ) + + +class CustomJob(_Job): + _getter_method = "get_custom_job" + pass + + +class DataLabelingJob(_Job): + _getter_method = "get_data_labeling_job" pass -class BatchPredictionJob(Job): +class HyperparameterTuningJob(_Job): + _getter_method = "get_hyperparameter_tuning_job" pass diff --git a/google/cloud/aiplatform/utils.py b/google/cloud/aiplatform/utils.py index a1ab202213..a5f1e6be94 100644 --- a/google/cloud/aiplatform/utils.py +++ b/google/cloud/aiplatform/utils.py @@ -35,9 +35,8 @@ client as prediction_client, ) -DEFAULT_REGION = "us-central1" -SUPPORTED_REGIONS = ("us-central1", "europe-west4", "asia-east1") -PROD_API_ENDPOINT = "aiplatform.googleapis.com" +from google.cloud.aiplatform import constants + AiPlatformServiceClient = TypeVar( "AiPlatformServiceClient", @@ -218,12 +217,14 @@ def validate_region(region: str) -> bool: ValueError: If region is not in supported regions. """ if not region: - raise ValueError(f"Please provide a region, select from {SUPPORTED_REGIONS}") + raise ValueError( + f"Please provide a region, select from {constants.SUPPORTED_REGIONS}" + ) region = region.lower() - if region not in SUPPORTED_REGIONS: + if region not in constants.SUPPORTED_REGIONS: raise ValueError( - f"Unsupported region for AI Platform, select from {SUPPORTED_REGIONS}" + f"Unsupported region for AI Platform, select from {constants.SUPPORTED_REGIONS}" ) return True diff --git a/setup.py b/setup.py index e707fa9cc3..4b5fb49da0 100644 --- a/setup.py +++ b/setup.py @@ -45,6 +45,7 @@ install_requires=( "google-api-core[grpc] >= 1.22.2, < 2.0.0dev", "google-cloud-storage >= 1.32.0, < 2.0.0dev", + "google-cloud-bigquery >= 1.15.0, < 3.0.0dev", "libcst >= 0.2.5", "proto-plus >= 1.10.1", "mock >= 4.0.2", diff --git a/tests/unit/aiplatform/test_initializer.py b/tests/unit/aiplatform/test_initializer.py index 94abdacfba..e79509babf 100644 --- a/tests/unit/aiplatform/test_initializer.py +++ b/tests/unit/aiplatform/test_initializer.py @@ -22,7 +22,7 @@ import google.auth from google.auth import credentials from google.cloud.aiplatform import initializer -from google.cloud.aiplatform import utils +from google.cloud.aiplatform import constants from google.cloud.aiplatform_v1beta1.services.model_service.client import ( ModelServiceClient, ) @@ -56,7 +56,7 @@ def test_init_location_sets_location(self): assert initializer.global_config.location == _TEST_LOCATION def test_not_init_location_gets_default_location(self): - assert initializer.global_config.location == utils.DEFAULT_REGION + assert initializer.global_config.location == constants.DEFAULT_REGION def test_init_location_with_invalid_location_raises(self): with pytest.raises(ValueError): @@ -94,7 +94,7 @@ def test_create_client_returns_client(self): client = initializer.global_config.create_client(ModelServiceClient) assert isinstance(client, ModelServiceClient) assert ( - client._transport._host == f"{_TEST_LOCATION}-{utils.PROD_API_ENDPOINT}:443" + client._transport._host == f"{_TEST_LOCATION}-{constants.API_BASE_PATH}:443" ) def test_create_client_overrides(self): @@ -109,7 +109,7 @@ def test_create_client_overrides(self): assert isinstance(client, ModelServiceClient) assert ( client._transport._host - == f"{_TEST_LOCATION_2}-prediction-{utils.PROD_API_ENDPOINT}:443" + == f"{_TEST_LOCATION_2}-prediction-{constants.API_BASE_PATH}:443" ) assert client._transport._credentials == creds diff --git a/tests/unit/aiplatform/test_jobs.py b/tests/unit/aiplatform/test_jobs.py new file mode 100644 index 0000000000..3f4c5702be --- /dev/null +++ b/tests/unit/aiplatform/test_jobs.py @@ -0,0 +1,265 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest + +from unittest import mock +from importlib import reload +from unittest.mock import patch + +from google.cloud import storage +from google.cloud import bigquery + +from google.cloud import aiplatform +from google.cloud.aiplatform import jobs +from google.cloud.aiplatform import initializer + +from google.cloud.aiplatform_v1beta1 import types +from google.cloud.aiplatform_v1beta1.types import job_state +from google.cloud.aiplatform_v1beta1.types import batch_prediction_job +from google.cloud.aiplatform_v1beta1.services import job_service + + +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" +_TEST_ID = "1028944691210842416" +_TEST_ALT_ID = "8834795523125638878" +_TEST_DISPLAY_NAME = "my_job_1234" +_TEST_BQ_DATASET_ID = "bqDatasetId" +_TEST_BQ_JOB_ID = "123459876" +_TEST_BQ_MAX_RESULTS = 100 +_TEST_GCS_BUCKET_NAME = "my-bucket" + +_TEST_BQ_PATH = f"bq://projectId.{_TEST_BQ_DATASET_ID}" +_TEST_GCS_BUCKET_PATH = f"gs://{_TEST_GCS_BUCKET_NAME}" +_TEST_GCS_JSONL_SOURCE_URI = f"{_TEST_GCS_BUCKET_PATH}/bp_input_config.jsonl" +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" + +_TEST_MODEL_NAME = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/models/{_TEST_ALT_ID}" +) +_TEST_BATCH_PREDICTION_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/batchPredictionJobs/{_TEST_ID}" + +_TEST_JOB_STATE_SUCCESS = job_state.JobState(4) +_TEST_JOB_STATE_RUNNING = job_state.JobState(3) +_TEST_JOB_STATE_PENDING = job_state.JobState(2) + +_TEST_GCS_INPUT_CONFIG = batch_prediction_job.BatchPredictionJob.InputConfig( + instances_format="jsonl", + gcs_source=types.GcsSource(uris=[_TEST_GCS_JSONL_SOURCE_URI]), +) +_TEST_GCS_OUTPUT_CONFIG = batch_prediction_job.BatchPredictionJob.OutputConfig( + predictions_format="jsonl", + gcs_destination=types.GcsDestination(output_uri_prefix=_TEST_GCS_BUCKET_PATH), +) + +_TEST_BQ_INPUT_CONFIG = batch_prediction_job.BatchPredictionJob.InputConfig( + instances_format="bigquery", + bigquery_source=types.BigQuerySource(input_uri=_TEST_BQ_PATH), +) +_TEST_BQ_OUTPUT_CONFIG = batch_prediction_job.BatchPredictionJob.OutputConfig( + predictions_format="bigquery", + bigquery_destination=types.BigQueryDestination(output_uri=_TEST_BQ_PATH), +) + +_TEST_GCS_OUTPUT_INFO = batch_prediction_job.BatchPredictionJob.OutputInfo( + gcs_output_directory=_TEST_GCS_BUCKET_PATH +) +_TEST_BQ_OUTPUT_INFO = batch_prediction_job.BatchPredictionJob.OutputInfo( + bigquery_output_dataset=_TEST_BQ_PATH +) + +_TEST_EMPTY_OUTPUT_INFO = batch_prediction_job.BatchPredictionJob.OutputInfo() + +_TEST_ITER_DIRS_BQ_QUERY = f"SELECT * FROM {_TEST_BQ_DATASET_ID}.predictions LIMIT 100" + +_TEST_GCS_BLOBS = [ + storage.Blob(name="some/path/prediction.jsonl", bucket=_TEST_GCS_BUCKET_NAME) +] + + +# TODO(b/171333554): Move reusable test fixtures to conftest.py file +class TestJob: + def setup_method(self): + reload(initializer) + reload(aiplatform) + + # Test Fixtures + + @pytest.fixture + def get_batch_prediction_job_gcs_output_mock(self): + with patch.object( + job_service.JobServiceClient, "get_batch_prediction_job" + ) as get_batch_prediction_job_mock: + get_batch_prediction_job_mock.return_value = batch_prediction_job.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_NAME, + display_name=_TEST_DISPLAY_NAME, + model=_TEST_MODEL_NAME, + input_config=_TEST_GCS_INPUT_CONFIG, + output_config=_TEST_GCS_OUTPUT_CONFIG, + output_info=_TEST_GCS_OUTPUT_INFO, + state=_TEST_JOB_STATE_SUCCESS, + ) + yield get_batch_prediction_job_mock + + @pytest.fixture + def get_batch_prediction_job_bq_output_mock(self): + with patch.object( + job_service.JobServiceClient, "get_batch_prediction_job" + ) as get_batch_prediction_job_mock: + get_batch_prediction_job_mock.return_value = batch_prediction_job.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_NAME, + display_name=_TEST_DISPLAY_NAME, + model=_TEST_MODEL_NAME, + input_config=_TEST_GCS_INPUT_CONFIG, + output_config=_TEST_BQ_OUTPUT_CONFIG, + output_info=_TEST_BQ_OUTPUT_INFO, + state=_TEST_JOB_STATE_SUCCESS, + ) + yield get_batch_prediction_job_mock + + @pytest.fixture + def get_batch_prediction_job_empty_output_mock(self): + with patch.object( + job_service.JobServiceClient, "get_batch_prediction_job" + ) as get_batch_prediction_job_mock: + get_batch_prediction_job_mock.return_value = batch_prediction_job.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_NAME, + display_name=_TEST_DISPLAY_NAME, + model=_TEST_MODEL_NAME, + input_config=_TEST_GCS_INPUT_CONFIG, + output_config=_TEST_BQ_OUTPUT_CONFIG, + output_info=_TEST_EMPTY_OUTPUT_INFO, + state=_TEST_JOB_STATE_SUCCESS, + ) + yield get_batch_prediction_job_mock + + @pytest.fixture + def get_batch_prediction_job_running_bq_output_mock(self): + with patch.object( + job_service.JobServiceClient, "get_batch_prediction_job" + ) as get_batch_prediction_job_mock: + get_batch_prediction_job_mock.return_value = batch_prediction_job.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_NAME, + display_name=_TEST_DISPLAY_NAME, + model=_TEST_MODEL_NAME, + input_config=_TEST_GCS_INPUT_CONFIG, + output_config=_TEST_BQ_OUTPUT_CONFIG, + output_info=_TEST_BQ_OUTPUT_INFO, + state=_TEST_JOB_STATE_RUNNING, + ) + yield get_batch_prediction_job_mock + + @pytest.fixture + def storage_list_blobs_mock(self): + with patch.object(storage.Client, "list_blobs") as list_blobs_mock: + list_blobs_mock.return_value = _TEST_GCS_BLOBS + yield list_blobs_mock + + @pytest.fixture + def bq_list_rows_mock(self): + with patch.object(bigquery.Client, "list_rows") as list_rows_mock: + + list_rows_mock.return_value = mock.Mock(bigquery.table.RowIterator) + yield list_rows_mock + + # Unit Tests + + def test_init_job_class(self): + """ + Raises TypeError since abstract property '_getter_method' is not set, + the _Job class should only be instantiated through a child class. + """ + with pytest.raises(TypeError): + jobs._Job(valid_job_name=_TEST_BATCH_PREDICTION_NAME) + + def test_init_batch_prediction_job_class( + self, get_batch_prediction_job_bq_output_mock + ): + aiplatform.BatchPredictionJob( + batch_prediction_job_name=_TEST_BATCH_PREDICTION_NAME + ) + get_batch_prediction_job_bq_output_mock.assert_called_once_with( + name=_TEST_BATCH_PREDICTION_NAME + ) + + def test_batch_prediction_job_status(self, get_batch_prediction_job_bq_output_mock): + bp = aiplatform.BatchPredictionJob( + batch_prediction_job_name=_TEST_BATCH_PREDICTION_NAME + ) + + # get_batch_prediction() is called again here + bp_job_state = bp.status() + + assert get_batch_prediction_job_bq_output_mock.call_count == 2 + assert bp_job_state == _TEST_JOB_STATE_SUCCESS + + get_batch_prediction_job_bq_output_mock.assert_called_with( + name=_TEST_BATCH_PREDICTION_NAME + ) + + def test_batch_prediction_iter_dirs_gcs( + self, get_batch_prediction_job_gcs_output_mock, storage_list_blobs_mock + ): + bp = aiplatform.BatchPredictionJob( + batch_prediction_job_name=_TEST_BATCH_PREDICTION_NAME + ) + blobs = bp.iter_outputs() + + storage_list_blobs_mock.assert_called_once_with( + _TEST_GCS_OUTPUT_INFO.gcs_output_directory + ) + + assert blobs == _TEST_GCS_BLOBS + + def test_batch_prediction_iter_dirs_bq( + self, get_batch_prediction_job_bq_output_mock, bq_list_rows_mock + ): + bp = aiplatform.BatchPredictionJob( + batch_prediction_job_name=_TEST_BATCH_PREDICTION_NAME + ) + + bp.iter_outputs() + + bq_list_rows_mock.assert_called_once_with( + table=f"{_TEST_BQ_DATASET_ID}.predictions", max_results=_TEST_BQ_MAX_RESULTS + ) + + def test_batch_prediction_iter_dirs_while_running( + self, get_batch_prediction_job_running_bq_output_mock + ): + """ + Raises RuntimeError since outputs cannot be read while BatchPredictionJob is still running + """ + with pytest.raises(RuntimeError): + bp = aiplatform.BatchPredictionJob( + batch_prediction_job_name=_TEST_BATCH_PREDICTION_NAME + ) + bp.iter_outputs() + + def test_batch_prediction_iter_dirs_invalid_output_info( + self, get_batch_prediction_job_empty_output_mock + ): + """ + Raises NotImplementedError since the BatchPredictionJob's output_info + contains no output GCS directory or BQ dataset. + """ + with pytest.raises(NotImplementedError): + bp = aiplatform.BatchPredictionJob( + batch_prediction_job_name=_TEST_BATCH_PREDICTION_NAME + ) + bp.iter_outputs()