From 14408e5eab597ad494864acde969a0229d9557c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20M=C3=BCller?= Date: Wed, 14 Jun 2023 15:11:06 +0200 Subject: [PATCH] Fix return type of `Iterator#chunk` and `Enumerable#chunks` without `Drop` (#13506) --- spec/std/enumerable_spec.cr | 8 ++++--- src/enumerable.cr | 42 +++++++++++++++++++++++-------------- src/iterator.cr | 24 +++++++++------------ 3 files changed, 41 insertions(+), 33 deletions(-) diff --git a/spec/std/enumerable_spec.cr b/spec/std/enumerable_spec.cr index 41ba27e5e1f8..435a056191e3 100644 --- a/spec/std/enumerable_spec.cr +++ b/spec/std/enumerable_spec.cr @@ -177,6 +177,7 @@ describe "Enumerable" do it "drop all" do result = [1, 2].chunk { Enumerable::Chunk::Drop }.to_a + result.should be_a(Array(Tuple(NoReturn, Array(Int32)))) result.size.should eq 0 end @@ -192,14 +193,14 @@ describe "Enumerable" do it "reuses true" do iter = [1, 1, 2, 3, 3].chunk(reuse: true, &.itself) - a = iter.next.as(Tuple) + a = iter.next.should be_a(Tuple(Int32, Array(Int32))) a.should eq({1, [1, 1]}) - b = iter.next.as(Tuple) + b = iter.next.should be_a(Tuple(Int32, Array(Int32))) b.should eq({2, [2]}) b[1].should be(a[1]) - c = iter.next.as(Tuple) + c = iter.next.should be_a(Tuple(Int32, Array(Int32))) c.should eq({3, [3, 3]}) c[1].should be(a[1]) end @@ -239,6 +240,7 @@ describe "Enumerable" do it "drop all" do result = [1, 2].chunks { Enumerable::Chunk::Drop } + result.should be_a(Array(Tuple(NoReturn, Array(Int32)))) result.size.should eq 0 end diff --git a/src/enumerable.cr b/src/enumerable.cr index d71d93ea48c7..d7f978bd088d 100644 --- a/src/enumerable.cr +++ b/src/enumerable.cr @@ -131,8 +131,8 @@ module Enumerable(T) # # See also: `Iterator#chunk`. def chunks(&block : T -> U) forall U - res = [] of Tuple(U, Array(T)) - chunks_internal(block) { |k, v| res << {k, v} } + res = [] of Tuple(typeof(Chunk.key_type(self, block)), Array(T)) + chunks_internal(block) { |*kv| res << kv } res end @@ -166,7 +166,6 @@ module Enumerable(T) end def init(key, val) - return if key == Drop @key = key if @reuse @@ -190,33 +189,44 @@ module Enumerable(T) def same_as?(key) : Bool return false unless @initialized - return false if key.in?(Alone, Drop) + return false if key.is_a?(Alone.class) || key.is_a?(Drop.class) @key == key end - def reset - @initialized = false - @data.clear + def acc(key, val, &) + if same_as?(key) + add(val) + else + if tuple = fetch + yield *tuple + end + + init(key, val) unless key.is_a?(Drop.class) + end end end + + def self.key_type(ary, block) + ary.each do |item| + key = block.call(item) + ::raise "" if key.is_a?(Drop.class) + return key + end + ::raise "" + end end private def chunks_internal(original_block : T -> U, &) forall U - acc = Chunk::Accumulator(T, U).new + acc = Chunk::Accumulator(T, typeof(Chunk.key_type(self, original_block))).new each do |val| key = original_block.call(val) - if acc.same_as?(key) - acc.add(val) - else - if tuple = acc.fetch - yield(*tuple) - end - acc.init(key, val) + acc.acc(key, val) do |*tuple| + yield *tuple end end if tuple = acc.fetch - yield(*tuple) + yield *tuple end end diff --git a/src/iterator.cr b/src/iterator.cr index 22b215440065..1d7d54cdcb58 100644 --- a/src/iterator.cr +++ b/src/iterator.cr @@ -1455,21 +1455,22 @@ module Iterator(T) # # See also: `Enumerable#chunks`. def chunk(reuse = false, &block : T -> U) forall T, U - ChunkIterator(typeof(self), T, U).new(self, reuse, &block) + ChunkIterator(typeof(self), T, U, typeof(::Enumerable::Chunk.key_type(self, block))).new(self, reuse, &block) end - private class ChunkIterator(I, T, U) - include Iterator(Tuple(U, Array(T))) + private class ChunkIterator(I, T, U, V) + include Iterator(Tuple(V, Array(T))) @iterator : I - @init : {U, T}? + @init : {V, T}? def initialize(@iterator : Iterator(T), reuse, &@original_block : T -> U) - @acc = Enumerable::Chunk::Accumulator(T, U).new(reuse) + @acc = ::Enumerable::Chunk::Accumulator(T, V).new(reuse) end def next if init = @init - @acc.init(*init) + k, v = init + @acc.init(k, v) @init = nil end @@ -1481,10 +1482,10 @@ module Iterator(T) else tuple = @acc.fetch if tuple - @init = {key, val} + @init = {key, val} unless key.is_a?(::Enumerable::Chunk::Drop.class) return tuple else - @acc.init(key, val) + @acc.init(key, val) unless key.is_a?(::Enumerable::Chunk::Drop.class) end end end @@ -1492,13 +1493,8 @@ module Iterator(T) if tuple = @acc.fetch return tuple end - stop - end - private def init_state - @init = nil - @acc.reset - self + stop end end