diff --git a/Steepfile b/Steepfile index 5b18ffab86f..e710262e719 100644 --- a/Steepfile +++ b/Steepfile @@ -13,5 +13,6 @@ target :lib do ignore "lib/prism/lex_compat.rb" ignore "lib/prism/serialize.rb" ignore "lib/prism/ffi.rb" + ignore "lib/prism/polyfill/byteindex.rb" ignore "lib/prism/translation" end diff --git a/lib/prism.rb b/lib/prism.rb index c733baf8c92..19774538e7a 100644 --- a/lib/prism.rb +++ b/lib/prism.rb @@ -67,6 +67,7 @@ def self.load(source, serialized) end end +require_relative "prism/polyfill/byteindex" require_relative "prism/node" require_relative "prism/node_ext" require_relative "prism/parse_result" diff --git a/lib/prism/parse_result.rb b/lib/prism/parse_result.rb index e8d77172283..63cc72a9664 100644 --- a/lib/prism/parse_result.rb +++ b/lib/prism/parse_result.rb @@ -347,6 +347,18 @@ def join(other) Location.new(source, start_offset, other.end_offset - start_offset) end + + # Join this location with the first occurrence of the string in the source + # that occurs after this location on the same line, and return the new + # location. This will raise an error if the string does not exist. + def adjoin(string) + line_suffix = source.slice(end_offset, source.line_end(end_offset) - end_offset) + + line_suffix_index = line_suffix.byteindex(string) + raise "Could not find #{string}" if line_suffix_index.nil? + + Location.new(source, start_offset, length + line_suffix_index + string.bytesize) + end end # This represents a comment that was encountered during parsing. It is the diff --git a/lib/prism/polyfill/byteindex.rb b/lib/prism/polyfill/byteindex.rb new file mode 100644 index 00000000000..98c6089f141 --- /dev/null +++ b/lib/prism/polyfill/byteindex.rb @@ -0,0 +1,13 @@ +# frozen_string_literal: true + +# Polyfill for String#byteindex, which didn't exist until Ruby 3.2. +if !("".respond_to?(:byteindex)) + String.include( + Module.new { + def byteindex(needle, offset = 0) + charindex = index(needle, offset) + slice(0...charindex).bytesize if charindex + end + } + ) +end diff --git a/lib/prism/polyfill/string.rb b/lib/prism/polyfill/unpack1.rb similarity index 100% rename from lib/prism/polyfill/string.rb rename to lib/prism/polyfill/unpack1.rb diff --git a/prism.gemspec b/prism.gemspec index 18d0c00a411..8a3efe826d3 100644 --- a/prism.gemspec +++ b/prism.gemspec @@ -86,7 +86,8 @@ Gem::Specification.new do |spec| "lib/prism/parse_result/comments.rb", "lib/prism/parse_result/newlines.rb", "lib/prism/pattern.rb", - "lib/prism/polyfill/string.rb", + "lib/prism/polyfill/byteindex.rb", + "lib/prism/polyfill/unpack1.rb", "lib/prism/reflection.rb", "lib/prism/serialize.rb", "lib/prism/translation.rb", diff --git a/rbi/prism/parse_result.rbi b/rbi/prism/parse_result.rbi index 61d125c331e..73e0be1bc70 100644 --- a/rbi/prism/parse_result.rbi +++ b/rbi/prism/parse_result.rbi @@ -44,7 +44,7 @@ class Prism::Source def code_units_column(byte_offset, encoding); end end -class Prism::ASCIISource < Source +class Prism::ASCIISource < Prism::Source sig { params(byte_offset: Integer).returns(Integer) } def character_offset(byte_offset); end @@ -154,6 +154,9 @@ class Prism::Location sig { params(other: Prism::Location).returns(Prism::Location) } def join(other); end + + sig { params(string: String).returns(Prism::Location) } + def adjoin(string); end end class Prism::Comment diff --git a/sig/prism/parse_result.rbs b/sig/prism/parse_result.rbs index c475fa597b9..3487d5412ca 100644 --- a/sig/prism/parse_result.rbs +++ b/sig/prism/parse_result.rbs @@ -55,6 +55,7 @@ module Prism def deconstruct_keys: (Array[Symbol]? keys) -> Hash[Symbol, untyped] def pretty_print: (untyped q) -> untyped def join: (Location other) -> Location + def adjoin: (String string) -> Location end class Comment diff --git a/templates/lib/prism/node.rb.erb b/templates/lib/prism/node.rb.erb index 4e07c397e6d..50e622a8d12 100644 --- a/templates/lib/prism/node.rb.erb +++ b/templates/lib/prism/node.rb.erb @@ -83,7 +83,7 @@ module Prism # Important to note is that the column given to this method should be in # bytes, as opposed to characters or code units. def tunnel(line, column) - queue = [self] + queue = [self] #: Array[Prism::node] result = [] while (node = queue.shift) diff --git a/templates/lib/prism/serialize.rb.erb b/templates/lib/prism/serialize.rb.erb index 29ae5356ba4..578e7d2e70f 100644 --- a/templates/lib/prism/serialize.rb.erb +++ b/templates/lib/prism/serialize.rb.erb @@ -1,5 +1,5 @@ require "stringio" -require_relative "polyfill/string" +require_relative "polyfill/unpack1" module Prism # A module responsible for deserializing parse results. diff --git a/test/prism/ruby_api_test.rb b/test/prism/ruby_api_test.rb index 538e715ca3f..a1e2592d3d4 100644 --- a/test/prism/ruby_api_test.rb +++ b/test/prism/ruby_api_test.rb @@ -285,6 +285,19 @@ def test_node_tunnel assert_equal [ProgramNode, StatementsNode, CallNode, ArgumentsNode, CallNode, ArgumentsNode], tunnel.map(&:class) end + def test_location_adjoin + program = Prism.parse("foo.bar = 1").value + + location = program.statements.body.first.message_loc + adjoined = location.adjoin("=") + + assert_kind_of Location, adjoined + refute_equal location, adjoined + + assert_equal 4, adjoined.start_offset + assert_equal 9, adjoined.end_offset + end + private def parse_expression(source)