Skip to content

Commit

Permalink
Merge branch 'deepset-ai:main' into draw-offline
Browse files Browse the repository at this point in the history
  • Loading branch information
lbux authored Jan 24, 2025
2 parents 692dc48 + 3119ae1 commit 779d50b
Show file tree
Hide file tree
Showing 11 changed files with 300 additions and 18 deletions.
20 changes: 15 additions & 5 deletions haystack/components/generators/chat/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
# SPDX-License-Identifier: Apache-2.0

import os
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, List, Optional

# pylint: disable=import-error
from openai.lib.azure import AzureOpenAI

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.dataclasses import StreamingChunk
from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -75,6 +76,8 @@ def __init__( # pylint: disable=too-many-positional-arguments
max_retries: Optional[int] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
default_headers: Optional[Dict[str, str]] = None,
tools: Optional[List[Tool]] = None,
tools_strict: bool = False,
):
"""
Initialize the Azure OpenAI Chat Generator component.
Expand Down Expand Up @@ -112,6 +115,11 @@ def __init__( # pylint: disable=too-many-positional-arguments
- `logit_bias`: Adds a logit bias to specific tokens. The keys of the dictionary are tokens, and the
values are the bias to add to that token.
:param default_headers: Default headers to use for the AzureOpenAI client.
:param tools:
A list of tools for which the model can prepare calls.
:param tools_strict:
Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly
the schema provided in the `parameters` field of the tool definition, but this may increase latency.
"""
# We intentionally do not call super().__init__ here because we only need to instantiate the client to interact
# with the API.
Expand Down Expand Up @@ -142,10 +150,9 @@ def __init__( # pylint: disable=too-many-positional-arguments
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
self.default_headers = default_headers or {}

# This ChatGenerator does not yet supports tools. The following workaround ensures that we do not
# get an error when invoking the run method of the parent class (OpenAIChatGenerator).
self.tools = None
self.tools_strict = False
_check_duplicate_tool_names(tools)
self.tools = tools
self.tools_strict = tools_strict

self.client = AzureOpenAI(
api_version=api_version,
Expand Down Expand Up @@ -180,6 +187,8 @@ def to_dict(self) -> Dict[str, Any]:
api_key=self.api_key.to_dict() if self.api_key is not None else None,
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
default_headers=self.default_headers,
tools=[tool.to_dict() for tool in self.tools] if self.tools else None,
tools_strict=self.tools_strict,
)

@classmethod
Expand All @@ -192,6 +201,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "AzureOpenAIChatGenerator":
The deserialized component instance.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_ad_token"])
deserialize_tools_inplace(data["init_parameters"], key="tools")
init_params = data.get("init_parameters", {})
serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler:
Expand Down
18 changes: 14 additions & 4 deletions haystack/components/generators/chat/hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def run(
messages: List[ChatMessage],
generation_kwargs: Optional[Dict[str, Any]] = None,
tools: Optional[List[Tool]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
):
"""
Invoke the text generation inference based on the provided messages and generation parameters.
Expand All @@ -231,6 +232,9 @@ def run(
:param tools:
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
during component initialization.
:param streaming_callback:
An optional callable for handling streaming responses. If set, it will override the `streaming_callback`
parameter set during component initialization.
:returns: A dictionary with the following keys:
- `replies`: A list containing the generated responses as ChatMessage objects.
"""
Expand All @@ -245,16 +249,22 @@ def run(
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(tools)

if self.streaming_callback:
return self._run_streaming(formatted_messages, generation_kwargs)
streaming_callback = streaming_callback or self.streaming_callback
if streaming_callback:
return self._run_streaming(formatted_messages, generation_kwargs, streaming_callback)

hf_tools = None
if tools:
hf_tools = [{"type": "function", "function": {**t.tool_spec}} for t in tools]

return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools)

def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any]):
def _run_streaming(
self,
messages: List[Dict[str, str]],
generation_kwargs: Dict[str, Any],
streaming_callback: Callable[[StreamingChunk], None],
):
api_output: Iterable[ChatCompletionStreamOutput] = self._client.chat_completion(
messages, stream=True, **generation_kwargs
)
Expand Down Expand Up @@ -282,7 +292,7 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict
first_chunk_time = datetime.now().isoformat()

