Skip to content

Commit

Permalink
feat: LLM - Added stop_sequences parameter to streaming methods and…
Browse files Browse the repository at this point in the history
… `CodeChatModel`

PiperOrigin-RevId: 562915062
  • Loading branch information
Ark-kun authored and copybara-github committed Sep 5, 2023
1 parent f8d43bb commit d62bb1b
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 1 deletion.
13 changes: 12 additions & 1 deletion tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,6 +1304,7 @@ def test_text_generation_model_predict_streaming(self):
temperature=0.0,
top_p=1.0,
top_k=5,
stop_sequences=["# %%"],
):
assert len(response.text) > 10

Expand Down Expand Up @@ -1969,6 +1970,7 @@ def test_chat_model_send_message_streaming(self):
),
],
temperature=0.0,
stop_sequences=["\n"],
)

# Using list instead of a generator so that it can be reused.
Expand All @@ -1983,6 +1985,7 @@ def test_chat_model_send_message_streaming(self):
message_max_output_tokens = 200
message_top_k = 2
message_top_p = 0.2
message_stop_sequences = ["# %%"]

with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
Expand All @@ -1998,6 +2001,7 @@ def test_chat_model_send_message_streaming(self):
temperature=message_temperature,
top_k=message_top_k,
top_p=message_top_p,
stop_sequences=message_stop_sequences,
)
):
assert len(response.text) > 10
Expand Down Expand Up @@ -2036,6 +2040,7 @@ def test_code_chat(self):
code_chat = model.start_chat(
max_output_tokens=128,
temperature=0.2,
stop_sequences=["\n"],
)

gca_predict_response1 = gca_prediction_service.PredictResponse()
Expand Down Expand Up @@ -2075,12 +2080,15 @@ def test_code_chat(self):
# Validating the parameters
chat_temperature = 0.1
chat_max_output_tokens = 100
chat_stop_sequences = ["\n"]
message_temperature = 0.2
message_max_output_tokens = 200
message_stop_sequences = ["# %%"]

code_chat2 = model.start_chat(
temperature=chat_temperature,
max_output_tokens=chat_max_output_tokens,
stop_sequences=chat_stop_sequences,
)

gca_predict_response3 = gca_prediction_service.PredictResponse()
Expand All @@ -2097,15 +2105,18 @@ def test_code_chat(self):
prediction_parameters = mock_predict.call_args[1]["parameters"]
assert prediction_parameters["temperature"] == chat_temperature
assert prediction_parameters["maxDecodeSteps"] == chat_max_output_tokens
assert prediction_parameters["stopSequences"] == chat_stop_sequences

code_chat2.send_message(
"Please help write a function to calculate the min of two numbers",
temperature=message_temperature,
max_output_tokens=message_max_output_tokens,
stop_sequences=message_stop_sequences,
)
prediction_parameters = mock_predict.call_args[1]["parameters"]
assert prediction_parameters["temperature"] == message_temperature
assert prediction_parameters["maxDecodeSteps"] == message_max_output_tokens
assert prediction_parameters["stopSequences"] == message_stop_sequences

def test_code_chat_model_send_message_streaming(self):
"""Tests the chat generation model."""
Expand All @@ -2122,7 +2133,7 @@ def test_code_chat_model_send_message_streaming(self):
):
model = language_models.CodeChatModel.from_pretrained("codechat-bison@001")

chat = model.start_chat(temperature=0.0)
chat = model.start_chat(temperature=0.0, stop_sequences=["\n"])

