diff --git a/lib/prism/node_ext.rb b/lib/prism/node_ext.rb index 10fc1decbfb..d4c9b21079e 100644 --- a/lib/prism/node_ext.rb +++ b/lib/prism/node_ext.rb @@ -172,7 +172,7 @@ def child DEPRECATED: ConstantPathNode#child is deprecated and will be removed \ in the next major version. Use \ ConstantPathNode#name/ConstantPathNode#name_loc instead. Called from \ - #{caller(1..1).first}. + #{caller(1, 1)&.first}. MSG name ? ConstantReadNode.new(source, name, name_loc) : MissingNode.new(source, location) @@ -214,7 +214,7 @@ def child DEPRECATED: ConstantPathTargetNode#child is deprecated and will be \ removed in the next major version. Use \ ConstantPathTargetNode#name/ConstantPathTargetNode#name_loc instead. \ - Called from #{caller(1..1).first}. + Called from #{caller(1, 1)&.first}. MSG name ? ConstantReadNode.new(source, name, name_loc) : MissingNode.new(source, location) diff --git a/lib/prism/pattern.rb b/lib/prism/pattern.rb index 91b23afe3e1..03fec267896 100644 --- a/lib/prism/pattern.rb +++ b/lib/prism/pattern.rb @@ -149,7 +149,10 @@ def compile_constant_path_node(node) parent = node.parent if parent.is_a?(ConstantReadNode) && parent.slice == "Prism" - compile_constant_name(node, node.name) + name = node.name + raise CompilationError, node.inspect if name.nil? + + compile_constant_name(node, name) else compile_error(node) end diff --git a/sig/prism/_private/pattern.rbs b/sig/prism/_private/pattern.rbs index e0e0117054d..244206d15d1 100644 --- a/sig/prism/_private/pattern.rbs +++ b/sig/prism/_private/pattern.rbs @@ -15,6 +15,7 @@ module Prism def compile_alternation_pattern_node: (AlternationPatternNode) -> Proc def compile_constant_path_node: (ConstantPathNode) -> Proc def compile_constant_read_node: (ConstantReadNode) -> Proc + def compile_constant_name: (Prism::node, Symbol) -> Proc def compile_hash_pattern_node: (HashPatternNode) -> Proc def compile_nil_node: (NilNode) -> Proc def compile_regular_expression_node: (RegularExpressionNode) -> Proc diff --git a/templates/lib/prism/node.rb.erb b/templates/lib/prism/node.rb.erb index 6ffbf7b1466..4e07c397e6d 100644 --- a/templates/lib/prism/node.rb.erb +++ b/templates/lib/prism/node.rb.erb @@ -76,6 +76,43 @@ module Prism DotVisitor.new.tap { |visitor| accept(visitor) }.to_dot end + # Returns a list of nodes that are descendants of this node that contain the + # given line and column. This is useful for locating a node that is selected + # based on the line and column of the source code. + # + # Important to note is that the column given to this method should be in + # bytes, as opposed to characters or code units. + def tunnel(line, column) + queue = [self] + result = [] + + while (node = queue.shift) + result << node + + node.compact_child_nodes.each do |child_node| + child_location = child_node.location + + start_line = child_location.start_line + end_line = child_location.end_line + + if start_line == end_line + if line == start_line && column >= child_location.start_column && column < child_location.end_column + queue << child_node + break + end + elsif (line == start_line && column >= child_location.start_column) || (line == end_line && column < child_location.end_column) + queue << child_node + break + elsif line > start_line && line < end_line + queue << child_node + break + end + end + end + + result + end + # Returns a list of the fields that exist for this node class. Fields # describe the structure of the node. This kind of reflection is useful for # things like recursively visiting each node _and_ field in the tree. diff --git a/templates/rbi/prism/node.rbi.erb b/templates/rbi/prism/node.rbi.erb index 9184ac6440a..923324ff20f 100644 --- a/templates/rbi/prism/node.rbi.erb +++ b/templates/rbi/prism/node.rbi.erb @@ -31,6 +31,9 @@ class Prism::Node sig { returns(String) } def to_dot; end + sig { params(line: Integer, column: Integer).returns(T::Array[Prism::Node]) } + def tunnel(line, column); end + sig { abstract.params(visitor: Prism::Visitor).returns(T.untyped) } def accept(visitor); end diff --git a/templates/sig/prism/node.rbs.erb b/templates/sig/prism/node.rbs.erb index 30c91eef0ed..5354dc3cb07 100644 --- a/templates/sig/prism/node.rbs.erb +++ b/templates/sig/prism/node.rbs.erb @@ -18,6 +18,7 @@ module Prism def slice_lines: () -> String def pretty_print: (untyped q) -> untyped def to_dot: () -> String + def tunnel: (Integer line, Integer column) -> Array[Prism::node] end type node_singleton = singleton(Node) & _NodeSingleton diff --git a/test/prism/ruby_api_test.rb b/test/prism/ruby_api_test.rb index 9e408d1edd4..538e715ca3f 100644 --- a/test/prism/ruby_api_test.rb +++ b/test/prism/ruby_api_test.rb @@ -266,6 +266,25 @@ def test_node_equality refute_operator parse_expression(complex_source_1), :===, parse_expression(complex_source_2) end + def test_node_tunnel + program = Prism.parse("foo(1) +\n bar(2, 3) +\n baz(3, 4, 5)").value + + tunnel = program.tunnel(1, 4).last + assert_kind_of IntegerNode, tunnel + assert_equal 1, tunnel.value + + tunnel = program.tunnel(2, 6).last + assert_kind_of IntegerNode, tunnel + assert_equal 2, tunnel.value + + tunnel = program.tunnel(3, 9).last + assert_kind_of IntegerNode, tunnel + assert_equal 4, tunnel.value + + tunnel = program.tunnel(3, 8) + assert_equal [ProgramNode, StatementsNode, CallNode, ArgumentsNode, CallNode, ArgumentsNode], tunnel.map(&:class) + end + private def parse_expression(source)