Skip to content

Commit

Permalink
fix: LLM - Fixed batch prediction on tuned models
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 560910428
  • Loading branch information
Ark-kun authored and copybara-github committed Aug 29, 2023
1 parent 2e3090b commit 2a08535
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 9 deletions.
38 changes: 34 additions & 4 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
artifact as gca_artifact,
prediction_service as gca_prediction_service,
context as gca_context,
endpoint as gca_endpoint,
endpoint_v1 as gca_endpoint,
pipeline_job as gca_pipeline_job,
pipeline_state as gca_pipeline_state,
deployed_model_ref_v1,
Expand Down Expand Up @@ -1030,6 +1030,11 @@ def get_endpoint_mock():
get_endpoint_mock.return_value = gca_endpoint.Endpoint(
display_name="test-display-name",
name=test_constants.EndpointConstants._TEST_ENDPOINT_NAME,
deployed_models=[
gca_endpoint.DeployedModel(
model=test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME
),
],
)
yield get_endpoint_mock

Expand Down Expand Up @@ -2420,7 +2425,10 @@ def test_text_embedding_ga(self):
assert len(vector) == _TEXT_EMBEDDING_VECTOR_LENGTH
assert vector == _TEST_TEXT_EMBEDDING_PREDICTION["embeddings"]["values"]

def test_batch_prediction(self):
def test_batch_prediction(
self,
get_endpoint_mock,
):
"""Tests batch prediction."""
aiplatform.init(
project=_TEST_PROJECT,
Expand All @@ -2447,7 +2455,29 @@ def test_batch_prediction(self):
model_parameters={"temperature": 0.1},
)
mock_create.assert_called_once_with(
model_name="publishers/google/models/text-bison@001",
model_name=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/text-bison@001",
job_display_name=None,
gcs_source="gs://test-bucket/test_table.jsonl",
gcs_destination_prefix="gs://test-bucket/results/",
model_parameters={"temperature": 0.1},
)

# Testing tuned model batch prediction
tuned_model = language_models.TextGenerationModel(
model_id=model._model_id,
endpoint_name=test_constants.EndpointConstants._TEST_ENDPOINT_NAME,
)
with mock.patch.object(
target=aiplatform.BatchPredictionJob,
attribute="create",
) as mock_create:
tuned_model.batch_predict(
dataset="gs://test-bucket/test_table.jsonl",
destination_uri_prefix="gs://test-bucket/results/",
model_parameters={"temperature": 0.1},
)
mock_create.assert_called_once_with(
model_name=test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME,
job_display_name=None,
gcs_source="gs://test-bucket/test_table.jsonl",
gcs_destination_prefix="gs://test-bucket/results/",
Expand Down Expand Up @@ -2481,7 +2511,7 @@ def test_batch_prediction_for_text_embedding(self):
model_parameters={},
)
mock_create.assert_called_once_with(
model_name="publishers/google/models/textembedding-gecko@001",
model_name=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/textembedding-gecko@001",
job_display_name=None,
gcs_source="gs://test-bucket/test_table.jsonl",
gcs_destination_prefix="gs://test-bucket/results/",
Expand Down
5 changes: 0 additions & 5 deletions vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,11 +839,6 @@ def batch_predict(
raise ValueError(f"Unsupported destination_uri: {destination_uri_prefix}")

model_name = self._model_resource_name
# TODO(b/284512065): Batch prediction service does not support
# fully qualified publisher model names yet
publishers_index = model_name.index("/publishers/")
if publishers_index > 0:
model_name = model_name[publishers_index + 1 :]

job = aiplatform.BatchPredictionJob.create(
model_name=model_name,
Expand Down

0 comments on commit 2a08535

Please sign in to comment.