Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix return type of Iterator#chunk and Enumerable#chunks without Drop #13506

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions spec/std/enumerable_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ describe "Enumerable" do

it "drop all" do
result = [1, 2].chunk { Enumerable::Chunk::Drop }.to_a
result.should be_a(Array(Tuple(NoReturn, Array(Int32))))
result.size.should eq 0
end

Expand All @@ -192,14 +193,14 @@ describe "Enumerable" do

it "reuses true" do
iter = [1, 1, 2, 3, 3].chunk(reuse: true, &.itself)
a = iter.next.as(Tuple)
a = iter.next.should be_a(Tuple(Int32, Array(Int32)))
a.should eq({1, [1, 1]})

b = iter.next.as(Tuple)
b = iter.next.should be_a(Tuple(Int32, Array(Int32)))
b.should eq({2, [2]})
b[1].should be(a[1])

c = iter.next.as(Tuple)
c = iter.next.should be_a(Tuple(Int32, Array(Int32)))
c.should eq({3, [3, 3]})
c[1].should be(a[1])
end
Expand Down Expand Up @@ -239,6 +240,7 @@ describe "Enumerable" do

it "drop all" do
result = [1, 2].chunks { Enumerable::Chunk::Drop }
result.should be_a(Array(Tuple(NoReturn, Array(Int32))))
result.size.should eq 0
end

Expand Down
42 changes: 26 additions & 16 deletions src/enumerable.cr
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ module Enumerable(T)
#
# See also: `Iterator#chunk`.
def chunks(&block : T -> U) forall U
res = [] of Tuple(U, Array(T))
chunks_internal(block) { |k, v| res << {k, v} }
res = [] of Tuple(typeof(Chunk.key_type(self, block)), Array(T))
chunks_internal(block) { |*kv| res << kv }
straight-shoota marked this conversation as resolved.
Show resolved Hide resolved
res
end

Expand Down Expand Up @@ -166,7 +166,6 @@ module Enumerable(T)
end

def init(key, val)
return if key == Drop
@key = key

if @reuse
Expand All @@ -190,33 +189,44 @@ module Enumerable(T)

def same_as?(key) : Bool
return false unless @initialized
return false if key.in?(Alone, Drop)
return false if key.is_a?(Alone.class) || key.is_a?(Drop.class)
@key == key
end

def reset
@initialized = false
@data.clear
def acc(key, val, &)
straight-shoota marked this conversation as resolved.
Show resolved Hide resolved
if same_as?(key)
add(val)
else
if tuple = fetch
yield *tuple
end

init(key, val) unless key.is_a?(Drop.class)
end
end
end

def self.key_type(ary, block)
ary.each do |item|
key = block.call(item)
::raise "" if key.is_a?(Drop.class)
return key
end
::raise ""
end
end

private def chunks_internal(original_block : T -> U, &) forall U
acc = Chunk::Accumulator(T, U).new
acc = Chunk::Accumulator(T, typeof(Chunk.key_type(self, original_block))).new
each do |val|
key = original_block.call(val)
if acc.same_as?(key)
acc.add(val)
else
if tuple = acc.fetch
yield(*tuple)
end
acc.init(key, val)
acc.acc(key, val) do |*tuple|
yield *tuple
end
end

if tuple = acc.fetch
yield(*tuple)
yield *tuple
end
end

Expand Down
24 changes: 10 additions & 14 deletions src/iterator.cr
Original file line number Diff line number Diff line change
Expand Up @@ -1455,21 +1455,22 @@ module Iterator(T)
#
# See also: `Enumerable#chunks`.
def chunk(reuse = false, &block : T -> U) forall T, U
ChunkIterator(typeof(self), T, U).new(self, reuse, &block)
ChunkIterator(typeof(self), T, U, typeof(::Enumerable::Chunk.key_type(self, block))).new(self, reuse, &block)
end

private class ChunkIterator(I, T, U)
include Iterator(Tuple(U, Array(T)))
private class ChunkIterator(I, T, U, V)
include Iterator(Tuple(V, Array(T)))
@iterator : I
@init : {U, T}?
@init : {V, T}?

def initialize(@iterator : Iterator(T), reuse, &@original_block : T -> U)
@acc = Enumerable::Chunk::Accumulator(T, U).new(reuse)
@acc = ::Enumerable::Chunk::Accumulator(T, V).new(reuse)
end

def next
if init = @init
@acc.init(*init)
k, v = init
@acc.init(k, v)
@init = nil
end

Expand All @@ -1481,24 +1482,19 @@ module Iterator(T)
else
tuple = @acc.fetch
if tuple
@init = {key, val}
@init = {key, val} unless key.is_a?(::Enumerable::Chunk::Drop.class)
return tuple
else
@acc.init(key, val)
@acc.init(key, val) unless key.is_a?(::Enumerable::Chunk::Drop.class)
end
end
end

if tuple = @acc.fetch
return tuple
end
stop
end

private def init_state
@init = nil
@acc.reset
self
stop
end
end

Expand Down