Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Distributed] Make worker state variable threadsafe #42239

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 42 additions & 14 deletions stdlib/Distributed/src/cluster.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,10 @@ mutable struct Worker
del_msgs::Array{Any,1} # XXX: Could del_msgs and add_msgs be Channels?
add_msgs::Array{Any,1}
@atomic gcflag::Bool
state::WorkerState
c_state::Condition # wait for state changes
ct_time::Float64 # creation time
conn_func::Any # used to setup connections lazily
@atomic state::WorkerState
c_state::Threads.Condition # wait for state changes, lock for state
ct_time::Float64 # creation time
conn_func::Any # used to setup connections lazily

r_stream::IO
w_stream::IO
Expand Down Expand Up @@ -134,7 +134,7 @@ mutable struct Worker
if haskey(map_pid_wrkr, id)
return map_pid_wrkr[id]
end
w=new(id, Threads.ReentrantLock(), [], [], false, W_CREATED, Condition(), time(), conn_func)
w=new(id, Threads.ReentrantLock(), [], [], false, W_CREATED, Threads.Condition(), time(), conn_func)
w.initialized = Event()
register_worker(w)
w
Expand All @@ -144,8 +144,10 @@ mutable struct Worker
end

function set_worker_state(w, state)
w.state = state
notify(w.c_state; all=true)
lock(w.c_state) do
@atomic w.state = state
notify(w.c_state; all=true)
end
end

function check_worker_state(w::Worker)
Expand All @@ -170,6 +172,7 @@ function check_worker_state(w::Worker)
wait_for_conn(w)
end
end
return nothing
end

exec_conn_func(id::Int) = exec_conn_func(worker_from_id(id)::Worker)
Expand All @@ -191,9 +194,17 @@ function wait_for_conn(w)
timeout = worker_timeout() - (time() - w.ct_time)
timeout <= 0 && error("peer $(w.id) has not connected to $(myid())")

@async (sleep(timeout); notify(w.c_state; all=true))
wait(w.c_state)
w.state === W_CREATED && error("peer $(w.id) didn't connect to $(myid()) within $timeout seconds")
T = Threads.@spawn begin
sleep($timeout)
lock(w.c_state) do
notify(w.c_state; all=true)
end
end
errormonitor(T)
lock(w.c_state) do
wait(w.c_state)
w.state === W_CREATED && error("peer $(w.id) didn't connect to $(myid()) within $timeout seconds")
end
end
nothing
end
Expand Down Expand Up @@ -488,7 +499,10 @@ function addprocs_locked(manager::ClusterManager; kwargs...)
while true
if isempty(launched)
istaskdone(t_launch) && break
@async (sleep(1); notify(launch_ntfy))
@async begin
sleep(1)
notify(launch_ntfy)
end
wait(launch_ntfy)
end

Expand Down Expand Up @@ -641,7 +655,12 @@ function create_worker(manager, wconfig)
# require the value of config.connect_at which is set only upon connection completion
for jw in PGRP.workers
if (jw.id != 1) && (jw.id < w.id)
(jw.state === W_CREATED) && wait(jw.c_state)
# wait for wl to join
if jw.state === W_CREATED
lock(jw.c_state) do
wait(jw.c_state)
end
end
push!(join_list, jw)
end
end
Expand All @@ -664,7 +683,12 @@ function create_worker(manager, wconfig)
end

