Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Refactor compare.scopes #2007

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
74 changes: 31 additions & 43 deletions lua/cmp/config/compare.lua
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ compare.locality = setmetatable({
---scopes: Entries defined in a closer scope will be ranked higher (e.g., prefer local variables to globals).
---@type cmp.ComparatorFunctor
compare.scopes = setmetatable({
scopes_map = {},
definition_depths = {},
has_nvim_0_9_features = vim.fn.has('nvim-0.9') == 1,
update = function(self)
local config = require('cmp').get_config()
if not vim.tbl_contains(config.sorting.comparators, compare.scopes) then
Expand All @@ -207,64 +208,51 @@ compare.scopes = setmetatable({

local ok, locals = pcall(require, 'nvim-treesitter.locals')
if ok then
local win, buf = vim.api.nvim_get_current_win(), vim.api.nvim_get_current_buf()
local cursor_row = vim.api.nvim_win_get_cursor(win)[1] - 1

-- Cursor scope.
local cursor_scope = nil
-- Prioritize the older get_scopes method from nvim-treesitter `master` over get from `main`
local scopes = locals.get_scopes and locals.get_scopes(buf) or select(3, locals.get(buf))
for _, scope in ipairs(scopes) do
if scope:start() <= cursor_row and cursor_row <= scope:end_() then
if not cursor_scope then
cursor_scope = scope
else
if cursor_scope:start() <= scope:start() and scope:end_() <= cursor_scope:end_() then
cursor_scope = scope
end
end
elseif cursor_scope and cursor_scope:end_() <= scope:start() then
break
end
self.definition_depths = {}
local buf = vim.api.nvim_get_current_buf()
if self.has_nvim_0_9_features and not vim.b[buf].cmp_buf_has_ts_parser then
return
end

-- Definitions.
local definitions = locals.get_definitions_lookup_table(buf)

-- Narrow definitions.
local get_cursor_node = vim.treesitter.get_node or require('nvim-treesitter.ts_utils').get_node_at_cursor
local cursor_node = get_cursor_node()
local scope_depths = {}
local depth = 0
for scope in locals.iter_scope_tree(cursor_scope, buf) do
local s, e = scope:start(), scope:end_()
-- If there's no cursor node, no iterations are made.
---@diagnostic disable-next-line: param-type-mismatch
for scope in locals.iter_scope_tree(cursor_node, buf) do
scope_depths[scope:id()] = depth
depth = depth + 1
end

-- Check scope's direct child.
for _, definition in pairs(definitions) do
if s <= definition.node:start() and definition.node:end_() <= e then
if scope:id() == locals.containing_scope(definition.node, buf):id() then
local get_node_text = vim.treesitter.get_node_text or vim.treesitter.query.get_node_text
local text = get_node_text(definition.node, buf) or ''
if not self.scopes_map[text] then
self.scopes_map[text] = depth
end
end
-- Map definitions based on their scope relative to the cursor.
local definitions = locals.get_definitions_lookup_table(buf)
local get_node_text = vim.treesitter.get_node_text or vim.treesitter.query.get_node_text
for _, definition in pairs(definitions) do
local definition_depth = scope_depths[locals.containing_scope(definition.node, buf):id()]
local def_text = get_node_text(definition.node, buf) or ''
if definition_depth then
-- Prefer the closest scoped definitions.
if not self.definition_depths[def_text] or self.definition_depths[def_text] > definition_depth then
self.definition_depths[def_text] = definition_depth
end
end
depth = depth + 1
end
end
end,
}, {
---@type fun(self: table, entry1: cmp.Entry, entry2: cmp.Entry): boolean|nil
__call = function(self, entry1, entry2)
local local1 = self.scopes_map[entry1.word]
local local2 = self.scopes_map[entry2.word]
if local1 ~= local2 then
if local1 == nil then
local def_depth1 = self.definition_depths[entry1.word]
local def_depth2 = self.definition_depths[entry2.word]
if def_depth1 ~= def_depth2 then
if def_depth1 == nil then
return false
end
if local2 == nil then
if def_depth2 == nil then
return true
end
return local1 < local2
return def_depth1 < def_depth2
end
end,
})
Expand Down
9 changes: 5 additions & 4 deletions lua/cmp/utils/autocmd.lua
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ autocmd.subscribe = function(events, callback)
vim.api.nvim_create_autocmd(event, {
desc = ('nvim-cmp: autocmd: %s'):format(event),
group = autocmd.group,
callback = function()
autocmd.emit(event)
callback = function(details)
autocmd.emit(event, details)
end,
})
end
Expand All @@ -41,12 +41,13 @@ end

---Emit autocmd
---@param event string
autocmd.emit = function(event)
---@param details table|nil
autocmd.emit = function(event, details)
debug.log(' ')
debug.log(string.format('>>> %s', event))
autocmd.events[event] = autocmd.events[event] or {}
for _, callback in ipairs(autocmd.events[event]) do
callback()
callback(details)
end
end

Expand Down
24 changes: 24 additions & 0 deletions plugin/cmp.lua
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,30 @@ if vim.on_key then
end, vim.api.nvim_create_namespace('cmp.plugin'))
end

-- see compare.scopes
if vim.fn.has('nvim-0.9') == 1 then
local ts = vim.treesitter
local has_ts_parser = ts.language.get_lang
-- vim.treesitter.language.add is recommended for checking treesitter in 0.11 nightly
if vim.fn.has('nvim-0.11') then
has_ts_parser = function(filetype)
local lang = ts.language.get_lang(filetype)
return lang and ts.language.add(lang)
end
end
autocmd.subscribe({ 'FileType' }, function(details)
if has_ts_parser(details.match) then
vim.b[details.buf].cmp_buf_has_ts_parser = true
else
vim.b[details.buf].cmp_buf_has_ts_parser = false
end
end)
autocmd.subscribe({ 'BufUnload' }, function(details)
if vim.treesitter.language.get_lang(details.match) then
vim.b[details.buf].cmp_buf_has_ts_parser = false
end
end)
end

vim.api.nvim_create_user_command('CmpStatus', function()
require('cmp').status()
Expand Down
Loading