Skip to content

Commit

Permalink
fix: LLM - Fixed the TextGenerationModel.predict parameters
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 555940714
  • Loading branch information
Ark-kun authored and copybara-github committed Aug 11, 2023
1 parent af6e455 commit f3b25ab
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 7 deletions.
27 changes: 26 additions & 1 deletion tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ def test_text_generation_ga(self):
target=prediction_service_client.PredictionServiceClient,
attribute="predict",
return_value=gca_predict_response,
):
) as mock_predict:
response = model.predict(
"What is the best recipe for banana bread? Recipe:",
max_output_tokens=128,
Expand All @@ -624,8 +624,33 @@ def test_text_generation_ga(self):
top_k=5,
)

prediction_parameters = mock_predict.call_args[1]["parameters"]
assert prediction_parameters["maxDecodeSteps"] == 128
assert prediction_parameters["temperature"] == 0
assert prediction_parameters["topP"] == 1
assert prediction_parameters["topK"] == 5
assert response.text == _TEST_TEXT_GENERATION_PREDICTION["content"]

# Validating that unspecified parameters are not passed to the model
# (except `max_output_tokens`).
with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
attribute="predict",
return_value=gca_predict_response,
) as mock_predict:
model.predict(
"What is the best recipe for banana bread? Recipe:",
)

prediction_parameters = mock_predict.call_args[1]["parameters"]
assert (
prediction_parameters["maxDecodeSteps"]
== language_models.TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS
)
assert "temperature" not in prediction_parameters
assert "topP" not in prediction_parameters
assert "topK" not in prediction_parameters

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_JOB],
Expand Down
19 changes: 13 additions & 6 deletions vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,12 +290,19 @@ def _batch_predict(
A list of `TextGenerationResponse` objects that contain the texts produced by the model.
"""
instances = [{"content": str(prompt)} for prompt in prompts]
prediction_parameters = {
"temperature": temperature,
"maxDecodeSteps": max_output_tokens,
"topP": top_p,
"topK": top_k,
}
prediction_parameters = {}

if max_output_tokens:
prediction_parameters["maxDecodeSteps"] = max_output_tokens

if temperature is not None:
prediction_parameters["temperature"] = temperature

if top_p:
prediction_parameters["topP"] = top_p

if top_k:
prediction_parameters["topK"] = top_k

prediction_response = self._endpoint.predict(
instances=instances,
Expand Down

0 comments on commit f3b25ab

Please sign in to comment.