Skip to content

Commit

Permalink
threading: support more than nthreads at runtime
Browse files Browse the repository at this point in the history
Hook a couple functions (notably cfunction) to handle adopting
foreign threads automatically when used.

n.b. If returning an object pointer, we do not gc_unsafe_leave
afterwards as that would render the pointer invalid. However, this means
that it can be a long time before the next safepoint (if ever). We
should look into ways of improving this bad situation, such as pinning
only that specific object temporarily.

n.b. There are some remaining issues to clean up. For example, we may
trap pages in the ptls after GC to keep them "warm", and trap other
pages in the unwind buffer, etc.
  • Loading branch information
vtjnash authored and JeffBezanson committed Sep 2, 2022
1 parent 2415f83 commit 01a7ad9
Show file tree
Hide file tree
Showing 48 changed files with 852 additions and 511 deletions.
30 changes: 30 additions & 0 deletions base/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -336,4 +336,34 @@ function setproperty!(ci::CodeInfo, s::Symbol, v)
return setfield!(ci, s, convert(fieldtype(CodeInfo, s), v))
end

@eval Threads nthreads() = threadpoolsize()

@eval Threads begin
"""
resize_nthreads!(A, copyvalue=A[1])
Resize the array `A` to length [`nthreads()`](@ref). Any new
elements that are allocated are initialized to `deepcopy(copyvalue)`,
where `copyvalue` defaults to `A[1]`.
This is typically used to allocate per-thread variables, and
should be called in `__init__` if `A` is a global constant.
!!! warning
This function is deprecated, since as of Julia v1.9 the number of
threads can change at run time. Instead, per-thread state should be
created as needed based on the thread id of the caller.
"""
function resize_nthreads!(A::AbstractVector, copyvalue=A[1])
nthr = nthreads()
nold = length(A)
resize!(A, nthr)
for i = nold+1:nthr
A[i] = deepcopy(copyvalue)
end
return A
end
end

# END 1.9 deprecations
2 changes: 1 addition & 1 deletion base/partr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

module Partr

using ..Threads: SpinLock, nthreads, threadid
using ..Threads: SpinLock, maxthreadid, threadid

# a task minheap
mutable struct taskheap
Expand Down
4 changes: 2 additions & 2 deletions base/pcre.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ THREAD_MATCH_CONTEXTS::Vector{Ptr{Cvoid}} = [C_NULL]
PCRE_COMPILE_LOCK = nothing

_tid() = Int(ccall(:jl_threadid, Int16, ())) + 1
_nth() = Int(unsafe_load(cglobal(:jl_n_threads, Cint)))
_mth() = Int(Core.Intrinsics.atomic_pointerref(cglobal(:jl_n_threads, Cint), :acquire))

function get_local_match_context()
tid = _tid()
Expand All @@ -41,7 +41,7 @@ function get_local_match_context()
try
ctxs = THREAD_MATCH_CONTEXTS
if length(ctxs) < tid
global THREAD_MATCH_CONTEXTS = ctxs = copyto!(fill(C_NULL, _nth()), ctxs)
global THREAD_MATCH_CONTEXTS = ctxs = copyto!(fill(C_NULL, length(ctxs) + _mth()), ctxs)
end
finally
unlock(l)
Expand Down
4 changes: 2 additions & 2 deletions base/task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ function workqueue_for(tid::Int)
@lock l begin
qs = Workqueues
if length(qs) < tid
nt = Threads.nthreads()
nt = Threads.maxthreadid()
@assert tid <= nt
global Workqueues = qs = copyto!(typeof(qs)(undef, length(qs) + nt - 1), qs)
end
Expand All @@ -767,7 +767,7 @@ end

