diff --git a/lua/avante/init.lua b/lua/avante/init.lua index 01350fedd..86319838b 100644 --- a/lua/avante/init.lua +++ b/lua/avante/init.lua @@ -4,6 +4,7 @@ local Path = require("plenary.path") local n = require("nui-components") local diff = require("avante.diff") local utils = require("avante.utils") +local tiktoken = require("avante.tiktoken") local api = vim.api local fn = vim.fn @@ -140,7 +141,7 @@ local system_prompt = [[ You are an excellent programming expert. ]] -local user_prompt_tpl = [[ +local base_user_prompt = [[ Your primary task is to suggest code modifications with precise line number ranges. Follow these instructions meticulously: 1. Carefully analyze the original code, paying close attention to its structure and line numbers. Line numbers start from 1 and include ALL lines, even empty ones. @@ -183,87 +184,119 @@ Replace lines: {{start_line}}-{{end_line}} - Do not show the content after these modifications. Remember: Accurate line numbers are CRITICAL. The range start_line to end_line must include ALL lines to be replaced, from the very first to the very last. Double-check every range before finalizing your response, paying special attention to the start_line to ensure it hasn't shifted down. Ensure that your line numbers perfectly match the original code structure without any overall shift. - -QUESTION: ${{question}} - -CODE: -``` -${{code}} -``` ]] -local function call_claude_api_stream(prompt, original_content, on_chunk, on_complete) +local function call_claude_api_stream(question, code_lang, code_content, on_chunk, on_complete) local api_key = os.getenv("ANTHROPIC_API_KEY") if not api_key then error("ANTHROPIC_API_KEY environment variable is not set") end - local user_prompt = user_prompt_tpl:gsub("${{question}}", prompt):gsub("${{code}}", original_content) - - print("Sending request to Claude API...") + local user_prompt = base_user_prompt - local tokens = M.config.claude.model == "claude-3-5-sonnet-20240620" and 8192 or 4096 + local tokens = M.config.claude.max_tokens local headers = { ["Content-Type"] = "application/json", ["x-api-key"] = api_key, ["anthropic-version"] = "2023-06-01", - ["anthropic-beta"] = "messages-2023-12-15", + ["anthropic-beta"] = "prompt-caching-2024-07-31", } - if M.config.claude.model == "claude-3-5-sonnet-20240620" then - headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15" + local code_prompt_obj = { + type = "text", + text = string.format("```%s\n%s```", code_lang, code_content), + } + + local user_prompt_obj = { + type = "text", + text = user_prompt, + } + + if tiktoken.count(code_prompt_obj.text) > 1024 then + code_prompt_obj.cache_control = { type = "ephemeral" } + end + + if tiktoken.count(user_prompt_obj.text) > 1024 then + user_prompt_obj.cache_control = { type = "ephemeral" } end + local params = { + model = M.config.claude.model, + system = system_prompt, + messages = { + { + role = "user", + content = { + code_prompt_obj, + { + type = "text", + text = string.format("%s", question), + }, + user_prompt_obj, + }, + }, + }, + stream = true, + temperature = M.config.claude.temperature, + max_tokens = tokens, + } + local url = utils.trim_suffix(M.config.claude.endpoint, "/") .. "/v1/messages" + print("Sending request to Claude API...") + curl.post(url, { ---@diagnostic disable-next-line: unused-local stream = function(err, data, job) if err then - error("Error: " .. vim.inspect(err)) + on_complete(err) return end - if data then - for line in data:gmatch("[^\r\n]+") do - if line:sub(1, 6) == "data: " then - vim.schedule(function() - local success, parsed = pcall(fn.json_decode, line:sub(7)) - if success and parsed and parsed.type == "content_block_delta" then - on_chunk(parsed.delta.text) - elseif success and parsed and parsed.type == "message_stop" then - -- Stream request completed - on_complete() - elseif success and parsed and parsed.type == "error" then - print("Error: " .. vim.inspect(parsed)) - -- Stream request completed - on_complete() - end - end) - end + if not data then + return + end + for line in data:gmatch("[^\r\n]+") do + if line:sub(1, 6) ~= "data: " then + return end + vim.schedule(function() + local success, parsed = pcall(fn.json_decode, line:sub(7)) + if not success then + error("Error: failed to parse json: " .. parsed) + return + end + if parsed and parsed.type == "content_block_delta" then + on_chunk(parsed.delta.text) + elseif parsed and parsed.type == "message_stop" then + -- Stream request completed + on_complete(nil) + elseif parsed and parsed.type == "error" then + -- Stream request completed + on_complete(parsed) + end + end) end end, headers = headers, - body = fn.json_encode({ - model = M.config.claude.model, - system = system_prompt, - messages = { - { role = "user", content = user_prompt }, - }, - stream = true, - temperature = M.config.claude.temperature, - max_tokens = tokens, - }), + body = fn.json_encode(params), }) end -local function call_openai_api_stream(prompt, original_content, on_chunk, on_complete) +local function call_openai_api_stream(question, code_lang, code_content, on_chunk, on_complete) local api_key = os.getenv("OPENAI_API_KEY") if not api_key then error("OPENAI_API_KEY environment variable is not set") end - local user_prompt = user_prompt_tpl:gsub("${{question}}", prompt):gsub("${{code}}", original_content) + local user_prompt = base_user_prompt + .. "\n\nQUESTION:\n" + .. question + .. "\n\nCODE:\n" + .. "```" + .. code_lang + .. "\n" + .. code_content + .. "\n```" local url = utils.trim_suffix(M.config.openai.endpoint, "/") .. "/v1/chat/completions" if M.config.provider == "azure" then @@ -276,23 +309,29 @@ local function call_openai_api_stream(prompt, original_content, on_chunk, on_com ---@diagnostic disable-next-line: unused-local stream = function(err, data, job) if err then - error("Error: " .. vim.inspect(err)) + on_complete(err) return end - if data then - for line in data:gmatch("[^\r\n]+") do - if line:sub(1, 6) == "data: " then - vim.schedule(function() - local success, parsed = pcall(fn.json_decode, line:sub(7)) - if success and parsed and parsed.choices and parsed.choices[1].delta.content then - on_chunk(parsed.choices[1].delta.content) - elseif success and parsed and parsed.choices and parsed.choices[1].finish_reason == "stop" then - -- Stream request completed - on_complete() - end - end) - end + if not data then + return + end + for line in data:gmatch("[^\r\n]+") do + if line:sub(1, 6) ~= "data: " then + return end + vim.schedule(function() + local success, parsed = pcall(fn.json_decode, line:sub(7)) + if not success then + error("Error: failed to parse json: " .. parsed) + return + end + if parsed and parsed.choices and parsed.choices[1].delta.content then + on_chunk(parsed.choices[1].delta.content) + elseif parsed and parsed.choices and parsed.choices[1].finish_reason == "stop" then + -- Stream request completed + on_complete(nil) + end + end) end end, headers = { @@ -313,11 +352,11 @@ local function call_openai_api_stream(prompt, original_content, on_chunk, on_com }) end -local function call_ai_api_stream(prompt, original_content, on_chunk, on_complete) +local function call_ai_api_stream(question, code_lang, code_content, on_chunk, on_complete) if M.config.provider == "openai" or M.config.provider == "azure" then - call_openai_api_stream(prompt, original_content, on_chunk, on_complete) + call_openai_api_stream(question, code_lang, code_content, on_chunk, on_complete) elseif M.config.provider == "claude" then - call_claude_api_stream(prompt, original_content, on_chunk, on_complete) + call_claude_api_stream(question, code_lang, code_content, on_chunk, on_complete) end end @@ -522,7 +561,9 @@ function M.render_sidebar() signal.is_loading = true - call_ai_api_stream(user_input, content_with_line_numbers, function(chunk) + local filetype = api.nvim_get_option_value("filetype", { buf = code_buf }) + + call_ai_api_stream(user_input, filetype, content_with_line_numbers, function(chunk) full_response = full_response .. chunk update_result_buf_content( "## " .. timestamp .. "\n\n> " .. user_input:gsub("\n", "\n> ") .. "\n\n" .. full_response @@ -530,8 +571,23 @@ function M.render_sidebar() vim.schedule(function() vim.cmd("redraw") end) - end, function() + end, function(err) signal.is_loading = false + + if err ~= nil then + update_result_buf_content( + "## " + .. timestamp + .. "\n\n> " + .. user_input:gsub("\n", "\n> ") + .. "\n\n" + .. full_response + .. "\n\n**Error**: " + .. vim.inspect(err) + ) + return + end + -- Execute when the stream request is actually completed update_result_buf_content( "## " @@ -687,6 +743,8 @@ function M.setup(opts) _cur_code_buf = bufnr end + tiktoken.setup("gpt-4o") + diff.setup({ debug = false, -- log output to console default_mappings = M.config.mappings.diff, -- disable buffer local mapping created by this plugin diff --git a/lua/avante/tiktoken.lua b/lua/avante/tiktoken.lua new file mode 100644 index 000000000..779266286 --- /dev/null +++ b/lua/avante/tiktoken.lua @@ -0,0 +1,103 @@ +-- NOTE: this file is copied from: https://github.com/CopilotC-Nvim/CopilotChat.nvim/blob/canary/lua/CopilotChat/tiktoken.lua + +local curl = require("plenary.curl") +local tiktoken_core = nil + +---Get the path of the cache directory +---@param fname string +---@return string +local function get_cache_path(fname) + return vim.fn.stdpath("cache") .. "/" .. fname +end + +local function file_exists(name) + local f = io.open(name, "r") + if f ~= nil then + io.close(f) + return true + else + return false + end +end + +--- Load tiktoken data from cache or download it +local function load_tiktoken_data(done, model) + local tiktoken_url = "https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken" + -- If model is gpt-4o, use o200k_base.tiktoken + if model ~= nil and vim.startswith(model, "gpt-4o") then + tiktoken_url = "https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken" + end + local async + async = vim.loop.new_async(function() + -- Take filename after the last slash of the url + local cache_path = get_cache_path(tiktoken_url:match(".+/(.+)")) + if not file_exists(cache_path) then + vim.schedule(function() + curl.get(tiktoken_url, { + output = cache_path, + }) + done(cache_path) + end) + else + done(cache_path) + end + async:close() + end) + async:send() +end + +local M = {} + +---@param model string|nil +function M.setup(model) + local ok, core = pcall(require, "tiktoken_core") + if not ok then + print("Warn: tiktoken_core is not found!!!!") + return + end + + load_tiktoken_data(function(path) + local special_tokens = {} + special_tokens["<|endoftext|>"] = 100257 + special_tokens["<|fim_prefix|>"] = 100258 + special_tokens["<|fim_middle|>"] = 100259 + special_tokens["<|fim_suffix|>"] = 100260 + special_tokens["<|endofprompt|>"] = 100276 + local pat_str = + "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + core.new(path, special_tokens, pat_str) + tiktoken_core = core + end, model) +end + +function M.available() + return tiktoken_core ~= nil +end + +function M.encode(prompt) + if not tiktoken_core then + return nil + end + if not prompt or prompt == "" then + return nil + end + -- Check if prompt is a string + if type(prompt) ~= "string" then + error("Prompt must be a string") + end + return tiktoken_core.encode(prompt) +end + +function M.count(prompt) + if not tiktoken_core then + return math.ceil(#prompt * 0.2) -- Fallback to 0.2 character count + end + + local tokens = M.encode(prompt) + if not tokens then + return 0 + end + return #tokens +end + +return M