-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathchat.lua
322 lines (263 loc) · 8.59 KB
/
chat.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
local segment = require('model.util.segment')
local util = require('model.util')
local juice = require('model.util.juice')
local M = {}
---@class ChatPrompt
---@field provider Provider The API provider for this prompt
---@field create fun(input: string, context: Context): string | ChatContents Converts input and context to the first message text or ChatContents
---@field run fun(messages: ChatMessage[], config: ChatConfig): table | fun(resolve: fun(params: table): nil ) ) Converts chat messages and config into completion request params
---@field system? string System instruction
---@field params? table Static request parameters
---@field options? table Provider options
---@class ChatMessage
---@field role 'user' | 'assistant'
---@field content string
---@alias ChatConfig { system?: string, params?: table, options?: table }
---@class ChatContents
---@field config ChatConfig Configuration for this chat buffer, used by chatprompt.run
---@field messages ChatMessage[] Messages in the chat buffer
--- Splits lines into array of { role: 'user' | 'assistant', content: string }
--- If first line starts with '> ', then the rest of that line is system message
---@param text string Text of buffer. '\n======\n' denote alternations between user and assistant roles
---@return { messages: { role: 'user'|'assistant', content: string}[], system?: string }
local function split_messages(text)
local lines = vim.fn.split(text, '\n')
local messages = {}
local system
local chunk_lines = {}
local chunk_is_user = true
--- Insert message and reset/toggle chunk state. User text is trimmed.
local function add_message()
local text_ = table.concat(chunk_lines, '\n')
table.insert(messages, {
role = chunk_is_user and 'user' or 'assistant',
content = chunk_is_user and vim.trim(text_) or text_,
})
chunk_lines = {}
chunk_is_user = not chunk_is_user
end
for i, line in ipairs(lines) do
if i == 1 then
system = line:match('^> (.+)')
if system == nil then
table.insert(chunk_lines, line)
end
elseif line == '======' then
add_message()
else
table.insert(chunk_lines, line)
end
end
-- add text after last `======` if not empty
if table.concat(chunk_lines, '') ~= '' then
add_message()
end
return {
system = system,
messages = messages,
}
end
---@param text string Input text of buffer
---@return { chat: string, config?: table, rest: string }
local function parse_config(text)
if text:match('^---$') then
error('Chat buffer must start with chat name, not config')
end
if text:match('^>') then
error('Chat buffer must start with chat name, not system instruction')
end
local chat_name, name_rest = text:match('^(.-)\n(.*)')
local params_text, rest = name_rest:match('%-%-%-\n(.-)\n%-%-%-\n(.*)')
if chat_name == '' then
error('Chat buffer must start with chat name, not empty line')
end
if params_text == nil then
return {
config = {},
rest = vim.fn.trim(name_rest),
chat = chat_name,
}
else
local config = vim.fn.luaeval(params_text)
if type(config) ~= 'table' then
error('Evaluated config text is not a lua table')
end
return {
config = config,
rest = vim.fn.trim(rest),
chat = chat_name,
}
end
end
--- Parse a chat file. Must start with a chat name, can follow with a lua table
--- of config between `---`. If the next line starts with `> `, it is parsed as
--- the system instruction. The rest of the text is parsed as alternating
--- user/assistant messages, with `\n======\n` delimiters.
---@param text string
---@return { contents: ChatContents, chat: string }
function M.parse(text)
local parsed = parse_config(text)
local messages_and_system = split_messages(parsed.rest)
parsed.config.system = messages_and_system.system
return {
contents = {
messages = messages_and_system.messages,
config = parsed.config,
},
chat = parsed.chat,
}
end
---@param contents ChatContents
---@param name string
---@return string
function M.to_string(contents, name)
local result = name .. '\n'
if not vim.tbl_isempty(contents.config) then
-- TODO consider refactoring this so we're not treating system special
-- Either remove it from contents.config so that it sits next to config
-- or just let it be a normal config field
local without_system = util.table.without(contents.config, 'system')
if without_system and not vim.tbl_isempty(without_system) then
result = result .. '---\n' .. vim.inspect(without_system) .. '\n---\n'
end
if contents.config.system then
result = result .. '> ' .. contents.config.system .. '\n'
end
end
for i, message in ipairs(contents.messages) do
if i ~= 1 then
result = result .. '\n======\n'
end
if message.role == 'user' then
result = result .. '\n' .. message.content .. '\n'
else
result = result .. message.content
end
end
if #contents.messages % 2 == 0 then
result = result .. '\n======\n'
end
return vim.fn.trim(result, '\n', 2) -- trim trailing newline
end
function M.build_contents(chat_prompt, input_context)
local first_message_or_contents =
chat_prompt.create(input_context.input, input_context.context)
local config = {
options = chat_prompt.options,
params = chat_prompt.params,
system = chat_prompt.system,
}
---@type ChatContents
local chat_contents
if type(first_message_or_contents) == 'string' then
chat_contents = {
config = config,
messages = {
{
role = 'user',
content = first_message_or_contents,
},
},
}
elseif type(first_message_or_contents) == 'table' then
chat_contents = vim.tbl_deep_extend(
'force',
{ config = config },
first_message_or_contents
)
else
error(
'ChatPrompt.create() needs to return a string for the first message or an ChatContents'
)
end
return chat_contents
end
function M.create_buffer(text, smods)
if smods.tab > 0 then
vim.cmd.tabnew()
elseif smods.horizontal then
vim.cmd.new()
else
vim.cmd.vnew()
end
vim.o.ft = 'mchat'
vim.cmd.syntax({ 'sync', 'fromstart' })
local lines = vim.fn.split(text, '\n')
---@cast lines string[]
vim.api.nvim_buf_set_lines(0, 0, 0, false, lines)
end
local function needs_nl(buf_lines)
local last_line = buf_lines[#buf_lines]
return not last_line or vim.fn.trim(last_line) ~= ''
end
---@param opts { chats?: table<string, ChatPrompt> }
function M.run_chat(opts)
local buf_lines = vim.api.nvim_buf_get_lines(0, 0, -1, false)
local parsed = M.parse(table.concat(buf_lines, '\n'))
local chat_name =
assert(parsed.chat, 'Chat buffer first line must be a chat prompt name')
---@type ChatPrompt
local chat_prompt = assert(
vim.tbl_get(opts, 'chats', chat_name),
'Chat "' .. chat_name .. '" not found'
)
local run_params =
chat_prompt.run(parsed.contents.messages, parsed.contents.config)
if run_params == nil then
error('Chat prompt run() returned nil')
end
local starter_seperator = needs_nl(buf_lines) and '\n======\n' or '======\n'
local seg
local last_msg = parsed.contents.messages[#parsed.contents.messages]
if last_msg.role == 'user' then
seg = segment.create_segment_at(#buf_lines, 0)
seg.add(starter_seperator)
else
seg = segment.create_segment_at(#buf_lines-1, #buf_lines[#buf_lines])
end
local sayer = juice.sayer()
---@type StreamHandlers
local handlers = {
on_partial = function(text)
seg.add(text)
sayer.say(text)
end,
on_finish = function(text, reason)
sayer.finish()
if text then
seg.set_text(starter_seperator .. text .. '\n======\n')
else
seg.add('\n======\n')
end
seg.clear_hl()
if reason and reason ~= 'stop' and reason ~= 'done' then
util.notify(reason)
end
end,
on_error = function(err, label)
util.notify(vim.inspect(err), vim.log.levels.ERROR, { title = label })
seg.set_text('')
seg.clear_hl()
end,
segment = seg,
}
local options = parsed.contents.config.options or {}
local params = parsed.contents.config.params or {}
if type(run_params) == 'function' then
run_params(function(async_params)
local merged_params = vim.tbl_deep_extend('force', params, async_params)
seg.data.cancel = chat_prompt.provider.request_completion(
handlers,
merged_params,
options
)
end)
else
seg.data.cancel = chat_prompt.provider.request_completion(
handlers,
vim.tbl_deep_extend('force', params, run_params),
options
)
end
end
return M