Skip to content

Commit

Permalink
feat: introduce class method to create ChatMessage from the OpenAI …
Browse files Browse the repository at this point in the history
…dictionary format (#8670)

* add ChatMessage.from_openai_dict_format

* remove print

* release note

* improve docstring

* separate validation logic

* rm obvious comment
  • Loading branch information
anakin87 authored Jan 2, 2025
1 parent 3ea128c commit 7b4d9ba
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 4 deletions.
4 changes: 0 additions & 4 deletions haystack/components/generators/chat/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,6 @@ def _prepare_api_call( # noqa: PLR0913
}

def _handle_stream_response(self, chat_completion: Stream, callback: StreamingCallbackT) -> List[ChatMessage]:
print("callback")
print(callback)
print("-" * 100)

chunks: List[StreamingChunk] = []
chunk = None

Expand Down
78 changes: 78 additions & 0 deletions haystack/dataclasses/chat_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,3 +426,81 @@ def to_openai_dict_format(self) -> Dict[str, Any]:
)
openai_msg["tool_calls"] = openai_tool_calls
return openai_msg

@staticmethod
def _validate_openai_message(message: Dict[str, Any]) -> None:
"""
Validate that a message dictionary follows OpenAI's Chat API format.
:param message: The message dictionary to validate
:raises ValueError: If the message format is invalid
"""
if "role" not in message:
raise ValueError("The `role` field is required in the message dictionary.")

role = message["role"]
content = message.get("content")
tool_calls = message.get("tool_calls")

if role not in ["assistant", "user", "system", "developer", "tool"]:
raise ValueError(f"Unsupported role: {role}")

if role == "assistant":
if not content and not tool_calls:
raise ValueError("For assistant messages, either `content` or `tool_calls` must be present.")
if tool_calls:
for tc in tool_calls:
if "function" not in tc:
raise ValueError("Tool calls must contain the `function` field")
elif not content:
raise ValueError(f"The `content` field is required for {role} messages.")

@classmethod
def from_openai_dict_format(cls, message: Dict[str, Any]) -> "ChatMessage":
"""
Create a ChatMessage from a dictionary in the format expected by OpenAI's Chat API.
NOTE: While OpenAI's API requires `tool_call_id` in both tool calls and tool messages, this method
accepts messages without it to support shallow OpenAI-compatible APIs.
If you plan to use the resulting ChatMessage with OpenAI, you must include `tool_call_id` or you'll
encounter validation errors.
:param message:
The OpenAI dictionary to build the ChatMessage object.
:returns:
The created ChatMessage object.
:raises ValueError:
If the message dictionary is missing required fields.
"""
cls._validate_openai_message(message)

role = message["role"]
content = message.get("content")
name = message.get("name")
tool_calls = message.get("tool_calls")
tool_call_id = message.get("tool_call_id")

if role == "assistant":
haystack_tool_calls = None
if tool_calls:
haystack_tool_calls = []
for tc in tool_calls:
haystack_tc = ToolCall(
id=tc.get("id"),
tool_name=tc["function"]["name"],
arguments=json.loads(tc["function"]["arguments"]),
)
haystack_tool_calls.append(haystack_tc)
return cls.from_assistant(text=content, name=name, tool_calls=haystack_tool_calls)

assert content is not None # ensured by _validate_openai_message, but we need to make mypy happy

if role == "user":
return cls.from_user(text=content, name=name)
if role in ["system", "developer"]:
return cls.from_system(text=content, name=name)

return cls.from_tool(
tool_result=content, origin=ToolCall(id=tool_call_id, tool_name="", arguments={}), error=False
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
enhancements:
- |
Add the `from_openai_dict_format` class method to the `ChatMessage` class. It allows you to create a `ChatMessage`
from a dictionary in the format expected by OpenAI's Chat API.
80 changes: 80 additions & 0 deletions test/dataclasses/test_chat_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,86 @@ def test_to_openai_dict_format_invalid():
message.to_openai_dict_format()


def test_from_openai_dict_format_user_message():
openai_msg = {"role": "user", "content": "Hello, how are you?", "name": "John"}
message = ChatMessage.from_openai_dict_format(openai_msg)
assert message.role.value == "user"
assert message.text == "Hello, how are you?"
assert message.name == "John"


def test_from_openai_dict_format_system_message():
openai_msg = {"role": "system", "content": "You are a helpful assistant"}
message = ChatMessage.from_openai_dict_format(openai_msg)
assert message.role.value == "system"
assert message.text == "You are a helpful assistant"


def test_from_openai_dict_format_assistant_message_with_content():
openai_msg = {"role": "assistant", "content": "I can help with that"}
message = ChatMessage.from_openai_dict_format(openai_msg)
assert message.role.value == "assistant"
assert message.text == "I can help with that"


def test_from_openai_dict_format_assistant_message_with_tool_calls():
openai_msg = {
"role": "assistant",
"content": None,
"tool_calls": [{"id": "call_123", "function": {"name": "get_weather", "arguments": '{"location": "Berlin"}'}}],
}
message = ChatMessage.from_openai_dict_format(openai_msg)
assert message.role.value == "assistant"
assert message.text is None
assert len(message.tool_calls) == 1
tool_call = message.tool_calls[0]
assert tool_call.id == "call_123"
assert tool_call.tool_name == "get_weather"
assert tool_call.arguments == {"location": "Berlin"}


def test_from_openai_dict_format_tool_message():
openai_msg = {"role": "tool", "content": "The weather is sunny", "tool_call_id": "call_123"}
message = ChatMessage.from_openai_dict_format(openai_msg)
assert message.role.value == "tool"
assert message.tool_call_result.result == "The weather is sunny"
assert message.tool_call_result.origin.id == "call_123"


def test_from_openai_dict_format_tool_without_id():
openai_msg = {"role": "tool", "content": "The weather is sunny"}
message = ChatMessage.from_openai_dict_format(openai_msg)
assert message.role.value == "tool"
assert message.tool_call_result.result == "The weather is sunny"
assert message.tool_call_result.origin.id is None


def test_from_openai_dict_format_missing_role():
with pytest.raises(ValueError):
ChatMessage.from_openai_dict_format({"content": "test"})


def test_from_openai_dict_format_missing_content():
with pytest.raises(ValueError):
ChatMessage.from_openai_dict_format({"role": "user"})


def test_from_openai_dict_format_invalid_tool_calls():
openai_msg = {"role": "assistant", "tool_calls": [{"invalid": "format"}]}
with pytest.raises(ValueError):
ChatMessage.from_openai_dict_format(openai_msg)


def test_from_openai_dict_format_unsupported_role():
with pytest.raises(ValueError):
ChatMessage.from_openai_dict_format({"role": "invalid", "content": "test"})


def test_from_openai_dict_format_assistant_missing_content_and_tool_calls():
with pytest.raises(ValueError):
ChatMessage.from_openai_dict_format({"role": "assistant", "irrelevant": "irrelevant"})


@pytest.mark.integration
def test_apply_chat_templating_on_chat_message():
messages = [ChatMessage.from_system("You are good assistant"), ChatMessage.from_user("I have a question")]
Expand Down

0 comments on commit 7b4d9ba

Please sign in to comment.