Skip to content

Commit

Permalink
function calling
Browse files Browse the repository at this point in the history
  • Loading branch information
venjiang committed Sep 2, 2024
1 parent 110d45d commit 3cc968f
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 8 deletions.
40 changes: 33 additions & 7 deletions pkg/bridge/ai/provider/anthropic/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package anthropic

import (
"context"
"encoding/json"
"fmt"
"os"
"time"

Expand Down Expand Up @@ -72,7 +74,7 @@ func (p *Provider) GetChatCompletions(
msgs := make([]anthropic.MessageParam, 0)
systemMsgs := make([]anthropic.TextBlockParam, 0)
tools := make([]anthropic.ToolParam, 0)
toolResult := make([]anthropic.ToolResultBlockParam, 0)
toolResult := []anthropic.MessageParamContentUnion{}

// tools
for _, tool := range req.Tools {
Expand All @@ -84,6 +86,7 @@ func (p *Provider) GetChatCompletions(
})
}
}
ylog.Debug("anthropic tools", "tools", fmt.Sprintf("%+v", tools))

// messages
for _, msg := range req.Messages {
Expand All @@ -92,13 +95,23 @@ func (p *Provider) GetChatCompletions(
msgs = append(msgs, anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Content)))
case openai.ChatMessageRoleAssistant:
// tool use, check if there are tool calls
ylog.Debug("openai request", "tool_calls", len(msg.ToolCalls))
if len(msg.ToolCalls) > 0 {
toolUses := make([]anthropic.MessageParamContentUnion, 0)
for _, toolCall := range msg.ToolCalls {
msgs = append(
msgs,
anthropic.NewAssistantMessage(anthropic.NewToolUseBlockParam(toolCall.ID, toolCall.Function.Name, toolCall.Function.Arguments)),
)
var args map[string]any
if len(toolCall.Function.Arguments) > 0 {
err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args)
if err != nil {
// TODO: handle error
ylog.Error("anthropic tool use unmarshal input", "err", err)
}
}
toolUse := anthropic.NewToolUseBlockParam(toolCall.ID, toolCall.Function.Name, args)
ylog.Debug("anthropic tool use", "tool_use", fmt.Sprintf("%+v", toolUse))
toolUses = append(toolUses, toolUse)
}
msgs = append(msgs, anthropic.NewAssistantMessage(toolUses...))
} else { // normal assistant message
msgs = append(msgs, anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Content)))
}
Expand All @@ -110,8 +123,8 @@ func (p *Provider) GetChatCompletions(
}
}
// add tool result to user messages
for _, tr := range toolResult {
msgs = append(msgs, anthropic.NewUserMessage(tr))
if len(toolResult) > 0 {
msgs = append(msgs, anthropic.NewUserMessage(toolResult...))
}

// send anthropic request
Expand Down Expand Up @@ -149,12 +162,16 @@ func (p *Provider) GetChatCompletions(
toolCallIndex := 0
for _, content := range result.Content {
switch content.Type {
// switch content.AsUnion().(type) {
// text
case anthropic.ContentBlockTypeText:
// case anthropic.TextBlock:
message.Content = content.Text
// tool use
case anthropic.ContentBlockTypeToolUse:
// case anthropic.ToolUseBlock:
i := toolCallIndex
ylog.Debug("anthropic tool use ", "function", content.Name, "arguments", string(content.Input))
message.ToolCalls = append(message.ToolCalls, openai.ToolCall{
Index: &i,
ID: content.ID,
Expand All @@ -173,6 +190,15 @@ func (p *Provider) GetChatCompletions(
choice.FinishReason = convertFinishReason(result.StopReason)
resp.Choices = append(resp.Choices, choice)
// usage
// BUG:
// total tokens = input tokens + output tokens
// #1 429, 139, 568
// #2 613, 171, 784
// = 1042, 310, 1352
// actual:
// "prompt_tokens": 1042,
// "completion_tokens": 310,
// "total_tokens": 923
resp.Usage = openai.Usage{
PromptTokens: int(result.Usage.InputTokens),
CompletionTokens: int(result.Usage.OutputTokens),
Expand Down
7 changes: 6 additions & 1 deletion pkg/bridge/ai/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,10 @@ func (srv *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompl
// 7. do the second call (the second call messages are from user input, first call resopnse and sfn calls result)
req.Messages = append(reqMessages, assistantMessage)
req.Messages = append(req.Messages, llmCalls...)
req.Tools = nil // reset tools field
// anthropic must define tools
if srv.provider.Name() != "anthropic" {
req.Tools = nil // reset tools field
}

srv.logger.Debug(" #2 second call", "request", fmt.Sprintf("%+v", req))

Expand Down Expand Up @@ -434,6 +437,8 @@ func (srv *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompl
resp.Usage.TotalTokens += totalUsage

secondCallSpan.End()

srv.logger.Debug(" #2 second call", "response", fmt.Sprintf("%+v", resp))
w.Header().Set("Content-Type", "application/json")
return json.NewEncoder(w).Encode(resp)
}
Expand Down

0 comments on commit 3cc968f

Please sign in to comment.