Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow capturing multiple nodes in textobject queries #1611

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 110 additions & 10 deletions helix-core/src/syntax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,16 +244,56 @@ pub struct TextObjectQuery {
pub query: Query,
}

pub enum CapturedNode<'a> {
Single(Node<'a>),
/// Guarenteed to be not empty
Grouped(Vec<Node<'a>>),
}

impl<'a> CapturedNode<'a> {
pub fn start_byte(&self) -> usize {
match self {
Self::Single(n) => n.start_byte(),
Self::Grouped(ns) => ns[0].start_byte(),
}
}

pub fn end_byte(&self) -> usize {
match self {
Self::Single(n) => n.end_byte(),
Self::Grouped(ns) => ns.last().unwrap().end_byte(),
sudormrfbin marked this conversation as resolved.
Show resolved Hide resolved
}
}

pub fn byte_range(&self) -> std::ops::Range<usize> {
self.start_byte()..self.end_byte()
}
}

impl TextObjectQuery {
/// Run the query on the given node and return sub nodes which match given
/// capture ("function.inside", "class.around", etc).
///
/// Captures may contain multiple nodes by using quantifiers (+, *, etc),
/// and support for this is partial and could use improvement.
///
/// ```query
/// ;; supported:
/// (comment)+ @capture
///
/// ;; unsupported:
/// (
/// (comment)+
/// (function)
/// ) @capture
/// ```
pub fn capture_nodes<'a>(
&'a self,
capture_name: &str,
node: Node<'a>,
slice: RopeSlice<'a>,
cursor: &'a mut QueryCursor,
) -> Option<impl Iterator<Item = Node<'a>>> {
) -> Option<impl Iterator<Item = CapturedNode<'a>>> {
self.capture_nodes_any(&[capture_name], node, slice, cursor)
}

Expand All @@ -265,17 +305,33 @@ impl TextObjectQuery {
node: Node<'a>,
slice: RopeSlice<'a>,
cursor: &'a mut QueryCursor,
) -> Option<impl Iterator<Item = Node<'a>>> {
) -> Option<impl Iterator<Item = CapturedNode<'a>>> {
let capture_idx = capture_names
.iter()
.find_map(|cap| self.query.capture_index_for_name(cap))?;
let captures = cursor.captures(&self.query, node, RopeProvider(slice));
let captures = cursor.matches(&self.query, node, RopeProvider(slice));

let nodes = captures.flat_map(move |mat| {
let captures = mat.captures.iter().filter(move |c| c.index == capture_idx);
let nodes = captures.map(|c| c.node);
let pattern_idx = mat.pattern_index;
let quantifier = self.query.capture_quantifiers(pattern_idx)[capture_idx as usize];

let iter: Box<dyn Iterator<Item = CapturedNode>> = match quantifier {
CaptureQuantifier::OneOrMore | CaptureQuantifier::ZeroOrMore => {
let nodes: Vec<Node> = nodes.collect();
if nodes.is_empty() {
Box::new(std::iter::empty())
} else {
Box::new(std::iter::once(CapturedNode::Grouped(nodes)))
}
}
_ => Box::new(nodes.map(CapturedNode::Single)),
};

captures
.filter_map(move |(mat, idx)| {
(mat.captures[idx].index == capture_idx).then(|| mat.captures[idx].node)
})
.into()
iter
});
Some(nodes)
}
}

Expand Down Expand Up @@ -1075,8 +1131,8 @@ pub(crate) fn generate_edits(
use std::sync::atomic::{AtomicUsize, Ordering};
use std::{iter, mem, ops, str, usize};
use tree_sitter::{
Language as Grammar, Node, Parser, Point, Query, QueryCaptures, QueryCursor, QueryError,
QueryMatch, Range, TextProvider, Tree,
CaptureQuantifier, Language as Grammar, Node, Parser, Point, Query, QueryCaptures, QueryCursor,
QueryError, QueryMatch, Range, TextProvider, Tree,
};

const CANCELLATION_CHECK_INTERVAL: usize = 100;
Expand Down Expand Up @@ -1928,6 +1984,50 @@ mod test {
use super::*;
use crate::{Rope, Transaction};

#[test]
fn test_textobject_queries() {
let query_str = r#"
(line_comment)+ @quantified_nodes
((line_comment)+) @quantified_nodes_grouped
((line_comment) (line_comment)) @multiple_nodes_grouped
"#;
let source = Rope::from_str(
r#"
/// a comment on
/// mutiple lines
"#,
);

let loader = Loader::new(Configuration { language: vec![] });
let language = get_language(&crate::RUNTIME_DIR, "Rust").unwrap();

let query = Query::new(language, query_str).unwrap();
let textobject = TextObjectQuery { query };
let mut cursor = QueryCursor::new();

let config = HighlightConfiguration::new(language, "", "", "").unwrap();
let syntax = Syntax::new(&source, Arc::new(config), Arc::new(loader));

let root = syntax.tree().root_node();
let mut test = |capture, range| {
let matches: Vec<_> = textobject
.capture_nodes(capture, root, source.slice(..), &mut cursor)
.unwrap()
.collect();

assert_eq!(
matches[0].byte_range(),
range,
"@{capture} expected {range:?}"
)
};

test("quantified_nodes", 1..35);
// NOTE: Enable after implementing proper node group capturing
// test("quantified_nodes_grouped", 1..35);
// test("multiple_nodes_grouped", 1..35);
}

#[test]
fn test_parser() {
let highlight_names: Vec<String> = [
Expand Down