Skip to content

Commit

Permalink
Small refactoring and adaptations
Browse files Browse the repository at this point in the history
Signed-off-by: Ettore Di Giacinto <[email protected]>
  • Loading branch information
mudler committed Dec 8, 2024
1 parent 614bb5e commit f47f344
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 165 deletions.
2 changes: 1 addition & 1 deletion core/application/application.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func newApplication(appConfig *config.ApplicationConfig) *Application {
backendLoader: config.NewBackendConfigLoader(appConfig.ModelPath),
modelLoader: model.NewModelLoader(appConfig.ModelPath),
applicationConfig: appConfig,
templatesEvaluator: templates.NewEvaluator(templates.NewTemplateCache(appConfig.ModelPath)),
templatesEvaluator: templates.NewEvaluator(appConfig.ModelPath),
}
}

Expand Down
2 changes: 1 addition & 1 deletion core/http/endpoints/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
predInput = evaluator.TemplateMessages(input.Messages, config, funcs, shouldUseFn)

log.Debug().Msgf("Prompt (after templating): %s", predInput)
if shouldUseFn && config.Grammar != "" {
if config.Grammar != "" {
log.Debug().Msgf("Grammar: %+v", config.Grammar)
}
}
Expand Down
127 changes: 63 additions & 64 deletions pkg/templates/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,59 +20,32 @@ import (
// Technically, order doesn't _really_ matter, but the count must stay in sync, see tests/integration/reflect_test.go
type TemplateType int

type TemplateCache struct {
type templateCache struct {
mu sync.Mutex
templatesPath string
templates map[TemplateType]map[string]*template.Template
jinjaTemplates map[TemplateType]map[string]*exec.Template
}

func NewTemplateCache(templatesPath string) *TemplateCache {
tc := &TemplateCache{
func newTemplateCache(templatesPath string) *templateCache {
tc := &templateCache{
templatesPath: templatesPath,
templates: make(map[TemplateType]map[string]*template.Template),
jinjaTemplates: make(map[TemplateType]map[string]*exec.Template),
}
return tc
}

func (tc *TemplateCache) initializeTemplateMapKey(tt TemplateType) {
func (tc *templateCache) initializeTemplateMapKey(tt TemplateType) {
if _, ok := tc.templates[tt]; !ok {
tc.templates[tt] = make(map[string]*template.Template)
}
}

func (tc *TemplateCache) ExistsInModelPath(s string) bool {
func (tc *templateCache) existsInModelPath(s string) bool {
return utils.ExistsInPath(tc.templatesPath, s)
}

func (tc *TemplateCache) EvaluateTemplate(templateType TemplateType, templateNameOrContent string, in interface{}) (string, error) {
tc.mu.Lock()
defer tc.mu.Unlock()

tc.initializeTemplateMapKey(templateType)
m, ok := tc.templates[templateType][templateNameOrContent]
if !ok {
// return "", fmt.Errorf("template not loaded: %s", templateName)
loadErr := tc.loadTemplateIfExists(templateType, templateNameOrContent)
if loadErr != nil {
return "", loadErr
}
m = tc.templates[templateType][templateNameOrContent] // ok is not important since we check m on the next line, and wealready checked
}
if m == nil {
return "", fmt.Errorf("failed loading a template for %s", templateNameOrContent)
}

var buf bytes.Buffer

if err := m.Execute(&buf, in); err != nil {
return "", err
}
return buf.String(), nil
}

func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templateName string) error {
func (tc *templateCache) loadTemplateIfExists(templateType TemplateType, templateName string) error {

// Check if the template was already loaded
if _, ok := tc.templates[templateType][templateName]; ok {
Expand All @@ -92,7 +65,7 @@ func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templat
}

// can either be a file in the system or a string with the template
if tc.ExistsInModelPath(modelTemplateFile) {
if tc.existsInModelPath(modelTemplateFile) {
d, err := os.ReadFile(file)
if err != nil {
return err
Expand All @@ -112,41 +85,13 @@ func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templat
return nil
}

func (tc *TemplateCache) initializeJinjaTemplateMapKey(tt TemplateType) {
func (tc *templateCache) initializeJinjaTemplateMapKey(tt TemplateType) {
if _, ok := tc.jinjaTemplates[tt]; !ok {
tc.jinjaTemplates[tt] = make(map[string]*exec.Template)
}
}

func (tc *TemplateCache) EvaluateJinjaTemplate(templateType TemplateType, templateNameOrContent string, in map[string]interface{}) (string, error) {
tc.mu.Lock()
defer tc.mu.Unlock()

tc.initializeJinjaTemplateMapKey(templateType)
m, ok := tc.jinjaTemplates[templateType][templateNameOrContent]
if !ok {
// return "", fmt.Errorf("template not loaded: %s", templateName)
loadErr := tc.loadJinjaTemplateIfExists(templateType, templateNameOrContent)
if loadErr != nil {
return "", loadErr
}
m = tc.jinjaTemplates[templateType][templateNameOrContent] // ok is not important since we check m on the next line, and wealready checked
}
if m == nil {
return "", fmt.Errorf("failed loading a template for %s", templateNameOrContent)
}

var buf bytes.Buffer

data := exec.NewContext(in)

if err := m.Execute(&buf, data); err != nil {
return "", err
}
return buf.String(), nil
}

func (tc *TemplateCache) loadJinjaTemplateIfExists(templateType TemplateType, templateName string) error {
func (tc *templateCache) loadJinjaTemplateIfExists(templateType TemplateType, templateName string) error {
// Check if the template was already loaded
if _, ok := tc.jinjaTemplates[templateType][templateName]; ok {
return nil
Expand Down Expand Up @@ -183,3 +128,57 @@ func (tc *TemplateCache) loadJinjaTemplateIfExists(templateType TemplateType, te

return nil
}

func (tc *templateCache) evaluateJinjaTemplate(templateType TemplateType, templateNameOrContent string, in map[string]interface{}) (string, error) {
tc.mu.Lock()
defer tc.mu.Unlock()

tc.initializeJinjaTemplateMapKey(templateType)
m, ok := tc.jinjaTemplates[templateType][templateNameOrContent]
if !ok {
// return "", fmt.Errorf("template not loaded: %s", templateName)
loadErr := tc.loadJinjaTemplateIfExists(templateType, templateNameOrContent)
if loadErr != nil {
return "", loadErr
}
m = tc.jinjaTemplates[templateType][templateNameOrContent] // ok is not important since we check m on the next line, and wealready checked
}
if m == nil {
return "", fmt.Errorf("failed loading a template for %s", templateNameOrContent)
}

var buf bytes.Buffer

data := exec.NewContext(in)

if err := m.Execute(&buf, data); err != nil {
return "", err
}
return buf.String(), nil
}

func (tc *templateCache) evaluateTemplate(templateType TemplateType, templateNameOrContent string, in interface{}) (string, error) {
tc.mu.Lock()
defer tc.mu.Unlock()

tc.initializeTemplateMapKey(templateType)
m, ok := tc.templates[templateType][templateNameOrContent]
if !ok {
// return "", fmt.Errorf("template not loaded: %s", templateName)
loadErr := tc.loadTemplateIfExists(templateType, templateNameOrContent)
if loadErr != nil {
return "", loadErr
}
m = tc.templates[templateType][templateNameOrContent] // ok is not important since we check m on the next line, and wealready checked
}
if m == nil {
return "", fmt.Errorf("failed loading a template for %s", templateNameOrContent)
}

var buf bytes.Buffer

if err := m.Execute(&buf, in); err != nil {
return "", err
}
return buf.String(), nil
}
89 changes: 0 additions & 89 deletions pkg/templates/cache_test.go

This file was deleted.

16 changes: 8 additions & 8 deletions pkg/templates/evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,20 @@ const (
)

type Evaluator struct {
cache *TemplateCache
cache *templateCache
}

func NewEvaluator(cache *TemplateCache) *Evaluator {
func NewEvaluator(modelPath string) *Evaluator {
return &Evaluator{
cache: cache,
cache: newTemplateCache(modelPath),
}
}

func (e *Evaluator) EvaluateTemplateForPrompt(templateType TemplateType, config config.BackendConfig, in PromptTemplateData) (string, error) {
template := ""

// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if e.cache.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
if e.cache.existsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
template = config.Model
}

Expand Down Expand Up @@ -88,11 +88,11 @@ func (e *Evaluator) EvaluateTemplateForPrompt(templateType TemplateType, config
return e.evaluateJinjaTemplateForPrompt(templateType, template, in)
}

return e.cache.EvaluateTemplate(templateType, template, in)
return e.cache.evaluateTemplate(templateType, template, in)
}

func (e *Evaluator) evaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) {
return e.cache.EvaluateTemplate(ChatMessageTemplate, templateName, messageData)
return e.cache.evaluateTemplate(ChatMessageTemplate, templateName, messageData)
}

func (e *Evaluator) templateJinjaChat(templateName string, messageData []ChatMessageTemplateData, funcs []functions.Function) (string, error) {
Expand Down Expand Up @@ -120,7 +120,7 @@ func (e *Evaluator) templateJinjaChat(templateName string, messageData []ChatMes
conversation["tools"] = funcs
}

return e.cache.EvaluateJinjaTemplate(ChatMessageTemplate, templateName, conversation)
return e.cache.evaluateJinjaTemplate(ChatMessageTemplate, templateName, conversation)
}

func (e *Evaluator) evaluateJinjaTemplateForPrompt(templateType TemplateType, templateName string, in PromptTemplateData) (string, error) {
Expand All @@ -130,7 +130,7 @@ func (e *Evaluator) evaluateJinjaTemplateForPrompt(templateType TemplateType, te
conversation["system_prompt"] = in.SystemPrompt
conversation["content"] = in.Input

return e.cache.EvaluateJinjaTemplate(templateType, templateName, conversation)
return e.cache.evaluateJinjaTemplate(templateType, templateName, conversation)
}

func (e *Evaluator) TemplateMessages(messages []schema.Message, config *config.BackendConfig, funcs []functions.Function, shouldUseFn bool) string {
Expand Down
Loading

0 comments on commit f47f344

Please sign in to comment.