Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use simpler faster Rabin-Karp-like search for short needle #13820

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
133 changes: 75 additions & 58 deletions src/string.cr
Original file line number Diff line number Diff line change
Expand Up @@ -3312,15 +3312,20 @@ class String
# Update rolling hash for Rabin-Karp algorithm `String#index`.
private macro update_hash(n)
{% for i in 1..n %}
{% if i != 1 %}
byte = head_pointer.value
{% end %}
byte = head_pointer.value
HertzDevil marked this conversation as resolved.
Show resolved Hide resolved
hash = hash &* PRIME_RK &+ pointer.value &- pow &* byte
pointer += 1
head_pointer += 1
{% end %}
end

private macro update_simplehash(n)
{% for i in 1..n %}
hash = (hash << 8) | pointer.value
pointer += 1
{% end %}
end

# Returns the index of the _first_ occurrence of *search* in the string, or `nil` if not present.
# If *offset* is present, it defines the position to start the search.
#
Expand Down Expand Up @@ -3354,23 +3359,45 @@ class String
nil
end

private def index_short(hash_type, offset : Int32, pointer : UInt8*, end_pointer : UInt8*, search, &)
search_hash = hash_type.new(0)
hash = hash_type.new(0)
mask = hash_type.new(0)

search.each_byte do |b|
search_hash = (search_hash << 8) | b
hash = (hash << 8) | pointer.value
mask = (mask << 8) | 0xff
pointer += 1
end

while true
straight-shoota marked this conversation as resolved.
Show resolved Hide resolved
return offset if (hash & mask) == search_hash

char_bytesize = yield pointer
return if pointer + char_bytesize > end_pointer
case char_bytesize
when 1 then update_simplehash 1
when 2 then update_simplehash 2
when 3 then update_simplehash 3
else update_simplehash 4
end
straight-shoota marked this conversation as resolved.
Show resolved Hide resolved

offset &+= 1
end
end

# :ditto:
def index(search : String, offset = 0)
offset += size if offset < 0
return if offset < 0

return size < offset ? nil : offset if search.empty?
return index(search[0], offset) if search.size == 1 && search.valid_encoding?

# Rabin-Karp algorithm
# https://en.wikipedia.org/wiki/Rabin%E2%80%93Karp_algorithm

# calculate a rolling hash of search text (needle)
search_hash = 0u32
search.each_byte do |b|
search_hash = search_hash &* PRIME_RK &+ b
end
pow = PRIME_RK &** search.bytesize

# Find start index with offset
char_index = 0
pointer = to_unsafe
Expand All @@ -3381,24 +3408,33 @@ class String
char_index += 1
end

head_pointer = pointer
return if pointer + search.bytesize > end_pointer

# calculate a rolling hash of this text (haystack)
if search.bytesize <= 8
search_bytesize = search.bytesize
return index_short(UInt64, char_index, pointer, end_pointer, search) do |pointer|
String.char_bytesize_at(pointer - search_bytesize)
end
end

head_pointer = pointer
search_hash = 0u32
hash = 0u32
hash_end_pointer = pointer + search.bytesize
return if hash_end_pointer > end_pointer
while pointer < hash_end_pointer

# calculate a rolling hash of search text (needle) and this text (haystack)
search.each_byte do |b|
search_hash = search_hash &* PRIME_RK &+ b
hash = hash &* PRIME_RK &+ pointer.value
pointer += 1
end
pow = PRIME_RK &** search.bytesize

while true
# check hash equality and real string equality
if hash == search_hash && head_pointer.memcmp(search.to_unsafe, search.bytesize) == 0
return char_index
end

byte = head_pointer.value
char_bytesize = String.char_bytesize_at(head_pointer)
return if pointer + char_bytesize > end_pointer
case char_bytesize
Expand Down Expand Up @@ -3685,12 +3721,7 @@ class String
offset += bytesize if offset < 0
return if offset < 0

offset.upto(bytesize - 1) do |i|
if to_unsafe[i] == byte
return i
end
end
nil
to_slice.fast_index(byte.to_u8, offset)
end

# Returns the index of the _first_ occurrence of *char* in the string, or `nil` if not present.
Expand All @@ -3708,32 +3739,16 @@ class String
# "Dizzy Miss Lizzy".byte_index('z', -4) # => 13
# "Dizzy Miss Lizzy".byte_index('z', -17) # => nil
# ```
def byte_index(char : Char, offset = 0) : Int32?
return byte_index(char.ord, offset) if char.ascii?
def byte_index(char search : Char, offset = 0) : Int32?
return byte_index(search.ord, offset) if search.ascii?

offset += bytesize if offset < 0
return if offset < 0
return if offset + char.bytesize > bytesize

# Simplified "Rabin-Karp" algorithm
search_hash = 0u32
search_mask = 0u32
hash = 0u32
char.each_byte do |byte|
search_hash = (search_hash << 8) | byte
search_mask = (search_mask << 8) | 0xff
hash = (hash << 8) | to_unsafe[offset]
offset += 1
end
return if offset + search.bytesize > bytesize

offset.upto(bytesize) do |i|
if (hash & search_mask) == search_hash
return i - char.bytesize
end
# rely on zero terminating byte
hash = (hash << 8) | to_unsafe[i]
end
nil
pointer = to_unsafe + offset
end_pointer = to_unsafe + bytesize
index_short(UInt32, offset, pointer, end_pointer, search) { 1 }
end

# Returns the byte index of *search* in the string, or `nil` if the string is not present.
Expand All @@ -3754,27 +3769,31 @@ class String
return if offset < 0

return bytesize < offset ? nil : offset if search.empty?
return byte_index(search.to_unsafe[0], offset) if search.bytesize == 1

# Rabin-Karp algorithm
# https://en.wikipedia.org/wiki/Rabin%E2%80%93Karp_algorithm

# calculate a rolling hash of search text (needle)
pointer = to_unsafe + offset
end_pointer = to_unsafe + bytesize
return if pointer + search.bytesize > end_pointer

search_hash = 0u32
search.each_byte do |b|
search_hash = search_hash &* PRIME_RK &+ b
hash = 0u32

if search.bytesize <= 8
return index_short(UInt64, offset, pointer, end_pointer, search) { 1 }
end
pow = PRIME_RK &** search.bytesize

# calculate a rolling hash of this text (haystack)
pointer = head_pointer = to_unsafe + offset
hash_end_pointer = pointer + search.bytesize
end_pointer = to_unsafe + bytesize
hash = 0u32
return if hash_end_pointer > end_pointer
while pointer < hash_end_pointer
head_pointer = pointer

# calculate a rolling hash of search text (needle) and this text (haystack)
search.each_byte do |b|
search_hash = search_hash &* PRIME_RK &+ b
hash = hash &* PRIME_RK &+ pointer.value
pointer += 1
end
pow = PRIME_RK &** search.bytesize

while true
# check hash equality and real string equality
Expand All @@ -3785,9 +3804,7 @@ class String
return if pointer >= end_pointer

# update a rolling hash of this text (haystack)
hash = hash &* PRIME_RK &+ pointer.value &- pow &* head_pointer.value
pointer += 1
head_pointer += 1
update_hash 1
offset += 1
end

Expand Down
Loading