Skip to content

Commit

Permalink
feat: more like Cursor
Browse files Browse the repository at this point in the history
  • Loading branch information
yetone committed Aug 15, 2024
1 parent cca2c23 commit b4d4080
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 39 deletions.
78 changes: 47 additions & 31 deletions lua/avante/diff.lua
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,9 @@ local INCOMING_LABEL_HL = "AvanteConflictIncomingLabel"
local ANCESTOR_LABEL_HL = "AvanteConflictAncestorLabel"
local PRIORITY = vim.highlight.priorities.user
local NAMESPACE = api.nvim_create_namespace("avante-conflict")
local KEYBINDING_NAMESPACE = api.nvim_create_namespace("avante-conflict-keybinding")
local AUGROUP_NAME = "AvanteConflictCommands"

local sep = package.config:sub(1, 1)

local conflict_start = "^<<<<<<<"
local conflict_middle = "^======="
local conflict_end = "^>>>>>>>"
Expand Down Expand Up @@ -165,35 +164,6 @@ local visited_buffers = create_visited_buffers()

-----------------------------------------------------------------------------//

---Get full path to the repository of the directory passed in
---@param dir any
---@param callback fun(data: string)
local function get_git_root(dir, callback)
job({ "git", "-C", dir, "rev-parse", "--show-toplevel" }, function(data)
callback(data[1])
end)
end

--- Get a list of the conflicted files within the specified directory
--- NOTE: only conflicted files within the git repository of the directory passed in are returned
--- also we add a line prefix to the git command so that the full path is returned
--- e.g. --line-prefix=`git rev-parse --show-toplevel`
---@reference: https://stackoverflow.com/a/10874862
---@param dir string?
---@param callback fun(files: table<string, integer[]>, string)
local function get_conflicted_files(dir, callback)
local cmd = { "git", "-C", dir, "diff", ("--line-prefix=%s%s"):format(dir, sep), "--name-only", "--diff-filter=U" }
job(cmd, function(data)
local files = {}
for _, filename in ipairs(data) do
if #filename > 0 then
files[filename] = files[filename] or {}
end
end
callback(files, dir)
end)
end

---Add the positions to the buffer in our in memory buffer list
---positions are keyed by a list of range start and end for each mark
---@param buf integer
Expand Down Expand Up @@ -397,6 +367,50 @@ local function set_cursor(position, side)
api.nvim_win_set_cursor(0, { target.range_start + 1, 0 })
end

local function register_cursor_move_events(bufnr)
local show_keybinding_hint_extmark_id = nil

local function show_keybinding_hint(lnum)
if show_keybinding_hint_extmark_id then
api.nvim_buf_del_extmark(bufnr, KEYBINDING_NAMESPACE, show_keybinding_hint_extmark_id)
end

local hint = string.format(
" [Press <%s> for CHOICE OURS, <%s> for CHOICE THEIRS, <%s> for PREV, <%s> for NEXT] ",
config.default_mappings.ours,
config.default_mappings.theirs,
config.default_mappings.prev,
config.default_mappings.next
)
local win_width = api.nvim_win_get_width(0)
local col = win_width - #hint - math.ceil(win_width * 0.3) - 4

if col < 0 then
col = 0
end

show_keybinding_hint_extmark_id = api.nvim_buf_set_extmark(bufnr, KEYBINDING_NAMESPACE, lnum - 1, -1, {
hl_group = "Keyword",
virt_text = { { hint, "Keyword" } },
virt_text_win_col = col,
priority = PRIORITY,
})
end

api.nvim_create_autocmd({ "CursorMoved", "CursorMovedI" }, {
buffer = bufnr,
callback = function()
local position = get_current_position(bufnr)

if position then
show_keybinding_hint(position.current.range_start + 1)
else
api.nvim_buf_clear_namespace(bufnr, KEYBINDING_NAMESPACE, 0, -1)
end
end,
})
end

---Get the conflict marker positions for a buffer if any and update the buffers state
---@param bufnr integer
---@param range_start integer
Expand All @@ -408,6 +422,7 @@ local function parse_buffer(bufnr, range_start, range_end)

update_visited_buffers(bufnr, positions)
if has_conflict then
register_cursor_move_events(bufnr)
highlight_conflicts(positions, lines)
else
M.clear(bufnr)
Expand Down Expand Up @@ -664,6 +679,7 @@ function M.clear(bufnr)
end
bufnr = bufnr or 0
api.nvim_buf_clear_namespace(bufnr, NAMESPACE, 0, -1)
api.nvim_buf_clear_namespace(bufnr, KEYBINDING_NAMESPACE, 0, -1)
end

---@param side ConflictSide
Expand Down
19 changes: 11 additions & 8 deletions lua/avante/sidebar.lua
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ local fn = vim.fn
local RESULT_BUF_NAME = "AVANTE_RESULT"
local CONFLICT_BUF_NAME = "AVANTE_CONFLICT"

local NAMESPACE = vim.api.nvim_create_namespace("AVANTE_CODEBLOCK")
local CODEBLOCK_KEYBINDING_NAMESPACE = vim.api.nvim_create_namespace("AVANTE_CODEBLOCK_KEYBINDING")
local PRIORITY = vim.highlight.priorities.user

local function parse_codeblocks(buf)
local codeblocks = {}
Expand Down Expand Up @@ -370,14 +371,16 @@ function M.render_sidebar()

local function show_apply_button(block)
if current_apply_extmark_id then
api.nvim_buf_del_extmark(result_buf, NAMESPACE, current_apply_extmark_id)
api.nvim_buf_del_extmark(result_buf, CODEBLOCK_KEYBINDING_NAMESPACE, current_apply_extmark_id)
end

current_apply_extmark_id = api.nvim_buf_set_extmark(result_buf, NAMESPACE, block.start_line, -1, {
virt_text = { { "[Press A to Apply these patches]", "Keyword" } },
virt_text_pos = "right_align",
hl_group = "Keyword",
})
current_apply_extmark_id =
api.nvim_buf_set_extmark(result_buf, CODEBLOCK_KEYBINDING_NAMESPACE, block.start_line, -1, {
virt_text = { { " [Press <A> to Apply these patches] ", "Keyword" } },
virt_text_pos = "right_align",
hl_group = "Keyword",
priority = PRIORITY,
})
end

local function apply()
Expand Down Expand Up @@ -429,7 +432,7 @@ function M.render_sidebar()
show_apply_button(block)
bind_apply_key()
else
vim.api.nvim_buf_clear_namespace(result_buf, NAMESPACE, 0, -1)
api.nvim_buf_clear_namespace(result_buf, CODEBLOCK_KEYBINDING_NAMESPACE, 0, -1)
unbind_apply_key()
end
end,
Expand Down

0 comments on commit b4d4080

Please sign in to comment.