# Using list instead of a generator so that it can be reused.
response_generator = [
Expand Down
24 changes: 24 additions & 0 deletions vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,7 @@ def predict_streaming(
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
stop_sequences: Optional[List[str]] = None,
) -> Iterator[TextGenerationResponse]:
"""Gets a streaming model response for a single prompt.
Expand All @@ -745,6 +746,7 @@ def predict_streaming(
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
stop_sequences: Customized stop sequences to stop the decoding process.
Yields:
A stream of `TextGenerationResponse` objects that contain partial
Expand All @@ -771,6 +773,9 @@ def predict_streaming(
if top_k:
prediction_parameters["topK"] = top_k

if stop_sequences:
prediction_parameters["stopSequences"] = stop_sequences

for prediction_dict in _streaming_prediction.predict_stream_of_dicts_from_single_dict(
prediction_service_client=prediction_service_client,
endpoint_name=self._endpoint_name,
Expand Down Expand Up @@ -1299,12 +1304,14 @@ def start_chat(
max_output_tokens: Optional[int] = None,
temperature: Optional[float] = None,
message_history: Optional[List[ChatMessage]] = None,
stop_sequences: Optional[List[str]] = None,
) -> "CodeChatSession":
"""Starts a chat session with the code chat model.
Args:
max_output_tokens: Max length of the output text in tokens. Range: [1, 1000].
temperature: Controls the randomness of predictions. Range: [0, 1].
stop_sequences: Customized stop sequences to stop the decoding process.
Returns:
A `ChatSession` object.
Expand All @@ -1314,6 +1321,7 @@ def start_chat(
max_output_tokens=max_output_tokens,
temperature=temperature,
message_history=message_history,
stop_sequences=stop_sequences,
)


Expand Down Expand Up @@ -1541,6 +1549,7 @@ def send_message_streaming(
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
stop_sequences: Optional[List[str]] = None,
) -> Iterator[TextGenerationResponse]:
"""Sends message to the language model and gets a streamed response.
Expand All @@ -1556,6 +1565,8 @@ def send_message_streaming(
Uses the value specified when calling `ChatModel.start_chat` by default.
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
Uses the value specified when calling `ChatModel.start_chat` by default.
stop_sequences: Customized stop sequences to stop the decoding process.
Uses the value specified when calling `ChatModel.start_chat` by default.
Yields:
A stream of `TextGenerationResponse` objects that contain partial
Expand All @@ -1567,6 +1578,7 @@ def send_message_streaming(
temperature=temperature,
top_k=top_k,
top_p=top_p,
stop_sequences=stop_sequences,
)

prediction_service_client = self._model._endpoint._prediction_client
Expand Down Expand Up @@ -1644,12 +1656,14 @@ def __init__(
max_output_tokens: Optional[int] = None,
temperature: Optional[float] = None,
message_history: Optional[List[ChatMessage]] = None,
stop_sequences: Optional[List[str]] = None,
):
super().__init__(
model=model,
max_output_tokens=max_output_tokens,
temperature=temperature,
message_history=message_history,
stop_sequences=stop_sequences,
)

def send_message(
Expand All @@ -1658,6 +1672,7 @@ def send_message(
*,
max_output_tokens: Optional[int] = None,
temperature: Optional[float] = None,
stop_sequences: Optional[List[str]] = None,
) -> "TextGenerationResponse":
"""Sends message to the code chat model and gets a response.
Expand All @@ -1667,6 +1682,7 @@ def send_message(
Uses the value specified when calling `CodeChatModel.start_chat` by default.
temperature: Controls the randomness of predictions. Range: [0, 1].
Uses the value specified when calling `CodeChatModel.start_chat` by default.
stop_sequences: Customized stop sequences to stop the decoding process.
Returns:
A `TextGenerationResponse` object that contains the text produced by the model.
Expand All @@ -1675,6 +1691,7 @@ def send_message(
message=message,
max_output_tokens=max_output_tokens,
temperature=temperature,
stop_sequences=stop_sequences,
)

def send_message_streaming(
Expand All @@ -1683,6 +1700,7 @@ def send_message_streaming(
*,
max_output_tokens: Optional[int] = None,
temperature: Optional[float] = None,
stop_sequences: Optional[List[str]] = None,
) -> Iterator[TextGenerationResponse]:
"""Sends message to the language model and gets a streamed response.
Expand All @@ -1694,6 +1712,8 @@ def send_message_streaming(
Uses the value specified when calling `ChatModel.start_chat` by default.
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
Uses the value specified when calling `ChatModel.start_chat` by default.
stop_sequences: Customized stop sequences to stop the decoding process.
Uses the value specified when calling `ChatModel.start_chat` by default.
Returns:
A stream of `TextGenerationResponse` objects that contain partial
Expand All @@ -1703,6 +1723,7 @@ def send_message_streaming(
message=message,
max_output_tokens=max_output_tokens,
temperature=temperature,
stop_sequences=stop_sequences,
)


Expand Down Expand Up @@ -1811,6 +1832,7 @@ def predict_streaming(
*,
max_output_tokens: Optional[int] = None,
temperature: Optional[float] = None,
stop_sequences: Optional[List[str]] = None,
) -> Iterator[TextGenerationResponse]:
"""Predicts the code based on previous code.
Expand All @@ -1821,6 +1843,7 @@ def predict_streaming(
suffix: Code after the current point.
max_output_tokens: Max length of the output text in tokens. Range: [1, 1000].
temperature: Controls the randomness of predictions. Range: [0, 1].
stop_sequences: Customized stop sequences to stop the decoding process.
Yields:
A stream of `TextGenerationResponse` objects that contain partial
Expand All @@ -1831,6 +1854,7 @@ def predict_streaming(
suffix=suffix,
max_output_tokens=max_output_tokens,
temperature=temperature,
stop_sequences=stop_sequences,
)

prediction_service_client = self._endpoint._prediction_client
Expand Down

0 comments on commit d62bb1b

Please sign in to comment.