Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement faster thread local rng for scheduler #55501

Merged
merged 9 commits into from
Sep 9, 2024
55 changes: 54 additions & 1 deletion base/partr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,60 @@ const heaps = [Vector{taskheap}(undef, 0), Vector{taskheap}(undef, 0)]
const heaps_lock = [SpinLock(), SpinLock()]


cong(max::UInt32) = iszero(max) ? UInt32(0) : ccall(:jl_rand_ptls, UInt32, (UInt32,), max) + UInt32(1)
gbaraldi marked this conversation as resolved.
Show resolved Hide resolved
"""
cong(max::UInt32)
gbaraldi marked this conversation as resolved.
Show resolved Hide resolved

Return a random UInt32 in the range `1:max` except if max is 0, in that case return 0.
"""
cong(max::UInt32) = iszero(max) ? UInt32(0) : rand_ptls(max) + UInt32(1) #TODO: make sure users don't use 0 and remove this check

get_ptls_rng() = ccall(:jl_get_ptls_rng, UInt64, ())

set_ptls_rng(seed::UInt64) = ccall(:jl_set_ptls_rng, Cvoid, (UInt64,), seed)

"""
rand_ptls(max::UInt32)

Return a random UInt32 in the range `0:max-1` using the thread-local RNG
state. Max must be greater than 0.
"""
Base.@assume_effects :removable :inaccessiblememonly :notaskstate function rand_ptls(max::UInt32)
rngseed = get_ptls_rng()
val, seed = rand_uniform_max_int32(max, rngseed)
set_ptls_rng(seed)
return val % UInt32
end

# This implementation is based on OpenSSLs implementation of rand_uniform
# https://github.com/openssl/openssl/blob/1d2cbd9b5a126189d5e9bc78a3bdb9709427d02b/crypto/rand/rand_uniform.c#L13-L99
# Comments are vendored from their implementation as well.
# For the original developer check the PR to swift https://github.com/apple/swift/pull/39143.

# Essentially it boils down to incrementally generating a fixed point
# number on the interval [0, 1) and multiplying this number by the upper
# range limit. Once it is certain what the fractional part contributes to
# the integral part of the product, the algorithm has produced a definitive
# result.
"""
rand_uniform_max_int32(max::UInt32, seed::UInt64)
gbaraldi marked this conversation as resolved.
Show resolved Hide resolved

Return a random UInt32 in the range `0:max-1` using the given seed.
Max must be greater than 0.
"""
Base.@assume_effects :total function rand_uniform_max_int32(max::UInt32, seed::UInt64)
if max == UInt32(1)
return UInt32(0), seed
end
# We are generating a fixed point number on the interval [0, 1).
# Multiplying this by the range gives us a number on [0, upper).
# The high word of the multiplication result represents the integral part
# This is not completely unbiased as it's missing the fractional part of the original implementation but it's good enough for our purposes
seed = UInt64(69069) * seed + UInt64(362437)
prod = (UInt64(max)) * (seed % UInt32) # 64 bit product
i = prod >> 32 % UInt32 # integral part
return i % UInt32, seed
end



