Skip to content

Commit

Permalink
feat: ask selected code block (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
yetone authored Aug 17, 2024
1 parent dea737b commit 3dca5f4
Show file tree
Hide file tree
Showing 9 changed files with 397 additions and 89 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ Default setup configuration:
},
},
mappings = {
show_sidebar = "<leader>aa",
ask = "<leader>aa",
diff = {
ours = "co",
theirs = "ct",
Expand Down
87 changes: 64 additions & 23 deletions lua/avante/ai_bot.lua
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ Your primary task is to suggest code modifications with precise line number rang
1. Carefully analyze the original code, paying close attention to its structure and line numbers. Line numbers start from 1 and include ALL lines, even empty ones.
2. When suggesting modifications:
a. Explain why the change is necessary or beneficial.
b. Provide the exact code snippet to be replaced using this format:
a. Use the language in the question to reply. If there are non-English parts in the question, use the language of those parts.
b. Explain why the change is necessary or beneficial.
c. Provide the exact code snippet to be replaced using this format:
Replace lines: {{start_line}}-{{end_line}}
```{{language}}
Expand Down Expand Up @@ -58,14 +59,12 @@ Replace lines: {{start_line}}-{{end_line}}
Remember: Accurate line numbers are CRITICAL. The range start_line to end_line must include ALL lines to be replaced, from the very first to the very last. Double-check every range before finalizing your response, paying special attention to the start_line to ensure it hasn't shifted down. Ensure that your line numbers perfectly match the original code structure without any overall shift.
]]

local function call_claude_api_stream(question, code_lang, code_content, on_chunk, on_complete)
local function call_claude_api_stream(question, code_lang, code_content, selected_code_content, on_chunk, on_complete)
local api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key then
error("ANTHROPIC_API_KEY environment variable is not set")
end

local user_prompt = base_user_prompt

local tokens = Config.claude.max_tokens
local headers = {
["Content-Type"] = "application/json",
Expand All @@ -79,33 +78,56 @@ local function call_claude_api_stream(question, code_lang, code_content, on_chun
text = string.format("<code>```%s\n%s```</code>", code_lang, code_content),
}

if Tiktoken.count(code_prompt_obj.text) > 1024 then
code_prompt_obj.cache_control = { type = "ephemeral" }
end

if selected_code_content then
code_prompt_obj.text = string.format("<code_context>```%s\n%s```</code_context>", code_lang, code_content)
end

local message_content = {
code_prompt_obj,
}

if selected_code_content then
local selected_code_obj = {
type = "text",
text = string.format("<code>```%s\n%s```</code>", code_lang, selected_code_content),
}

if Tiktoken.count(selected_code_obj.text) > 1024 then
selected_code_obj.cache_control = { type = "ephemeral" }
end

table.insert(message_content, selected_code_obj)
end

table.insert(message_content, {
type = "text",
text = string.format("<question>%s</question>", question),
})

local user_prompt = base_user_prompt

local user_prompt_obj = {
type = "text",
text = user_prompt,
}

if Tiktoken.count(code_prompt_obj.text) > 1024 then
code_prompt_obj.cache_control = { type = "ephemeral" }
end

if Tiktoken.count(user_prompt_obj.text) > 1024 then
user_prompt_obj.cache_control = { type = "ephemeral" }
end

table.insert(message_content, user_prompt_obj)

local body = {
model = Config.claude.model,
system = system_prompt,
messages = {
{
role = "user",
content = {
code_prompt_obj,
{
type = "text",
text = string.format("<question>%s</question>", question),
},
user_prompt_obj,
},
content = message_content,
},
},
stream = true,
Expand Down Expand Up @@ -154,21 +176,39 @@ local function call_claude_api_stream(question, code_lang, code_content, on_chun
})
end

local function call_openai_api_stream(question, code_lang, code_content, on_chunk, on_complete)
local function call_openai_api_stream(question, code_lang, code_content, selected_code_content, on_chunk, on_complete)
local api_key = os.getenv("OPENAI_API_KEY")
if not api_key and Config.provider == "openai" then
error("OPENAI_API_KEY environment variable is not set")
end

local user_prompt = base_user_prompt
.. "\n\nQUESTION:\n"
.. question
.. "\n\nCODE:\n"
.. "```"
.. code_lang
.. "\n"
.. code_content
.. "\n```"
.. "\n\nQUESTION:\n"
.. question

if selected_code_content then
user_prompt = base_user_prompt
.. "\n\nCODE CONTEXT:\n"
.. "```"
.. code_lang
.. "\n"
.. code_content
.. "\n```"
.. "\n\nCODE:\n"
.. "```"
.. code_lang
.. "\n"
.. selected_code_content
.. "\n```"
.. "\n\nQUESTION:\n"
.. question
end

local url, headers, body
if Config.provider == "azure" then
Expand Down Expand Up @@ -258,13 +298,14 @@ end
---@param question string
---@param code_lang string
---@param code_content string
---@param selected_content_content string | nil
---@param on_chunk fun(chunk: string): any
---@param on_complete fun(err: string|nil): any
function M.call_ai_api_stream(question, code_lang, code_content, on_chunk, on_complete)
function M.call_ai_api_stream(question, code_lang, code_content, selected_content_content, on_chunk, on_complete)
if Config.provider == "openai" or Config.provider == "azure" then
call_openai_api_stream(question, code_lang, code_content, on_chunk, on_complete)
call_openai_api_stream(question, code_lang, code_content, selected_content_content, on_chunk, on_complete)
elseif Config.provider == "claude" then
call_claude_api_stream(question, code_lang, code_content, on_chunk, on_complete)
call_claude_api_stream(question, code_lang, code_content, selected_content_content, on_chunk, on_complete)
end
end

Expand Down
3 changes: 2 additions & 1 deletion lua/avante/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ M.defaults = {
},
},
mappings = {
show_sidebar = "<leader>aa",
ask = "<leader>aa",
edit = "<leader>ae",
diff = {
ours = "co",
theirs = "ct",
Expand Down
12 changes: 10 additions & 2 deletions lua/avante/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ local Tiktoken = require("avante.tiktoken")
local Sidebar = require("avante.sidebar")
local Config = require("avante.config")
local Diff = require("avante.diff")
local Selection = require("avante.selection")

---@class Avante
local M = {
---@type avante.Sidebar[] we use this to track chat command across tabs
sidebars = {},
---@type avante.Sidebar
current = nil,
selection = nil,
_once = false,
}

Expand All @@ -35,7 +37,7 @@ H.commands = function()
end

H.keymaps = function()
vim.keymap.set({ "n" }, Config.mappings.show_sidebar, M.toggle, { noremap = true })
vim.keymap.set({ "n", "v" }, Config.mappings.ask, M.toggle, { noremap = true })
end

H.autocmds = function()
Expand Down Expand Up @@ -76,7 +78,9 @@ H.autocmds = function()
if s then
s:destroy()
end
M.sidebars[tab] = nil
if tab ~= nil then
M.sidebars[tab] = nil
end
end,
})

Expand Down Expand Up @@ -137,6 +141,10 @@ function M.setup(opts)
highlights = Config.highlights.diff,
})

