Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix handling of streaming response in AnthropicClaudeInvocationLayer #4993

Merged
merged 5 commits into from
Jun 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions haystack/nodes/prompt/invocation_layer/anthropic_claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@

from haystack.errors import AnthropicError, AnthropicRateLimitError, AnthropicUnauthorizedError
from haystack.nodes.prompt.invocation_layer.base import PromptModelInvocationLayer
from haystack.nodes.prompt.invocation_layer.handlers import TokenStreamingHandler, DefaultTokenStreamingHandler
from haystack.nodes.prompt.invocation_layer.handlers import (
TokenStreamingHandler,
AnthropicTokenStreamingHandler,
DefaultTokenStreamingHandler,
)
from haystack.utils.requests import request_with_retry
from haystack.environment import HAYSTACK_REMOTE_API_MAX_RETRIES, HAYSTACK_REMOTE_API_TIMEOUT_SEC

Expand Down Expand Up @@ -126,21 +130,22 @@ def invoke(self, *args, **kwargs):
return [res.json()["completion"].strip()]

res = self._post(data=data, stream=True)
# Anthropic streamed response always includes the whole string that has been
# streamed until that point, so we use a stream handler built ad hoc for this
# invocation layer.
handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler())
handler = AnthropicTokenStreamingHandler(handler)
client = sseclient.SSEClient(res)
tokens = ""
tokens = []
try:
for event in client.events():
if event.data == TokenStreamingHandler.DONE_MARKER:
continue
ed = json.loads(event.data)
# Anthropic streamed response always includes the whole
# string that has been streamed until that point, so
# we can just store the last received event
tokens = handler(ed["completion"])
tokens.append(handler(ed["completion"]))
finally:
client.close()
return [tokens.strip()] # return a list of strings just like non-streaming
return ["".join(tokens)] # return a list of strings just like non-streaming

def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]:
"""Make sure the length of the prompt and answer is within the max tokens limit of the model.
Expand Down
42 changes: 42 additions & 0 deletions haystack/nodes/prompt/invocation_layer/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,48 @@ def __call__(self, token_received, **kwargs) -> str:
return token_received


class AnthropicTokenStreamingHandler(TokenStreamingHandler):
"""
Anthropic has an unusual way of handling streaming responses
as it returns all the tokens generated up to that point for each
response.
This makes it hard to use DefaultTokenStreamingHandler as the user
would see the generated text printed multiple times.

This streaming handler tackles the repeating text and prints
only the newly generated part.
"""

def __init__(self, token_handler: TokenStreamingHandler):
self.token_handler = token_handler
self.previous_text = ""

def __call__(self, token_received: str, **kwargs) -> str:
"""
When the handler is called directly with a response string from Anthropic,
we split it, comparing it with the previously received text by this handler,
and return only the new part.

If the text is completely different from the previously received one, we
replace it and return it in full.

:param token_received: Text response received by Anthropic backend.
:type token_received: str
:return: The part of text that has not been received previously.
:rtype: str
"""
if self.previous_text not in token_received:
# The handler is being reused, we want to handle this case gracefully
# so we just cleanup the previously received text and keep going
self.previous_text = ""

previous_text_length = len(self.previous_text)
chopped_text = token_received[previous_text_length:]
self.token_handler(chopped_text)
self.previous_text = token_received
return chopped_text


class HFTokenStreamingHandler(TextStreamer):
def __init__(
self, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], stream_handler: TokenStreamingHandler
Expand Down
10 changes: 3 additions & 7 deletions test/prompt/invocation_layer/test_anthropic_claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def mock_iter(self):
res = layer.invoke(prompt="Some prompt", stream=True)

assert len(res) == 1
assert res[0] == "The sky appears blue to us due to how"
assert res[0] == " The sky appears blue to us due to how"


@pytest.mark.unit
Expand Down Expand Up @@ -177,15 +177,11 @@ def mock_iter(self):

assert len(res) == 1
# This is not the real result but the values returned by the mock handler
assert res[0] == "token"
assert res[0] == " The sky appears blue to us due to how"

# Verifies the handler has been called the expected times with the expected args
assert mock_stream_handler.call_count == 3
expected_call_list = [
call(" The sky appears"),
call(" The sky appears blue to"),
call(" The sky appears blue to us due to how"),
]
expected_call_list = [call(" The sky appears"), call(" blue to"), call(" us due to how")]
assert mock_stream_handler.call_args_list == expected_call_list


Expand Down
34 changes: 33 additions & 1 deletion test/prompt/test_handlers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from unittest.mock import patch

import pytest

from haystack.nodes.prompt.invocation_layer.handlers import DefaultPromptHandler
from haystack.nodes.prompt.invocation_layer.handlers import (
DefaultTokenStreamingHandler,
DefaultPromptHandler,
AnthropicTokenStreamingHandler,
)


@pytest.mark.integration
Expand Down Expand Up @@ -76,3 +82,29 @@ def test_flan_prompt_handler():
"model_max_length": 20,
"new_prompt_length": 0,
}


@pytest.mark.unit
@patch("builtins.print")
def test_anthropic_token_streaming_handler(mock_print):
handler = AnthropicTokenStreamingHandler(DefaultTokenStreamingHandler())

res = handler(" This")
assert res == " This"
mock_print.assert_called_with(" This", flush=True, end="")

res = handler(" This is a new")
assert res == " is a new"
mock_print.assert_called_with(" is a new", flush=True, end="")

res = handler(" This is a new token")
assert res == " token"
mock_print.assert_called_with(" token", flush=True, end="")

res = handler("And now")
assert res == "And now"
mock_print.assert_called_with("And now", flush=True, end="")

res = handler("And now something completely different")
assert res == " something completely different"
mock_print.assert_called_with(" something completely different", flush=True, end="")