-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathchat.go
129 lines (110 loc) · 2.88 KB
/
chat.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
package gollama
import (
"context"
"fmt"
"strings"
)
type ChatOption interface{}
// Chat generates a response to a prompt using the Ollama API.
//
// The first argument is the prompt to generate a response to.
//
// The function takes a variable number of options as arguments. The options are:
// - A slice of strings representing the paths to images that should be passed as vision input.
// - A slice of Tool objects representing the tools that should be available to the model.
//
// The function returns a pointer to a ChatOuput object, which contains the response to the prompt,
// as well as some additional information about the response. If an error occurs, the function
// returns nil and an error.
func (c *Gollama) Chat(ctx context.Context, prompt string, options ...ChatOption) (*ChatOuput, error) {
var (
temperature float64
seed = c.SeedOrNegative
contextLength = c.ContextLength
promptImages = []PromptImage{}
tools = []Tool{}
format = StructuredFormat{}
)
for _, option := range options {
switch opt := option.(type) {
case PromptImage:
promptImages = append(promptImages, opt)
case []PromptImage:
promptImages = opt
case Tool:
tools = append(tools, opt)
case []Tool:
tools = opt
case StructuredFormat:
format = opt
default:
continue
}
}
if seed < 0 {
temperature = c.TemperatureIfNegativeSeed
}
messages := []chatMessage{}
if c.SystemPrompt != "" {
messages = append(messages, chatMessage{
Role: "system",
Content: c.SystemPrompt,
})
}
userMessage := chatMessage{
Role: "user",
Content: prompt,
}
base64VisionImages := make([]string, 0)
for _, image := range promptImages {
base64image, err := base64EncodeFile(image.Filename)
if err != nil {
return nil, err
}
base64VisionImages = append(base64VisionImages, base64image)
}
if len(base64VisionImages) > 0 {
userMessage.Images = base64VisionImages
}
messages = append(messages, userMessage)
req := chatRequest{
Stream: false,
Model: c.ModelName,
Messages: messages,
Options: chatOptionsRequest{
Seed: seed,
Temperature: temperature,
TopK: c.TopK,
TopP: c.TopP,
ContextLength: contextLength,
},
}
if len(tools) > 0 {
req.Tools = &tools
}
if len(format.Properties) > 0 {
req.Format = &format
}
if c.ContextLength != 0 {
req.Options.ContextLength = c.ContextLength
}
var resp chatResponse
err := c.apiPost(ctx, "/api/chat", &resp, req)
if err != nil {
return nil, err
}
if resp.Model != c.ModelName {
return nil, fmt.Errorf("model don't found")
}
out := &ChatOuput{
Role: resp.Message.Role,
Content: resp.Message.Content,
ToolCalls: resp.Message.ToolCalls,
PromptTokens: resp.PromptEvalCount,
ResponseTokens: resp.EvalCount,
}
if c.TrimSpace {
out.Content = strings.TrimSpace(out.Content)
}
return out, nil
}