diff --git a/stdlib/Distributed/src/cluster.jl b/stdlib/Distributed/src/cluster.jl index cea8258f36939..cccc377e6790f 100644 --- a/stdlib/Distributed/src/cluster.jl +++ b/stdlib/Distributed/src/cluster.jl @@ -99,10 +99,10 @@ mutable struct Worker del_msgs::Array{Any,1} # XXX: Could del_msgs and add_msgs be Channels? add_msgs::Array{Any,1} @atomic gcflag::Bool - state::WorkerState - c_state::Condition # wait for state changes - ct_time::Float64 # creation time - conn_func::Any # used to setup connections lazily + @atomic state::WorkerState + c_state::Threads.Condition # wait for state changes, lock for state + ct_time::Float64 # creation time + conn_func::Any # used to setup connections lazily r_stream::IO w_stream::IO @@ -134,7 +134,7 @@ mutable struct Worker if haskey(map_pid_wrkr, id) return map_pid_wrkr[id] end - w=new(id, Threads.ReentrantLock(), [], [], false, W_CREATED, Condition(), time(), conn_func) + w=new(id, Threads.ReentrantLock(), [], [], false, W_CREATED, Threads.Condition(), time(), conn_func) w.initialized = Event() register_worker(w) w @@ -144,8 +144,10 @@ mutable struct Worker end function set_worker_state(w, state) - w.state = state - notify(w.c_state; all=true) + lock(w.c_state) do + @atomic w.state = state + notify(w.c_state; all=true) + end end function check_worker_state(w::Worker) @@ -170,6 +172,7 @@ function check_worker_state(w::Worker) wait_for_conn(w) end end + return nothing end exec_conn_func(id::Int) = exec_conn_func(worker_from_id(id)::Worker) @@ -191,9 +194,17 @@ function wait_for_conn(w) timeout = worker_timeout() - (time() - w.ct_time) timeout <= 0 && error("peer $(w.id) has not connected to $(myid())") - @async (sleep(timeout); notify(w.c_state; all=true)) - wait(w.c_state) - w.state === W_CREATED && error("peer $(w.id) didn't connect to $(myid()) within $timeout seconds") + T = Threads.@spawn begin + sleep($timeout) + lock(w.c_state) do + notify(w.c_state; all=true) + end + end + errormonitor(T) + lock(w.c_state) do + wait(w.c_state) + w.state === W_CREATED && error("peer $(w.id) didn't connect to $(myid()) within $timeout seconds") + end end nothing end @@ -488,7 +499,10 @@ function addprocs_locked(manager::ClusterManager; kwargs...) while true if isempty(launched) istaskdone(t_launch) && break - @async (sleep(1); notify(launch_ntfy)) + @async begin + sleep(1) + notify(launch_ntfy) + end wait(launch_ntfy) end @@ -641,7 +655,12 @@ function create_worker(manager, wconfig) # require the value of config.connect_at which is set only upon connection completion for jw in PGRP.workers if (jw.id != 1) && (jw.id < w.id) - (jw.state === W_CREATED) && wait(jw.c_state) + # wait for wl to join + if jw.state === W_CREATED + lock(jw.c_state) do + wait(jw.c_state) + end + end push!(join_list, jw) end end @@ -664,7 +683,12 @@ function create_worker(manager, wconfig) end for wl in wlist - (wl.state === W_CREATED) && wait(wl.c_state) + lock(wl.c_state) do + if wl.state === W_CREATED + # wait for wl to join + wait(wl.c_state) + end + end push!(join_list, wl) end end @@ -681,7 +705,11 @@ function create_worker(manager, wconfig) @async manage(w.manager, w.id, w.config, :register) # wait for rr_ntfy_join with timeout timedout = false - @async (sleep($timeout); timedout = true; put!(rr_ntfy_join, 1)) + @async begin + sleep($timeout) + timedout = true + put!(rr_ntfy_join, 1) + end wait(rr_ntfy_join) if timedout error("worker did not connect within $timeout seconds") diff --git a/stdlib/Distributed/src/managers.jl b/stdlib/Distributed/src/managers.jl index 91a27aa95cb98..bd5a979603514 100644 --- a/stdlib/Distributed/src/managers.jl +++ b/stdlib/Distributed/src/managers.jl @@ -163,7 +163,7 @@ function launch(manager::SSHManager, params::Dict, launched::Array, launch_ntfy: # Wait for all launches to complete. @sync for (i, (machine, cnt)) in enumerate(manager.machines) let machine=machine, cnt=cnt - @async try + @async try launch_on_machine(manager, $machine, $cnt, params, launched, launch_ntfy) catch e print(stderr, "exception launching on machine $(machine) : $(e)\n") diff --git a/stdlib/Distributed/test/distributed_exec.jl b/stdlib/Distributed/test/distributed_exec.jl index 749c18f6b61f0..3b99afac8cc15 100644 --- a/stdlib/Distributed/test/distributed_exec.jl +++ b/stdlib/Distributed/test/distributed_exec.jl @@ -1696,4 +1696,5 @@ include("splitrange.jl") # Run topology tests last after removing all workers, since a given # cluster at any time only supports a single topology. rmprocs(workers()) +include("threads.jl") include("topology.jl") diff --git a/stdlib/Distributed/test/threads.jl b/stdlib/Distributed/test/threads.jl new file mode 100644 index 0000000000000..57d99b7ea056c --- /dev/null +++ b/stdlib/Distributed/test/threads.jl @@ -0,0 +1,63 @@ +using Test +using Distributed, Base.Threads +using Base.Iterators: product + +exeflags = ("--startup-file=no", + "--check-bounds=yes", + "--depwarn=error", + "--threads=2") + +function call_on(f, wid, tid) + remotecall(wid) do + t = Task(f) + ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, tid - 1) + schedule(t) + @assert threadid(t) == tid + t + end +end + +# Run function on process holding the data to only serialize the result of f. +# This becomes useful for things that cannot be serialized (e.g. running tasks) +# or that would be unnecessarily big if serialized. +fetch_from_owner(f, rr) = remotecall_fetch(f ∘ fetch, rr.where, rr) + +isdone(rr) = fetch_from_owner(istaskdone, rr) +isfailed(rr) = fetch_from_owner(istaskfailed, rr) + +@testset "RemoteChannel allows put!/take! from thread other than 1" begin + ws = ts = product(1:2, 1:2) + @testset "from worker $w1 to $w2 via 1" for (w1, w2) in ws + @testset "from thread $w1.$t1 to $w2.$t2" for (t1, t2) in ts + # We want (the default) lazyness, so that we wait for `Worker.c_state`! + procs_added = addprocs(2; exeflags, lazy=true) + @everywhere procs_added using Base.Threads + + p1 = procs_added[w1] + p2 = procs_added[w2] + chan_id = first(procs_added) + chan = RemoteChannel(chan_id) + send = call_on(p1, t1) do + put!(chan, nothing) + end + recv = call_on(p2, t2) do + take!(chan) + end + + # Wait on the spawned tasks on the owner + @sync begin + Threads.@spawn fetch_from_owner(wait, recv) + Threads.@spawn fetch_from_owner(wait, send) + end + + # Check the tasks + @test isdone(send) + @test isdone(recv) + + @test !isfailed(send) + @test !isfailed(recv) + + rmprocs(procs_added) + end + end +end