diff --git a/base/error.jl b/base/error.jl index 9609293db78182..e3ded22eac3c04 100644 --- a/base/error.jl +++ b/base/error.jl @@ -52,7 +52,7 @@ end """ - retry(f, [condition]; n=3; max_delay=10) + retry(f, [condition]; n=3; max_delay=10) -> Function Returns a lambda that retries function `f` up to `n` times in the event of an exception. If `condition` is a `Type` then retry only @@ -84,7 +84,7 @@ retry(f::Function, t::Type; kw...) = retry(f, e->isa(e, t); kw...) """ - @catch(f) + @catch(f) -> Function Returns a lambda that executes `f` and returns either the result of `f` or an `Exception` thrown by `f`. diff --git a/base/generator.jl b/base/generator.jl index dff08ea9d25422..c0c6f59ccfc164 100644 --- a/base/generator.jl +++ b/base/generator.jl @@ -20,4 +20,6 @@ function next(g::Generator, s) g.f(v), s2 end +generate(f, c...) = Generator(f, c...) + collect(g::Generator) = map(g.f, g.iter) diff --git a/base/iterator.jl b/base/iterator.jl index 5e6bfab3786428..6318e01a501ee3 100644 --- a/base/iterator.jl +++ b/base/iterator.jl @@ -123,6 +123,25 @@ done(i::Rest, st) = done(i.itr, st) eltype{I}(::Type{Rest{I}}) = eltype(I) + +""" + head_and_tail(c, n) -> head, tail + +Returns `head`: the first `n` elements of `c`; +and `tail`: an iterator over the remaining elements. +""" +function head_and_tail(c, n) + head = Vector{eltype(c)}(n) + s = start(c) + i = 0 + while i < n && !done(c, s) + i += 1 + head[i], s = next(c, s) + end + return resize!(head, i), rest(c, s) +end + + # Count -- infinite counting immutable Count{S<:Number} diff --git a/base/mapiterator.jl b/base/mapiterator.jl index 75bde41ed6c85f..3b4f2fc4d79c61 100644 --- a/base/mapiterator.jl +++ b/base/mapiterator.jl @@ -90,7 +90,7 @@ end Apply f to each element of c using at most 100 asynchronous tasks. For multiple collection arguments, apply f elementwise. -The iterator returns results as the become available. +Results are returned by the iterator as they become available. Note: `collect(StreamMapIterator(f, c...; ntasks=1))` is equivalent to `map(f, c...)`. """ @@ -144,3 +144,24 @@ function next(itr::StreamMapIterator, state::StreamMapState) return (r, state) end + + + +""" + asyncgenerate(f, c...) -> iterator + +Apply `@async f` to each element of `c`. +For multiple collection arguments, apply f elementwise. +Results are returned by the iterator as they become available. +""" +asyncgenerate(f, c...) = StreamMapIterator(f, c...) + + + +""" + asyncmap(f, c...) -> collection + +Transform collection `c` by applying `@async f` to each element. +For multiple collection arguments, apply f elementwise. +""" +asyncmap(f, c...) = collect(asyncgenerate(f, c...)) diff --git a/base/multi.jl b/base/multi.jl index d3bd2b4ddeca73..fb21671c792174 100644 --- a/base/multi.jl +++ b/base/multi.jl @@ -1514,101 +1514,16 @@ macro everywhere(ex) end end +#FIXME delete? function pmap_static(f, lsts...) np = nprocs() n = length(lsts[1]) Any[ remotecall(f, PGRP.workers[(i-1)%np+1].id, map(L->L[i], lsts)...) for i = 1:n ] end +#FIXME delete? pmap(f) = f() -# dynamic scheduling by creating a local task to feed work to each processor -# as it finishes. -# example unbalanced workload: -# rsym(n) = (a=rand(n,n);a*a') -# L = {rsym(200),rsym(1000),rsym(200),rsym(1000),rsym(200),rsym(1000),rsym(200),rsym(1000)}; -# pmap(eig, L); -function pmap(f, lsts...; err_retry=true, err_stop=false, pids = workers()) - len = length(lsts) - - results = Dict{Int,Any}() - - busy_workers = fill(false, length(pids)) - busy_workers_ntfy = Condition() - - retryqueue = [] - task_in_err = false - is_task_in_error() = task_in_err - set_task_in_error() = (task_in_err = true) - - nextidx = 0 - getnextidx() = (nextidx += 1) - - states = [start(lsts[idx]) for idx in 1:len] - function getnext_tasklet() - if is_task_in_error() && err_stop - return nothing - elseif !any(idx->done(lsts[idx],states[idx]), 1:len) - nxts = [next(lsts[idx],states[idx]) for idx in 1:len] - for idx in 1:len; states[idx] = nxts[idx][2]; end - nxtvals = [x[1] for x in nxts] - return (getnextidx(), nxtvals) - elseif !isempty(retryqueue) - return shift!(retryqueue) - elseif err_retry - # Handles the condition where we have finished processing the requested lsts as well - # as any retryqueue entries, but there are still some jobs active that may result - # in an error and have to be retried. - while any(busy_workers) - wait(busy_workers_ntfy) - if !isempty(retryqueue) - return shift!(retryqueue) - end - end - return nothing - else - return nothing - end - end - - @sync begin - for (pididx, wpid) in enumerate(pids) - @async begin - tasklet = getnext_tasklet() - while (tasklet !== nothing) - (idx, fvals) = tasklet - busy_workers[pididx] = true - try - results[idx] = remotecall_fetch(f, wpid, fvals...) - catch ex - if err_retry - push!(retryqueue, (idx,fvals, ex)) - else - results[idx] = ex - end - set_task_in_error() - - busy_workers[pididx] = false - notify(busy_workers_ntfy; all=true) - - break # remove this worker from accepting any more tasks - end - - busy_workers[pididx] = false - notify(busy_workers_ntfy; all=true) - - tasklet = getnext_tasklet() - end - end - end - end - - for failure in retryqueue - results[failure[1]] = failure[3] - end - [results[x] for x in 1:nextidx] -end - # Statically split range [1,N] into equal sized chunks for np processors function splitrange(N::Int, np::Int) each = div(N,np) diff --git a/base/pmap.jl b/base/pmap.jl new file mode 100644 index 00000000000000..40bc2056a010ba --- /dev/null +++ b/base/pmap.jl @@ -0,0 +1,89 @@ +# This file is a part of Julia. License is MIT: http://julialang.org/license + + +""" + pgenerate([::WorkerPool,] f, c...) -> iterator + +Apply `f` to each element of `c` in parallel using available workers and tasks. +For multiple collection arguments, apply f elementwise. +Results are returned by the iterator as they become available. +""" +function pgenerate(p::WorkerPool, f, c) + if length(p) == 0 + return asyncgenerate(f, c) + end + batches = batchsplit(c, min_batch_count = length(p) * 3) + return flatten(asyncgenerate(remote(p, b -> asyncmap(f, b)), batches)) +end + +pgenerate(p::WorkerPool, f, c1, c...) = pgenerate(p, a->f(a...), zip(c1, c...)) + +pgenerate(f, c) = pgenerate(default_worker_pool(), f, c...) +pgenerate(f, c1, c...) = pgenerate(a->f(a...), zip(c1, c...)) + + + +""" + pmap([::WorkerPool,] f, c...) + +Transform collection `c` by applying `f` to each element using available +workers and tasks. +For multiple collection arguments, apply f elementwise. +""" +pmap(p::WorkerPool, f, c...) = collect(pgenerate(p, f, c...)) +pmap(f, c...) = pmap(p, f, c...) + +function pmap(f, c...; err_retry=nothing, err_stop=nothing, pids=nothing) + + if err_retry != nothing + depwarn("`err_retry` is deprecated, use `pmap(retry(f), c...)`.", :pmap) + if err_retry == true + f = retry(f) + end + end + + if err_stop != nothing + depwarn("`err_stop` is deprecated, use `pmap(@catch(f), c...).", :pmap) + if err_stop == false + f = @catch(f) + end + end + + if pids == nothing + p = default_worker_pool() + else + depwarn("`pids` is deprecated, use `pmap(::WorkerPool, f, c...).", :pmap) + p = WorkerPool(pids) + end + + return pmap(p, f, c...) +end + + + +""" + batchsplit(c; min_batch_count=0, max_batch_size=100) + +Split a collection into at least `min_batch_count` batches. + +Equivalent to `split(c, batch_size)` when `length(c) >> max_batch_size`. +""" +function batchsplit(c; min_batch_count=1, max_batch_size=100) + + # FIXME Use @require per https://github.com/JuliaLang/julia/pull/15495 + @assert min_batch_count > 0 + @assert max_batch_size > 1 + + # Split collection into batches, then peek at the first few batches... + batches = split(c, max_batch_size) + head, tail = head_and_tail(batches, min_batch_count) + + # If there are not enough batches, use a smaller batch size... + if length(head) < min_batch_count + head = vcat(head...) + batch_size = max(1, div(length(head), min_batch_count)) + return split(head, batch_size) + end + + return flatten((head, tail)) +end diff --git a/base/sysimg.jl b/base/sysimg.jl index 8ad07fae6b7820..d66a6564c7bceb 100644 --- a/base/sysimg.jl +++ b/base/sysimg.jl @@ -234,6 +234,7 @@ importall .Serializer include("channels.jl") include("multi.jl") include("workerpool.jl") +include("pmap.jl") include("managers.jl") include("mapiterator.jl") diff --git a/base/workerpool.jl b/base/workerpool.jl index 1a8f486789b544..d88d840df7c4dc 100644 --- a/base/workerpool.jl +++ b/base/workerpool.jl @@ -1,11 +1,11 @@ # This file is a part of Julia. License is MIT: http://julialang.org/license - type WorkerPool channel::RemoteChannel{Channel{Int}} + count::Int # Create a shared queue of available workers... - WorkerPool() = new(RemoteChannel(()->Channel{Int}(typemax(Int)))) + WorkerPool() = new(RemoteChannel(()->Channel{Int}(typemax(Int))), 0) end @@ -16,7 +16,6 @@ Create a WorkerPool from a vector of worker ids. """ function WorkerPool(workers::Vector{Int}) - # Create a shared queue of available workers... pool = WorkerPool() # Add workers to the pool... @@ -28,12 +27,12 @@ function WorkerPool(workers::Vector{Int}) end -put!(pool::WorkerPool, w::Int) = put!(pool.channel, w) -put!(pool::WorkerPool, w::Worker) = put!(pool.channel, w.id) +put!(pool::WorkerPool, w::Int) = (pool.count += 1; put!(pool.channel, w)) +put!(pool::WorkerPool, w::Worker) = put!(pool, w.id) -isready(pool::WorkerPool) = isready(pool.channel) +length(pool::WorkerPool) = pool.count -take!(pool::WorkerPool) = take!(pool.channel) +isready(pool::WorkerPool) = isready(pool.channel) """ @@ -42,11 +41,11 @@ take!(pool::WorkerPool) = take!(pool.channel) Call `f(args...)` on one of the workers in `pool`. """ function remotecall_fetch(f, pool::WorkerPool, args...) - worker = take!(pool) + worker = take!(pool.channel) try remotecall_fetch(f, worker, args...) finally - put!(pool, worker) + put!(pool.channel, worker) end end @@ -70,9 +69,10 @@ end """ - remote(f) + remote([::WorkerPool,] f) -> Function Returns a lambda that executes function `f` on an available worker using `remotecall_fetch`. """ remote(f) = (args...)->remotecall_fetch(f, default_worker_pool(), args...) +remote(p::WorkerPool, f) = (args...)->remotecall_fetch(f, p, args...) diff --git a/test/parallel_exec.jl b/test/parallel_exec.jl index c72893cdf34e54..dd3a9ba358a250 100644 --- a/test/parallel_exec.jl +++ b/test/parallel_exec.jl @@ -516,7 +516,7 @@ testcpt() @test_throws ArgumentError timedwait(()->false, 0.1, pollint=-0.5) # specify pids for pmap -@test sort(workers()[1:2]) == sort(unique(pmap(x->(sleep(0.1);myid()), 1:10, pids = workers()[1:2]))) +@test sort(workers()[1:2]) == sort(unique(pmap(Base.WorkerPool(workers()[1:2]), x->(sleep(0.1);myid()), 1:10))) # Testing buffered and unbuffered reads # This large array should write directly to the socket