stream_chunk = StreamingChunk(text, meta)
self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method)
streaming_callback(stream_chunk)

meta.update(
{
Expand Down
13 changes: 10 additions & 3 deletions haystack/components/generators/chat/hugging_face_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,18 @@ def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceLocalChatGenerator":
return default_from_dict(cls, data)

@component.output_types(replies=List[ChatMessage])
def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None):
def run(
self,
messages: List[ChatMessage],
generation_kwargs: Optional[Dict[str, Any]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
):
"""
Invoke text generation inference based on the provided messages and generation parameters.
:param messages: A list of ChatMessage objects representing the input messages.
:param generation_kwargs: Additional keyword arguments for text generation.
:param streaming_callback: An optional callable for handling streaming responses.
:returns:
A list containing the generated responses as ChatMessage instances.
"""
Expand All @@ -259,7 +265,8 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
if stop_words_criteria:
generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stop_words_criteria])

if self.streaming_callback:
streaming_callback = streaming_callback or self.streaming_callback
if streaming_callback:
num_responses = generation_kwargs.get("num_return_sequences", 1)
if num_responses > 1:
msg = (
Expand All @@ -270,7 +277,7 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
logger.warning(msg, num_responses=num_responses)
generation_kwargs["num_return_sequences"] = 1
# streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, self.streaming_callback, stop_words)
generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, streaming_callback, stop_words)

hf_messages = [convert_message_to_hf_format(message) for message in messages]

Expand Down
6 changes: 4 additions & 2 deletions haystack/core/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,10 @@ def from_dict(
f"Successfully imported module {module} but can't find it in the component registry."
"This is unexpected and most likely a bug."
)
except (ImportError, PipelineError) as e:
raise PipelineError(f"Component '{component_data['type']}' not imported.") from e
except (ImportError, PipelineError, ValueError) as e:
raise PipelineError(
f"Component '{component_data['type']}' (name: '{name}') not imported."
) from e

# Create a new one
component_class = component.registry[component_data["type"]]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
Add support for Tools in the Azure OpenAI Chat Generator.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
enhancements:
- |
When `Pipeline.from_dict` receives an invalid type (e.g. empty string), an informative `PipelineError` is now
raised.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
enhancements:
- |
Streaming callback run param support for HF chat generators.
130 changes: 127 additions & 3 deletions test/components/generators/chat/test_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,25 @@
from haystack import Pipeline
from haystack.components.generators.chat import AzureOpenAIChatGenerator
from haystack.components.generators.utils import print_streaming_chunk
from haystack.dataclasses import ChatMessage
from haystack.dataclasses import ChatMessage, ToolCall
from haystack.tools.tool import Tool
from haystack.utils.auth import Secret


class TestOpenAIChatGenerator:
@pytest.fixture
def tools():
tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
tool = Tool(
name="weather",
description="useful to determine the weather in a given location",
parameters=tool_parameters,
function=lambda x: x,
)

return [tool]


class TestAzureOpenAIChatGenerator:
def test_init_default(self, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint")
Expand All @@ -28,17 +42,21 @@ def test_init_fail_wo_api_key(self, monkeypatch):
with pytest.raises(OpenAIError):
AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint")

def test_init_with_parameters(self):
def test_init_with_parameters(self, tools):
component = AzureOpenAIChatGenerator(
api_key=Secret.from_token("test-api-key"),
azure_endpoint="some-non-existing-endpoint",
streaming_callback=print_streaming_chunk,
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
tools=tools,
tools_strict=True,
)
assert component.client.api_key == "test-api-key"
assert component.azure_deployment == "gpt-4o-mini"
assert component.streaming_callback is print_streaming_chunk
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
assert component.tools == tools
assert component.tools_strict

def test_to_dict_default(self, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
Expand All @@ -58,6 +76,8 @@ def test_to_dict_default(self, monkeypatch):
"timeout": 30.0,
"max_retries": 5,
"default_headers": {},
"tools": None,
"tools_strict": False,
},
}