function multiq_sift_up(heap::taskheap, idx::Int32)
Expand Down
32 changes: 32 additions & 0 deletions src/ccall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ TRANSFORMED_CCALL_STAT(jl_cpu_wake);
TRANSFORMED_CCALL_STAT(jl_gc_safepoint);
TRANSFORMED_CCALL_STAT(jl_get_ptls_states);
TRANSFORMED_CCALL_STAT(jl_threadid);
TRANSFORMED_CCALL_STAT(jl_get_ptls_rng);
TRANSFORMED_CCALL_STAT(jl_set_ptls_rng);
TRANSFORMED_CCALL_STAT(jl_get_tls_world_age);
TRANSFORMED_CCALL_STAT(jl_get_world_counter);
TRANSFORMED_CCALL_STAT(jl_gc_enable_disable_finalizers_internal);
Expand Down Expand Up @@ -1692,6 +1694,36 @@ static jl_cgval_t emit_ccall(jl_codectx_t &ctx, jl_value_t **args, size_t nargs)
ai.decorateInst(tid);
return mark_or_box_ccall_result(ctx, tid, retboxed, rt, unionall, static_rt);
}
else if (is_libjulia_func(jl_get_ptls_rng)) {
++CCALL_STAT(jl_get_ptls_rng);
assert(lrt == getInt64Ty(ctx.builder.getContext()));
assert(!isVa && !llvmcall && nccallargs == 0);
JL_GC_POP();
Value *ptls_p = get_current_ptls(ctx);
const int rng_offset = offsetof(jl_tls_states_t, rngseed);
Value *rng_ptr = ctx.builder.CreateInBoundsGEP(getInt8Ty(ctx.builder.getContext()), ptls_p, ConstantInt::get(ctx.types().T_size, rng_offset / sizeof(int8_t)));
setName(ctx.emission_context, rng_ptr, "rngseed_ptr");
LoadInst *rng_value = ctx.builder.CreateAlignedLoad(getInt64Ty(ctx.builder.getContext()), rng_ptr, Align(sizeof(void*)));
setName(ctx.emission_context, rng_value, "rngseed");
jl_aliasinfo_t ai = jl_aliasinfo_t::fromTBAA(ctx, ctx.tbaa().tbaa_gcframe);
ai.decorateInst(rng_value);
return mark_or_box_ccall_result(ctx, rng_value, retboxed, rt, unionall, static_rt);
}
else if (is_libjulia_func(jl_set_ptls_rng)) {
++CCALL_STAT(jl_set_ptls_rng);
assert(lrt == getVoidTy(ctx.builder.getContext()));
assert(!isVa && !llvmcall && nccallargs == 1);
JL_GC_POP();
Value *ptls_p = get_current_ptls(ctx);
const int rng_offset = offsetof(jl_tls_states_t, rngseed);
Value *rng_ptr = ctx.builder.CreateInBoundsGEP(getInt8Ty(ctx.builder.getContext()), ptls_p, ConstantInt::get(ctx.types().T_size, rng_offset / sizeof(int8_t)));
setName(ctx.emission_context, rng_ptr, "rngseed_ptr");
assert(argv[0].V->getType() == getInt64Ty(ctx.builder.getContext()));
auto store = ctx.builder.CreateAlignedStore(argv[0].V, rng_ptr, Align(sizeof(void*)));
jl_aliasinfo_t ai = jl_aliasinfo_t::fromTBAA(ctx, ctx.tbaa().tbaa_gcframe);
ai.decorateInst(store);
return ghostValue(ctx, jl_nothing_type);
}
else if (is_libjulia_func(jl_get_tls_world_age)) {
bool toplevel = !(ctx.linfo && jl_is_method(ctx.linfo->def.method));
if (!toplevel) { // top level code does not see a stable world age during execution
Expand Down
2 changes: 2 additions & 0 deletions src/jl_exported_funcs.inc
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,8 @@
XX(jl_test_cpu_feature) \
XX(jl_threadid) \
XX(jl_threadpoolid) \
XX(jl_get_ptls_rng) \
XX(jl_set_ptls_rng) \
XX(jl_throw) \
XX(jl_throw_out_of_memory_error) \
XX(jl_too_few_args) \
Expand Down
2 changes: 2 additions & 0 deletions src/julia_threads.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ extern "C" {

JL_DLLEXPORT int16_t jl_threadid(void);
JL_DLLEXPORT int8_t jl_threadpoolid(int16_t tid) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint64_t jl_get_ptls_rng(void) JL_NOTSAFEPOINT;
JL_DLLEXPORT void jl_set_ptls_rng(uint64_t new_seed) JL_NOTSAFEPOINT;

// JULIA_ENABLE_THREADING may be controlled by altering JULIA_THREADS in Make.user

Expand Down
9 changes: 0 additions & 9 deletions src/scheduler.c
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,6 @@ JL_DLLEXPORT int jl_set_task_threadpoolid(jl_task_t *task, int8_t tpid) JL_NOTSA
extern int jl_gc_mark_queue_obj_explicit(jl_gc_mark_cache_t *gc_cache,
jl_gc_markqueue_t *mq, jl_value_t *obj) JL_NOTSAFEPOINT;

// parallel task runtime
// ---

JL_DLLEXPORT uint32_t jl_rand_ptls(uint32_t max) // [0, n)
{
jl_ptls_t ptls = jl_current_task->ptls;
return cong(max, &ptls->rngseed);
vchuravy marked this conversation as resolved.
Show resolved Hide resolved
}

// initialize the threading infrastructure
// (called only by the main thread)
void jl_init_threadinginfra(void)
Expand Down
12 changes: 12 additions & 0 deletions src/threading.c
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,18 @@ JL_DLLEXPORT int8_t jl_threadpoolid(int16_t tid) JL_NOTSAFEPOINT
return -1; // everything else uses threadpool -1 (does not belong to any threadpool)
}

// get thread local rng
JL_DLLEXPORT uint64_t jl_get_ptls_rng(void) JL_NOTSAFEPOINT
{
return jl_current_task->ptls->rngseed;
}

// get thread local rng
JL_DLLEXPORT void jl_set_ptls_rng(uint64_t new_seed) JL_NOTSAFEPOINT
{
jl_current_task->ptls->rngseed = new_seed;
}

jl_ptls_t jl_init_threadtls(int16_t tid)
{
#ifndef _OS_WINDOWS_
Expand Down