Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: use openai tools #156

Merged
merged 3 commits into from
Dec 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/pipeline/summarize/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
func main() {

summarize := summarizepipeline.New(
openai.NewCompletion().WithMaxTokens(1000).WithVerbose(true).WithModel(openai.GPT3TextDavinci002),
openai.NewCompletion().WithMaxTokens(1000).WithVerbose(true).WithModel(openai.GPT3Dot5TurboInstruct),
loader.NewTextLoader("state_of_the_union.txt", nil).
WithTextSplitter(textsplitter.NewRecursiveCharacterTextSplitter(2000, 0)),
)
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ require (
github.com/henomis/qdrant-go v1.1.0
github.com/invopop/jsonschema v0.7.0
github.com/pkoukk/tiktoken-go v0.1.1
github.com/sashabaranov/go-openai v1.12.0
github.com/sashabaranov/go-openai v1.17.9
)

require (
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ github.com/pkoukk/tiktoken-go v0.1.1 h1:jtkYlIECjyM9OW1w4rjPmTohK4arORP9V25y6TM6
github.com/pkoukk/tiktoken-go v0.1.1/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sashabaranov/go-openai v1.12.0 h1:aRNHH0gtVfrpIaEolD0sWrLLRnYQNK4cH/bIAHwL8Rk=
github.com/sashabaranov/go-openai v1.12.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sashabaranov/go-openai v1.17.9 h1:QEoBiGKWW68W79YIfXWEFZ7l5cEgZBV4/Ow3uy+5hNY=
github.com/sashabaranov/go-openai v1.17.9/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/testify v1.3.1-0.20190311161405-34c6fa2dc709/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
Expand Down
27 changes: 17 additions & 10 deletions llm/openai/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,21 @@ func (o *OpenAI) BindFunction(
return nil
}

func (o *OpenAI) getFunctions() []openai.FunctionDefinition {
functions := []openai.FunctionDefinition{}
func (o *OpenAI) getFunctions() []openai.Tool {
tools := []openai.Tool{}

for _, function := range o.functions {
functions = append(functions, openai.FunctionDefinition{
Name: function.Name,
Description: function.Description,
Parameters: function.Parameters,
tools = append(tools, openai.Tool{
Type: "function",
Function: openai.FunctionDefinition{
Name: function.Name,
Description: function.Description,
Parameters: function.Parameters,
},
})
}

return functions
return tools
}

func extractFunctionParameter(f interface{}) (map[string]interface{}, error) {
Expand Down Expand Up @@ -170,12 +173,16 @@ func callFnWithArgumentAsJSON(fn interface{}, argumentAsJSON string) (string, er
}

func (o *OpenAI) functionCall(response openai.ChatCompletionResponse) (string, error) {
fn, ok := o.functions[response.Choices[0].Message.FunctionCall.Name]
fn, ok := o.functions[response.Choices[0].Message.ToolCalls[0].Function.Name]
if !ok {
return "", fmt.Errorf("%w: unknown function %s", ErrOpenAIChat, response.Choices[0].Message.FunctionCall.Name)
return "", fmt.Errorf(
"%w: unknown function %s",
ErrOpenAIChat,
response.Choices[0].Message.ToolCalls[0].Function.Name,
)
}

resultAsJSON, err := callFnWithArgumentAsJSON(fn.Fn, response.Choices[0].Message.FunctionCall.Arguments)
resultAsJSON, err := callFnWithArgumentAsJSON(fn.Fn, response.Choices[0].Message.ToolCalls[0].Function.Arguments)
if err != nil {
return "", fmt.Errorf("%w: %w", ErrOpenAIChat, err)
}
Expand Down
72 changes: 44 additions & 28 deletions llm/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,29 +32,29 @@ const (
type Model string

const (
GPT432K0613 Model = openai.GPT432K0613
GPT432K0314 Model = openai.GPT432K0314
GPT432K Model = openai.GPT432K
GPT40613 Model = openai.GPT40613
GPT40314 Model = openai.GPT40314
GPT4 Model = openai.GPT4
GPT3Dot5Turbo0613 Model = openai.GPT3Dot5Turbo0613
GPT3Dot5Turbo0301 Model = openai.GPT3Dot5Turbo0301
GPT3Dot5Turbo16K Model = openai.GPT3Dot5Turbo16K
GPT3Dot5Turbo16K0613 Model = openai.GPT3Dot5Turbo16K0613
GPT3Dot5Turbo Model = openai.GPT3Dot5Turbo
GPT3TextDavinci003 Model = openai.GPT3TextDavinci003
GPT3TextDavinci002 Model = openai.GPT3TextDavinci002
GPT3TextCurie001 Model = openai.GPT3TextCurie001
GPT3TextBabbage001 Model = openai.GPT3TextBabbage001
GPT3TextAda001 Model = openai.GPT3TextAda001
GPT3TextDavinci001 Model = openai.GPT3TextDavinci001
GPT3DavinciInstructBeta Model = openai.GPT3DavinciInstructBeta
GPT3Davinci Model = openai.GPT3Davinci
GPT3CurieInstructBeta Model = openai.GPT3CurieInstructBeta
GPT3Curie Model = openai.GPT3Curie
GPT3Ada Model = openai.GPT3Ada
GPT3Babbage Model = openai.GPT3Babbage
GPT432K0613 Model = openai.GPT432K0613
GPT432K0314 Model = openai.GPT432K0314
GPT432K Model = openai.GPT432K
GPT40613 Model = openai.GPT40613
GPT40314 Model = openai.GPT40314
GPT4TurboPreview Model = openai.GPT4TurboPreview
GPT4VisionPreview Model = openai.GPT4VisionPreview
GPT4 Model = openai.GPT4
GPT3Dot5Turbo1106 Model = openai.GPT3Dot5Turbo1106
GPT3Dot5Turbo0613 Model = openai.GPT3Dot5Turbo0613
GPT3Dot5Turbo0301 Model = openai.GPT3Dot5Turbo0301
GPT3Dot5Turbo16K Model = openai.GPT3Dot5Turbo16K
GPT3Dot5Turbo16K0613 Model = openai.GPT3Dot5Turbo16K0613
GPT3Dot5Turbo Model = openai.GPT3Dot5Turbo
GPT3Dot5TurboInstruct Model = openai.GPT3Dot5TurboInstruct
GPT3Davinci Model = openai.GPT3Davinci
GPT3Davinci002 Model = openai.GPT3Davinci002
GPT3Curie Model = openai.GPT3Curie
GPT3Curie002 Model = openai.GPT3Curie002
GPT3Ada Model = openai.GPT3Ada
GPT3Ada002 Model = openai.GPT3Ada002
GPT3Babbage Model = openai.GPT3Babbage
GPT3Babbage002 Model = openai.GPT3Babbage002
)

type UsageCallback func(types.Meta)
Expand All @@ -70,6 +70,7 @@ type OpenAI struct {
usageCallback UsageCallback
functions map[string]Function
functionsMaxIterations uint
toolChoice *string
calledFunctionName *string
finishReason string
cache *cache.Cache
Expand Down Expand Up @@ -137,6 +138,11 @@ func (o *OpenAI) WithCompletionCache(cache *cache.Cache) *OpenAI {
return o
}

func (o *OpenAI) WithToolChoice(toolChoice string) *OpenAI {
o.toolChoice = &toolChoice
return o
}

// CalledFunctionName returns the name of the function that was called.
func (o *OpenAI) CalledFunctionName() *string {
return o.calledFunctionName
Expand All @@ -149,7 +155,7 @@ func (o *OpenAI) FinishReason() string {

func NewCompletion() *OpenAI {
return New(
GPT3TextDavinci003,
GPT3Dot5TurboInstruct,
DefaultOpenAITemperature,
DefaultOpenAIMaxTokens,
false,
Expand Down Expand Up @@ -308,7 +314,17 @@ func (o *OpenAI) Chat(ctx context.Context, prompt *chat.Chat) (string, error) {
}

if len(o.functions) > 0 {
chatCompletionRequest.Functions = o.getFunctions()
chatCompletionRequest.Tools = o.getFunctions()
if o.toolChoice != nil {
chatCompletionRequest.ToolChoice = openai.ToolChoice{
Type: openai.ToolTypeFunction,
Function: openai.ToolFunction{
Name: *o.toolChoice,
},
}
} else {
chatCompletionRequest.ToolChoice = "auto"
}
}

response, err := o.openAIClient.CreateChatCompletion(
Expand All @@ -332,10 +348,10 @@ func (o *OpenAI) Chat(ctx context.Context, prompt *chat.Chat) (string, error) {

o.finishReason = string(response.Choices[0].FinishReason)
o.calledFunctionName = nil
if response.Choices[0].FinishReason == "function_call" && len(o.functions) > 0 {
if len(response.Choices[0].Message.ToolCalls) > 0 && len(o.functions) > 0 {
if o.verbose {
fmt.Printf("Calling function %s\n", response.Choices[0].Message.FunctionCall.Name)
fmt.Printf("Function call arguments: %s\n", response.Choices[0].Message.FunctionCall.Arguments)
fmt.Printf("Calling function %s\n", response.Choices[0].Message.ToolCalls[0].Function.Name)
fmt.Printf("Function call arguments: %s\n", response.Choices[0].Message.ToolCalls[0].Function.Arguments)
}

content, err = o.functionCall(response)
Expand Down