Skip to content

Commit

Permalink
Python: Add parallel_tool_calls attribute to OpenAI chat prompt execu…
Browse files Browse the repository at this point in the history
…tion settings (#9479)

### Motivation and Context

OpenAI / Azure OpenAI released the ability to specify the
`parallel_tool_calls` = True | False boolean attribute on the prompt
execution settings in July. This was never brought into SK Python.

Further, there appears to be a pesky bug related to function calling
where enabling parallel tool calls can cause 500s. The only way to get
around this is to disable parallel function calling. This has been both
tested on the Azure Chat Completion and OpenAI Chat Completion code and
works well -- the chat history shows the synchronous function calls if
multiple tool calls are required.

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

### Description

Add the `parallel_tool_calls` attribute on the
OpenAIChatPromptExecutionExecution settings, which are also used by the
AzureChatCompletionClass.
- Closes #9478
- Adds unit tests for this prompt execution setting

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [X] The code builds clean without any errors or warnings
- [X] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [X] All unit tests pass, and I have added new tests where possible
- [X] I didn't break anyone 😄
  • Loading branch information
moonbox3 authored Oct 30, 2024
1 parent 82f248f commit 5e632fd
Show file tree
Hide file tree
Showing 4 changed files with 350 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class OpenAIChatPromptExecutionSettings(OpenAIPromptExecutionSettings):
functions: list[dict[str, Any]] | None = None
messages: list[dict[str, Any]] | None = None
function_call_behavior: FunctionCallBehavior | None = Field(None, exclude=True)
parallel_tool_calls: bool = True
tools: list[dict[str, Any]] | None = Field(
None,
max_length=64,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ async def _send_request(
if self.ai_model_type == OpenAIModelTypes.CHAT:
assert isinstance(request_settings, OpenAIChatPromptExecutionSettings) # nosec
self._handle_structured_output(request_settings, settings)
if request_settings.tools is None:
settings.pop("parallel_tool_calls", None)
response = await self.client.chat.completions.create(**settings)
else:
response = await self.client.completions.create(**settings)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import os
from copy import deepcopy
from unittest.mock import AsyncMock, MagicMock, patch

import openai
Expand All @@ -17,6 +18,7 @@

from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase
from semantic_kernel.connectors.ai.function_call_behavior import FunctionCallBehavior
from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior
from semantic_kernel.connectors.ai.open_ai import AzureChatCompletion
from semantic_kernel.connectors.ai.open_ai.exceptions.content_filter_ai_exception import (
ContentFilterAIException,
Expand All @@ -32,6 +34,8 @@
from semantic_kernel.contents.text_content import TextContent
from semantic_kernel.exceptions import ServiceInitializationError, ServiceInvalidExecutionSettingsError
from semantic_kernel.exceptions.service_exceptions import ServiceResponseException
from semantic_kernel.functions.kernel_arguments import KernelArguments
from semantic_kernel.functions.kernel_function_decorator import kernel_function
from semantic_kernel.kernel import Kernel

# region Service Setup
Expand Down Expand Up @@ -595,6 +599,162 @@ async def test_cmc_tool_calling(
)


@pytest.mark.asyncio
@patch.object(AsyncChatCompletions, "create", new_callable=AsyncMock)
async def test_cmc_tool_calling_parallel_tool_calls(
mock_create,
kernel: Kernel,
azure_openai_unit_test_env,
chat_history: ChatHistory,
mock_chat_completion_response: ChatCompletion,
) -> None:
mock_chat_completion_response.choices = [
Choice(
index=0,
message=ChatCompletionMessage(
content=None,
role="assistant",
tool_calls=[
{
"id": "test id",
"function": {"name": "test-tool", "arguments": '{"key": "value"}'},
"type": "function",
}
],
),
finish_reason="stop",
)
]
mock_create.return_value = mock_chat_completion_response
prompt = "hello world"
chat_history.add_user_message(prompt)

class MockPlugin:
@kernel_function(name="test_tool")
def test_tool(self, key: str):
return "test"

kernel.add_plugin(MockPlugin(), plugin_name="test_tool")

orig_chat_history = deepcopy(chat_history)
complete_prompt_execution_settings = AzureChatPromptExecutionSettings(
service_id="test_service_id", function_choice_behavior=FunctionChoiceBehavior.Auto()
)

with patch(
"semantic_kernel.kernel.Kernel.invoke_function_call",
new_callable=AsyncMock,
) as mock_process_function_call:
azure_chat_completion = AzureChatCompletion()
await azure_chat_completion.get_chat_message_contents(
chat_history=chat_history,
settings=complete_prompt_execution_settings,
kernel=kernel,
arguments=KernelArguments(),
)
mock_create.assert_awaited_once_with(
model=azure_openai_unit_test_env["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"],
stream=False,
messages=azure_chat_completion._prepare_chat_history_for_request(orig_chat_history),
parallel_tool_calls=True,
tools=[
{
"type": "function",
"function": {
"name": "test_tool-test_tool",
"description": "",
"parameters": {
"type": "object",
"properties": {"key": {"type": "string"}},
"required": ["key"],
},
},
}
],
tool_choice="auto",
)
mock_process_function_call.assert_awaited()


@pytest.mark.asyncio
@patch.object(AsyncChatCompletions, "create", new_callable=AsyncMock)
async def test_cmc_tool_calling_parallel_tool_calls_disabled(
mock_create,
kernel: Kernel,
azure_openai_unit_test_env,
chat_history: ChatHistory,
mock_chat_completion_response: ChatCompletion,
) -> None:
mock_chat_completion_response.choices = [
Choice(
index=0,
message=ChatCompletionMessage(
content=None,
role="assistant",
tool_calls=[
{
"id": "test id",
"function": {"name": "test-tool", "arguments": '{"key": "value"}'},
"type": "function",
}
],
),
finish_reason="stop",
)
]
mock_create.return_value = mock_chat_completion_response
prompt = "hello world"
chat_history.add_user_message(prompt)

class MockPlugin:
@kernel_function(name="test_tool")
def test_tool(self, key: str):
return "test"

kernel.add_plugin(MockPlugin(), plugin_name="test_tool")

orig_chat_history = deepcopy(chat_history)
complete_prompt_execution_settings = AzureChatPromptExecutionSettings(
service_id="test_service_id",
function_choice_behavior=FunctionChoiceBehavior.Auto(),
parallel_tool_calls=False,
)

with patch(
"semantic_kernel.kernel.Kernel.invoke_function_call",
new_callable=AsyncMock,
) as mock_process_function_call:
azure_chat_completion = AzureChatCompletion()
await azure_chat_completion.get_chat_message_contents(
chat_history=chat_history,
settings=complete_prompt_execution_settings,
kernel=kernel,
arguments=KernelArguments(),
)
mock_create.assert_awaited_once_with(
model=azure_openai_unit_test_env["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"],
stream=False,
messages=azure_chat_completion._prepare_chat_history_for_request(orig_chat_history),
parallel_tool_calls=False,
tools=[
{
"type": "function",
"function": {
"name": "test_tool-test_tool",
"description": "",
"parameters": {
"type": "object",
"properties": {"key": {"type": "string"}},
"required": ["key"],
},
},
}
],
tool_choice="auto",
)
mock_process_function_call.assert_awaited()


CONTENT_FILTERED_ERROR_MESSAGE = (
"The response was filtered due to the prompt triggering Azure OpenAI's content management policy. Please "
"modify your prompt and retry. To learn more about our content filtering policies please read our "
Expand Down
Loading

0 comments on commit 5e632fd

Please sign in to comment.