From 6aa9a3592a38a19f17b9ab414b420cacce8dab60 Mon Sep 17 00:00:00 2001 From: gbaraldi Date: Fri, 16 Aug 2024 10:40:17 -0300 Subject: [PATCH 1/5] Implement faster thread local rng for the scheduler. --- base/partr.jl | 51 ++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/base/partr.jl b/base/partr.jl index 8c95e3668ee74..1656b56c21989 100644 --- a/base/partr.jl +++ b/base/partr.jl @@ -20,7 +20,56 @@ const heaps = [Vector{taskheap}(undef, 0), Vector{taskheap}(undef, 0)] const heaps_lock = [SpinLock(), SpinLock()] -cong(max::UInt32) = iszero(max) ? UInt32(0) : ccall(:jl_rand_ptls, UInt32, (UInt32,), max) + UInt32(1) +""" + cong(max::UInt32) +Return a random UInt32 in the range `1:max` except if max is 0, in that case return 0. +""" +cong(max::UInt32) = iszero(max) ? UInt32(0) : jl_rand_ptls(max) + UInt32(1) #TODO: make sure users don't use 0 and remove this check + +""" + jl_rand_ptls(max::UInt32) +Return a random UInt32 in the range `0:max-1` using the thread-local RNG +state. Max must be greater than 0. +""" +Base.@assume_effects :removable :inaccessiblememonly :notaskstate function jl_rand_ptls(max::UInt32) + # Are these effects correct? We are technically lying to the compiler + # Though these are the same lies we tell to say that an unexcaped allocation has no effects + ptls = Base.unsafe_convert(Ptr{UInt64}, Core.getptls()) + rngseed = Base.unsafe_load(ptls, 2) # TODO: What's the best way to do this for 32bit. + val, seed = rand_uniform_max_int32(max, rngseed) + Base.unsafe_store!(ptls, seed, 2) + return val % UInt32 +end + +# This implementation is based on OpenSSLs implementation of rand_uniform +# https://github.com/openssl/openssl/blob/1d2cbd9b5a126189d5e9bc78a3bdb9709427d02b/crypto/rand/rand_uniform.c#L13-L99 +# Comments are vendored from their implementation as well. +# For the original developer check the PR to swift https://github.com/apple/swift/pull/39143. + +# Essentially it boils down to incrementally generating a fixed point +# number on the interval [0, 1) and multiplying this number by the upper +# range limit. Once it is certain what the fractional part contributes to +# the integral part of the product, the algorithm has produced a definitive +# result. +""" + rand_uniform_max_int32(max::UInt32, seed::UInt64) +Return a random UInt32 in the range `0:max-1` using the given seed. +Max must be greater than 0. +""" +Base.@assume_effects :total function rand_uniform_max_int32(max::UInt32, seed::UInt64) + if max == UInt32(1) + return UInt32(0), seed + end +# We are generating a fixed point number on the interval [0, 1). +# Multiplying this by the range gives us a number on [0, upper). +# The high word of the multiplication result represents the integral part +# This is not completely unbiased as it's missing the fractional part of the original implementation but it's good enough for our purposes + seed = UInt64(69069) * seed + UInt64(362437) + prod = (UInt64(max)) * (seed % UInt32) # 64 bit product + i = prod >> 32 % UInt32 # integral part + return i % UInt32, seed +end + function multiq_sift_up(heap::taskheap, idx::Int32) From cd16a4905e5311756d0b2d5a58fe2458321f5cd6 Mon Sep 17 00:00:00 2001 From: gbaraldi Date: Fri, 16 Aug 2024 13:17:24 -0300 Subject: [PATCH 2/5] Fix implementation in 32 bit platforms --- base/partr.jl | 8 ++++---- src/init.c | 1 + src/jl_exported_data.inc | 1 + src/julia.h | 1 + 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/base/partr.jl b/base/partr.jl index 1656b56c21989..41213e85e9ba3 100644 --- a/base/partr.jl +++ b/base/partr.jl @@ -26,18 +26,18 @@ Return a random UInt32 in the range `1:max` except if max is 0, in that case ret """ cong(max::UInt32) = iszero(max) ? UInt32(0) : jl_rand_ptls(max) + UInt32(1) #TODO: make sure users don't use 0 and remove this check +const rngseed_offset = unsafe_load(cglobal(:jl_ptls_rng_offset, Cint)) + """ jl_rand_ptls(max::UInt32) Return a random UInt32 in the range `0:max-1` using the thread-local RNG state. Max must be greater than 0. """ Base.@assume_effects :removable :inaccessiblememonly :notaskstate function jl_rand_ptls(max::UInt32) - # Are these effects correct? We are technically lying to the compiler - # Though these are the same lies we tell to say that an unexcaped allocation has no effects ptls = Base.unsafe_convert(Ptr{UInt64}, Core.getptls()) - rngseed = Base.unsafe_load(ptls, 2) # TODO: What's the best way to do this for 32bit. + rngseed = Base.unsafe_load(ptls + rngseed_offset) val, seed = rand_uniform_max_int32(max, rngseed) - Base.unsafe_store!(ptls, seed, 2) + Base.unsafe_store!(ptls + rngseed_offset, seed) return val % UInt32 end diff --git a/src/init.c b/src/init.c index eff786b564e54..ef19e9218f559 100644 --- a/src/init.c +++ b/src/init.c @@ -785,6 +785,7 @@ JL_DLLEXPORT void julia_init(JL_IMAGE_SEARCH rel) // Important offset for external codegen. jl_task_gcstack_offset = offsetof(jl_task_t, gcstack); jl_task_ptls_offset = offsetof(jl_task_t, ptls); + jl_ptls_rng_offset = offsetof(jl_tls_states_t, rngseed); jl_prep_sanitizers(); void *stack_lo, *stack_hi; diff --git a/src/jl_exported_data.inc b/src/jl_exported_data.inc index ff79966b2b01b..0327872a0f9cc 100644 --- a/src/jl_exported_data.inc +++ b/src/jl_exported_data.inc @@ -159,5 +159,6 @@ XX(jl_options, jl_options_t) \ XX(jl_task_gcstack_offset, int) \ XX(jl_task_ptls_offset, int) \ + XX(jl_ptls_rng_offset, int) \ // end of file diff --git a/src/julia.h b/src/julia.h index e211f31c6512c..58f13412dfd0b 100644 --- a/src/julia.h +++ b/src/julia.h @@ -2256,6 +2256,7 @@ JL_DLLEXPORT JL_CONST_FUNC jl_gcframe_t **(jl_get_pgcstack)(void) JL_GLOBALLY_RO extern JL_DLLIMPORT int jl_task_gcstack_offset; extern JL_DLLIMPORT int jl_task_ptls_offset; +extern JL_DLLIMPORT int jl_ptls_rng_offset; #include "julia_locks.h" // requires jl_task_t definition From 1039ef4a98a7f2891e3ade5f4b571557497ac5bd Mon Sep 17 00:00:00 2001 From: gbaraldi Date: Mon, 26 Aug 2024 16:19:32 -0300 Subject: [PATCH 3/5] Update code to use ccalls instead of unsafe_load with fast ccall handling --- base/partr.jl | 9 +++++---- src/ccall.cpp | 32 ++++++++++++++++++++++++++++++++ src/init.c | 1 - src/jl_exported_data.inc | 1 - src/jl_exported_funcs.inc | 2 ++ src/julia.h | 1 - src/julia_threads.h | 2 ++ src/threading.c | 12 ++++++++++++ 8 files changed, 53 insertions(+), 7 deletions(-) diff --git a/base/partr.jl b/base/partr.jl index 41213e85e9ba3..9a2d5dee96df8 100644 --- a/base/partr.jl +++ b/base/partr.jl @@ -26,7 +26,9 @@ Return a random UInt32 in the range `1:max` except if max is 0, in that case ret """ cong(max::UInt32) = iszero(max) ? UInt32(0) : jl_rand_ptls(max) + UInt32(1) #TODO: make sure users don't use 0 and remove this check -const rngseed_offset = unsafe_load(cglobal(:jl_ptls_rng_offset, Cint)) +get_ptls_rng() = ccall(:jl_get_ptls_rng, UInt64, ()) + +set_ptls_rng(seed::UInt64) = ccall(:jl_set_ptls_rng, Cvoid, (UInt64,), seed) """ jl_rand_ptls(max::UInt32) @@ -34,10 +36,9 @@ Return a random UInt32 in the range `0:max-1` using the thread-local RNG state. Max must be greater than 0. """ Base.@assume_effects :removable :inaccessiblememonly :notaskstate function jl_rand_ptls(max::UInt32) - ptls = Base.unsafe_convert(Ptr{UInt64}, Core.getptls()) - rngseed = Base.unsafe_load(ptls + rngseed_offset) + rngseed = get_ptls_rng() val, seed = rand_uniform_max_int32(max, rngseed) - Base.unsafe_store!(ptls + rngseed_offset, seed) + set_ptls_rng(seed) return val % UInt32 end diff --git a/src/ccall.cpp b/src/ccall.cpp index 36808e13fdbf9..7ab8cfa974d6f 100644 --- a/src/ccall.cpp +++ b/src/ccall.cpp @@ -22,6 +22,8 @@ TRANSFORMED_CCALL_STAT(jl_cpu_wake); TRANSFORMED_CCALL_STAT(jl_gc_safepoint); TRANSFORMED_CCALL_STAT(jl_get_ptls_states); TRANSFORMED_CCALL_STAT(jl_threadid); +TRANSFORMED_CCALL_STAT(jl_get_ptls_rng); +TRANSFORMED_CCALL_STAT(jl_set_ptls_rng); TRANSFORMED_CCALL_STAT(jl_get_tls_world_age); TRANSFORMED_CCALL_STAT(jl_get_world_counter); TRANSFORMED_CCALL_STAT(jl_gc_enable_disable_finalizers_internal); @@ -1692,6 +1694,36 @@ static jl_cgval_t emit_ccall(jl_codectx_t &ctx, jl_value_t **args, size_t nargs) ai.decorateInst(tid); return mark_or_box_ccall_result(ctx, tid, retboxed, rt, unionall, static_rt); } + else if (is_libjulia_func(jl_get_ptls_rng)) { + ++CCALL_STAT(jl_get_ptls_rng); + assert(lrt == getInt64Ty(ctx.builder.getContext())); + assert(!isVa && !llvmcall && nccallargs == 0); + JL_GC_POP(); + Value *ptls_p = get_current_ptls(ctx); + const int rng_offset = offsetof(jl_tls_states_t, rngseed); + Value *rng_ptr = ctx.builder.CreateInBoundsGEP(getInt8Ty(ctx.builder.getContext()), ptls_p, ConstantInt::get(ctx.types().T_size, rng_offset / sizeof(int8_t))); + setName(ctx.emission_context, rng_ptr, "rngseed_ptr"); + LoadInst *rng_value = ctx.builder.CreateAlignedLoad(getInt64Ty(ctx.builder.getContext()), rng_ptr, Align(sizeof(void*))); + setName(ctx.emission_context, rng_value, "rngseed"); + jl_aliasinfo_t ai = jl_aliasinfo_t::fromTBAA(ctx, ctx.tbaa().tbaa_gcframe); + ai.decorateInst(rng_value); + return mark_or_box_ccall_result(ctx, rng_value, retboxed, rt, unionall, static_rt); + } + else if (is_libjulia_func(jl_set_ptls_rng)) { + ++CCALL_STAT(jl_set_ptls_rng); + assert(lrt == getVoidTy(ctx.builder.getContext())); + assert(!isVa && !llvmcall && nccallargs == 1); + JL_GC_POP(); + Value *ptls_p = get_current_ptls(ctx); + const int rng_offset = offsetof(jl_tls_states_t, rngseed); + Value *rng_ptr = ctx.builder.CreateInBoundsGEP(getInt8Ty(ctx.builder.getContext()), ptls_p, ConstantInt::get(ctx.types().T_size, rng_offset / sizeof(int8_t))); + setName(ctx.emission_context, rng_ptr, "rngseed_ptr"); + assert(argv[0].V->getType() == getInt64Ty(ctx.builder.getContext())); + auto store = ctx.builder.CreateAlignedStore(argv[0].V, rng_ptr, Align(sizeof(void*))); + jl_aliasinfo_t ai = jl_aliasinfo_t::fromTBAA(ctx, ctx.tbaa().tbaa_gcframe); + ai.decorateInst(store); + return ghostValue(ctx, jl_nothing_type); + } else if (is_libjulia_func(jl_get_tls_world_age)) { bool toplevel = !(ctx.linfo && jl_is_method(ctx.linfo->def.method)); if (!toplevel) { // top level code does not see a stable world age during execution diff --git a/src/init.c b/src/init.c index 4d995e18c7699..9e6a695c71eb0 100644 --- a/src/init.c +++ b/src/init.c @@ -778,7 +778,6 @@ JL_DLLEXPORT void julia_init(JL_IMAGE_SEARCH rel) // Important offset for external codegen. jl_task_gcstack_offset = offsetof(jl_task_t, gcstack); jl_task_ptls_offset = offsetof(jl_task_t, ptls); - jl_ptls_rng_offset = offsetof(jl_tls_states_t, rngseed); jl_prep_sanitizers(); void *stack_lo, *stack_hi; diff --git a/src/jl_exported_data.inc b/src/jl_exported_data.inc index 0327872a0f9cc..ff79966b2b01b 100644 --- a/src/jl_exported_data.inc +++ b/src/jl_exported_data.inc @@ -159,6 +159,5 @@ XX(jl_options, jl_options_t) \ XX(jl_task_gcstack_offset, int) \ XX(jl_task_ptls_offset, int) \ - XX(jl_ptls_rng_offset, int) \ // end of file diff --git a/src/jl_exported_funcs.inc b/src/jl_exported_funcs.inc index 1976dbe709733..583b0ea5e78fe 100644 --- a/src/jl_exported_funcs.inc +++ b/src/jl_exported_funcs.inc @@ -454,6 +454,8 @@ XX(jl_test_cpu_feature) \ XX(jl_threadid) \ XX(jl_threadpoolid) \ + XX(jl_get_ptls_rng) \ + XX(jl_set_ptls_rng) \ XX(jl_throw) \ XX(jl_throw_out_of_memory_error) \ XX(jl_too_few_args) \ diff --git a/src/julia.h b/src/julia.h index d13f6d8dae5fd..074c50fd0aa21 100644 --- a/src/julia.h +++ b/src/julia.h @@ -2252,7 +2252,6 @@ JL_DLLEXPORT JL_CONST_FUNC jl_gcframe_t **(jl_get_pgcstack)(void) JL_GLOBALLY_RO extern JL_DLLIMPORT int jl_task_gcstack_offset; extern JL_DLLIMPORT int jl_task_ptls_offset; -extern JL_DLLIMPORT int jl_ptls_rng_offset; #include "julia_locks.h" // requires jl_task_t definition diff --git a/src/julia_threads.h b/src/julia_threads.h index e56ff5edd6176..d7c2e707f6717 100644 --- a/src/julia_threads.h +++ b/src/julia_threads.h @@ -18,6 +18,8 @@ extern "C" { JL_DLLEXPORT int16_t jl_threadid(void); JL_DLLEXPORT int8_t jl_threadpoolid(int16_t tid) JL_NOTSAFEPOINT; +JL_DLLEXPORT uint64_t jl_get_ptls_rng(void) JL_NOTSAFEPOINT; +JL_DLLEXPORT void jl_set_ptls_rng(uint64_t new_seed) JL_NOTSAFEPOINT; // JULIA_ENABLE_THREADING may be controlled by altering JULIA_THREADS in Make.user diff --git a/src/threading.c b/src/threading.c index 8f37ee814056c..a6b5be27d9108 100644 --- a/src/threading.c +++ b/src/threading.c @@ -338,6 +338,18 @@ JL_DLLEXPORT int8_t jl_threadpoolid(int16_t tid) JL_NOTSAFEPOINT return -1; // everything else uses threadpool -1 (does not belong to any threadpool) } +// get thread local rng +JL_DLLEXPORT uint64_t jl_get_ptls_rng(void) JL_NOTSAFEPOINT +{ + return jl_current_task->ptls->rngseed; +} + +// get thread local rng +JL_DLLEXPORT void jl_set_ptls_rng(uint64_t new_seed) JL_NOTSAFEPOINT +{ + jl_current_task->ptls->rngseed = new_seed; +} + jl_ptls_t jl_init_threadtls(int16_t tid) { #ifndef _OS_WINDOWS_ From cb49e1a6a801c4bdbdef00f0e926389912e71c9f Mon Sep 17 00:00:00 2001 From: gbaraldi Date: Tue, 27 Aug 2024 11:15:30 -0300 Subject: [PATCH 4/5] Apply suggestions from code review --- base/partr.jl | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/base/partr.jl b/base/partr.jl index 9a2d5dee96df8..6053a584af5ba 100644 --- a/base/partr.jl +++ b/base/partr.jl @@ -22,20 +22,22 @@ const heaps_lock = [SpinLock(), SpinLock()] """ cong(max::UInt32) + Return a random UInt32 in the range `1:max` except if max is 0, in that case return 0. """ -cong(max::UInt32) = iszero(max) ? UInt32(0) : jl_rand_ptls(max) + UInt32(1) #TODO: make sure users don't use 0 and remove this check +cong(max::UInt32) = iszero(max) ? UInt32(0) : rand_ptls(max) + UInt32(1) #TODO: make sure users don't use 0 and remove this check get_ptls_rng() = ccall(:jl_get_ptls_rng, UInt64, ()) set_ptls_rng(seed::UInt64) = ccall(:jl_set_ptls_rng, Cvoid, (UInt64,), seed) """ - jl_rand_ptls(max::UInt32) + rand_ptls(max::UInt32) + Return a random UInt32 in the range `0:max-1` using the thread-local RNG state. Max must be greater than 0. """ -Base.@assume_effects :removable :inaccessiblememonly :notaskstate function jl_rand_ptls(max::UInt32) +Base.@assume_effects :removable :inaccessiblememonly :notaskstate function rand_ptls(max::UInt32) rngseed = get_ptls_rng() val, seed = rand_uniform_max_int32(max, rngseed) set_ptls_rng(seed) @@ -54,6 +56,7 @@ end # result. """ rand_uniform_max_int32(max::UInt32, seed::UInt64) + Return a random UInt32 in the range `0:max-1` using the given seed. Max must be greater than 0. """ @@ -61,10 +64,10 @@ Base.@assume_effects :total function rand_uniform_max_int32(max::UInt32, seed::U if max == UInt32(1) return UInt32(0), seed end -# We are generating a fixed point number on the interval [0, 1). -# Multiplying this by the range gives us a number on [0, upper). -# The high word of the multiplication result represents the integral part -# This is not completely unbiased as it's missing the fractional part of the original implementation but it's good enough for our purposes + # We are generating a fixed point number on the interval [0, 1). + # Multiplying this by the range gives us a number on [0, upper). + # The high word of the multiplication result represents the integral part + # This is not completely unbiased as it's missing the fractional part of the original implementation but it's good enough for our purposes seed = UInt64(69069) * seed + UInt64(362437) prod = (UInt64(max)) * (seed % UInt32) # 64 bit product i = prod >> 32 % UInt32 # integral part From d072ad52bd95f859377372ddf41d14cdd3ec4dff Mon Sep 17 00:00:00 2001 From: gbaraldi Date: Tue, 27 Aug 2024 11:17:39 -0300 Subject: [PATCH 5/5] Remove jl_rand_ptls --- src/scheduler.c | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/scheduler.c b/src/scheduler.c index 3cf97ba108873..881cf29e3b49c 100644 --- a/src/scheduler.c +++ b/src/scheduler.c @@ -84,15 +84,6 @@ JL_DLLEXPORT int jl_set_task_threadpoolid(jl_task_t *task, int8_t tpid) JL_NOTSA extern int jl_gc_mark_queue_obj_explicit(jl_gc_mark_cache_t *gc_cache, jl_gc_markqueue_t *mq, jl_value_t *obj) JL_NOTSAFEPOINT; -// parallel task runtime -// --- - -JL_DLLEXPORT uint32_t jl_rand_ptls(uint32_t max) -{ - jl_ptls_t ptls = jl_current_task->ptls; - return cong(max, &ptls->rngseed); -} - // initialize the threading infrastructure // (called only by the main thread) void jl_init_threadinginfra(void)