diff --git a/lua/neorg/modules/core/dirman/utils/module.lua b/lua/neorg/modules/core/dirman/utils/module.lua index c97628783..10b33e6ab 100644 --- a/lua/neorg/modules/core/dirman/utils/module.lua +++ b/lua/neorg/modules/core/dirman/utils/module.lua @@ -18,10 +18,18 @@ local module = neorg.modules.create("core.dirman.utils") ---@class core.dirman.utils module.public = { ---Resolve `$/path/to/file` and return the real path - ---@param path string|PathlibPath # path - ---@param raw_path boolean? # If true, returns resolved path, otherwise, returns resolved path and append ".norg" - ---@return PathlibPath? # Resolved path. If path does not start with `$` or not absolute, adds relative from current file. - expand_pathlib = function(path, raw_path) + ---@param path string | PathlibPath # path + ---@param raw_path boolean? # If true, returns resolved path, otherwise, returns resolved path + ---and append ".norg" + ---@param host_file string | PathlibPath | nil file the link resides in, if the link is + ---relative, this file is used instead of the current file + ---@return PathlibPath?, boolean? # Resolved path. If path does not start with `$` or not absolute, adds + ---relative from current file. + expand_pathlib = function(path, raw_path, host_file) + local relative = false + if not host_file then + host_file = vim.fn.expand("%:p") + end local filepath = Path(path) -- Expand special chars like `$` local custom_workspace_path = filepath:match("^%$([^/\\]*)[/\\]") @@ -48,7 +56,8 @@ module.public = { filepath = workspace / filepath:relative_to(Path("$" .. custom_workspace_path)) end elseif filepath:is_relative() then - local this_file = Path(vim.fn.expand("%:p")):absolute() + relative = true + local this_file = Path(host_file):absolute() filepath = this_file:parent_assert() / filepath else filepath = filepath:absolute() @@ -66,7 +75,7 @@ module.public = { end filepath = filepath:add_suffix(".norg") end - return filepath + return filepath, relative end, ---Call attempt to edit a file, catches and suppresses the error caused by a swap file being diff --git a/lua/neorg/modules/core/integrations/treesitter/module.lua b/lua/neorg/modules/core/integrations/treesitter/module.lua index 72ad8c0b9..5d521494b 100644 --- a/lua/neorg/modules/core/integrations/treesitter/module.lua +++ b/lua/neorg/modules/core/integrations/treesitter/module.lua @@ -215,9 +215,9 @@ module.public = { end end, --- Gets all nodes of a given type from the AST - ---@param type string #The type of node to filter out + ---@param node_type string #The type of node to filter out ---@param opts? table #A table of two options: `buf` and `ft`, for the buffer and format to use respectively. - get_all_nodes = function(type, opts) + get_all_nodes = function(node_type, opts) local result = {} opts = opts or {} @@ -231,31 +231,51 @@ module.public = { -- Do we need to go through each tree? lol vim.treesitter.get_parser(opts.buf, opts.ft):for_each_tree(function(tree) - -- Get the root for that tree - ---@type TSNode - local root = tree:root() + table.insert(result, module.public.search_tree(tree, node_type)) + end) - --- Recursively searches for a node of a given type - ---@param node TSNode #The starting point for the search - local function descend(node) - -- Iterate over all children of the node and try to match their type - for child, _ in node:iter_children() do ---@diagnostic disable-line -- TODO: type error workaround - if child:type() == type then - table.insert(result, child) - else - -- If no match is found try descending further down the syntax tree - for _, child_node in ipairs(descend(child) or {}) do - table.insert(result, child_node) - end + return vim.iter(result):flatten():totable() + end, + + ---Gets all nodes of a given type from the AST + ---@param node_type string #The type of node to filter out + ---@param path string path to the file to parse + ---@param filetype string? file type of the file or `norg` if omitted + get_all_nodes_in_file = function(node_type, path, filetype) + path = vim.fs.normalize(path) + if not filetype then filetype = "norg" end + + local contents = io.open(path, "r"):read("*a") + local tree = vim.treesitter.get_string_parser(contents, filetype):parse()[1] + if not (tree or tree.root) then return {} end + + return module.public.search_tree(tree, node_type) + end, + + search_tree = function(tree, node_type) + local result = {} + local root = tree:root() + + --- Recursively searches for a node of a given type + ---@param node TSNode #The starting point for the search + local function descend(node) + -- Iterate over all children of the node and try to match their type + for child, _ in node:iter_children() do + if child:type() == node_type then + table.insert(result, child) + else + -- If no match is found try descending further down the syntax tree + for _, child_node in ipairs(descend(child) or {}) do + table.insert(result, child_node) end end end + end - descend(root) - end) - + descend(root) return result end, + --- Executes function callback on each child node of the root ---@param callback function ---@param ts_tree any #Optional syntax tree ---@diagnostic disable-line -- TODO: type error workaround @@ -566,7 +586,7 @@ module.public = { end, --- Gets the range of a given node ---@param node userdata #The node to get the range of - ---@return table #A table of `row_start`, `column_start`, `row_end` and `column_end` values + ---@return { row_start: number, column_start: number, row_end: number, column_end: number } range get_node_range = function(node) if not node then return { @@ -639,7 +659,7 @@ module.public = { ---@param line number #The line number (0-indexed) to get the node from -- the same line as `line`. ---@param stop_type string|table? #Don't recurse to the provided type(s) - ---@return userdata|nil #The first node on `line` + ---@return TSNode|nil #The first node on `line` get_first_node_on_line = function(buf, line, stop_type) if type(stop_type) == "string" then stop_type = { stop_type }