Skip to content

Commit

Permalink
feat: LLM - Added support for multiple response candidates in code ge…
Browse files Browse the repository at this point in the history
…neration models

PiperOrigin-RevId: 573357986
  • Loading branch information
Ark-kun authored and copybara-github committed Oct 14, 2023
1 parent 760a025 commit 0c371a4
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 5 deletions.
40 changes: 40 additions & 0 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2540,6 +2540,46 @@ def test_code_generation(self):
assert "temperature" not in prediction_parameters
assert "maxOutputTokens" not in prediction_parameters

def test_code_generation_multiple_candidates(self):
"""Tests the code generation model with multiple candidates."""
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_CODE_GENERATION_BISON_PUBLISHER_MODEL_DICT
),
autospec=True,
):
model = language_models.CodeGenerationModel.from_pretrained(
"code-bison@001"
)

gca_predict_response = gca_prediction_service.PredictResponse()
# Discrepancy between the number of `instances` and the number of `predictions`
# is a violation of the prediction service invariant, but the service does this.
gca_predict_response.predictions.append(_TEST_CODE_GENERATION_PREDICTION)
gca_predict_response.predictions.append(_TEST_CODE_GENERATION_PREDICTION)
with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
attribute="predict",
return_value=gca_predict_response,
autospec=True,
) as mock_predict:
response = model.predict(
prefix="Write a function that checks if a year is a leap year.",
# candidate_count acts as a maximum number, not exact number.
candidate_count=7,
)
prediction_parameters = mock_predict.call_args[1]["parameters"]
assert prediction_parameters["candidateCount"] == 7

assert response.text == _TEST_CODE_GENERATION_PREDICTION["content"]
# The service can return a different number of candidates.
assert len(response.candidates) == 2
assert (
response.candidates[0].text == _TEST_CODE_GENERATION_PREDICTION["content"]
)

def test_code_completion(self):
"""Tests code completion with the code generation model."""
aiplatform.init(
Expand Down
22 changes: 17 additions & 5 deletions vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2254,6 +2254,7 @@ def _create_prediction_request(
max_output_tokens: Optional[int] = None,
temperature: Optional[float] = None,
stop_sequences: Optional[List[str]] = None,
candidate_count: Optional[int] = None,
) -> _PredictionRequest:
"""Creates a code generation prediction request.
Expand All @@ -2263,7 +2264,7 @@ def _create_prediction_request(
max_output_tokens: Max length of the output text in tokens. Range: [1, 1000].
temperature: Controls the randomness of predictions. Range: [0, 1].
stop_sequences: Customized stop sequences to stop the decoding process.
candidate_count: Number of response candidates to return.
Returns:
A `TextGenerationResponse` object that contains the text produced by the model.
Expand All @@ -2285,6 +2286,9 @@ def _create_prediction_request(
if stop_sequences:
prediction_parameters["stopSequences"] = stop_sequences

if candidate_count is not None:
prediction_parameters["candidateCount"] = candidate_count

return _PredictionRequest(instance=instance, parameters=prediction_parameters)

def predict(
Expand All @@ -2295,6 +2299,7 @@ def predict(
max_output_tokens: Optional[int] = None,
temperature: Optional[float] = None,
stop_sequences: Optional[List[str]] = None,
candidate_count: Optional[int] = None,
) -> "TextGenerationResponse":
"""Gets model response for a single prompt.
Expand All @@ -2304,23 +2309,26 @@ def predict(
max_output_tokens: Max length of the output text in tokens. Range: [1, 1000].
temperature: Controls the randomness of predictions. Range: [0, 1].
stop_sequences: Customized stop sequences to stop the decoding process.
candidate_count: Number of response candidates to return.
Returns:
A `TextGenerationResponse` object that contains the text produced by the model.
A `MultiCandidateTextGenerationResponse` object that contains the
text produced by the model.
"""
prediction_request = self._create_prediction_request(
prefix=prefix,
suffix=suffix,
max_output_tokens=max_output_tokens,
temperature=temperature,
stop_sequences=stop_sequences,
candidate_count=candidate_count,
)

prediction_response = self._endpoint.predict(
instances=[prediction_request.instance],
parameters=prediction_request.parameters,
)
return _parse_text_generation_model_response(prediction_response)
return _parse_text_generation_model_multi_candidate_response(prediction_response)

async def predict_async(
self,
Expand All @@ -2330,6 +2338,7 @@ async def predict_async(
max_output_tokens: Optional[int] = None,
temperature: Optional[float] = None,
stop_sequences: Optional[List[str]] = None,
candidate_count: Optional[int] = None,
) -> "TextGenerationResponse":
"""Asynchronously gets model response for a single prompt.
Expand All @@ -2339,23 +2348,26 @@ async def predict_async(
max_output_tokens: Max length of the output text in tokens. Range: [1, 1000].
temperature: Controls the randomness of predictions. Range: [0, 1].
stop_sequences: Customized stop sequences to stop the decoding process.
candidate_count: Number of response candidates to return.
Returns:
A `TextGenerationResponse` object that contains the text produced by the model.
A `MultiCandidateTextGenerationResponse` object that contains the
text produced by the model.
"""
prediction_request = self._create_prediction_request(
prefix=prefix,
suffix=suffix,
max_output_tokens=max_output_tokens,
temperature=temperature,
stop_sequences=stop_sequences,
candidate_count=candidate_count,
)

prediction_response = await self._endpoint.predict_async(
instances=[prediction_request.instance],
parameters=prediction_request.parameters,
)
return _parse_text_generation_model_response(prediction_response)
return _parse_text_generation_model_multi_candidate_response(prediction_response)

def predict_streaming(
self,
Expand Down

0 comments on commit 0c371a4

Please sign in to comment.