diff --git a/haystack/nodes/prompt/invocation_layer/anthropic_claude.py b/haystack/nodes/prompt/invocation_layer/anthropic_claude.py index 535a91d935..bea7313adb 100644 --- a/haystack/nodes/prompt/invocation_layer/anthropic_claude.py +++ b/haystack/nodes/prompt/invocation_layer/anthropic_claude.py @@ -10,7 +10,11 @@ from haystack.errors import AnthropicError, AnthropicRateLimitError, AnthropicUnauthorizedError from haystack.nodes.prompt.invocation_layer.base import PromptModelInvocationLayer -from haystack.nodes.prompt.invocation_layer.handlers import TokenStreamingHandler, DefaultTokenStreamingHandler +from haystack.nodes.prompt.invocation_layer.handlers import ( + TokenStreamingHandler, + AnthropicTokenStreamingHandler, + DefaultTokenStreamingHandler, +) from haystack.utils.requests import request_with_retry from haystack.environment import HAYSTACK_REMOTE_API_MAX_RETRIES, HAYSTACK_REMOTE_API_TIMEOUT_SEC @@ -126,21 +130,22 @@ def invoke(self, *args, **kwargs): return [res.json()["completion"].strip()] res = self._post(data=data, stream=True) + # Anthropic streamed response always includes the whole string that has been + # streamed until that point, so we use a stream handler built ad hoc for this + # invocation layer. handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler()) + handler = AnthropicTokenStreamingHandler(handler) client = sseclient.SSEClient(res) - tokens = "" + tokens = [] try: for event in client.events(): if event.data == TokenStreamingHandler.DONE_MARKER: continue ed = json.loads(event.data) - # Anthropic streamed response always includes the whole - # string that has been streamed until that point, so - # we can just store the last received event - tokens = handler(ed["completion"]) + tokens.append(handler(ed["completion"])) finally: client.close() - return [tokens.strip()] # return a list of strings just like non-streaming + return ["".join(tokens)] # return a list of strings just like non-streaming def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]: """Make sure the length of the prompt and answer is within the max tokens limit of the model. diff --git a/haystack/nodes/prompt/invocation_layer/handlers.py b/haystack/nodes/prompt/invocation_layer/handlers.py index 0872fc53e5..77f5b37232 100644 --- a/haystack/nodes/prompt/invocation_layer/handlers.py +++ b/haystack/nodes/prompt/invocation_layer/handlers.py @@ -36,6 +36,48 @@ def __call__(self, token_received, **kwargs) -> str: return token_received +class AnthropicTokenStreamingHandler(TokenStreamingHandler): + """ + Anthropic has an unusual way of handling streaming responses + as it returns all the tokens generated up to that point for each + response. + This makes it hard to use DefaultTokenStreamingHandler as the user + would see the generated text printed multiple times. + + This streaming handler tackles the repeating text and prints + only the newly generated part. + """ + + def __init__(self, token_handler: TokenStreamingHandler): + self.token_handler = token_handler + self.previous_text = "" + + def __call__(self, token_received: str, **kwargs) -> str: + """ + When the handler is called directly with a response string from Anthropic, + we split it, comparing it with the previously received text by this handler, + and return only the new part. + + If the text is completely different from the previously received one, we + replace it and return it in full. + + :param token_received: Text response received by Anthropic backend. + :type token_received: str + :return: The part of text that has not been received previously. + :rtype: str + """ + if self.previous_text not in token_received: + # The handler is being reused, we want to handle this case gracefully + # so we just cleanup the previously received text and keep going + self.previous_text = "" + + previous_text_length = len(self.previous_text) + chopped_text = token_received[previous_text_length:] + self.token_handler(chopped_text) + self.previous_text = token_received + return chopped_text + + class HFTokenStreamingHandler(TextStreamer): def __init__( self, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], stream_handler: TokenStreamingHandler diff --git a/test/prompt/invocation_layer/test_anthropic_claude.py b/test/prompt/invocation_layer/test_anthropic_claude.py index b6467cfc65..6823ba9570 100644 --- a/test/prompt/invocation_layer/test_anthropic_claude.py +++ b/test/prompt/invocation_layer/test_anthropic_claude.py @@ -146,7 +146,7 @@ def mock_iter(self): res = layer.invoke(prompt="Some prompt", stream=True) assert len(res) == 1 - assert res[0] == "The sky appears blue to us due to how" + assert res[0] == " The sky appears blue to us due to how" @pytest.mark.unit @@ -177,15 +177,11 @@ def mock_iter(self): assert len(res) == 1 # This is not the real result but the values returned by the mock handler - assert res[0] == "token" + assert res[0] == " The sky appears blue to us due to how" # Verifies the handler has been called the expected times with the expected args assert mock_stream_handler.call_count == 3 - expected_call_list = [ - call(" The sky appears"), - call(" The sky appears blue to"), - call(" The sky appears blue to us due to how"), - ] + expected_call_list = [call(" The sky appears"), call(" blue to"), call(" us due to how")] assert mock_stream_handler.call_args_list == expected_call_list diff --git a/test/prompt/test_handlers.py b/test/prompt/test_handlers.py index 2f90eb4c41..45bd85fa2d 100644 --- a/test/prompt/test_handlers.py +++ b/test/prompt/test_handlers.py @@ -1,6 +1,12 @@ +from unittest.mock import patch + import pytest -from haystack.nodes.prompt.invocation_layer.handlers import DefaultPromptHandler +from haystack.nodes.prompt.invocation_layer.handlers import ( + DefaultTokenStreamingHandler, + DefaultPromptHandler, + AnthropicTokenStreamingHandler, +) @pytest.mark.integration @@ -76,3 +82,29 @@ def test_flan_prompt_handler(): "model_max_length": 20, "new_prompt_length": 0, } + + +@pytest.mark.unit +@patch("builtins.print") +def test_anthropic_token_streaming_handler(mock_print): + handler = AnthropicTokenStreamingHandler(DefaultTokenStreamingHandler()) + + res = handler(" This") + assert res == " This" + mock_print.assert_called_with(" This", flush=True, end="") + + res = handler(" This is a new") + assert res == " is a new" + mock_print.assert_called_with(" is a new", flush=True, end="") + + res = handler(" This is a new token") + assert res == " token" + mock_print.assert_called_with(" token", flush=True, end="") + + res = handler("And now") + assert res == "And now" + mock_print.assert_called_with("And now", flush=True, end="") + + res = handler("And now something completely different") + assert res == " something completely different" + mock_print.assert_called_with(" something completely different", flush=True, end="")