Skip to content

Commit

Permalink
allow default_addprocs_params to be specialized on ClusterManager (#3…
Browse files Browse the repository at this point in the history
…8570)

I made this change to allow the set to be expanded for my own package, but I noticed this also helps unify #38353 and existing ssh-only options.
  • Loading branch information
simonbyrne authored Nov 30, 2020
1 parent c00aae9 commit b186a31
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 12 deletions.
1 change: 1 addition & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,5 @@ Distributed.connect(::ClusterManager, ::Int, ::WorkerConfig)
Distributed.init_worker
Distributed.start_worker
Distributed.process_messages
Distributed.default_addprocs_params
```
14 changes: 9 additions & 5 deletions src/cluster.jl
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ function addprocs(manager::ClusterManager; kwargs...)
end

function addprocs_locked(manager::ClusterManager; kwargs...)
params = merge(default_addprocs_params(), Dict{Symbol,Any}(kwargs))
params = merge(default_addprocs_params(manager), Dict{Symbol,Any}(kwargs))
topology(Symbol(params[:topology]))

if PGRP.topology !== :all_to_all
Expand Down Expand Up @@ -513,12 +513,16 @@ function set_valid_processes(plist::Array{Int})
end
end

"""
default_addprocs_params(mgr::ClusterManager) -> Dict{Symbol, Any}
Implemented by cluster managers. The default keyword parameters passed when calling
`addprocs(mgr)`. The minimal set of options is available by calling
`default_addprocs_params()`
"""
default_addprocs_params(::ClusterManager) = default_addprocs_params()
default_addprocs_params() = Dict{Symbol,Any}(
:topology => :all_to_all,
:ssh => "ssh",
:shell => :posix,
:cmdline_cookie => false,
:env => [],
:dir => pwd(),
:exename => joinpath(Sys.BINDIR::String, julia_exename()),
:exeflags => ``,
Expand Down
27 changes: 20 additions & 7 deletions src/managers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ struct SSHManager <: ClusterManager
end


function check_addprocs_args(kwargs)
valid_kw_names = collect(keys(default_addprocs_params()))
function check_addprocs_args(manager, kwargs)
valid_kw_names = keys(default_addprocs_params(manager))
for keyname in keys(kwargs)
!(keyname in valid_kw_names) && throw(ArgumentError("Invalid keyword argument $(keyname)"))
end
Expand Down Expand Up @@ -137,11 +137,23 @@ This timeout can be controlled via environment variable `JULIA_WORKER_TIMEOUT`.
The value of `JULIA_WORKER_TIMEOUT` on the master process specifies the number of seconds a
newly launched worker waits for connection establishment.
"""
function addprocs(machines::AbstractVector; tunnel=false, multiplex=false, sshflags=``, max_parallel=10, kwargs...)
check_addprocs_args(kwargs)
addprocs(SSHManager(machines); tunnel=tunnel, multiplex=multiplex, sshflags=sshflags, max_parallel=max_parallel, kwargs...)
function addprocs(machines::AbstractVector; kwargs...)
manager = SSHManager(machines)
check_addprocs_args(manager, kwargs)
addprocs(manager; kwargs...)
end

default_addprocs_params(::SSHManager) =
merge(default_addprocs_params(),
Dict{Symbol,Any}(
:ssh => "ssh",
:sshflags => ``,
:shell => :posix,
:cmdline_cookie => false,
:env => [],
:tunnel => false,
:multiplex => false,
:max_parallel => 10))

function launch(manager::SSHManager, params::Dict, launched::Array, launch_ntfy::Condition)
# Launch one worker on each unique host in parallel. Additional workers are launched later.
Expand Down Expand Up @@ -426,8 +438,9 @@ processes on the local machine. If `restrict` is `true`, binding is restricted t
`enable_threaded_blas` have the same effect as documented for `addprocs(machines)`.
"""
function addprocs(np::Integer; restrict=true, kwargs...)
check_addprocs_args(kwargs)
addprocs(LocalManager(np, restrict); kwargs...)
manager = LocalManager(np, restrict)
check_addprocs_args(manager, kwargs)
addprocs(manager; kwargs...)
end

Base.show(io::IO, manager::LocalManager) = print(io, "LocalManager()")
Expand Down

0 comments on commit b186a31

Please sign in to comment.