diff --git a/crates/core/src/binding.rs b/crates/core/src/binding.rs index 80efcfc5c..dd64786c1 100644 --- a/crates/core/src/binding.rs +++ b/crates/core/src/binding.rs @@ -3,12 +3,14 @@ use crate::pattern::resolved_pattern::CodeRange; use crate::pattern::state::{get_top_level_effects, FileRegistry}; use crate::pattern::{Effect, EffectKind}; use anyhow::{anyhow, Result}; +use grit_util::AstNode; use marzano_language::language::{FieldId, Language}; use marzano_language::target_language::TargetLanguage; use marzano_util::analysis_logs::{AnalysisLogBuilder, AnalysisLogs}; use marzano_util::node_with_source::NodeWithSource; use marzano_util::position::{Position, Range}; use marzano_util::tree_sitter_util::children_by_field_id_count; +use std::iter; use std::ops::Range as StdRange; use std::path::Path; use std::{borrow::Cow, collections::HashMap, fmt::Display}; @@ -103,6 +105,17 @@ fn pad_snippet(padding: &str, snippet: &str) -> String { result } +fn adjust_ranges(substitutions: &mut [(EffectRange, String)], index: usize, delta: isize) { + for (EffectRange { range, .. }, _) in substitutions.iter_mut() { + if range.start >= index { + range.start = (range.start as isize + delta) as usize; + } + if range.end >= index { + range.end = (range.end as isize + delta) as usize; + } + } +} + // in multiline snippets, remove padding from every line equal to the padding of the first line, // such that the first line is left-aligned. pub(crate) fn adjust_padding<'a>( @@ -136,14 +149,11 @@ pub(crate) fn adjust_padding<'a>( for line in lines { result.push('\n'); index += 1; - for (EffectRange { range, .. }, _) in substitutions.iter_mut() { - if range.start >= index { - range.start = (range.start as isize + delta) as usize; - } - if range.end >= index { - range.end = (range.end as isize + delta) as usize; - } + if line.trim().is_empty() { + adjust_ranges(substitutions, index, -(line.len() as isize)); + continue; } + adjust_ranges(substitutions, index, delta); let line = line.strip_prefix(&padding).ok_or_else(|| { anyhow!( "expected line \n{}\n to start with {} spaces, code is either not indented with spaces, or does not consistently indent code blocks", @@ -460,7 +470,9 @@ impl<'a> Binding<'a> { pub fn text(&self) -> String { match self { Binding::Empty(_, _, _) => "".to_string(), - Binding::Node(source, node) => node_text(source, node).to_string(), + Binding::Node(source, node) => { + NodeWithSource::new(node.clone(), source).text().to_string() + } Binding::String(s, r) => s[r.start_byte as usize..r.end_byte as usize].into(), Binding::FileName(s) => s.to_string_lossy().into(), Binding::List(source, _, _) => { @@ -470,13 +482,7 @@ impl<'a> Binding<'a> { "".to_string() } } - Binding::ConstantRef(c) => match c { - Constant::Boolean(b) => b.to_string(), - Constant::String(s) => s.to_string(), - Constant::Integer(i) => i.to_string(), - Constant::Float(d) => d.to_string(), - Constant::Undefined => String::new(), - }, + Binding::ConstantRef(c) => c.to_string(), } } @@ -522,6 +528,44 @@ impl<'a> Binding<'a> { matches!(self, Self::List(..)) } + /// Returns an iterator over the items in a list. + /// + /// Returns `None` if the binding is not bound to a list. + pub(crate) fn list_items(&self) -> Option>> { + match self { + Self::List(src, node, field_id) => { + let field_id = *field_id; + let mut cursor = node.walk(); + cursor.goto_first_child(); + let mut done = false; + Some( + iter::from_fn(move || { + while !done { + while cursor.field_id() != Some(field_id) { + if !cursor.goto_next_sibling() { + return None; + } + } + let result = cursor.node(); + if !cursor.goto_next_sibling() { + done = true; + } + return Some(result); + } + None + }) + .filter(|child| child.is_named()) + .map(|named_child| NodeWithSource::new(named_child, src)), + ) + } + Self::Empty(..) + | Self::Node(..) + | Self::String(..) + | Self::ConstantRef(..) + | Self::FileName(..) => None, + } + } + /// Returns the parent node of this binding. /// /// Returns `None` if the binding has no relation to a node. @@ -576,8 +620,3 @@ impl<'a> Binding<'a> { Ok(()) } } - -pub(crate) fn node_text<'a>(source: &'a str, node: &Node) -> &'a str { - let range = Range::from(node.range()); - &source[range.start_byte as usize..range.end_byte as usize] -} diff --git a/crates/core/src/pattern/after.rs b/crates/core/src/pattern/after.rs index 26abccf43..698826ef6 100644 --- a/crates/core/src/pattern/after.rs +++ b/crates/core/src/pattern/after.rs @@ -103,7 +103,7 @@ impl Matcher for After { }; let prev_node = resolve!(node.previous_non_trivia_node()); if !self.after.execute( - &ResolvedPattern::from_node(prev_node.source, prev_node.node), + &ResolvedPattern::from_node(prev_node), &mut cur_state, context, logs, diff --git a/crates/core/src/pattern/ast_node.rs b/crates/core/src/pattern/ast_node.rs index e1d5b9a54..c0804e913 100644 --- a/crates/core/src/pattern/ast_node.rs +++ b/crates/core/src/pattern/ast_node.rs @@ -203,7 +203,7 @@ impl Matcher for ASTNode { ) } else if let Some(child) = node.child_by_field_id(*field_id) { pattern.execute( - &ResolvedPattern::from_node(source, child), + &ResolvedPattern::from_node(NodeWithSource::new(child, source)), &mut cur_state, context, logs, diff --git a/crates/core/src/pattern/before.rs b/crates/core/src/pattern/before.rs index b8927ee0a..0b5c682e1 100644 --- a/crates/core/src/pattern/before.rs +++ b/crates/core/src/pattern/before.rs @@ -103,7 +103,7 @@ impl Matcher for Before { }; let next_node = resolve!(node.next_non_trivia_node()); if !self.before.execute( - &ResolvedPattern::from_node(next_node.source, next_node.node), + &ResolvedPattern::from_node(next_node), &mut cur_state, context, logs, diff --git a/crates/core/src/pattern/built_in_functions.rs b/crates/core/src/pattern/built_in_functions.rs index 5de4cdf68..3358ce503 100644 --- a/crates/core/src/pattern/built_in_functions.rs +++ b/crates/core/src/pattern/built_in_functions.rs @@ -402,28 +402,20 @@ fn distinct_fn<'a>( Ok(ResolvedPattern::List(unique_list)) } Some(ResolvedPattern::Binding(binding)) => match binding.last() { - Some(b) => match b { - Binding::List(src, parent_node, field_id) => { + Some(b) => { + if let Some(list_items) = b.list_items() { let mut unique_list = Vector::new(); - for child in parent_node - .children_by_field_id(*field_id, &mut parent_node.walk()) - .filter(|child| child.is_named()) - { - let resolved = ResolvedPattern::from_node(src, child); + for item in list_items { + let resolved = ResolvedPattern::from_node(item); if !unique_list.contains(&resolved) { unique_list.push_back(resolved); } } Ok(ResolvedPattern::List(unique_list)) - } - Binding::String(..) - | Binding::FileName(_) - | Binding::Node(..) - | Binding::Empty(..) - | Binding::ConstantRef(_) => { + } else { bail!("distinct takes a list as the first argument") } - }, + } None => Ok(ResolvedPattern::Binding(binding)), }, _ => Err(anyhow!("distinct takes a list as the first argument")), @@ -454,19 +446,17 @@ fn shuffle_fn<'a>( Ok(ResolvedPattern::List(shuffled_list.into())) } ResolvedPattern::Binding(binding) => match binding.last() { - Some(Binding::List(src, parent_node, field_id)) => { - let mut list = parent_node - .children_by_field_id(*field_id, &mut parent_node.walk()) - .filter(|child| child.is_named()) - .collect::>(); - list.shuffle(state.get_rng()); - let list = list - .into_iter() - .map(|child| ResolvedPattern::from_node(src, child)) - .collect::>(); - Ok(ResolvedPattern::List(list)) + Some(b) => { + if let Some(list_items) = b.list_items() { + let mut list: Vec<_> = list_items.collect(); + list.shuffle(state.get_rng()); + let list: Vector<_> = + list.into_iter().map(ResolvedPattern::from_node).collect(); + Ok(ResolvedPattern::List(list)) + } else { + Err(anyhow!("shuffle takes a list as the first argument")) + } } - Some(_) => Err(anyhow!("shuffle takes a list as the first argument")), None => Err(anyhow!("shuffle argument must be bound")), }, ResolvedPattern::Snippets(_) diff --git a/crates/core/src/pattern/code_snippet.rs b/crates/core/src/pattern/code_snippet.rs index 0e26bb033..c9779bd0c 100644 --- a/crates/core/src/pattern/code_snippet.rs +++ b/crates/core/src/pattern/code_snippet.rs @@ -1,5 +1,3 @@ -use crate::{binding::Binding, context::Context, resolve}; - use super::{ dynamic_snippet::{DynamicPattern, DynamicSnippet}, patterns::{Matcher, Name, Pattern}, @@ -7,16 +5,17 @@ use super::{ variable::{register_variable, VariableSourceLocations}, State, }; +use crate::{context::Context, resolve}; use anyhow::{anyhow, bail, Result}; use core::fmt::Debug; -use marzano_util::{analysis_logs::AnalysisLogs, position::Range}; -use std::collections::BTreeMap; -use tree_sitter::Node; - use marzano_language::{ language::{nodes_from_indices, Language, SortId}, target_language::TargetLanguage, }; +use marzano_util::{analysis_logs::AnalysisLogs, position::Range}; +use std::collections::BTreeMap; +use tree_sitter::Node; + #[derive(Debug, Clone)] pub struct CodeSnippet { pub(crate) patterns: Vec<(SortId, Pattern)>, @@ -117,28 +116,15 @@ impl Matcher for CodeSnippet { } }; - let node = match binding { - Binding::Empty(_, _, _) => return Ok(false), - Binding::Node(_, node) => node.to_owned(), - // maybe String should instead be fake node? eg for comment_content - Binding::String(_, _) => return Ok(false), - Binding::List(_, node, id) => { - let mut cursor = node.walk(); - let mut list = node.children_by_field_id(*id, &mut cursor); - if let Some(child) = list.next() { - if list.next().is_some() { - return Ok(false); - } - child - } else { - return Ok(false); - } - } - Binding::FileName(_) => return Ok(false), - Binding::ConstantRef(_) => return Ok(false), + let Some(node) = binding.singleton() else { + return Ok(false); }; - if let Some((_, pattern)) = self.patterns.iter().find(|(id, _)| *id == node.kind_id()) { + if let Some((_, pattern)) = self + .patterns + .iter() + .find(|(id, _)| *id == node.node.kind_id()) + { pattern.execute(resolved_pattern, state, context, logs) } else { Ok(false) diff --git a/crates/core/src/pattern/contains.rs b/crates/core/src/pattern/contains.rs index b56173dc1..85f8d3965 100644 --- a/crates/core/src/pattern/contains.rs +++ b/crates/core/src/pattern/contains.rs @@ -11,7 +11,7 @@ use super::{ use anyhow::{anyhow, Result}; use core::fmt::Debug; use im::vector; -use marzano_util::analysis_logs::AnalysisLogs; +use marzano_util::{analysis_logs::AnalysisLogs, node_with_source::NodeWithSource}; use std::collections::BTreeMap; @@ -86,7 +86,7 @@ fn execute_until<'a>( let mut still_computing = true; while still_computing { let node = cursor.node(); - let node_lhs = ResolvedPattern::from_node(src, node); + let node_lhs = ResolvedPattern::from_node(NodeWithSource::new(node, src)); let state = cur_state.clone(); if the_contained.execute(&node_lhs, &mut cur_state, context, logs)? { @@ -137,50 +137,43 @@ impl Matcher for Contains { match resolved_pattern { ResolvedPattern::Binding(bindings) => { let binding = resolve!(bindings.last()); - let mut did_match = false; - let mut cur_state = init_state.clone(); - let mut cursor; // needed for scope in case of list. - match binding { - Binding::Empty(_, _, _) => Ok(false), - Binding::String(_, _) => Ok(false), - Binding::Node(src, node) => execute_until( + if let Some(node) = binding.as_node() { + execute_until( init_state, - node, - src, + &node.node, + node.source, context, logs, &self.contains, &self.until, - ), - Binding::List(src, node, field_id) => { - cursor = node.walk(); - let children = node.children_by_field_id(*field_id, &mut cursor); - - for child in children { - let state = cur_state.clone(); - if self.execute( - &ResolvedPattern::from_node(src, child), - &mut cur_state, - context, - logs, - )? { - did_match = true; - } else { - cur_state = state; - } + ) + } else if let Some(list_items) = binding.list_items() { + let mut did_match = false; + let mut cur_state = init_state.clone(); + for item in list_items { + let state = cur_state.clone(); + if self.execute( + &ResolvedPattern::from_node(item), + &mut cur_state, + context, + logs, + )? { + did_match = true; + } else { + cur_state = state; } + } - if did_match { - *init_state = cur_state; - } - Ok(did_match) + if did_match { + *init_state = cur_state; } - Binding::FileName(_) => Ok(false), + Ok(did_match) + } else if let Some(_c) = binding.as_constant() { // this seems like an infinite loop, todo return false? - Binding::ConstantRef(_c) => { - self.contains - .execute(resolved_pattern, init_state, context, logs) - } + self.contains + .execute(resolved_pattern, init_state, context, logs) + } else { + Ok(false) } } ResolvedPattern::List(elements) => { diff --git a/crates/core/src/pattern/every.rs b/crates/core/src/pattern/every.rs index 1f343ebb6..7e51ddc80 100644 --- a/crates/core/src/pattern/every.rs +++ b/crates/core/src/pattern/every.rs @@ -5,7 +5,7 @@ use super::{ variable::VariableSourceLocations, State, }; -use crate::{binding::Binding, context::Context, resolve}; +use crate::{context::Context, resolve}; use anyhow::{anyhow, Result}; use im::vector; use marzano_util::analysis_logs::AnalysisLogs; @@ -67,32 +67,21 @@ impl Matcher for Every { match binding { ResolvedPattern::Binding(bindings) => { let binding = resolve!(bindings.last()); - let pattern = &self.pattern; + let Some(list_items) = binding.list_items() else { + return Ok(false); + }; - match binding { - Binding::Empty(_, _, _) => Ok(false), - Binding::Node(_, _node) => Ok(false), - Binding::String(_, _) => Ok(false), - Binding::List(src, node, field_id) => { - let mut cursor = node.walk(); - let children = node - .children_by_field_id(*field_id, &mut cursor) - .filter(|c| c.is_named()); - for child in children { - if !pattern.execute( - &ResolvedPattern::from_node(src, child), - init_state, - context, - logs, - )? { - return Ok(false); - } - } - Ok(true) + for item in list_items { + if !self.pattern.execute( + &ResolvedPattern::from_node(item), + init_state, + context, + logs, + )? { + return Ok(false); } - Binding::ConstantRef(_) => Ok(false), - Binding::FileName(_) => Ok(false), } + Ok(true) } ResolvedPattern::List(elements) => { let pattern = &self.pattern; diff --git a/crates/core/src/pattern/list.rs b/crates/core/src/pattern/list.rs index 2eb533b91..a64adbc26 100644 --- a/crates/core/src/pattern/list.rs +++ b/crates/core/src/pattern/list.rs @@ -6,7 +6,7 @@ use super::{ state::State, variable::VariableSourceLocations, }; -use crate::{binding::Binding, context::Context}; +use crate::context::Context; use anyhow::{anyhow, bail, Result}; use core::fmt::Debug; use marzano_language::language::Field; @@ -121,21 +121,16 @@ impl Matcher for List { ) -> Result { match binding { ResolvedPattern::Binding(v) => { - if let Some(Binding::List(src, node, field_id)) = v.last() { - let mut cursor = node.walk(); + let Some(list_items) = v.last().and_then(|b| b.list_items()) else { + return Ok(false); + }; - let children_vec: Vec> = node - .children_by_field_id(*field_id, &mut cursor) - .filter(|n| n.is_named()) - .map(|node| Cow::Owned(ResolvedPattern::from_node(src, node))) - .collect(); + let children: Vec> = list_items + .map(ResolvedPattern::from_node) + .map(Cow::Owned) + .collect(); - let children: &[Cow] = &children_vec; - let patterns: &[Pattern] = &self.patterns; - execute_assoc(patterns, children, state, context, logs) - } else { - Ok(false) - } + execute_assoc(&self.patterns, &children, state, context, logs) } ResolvedPattern::List(patterns) => { let patterns: Vec>> = diff --git a/crates/core/src/pattern/resolved_pattern.rs b/crates/core/src/pattern/resolved_pattern.rs index fde39a2de..5fdc9ca52 100644 --- a/crates/core/src/pattern/resolved_pattern.rs +++ b/crates/core/src/pattern/resolved_pattern.rs @@ -17,7 +17,9 @@ use crate::{ use anyhow::{anyhow, bail, Result}; use im::{vector, Vector}; use marzano_language::{language::FieldId, target_language::TargetLanguage}; -use marzano_util::{analysis_logs::AnalysisLogs, position::Range}; +use marzano_util::{ + analysis_logs::AnalysisLogs, node_with_source::NodeWithSource, position::Range, +}; use std::{ borrow::Cow, collections::{BTreeMap, HashMap}, @@ -478,8 +480,8 @@ impl<'a> ResolvedPattern<'a> { Self::Binding(vector![Binding::ConstantRef(constant)]) } - pub(crate) fn from_node(src: &'a str, node: Node<'a>) -> Self { - Self::Binding(vector![Binding::Node(src, node)]) + pub(crate) fn from_node(node: NodeWithSource<'a>) -> Self { + Self::Binding(vector![Binding::Node(node.source, node.node)]) } pub(crate) fn from_list(src: &'a str, node: Node<'a>, field_id: FieldId) -> Self { diff --git a/crates/core/src/pattern/some.rs b/crates/core/src/pattern/some.rs index 69425ede6..a8d3202e6 100644 --- a/crates/core/src/pattern/some.rs +++ b/crates/core/src/pattern/some.rs @@ -1,8 +1,3 @@ -use anyhow::{anyhow, Result}; -use im::vector; - -use crate::{binding::Binding, context::Context, resolve}; - use super::{ compiler::CompilationContext, patterns::{Matcher, Name, Pattern}, @@ -10,6 +5,9 @@ use super::{ variable::VariableSourceLocations, State, }; +use crate::{context::Context, resolve}; +use anyhow::{anyhow, Result}; +use im::vector; use marzano_util::analysis_logs::AnalysisLogs; use std::collections::BTreeMap; use tree_sitter::Node; @@ -67,37 +65,27 @@ impl Matcher for Some { match binding { ResolvedPattern::Binding(bindings) => { let binding = resolve!(bindings.last()); + let Some(list_items) = binding.list_items() else { + return Ok(false); + }; + let mut did_match = false; - let pattern = &self.pattern; - match binding { - Binding::Empty(_, _, _) => Ok(false), - Binding::Node(_, _node) => Ok(false), - Binding::String(_, _) => Ok(false), - Binding::List(src, node, field_id) => { - let mut cur_state = init_state.clone(); - let mut cursor = node.walk(); - let children = node - .children_by_field_id(*field_id, &mut cursor) - .filter(|c| c.is_named()); - for child in children { - let state = cur_state.clone(); - if pattern.execute( - &ResolvedPattern::from_node(src, child), - &mut cur_state, - context, - logs, - )? { - did_match = true; - } else { - cur_state = state; - } - } - *init_state = cur_state; - Ok(did_match) + let mut cur_state = init_state.clone(); + for item in list_items { + let state = cur_state.clone(); + if self.pattern.execute( + &ResolvedPattern::from_node(item), + &mut cur_state, + context, + logs, + )? { + did_match = true; + } else { + cur_state = state; } - Binding::ConstantRef(_) => Ok(false), - Binding::FileName(_) => Ok(false), } + *init_state = cur_state; + Ok(did_match) } ResolvedPattern::List(elements) => { let pattern = &self.pattern; diff --git a/crates/core/src/pattern/state.rs b/crates/core/src/pattern/state.rs index b6d126cb0..50613f251 100644 --- a/crates/core/src/pattern/state.rs +++ b/crates/core/src/pattern/state.rs @@ -246,9 +246,7 @@ impl<'a> State<'a> { if let ResolvedPattern::Binding(bindings) = value { for binding in bindings.iter() { bindings_count += 1; - if is_binding_suppressed(binding, lang, current_name) - .unwrap_or_default() - { + if is_binding_suppressed(binding, lang, current_name) { suppressed_count += 1; continue; } diff --git a/crates/core/src/pattern/string_constant.rs b/crates/core/src/pattern/string_constant.rs index 31c91dab3..e154c54ca 100644 --- a/crates/core/src/pattern/string_constant.rs +++ b/crates/core/src/pattern/string_constant.rs @@ -3,15 +3,12 @@ use super::{ resolved_pattern::ResolvedPattern, State, }; -use crate::{ - binding::{node_text, Binding}, - context::Context, -}; +use crate::{binding::Binding, context::Context}; use anyhow::{anyhow, Result}; use core::fmt::Debug; +use grit_util::AstNode; use marzano_language::language::{Language, LeafEquivalenceClass, SortId}; use marzano_util::analysis_logs::AnalysisLogs; -use marzano_util::tree_sitter_util::named_children_by_field_id; #[derive(Debug, Clone)] pub struct StringConstant { @@ -92,37 +89,18 @@ impl Matcher for AstLeafNode { _context: &'a impl Context, _logs: &mut AnalysisLogs, ) -> Result { - if let ResolvedPattern::Binding(b) = binding { - let (src, n) = match b.last() { - Some(Binding::Node(src, n)) => (src, n.to_owned()), - Some(Binding::List(src, n, f)) => { - let mut w = n.walk(); - let mut l = named_children_by_field_id(n, &mut w, *f); - if let (Some(n), None) = (l.next(), l.next()) { - (src, n) - } else { - return Ok(false); - } - } - Some(Binding::ConstantRef(..)) - | Some(Binding::Empty(..)) - | Some(Binding::FileName(..)) - | Some(Binding::String(..)) - | None => return Ok(false), - }; - if let Some(e) = &self.equivalence_class { - let text = node_text(src, &n); - return Ok(e.are_equivalent(n.kind_id(), text.trim())); - } else if self.sort != n.kind_id() { - return Ok(false); - } - let text = node_text(src, &n); - if text.trim() == self.text { - return Ok(true); - } else { - return Ok(false); - } + let ResolvedPattern::Binding(b) = binding else { + return Ok(false); + }; + let Some(node) = b.last().and_then(Binding::singleton) else { + return Ok(false); + }; + if let Some(e) = &self.equivalence_class { + Ok(e.are_equivalent(node.node.kind_id(), node.text().trim())) + } else if self.sort != node.node.kind_id() { + Ok(false) + } else { + Ok(node.text().trim() == self.text) } - Ok(false) } } diff --git a/crates/core/src/pattern/within.rs b/crates/core/src/pattern/within.rs index 409cd31d9..72ca15de6 100644 --- a/crates/core/src/pattern/within.rs +++ b/crates/core/src/pattern/within.rs @@ -9,7 +9,7 @@ use crate::{context::Context, resolve}; use anyhow::{anyhow, Result}; use core::fmt::Debug; use marzano_language::parent_traverse::{ParentTraverse, TreeSitterParentCursor}; -use marzano_util::analysis_logs::AnalysisLogs; +use marzano_util::{analysis_logs::AnalysisLogs, node_with_source::NodeWithSource}; use std::collections::BTreeMap; use tree_sitter::Node; @@ -88,7 +88,7 @@ impl Matcher for Within { for n in ParentTraverse::new(TreeSitterParentCursor::new(node.node)) { let state = cur_state.clone(); if self.pattern.execute( - &ResolvedPattern::from_node(node.source, n), + &ResolvedPattern::from_node(NodeWithSource::new(n, node.source)), &mut cur_state, context, logs, diff --git a/crates/core/src/suppress.rs b/crates/core/src/suppress.rs index 6b30ccf81..3cc6a00ac 100644 --- a/crates/core/src/suppress.rs +++ b/crates/core/src/suppress.rs @@ -1,3 +1,4 @@ +use itertools::{EitherOrBoth, Itertools}; use marzano_language::{ language::Language, parent_traverse::{ParentTraverse, TreeSitterParentCursor}, @@ -5,21 +6,17 @@ use marzano_language::{ use tree_sitter::{Node, Range}; use crate::binding::Binding; -use crate::resolve; -use anyhow::Result; pub(crate) fn is_binding_suppressed( binding: &Binding, lang: &impl Language, current_name: Option<&str>, -) -> Result { +) -> bool { let (src, node) = match binding { - Binding::Node(src, node) => (src, node), - Binding::String(_, _) => return Ok(false), - Binding::List(src, node, _) => (src, node), - Binding::Empty(src, node, _) => (src, node), - Binding::FileName(_) => return Ok(false), - Binding::ConstantRef(_) => return Ok(false), + Binding::Node(src, node) | Binding::List(src, node, _) | Binding::Empty(src, node, _) => { + (src, node) + } + Binding::String(_, _) | Binding::FileName(_) | Binding::ConstantRef(_) => return false, }; let target_range = node.range(); for n in @@ -34,13 +31,13 @@ pub(crate) fn is_binding_suppressed( if !(lang.is_comment(c.kind_id()) || lang.is_comment_wrapper(&c)) { continue; } - if is_suppress_comment(&c, src, &target_range, current_name, lang)? { - return Ok(true); + if is_suppress_comment(&c, src, &target_range, current_name, lang) { + return true; } } } - Ok(false) + false } fn is_suppress_comment( @@ -49,28 +46,30 @@ fn is_suppress_comment( target_range: &Range, current_name: Option<&str>, lang: &impl Language, -) -> Result { +) -> bool { let child_range = comment_node.range(); - let text = comment_node.utf8_text(src.as_bytes())?; + let Ok(text) = comment_node.utf8_text(src.as_bytes()) else { + return false; + }; let inline_suppress = child_range.end_point().row() >= target_range.start_point().row() && child_range.end_point().row() <= target_range.end_point().row(); if !inline_suppress { - let pre_suppress = comment_applies_to_range(comment_node, target_range, lang, src)? - && comment_occupies_entire_line(text.as_ref(), &comment_node.range(), src)?; + let pre_suppress = comment_applies_to_range(comment_node, target_range, lang, src) + && comment_occupies_entire_line(text.as_ref(), &comment_node.range(), src); if !pre_suppress { - return Ok(false); + return false; } } if !text.contains("grit-ignore") { - return Ok(false); + return false; } let comment_text = text.trim(); - let ignore_spec = match comment_text.split("grit-ignore").collect::>().get(1) { - Some(s) => match s.split(':').next() { + let ignore_spec = match comment_text.split_once("grit-ignore") { + Some((_, s)) => match s.split(':').next() { Some(s) => s.trim(), - None => return Ok(true), + None => return true, }, - None => return Ok(true), + None => return true, }; if ignore_spec.is_empty() || ignore_spec @@ -78,13 +77,15 @@ fn is_suppress_comment( .next() .is_some_and(|c| !c.is_alphanumeric() && c != '_') { - return Ok(true); + return true; } let Some(current_name) = current_name else { - return Ok(false); + return false; }; - let ignored_rules = ignore_spec.split(',').map(|s| s.trim()).collect::>(); - Ok(ignored_rules.contains(¤t_name)) + ignore_spec + .split(',') + .map(str::trim) + .contains(¤t_name) } fn comment_applies_to_range( @@ -92,28 +93,31 @@ fn comment_applies_to_range( range: &Range, lang: &impl Language, src: &str, -) -> Result { - let mut applicable = resolve!(comment_node.next_named_sibling()); +) -> bool { + let Some(mut applicable) = comment_node.next_named_sibling() else { + return false; + }; while let Some(next) = applicable.next_named_sibling() { if !lang.is_comment(applicable.kind_id()) && !lang.is_comment_wrapper(&applicable) // Some languages have significant whitespace; continue until we find a non-whitespace non-comment node - && !applicable.utf8_text(src.as_bytes())?.trim().is_empty() + && !applicable.utf8_text(src.as_bytes()).map_or(true, |text| text.trim().is_empty()) { break; } applicable = next; } let applicable_range = applicable.range(); - Ok(applicable_range.start_point().row() == range.start_point().row()) + applicable_range.start_point().row() == range.start_point().row() } -fn comment_occupies_entire_line(text: &str, range: &Range, src: &str) -> Result { - let code = src - .lines() +fn comment_occupies_entire_line(text: &str, range: &Range, src: &str) -> bool { + src.lines() .skip(range.start_point().row() as usize) .take((range.end_point().row() - range.start_point().row() + 1) as usize) - .collect::>() - .join("\n"); - Ok(code.trim() == text.trim()) + .zip_longest(text.split("\n")) + .all(|zipped| match zipped { + EitherOrBoth::Both(src_line, text_line) => src_line.trim() == text_line.trim(), + _ => false, + }) } diff --git a/crates/core/src/test.rs b/crates/core/src/test.rs index a2d037d24..46c54ad68 100644 --- a/crates/core/src/test.rs +++ b/crates/core/src/test.rs @@ -12885,3 +12885,34 @@ fn limit_export_default_match() { }) .unwrap(); } + +#[test] +fn python_support_empty_line() { + run_test_expected(TestArgExpected { + pattern: r#" + |engine marzano(0.1) + |language python + |`class $name: $body` => $body + |"# + .trim_margin() + .unwrap(), + source: r#" + |class MyClass: + | def function(self): + | result = 1 + 1 + | + | return result + |"# + .trim_margin() + .unwrap(), + expected: r#" + |def function(self): + | result = 1 + 1 + | + | return result + |"# + .trim_margin() + .unwrap(), + }) + .unwrap(); +} diff --git a/crates/core/src/text_unparser.rs b/crates/core/src/text_unparser.rs index b9429cf1f..ac5aeb9bf 100644 --- a/crates/core/src/text_unparser.rs +++ b/crates/core/src/text_unparser.rs @@ -28,20 +28,17 @@ pub(crate) fn apply_effects<'a>( current_name: Option<&str>, logs: &mut AnalysisLogs, ) -> Result<(String, Option>>)> { - let mut our_effects = Vec::new(); - for effect in effects { - let disabled = is_binding_suppressed(&effect.binding, language, current_name)?; - if !disabled { - our_effects.push(effect); - } - } - if our_effects.is_empty() { + let effects: Vec<_> = effects + .into_iter() + .filter(|effect| !is_binding_suppressed(&effect.binding, language, current_name)) + .collect(); + if effects.is_empty() { return Ok((code.to_string(), None)); } let mut memo: HashMap> = HashMap::new(); let (from_inline, ranges) = linearize_binding( language, - &our_effects, + &effects, files, &mut memo, code, @@ -49,17 +46,12 @@ pub(crate) fn apply_effects<'a>( language.should_pad_snippet().then_some(0), logs, )?; - for effect in our_effects.iter() { + for effect in effects.iter() { if let Binding::FileName(c) = effect.binding { if std::ptr::eq(c, the_filename) { - let snippet = effect.pattern.linearized_text( - language, - &our_effects, - files, - &mut memo, - false, - logs, - )?; + let snippet = effect + .pattern + .linearized_text(language, &effects, files, &mut memo, false, logs)?; *new_filename = PathBuf::from(snippet.to_string()); } } diff --git a/crates/grit-util/src/ast_node.rs b/crates/grit-util/src/ast_node.rs index e3873657f..b9644d02f 100644 --- a/crates/grit-util/src/ast_node.rs +++ b/crates/grit-util/src/ast_node.rs @@ -7,4 +7,7 @@ pub trait AstNode: Sized { /// Returns the previous node, ignoring trivia such as whitespace. fn previous_non_trivia_node(&self) -> Option; + + /// Returns the text representation of the node. + fn text(&self) -> &str; } diff --git a/crates/util/src/node_with_source.rs b/crates/util/src/node_with_source.rs index 5da7e5706..51e223491 100644 --- a/crates/util/src/node_with_source.rs +++ b/crates/util/src/node_with_source.rs @@ -1,3 +1,4 @@ +use crate::position::Range; use grit_util::AstNode; use tree_sitter::Node; @@ -34,4 +35,9 @@ impl<'a> AstNode for NodeWithSource<'a> { current_node = current_node.parent()?; } } + + fn text(&self) -> &str { + let range = Range::from(self.node.range()); + &self.source[range.start_byte as usize..range.end_byte as usize] + } }