Skip to content

Commit

Permalink
feat: LLM - Added batch prediction
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 542106410
  • Loading branch information
Ark-kun authored and copybara-github committed Jun 21, 2023
1 parent cd67734 commit 2235305
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 2 deletions.
23 changes: 23 additions & 0 deletions tests/system/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
# pylint: disable=protected-access, g-multiple-import

from google.cloud import aiplatform
from google.cloud.aiplatform.compat.types import (
job_state_v1beta1 as gca_job_state_v1beta1,
)
from tests.system.aiplatform import e2e_base
from vertexai.preview.language_models import (
ChatModel,
Expand Down Expand Up @@ -144,3 +147,23 @@ def test_tuning(self, shared_state):
top_k=5,
)
assert tuned_model_response.text

def test_batch_prediction(self):
source_uri = "gs://ucaip-samples-us-central1/model/llm/batch_prediction/batch_prediction_prompts1.jsonl"
destination_uri_prefix = "gs://ucaip-samples-us-central1/model/llm/batch_prediction/predictions/text-bison@001_"

aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)

model = TextGenerationModel.from_pretrained("text-bison@001")
job = model.batch_predict(
source_uri=source_uri,
destination_uri_prefix=destination_uri_prefix,
model_parameters={"temperature": 0, "top_p": 1, "top_k": 5},
)

job.wait_for_resource_creation()
job.wait()
gapic_job = job._gca_resource
job.delete()

assert gapic_job.state == gca_job_state_v1beta1.JobState.JOB_STATE_SUCCEEDED
34 changes: 34 additions & 0 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,3 +1135,37 @@ def test_text_embedding_ga(self):
vector = embedding.values
assert len(vector) == _TEXT_EMBEDDING_VECTOR_LENGTH
assert vector == _TEST_TEXT_EMBEDDING_PREDICTION["embeddings"]["values"]

def test_batch_prediction(self):
"""Tests batch prediction."""
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_TEXT_BISON_PUBLISHER_MODEL_DICT
),
):
model = preview_language_models.TextGenerationModel.from_pretrained(
"text-bison@001"
)

with mock.patch.object(
target=aiplatform.BatchPredictionJob,
attribute="create",
) as mock_create:
model.batch_predict(
source_uri="gs://test-bucket/test_table.jsonl",
destination_uri_prefix="gs://test-bucket/results/",
model_parameters={"temperature": 0.1},
)
mock_create.assert_called_once_with(
model_name="publishers/google/models/text-bison@001",
job_display_name=None,
gcs_source="gs://test-bucket/test_table.jsonl",
gcs_destination_prefix="gs://test-bucket/results/",
model_parameters={"temperature": 0.1},
)
71 changes: 69 additions & 2 deletions vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,75 @@ def _batch_predict(
_TextGenerationModel = TextGenerationModel


class _PreviewTextGenerationModel(TextGenerationModel, _TunableModelMixin):
"""Tunable text generation model."""
class _ModelWithBatchPredict(_LanguageModel):
"""Model that supports batch prediction."""

def batch_predict(
self,
*,
source_uri: Union[str, List[str]],
destination_uri_prefix: str,
model_parameters: Optional[Dict] = None,
) -> aiplatform.BatchPredictionJob:
"""Starts a batch prediction job with the model.
Args:
source_uri: The location of the dataset.
`gs://` and `bq://` URIs are supported.
destination_uri_prefix: The URI prefix for the prediction.
`gs://` and `bq://` URIs are supported.
model_parameters: Model-specific parameters to send to the model.
Returns:
A `BatchPredictionJob` object
Raises:
ValueError: When source or destination URI is not supported.
"""
arguments = {}
first_source_uri = source_uri if isinstance(source_uri, str) else source_uri[0]
if first_source_uri.startswith("gs://"):
if not isinstance(source_uri, str):
if not all(uri.startswith("gs://") for uri in source_uri):
raise ValueError(
f"All URIs in the list must start with 'gs://': {source_uri}"
)
arguments["gcs_source"] = source_uri
elif first_source_uri.startswith("bq://"):
if not isinstance(source_uri, str):
raise ValueError(
f"Only single BigQuery source can be specified: {source_uri}"
)
arguments["bigquery_source"] = source_uri
else:
raise ValueError(f"Unsupported source_uri: {source_uri}")

if destination_uri_prefix.startswith("gs://"):
arguments["gcs_destination_prefix"] = destination_uri_prefix
elif destination_uri_prefix.startswith("bq://"):
arguments["bigquery_destination_prefix"] = destination_uri_prefix
else:
raise ValueError(f"Unsupported destination_uri: {destination_uri_prefix}")

model_name = self._model_resource_name
# TODO(b/284512065): Batch prediction service does not support
# fully qualified publisher model names yet
publishers_index = model_name.index("/publishers/")
if publishers_index > 0:
model_name = model_name[publishers_index + 1 :]

job = aiplatform.BatchPredictionJob.create(
model_name=model_name,
job_display_name=None,
**arguments,
model_parameters=model_parameters,
)
return job


class _PreviewTextGenerationModel(
TextGenerationModel, _TunableModelMixin, _ModelWithBatchPredict
):
"""Preview text generation model."""

_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE

Expand Down

0 comments on commit 2235305

Please sign in to comment.