diff --git a/doc/neotest.txt b/doc/neotest.txt index a5a174be..c02dba7d 100644 --- a/doc/neotest.txt +++ b/doc/neotest.txt @@ -1278,6 +1278,26 @@ Return~ Return~ `(fun(): neotest.Tree)` + *neotest.Tree:closest_node_with()* +`Tree:closest_node_with`({data_attr}) + +Fetch the first node ascending the tree (including the current one) with the +given data attribute e.g. `range` +Parameters~ +{data_attr} `(string)` +Return~ +`(neotest.Tree)` | nil + + *neotest.Tree:closest_value_for()* +`Tree:closest_value_for`({data_attr}) + +Fetch the first non-nil value for the given data attribute ascending the +tree (including the current node) with the given data attribute. +Parameters~ +{data_attr} `(string)` +Return~ +`(any)` | nil + *neotest.Tree:root()* `Tree:root`() diff --git a/lua/neotest/client/init.lua b/lua/neotest/client/init.lua index 9297af48..a5ba5de4 100644 --- a/lua/neotest/client/init.lua +++ b/lua/neotest/client/init.lua @@ -154,10 +154,11 @@ function neotest.Client:get_nearest(file_path, row, args) return end local nearest - for _, pos in positions:iter_nodes() do - local data = pos:data() - if data.range and data.range[1] <= row then - nearest = pos + for _, node in positions:iter_nodes() do + node = node:closest_node_with("range") or node + local range = node:data().range + if range and range[1] <= row then + nearest = node else return nearest, adapter_id end diff --git a/lua/neotest/client/runner.lua b/lua/neotest/client/runner.lua index 550f4545..109d6e39 100644 --- a/lua/neotest/client/runner.lua +++ b/lua/neotest/client/runner.lua @@ -227,10 +227,15 @@ function TestRunner:_missing_results(tree, results, partial) local root = tree:data() local missing_tests = {} + local all_position_ids = {} + for _, pos in tree:iter() do + all_position_ids[pos.id] = true + end + local function propagate_result_upwards(node) for parent in node:iter_parents() do local parent_pos = parent:data() - if not lib.positions.contains(root, parent_pos) then + if not all_position_ids[parent_pos.id] then return end diff --git a/lua/neotest/consumers/diagnostic.lua b/lua/neotest/consumers/diagnostic.lua index 605c4fda..7c84ee24 100644 --- a/lua/neotest/consumers/diagnostic.lua +++ b/lua/neotest/consumers/diagnostic.lua @@ -79,7 +79,11 @@ local function init(client) local result = results[pos_id] if position.type == "test" and result and result.errors and #result.errors > 0 then local placed = self.tracking_marks[pos_id] - or self:init_mark(pos_id, result.errors, positions:get_key(pos_id):data().range[1]) + or self:init_mark( + pos_id, + result.errors, + positions:get_key(pos_id):closest_value_for("range")[1] + ) if placed then for error_i, error in pairs(result.errors or {}) do local mark = api.nvim_buf_get_extmark_by_id( diff --git a/lua/neotest/consumers/jump.lua b/lua/neotest/consumers/jump.lua index 7bd8a32c..b9d708c8 100644 --- a/lua/neotest/consumers/jump.lua +++ b/lua/neotest/consumers/jump.lua @@ -31,7 +31,7 @@ local get_nearest = function() end local function jump_to(node) - local range = node:data().range + local range = node:closest_value_for("range") async.api.nvim_win_set_cursor(0, { range[1] + 1, range[2] }) end @@ -48,7 +48,7 @@ local jump_to_prev = function(pos, predicate) if pos:data().type == "file" then return false end - if async.api.nvim_win_get_cursor(0)[1] - 1 > pos:data().range[1] then + if async.api.nvim_win_get_cursor(0)[1] - 1 > pos:closest_value_for("range")[1] then jump_to(pos) return true end diff --git a/lua/neotest/consumers/output.lua b/lua/neotest/consumers/output.lua index e6645c6f..2926a630 100644 --- a/lua/neotest/consumers/output.lua +++ b/lua/neotest/consumers/output.lua @@ -112,13 +112,15 @@ local init = function() if not positions then return end - for _, pos in positions:iter() do + for _, node in positions:iter_nodes() do + local pos = node:data() + local range = node:closest_value_for("range") if pos.type == "test" and results[pos.id] and results[pos.id].status == "failed" - and pos.range[1] <= line - and pos.range[3] >= line + and range[1] <= line + and range[3] >= line then open_output( results[pos.id], diff --git a/lua/neotest/consumers/status.lua b/lua/neotest/consumers/status.lua index ec7f1d0b..f570acd1 100644 --- a/lua/neotest/consumers/status.lua +++ b/lua/neotest/consumers/status.lua @@ -15,7 +15,7 @@ local function init(client) local namespace = async.api.nvim_create_namespace(sign_group) - local function place_sign(buf, pos, adapter_id, results) + local function place_sign(buf, pos, range, adapter_id, results) local status if results[pos.id] then local result = results[pos.id] @@ -28,12 +28,12 @@ local function init(client) end if config.status.signs then async.fn.sign_place(0, sign_group, "neotest_" .. status, pos.path, { - lnum = pos.range[1] + 1, + lnum = range[1] + 1, priority = 1000, }) end if config.status.virtual_text then - async.api.nvim_buf_set_extmark(buf, namespace, pos.range[1], 0, { + async.api.nvim_buf_set_extmark(buf, namespace, range[1], 0, { virt_text = { { statuses[status].text .. " ", statuses[status].texthl }, }, @@ -51,9 +51,11 @@ local function init(client) if not tree then return end - for _, pos in tree:iter() do + for _, node in tree:iter_nodes() do + local pos = node:data() + local range = node:closest_value_for("range") if pos.type ~= "file" then - place_sign(async.fn.bufnr(file_path), pos, adapter_id, results) + place_sign(async.fn.bufnr(file_path), pos, range, adapter_id, results) end end end diff --git a/lua/neotest/consumers/summary/component.lua b/lua/neotest/consumers/summary/component.lua index 533518b3..99273b48 100644 --- a/lua/neotest/consumers/summary/component.lua +++ b/lua/neotest/consumers/summary/component.lua @@ -167,7 +167,8 @@ function SummaryComponent:_render(canvas, tree, expanded, focused, indent) if position.type == "file" then lib.ui.open_buf(buf) else - lib.ui.open_buf(buf, position.range[1], position.range[2]) + local range = node:closest_value_for("range") + lib.ui.open_buf(buf, range[1], range[2]) end end) end diff --git a/lua/neotest/types/tree.lua b/lua/neotest/types/tree.lua index f84b8bec..87915c52 100644 --- a/lua/neotest/types/tree.lua +++ b/lua/neotest/types/tree.lua @@ -156,6 +156,29 @@ function neotest.Tree:iter_parents() end end +--- Fetch the first node ascending the tree (including the current one) with the +--- given data attribute e.g. `range` +---@param data_attr string +---@return neotest.Tree | nil +function neotest.Tree:closest_node_with(data_attr) + if self:data()[data_attr] ~= nil then + return self + end + for parent in self:iter_parents() do + if parent:data()[data_attr] ~= nil then + return parent + end + end +end + +--- Fetch the first non-nil value for the given data attribute ascending the +--- tree (including the current node) with the given data attribute. +---@param data_attr string +---@return any | nil +function neotest.Tree:closest_value_for(data_attr) + return self:closest_node_with(data_attr):data()[data_attr] +end + ---@return neotest.Tree function neotest.Tree:root() local node = self