Skip to content

Commit

Permalink
Make Distributed.Worker threadsafe (#37905)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonas-schulze authored and vchuravy committed Oct 21, 2020
1 parent 7ac9ac3 commit 8d4a408
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 4 deletions.
8 changes: 4 additions & 4 deletions stdlib/Distributed/src/cluster.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ mutable struct Worker
add_msgs::Array{Any,1}
gcflag::Bool
state::WorkerState
c_state::Condition # wait for state changes
c_state::Event # wait for state changes
ct_time::Float64 # creation time
conn_func::Any # used to setup connections lazily

Expand Down Expand Up @@ -133,7 +133,7 @@ mutable struct Worker
if haskey(map_pid_wrkr, id)
return map_pid_wrkr[id]
end
w=new(id, [], [], false, W_CREATED, Condition(), time(), conn_func)
w=new(id, [], [], false, W_CREATED, Event(), time(), conn_func)
w.initialized = Event()
register_worker(w)
w
Expand All @@ -144,7 +144,7 @@ end

function set_worker_state(w, state)
w.state = state
notify(w.c_state; all=true)
notify(w.c_state)
end

function check_worker_state(w::Worker)
Expand Down Expand Up @@ -189,7 +189,7 @@ function wait_for_conn(w)
timeout = worker_timeout() - (time() - w.ct_time)
timeout <= 0 && error("peer $(w.id) has not connected to $(myid())")

@async (sleep(timeout); notify(w.c_state; all=true))
@async (sleep(timeout); notify(w.c_state))
wait(w.c_state)
w.state === W_CREATED && error("peer $(w.id) didn't connect to $(myid()) within $timeout seconds")
end
Expand Down
1 change: 1 addition & 0 deletions stdlib/Distributed/test/distributed_exec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1711,4 +1711,5 @@ include("splitrange.jl")
# Run topology tests last after removing all workers, since a given
# cluster at any time only supports a single topology.
rmprocs(workers())
include("threads.jl")
include("topology.jl")
54 changes: 54 additions & 0 deletions stdlib/Distributed/test/threads.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
using Test
using Distributed, Base.Threads
using Base.Iterators: product

exeflags = ("--startup-file=no",
"--check-bounds=yes",
"--depwarn=error",
"--threads=2")

function call_on(f, wid, tid)
remotecall(wid) do
t = Task(f)
ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, tid-1)
schedule(t)
@assert threadid(t) == tid
t
end
end

# Run function on process holding the data to only serialize the result of f.
# This becomes useful for things that cannot be serialized (e.g. running tasks)
# or that would be unnecessarily big if serialized.
fetch_from_owner(f, rr) = remotecall_fetch(ffetch, rr.where, rr)

isdone(rr) = fetch_from_owner(istaskdone, rr)
isfailed(rr) = fetch_from_owner(istaskfailed, rr)

@testset "RemoteChannel allows put!/take! from thread other than 1" begin
ws = ts = product(1:2, 1:2)
timeout = 10.0
@testset "from worker $w1 to $w2 via 1" for (w1, w2) in ws
@testset "from thread $w1.$t1 to $w2.$t2" for (t1, t2) in ts
# We want (the default) lazyness, so that we wait for `Worker.c_state`!
procs_added = addprocs(2; exeflags, lazy=true)
@everywhere procs_added using Base.Threads
p1 = procs_added[w1]
p2 = procs_added[w2]
chan_id = first(procs_added)
chan = RemoteChannel(chan_id)
send = call_on(p1, t1) do
put!(chan, nothing)
end
recv = call_on(p2, t2) do
take!(chan)
end
timedwait(() -> isdone(send) && isdone(recv), timeout)
@test isdone(send)
@test isdone(recv)
@test !isfailed(send)
@test !isfailed(recv)
rmprocs(procs_added)
end
end
end

0 comments on commit 8d4a408

Please sign in to comment.