diff --git a/lua/avante/init.lua b/lua/avante/init.lua
index 01350fedd..86319838b 100644
--- a/lua/avante/init.lua
+++ b/lua/avante/init.lua
@@ -4,6 +4,7 @@ local Path = require("plenary.path")
local n = require("nui-components")
local diff = require("avante.diff")
local utils = require("avante.utils")
+local tiktoken = require("avante.tiktoken")
local api = vim.api
local fn = vim.fn
@@ -140,7 +141,7 @@ local system_prompt = [[
You are an excellent programming expert.
]]
-local user_prompt_tpl = [[
+local base_user_prompt = [[
Your primary task is to suggest code modifications with precise line number ranges. Follow these instructions meticulously:
1. Carefully analyze the original code, paying close attention to its structure and line numbers. Line numbers start from 1 and include ALL lines, even empty ones.
@@ -183,87 +184,119 @@ Replace lines: {{start_line}}-{{end_line}}
- Do not show the content after these modifications.
Remember: Accurate line numbers are CRITICAL. The range start_line to end_line must include ALL lines to be replaced, from the very first to the very last. Double-check every range before finalizing your response, paying special attention to the start_line to ensure it hasn't shifted down. Ensure that your line numbers perfectly match the original code structure without any overall shift.
-
-QUESTION: ${{question}}
-
-CODE:
-```
-${{code}}
-```
]]
-local function call_claude_api_stream(prompt, original_content, on_chunk, on_complete)
+local function call_claude_api_stream(question, code_lang, code_content, on_chunk, on_complete)
local api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key then
error("ANTHROPIC_API_KEY environment variable is not set")
end
- local user_prompt = user_prompt_tpl:gsub("${{question}}", prompt):gsub("${{code}}", original_content)
-
- print("Sending request to Claude API...")
+ local user_prompt = base_user_prompt
- local tokens = M.config.claude.model == "claude-3-5-sonnet-20240620" and 8192 or 4096
+ local tokens = M.config.claude.max_tokens
local headers = {
["Content-Type"] = "application/json",
["x-api-key"] = api_key,
["anthropic-version"] = "2023-06-01",
- ["anthropic-beta"] = "messages-2023-12-15",
+ ["anthropic-beta"] = "prompt-caching-2024-07-31",
}
- if M.config.claude.model == "claude-3-5-sonnet-20240620" then
- headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15"
+ local code_prompt_obj = {
+ type = "text",
+ text = string.format("```%s\n%s```
", code_lang, code_content),
+ }
+
+ local user_prompt_obj = {
+ type = "text",
+ text = user_prompt,
+ }
+
+ if tiktoken.count(code_prompt_obj.text) > 1024 then
+ code_prompt_obj.cache_control = { type = "ephemeral" }
+ end
+
+ if tiktoken.count(user_prompt_obj.text) > 1024 then
+ user_prompt_obj.cache_control = { type = "ephemeral" }
end
+ local params = {
+ model = M.config.claude.model,
+ system = system_prompt,
+ messages = {
+ {
+ role = "user",
+ content = {
+ code_prompt_obj,
+ {
+ type = "text",
+ text = string.format("%s", question),
+ },
+ user_prompt_obj,
+ },
+ },
+ },
+ stream = true,
+ temperature = M.config.claude.temperature,
+ max_tokens = tokens,
+ }
+
local url = utils.trim_suffix(M.config.claude.endpoint, "/") .. "/v1/messages"
+ print("Sending request to Claude API...")
+
curl.post(url, {
---@diagnostic disable-next-line: unused-local
stream = function(err, data, job)
if err then
- error("Error: " .. vim.inspect(err))
+ on_complete(err)
return
end
- if data then
- for line in data:gmatch("[^\r\n]+") do
- if line:sub(1, 6) == "data: " then
- vim.schedule(function()
- local success, parsed = pcall(fn.json_decode, line:sub(7))
- if success and parsed and parsed.type == "content_block_delta" then
- on_chunk(parsed.delta.text)
- elseif success and parsed and parsed.type == "message_stop" then
- -- Stream request completed
- on_complete()
- elseif success and parsed and parsed.type == "error" then
- print("Error: " .. vim.inspect(parsed))
- -- Stream request completed
- on_complete()
- end
- end)
- end
+ if not data then
+ return
+ end
+ for line in data:gmatch("[^\r\n]+") do
+ if line:sub(1, 6) ~= "data: " then
+ return
end
+ vim.schedule(function()
+ local success, parsed = pcall(fn.json_decode, line:sub(7))
+ if not success then
+ error("Error: failed to parse json: " .. parsed)
+ return
+ end
+ if parsed and parsed.type == "content_block_delta" then
+ on_chunk(parsed.delta.text)
+ elseif parsed and parsed.type == "message_stop" then
+ -- Stream request completed
+ on_complete(nil)
+ elseif parsed and parsed.type == "error" then
+ -- Stream request completed
+ on_complete(parsed)
+ end
+ end)
end
end,
headers = headers,
- body = fn.json_encode({
- model = M.config.claude.model,
- system = system_prompt,
- messages = {
- { role = "user", content = user_prompt },
- },
- stream = true,
- temperature = M.config.claude.temperature,
- max_tokens = tokens,
- }),
+ body = fn.json_encode(params),
})
end
-local function call_openai_api_stream(prompt, original_content, on_chunk, on_complete)
+local function call_openai_api_stream(question, code_lang, code_content, on_chunk, on_complete)
local api_key = os.getenv("OPENAI_API_KEY")
if not api_key then
error("OPENAI_API_KEY environment variable is not set")
end
- local user_prompt = user_prompt_tpl:gsub("${{question}}", prompt):gsub("${{code}}", original_content)
+ local user_prompt = base_user_prompt
+ .. "\n\nQUESTION:\n"
+ .. question
+ .. "\n\nCODE:\n"
+ .. "```"
+ .. code_lang
+ .. "\n"
+ .. code_content
+ .. "\n```"
local url = utils.trim_suffix(M.config.openai.endpoint, "/") .. "/v1/chat/completions"
if M.config.provider == "azure" then
@@ -276,23 +309,29 @@ local function call_openai_api_stream(prompt, original_content, on_chunk, on_com
---@diagnostic disable-next-line: unused-local
stream = function(err, data, job)
if err then
- error("Error: " .. vim.inspect(err))
+ on_complete(err)
return
end
- if data then
- for line in data:gmatch("[^\r\n]+") do
- if line:sub(1, 6) == "data: " then
- vim.schedule(function()
- local success, parsed = pcall(fn.json_decode, line:sub(7))
- if success and parsed and parsed.choices and parsed.choices[1].delta.content then
- on_chunk(parsed.choices[1].delta.content)
- elseif success and parsed and parsed.choices and parsed.choices[1].finish_reason == "stop" then
- -- Stream request completed
- on_complete()
- end
- end)
- end
+ if not data then
+ return
+ end
+ for line in data:gmatch("[^\r\n]+") do
+ if line:sub(1, 6) ~= "data: " then
+ return
end
+ vim.schedule(function()
+ local success, parsed = pcall(fn.json_decode, line:sub(7))
+ if not success then
+ error("Error: failed to parse json: " .. parsed)
+ return
+ end
+ if parsed and parsed.choices and parsed.choices[1].delta.content then
+ on_chunk(parsed.choices[1].delta.content)
+ elseif parsed and parsed.choices and parsed.choices[1].finish_reason == "stop" then
+ -- Stream request completed
+ on_complete(nil)
+ end
+ end)
end
end,
headers = {
@@ -313,11 +352,11 @@ local function call_openai_api_stream(prompt, original_content, on_chunk, on_com
})
end
-local function call_ai_api_stream(prompt, original_content, on_chunk, on_complete)
+local function call_ai_api_stream(question, code_lang, code_content, on_chunk, on_complete)
if M.config.provider == "openai" or M.config.provider == "azure" then
- call_openai_api_stream(prompt, original_content, on_chunk, on_complete)
+ call_openai_api_stream(question, code_lang, code_content, on_chunk, on_complete)
elseif M.config.provider == "claude" then
- call_claude_api_stream(prompt, original_content, on_chunk, on_complete)
+ call_claude_api_stream(question, code_lang, code_content, on_chunk, on_complete)
end
end
@@ -522,7 +561,9 @@ function M.render_sidebar()
signal.is_loading = true
- call_ai_api_stream(user_input, content_with_line_numbers, function(chunk)
+ local filetype = api.nvim_get_option_value("filetype", { buf = code_buf })
+
+ call_ai_api_stream(user_input, filetype, content_with_line_numbers, function(chunk)
full_response = full_response .. chunk
update_result_buf_content(
"## " .. timestamp .. "\n\n> " .. user_input:gsub("\n", "\n> ") .. "\n\n" .. full_response
@@ -530,8 +571,23 @@ function M.render_sidebar()
vim.schedule(function()
vim.cmd("redraw")
end)
- end, function()
+ end, function(err)
signal.is_loading = false
+
+ if err ~= nil then
+ update_result_buf_content(
+ "## "
+ .. timestamp
+ .. "\n\n> "
+ .. user_input:gsub("\n", "\n> ")
+ .. "\n\n"
+ .. full_response
+ .. "\n\n**Error**: "
+ .. vim.inspect(err)
+ )
+ return
+ end
+
-- Execute when the stream request is actually completed
update_result_buf_content(
"## "
@@ -687,6 +743,8 @@ function M.setup(opts)
_cur_code_buf = bufnr
end
+ tiktoken.setup("gpt-4o")
+
diff.setup({
debug = false, -- log output to console
default_mappings = M.config.mappings.diff, -- disable buffer local mapping created by this plugin
diff --git a/lua/avante/tiktoken.lua b/lua/avante/tiktoken.lua
new file mode 100644
index 000000000..779266286
--- /dev/null
+++ b/lua/avante/tiktoken.lua
@@ -0,0 +1,103 @@
+-- NOTE: this file is copied from: https://github.com/CopilotC-Nvim/CopilotChat.nvim/blob/canary/lua/CopilotChat/tiktoken.lua
+
+local curl = require("plenary.curl")
+local tiktoken_core = nil
+
+---Get the path of the cache directory
+---@param fname string
+---@return string
+local function get_cache_path(fname)
+ return vim.fn.stdpath("cache") .. "/" .. fname
+end
+
+local function file_exists(name)
+ local f = io.open(name, "r")
+ if f ~= nil then
+ io.close(f)
+ return true
+ else
+ return false
+ end
+end
+
+--- Load tiktoken data from cache or download it
+local function load_tiktoken_data(done, model)
+ local tiktoken_url = "https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken"
+ -- If model is gpt-4o, use o200k_base.tiktoken
+ if model ~= nil and vim.startswith(model, "gpt-4o") then
+ tiktoken_url = "https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken"
+ end
+ local async
+ async = vim.loop.new_async(function()
+ -- Take filename after the last slash of the url
+ local cache_path = get_cache_path(tiktoken_url:match(".+/(.+)"))
+ if not file_exists(cache_path) then
+ vim.schedule(function()
+ curl.get(tiktoken_url, {
+ output = cache_path,
+ })
+ done(cache_path)
+ end)
+ else
+ done(cache_path)
+ end
+ async:close()
+ end)
+ async:send()
+end
+
+local M = {}
+
+---@param model string|nil
+function M.setup(model)
+ local ok, core = pcall(require, "tiktoken_core")
+ if not ok then
+ print("Warn: tiktoken_core is not found!!!!")
+ return
+ end
+
+ load_tiktoken_data(function(path)
+ local special_tokens = {}
+ special_tokens["<|endoftext|>"] = 100257
+ special_tokens["<|fim_prefix|>"] = 100258
+ special_tokens["<|fim_middle|>"] = 100259
+ special_tokens["<|fim_suffix|>"] = 100260
+ special_tokens["<|endofprompt|>"] = 100276
+ local pat_str =
+ "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
+ core.new(path, special_tokens, pat_str)
+ tiktoken_core = core
+ end, model)
+end
+
+function M.available()
+ return tiktoken_core ~= nil
+end
+
+function M.encode(prompt)
+ if not tiktoken_core then
+ return nil
+ end
+ if not prompt or prompt == "" then
+ return nil
+ end
+ -- Check if prompt is a string
+ if type(prompt) ~= "string" then
+ error("Prompt must be a string")
+ end
+ return tiktoken_core.encode(prompt)
+end
+
+function M.count(prompt)
+ if not tiktoken_core then
+ return math.ceil(#prompt * 0.2) -- Fallback to 0.2 character count
+ end
+
+ local tokens = M.encode(prompt)
+ if not tokens then
+ return 0
+ end
+ return #tokens
+end
+
+return M