diff --git a/base/asyncmap.jl b/base/asyncmap.jl index a078021fd66ce..4eb0cd35b0a0b 100644 --- a/base/asyncmap.jl +++ b/base/asyncmap.jl @@ -209,8 +209,8 @@ function next(itr::AsyncGenerator, state::AsyncGeneratorState) return (r, state) end -iteratorsize(::Type{AsyncGenerator}) = SizeUnknown() - +iteratorsize(itr::AsyncGenerator) = iteratorsize(itr.collector.enumerator) +size(itr::AsyncGenerator) = size(itr.collector.enumerator) """ asyncmap(f, c...) -> collection diff --git a/base/channels.jl b/base/channels.jl index 2648df4403eb5..e7ebaf11a0e94 100644 --- a/base/channels.jl +++ b/base/channels.jl @@ -179,11 +179,19 @@ eltype{T}(::Type{Channel{T}}) = T show(io::IO, c::Channel) = print(io, "$(typeof(c))(sz_max:$(c.sz_max),sz_curr:$(n_avail(c)))") -start{T}(c::Channel{T}) = Ref{Nullable{T}}() -function done(c::Channel, state::Ref) +type ChannelState{T} + hasval::Bool + val::T + ChannelState(x) = new(x) +end + +start{T}(c::Channel{T}) = ChannelState{T}(false) +function done(c::Channel, state::ChannelState) try # we are waiting either for more data or channel to be closed - state[] = take!(c) + state.hasval && return false + state.val = take!(c) + state.hasval = true return false catch e if isa(e, InvalidStateException) && e.state==:closed @@ -193,6 +201,6 @@ function done(c::Channel, state::Ref) end end end -next{T}(c::Channel{T}, state) = (v=get(state[]); state[]=nothing; (v, state)) +next{T}(c::Channel{T}, state) = (v=state.val; state.hasval=false; (v, state)) iteratorsize{C<:Channel}(::Type{C}) = SizeUnknown() diff --git a/test/channels.jl b/test/channels.jl index 620b433b2795c..a134dacfd0576 100644 --- a/test/channels.jl +++ b/test/channels.jl @@ -61,6 +61,21 @@ results=[] end @test sum(results) == 15 +# Test channel iterator with done() being called multiple times +# This needs to be explicitly tested since `take!` is called +# in `done()` and not `next()` +c=Channel(32); foreach(i->put!(c,i), 1:10); close(c) +s=start(c) +@test done(c,s) == false +res = Int[] +while !done(c,s) + @test done(c,s) == false + v,s = next(c,s) + push!(res,v) +end +@test res == Int[1:10...] + + # Testing timedwait on multiple channels @sync begin rr1 = Channel(1) diff --git a/test/parallel_exec.jl b/test/parallel_exec.jl index fa3100c5fe480..84f02a8b64af6 100644 --- a/test/parallel_exec.jl +++ b/test/parallel_exec.jl @@ -777,6 +777,18 @@ end # Test asyncmap @test allunique(asyncmap(x->(sleep(1.0);object_id(current_task())), 1:10)) +# check whether shape is retained +a=rand(2,2) +b=asyncmap(identity, a) +@test a == b +@test size(a) == size(b) + +# check with an iterator that does not implement size() +c=Channel(32); foreach(i->put!(c,i), 1:10); close(c) +b=asyncmap(identity, c) +@test Int[1:10...] == b +@test size(b) == (10,) + # CachingPool tests wp = CachingPool(workers()) @test [1:100...] == pmap(wp, x->x, 1:100)