From f3b25ab694eaee18f5cc34f800f1b6021d291bca Mon Sep 17 00:00:00 2001 From: Alexey Volkov Date: Fri, 11 Aug 2023 07:13:51 -0700 Subject: [PATCH] fix: LLM - Fixed the `TextGenerationModel.predict` parameters PiperOrigin-RevId: 555940714 --- tests/unit/aiplatform/test_language_models.py | 27 ++++++++++++++++++- vertexai/language_models/_language_models.py | 19 ++++++++----- 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py index 4d5e286070..1c4781c9bc 100644 --- a/tests/unit/aiplatform/test_language_models.py +++ b/tests/unit/aiplatform/test_language_models.py @@ -615,7 +615,7 @@ def test_text_generation_ga(self): target=prediction_service_client.PredictionServiceClient, attribute="predict", return_value=gca_predict_response, - ): + ) as mock_predict: response = model.predict( "What is the best recipe for banana bread? Recipe:", max_output_tokens=128, @@ -624,8 +624,33 @@ def test_text_generation_ga(self): top_k=5, ) + prediction_parameters = mock_predict.call_args[1]["parameters"] + assert prediction_parameters["maxDecodeSteps"] == 128 + assert prediction_parameters["temperature"] == 0 + assert prediction_parameters["topP"] == 1 + assert prediction_parameters["topK"] == 5 assert response.text == _TEST_TEXT_GENERATION_PREDICTION["content"] + # Validating that unspecified parameters are not passed to the model + # (except `max_output_tokens`). + with mock.patch.object( + target=prediction_service_client.PredictionServiceClient, + attribute="predict", + return_value=gca_predict_response, + ) as mock_predict: + model.predict( + "What is the best recipe for banana bread? Recipe:", + ) + + prediction_parameters = mock_predict.call_args[1]["parameters"] + assert ( + prediction_parameters["maxDecodeSteps"] + == language_models.TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS + ) + assert "temperature" not in prediction_parameters + assert "topP" not in prediction_parameters + assert "topK" not in prediction_parameters + @pytest.mark.parametrize( "job_spec", [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_JOB], diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index 4cd02a1142..620bc4e708 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -290,12 +290,19 @@ def _batch_predict( A list of `TextGenerationResponse` objects that contain the texts produced by the model. """ instances = [{"content": str(prompt)} for prompt in prompts] - prediction_parameters = { - "temperature": temperature, - "maxDecodeSteps": max_output_tokens, - "topP": top_p, - "topK": top_k, - } + prediction_parameters = {} + + if max_output_tokens: + prediction_parameters["maxDecodeSteps"] = max_output_tokens + + if temperature is not None: + prediction_parameters["temperature"] = temperature + + if top_p: + prediction_parameters["topP"] = top_p + + if top_k: + prediction_parameters["topK"] = top_k prediction_response = self._endpoint.predict( instances=instances,