diff --git a/lib/prism/parse_result.rb b/lib/prism/parse_result.rb index e3ba7e7c8e9..46bd33d1db4 100644 --- a/lib/prism/parse_result.rb +++ b/lib/prism/parse_result.rb @@ -120,6 +120,12 @@ def code_units_offset(byte_offset, encoding) end end + # Generate a cache that targets a specific encoding for calculating code + # unit offsets. + def code_units_cache(encoding) + CodeUnitsCache.new(source, encoding) + end + # Returns the column number in code units for the given encoding for the # given byte offset. def code_units_column(byte_offset, encoding) @@ -149,6 +155,76 @@ def find_line(byte_offset) end end + # A cache that can be used to quickly compute code unit offsets from byte + # offsets. It purposefully provides only a single #[] method to access the + # cache in order to minimize surface area. + # + # Note that there are some known issues here that may or may not be addressed + # in the future: + # + # * The first is that there are issues when the cache computes values that are + # not on character boundaries. This can result in subsequent computations + # being off by one or more code units. + # * The second is that this cache is currently unbounded. In theory we could + # introduce some kind of LRU cache to limit the number of entries, but this + # has not yet been implemented. + # + class CodeUnitsCache + class UTF16Counter # :nodoc: + def initialize(source, encoding) + @source = source + @encoding = encoding + end + + def count(byte_offset, byte_length) + @source.byteslice(byte_offset, byte_length).encode(@encoding, invalid: :replace, undef: :replace).bytesize / 2 + end + end + + class LengthCounter # :nodoc: + def initialize(source, encoding) + @source = source + @encoding = encoding + end + + def count(byte_offset, byte_length) + @source.byteslice(byte_offset, byte_length).encode(@encoding, invalid: :replace, undef: :replace).length + end + end + + private_constant :UTF16Counter, :LengthCounter + + # Initialize a new cache with the given source and encoding. + def initialize(source, encoding) + @source = source + @counter = + if encoding == Encoding::UTF_16LE || encoding == Encoding::UTF_16BE + UTF16Counter.new(source, encoding) + else + LengthCounter.new(source, encoding) + end + + @cache = {} + @offsets = [] + end + + # Retrieve the code units offset from the given byte offset. + def [](byte_offset) + @cache[byte_offset] ||= + if (index = @offsets.bsearch_index { |offset| offset > byte_offset }).nil? + @offsets << byte_offset + @counter.count(0, byte_offset) + elsif index == 0 + @offsets.unshift(byte_offset) + @counter.count(0, byte_offset) + else + @offsets.insert(index, byte_offset) + offset = @offsets[index - 1] + @cache[offset] + @counter.count(offset, byte_offset - offset) + end + end + end + # Specialized version of Prism::Source for source code that includes ASCII # characters only. This class is used to apply performance optimizations that # cannot be applied to sources that include multibyte characters. @@ -178,6 +254,13 @@ def code_units_offset(byte_offset, encoding) byte_offset end + # Returns a cache that is the identity function in order to maintain the + # same interface. We can do this because code units are always equivalent to + # byte offsets for ASCII-only sources. + def code_units_cache(encoding) + ->(byte_offset) { byte_offset } + end + # Specialized version of `code_units_column` that does not depend on # `code_units_offset`, which is a more expensive operation. This is # essentially the same as `Prism::Source#column`. @@ -287,6 +370,12 @@ def start_code_units_offset(encoding = Encoding::UTF_16LE) source.code_units_offset(start_offset, encoding) end + # The start offset from the start of the file in code units using the given + # cache to fetch or calculate the value. + def cached_start_code_units_offset(cache) + cache[start_offset] + end + # The byte offset from the beginning of the source where this location ends. def end_offset start_offset + length @@ -303,6 +392,12 @@ def end_code_units_offset(encoding = Encoding::UTF_16LE) source.code_units_offset(end_offset, encoding) end + # The end offset from the start of the file in code units using the given + # cache to fetch or calculate the value. + def cached_end_code_units_offset(cache) + cache[end_offset] + end + # The line number where this location starts. def start_line source.line(start_offset) @@ -337,6 +432,12 @@ def start_code_units_column(encoding = Encoding::UTF_16LE) source.code_units_column(start_offset, encoding) end + # The start column in code units using the given cache to fetch or calculate + # the value. + def cached_start_code_units_column(cache) + cache[start_offset] - cache[source.line_start(start_offset)] + end + # The column number in bytes where this location ends from the start of the # line. def end_column @@ -355,6 +456,12 @@ def end_code_units_column(encoding = Encoding::UTF_16LE) source.code_units_column(end_offset, encoding) end + # The end column in code units using the given cache to fetch or calculate + # the value. + def cached_end_code_units_column(cache) + cache[end_offset] - cache[source.line_start(end_offset)] + end + # Implement the hash pattern matching interface for Location. def deconstruct_keys(keys) { start_offset: start_offset, end_offset: end_offset } @@ -604,6 +711,11 @@ def success? def failure? !success? end + + # Create a code units cache for the given encoding. + def code_units_cache(encoding) + source.code_units_cache(encoding) + end end # This is a result specific to the `parse` and `parse_file` methods. diff --git a/rbi/prism/parse_result.rbi b/rbi/prism/parse_result.rbi index ef47e93bd1a..ef1f051e766 100644 --- a/rbi/prism/parse_result.rbi +++ b/rbi/prism/parse_result.rbi @@ -40,10 +40,21 @@ class Prism::Source sig { params(byte_offset: Integer, encoding: Encoding).returns(Integer) } def code_units_offset(byte_offset, encoding); end + sig { params(encoding: Encoding).returns(T.any(Prism::CodeUnitsCache, T.proc.params(byte_offset: Integer).returns(Integer))) } + def code_units_cache(encoding); end + sig { params(byte_offset: Integer, encoding: Encoding).returns(Integer) } def code_units_column(byte_offset, encoding); end end +class Prism::CodeUnitsCache + sig { params(source: String, encoding: Encoding).void } + def initialize(source, encoding); end + + sig { params(byte_offset: Integer).returns(Integer) } + def [](byte_offset); end +end + class Prism::ASCIISource < Prism::Source sig { params(byte_offset: Integer).returns(Integer) } def character_offset(byte_offset); end @@ -54,6 +65,9 @@ class Prism::ASCIISource < Prism::Source sig { params(byte_offset: Integer, encoding: Encoding).returns(Integer) } def code_units_offset(byte_offset, encoding); end + sig { params(encoding: Encoding).returns(T.any(Prism::CodeUnitsCache, T.proc.params(byte_offset: Integer).returns(Integer))) } + def code_units_cache(encoding); end + sig { params(byte_offset: Integer, encoding: Encoding).returns(Integer) } def code_units_column(byte_offset, encoding); end end @@ -107,6 +121,9 @@ class Prism::Location sig { params(encoding: Encoding).returns(Integer) } def start_code_units_offset(encoding = Encoding::UTF_16LE); end + sig { params(cache: T.any(Prism::CodeUnitsCache, T.proc.params(byte_offset: Integer).returns(Integer))).returns(Integer) } + def cached_start_code_units_offset(cache); end + sig { returns(Integer) } def end_offset; end @@ -116,6 +133,9 @@ class Prism::Location sig { params(encoding: Encoding).returns(Integer) } def end_code_units_offset(encoding = Encoding::UTF_16LE); end + sig { params(cache: T.any(Prism::CodeUnitsCache, T.proc.params(byte_offset: Integer).returns(Integer))).returns(Integer) } + def cached_end_code_units_offset(cache); end + sig { returns(Integer) } def start_line; end @@ -134,6 +154,9 @@ class Prism::Location sig { params(encoding: Encoding).returns(Integer) } def start_code_units_column(encoding = Encoding::UTF_16LE); end + sig { params(cache: T.any(Prism::CodeUnitsCache, T.proc.params(byte_offset: Integer).returns(Integer))).returns(Integer) } + def cached_start_code_units_column(cache); end + sig { returns(Integer) } def end_column; end @@ -143,6 +166,9 @@ class Prism::Location sig { params(encoding: Encoding).returns(Integer) } def end_code_units_column(encoding = Encoding::UTF_16LE); end + sig { params(cache: T.any(Prism::CodeUnitsCache, T.proc.params(byte_offset: Integer).returns(Integer))).returns(Integer) } + def cached_end_code_units_column(cache); end + sig { params(keys: T.nilable(T::Array[Symbol])).returns(T::Hash[Symbol, T.untyped]) } def deconstruct_keys(keys); end @@ -296,6 +322,9 @@ class Prism::Result sig { returns(T::Boolean) } def failure?; end + + sig { params(encoding: Encoding).returns(T.any(Prism::CodeUnitsCache, T.proc.params(byte_offset: Integer).returns(Integer))) } + def code_units_cache(encoding); end end class Prism::ParseResult < Prism::Result diff --git a/sig/prism/_private/parse_result.rbs b/sig/prism/_private/parse_result.rbs index 62e0cdc9177..659bedcfe34 100644 --- a/sig/prism/_private/parse_result.rbs +++ b/sig/prism/_private/parse_result.rbs @@ -5,6 +5,18 @@ module Prism def find_line: (Integer) -> Integer end + class CodeUnitsCache + class UTF16Counter + def initialize: (String source, Encoding encoding) -> void + def count: (Integer byte_offset, Integer byte_length) -> Integer + end + + class LengthCounter + def initialize: (String source, Encoding encoding) -> void + def count: (Integer byte_offset, Integer byte_length) -> Integer + end + end + class Location private diff --git a/sig/prism/parse_result.rbs b/sig/prism/parse_result.rbs index d5b9767a01b..d81fe90966b 100644 --- a/sig/prism/parse_result.rbs +++ b/sig/prism/parse_result.rbs @@ -1,4 +1,8 @@ module Prism + interface _CodeUnitsCache + def []: (Integer byte_offset) -> Integer + end + class Source attr_reader source: String attr_reader start_line: Integer @@ -16,15 +20,22 @@ module Prism def character_offset: (Integer byte_offset) -> Integer def character_column: (Integer byte_offset) -> Integer def code_units_offset: (Integer byte_offset, Encoding encoding) -> Integer + def code_units_cache: (Encoding encoding) -> _CodeUnitsCache def code_units_column: (Integer byte_offset, Encoding encoding) -> Integer def self.for: (String source) -> Source end + class CodeUnitsCache + def initialize: (String source, Encoding encoding) -> void + def []: (Integer byte_offset) -> Integer + end + class ASCIISource < Source def character_offset: (Integer byte_offset) -> Integer def character_column: (Integer byte_offset) -> Integer def code_units_offset: (Integer byte_offset, Encoding encoding) -> Integer + def code_units_cache: (Encoding encoding) -> _CodeUnitsCache def code_units_column: (Integer byte_offset, Encoding encoding) -> Integer end @@ -45,15 +56,23 @@ module Prism def slice: () -> String def slice_lines: () -> String def start_character_offset: () -> Integer + def start_code_units_offset: (Encoding encoding) -> Integer + def cached_start_code_units_offset: (_CodeUnitsCache cache) -> Integer def end_offset: () -> Integer def end_character_offset: () -> Integer + def end_code_units_offset: (Encoding encoding) -> Integer + def cached_end_code_units_offset: (_CodeUnitsCache cache) -> Integer def start_line: () -> Integer def start_line_slice: () -> String def end_line: () -> Integer def start_column: () -> Integer def start_character_column: () -> Integer + def start_code_units_column: (Encoding encoding) -> Integer + def cached_start_code_units_column: (_CodeUnitsCache cache) -> Integer def end_column: () -> Integer def end_character_column: () -> Integer + def end_code_units_column: (Encoding encoding) -> Integer + def cached_end_code_units_column: (_CodeUnitsCache cache) -> Integer def deconstruct_keys: (Array[Symbol]? keys) -> Hash[Symbol, untyped] def pretty_print: (untyped q) -> untyped def join: (Location other) -> Location @@ -125,6 +144,7 @@ module Prism def deconstruct_keys: (Array[Symbol]? keys) -> Hash[Symbol, untyped] def success?: () -> bool def failure?: () -> bool + def code_units_cache: (Encoding encoding) -> _CodeUnitsCache end class ParseResult < Result diff --git a/test/prism/ruby/location_test.rb b/test/prism/ruby/location_test.rb index 3d3e7dd5623..33f844243c0 100644 --- a/test/prism/ruby/location_test.rb +++ b/test/prism/ruby/location_test.rb @@ -140,6 +140,52 @@ def test_code_units assert_equal 7, location.end_code_units_column(Encoding::UTF_32LE) end + def test_cached_code_units + result = Prism.parse("šŸ˜€ + šŸ˜€\nšŸ˜ ||= šŸ˜") + + utf8_cache = result.code_units_cache(Encoding::UTF_8) + utf16_cache = result.code_units_cache(Encoding::UTF_16LE) + utf32_cache = result.code_units_cache(Encoding::UTF_32LE) + + # first šŸ˜€ + location = result.value.statements.body.first.receiver.location + + assert_equal 0, location.cached_start_code_units_offset(utf8_cache) + assert_equal 0, location.cached_start_code_units_offset(utf16_cache) + assert_equal 0, location.cached_start_code_units_offset(utf32_cache) + + assert_equal 1, location.cached_end_code_units_offset(utf8_cache) + assert_equal 2, location.cached_end_code_units_offset(utf16_cache) + assert_equal 1, location.cached_end_code_units_offset(utf32_cache) + + assert_equal 0, location.cached_start_code_units_column(utf8_cache) + assert_equal 0, location.cached_start_code_units_column(utf16_cache) + assert_equal 0, location.cached_start_code_units_column(utf32_cache) + + assert_equal 1, location.cached_end_code_units_column(utf8_cache) + assert_equal 2, location.cached_end_code_units_column(utf16_cache) + assert_equal 1, location.cached_end_code_units_column(utf32_cache) + + # second šŸ˜€ + location = result.value.statements.body.first.arguments.arguments.first.location + + assert_equal 4, location.cached_start_code_units_offset(utf8_cache) + assert_equal 5, location.cached_start_code_units_offset(utf16_cache) + assert_equal 4, location.cached_start_code_units_offset(utf32_cache) + + assert_equal 5, location.cached_end_code_units_offset(utf8_cache) + assert_equal 7, location.cached_end_code_units_offset(utf16_cache) + assert_equal 5, location.cached_end_code_units_offset(utf32_cache) + + assert_equal 4, location.cached_start_code_units_column(utf8_cache) + assert_equal 5, location.cached_start_code_units_column(utf16_cache) + assert_equal 4, location.cached_start_code_units_column(utf32_cache) + + assert_equal 5, location.cached_end_code_units_column(utf8_cache) + assert_equal 7, location.cached_end_code_units_column(utf16_cache) + assert_equal 5, location.cached_end_code_units_column(utf32_cache) + end + def test_code_units_binary_valid_utf8 program = Prism.parse(<<~RUBY).value # -*- encoding: binary -*-