function enq_work(t::Task)
(t._state === task_state_runnable && t.queue === nothing) || error("schedule: Task not runnable")
if t.sticky || Threads.nthreads() == 1
if t.sticky || Threads.threadpoolsize() == 1
tid = Threads.threadid(t)
if tid == 0
# Issue #41324
Expand Down
46 changes: 32 additions & 14 deletions base/threadingconstructs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,27 @@ ID `1`.
"""
threadid() = Int(ccall(:jl_threadid, Int16, ())+1)

# lower bound on the largest threadid()
"""
Threads.nthreads([:default|:interactive]) -> Int
Threads.maxthreadid() -> Int
Get the number of threads (across all thread pools or within the specified
thread pool) available to Julia. The number of threads across all thread
pools is the inclusive upper bound on [`threadid()`](@ref).
Get a lower bound on the number of threads (across all thread pools) available
to the Julia process, with atomic-acquire semantics. The result will always be
greater than or equal to [`threadid()`](@ref) as well as `threadid(task)` for
any task you were able to observe before calling `maxthreadid`.
"""
maxthreadid() = Int(Core.Intrinsics.atomic_pointerref(cglobal(:jl_n_threads, Cint), :acquire))

See also: `BLAS.get_num_threads` and `BLAS.set_num_threads` in the
[`LinearAlgebra`](@ref man-linalg) standard library, and `nprocs()` in the
[`Distributed`](@ref man-distributed) standard library.
"""
function nthreads end
Threads.nthreads(:default | :interactive) -> Int
nthreads() = Int(unsafe_load(cglobal(:jl_n_threads, Cint)))
Get the current number of threads within the specified thread pool. The threads in default
have id numbers `1:nthreads(:default)`.
See also `BLAS.get_num_threads` and `BLAS.set_num_threads` in the [`LinearAlgebra`](@ref
man-linalg) standard library, and `nprocs()` in the [`Distributed`](@ref man-distributed)
standard library and [`Threads.maxthreadid()`](@ref).
"""
function nthreads(pool::Symbol)
if pool === :default
tpid = Int8(0)
Expand All @@ -35,6 +42,7 @@ function nthreads(pool::Symbol)
end
return _nthreads_in_pool(tpid)
end

function _nthreads_in_pool(tpid::Int8)
p = unsafe_load(cglobal(:jl_n_threads_per_pool, Ptr{Cint}))
return Int(unsafe_load(p, tpid + 1))
Expand All @@ -57,10 +65,20 @@ Returns the number of threadpools currently configured.
"""
nthreadpools() = Int(unsafe_load(cglobal(:jl_n_threadpools, Cint)))

"""
Threads.threadpoolsize()
Get the number of threads available to the Julia default worker-thread pool.
See also: `BLAS.get_num_threads` and `BLAS.set_num_threads` in the
[`LinearAlgebra`](@ref man-linalg) standard library, and `nprocs()` in the
[`Distributed`](@ref man-distributed) standard library.
"""
threadpoolsize() = Threads._nthreads_in_pool(Int8(0))

function threading_run(fun, static)
ccall(:jl_enter_threaded_region, Cvoid, ())
n = nthreads()
n = threadpoolsize()
tasks = Vector{Task}(undef, n)
for i = 1:n
t = Task(() -> fun(i)) # pass in tid
Expand Down Expand Up @@ -93,7 +111,7 @@ function _threadsfor(iter, lbody, schedule)
tid = 1
len, rem = lenr, 0
else
len, rem = divrem(lenr, nthreads())
len, rem = divrem(lenr, threadpoolsize())
end
# not enough iterations for all the threads?
if len == 0
Expand Down Expand Up @@ -185,7 +203,7 @@ assumption may be removed in the future.
This scheduling option is merely a hint to the underlying execution mechanism. However, a
few properties can be expected. The number of `Task`s used by `:dynamic` scheduler is
bounded by a small constant multiple of the number of available worker threads
([`nthreads()`](@ref Threads.nthreads)). Each task processes contiguous regions of the
([`Threads.threadpoolsize()`](@ref)). Each task processes contiguous regions of the
iteration space. Thus, `@threads :dynamic for x in xs; f(x); end` is typically more
efficient than `@sync for x in xs; @spawn f(x); end` if `length(xs)` is significantly
larger than the number of the worker threads and the run-time of `f(x)` is relatively
Expand Down Expand Up @@ -222,15 +240,15 @@ julia> function busywait(seconds)
julia> @time begin
Threads.@spawn busywait(5)
Threads.@threads :static for i in 1:Threads.nthreads()
Threads.@threads :static for i in 1:Threads.threadpoolsize()
busywait(1)
end
end
6.003001 seconds (16.33 k allocations: 899.255 KiB, 0.25% compilation time)
julia> @time begin
Threads.@spawn busywait(5)
Threads.@threads :dynamic for i in 1:Threads.nthreads()
Threads.@threads :dynamic for i in 1:Threads.threadpoolsize()
busywait(1)
end
end
Expand Down
21 changes: 0 additions & 21 deletions base/threads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,4 @@ include("threadingconstructs.jl")
include("atomics.jl")
include("locks-mt.jl")


