Skip to content

Commit

Permalink
[ci skip] WIP: Refactor pmap
Browse files Browse the repository at this point in the history
New functions:

 - head_and_tail  -- like take and rest but atomic
 - batchsplit     -- like split, but aware of nworkers
 - generate       -- shorthand for creating a Genertor
 - asyncgenerate  -- generate using tasks
 - asyncmap       -- map using tasks
 - pgenerate      -- generate using tasks and workers.

Reimplement pmap:

    pmap(f, c...) = collect(pgenerate(f, c...))

workerpool.jl
 - add length(pool::WorkerPool)
 - add remote(p::WorkerPool, f)

new pmap.jl (to avoid circular dependancy between type definitions in
multi.jl and workerpool.jl)
  • Loading branch information
samoconnor committed Mar 24, 2016
1 parent f880834 commit 87a0dad
Show file tree
Hide file tree
Showing 9 changed files with 148 additions and 101 deletions.
4 changes: 2 additions & 2 deletions base/error.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ end


"""
retry(f, [condition]; n=3; max_delay=10)
retry(f, [condition]; n=3; max_delay=10) -> Function
Returns a lambda that retries function `f` up to `n` times in the
event of an exception. If `condition` is a `Type` then retry only
Expand Down Expand Up @@ -84,7 +84,7 @@ retry(f::Function, t::Type; kw...) = retry(f, e->isa(e, t); kw...)


"""
@catch(f)
@catch(f) -> Function
Returns a lambda that executes `f` and returns either the result of `f` or
an `Exception` thrown by `f`.
Expand Down
2 changes: 2 additions & 0 deletions base/generator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ function next(g::Generator, s)
g.f(v), s2
end

generate(f, c...) = Generator(f, c...)

collect(g::Generator) = map(g.f, g.iter)
19 changes: 19 additions & 0 deletions base/iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,25 @@ done(i::Rest, st) = done(i.itr, st)

eltype{I}(::Type{Rest{I}}) = eltype(I)


"""
head_and_tail(c, n) -> head, tail
Returns `head`: the first `n` elements of `c`;
and `tail`: an iterator over the remaining elements.
"""
function head_and_tail(c, n)
head = Vector{eltype(c)}(n)
s = start(c)
i = 0
while i < n && !done(c, s)
i += 1
head[i], s = next(c, s)
end
return resize!(head, i), rest(c, s)
end


# Count -- infinite counting

