Skip to content

Commit

Permalink
[#432] Add Groq Provider - tool calls (#630)
Browse files Browse the repository at this point in the history
# What does this PR do?

Contributes to issue #432

- Adds tool calls to Groq provider
- Enables tool call integration tests

### PR Train

- #609 
- #630 👈

## Test Plan
Environment:

```shell
export GROQ_API_KEY=<api-key>

# build.yaml and run.yaml files
wget https://raw.githubusercontent.com/aidando73/llama-stack/9165502582cd7cb178bc1dcf89955b45768ab6c1/build.yaml
wget https://raw.githubusercontent.com/aidando73/llama-stack/9165502582cd7cb178bc1dcf89955b45768ab6c1/run.yaml

# Create environment if not already
conda create --prefix ./envs python=3.10
conda activate ./envs

# Build
pip install -e . && llama stack build --config ./build.yaml --image-type conda

# Activate built environment
conda activate llamastack-groq
```

<details>
<summary>Unit tests</summary>

```shell
# Setup
conda activate llamastack-groq
pytest llama_stack/providers/tests/inference/groq/test_groq_utils.py -vv -k groq -s

# Result
llama_stack/providers/tests/inference/groq/test_groq_utils.py .....................

======================================== 21 passed, 1 warning in 0.05s ========================================
```
</details>

<details>
<summary>Integration tests</summary>

```shell
# Run
conda activate llamastack-groq
pytest llama_stack/providers/tests/inference/test_text_inference.py -k groq -s

# Result
llama_stack/providers/tests/inference/test_text_inference.py .sss.s.ss.sss.s...

========================== 8 passed, 10 skipped, 180 deselected, 7 warnings in 2.73s ==========================
```
</details>

<details>
<summary>Manual</summary>

```bash
llama stack run ./run.yaml --port 5001
```

Via this Jupyter notebook:
https://github.com/aidando73/llama-stack/blob/9165502582cd7cb178bc1dcf89955b45768ab6c1/hello.ipynb
</details>

## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Ran pre-commit to handle lint / formatting issues.
- [x] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [x] Updated relevant documentation. (no relevant documentation it
seems)
- [x] Wrote necessary unit or integration tests.
  • Loading branch information
aidando73 authored Jan 14, 2025
1 parent ace8dd6 commit fdcc74f
Show file tree
Hide file tree
Showing 4 changed files with 400 additions and 57 deletions.
12 changes: 11 additions & 1 deletion llama_stack/providers/remote/inference/groq/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import warnings
from typing import AsyncIterator, List, Optional, Union

import groq
from groq import Groq
from llama_models.datatypes import SamplingParams
from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat
Expand Down Expand Up @@ -123,7 +124,16 @@ async def chat_completion(
)
)

response = self._get_client().chat.completions.create(**request)
try:
response = self._get_client().chat.completions.create(**request)
except groq.BadRequestError as e:
if e.body.get("error", {}).get("code") == "tool_use_failed":
# For smaller models, Groq may fail to call a tool even when the request is well formed
raise ValueError(
"Groq failed to call a tool", e.body.get("error", {})
) from e
else:
raise e

if stream:
return convert_chat_completion_response_stream(response)
Expand Down
140 changes: 113 additions & 27 deletions llama_stack/providers/remote/inference/groq/groq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import json
import warnings
from typing import AsyncGenerator, Literal

Expand All @@ -14,14 +15,20 @@
)
from groq.types.chat.chat_completion_chunk import ChatCompletionChunk
from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam
from groq.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
)
from groq.types.chat.chat_completion_system_message_param import (
ChatCompletionSystemMessageParam,
)
from groq.types.chat.chat_completion_tool_param import ChatCompletionToolParam
from groq.types.chat.chat_completion_user_message_param import (
ChatCompletionUserMessageParam,
)

from groq.types.chat.completion_create_params import CompletionCreateParams
from groq.types.shared.function_definition import FunctionDefinition

from llama_models.llama3.api.datatypes import ToolParamDefinition

from llama_stack.apis.inference import (
ChatCompletionRequest,
Expand All @@ -32,6 +39,11 @@
CompletionMessage,
Message,
StopReason,
ToolCall,
ToolCallDelta,
ToolCallParseStatus,
ToolDefinition,
ToolPromptFormat,
)