"""
resize_nthreads!(A, copyvalue=A[1])
Resize the array `A` to length [`nthreads()`](@ref). Any new
elements that are allocated are initialized to `deepcopy(copyvalue)`,
where `copyvalue` defaults to `A[1]`.
This is typically used to allocate per-thread variables, and
should be called in `__init__` if `A` is a global constant.
"""
function resize_nthreads!(A::AbstractVector, copyvalue=A[1])
nthr = nthreads()
nold = length(A)
resize!(A, nthr)
for i = nold+1:nthr
A[i] = deepcopy(copyvalue)
end
return A
end

end
4 changes: 2 additions & 2 deletions base/threads_overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
Threads.foreach(f, channel::Channel;
schedule::Threads.AbstractSchedule=Threads.FairSchedule(),
ntasks=Threads.nthreads())
ntasks=Base.threadpoolsize())
Similar to `foreach(f, channel)`, but iteration over `channel` and calls to
`f` are split across `ntasks` tasks spawned by `Threads.@spawn`. This function
Expand Down Expand Up @@ -40,7 +40,7 @@ collect(d) = [1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144, 169, 196, 225, 256
"""
function Threads.foreach(f, channel::Channel;
schedule::Threads.AbstractSchedule=Threads.FairSchedule(),
ntasks=Threads.nthreads())
ntasks=Threads.threadpoolsize())
apply = _apply_for_schedule(schedule)
stop = Threads.Atomic{Bool}(false)
@sync for _ in 1:ntasks
Expand Down
2 changes: 1 addition & 1 deletion cli/loader_exe.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ extern "C" {
JULIA_DEFINE_FAST_TLS

#ifdef _COMPILER_ASAN_ENABLED_
JL_DLLEXPORT const char* __asan_default_options()
JL_DLLEXPORT const char* __asan_default_options(void)
{
return "allow_user_segv_handler=1:detect_leaks=0";
// FIXME: enable LSAN after fixing leaks & defining __lsan_default_suppressions(),
Expand Down
6 changes: 3 additions & 3 deletions contrib/generate_precompile.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

if Threads.nthreads() != 1
@warn "Running this file with multiple Julia threads may lead to a build error" Threads.nthreads()
if Threads.maxthreadid() != 1
@warn "Running this file with multiple Julia threads may lead to a build error" Base.maxthreadid()
end

if Base.isempty(Base.ARGS) || Base.ARGS[1] !== "0"
Expand Down Expand Up @@ -340,7 +340,7 @@ function generate_precompile_statements()
# wait for the next prompt-like to appear
readuntil(output_copy, "\n")
strbuf = ""
while true
while !eof(output_copy)
strbuf *= String(readavailable(output_copy))
occursin(JULIA_PROMPT, strbuf) && break
occursin(PKG_PROMPT, strbuf) && break
Expand Down
2 changes: 2 additions & 0 deletions doc/src/base/multi-threading.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ Base.Threads.@threads
Base.Threads.foreach
Base.Threads.@spawn
Base.Threads.threadid
Base.Threads.maxthreadid
Base.Threads.nthreads
Base.Threads.threadpool
Base.Threads.nthreadpools
Base.Threads.threadpoolsize
```

See also [Multi-Threading](@ref man-multithreading).
Expand Down
6 changes: 3 additions & 3 deletions doc/src/manual/embedding.md
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ The second condition above implies that you can not safely call `jl_...()` funct
void *func(void*)
{
// Wrong, jl_eval_string() called from thread that was not started by Julia
jl_eval_string("println(Threads.nthreads())");
jl_eval_string("println(Threads.threadid())");
return NULL;
}

Expand All @@ -630,7 +630,7 @@ void *func(void*)
// Okay, all jl_...() calls from the same thread,
// even though it is not the main application thread
jl_init();
jl_eval_string("println(Threads.nthreads())");
jl_eval_string("println(Threads.threadid())");
jl_atexit_hook(0);
return NULL;
}
Expand Down Expand Up @@ -670,7 +670,7 @@ int main()
jl_eval_string("func(i) = ccall(:c_func, Float64, (Int32,), i)");

// Call func() multiple times, using multiple threads to do so
jl_eval_string("println(Threads.nthreads())");
jl_eval_string("println(Base.threadpoolsize())");
jl_eval_string("use(i) = println(\"[J $(Threads.threadid())] i = $(i) -> $(func(i))\")");
jl_eval_string("Threads.@threads for i in 1:5 use(i) end");

Expand Down
8 changes: 4 additions & 4 deletions doc/src/manual/multi-threading.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ of Julia multi-threading features.
## Starting Julia with multiple threads

By default, Julia starts up with a single thread of execution. This can be verified by using the
command [`Threads.nthreads()`](@ref):
command [`Threads.threadpoolsize()`](@ref):

```jldoctest
julia> Threads.nthreads()
julia> Threads.threadpoolsize()
1
```

Expand Down Expand Up @@ -38,7 +38,7 @@ $ julia --threads 4
Let's verify there are 4 threads at our disposal.

```julia-repl
julia> Threads.nthreads()
julia> Threads.threadpoolsize()
4
```

Expand Down Expand Up @@ -267,7 +267,7 @@ avoid the race:
```julia-repl
julia> using Base.Threads
julia> nthreads()
julia> Threads.threadpoolsize()
4
julia> acc = Ref(0)
Expand Down
6 changes: 4 additions & 2 deletions src/ccall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1553,7 +1553,8 @@ static jl_cgval_t emit_ccall(jl_codectx_t &ctx, jl_value_t **args, size_t nargs)
assert(lrt == getVoidTy(ctx.builder.getContext()));
assert(!isVa && !llvmcall && nccallargs == 0);
JL_GC_POP();
emit_gc_safepoint(ctx);
ctx.builder.CreateCall(prepare_call(gcroot_flush_func));
emit_gc_safepoint(ctx.builder, get_current_ptls(ctx), ctx.tbaa().tbaa_const);
return ghostValue(ctx, jl_nothing_type);
}
else if (is_libjulia_func("jl_get_ptls_states")) {
Expand Down Expand Up @@ -1656,7 +1657,8 @@ static jl_cgval_t emit_ccall(jl_codectx_t &ctx, jl_value_t **args, size_t nargs)
ctx.builder.SetInsertPoint(checkBB);
ctx.builder.CreateLoad(
getSizeTy(ctx.builder.getContext()),
ctx.builder.CreateConstInBoundsGEP1_32(getSizeTy(ctx.builder.getContext()), get_current_signal_page(ctx), -1),
ctx.builder.CreateConstInBoundsGEP1_32(getSizeTy(ctx.builder.getContext()),
get_current_signal_page_from_ptls(ctx.builder, get_current_ptls(ctx), ctx.tbaa().tbaa_const), -1),
true);
ctx.builder.CreateBr(contBB);
ctx.f->getBasicBlockList().push_back(contBB);
Expand Down
Loading

0 comments on commit 01a7ad9

Please sign in to comment.