immutable Count{S<:Number}
Expand Down
23 changes: 22 additions & 1 deletion base/mapiterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ end
Apply f to each element of c using at most 100 asynchronous tasks.
For multiple collection arguments, apply f elementwise.
The iterator returns results as the become available.
Results are returned by the iterator as they become available.
Note: `collect(StreamMapIterator(f, c...; ntasks=1))` is equivalent to
`map(f, c...)`.
"""
Expand Down Expand Up @@ -144,3 +144,24 @@ function next(itr::StreamMapIterator, state::StreamMapState)

return (r, state)
end



"""
asyncgenerate(f, c...) -> iterator
Apply `@async f` to each element of `c`.
For multiple collection arguments, apply f elementwise.
Results are returned by the iterator as they become available.
"""
asyncgenerate(f, c...) = StreamMapIterator(f, c...)



"""
asyncmap(f, c...) -> collection
Transform collection `c` by applying `@async f` to each element.
For multiple collection arguments, apply f elementwise.
"""
asyncmap(f, c...) = collect(asyncgenerate(f, c...))
89 changes: 2 additions & 87 deletions base/multi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1514,101 +1514,16 @@ macro everywhere(ex)
end
end

#FIXME delete?
function pmap_static(f, lsts...)
np = nprocs()
n = length(lsts[1])
Any[ remotecall(f, PGRP.workers[(i-1)%np+1].id, map(L->L[i], lsts)...) for i = 1:n ]
end

#FIXME delete?
pmap(f) = f()

# dynamic scheduling by creating a local task to feed work to each processor
# as it finishes.
# example unbalanced workload:
# rsym(n) = (a=rand(n,n);a*a')
# L = {rsym(200),rsym(1000),rsym(200),rsym(1000),rsym(200),rsym(1000),rsym(200),rsym(1000)};
# pmap(eig, L);
function pmap(f, lsts...; err_retry=true, err_stop=false, pids = workers())
len = length(lsts)

results = Dict{Int,Any}()

busy_workers = fill(false, length(pids))
busy_workers_ntfy = Condition()

retryqueue = []
task_in_err = false
is_task_in_error() = task_in_err
set_task_in_error() = (task_in_err = true)

nextidx = 0
getnextidx() = (nextidx += 1)

states = [start(lsts[idx]) for idx in 1:len]
function getnext_tasklet()
if is_task_in_error() && err_stop
return nothing
elseif !any(idx->done(lsts[idx],states[idx]), 1:len)
nxts = [next(lsts[idx],states[idx]) for idx in 1:len]
for idx in 1:len; states[idx] = nxts[idx][2]; end
nxtvals = [x[1] for x in nxts]
return (getnextidx(), nxtvals)
elseif !isempty(retryqueue)
return shift!(retryqueue)
elseif err_retry
# Handles the condition where we have finished processing the requested lsts as well
# as any retryqueue entries, but there are still some jobs active that may result
# in an error and have to be retried.
while any(busy_workers)
wait(busy_workers_ntfy)
if !isempty(retryqueue)
return shift!(retryqueue)
end
end
return nothing
else
return nothing
end
end

@sync begin
for (pididx, wpid) in enumerate(pids)
@async begin
tasklet = getnext_tasklet()
while (tasklet !== nothing)
(idx, fvals) = tasklet
busy_workers[pididx] = true
try
results[idx] = remotecall_fetch(f, wpid, fvals...)
catch ex
if err_retry
push!(retryqueue, (idx,fvals, ex))
else
results[idx] = ex
end
set_task_in_error()

busy_workers[pididx] = false
notify(busy_workers_ntfy; all=true)

break # remove this worker from accepting any more tasks
end

busy_workers[pididx] = false
notify(busy_workers_ntfy; all=true)

tasklet = getnext_tasklet()
end
end
end
end

for failure in retryqueue
results[failure[1]] = failure[3]
end
[results[x] for x in 1:nextidx]
end

# Statically split range [1,N] into equal sized chunks for np processors
function splitrange(N::Int, np::Int)
each = div(N,np)
Expand Down
89 changes: 89 additions & 0 deletions base/pmap.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# This file is a part of Julia. License is MIT: http://julialang.org/license


"""
pgenerate([::WorkerPool,] f, c...) -> iterator
Apply `f` to each element of `c` in parallel using available workers and tasks.
For multiple collection arguments, apply f elementwise.
Results are returned by the iterator as they become available.
"""
function pgenerate(p::WorkerPool, f, c)
if length(p) == 0
return asyncgenerate(f, c)
end
batches = batchsplit(c, min_batch_count = length(p) * 3)
return flatten(asyncgenerate(remote(p, b -> asyncmap(f, b)), batches))
end

pgenerate(p::WorkerPool, f, c1, c...) = pgenerate(p, a->f(a...), zip(c1, c...))

pgenerate(f, c) = pgenerate(default_worker_pool(), f, c...)
pgenerate(f, c1, c...) = pgenerate(a->f(a...), zip(c1, c...))



"""
pmap([::WorkerPool,] f, c...)
Transform collection `c` by applying `f` to each element using available
workers and tasks.
For multiple collection arguments, apply f elementwise.
"""
pmap(p::WorkerPool, f, c...) = collect(pgenerate(p, f, c...))
pmap(f, c...) = pmap(p, f, c...)

function pmap(f, c...; err_retry=nothing, err_stop=nothing, pids=nothing)

if err_retry != nothing
depwarn("`err_retry` is deprecated, use `pmap(retry(f), c...)`.", :pmap)
if err_retry == true
f = retry(f)
end
end

if err_stop != nothing
depwarn("`err_stop` is deprecated, use `pmap(@catch(f), c...).", :pmap)
if err_stop == false
f = @catch(f)
end
end

if pids == nothing
p = default_worker_pool()
else
depwarn("`pids` is deprecated, use `pmap(::WorkerPool, f, c...).", :pmap)
p = WorkerPool(pids)
end

return pmap(p, f, c...)
end



"""
batchsplit(c; min_batch_count=0, max_batch_size=100)
Split a collection into at least `min_batch_count` batches.
Equivalent to `split(c, batch_size)` when `length(c) >> max_batch_size`.
"""
function batchsplit(c; min_batch_count=1, max_batch_size=100)

# FIXME Use @require per https://github.com/JuliaLang/julia/pull/15495
@assert min_batch_count > 0
@assert max_batch_size > 1

# Split collection into batches, then peek at the first few batches...
batches = split(c, max_batch_size)
head, tail = head_and_tail(batches, min_batch_count)

# If there are not enough batches, use a smaller batch size...
if length(head) < min_batch_count
head = vcat(head...)
batch_size = max(1, div(length(head), min_batch_count))
return split(head, batch_size)
end

return flatten((head, tail))
end
1 change: 1 addition & 0 deletions base/sysimg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ importall .Serializer
include("channels.jl")
include("multi.jl")
include("workerpool.jl")
include("pmap.jl")
include("managers.jl")
include("mapiterator.jl")

Expand Down
20 changes: 10 additions & 10 deletions base/workerpool.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# This file is a part of Julia. License is MIT: http://julialang.org/license


type WorkerPool
channel::RemoteChannel{Channel{Int}}
count::Int

# Create a shared queue of available workers...
WorkerPool() = new(RemoteChannel(()->Channel{Int}(typemax(Int))))
WorkerPool() = new(RemoteChannel(()->Channel{Int}(typemax(Int))), 0)
end


Expand All @@ -16,7 +16,6 @@ Create a WorkerPool from a vector of worker ids.
"""
function WorkerPool(workers::Vector{Int})

