Skip to content

Commit

Permalink
Fix return type of Iterator#chunk and Enumerable#chunks without `…
Browse files Browse the repository at this point in the history
…Drop` (#13506)
  • Loading branch information
straight-shoota authored Jun 14, 2023
1 parent 15cd748 commit 2f38e16
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 33 deletions.
8 changes: 5 additions & 3 deletions spec/std/enumerable_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ describe "Enumerable" do

it "drop all" do
result = [1, 2].chunk { Enumerable::Chunk::Drop }.to_a
result.should be_a(Array(Tuple(NoReturn, Array(Int32))))
result.size.should eq 0
end

Expand All @@ -192,14 +193,14 @@ describe "Enumerable" do

it "reuses true" do
iter = [1, 1, 2, 3, 3].chunk(reuse: true, &.itself)
a = iter.next.as(Tuple)
a = iter.next.should be_a(Tuple(Int32, Array(Int32)))
a.should eq({1, [1, 1]})

b = iter.next.as(Tuple)
b = iter.next.should be_a(Tuple(Int32, Array(Int32)))
b.should eq({2, [2]})
b[1].should be(a[1])

c = iter.next.as(Tuple)
c = iter.next.should be_a(Tuple(Int32, Array(Int32)))
c.should eq({3, [3, 3]})
c[1].should be(a[1])
end
Expand Down Expand Up @@ -239,6 +240,7 @@ describe "Enumerable" do

it "drop all" do
result = [1, 2].chunks { Enumerable::Chunk::Drop }
result.should be_a(Array(Tuple(NoReturn, Array(Int32))))
result.size.should eq 0
end

Expand Down
42 changes: 26 additions & 16 deletions src/enumerable.cr
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ module Enumerable(T)
#
# See also: `Iterator#chunk`.
def chunks(&block : T -> U) forall U
res = [] of Tuple(U, Array(T))
chunks_internal(block) { |k, v| res << {k, v} }
res = [] of Tuple(typeof(Chunk.key_type(self, block)), Array(T))
chunks_internal(block) { |*kv| res << kv }
res
end

Expand Down Expand Up @@ -166,7 +166,6 @@ module Enumerable(T)
end

def init(key, val)
return if key == Drop
@key = key

if @reuse
Expand All @@ -190,33 +189,44 @@ module Enumerable(T)

def same_as?(key) : Bool
return false unless @initialized
return false if key.in?(Alone, Drop)
return false if key.is_a?(Alone.class) || key.is_a?(Drop.class)
@key == key
end

def reset
@initialized = false
@data.clear
def acc(key, val, &)
if same_as?(key)
add(val)
else
if tuple = fetch
yield *tuple
end

init(key, val) unless key.is_a?(Drop.class)
end
end
end

def self.key_type(ary, block)
ary.each do |item|
key = block.call(item)
::raise "" if key.is_a?(Drop.class)
return key
end
::raise ""
end
end

private def chunks_internal(original_block : T -> U, &) forall U
acc = Chunk::Accumulator(T, U).new
acc = Chunk::Accumulator(T, typeof(Chunk.key_type(self, original_block))).new
each do |val|
key = original_block.call(val)
if acc.same_as?(key)
acc.add(val)
else
if tuple = acc.fetch
yield(*tuple)
end
acc.init(key, val)
acc.acc(key, val) do |*tuple|
yield *tuple
end
end

if tuple = acc.fetch
yield(*tuple)
yield *tuple
end
end

Expand Down
24 changes: 10 additions & 14 deletions src/iterator.cr
Original file line number Diff line number Diff line change
Expand Up @@ -1455,21 +1455,22 @@ module Iterator(T)
#
# See also: `Enumerable#chunks`.
def chunk(reuse = false, &block : T -> U) forall T, U
ChunkIterator(typeof(self), T, U).new(self, reuse, &block)
ChunkIterator(typeof(self), T, U, typeof(::Enumerable::Chunk.key_type(self, block))).new(self, reuse, &block)
end

private class ChunkIterator(I, T, U)
include Iterator(Tuple(U, Array(T)))
private class ChunkIterator(I, T, U, V)
include Iterator(Tuple(V, Array(T)))
@iterator : I
@init : {U, T}?
@init : {V, T}?

def initialize(@iterator : Iterator(T), reuse, &@original_block : T -> U)
@acc = Enumerable::Chunk::Accumulator(T, U).new(reuse)
@acc = ::Enumerable::Chunk::Accumulator(T, V).new(reuse)
end

def next
if init = @init
@acc.init(*init)
k, v = init
@acc.init(k, v)
@init = nil
end

Expand All @@ -1481,24 +1482,19 @@ module Iterator(T)
else
tuple = @acc.fetch
if tuple
@init = {key, val}
@init = {key, val} unless key.is_a?(::Enumerable::Chunk::Drop.class)
return tuple
else
@acc.init(key, val)
@acc.init(key, val) unless key.is_a?(::Enumerable::Chunk::Drop.class)
end
end
end

if tuple = @acc.fetch
return tuple
end
stop
end

private def init_state
@init = nil
@acc.reset
self
stop
end
end

Expand Down

0 comments on commit 2f38e16

Please sign in to comment.