Skip to content

Commit

Permalink
feat(tokens): add token count display to sidebar (#956)
Browse files Browse the repository at this point in the history
* feat (tokens) add token count display to sidebar

* refactor: calculate the real tokens and reuse input hints to avoid occlusion

---------

Co-authored-by: yetone <[email protected]>
  • Loading branch information
Mng-dev-ai and yetone authored Dec 17, 2024
1 parent e612ad7 commit e98fa46
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 93 deletions.
2 changes: 1 addition & 1 deletion crates/avante-templates/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ impl<'a> State<'a> {
#[derive(Debug, Serialize, Deserialize)]
struct SelectedFile {
path: String,
content: String,
content: Option<String>,
file_type: String,
}

Expand Down
68 changes: 46 additions & 22 deletions lua/avante/llm.lua
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ M.CANCEL_PATTERN = "AvanteLLMEscape"

local group = api.nvim_create_augroup("avante_llm", { clear = true })

---@param opts StreamOptions
---@param Provider AvanteProviderFunctor
M._stream = function(opts, Provider)
-- print opts
---@param opts GeneratePromptsOptions
---@return AvantePromptOptions
M.generate_prompts = function(opts)
local Provider = opts.provider or P[Config.provider]
local mode = opts.mode or "planning"
---@type AvanteProviderFunctor
local _, body_opts = P.parse_config(Provider)
Expand All @@ -42,7 +42,8 @@ M._stream = function(opts, Provider)
instructions = table.concat(lines, "\n")
end

Path.prompts.initialize(Path.prompts.get(opts.bufnr))
local project_root = Utils.root.get()
Path.prompts.initialize(Path.prompts.get(project_root))

local template_opts = {
use_xml_format = Provider.use_xml_format,
Expand Down Expand Up @@ -104,11 +105,30 @@ M._stream = function(opts, Provider)
end

---@type AvantePromptOptions
local code_opts = {
return {
system_prompt = system_prompt,
messages = messages,
image_paths = image_paths,
}
end

---@param opts GeneratePromptsOptions
---@return integer
M.calculate_tokens = function(opts)
local code_opts = M.generate_prompts(opts)
local tokens = Utils.tokens.calculate_tokens(code_opts.system_prompt)
for _, message in ipairs(code_opts.messages) do
tokens = tokens + Utils.tokens.calculate_tokens(message.content)
end
return tokens
end

---@param opts StreamOptions
M._stream = function(opts)
local Provider = opts.provider or P[Config.provider]

local code_opts = M.generate_prompts(opts)

---@type string
local current_event_state = nil

Expand Down Expand Up @@ -248,7 +268,7 @@ M._stream = function(opts, Provider)
return active_job
end

local function _merge_response(first_response, second_response, opts, Provider)
local function _merge_response(first_response, second_response, opts)
local prompt = "\n" .. Config.dual_boost.prompt
prompt = prompt
:gsub("{{[%s]*provider1_output[%s]*}}", first_response)
Expand All @@ -259,28 +279,28 @@ local function _merge_response(first_response, second_response, opts, Provider)
-- append this reference prompt to the code_opts messages at last
opts.instructions = opts.instructions .. prompt

M._stream(opts, Provider)
M._stream(opts)
end

local function _collector_process_responses(collector, opts, Provider)
local function _collector_process_responses(collector, opts)
if not collector[1] or not collector[2] then
Utils.error("One or both responses failed to complete")
return
end
_merge_response(collector[1], collector[2], opts, Provider)
_merge_response(collector[1], collector[2], opts)
end

local function _collector_add_response(collector, index, response, opts, Provider)
local function _collector_add_response(collector, index, response, opts)
collector[index] = response
collector.count = collector.count + 1

if collector.count == 2 then
collector.timer:stop()
_collector_process_responses(collector, opts, Provider)
_collector_process_responses(collector, opts)
end
end

M._dual_boost_stream = function(opts, Provider, Provider1, Provider2)
M._dual_boost_stream = function(opts, Provider1, Provider2)
Utils.debug("Starting Dual Boost Stream")

local collector = {
Expand All @@ -299,7 +319,7 @@ M._dual_boost_stream = function(opts, Provider, Provider1, Provider2)
Utils.warn("Dual boost stream timeout reached")
collector.timer:stop()
-- Process whatever responses we have
_collector_process_responses(collector, opts, Provider)
_collector_process_responses(collector, opts)
end
end)
)
Expand All @@ -317,15 +337,19 @@ M._dual_boost_stream = function(opts, Provider, Provider1, Provider2)
return
end
Utils.debug(string.format("Response %d completed", index))
_collector_add_response(collector, index, response, opts, Provider)
_collector_add_response(collector, index, response, opts)
end,
})
end

-- Start both streams
local success, err = xpcall(function()
M._stream(create_stream_opts(1), Provider1)
M._stream(create_stream_opts(2), Provider2)
local opts1 = create_stream_opts(1)
opts1.provider = Provider1
M._stream(opts1)
local opts2 = create_stream_opts(2)
opts2.provider = Provider2
M._stream(opts2)
end, function(err) return err end)
if not success then Utils.error("Failed to start dual_boost streams: " .. tostring(err)) end
end
Expand All @@ -348,12 +372,13 @@ end
---@field diagnostics string | nil
---@field history_messages AvanteLLMMessage[]
---
---@class StreamOptions: TemplateOptions
---@class GeneratePromptsOptions: TemplateOptions
---@field ask boolean
---@field bufnr integer
---@field instructions string
---@field mode LlmMode
---@field provider AvanteProviderFunctor | nil
---
---@class StreamOptions: GeneratePromptsOptions
---@field on_chunk AvanteChunkParser
---@field on_complete AvanteCompleteParser

Expand All @@ -375,11 +400,10 @@ M.stream = function(opts)
return original_on_complete(err)
end)
end
local Provider = opts.provider or P[Config.provider]
if Config.dual_boost.enabled then
M._dual_boost_stream(opts, Provider, P[Config.dual_boost.first_provider], P[Config.dual_boost.second_provider])
M._dual_boost_stream(opts, P[Config.dual_boost.first_provider], P[Config.dual_boost.second_provider])
else
M._stream(opts, Provider)
M._stream(opts)
end
end