Expand Down Expand Up @@ -59,8 +71,8 @@ def convert_chat_completion_request(
# so we exclude it for now
warnings.warn("repetition_penalty is not supported")

if request.tools:
warnings.warn("tools are not supported yet")
if request.tool_prompt_format != ToolPromptFormat.json:
warnings.warn("tool_prompt_format is not used by Groq. Ignoring.")

return CompletionCreateParams(
model=request.model,
Expand All @@ -71,6 +83,8 @@ def convert_chat_completion_request(
max_tokens=request.sampling_params.max_tokens or None,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
tools=[_convert_groq_tool_definition(tool) for tool in request.tools or []],
tool_choice=request.tool_choice.value if request.tool_choice else None,
)


Expand All @@ -87,17 +101,64 @@ def _convert_message(message: Message) -> ChatCompletionMessageParam:
raise ValueError(f"Invalid message role: {message.role}")


def _convert_groq_tool_definition(tool_definition: ToolDefinition) -> dict:
# Groq requires a description for function tools
if tool_definition.description is None:
raise AssertionError("tool_definition.description is required")

tool_parameters = tool_definition.parameters or {}
return ChatCompletionToolParam(
type="function",
function=FunctionDefinition(
name=tool_definition.tool_name,
description=tool_definition.description,
parameters={
key: _convert_groq_tool_parameter(param)
for key, param in tool_parameters.items()
},
),
)


def _convert_groq_tool_parameter(tool_parameter: ToolParamDefinition) -> dict:
param = {
"type": tool_parameter.param_type,
}
if tool_parameter.description is not None:
param["description"] = tool_parameter.description
if tool_parameter.required is not None:
param["required"] = tool_parameter.required
if tool_parameter.default is not None:
param["default"] = tool_parameter.default
return param


def convert_chat_completion_response(
response: ChatCompletion,
) -> ChatCompletionResponse:
# groq only supports n=1 at time of writing, so there is only one choice
choice = response.choices[0]
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=choice.message.content,
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
),
)
if choice.finish_reason == "tool_calls":
tool_calls = [
_convert_groq_tool_call(tool_call)
for tool_call in choice.message.tool_calls
]
return ChatCompletionResponse(
completion_message=CompletionMessage(
tool_calls=tool_calls,
stop_reason=StopReason.end_of_message,
# Content is not optional
content="",
),
logprobs=None,
)
else:
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=choice.message.content,
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
),
)


def _map_finish_reason_to_stop_reason(
Expand All @@ -116,7 +177,7 @@ def _map_finish_reason_to_stop_reason(
elif finish_reason == "length":
return StopReason.out_of_tokens
elif finish_reason == "tool_calls":
raise NotImplementedError("tool_calls is not supported yet")
return StopReason.end_of_message
else:
raise ValueError(f"Invalid finish reason: {finish_reason}")

Expand All @@ -129,25 +190,50 @@ async def convert_chat_completion_response_stream(
for chunk in stream:
choice = chunk.choices[0]

# We assume there's only one finish_reason for the entire stream.
# We collect the last finish_reason
if choice.finish_reason:
stop_reason = _map_finish_reason_to_stop_reason(choice.finish_reason)

yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=choice.delta.content or "",
logprobs=None,
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=choice.delta.content or "",
logprobs=None,
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
)
)
elif choice.delta.tool_calls:
# We assume there is only one tool call per chunk, but emit a warning in case we're wrong
if len(choice.delta.tool_calls) > 1:
warnings.warn(
"Groq returned multiple tool calls in one chunk. Using the first one, ignoring the rest."
)

# We assume Groq produces fully formed tool calls for each chunk
tool_call = _convert_groq_tool_call(choice.delta.tool_calls[0])
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
)
)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=choice.delta.content or "",
logprobs=None,
)
)
)
event_type = ChatCompletionResponseEventType.progress

yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
logprobs=None,
stop_reason=stop_reason,
)

def _convert_groq_tool_call(tool_call: ChatCompletionMessageToolCall) -> ToolCall:
return ToolCall(
call_id=tool_call.id,
tool_name=tool_call.function.name,
# Note that Groq may return a string that is not valid JSON here
# So this may raise a 500 error. Going to leave this as is to see
# how big of an issue this is and what we can do about it.
arguments=json.loads(tool_call.function.arguments),
)
Loading

0 comments on commit fdcc74f

Please sign in to comment.