local selection = Selection:new()
selection:setup()
M.selection = selection

-- setup helpers
H.autocmds()
H.commands()
Expand Down
24 changes: 24 additions & 0 deletions lua/avante/range.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
--@class avante.Range
--@field start table Selection start point
--@field start.line number Line number of the selection start
--@field start.col number Column number of the selection start
--@field finish table Selection end point
--@field finish.line number Line number of the selection end
--@field finish.col number Column number of the selection end
local Range = {}
Range.__index = Range
-- Create a selection range
-- @param start table Selection start point
-- @param start.line number Line number of the selection start
-- @param start.col number Column number of the selection start
-- @param finish table Selection end point
-- @param finish.line number Line number of the selection end
-- @param finish.col number Column number of the selection end
function Range.new(start, finish)
local self = setmetatable({}, Range)
self.start = start
self.finish = finish
return self
end

return Range
95 changes: 95 additions & 0 deletions lua/avante/selection.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
local Config = require("avante.config")

local api = vim.api
local fn = vim.fn

local NAMESPACE = api.nvim_create_namespace("avante_selection")
local PRIORITY = vim.highlight.priorities.user

local Selection = {}

function Selection:new()
return setmetatable({
hints_popup_extmark_id = nil,
edit_popup_renderer = nil,
augroup = api.nvim_create_augroup("avante_selection", { clear = true }),
}, { __index = self })
end

function Selection:get_virt_text_line()
local current_pos = fn.getpos(".")

-- Get the current and start position line numbers
local current_line = current_pos[2] - 1 -- 0-indexed

-- Ensure line numbers are not negative and don't exceed buffer range
local total_lines = api.nvim_buf_line_count(0)
if current_line < 0 then
current_line = 0
end
if current_line >= total_lines then
current_line = total_lines - 1
end

-- Take the first line of the selection to ensure virt_text is always in the top right corner
return current_line
end

function Selection:show_hints_popup()
self:close_hints_popup()

local hint_text = string.format(" [Ask %s] ", Config.mappings.ask)

local virt_text_line = self:get_virt_text_line()

self.hints_popup_extmark_id = vim.api.nvim_buf_set_extmark(0, NAMESPACE, virt_text_line, -1, {
virt_text = { { hint_text, "Keyword" } },
virt_text_pos = "eol",
priority = PRIORITY,
})
end

function Selection:close_hints_popup()
if self.hints_popup_extmark_id then
vim.api.nvim_buf_del_extmark(0, NAMESPACE, self.hints_popup_extmark_id)
self.hints_popup_extmark_id = nil
end
end

function Selection:setup()
vim.api.nvim_create_autocmd({ "ModeChanged" }, {
group = self.augroup,
pattern = { "n:v", "n:V", "n:" }, -- Entering Visual mode from Normal mode
callback = function()
self:show_hints_popup()
end,
})

api.nvim_create_autocmd({ "CursorMoved", "CursorMovedI" }, {
group = self.augroup,
callback = function()
if vim.fn.mode() == "v" or vim.fn.mode() == "V" or vim.fn.mode() == "" then
self:show_hints_popup()
else
self:close_hints_popup()
end
end,
})

api.nvim_create_autocmd({ "ModeChanged" }, {
group = self.augroup,
pattern = { "v:n", "v:i", "v:c" }, -- Switching from visual mode back to normal, insert, or other modes
callback = function()
self:close_hints_popup()
end,
})
end

function Selection:delete_autocmds()
if self.augroup then
vim.api.nvim_del_augroup_by_id(self.augroup)
end
self.augroup = nil
end

return Selection
17 changes: 17 additions & 0 deletions lua/avante/selection_result.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
--@class avante.SelectionResult
--@field content string Selected content
--@field range avante.Range Selection range
local SelectionResult = {}
SelectionResult.__index = SelectionResult

-- Create a selection content and range
--@param content string Selected content
--@param range avante.Range Selection range
function SelectionResult.new(content, range)
local self = setmetatable({}, SelectionResult)
self.content = content
self.range = range
return self
end

return SelectionResult
Loading

0 comments on commit 3dca5f4

Please sign in to comment.