for wl in wlist
(wl.state === W_CREATED) && wait(wl.c_state)
lock(wl.c_state) do
if wl.state === W_CREATED
# wait for wl to join
wait(wl.c_state)
end
end
push!(join_list, wl)
end
end
Expand All @@ -681,7 +705,11 @@ function create_worker(manager, wconfig)
@async manage(w.manager, w.id, w.config, :register)
# wait for rr_ntfy_join with timeout
timedout = false
@async (sleep($timeout); timedout = true; put!(rr_ntfy_join, 1))
@async begin
sleep($timeout)
timedout = true
put!(rr_ntfy_join, 1)
end
wait(rr_ntfy_join)
if timedout
error("worker did not connect within $timeout seconds")
Expand Down
2 changes: 1 addition & 1 deletion stdlib/Distributed/src/managers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ function launch(manager::SSHManager, params::Dict, launched::Array, launch_ntfy:
# Wait for all launches to complete.
@sync for (i, (machine, cnt)) in enumerate(manager.machines)
let machine=machine, cnt=cnt
@async try
@async try
launch_on_machine(manager, $machine, $cnt, params, launched, launch_ntfy)
catch e
print(stderr, "exception launching on machine $(machine) : $(e)\n")
Expand Down
1 change: 1 addition & 0 deletions stdlib/Distributed/test/distributed_exec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1696,4 +1696,5 @@ include("splitrange.jl")
# Run topology tests last after removing all workers, since a given
# cluster at any time only supports a single topology.
rmprocs(workers())
include("threads.jl")
include("topology.jl")
63 changes: 63 additions & 0 deletions stdlib/Distributed/test/threads.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
using Test
using Distributed, Base.Threads
using Base.Iterators: product

exeflags = ("--startup-file=no",
"--check-bounds=yes",
"--depwarn=error",
"--threads=2")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test seems invalid for now


function call_on(f, wid, tid)
remotecall(wid) do
t = Task(f)
ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, tid - 1)
schedule(t)
@assert threadid(t) == tid
t
end
end

# Run function on process holding the data to only serialize the result of f.
# This becomes useful for things that cannot be serialized (e.g. running tasks)
# or that would be unnecessarily big if serialized.
fetch_from_owner(f, rr) = remotecall_fetch(f ∘ fetch, rr.where, rr)

isdone(rr) = fetch_from_owner(istaskdone, rr)
isfailed(rr) = fetch_from_owner(istaskfailed, rr)

@testset "RemoteChannel allows put!/take! from thread other than 1" begin
ws = ts = product(1:2, 1:2)
@testset "from worker $w1 to $w2 via 1" for (w1, w2) in ws
@testset "from thread $w1.$t1 to $w2.$t2" for (t1, t2) in ts
# We want (the default) lazyness, so that we wait for `Worker.c_state`!
procs_added = addprocs(2; exeflags, lazy=true)
@everywhere procs_added using Base.Threads

p1 = procs_added[w1]
p2 = procs_added[w2]
chan_id = first(procs_added)
chan = RemoteChannel(chan_id)
send = call_on(p1, t1) do
put!(chan, nothing)
end
recv = call_on(p2, t2) do
take!(chan)
end

# Wait on the spawned tasks on the owner
@sync begin
Threads.@spawn fetch_from_owner(wait, recv)
Threads.@spawn fetch_from_owner(wait, send)
end
Comment on lines +47 to +51
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've had a very bad experience with this wait pattern, as it doesn't catch deadlocks: if the sender fails before actually sending a value, the receiver will wait forever. Similarly, if the receiver fails before retrieving the message, the sender will wait forever (if the channel is unbuffered). The @sync block will only "forward" that failure, if all tasks finished, which never happens in this case. That's why I've resorted to a timedwait.

IIRC, I saw some work on an eager version of @sync, Experimental.@sync maybe, but it's still experimental (i.e. discouraged for my use case). I couldn't find it quickly. #32677 describes some related ideas. I was looking into a custom @sync a la

tasks = ...
# new sync end:
@sync for t in tasks
    @async try
        wait(t)
    catch
        for t2 in tasks
            istaskdone(t2) || yieldto(t2, InterruptException())
        end
    end
end

but yieldto is, again, discouraged and people much more adept than me suggest this is a bad idea (golang/go#29011 (comment)).


# Check the tasks
@test isdone(send)
@test isdone(recv)

@test !isfailed(send)
@test !isfailed(recv)

rmprocs(procs_added)
end
end
end