Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: introduce class method to create ChatMessage from the OpenAI dictionary format #8670

Merged
merged 8 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions haystack/components/generators/chat/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,6 @@ def _prepare_api_call( # noqa: PLR0913
}

def _handle_stream_response(self, chat_completion: Stream, callback: StreamingCallbackT) -> List[ChatMessage]:
print("callback")
print(callback)
print("-" * 100)

mpangrazzi marked this conversation as resolved.
Show resolved Hide resolved
chunks: List[StreamingChunk] = []
chunk = None

Expand Down
63 changes: 63 additions & 0 deletions haystack/dataclasses/chat_message.py
mpangrazzi marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Sequence, Union

from haystack import logging

logger = logging.getLogger(__name__)
anakin87 marked this conversation as resolved.
Show resolved Hide resolved

LEGACY_INIT_PARAMETERS = {"role", "content", "meta", "name"}


Expand Down Expand Up @@ -426,3 +430,62 @@ def to_openai_dict_format(self) -> Dict[str, Any]:
)
openai_msg["tool_calls"] = openai_tool_calls
return openai_msg

@classmethod
def from_openai_dict_format(cls, message: Dict[str, Any]) -> "ChatMessage":
"""
Create a ChatMessage from a dictionary in the format expected by OpenAI's Chat API.

NOTE: While OpenAI's API requires `tool_call_id` in both tool calls and tool messages, this method
accepts messages without it to support shallow OpenAI-compatible APIs. However, if you plan to use the
resulting ChatMessage with OpenAI, you must include `tool_call_id` or you'll encounter validation errors.

:param message:
The OpenAI dictionary to build the ChatMessage object.
:returns:
The created ChatMessage object.

:raises ValueError:
If the message dictionary is missing required fields.
"""

if "role" not in message:
raise ValueError("The `role` field is required in the message dictionary.")
role = message["role"]

content = message.get("content")
name = message.get("name")
tool_calls = message.get("tool_calls")
tool_call_id = message.get("tool_call_id")

if role == "assistant":
if not content and not tool_calls:
raise ValueError("For assistant messages, either `content` or `tool_calls` must be present.")

haystack_tool_calls = None
if tool_calls:
haystack_tool_calls = []
for tc in tool_calls:
if "function" not in tc:
raise ValueError("Tool calls must contain the `function` field")
haystack_tc = ToolCall(
id=tc.get("id"),
tool_name=tc["function"]["name"],
arguments=json.loads(tc["function"]["arguments"]),
)
haystack_tool_calls.append(haystack_tc)
return cls.from_assistant(text=content, name=name, tool_calls=haystack_tool_calls)

if not content:
raise ValueError(f"The `content` field is required for {role} messages.")

if role == "user":
return cls.from_user(text=content, name=name)
if role in ["system", "developer"]:
return cls.from_system(text=content, name=name)
if role == "tool":
return cls.from_tool(
tool_result=content, origin=ToolCall(id=tool_call_id, tool_name="", arguments={}), error=False
)

raise ValueError(f"Unsupported role: {role}")
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
enhancements:
- |
Add the `from_openai_dict_format` class method to the `ChatMessage` class. It allows you to create a `ChatMessage`
from a dictionary in the format expected by OpenAI's Chat API.
80 changes: 80 additions & 0 deletions test/dataclasses/test_chat_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,86 @@ def test_to_openai_dict_format_invalid():
message.to_openai_dict_format()


def test_from_openai_dict_format_user_message():
openai_msg = {"role": "user", "content": "Hello, how are you?", "name": "John"}
message = ChatMessage.from_openai_dict_format(openai_msg)
assert message.role.value == "user"
assert message.text == "Hello, how are you?"
assert message.name == "John"


def test_from_openai_dict_format_system_message():
openai_msg = {"role": "system", "content": "You are a helpful assistant"}
message = ChatMessage.from_openai_dict_format(openai_msg)
assert message.role.value == "system"
assert message.text == "You are a helpful assistant"


def test_from_openai_dict_format_assistant_message_with_content():
openai_msg = {"role": "assistant", "content": "I can help with that"}
message = ChatMessage.from_openai_dict_format(openai_msg)
assert message.role.value == "assistant"
assert message.text == "I can help with that"


def test_from_openai_dict_format_assistant_message_with_tool_calls():
openai_msg = {
"role": "assistant",
"content": None,
"tool_calls": [{"id": "call_123", "function": {"name": "get_weather", "arguments": '{"location": "Berlin"}'}}],
}
message = ChatMessage.from_openai_dict_format(openai_msg)
assert message.role.value == "assistant"
assert message.text is None
assert len(message.tool_calls) == 1
tool_call = message.tool_calls[0]
assert tool_call.id == "call_123"
assert tool_call.tool_name == "get_weather"
assert tool_call.arguments == {"location": "Berlin"}


def test_from_openai_dict_format_tool_message():
openai_msg = {"role": "tool", "content": "The weather is sunny", "tool_call_id": "call_123"}
message = ChatMessage.from_openai_dict_format(openai_msg)
assert message.role.value == "tool"
assert message.tool_call_result.result == "The weather is sunny"
assert message.tool_call_result.origin.id == "call_123"


def test_from_openai_dict_format_tool_without_id():
openai_msg = {"role": "tool", "content": "The weather is sunny"}
message = ChatMessage.from_openai_dict_format(openai_msg)
assert message.role.value == "tool"
assert message.tool_call_result.result == "The weather is sunny"
assert message.tool_call_result.origin.id is None


def test_from_openai_dict_format_missing_role():
with pytest.raises(ValueError):
ChatMessage.from_openai_dict_format({"content": "test"})


def test_from_openai_dict_format_missing_content():
with pytest.raises(ValueError):
ChatMessage.from_openai_dict_format({"role": "user"})


def test_from_openai_dict_format_invalid_tool_calls():
openai_msg = {"role": "assistant", "tool_calls": [{"invalid": "format"}]}
with pytest.raises(ValueError):
ChatMessage.from_openai_dict_format(openai_msg)


def test_from_openai_dict_format_unsupported_role():
with pytest.raises(ValueError):
ChatMessage.from_openai_dict_format({"role": "invalid", "content": "test"})


def test_from_openai_dict_format_assistant_missing_content_and_tool_calls():
with pytest.raises(ValueError):
ChatMessage.from_openai_dict_format({"role": "assistant", "irrelevant": "irrelevant"})


@pytest.mark.integration
def test_apply_chat_templating_on_chat_message():
messages = [ChatMessage.from_system("You are good assistant"), ChatMessage.from_user("I have a question")]
Expand Down
Loading