Skip to content
This repository has been archived by the owner on Sep 1, 2020. It is now read-only.

Commit

Permalink
length and generated iteratorsize for chain
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Oct 19, 2016
1 parent dd9024c commit 7e2744e
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 7 deletions.
26 changes: 22 additions & 4 deletions src/Iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ if VERSION < v"0.5.0-dev+3305"
error("Do not call this on older versions")
end
else
import Base: iteratorsize, SizeUnknown, IsInfinite, HasLength
import Base: iteratorsize, SizeUnknown, IsInfinite,
HasLength, HasShape
end


Expand Down Expand Up @@ -105,17 +106,34 @@ done(it::RepeatCallForever, state) = false


# Concatenate the output of n iterators

immutable Chain{T<:Tuple}
xss::T
end

iteratorsize{T<:Chain}(::Type{T}) = SizeUnknown()
if VERSION >= v"0.5.0-dev+3305"
# corner case of empty chain
iteratorsize(c::Chain) = iteratorsize(typeof(c))
iteratorsize{T}(::Type{Chain{T}}) = _chain_is(T)

eltype{T}(::Type{Chain{T}}) = typejoin([eltype(t) for t in T.parameters]...)
@generated function _chain_is{T}(t::Type{T})
for itype in T.types
if iteratorsize(itype) == IsInfinite()
return :(IsInfinite())
elseif iteratorsize(itype) == SizeUnknown()
return :(SizeUnknown())
end
end
return :(HasLength())
end
end

chain(xss...) = Chain(xss)

length(it::Chain{Tuple{}}) = 0
length(it::Chain) = sum(length, it.xss)

eltype{T}(::Type{Chain{T}}) = typejoin([eltype(t) for t in T.parameters]...)

function start(it::Chain)
i = 1
xs_state = nothing
Expand Down
54 changes: 51 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
include("../src/Iterators.jl")
using Iterators, Base.Test

if VERSION >= v"0.5.0-dev+3305"
import Base: IsInfinite, SizeUnknown, HasLength, iteratorsize, HasShape
end
# count
# -----

Expand Down Expand Up @@ -106,7 +109,6 @@ end

# chain
# -----

ch1 = chain(1:2:5, 0.2:0.1:1.6)

@test eltype(ch1) == typejoin(Int, Float64)
Expand All @@ -116,6 +118,53 @@ ch2 = chain(1:0, 1:2:5, 0.2:0.1:1.6)

@test eltype(ch2) == typejoin(Int, Float64)
@test collect(ch2) == [1:2:5; 0.2:0.1:1.6]
@test length(ch2) == length(collect(ch2))
if VERSION >= v"0.5.0-dev+3305"
@test iteratorsize(ch2) == HasLength()
end

ch3 = chain(1:10, 1:10, 1:10)
@test length(ch3) == 30
if VERSION >= v"0.5.0-dev+3305"
@test iteratorsize(ch3) == HasLength()
end

r = countfrom(1)
ch4 = chain(1:10, countfrom(1))
@test_throws MethodError length(ch4)
if VERSION >= v"0.5.0-dev+3305"
@assert iteratorsize(r) == IsInfinite()
@test iteratorsize(ch4) == IsInfinite()
end

ch5 = chain()
@test length(ch5) == 0
if VERSION >= v"0.5.0-dev+3305"
@test iteratorsize(ch5) == HasLength()
end

c = chain(ch1, ch2, ch3)
@test length(c) == length(ch1) + length(ch2) + length(ch3)
#@test collect(c) == [ch1; ch2; ch3] # why doesn't work?
@test collect(c) == [collect(ch1); collect(ch2); collect(ch3)]

r = rand(2,2)
c = chain(r, r)
@test length(c) == 8
@test collect(c) == [vec(r); vec(r)]
if VERSION >= v"0.5.0-dev+3305"
@test iteratorsize(r) == HasShape()
@test iteratorsize(c) == HasLength()
end

if VERSION >= v"0.5.0-dev+3305"
r = Base.flatten(rand(2,2))
c = chain(1:10, r)
@test_throws MethodError length(c)
@test length(collect(c)) == 14
@assert iteratorsize(r) == SizeUnknown()
@test iteratorsize(c) == SizeUnknown()
end

# product
# -------
Expand Down Expand Up @@ -557,4 +606,3 @@ end
@test_chain [1,2,3] Any[] ['w', 'x', 'y', 'z']
@test_chain [1,2,3] Union{}[] ['w', 'x', 'y', 'z']
@test_chain [1,2,3] 4 [('w',3), ('x',2), ('y',1), ('z',0)]

0 comments on commit 7e2744e

Please sign in to comment.