Expand Down Expand Up @@ -85,15 +105,94 @@ def test_to_dict_with_parameters(self, monkeypatch):
"timeout": 2.5,
"max_retries": 10,
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
"tools": None,
"tools_strict": False,
"default_headers": {},
},
}

def test_from_dict(self, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
monkeypatch.setenv("AZURE_OPENAI_AD_TOKEN", "test-ad-token")
data = {
"type": "haystack.components.generators.chat.azure.AzureOpenAIChatGenerator",
"init_parameters": {
"api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"},
"azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"},
"api_version": "2023-05-15",
"azure_endpoint": "some-non-existing-endpoint",
"azure_deployment": "gpt-4o-mini",
"organization": None,
"streaming_callback": None,
"generation_kwargs": {},
"timeout": 30.0,
"max_retries": 5,
"default_headers": {},
"tools": [
{
"type": "haystack.tools.tool.Tool",
"data": {
"description": "description",
"function": "builtins.print",
"name": "name",
"parameters": {"x": {"type": "string"}},
},
}
],
"tools_strict": False,
},
}

generator = AzureOpenAIChatGenerator.from_dict(data)
assert isinstance(generator, AzureOpenAIChatGenerator)

assert generator.api_key == Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False)
assert generator.azure_ad_token == Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False)
assert generator.api_version == "2023-05-15"
assert generator.azure_endpoint == "some-non-existing-endpoint"
assert generator.azure_deployment == "gpt-4o-mini"
assert generator.organization is None
assert generator.streaming_callback is None
assert generator.generation_kwargs == {}
assert generator.timeout == 30.0
assert generator.max_retries == 5
assert generator.default_headers == {}
assert generator.tools == [
Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print)
]
assert generator.tools_strict == False

def test_pipeline_serialization_deserialization(self, tmp_path, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
generator = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint")
p = Pipeline()
p.add_component(instance=generator, name="generator")

assert p.to_dict() == {
"metadata": {},
"max_runs_per_component": 100,
"components": {
"generator": {
"type": "haystack.components.generators.chat.azure.AzureOpenAIChatGenerator",
"init_parameters": {
"azure_endpoint": "some-non-existing-endpoint",
"azure_deployment": "gpt-4o-mini",
"organization": None,
"api_version": "2023-05-15",
"streaming_callback": None,
"generation_kwargs": {},
"timeout": 30.0,
"max_retries": 5,
"api_key": {"type": "env_var", "env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False},
"azure_ad_token": {"type": "env_var", "env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False},
"default_headers": {},
"tools": None,
"tools_strict": False,
},
}
},
"connections": [],
}
p_str = p.dumps()
q = Pipeline.loads(p_str)
assert p.to_dict() == q.to_dict(), "Pipeline serialization/deserialization w/ AzureOpenAIChatGenerator failed."
Expand All @@ -117,4 +216,29 @@ def test_live_run(self):
assert "gpt-4o-mini" in message.meta["model"]
assert message.meta["finish_reason"] == "stop"

@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("AZURE_OPENAI_API_KEY", None) or not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
reason=(
"Please export env variables called AZURE_OPENAI_API_KEY containing "
"the Azure OpenAI key, AZURE_OPENAI_ENDPOINT containing "
"the Azure OpenAI endpoint URL to run this test."
),
)
def test_live_run_with_tools(self, tools):
chat_messages = [ChatMessage.from_user("What's the weather like in Paris?")]
component = AzureOpenAIChatGenerator(organization="HaystackCI", tools=tools)
results = component.run(chat_messages)
assert len(results["replies"]) == 1
message = results["replies"][0]

assert not message.texts
assert not message.text
assert message.tool_calls
tool_call = message.tool_call
assert isinstance(tool_call, ToolCall)
assert tool_call.tool_name == "weather"
assert tool_call.arguments == {"city": "Paris"}
assert message.meta["finish_reason"] == "tool_calls"

# additional tests intentionally omitted as they are covered by test_openai.py
Loading

0 comments on commit 779d50b

Please sign in to comment.