Expand Down
7 changes: 3 additions & 4 deletions lua/avante/path.lua
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,15 @@ local templates = nil

Prompt.templates = { planning = nil, editing = nil, suggesting = nil }

-- Creates a directory in the cache path for the given buffer and copies the custom prompts to it.
-- We need to do this beacuse the prompt template engine requires a given directory to load all required files.
-- PERF: Hmm instead of copy to cache, we can also load in globals context, but it requires some work on bindings. (eh maybe?)
---@param bufnr number
---@param project_root string
---@return string the resulted cache_directory to be loaded with avante_templates
Prompt.get = function(bufnr)
Prompt.get = function(project_root)
if not P.available() then error("Make sure to build avante (missing avante_templates)", 2) end

-- get root directory of given bufnr
local directory = Path:new(Utils.root.get({ buf = bufnr }))
local directory = Path:new(project_root)
if Utils.get_os_name() == "windows" then directory = Path:new(directory:absolute():gsub("^%a:", "")[1]) end
---@cast directory Path
---@type Path
Expand Down
1 change: 0 additions & 1 deletion lua/avante/selection.lua
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,6 @@ function Selection:create_editing_input()
local diagnostics = Utils.get_current_selection_diagnostics(code_bufnr, self.selection)

Llm.stream({
bufnr = code_bufnr,
ask = true,
project_context = vim.json.encode(project_context),
diagnostics = vim.json.encode(diagnostics),
Expand Down
147 changes: 83 additions & 64 deletions lua/avante/sidebar.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1491,6 +1491,74 @@ function Sidebar:create_input_container(opts)

local chat_history = Path.history.load(self.code.bufnr)

---@param request string
---@return GeneratePromptsOptions
local function get_generate_prompts_options(request)
local filetype = api.nvim_get_option_value("filetype", { buf = self.code.bufnr })

local selected_code_content = nil
if self.code.selection ~= nil then selected_code_content = self.code.selection.content end

local mentions = Utils.extract_mentions(request)
request = mentions.new_content

local file_ext = api.nvim_buf_get_name(self.code.bufnr):match("^.+%.(.+)$")

local project_context = mentions.enable_project_context and RepoMap.get_repo_map(file_ext) or nil

local selected_files_contents = self.file_selector:get_selected_files_contents()

local diagnostics = nil
if mentions.enable_diagnostics then
if self.code ~= nil and self.code.bufnr ~= nil and self.code.selection ~= nil then
diagnostics = Utils.get_current_selection_diagnostics(self.code.bufnr, self.code.selection)
else
diagnostics = Utils.get_diagnostics(self.code.bufnr)
end
end

local history_messages = {}
for i = #chat_history, 1, -1 do
local entry = chat_history[i]
if entry.reset_memory then break end
if
entry.request == nil
or entry.original_response == nil
or entry.request == ""
or entry.original_response == ""
then
break
end
table.insert(history_messages, 1, { role = "assistant", content = entry.original_response })
local user_content = ""
if entry.selected_file ~= nil then
user_content = user_content .. "SELECTED FILE: " .. entry.selected_file.filepath .. "\n\n"
end
if entry.selected_code ~= nil then
user_content = user_content
.. "SELECTED CODE:\n\n```"
.. entry.selected_code.filetype
.. "\n"
.. entry.selected_code.content
.. "\n```\n\n"
end
user_content = user_content .. "USER PROMPT:\n\n" .. entry.request
table.insert(history_messages, 1, { role = "user", content = user_content })
end

return {
ask = opts.ask,
project_context = vim.json.encode(project_context),
selected_files = selected_files_contents,
diagnostics = vim.json.encode(diagnostics),
history_messages = history_messages,
code_lang = filetype,
selected_code = selected_code_content,
instructions = request,
mode = "planning",
}
end

---@param request string
local function handle_submit(request)
local model = Config.has_provider(Config.provider) and Config.get_provider(Config.provider).model or "default"
Expand Down Expand Up @@ -1518,9 +1586,6 @@ function Sidebar:create_input_container(opts)
self:update_content("", { focus = true, scroll = false })
self:update_content(content_prefix .. generating_text)

local selected_code_content = nil
if self.code.selection ~= nil then selected_code_content = self.code.selection.content end

if request:sub(1, 1) == "/" then
local command, args = request:match("^/(%S+)%s*(.*)")
if command == nil then
Expand All @@ -1542,8 +1607,6 @@ function Sidebar:create_input_container(opts)
Utils.error("Invalid end line number", { once = true, title = "Avante" })
return
end
selected_code_content =
table.concat(api.nvim_buf_get_lines(self.code.bufnr, start_line - 1, end_line, false), "\n")
request = question
end)
else
Expand Down Expand Up @@ -1632,67 +1695,15 @@ function Sidebar:create_input_container(opts)
Path.history.save(self.code.bufnr, chat_history)
end

local mentions = Utils.extract_mentions(request)
request = mentions.new_content

local file_ext = api.nvim_buf_get_name(self.code.bufnr):match("^.+%.(.+)$")

local project_context = mentions.enable_project_context and RepoMap.get_repo_map(file_ext) or nil

local selected_files_contents = self.file_selector:get_selected_files_contents()

local diagnostics = nil
if mentions.enable_diagnostics then
if self.code ~= nil and self.code.bufnr ~= nil and self.code.selection ~= nil then
diagnostics = Utils.get_current_selection_diagnostics(self.code.bufnr, self.code.selection)
else
diagnostics = Utils.get_diagnostics(self.code.bufnr)
end
end

local history_messages = {}
for i = #chat_history, 1, -1 do
local entry = chat_history[i]
if entry.reset_memory then break end
if
entry.request == nil
or entry.original_response == nil
or entry.request == ""
or entry.original_response == ""
then
break
end
table.insert(history_messages, 1, { role = "assistant", content = entry.original_response })
local user_content = ""
if entry.selected_file ~= nil then
user_content = user_content .. "SELECTED FILE: " .. entry.selected_file.filepath .. "\n\n"
end
if entry.selected_code ~= nil then
user_content = user_content
.. "SELECTED CODE:\n\n```"
.. entry.selected_code.filetype
.. "\n"
.. entry.selected_code.content
.. "\n```\n\n"
end
user_content = user_content .. "USER PROMPT:\n\n" .. entry.request
table.insert(history_messages, 1, { role = "user", content = user_content })
end

Llm.stream({
bufnr = self.code.bufnr,
ask = opts.ask,
project_context = vim.json.encode(project_context),
selected_files = selected_files_contents,
diagnostics = vim.json.encode(diagnostics),
history_messages = history_messages,
code_lang = filetype,
selected_code = selected_code_content,
instructions = request,
mode = "planning",
local generate_prompts_options = get_generate_prompts_options(request)
---@type StreamOptions
---@diagnostic disable-next-line: assign-type-mismatch
local stream_options = vim.tbl_deep_extend("force", generate_prompts_options, {
on_chunk = on_chunk,
on_complete = on_complete,
})

Llm.stream(stream_options)
end

local get_position = function()
Expand Down Expand Up @@ -1827,7 +1838,15 @@ function Sidebar:create_input_container(opts)
local function show_hint()
close_hint() -- Close the existing hint window

local hint_text = (fn.mode() ~= "i" and Config.mappings.submit.normal or Config.mappings.submit.insert)
local input_value = table.concat(api.nvim_buf_get_lines(self.input_container.bufnr, 0, -1, false), "\n")

local generate_prompts_options = get_generate_prompts_options(input_value)
local tokens = Llm.calculate_tokens(generate_prompts_options)

local hint_text = "Tokens: "
.. tostring(tokens)
.. "; "
.. (fn.mode() ~= "i" and Config.mappings.submit.normal or Config.mappings.submit.insert)
.. ": submit"

local buf = api.nvim_create_buf(false, true)
Expand Down
1 change: 0 additions & 1 deletion lua/avante/suggestion.lua
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ function Suggestion:suggest()

Llm.stream({
provider = provider,
bufnr = bufnr,
ask = true,
selected_files = { { content = code_content, file_type = filetype, path = "" } },
code_lang = filetype,
Expand Down

0 comments on commit e98fa46

Please sign in to comment.