Skip to content

Commit

Permalink
feat(functions): support models with no grammar, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mudler committed Apr 18, 2024
1 parent 502c1ee commit a61e4e7
Show file tree
Hide file tree
Showing 13 changed files with 257 additions and 116 deletions.
10 changes: 2 additions & 8 deletions core/config/backend_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/downloader"
"github.com/go-skynet/LocalAI/pkg/functions"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
"gopkg.in/yaml.v3"
Expand Down Expand Up @@ -39,7 +40,7 @@ type BackendConfig struct {
InputToken [][]int `yaml:"-"`
functionCallString, functionCallNameString string `yaml:"-"`

FunctionsConfig Functions `yaml:"function"`
FunctionsConfig functions.FunctionsConfig `yaml:"function"`

FeatureFlag FeatureFlag `yaml:"feature_flags"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early.
// LLM configs (GPT4ALL, Llama.cpp, ...)
Expand Down Expand Up @@ -157,13 +158,6 @@ type AutoGPTQ struct {
UseFastTokenizer bool `yaml:"use_fast_tokenizer"`
}

type Functions struct {
DisableNoAction bool `yaml:"disable_no_action"`
NoActionFunctionName string `yaml:"no_action_function_name"`
NoActionDescriptionName string `yaml:"no_action_description_name"`
ParallelCalls bool `yaml:"parallel_calls"`
}

type TemplateConfig struct {
Chat string `yaml:"chat"`
ChatMessage string `yaml:"chat_message"`
Expand Down
130 changes: 42 additions & 88 deletions core/http/endpoints/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@ import (
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/grammar"
"github.com/go-skynet/LocalAI/pkg/functions"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
Expand Down Expand Up @@ -68,8 +67,8 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
return true
})

results := parseFunctionCall(result, config.FunctionsConfig.ParallelCalls)
noActionToRun := len(results) > 0 && results[0].name == noAction
results := functions.ParseFunctionCall(result, config.FunctionsConfig)
noActionToRun := len(results) > 0 && results[0].Name == noAction || len(results) == 0

switch {
case noActionToRun:
Expand All @@ -82,7 +81,12 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
}
responses <- initialMessage

result, err := handleQuestion(config, req, ml, startupOptions, results[0].arguments, prompt)
replyMessage := ""
if len(results) > 0 {
replyMessage = results[0].Arguments
}

result, err := handleQuestion(config, req, ml, startupOptions, replyMessage, prompt)
if err != nil {
log.Error().Err(err).Msg("error handling question")
return
Expand All @@ -105,7 +109,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup

default:
for i, ss := range results {
name, args := ss.name, ss.arguments
name, args := ss.Name, ss.Arguments

initialMessage := schema.OpenAIResponse{
ID: id,
Expand Down Expand Up @@ -156,8 +160,6 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
}

return func(c *fiber.Ctx) error {
processFunctions := false
funcs := grammar.Functions{}
modelFile, input, err := readRequest(c, ml, startupOptions, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
Expand All @@ -169,6 +171,9 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
}
log.Debug().Msgf("Configuration read: %+v", config)

funcs := input.Functions
shouldUseFn := len(input.Functions) > 0 && config.ShouldUseFunctions()

// Allow the user to set custom actions via config file
// to be "embedded" in each model
noActionName := "answer"
Expand All @@ -182,18 +187,18 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
}

if input.ResponseFormat.Type == "json_object" {
input.Grammar = grammar.JSONBNF
input.Grammar = functions.JSONBNF
}

config.Grammar = input.Grammar

// process functions if we have any defined or if we have a function call string
if len(input.Functions) > 0 && config.ShouldUseFunctions() {
if shouldUseFn {
log.Debug().Msgf("Response needs to process functions")
}

processFunctions = true

noActionGrammar := grammar.Function{
switch {
case !config.FunctionsConfig.NoGrammar && shouldUseFn:
noActionGrammar := functions.Function{
Name: noActionName,
Description: noActionDescription,
Parameters: map[string]interface{}{
Expand All @@ -206,7 +211,6 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
}

// Append the no action function
funcs = append(funcs, input.Functions...)
if !config.FunctionsConfig.DisableNoAction {
funcs = append(funcs, noActionGrammar)
}
Expand All @@ -219,10 +223,17 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
// Update input grammar
jsStruct := funcs.ToJSONStructure()
config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls)
} else if input.JSONFunctionGrammarObject != nil {
case input.JSONFunctionGrammarObject != nil:
config.Grammar = input.JSONFunctionGrammarObject.Grammar("", config.FunctionsConfig.ParallelCalls)
default:
// Force picking one of the functions by the request
if config.FunctionToCall() != "" {
funcs = funcs.Select(config.FunctionToCall())
}
}

// process functions if we have any defined or if we have a function call string

// functions are not supported in stream mode (yet?)
toStream := input.Stream

Expand All @@ -232,8 +243,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup

// If we are using the tokenizer template, we don't need to process the messages
// unless we are processing functions
if !config.TemplateConfig.UseTokenizerTemplate || processFunctions {

if !config.TemplateConfig.UseTokenizerTemplate || shouldUseFn {
suppressConfigSystemPrompt := false
mess := []string{}
for messageIndex, i := range input.Messages {
Expand Down Expand Up @@ -346,11 +356,11 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
templateFile = config.Model
}

if config.TemplateConfig.Chat != "" && !processFunctions {
if config.TemplateConfig.Chat != "" && !shouldUseFn {
templateFile = config.TemplateConfig.Chat
}

if config.TemplateConfig.Functions != "" && processFunctions {
if config.TemplateConfig.Functions != "" && shouldUseFn {
templateFile = config.TemplateConfig.Functions
}

Expand All @@ -370,7 +380,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
}

log.Debug().Msgf("Prompt (after templating): %s", predInput)
if processFunctions {
if shouldUseFn && config.Grammar != "" {
log.Debug().Msgf("Grammar: %+v", config.Grammar)
}
}
Expand All @@ -388,7 +398,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup

responses := make(chan schema.OpenAIResponse)

if !processFunctions {
if !shouldUseFn {
go process(predInput, input, config, ml, responses)
} else {
go processTools(noActionName, predInput, input, config, ml, responses)
Expand Down Expand Up @@ -446,18 +456,23 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
// no streaming mode
default:
result, tokenUsage, err := ComputeChoices(input, predInput, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
if !processFunctions {
if !shouldUseFn {
// no function is called, just reply and use stop as finish reason
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
return
}

results := parseFunctionCall(s, config.FunctionsConfig.ParallelCalls)
noActionsToRun := len(results) > 0 && results[0].name == noActionName
results := functions.ParseFunctionCall(s, config.FunctionsConfig)
noActionsToRun := len(results) > 0 && results[0].Name == noActionName || len(results) == 0

switch {
case noActionsToRun:
result, err := handleQuestion(config, input, ml, startupOptions, results[0].arguments, predInput)
replyMessage := ""
if len(results) > 0 {
replyMessage = results[0].Arguments
}

result, err := handleQuestion(config, input, ml, startupOptions, replyMessage, predInput)
if err != nil {
log.Error().Err(err).Msg("error handling question")
return
Expand All @@ -476,7 +491,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
}

for _, ss := range results {
name, args := ss.name, ss.arguments
name, args := ss.Name, ss.Arguments
if len(input.Tools) > 0 {
// If we are using tools, we condense the function calls into
// a single response choice with all the tools
Expand Down Expand Up @@ -534,7 +549,6 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
// Return the prediction in the response body
return c.JSON(resp)
}

}
}

Expand Down Expand Up @@ -580,63 +594,3 @@ func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, m
}
return backend.Finetune(*config, prompt, prediction.Response), nil
}

type funcCallResults struct {
name string
arguments string
}

func parseFunctionCall(llmresult string, multipleResults bool) []funcCallResults {
results := []funcCallResults{}

// TODO: use generics to avoid this code duplication
if multipleResults {
ss := []map[string]interface{}{}
s := utils.EscapeNewLines(llmresult)
json.Unmarshal([]byte(s), &ss)
log.Debug().Msgf("Function return: %s %+v", s, ss)

for _, s := range ss {
func_name, ok := s["function"]
if !ok {
continue
}
args, ok := s["arguments"]
if !ok {
continue
}
d, _ := json.Marshal(args)
funcName, ok := func_name.(string)
if !ok {
continue
}
results = append(results, funcCallResults{name: funcName, arguments: string(d)})
}
} else {
// As we have to change the result before processing, we can't stream the answer token-by-token (yet?)
ss := map[string]interface{}{}
// This prevent newlines to break JSON parsing for clients
s := utils.EscapeNewLines(llmresult)
json.Unmarshal([]byte(s), &ss)
log.Debug().Msgf("Function return: %s %+v", s, ss)

// The grammar defines the function name as "function", while OpenAI returns "name"
func_name, ok := ss["function"]
if !ok {
return results
}
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
args, ok := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
if !ok {
return results
}
d, _ := json.Marshal(args)
funcName, ok := func_name.(string)
if !ok {
return results
}
results = append(results, funcCallResults{name: funcName, arguments: string(d)})
}

return results
}
4 changes: 2 additions & 2 deletions core/http/endpoints/openai/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/go-skynet/LocalAI/core/config"

"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/grammar"
"github.com/go-skynet/LocalAI/pkg/functions"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
Expand Down Expand Up @@ -70,7 +70,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
}

if input.ResponseFormat.Type == "json_object" {
input.Grammar = grammar.JSONBNF
input.Grammar = functions.JSONBNF
}

config.Grammar = input.Grammar
Expand Down
4 changes: 2 additions & 2 deletions core/http/endpoints/openai/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/go-skynet/LocalAI/core/config"
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/grammar"
"github.com/go-skynet/LocalAI/pkg/functions"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
Expand Down Expand Up @@ -145,7 +145,7 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
}

if input.ToolsChoice != nil {
var toolChoice grammar.Tool
var toolChoice functions.Tool

switch content := input.ToolsChoice.(type) {
case string:
Expand Down
14 changes: 7 additions & 7 deletions core/schema/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package schema
import (
"context"

"github.com/go-skynet/LocalAI/pkg/grammar"
functions "github.com/go-skynet/LocalAI/pkg/functions"
)

// APIError provides error information returned by the OpenAI API.
Expand Down Expand Up @@ -108,7 +108,7 @@ type ChatCompletionResponseFormat struct {
type OpenAIRequest struct {
PredictionOptions

Context context.Context `json:"-"`
Context context.Context `json:"-"`
Cancel context.CancelFunc `json:"-"`

// whisper
Expand All @@ -130,11 +130,11 @@ type OpenAIRequest struct {
Messages []Message `json:"messages" yaml:"messages"`

// A list of available functions to call
Functions []grammar.Function `json:"functions" yaml:"functions"`
FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object
Functions functions.Functions `json:"functions" yaml:"functions"`
FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object

Tools []grammar.Tool `json:"tools,omitempty" yaml:"tools"`
ToolsChoice interface{} `json:"tool_choice,omitempty" yaml:"tool_choice"`
Tools []functions.Tool `json:"tools,omitempty" yaml:"tools"`
ToolsChoice interface{} `json:"tool_choice,omitempty" yaml:"tool_choice"`

Stream bool `json:"stream"`

Expand All @@ -145,7 +145,7 @@ type OpenAIRequest struct {
// A grammar to constrain the LLM output
Grammar string `json:"grammar" yaml:"grammar"`

JSONFunctionGrammarObject *grammar.JSONFunctionStructure `json:"grammar_json_functions" yaml:"grammar_json_functions"`
JSONFunctionGrammarObject *functions.JSONFunctionStructure `json:"grammar_json_functions" yaml:"grammar_json_functions"`

Backend string `json:"backend" yaml:"backend"`

Expand Down
2 changes: 1 addition & 1 deletion pkg/grammar/functions.go → pkg/functions/functions.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package grammar
package functions

import (
"encoding/json"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package grammar
package functions

import (
"testing"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package grammar_test
package functions_test

import (
. "github.com/go-skynet/LocalAI/pkg/grammar"
. "github.com/go-skynet/LocalAI/pkg/functions"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
Expand Down
Loading

0 comments on commit a61e4e7

Please sign in to comment.