# Create a shared queue of available workers...
pool = WorkerPool()

# Add workers to the pool...
Expand All @@ -28,12 +27,12 @@ function WorkerPool(workers::Vector{Int})
end


put!(pool::WorkerPool, w::Int) = put!(pool.channel, w)
put!(pool::WorkerPool, w::Worker) = put!(pool.channel, w.id)
put!(pool::WorkerPool, w::Int) = (pool.count += 1; put!(pool.channel, w))
put!(pool::WorkerPool, w::Worker) = put!(pool, w.id)

isready(pool::WorkerPool) = isready(pool.channel)
length(pool::WorkerPool) = pool.count

take!(pool::WorkerPool) = take!(pool.channel)
isready(pool::WorkerPool) = isready(pool.channel)


"""
Expand All @@ -42,11 +41,11 @@ take!(pool::WorkerPool) = take!(pool.channel)
Call `f(args...)` on one of the workers in `pool`.
"""
function remotecall_fetch(f, pool::WorkerPool, args...)
worker = take!(pool)
worker = take!(pool.channel)
try
remotecall_fetch(f, worker, args...)
finally
put!(pool, worker)
put!(pool.channel, worker)
end
end

Expand All @@ -70,9 +69,10 @@ end


"""
remote(f)
remote([::WorkerPool,] f) -> Function
Returns a lambda that executes function `f` on an available worker
using `remotecall_fetch`.
"""
remote(f) = (args...)->remotecall_fetch(f, default_worker_pool(), args...)
remote(p::WorkerPool, f) = (args...)->remotecall_fetch(f, p, args...)
2 changes: 1 addition & 1 deletion test/parallel_exec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ testcpt()
@test_throws ArgumentError timedwait(()->false, 0.1, pollint=-0.5)

# specify pids for pmap
@test sort(workers()[1:2]) == sort(unique(pmap(x->(sleep(0.1);myid()), 1:10, pids = workers()[1:2])))
@test sort(workers()[1:2]) == sort(unique(pmap(Base.WorkerPool(workers()[1:2]), x->(sleep(0.1);myid()), 1:10)))

# Testing buffered and unbuffered reads
# This large array should write directly to the socket
Expand Down

0 comments on commit 87a0dad

Please sign in to comment.