Skip to content
This repository has been archived by the owner on Sep 1, 2020. It is now read-only.

Commit

Permalink
length for chain
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Oct 17, 2016
1 parent c3f0692 commit 40d0d3d
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 8 deletions.
26 changes: 22 additions & 4 deletions src/Iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ if VERSION < v"0.5.0-dev+3305"
error("Do not call this on older versions")
end
else
import Base: iteratorsize, SizeUnknown, IsInfinite, HasLength
import Base: IteratorSize, iteratorsize, SizeUnknown, IsInfinite,
HasLength, HasShape
end


Expand Down Expand Up @@ -106,17 +107,34 @@ done(it::RepeatCallForever, state) = false


# Concatenate the output of n iterators

immutable Chain{T<:Tuple}
xss::T
end

iteratorsize{T<:Chain}(::Type{T}) = SizeUnknown()
if VERSION >= v"0.5.0-dev+3305"
# corner case of empty chain
iteratorsize(c::Chain{Tuple{}}) = HasLength()

eltype{T}(::Type{Chain{T}}) = typejoin([eltype(t) for t in T.parameters]...)
iteratorsize(c::Chain) =_chain_is(iteratorsize(c.xss[1]), c.xss[2:end]...)

_chain_is{I<:Union{HasLength, HasShape}}(isx::I, y...) =
_chain_is(iteratorsize(y[1]), y[2:end]...)

# have to define twice becouse of a bug in julia (#18985)
# _chain_is{I<:Union{HasLength, HasShape}}(isx::I) = HasLength()
_chain_is(isx::HasShape) = HasLength()
_chain_is(isx::HasLength) = HasLength()
# fallback
_chain_is(isx, y...) = SizeUnknown()
end

chain(xss...) = Chain(xss)

length(it::Chain{Tuple{}}) = 0
length(it::Chain) = sum(length, it.xss)

eltype{T}(::Type{Chain{T}}) = typejoin([eltype(t) for t in T.parameters]...)

function start(it::Chain)
i = 1
xs_state = nothing
Expand Down
29 changes: 25 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,38 @@ end

# chain
# -----
ch2 = chain(1:0, 1:2:5, 0.2:0.1:1.6)

@test eltype(ch2) == typejoin(Int, Float64)
@test collect(ch2) == [1:2:5; 0.2:0.1:1.6]

ch1 = chain(1:2:5, 0.2:0.1:1.6)

@test eltype(ch1) == typejoin(Int, Float64)
@test collect(ch1) == [1:2:5; 0.2:0.1:1.6]

ch2 = chain(1:0, 1:2:5, 0.2:0.1:1.6)

@test eltype(ch2) == typejoin(Int, Float64)
@test collect(ch2) == [1:2:5; 0.2:0.1:1.6]
ch3 = chain(1:10, 1:10, 1:10)
@test length(ch3) == 30
if VERSION >= v"0.5.0-dev+3305"
@test Base.iteratorsize(ch3) == Base.HasLength()
end

# test fails on 0.4 (see PR 85)
# add back when dropping 0.4 support
# if VERSION >= v"0.5.0-dev+3305"
# r = ( x for x in 1:10 if isodd(x))
# ch4 = chain(1:10, r)
# @assert Base.iteratorsize(r) == Base.SizeUnknown()
# @test Base.iteratorsize(ch4) == Base.SizeUnknown()
# @test_throws MethodError length(ch4)
# end

ch5 = chain()
@test length(ch5) == 0
if VERSION >= v"0.5.0-dev+3305"
@test Base.iteratorsize(ch5) == Base.HasLength()
end

# product
# -------
Expand Down Expand Up @@ -561,4 +583,3 @@ end
@test_chain [1,2,3] Any[] ['w', 'x', 'y', 'z']
@test_chain [1,2,3] Union{}[] ['w', 'x', 'y', 'z']
@test_chain [1,2,3] 4 [('w',3), ('x',2), ('y',1), ('z',0)]

0 comments on commit 40d0d3d

Please sign in to comment.