diff --git a/tests/system/aiplatform/test_language_models.py b/tests/system/aiplatform/test_language_models.py index b50403bc7f..39c4dba9d3 100644 --- a/tests/system/aiplatform/test_language_models.py +++ b/tests/system/aiplatform/test_language_models.py @@ -18,6 +18,9 @@ # pylint: disable=protected-access, g-multiple-import from google.cloud import aiplatform +from google.cloud.aiplatform.compat.types import ( + job_state_v1beta1 as gca_job_state_v1beta1, +) from tests.system.aiplatform import e2e_base from vertexai.preview.language_models import ( ChatModel, @@ -144,3 +147,23 @@ def test_tuning(self, shared_state): top_k=5, ) assert tuned_model_response.text + + def test_batch_prediction(self): + source_uri = "gs://ucaip-samples-us-central1/model/llm/batch_prediction/batch_prediction_prompts1.jsonl" + destination_uri_prefix = "gs://ucaip-samples-us-central1/model/llm/batch_prediction/predictions/text-bison@001_" + + aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) + + model = TextGenerationModel.from_pretrained("text-bison@001") + job = model.batch_predict( + source_uri=source_uri, + destination_uri_prefix=destination_uri_prefix, + model_parameters={"temperature": 0, "top_p": 1, "top_k": 5}, + ) + + job.wait_for_resource_creation() + job.wait() + gapic_job = job._gca_resource + job.delete() + + assert gapic_job.state == gca_job_state_v1beta1.JobState.JOB_STATE_SUCCEEDED diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py index 79e4b503fc..34a0b2721b 100644 --- a/tests/unit/aiplatform/test_language_models.py +++ b/tests/unit/aiplatform/test_language_models.py @@ -1135,3 +1135,37 @@ def test_text_embedding_ga(self): vector = embedding.values assert len(vector) == _TEXT_EMBEDDING_VECTOR_LENGTH assert vector == _TEST_TEXT_EMBEDDING_PREDICTION["embeddings"]["values"] + + def test_batch_prediction(self): + """Tests batch prediction.""" + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + with mock.patch.object( + target=model_garden_service_client.ModelGardenServiceClient, + attribute="get_publisher_model", + return_value=gca_publisher_model.PublisherModel( + _TEXT_BISON_PUBLISHER_MODEL_DICT + ), + ): + model = preview_language_models.TextGenerationModel.from_pretrained( + "text-bison@001" + ) + + with mock.patch.object( + target=aiplatform.BatchPredictionJob, + attribute="create", + ) as mock_create: + model.batch_predict( + source_uri="gs://test-bucket/test_table.jsonl", + destination_uri_prefix="gs://test-bucket/results/", + model_parameters={"temperature": 0.1}, + ) + mock_create.assert_called_once_with( + model_name="publishers/google/models/text-bison@001", + job_display_name=None, + gcs_source="gs://test-bucket/test_table.jsonl", + gcs_destination_prefix="gs://test-bucket/results/", + model_parameters={"temperature": 0.1}, + ) diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index 31b743afee..ffe280c2a1 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -320,8 +320,75 @@ def _batch_predict( _TextGenerationModel = TextGenerationModel -class _PreviewTextGenerationModel(TextGenerationModel, _TunableModelMixin): - """Tunable text generation model.""" +class _ModelWithBatchPredict(_LanguageModel): + """Model that supports batch prediction.""" + + def batch_predict( + self, + *, + source_uri: Union[str, List[str]], + destination_uri_prefix: str, + model_parameters: Optional[Dict] = None, + ) -> aiplatform.BatchPredictionJob: + """Starts a batch prediction job with the model. + + Args: + source_uri: The location of the dataset. + `gs://` and `bq://` URIs are supported. + destination_uri_prefix: The URI prefix for the prediction. + `gs://` and `bq://` URIs are supported. + model_parameters: Model-specific parameters to send to the model. + + Returns: + A `BatchPredictionJob` object + Raises: + ValueError: When source or destination URI is not supported. + """ + arguments = {} + first_source_uri = source_uri if isinstance(source_uri, str) else source_uri[0] + if first_source_uri.startswith("gs://"): + if not isinstance(source_uri, str): + if not all(uri.startswith("gs://") for uri in source_uri): + raise ValueError( + f"All URIs in the list must start with 'gs://': {source_uri}" + ) + arguments["gcs_source"] = source_uri + elif first_source_uri.startswith("bq://"): + if not isinstance(source_uri, str): + raise ValueError( + f"Only single BigQuery source can be specified: {source_uri}" + ) + arguments["bigquery_source"] = source_uri + else: + raise ValueError(f"Unsupported source_uri: {source_uri}") + + if destination_uri_prefix.startswith("gs://"): + arguments["gcs_destination_prefix"] = destination_uri_prefix + elif destination_uri_prefix.startswith("bq://"): + arguments["bigquery_destination_prefix"] = destination_uri_prefix + else: + raise ValueError(f"Unsupported destination_uri: {destination_uri_prefix}") + + model_name = self._model_resource_name + # TODO(b/284512065): Batch prediction service does not support + # fully qualified publisher model names yet + publishers_index = model_name.index("/publishers/") + if publishers_index > 0: + model_name = model_name[publishers_index + 1 :] + + job = aiplatform.BatchPredictionJob.create( + model_name=model_name, + job_display_name=None, + **arguments, + model_parameters=model_parameters, + ) + return job + + +class _PreviewTextGenerationModel( + TextGenerationModel, _TunableModelMixin, _ModelWithBatchPredict +): + """Preview text generation model.""" _LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE