From 1a38fbde04c20465546d95511252e74f5dd63e06 Mon Sep 17 00:00:00 2001 From: Jan Vesely Date: Wed, 13 Nov 2024 10:32:59 -0500 Subject: [PATCH] llvm: Implement range integer generation to match Numpy's Uses Lemire's algorithm [0]. Applies to Philox PRNG using "integers" API call, "randint" uses older, "masked rejection sampling" approach. [0] https://arxiv.org/abs/1805.10941 Signed-off-by: Jan Vesely --- psyneulink/core/llvm/builtins.py | 89 ++++++++++++++++++++++- tests/llvm/test_builtins_philox_random.py | 55 ++++++++++++++ 2 files changed, 142 insertions(+), 2 deletions(-) diff --git a/psyneulink/core/llvm/builtins.py b/psyneulink/core/llvm/builtins.py index 20920ccf59..13bd76e97b 100644 --- a/psyneulink/core/llvm/builtins.py +++ b/psyneulink/core/llvm/builtins.py @@ -17,8 +17,10 @@ def _setup_builtin_func_builder(ctx, name, args, *, return_type=ir.VoidType()): - builder = ctx.create_llvm_function(args, None, _BUILTIN_PREFIX + name, - return_type=return_type) + if not name.startswith(_BUILTIN_PREFIX): + name = _BUILTIN_PREFIX + name + + builder = ctx.create_llvm_function(args, None, name, return_type=return_type) # Add noalias attribute for a in builder.function.args: @@ -758,6 +760,88 @@ def _setup_mt_rand_integer(ctx, state_ty): return builder.function +def _setup_rand_lemire_int32(ctx, state_ty, gen_int32): + + """ + Uses Lemire's algorithm - https://arxiv.org/abs/1805.10941 + As implemented in Numpy to match Numpy results. + """ + + out_ty = gen_int32.args[1].type.pointee + builder = _setup_builtin_func_builder(ctx, gen_int32.name + "_bounded", (state_ty.as_pointer(), out_ty, out_ty, out_ty.as_pointer())) + state, lower, upper, out_ptr = builder.function.args + + rand_range_excl = builder.sub(upper, lower) + rand_range_excl_64 = builder.zext(rand_range_excl, ir.IntType(64)) + rand_range = builder.sub(rand_range_excl, rand_range_excl.type(1)) + + + builder.call(gen_int32, [state, out_ptr]) + val = builder.load(out_ptr) + + is_full_range = builder.icmp_unsigned("==", rand_range, rand_range.type(0xffffffff)) + with builder.if_then(is_full_range): + builder.ret_void() + + val64 = builder.zext(val, rand_range_excl_64.type) + m = builder.mul(val64, rand_range_excl_64) + + # Store current result as output. It will be overwritten below if needed. + out_val = builder.lshr(m, m.type(32)) + out_val = builder.trunc(out_val, out_ptr.type.pointee) + out_val = builder.add(out_val, lower) + builder.store(out_val, out_ptr) + + leftover = builder.and_(m, m.type(0xffffffff)) + + is_good = builder.icmp_unsigned(">=", leftover, rand_range_excl_64) + with builder.if_then(is_good): + builder.ret_void() + + # Apply rejection sampling + leftover_ptr = builder.alloca(leftover.type) + builder.store(leftover, leftover_ptr) + + rand_range_64 = builder.zext(rand_range, ir.IntType(64)) + threshold = builder.sub(rand_range_64, rand_range_64.type(0xffffffff)) + threshold = builder.urem(threshold, rand_range_excl_64) + + cond_block = builder.append_basic_block("bounded_cond_block") + loop_block = builder.append_basic_block("bounded_loop_block") + out_block = builder.append_basic_block("bounded_out_block") + + builder.branch(cond_block) + + # Condition: leftover < threshold + builder.position_at_end(cond_block) + leftover = builder.load(leftover_ptr) + do_next = builder.icmp_unsigned("<", leftover, threshold) + builder.cbranch(do_next, loop_block, out_block) + + # Loop block: + # m = ((uint64_t)next_uint32(bitgen_state)) * rng_excl; + # leftover = m & 0xffffffff + # result = m >> 32 + builder.position_at_end(loop_block) + builder.call(gen_int32, [state, out_ptr]) + + val = builder.load(out_ptr) + val64 = builder.zext(val, rand_range_excl_64.type) + m = builder.mul(val64, rand_range_excl_64) + + leftover = builder.and_(m, m.type(0xffffffff)) + builder.store(leftover, leftover_ptr) + + out_val = builder.lshr(m, m.type(32)) + out_val = builder.trunc(out_val, out_ptr.type.pointee) + out_val = builder.add(out_val, lower) + builder.store(out_val, out_ptr) + builder.branch(cond_block) + + + builder.position_at_end(out_block) + builder.ret_void() + def _setup_mt_rand_float(ctx, state_ty, gen_int): """ Mersenne Twister double prcision random number generation. @@ -2087,6 +2171,7 @@ def setup_philox(ctx): _setup_rand_binomial(ctx, state_ty, gen_double, prefix="philox") gen_int32 = _setup_philox_rand_int32(ctx, state_ty, gen_int64) + _setup_rand_lemire_int32(ctx, state_ty, gen_int32) gen_float = _setup_philox_rand_float(ctx, state_ty, gen_int32) _setup_philox_rand_normal(ctx, state_ty, gen_float, gen_int32, _wi_float_data, _ki_i32_data, _fi_float_data) _setup_rand_binomial(ctx, state_ty, gen_float, prefix="philox") diff --git a/tests/llvm/test_builtins_philox_random.py b/tests/llvm/test_builtins_philox_random.py index 56b6485b75..e1711dc7ad 100644 --- a/tests/llvm/test_builtins_philox_random.py +++ b/tests/llvm/test_builtins_philox_random.py @@ -60,6 +60,61 @@ def f(): benchmark(f) +@pytest.mark.benchmark(group="Philox integer PRNG") +@pytest.mark.parametrize('mode', ['numpy', + pytest.param('LLVM', marks=pytest.mark.llvm), + pytest.helpers.cuda_param('PTX')]) +@pytest.mark.parametrize("bounds, expected", + [((0xffffffff,), [582496169, 60417458, 4027530181, 1107101889, 1659784452, 2025357889]), + ((15,), [2, 0, 14, 3, 5, 7]), + ((0,15), [2, 0, 14, 3, 5, 7]), + ((5,0xffff), [8892, 926, 61454, 16896, 25328, 30906]), + ], ids=lambda x: str(x) if len(x) != 6 else "") +def test_random_int32_bounded(benchmark, mode, bounds, expected): + res = [] + if mode == 'numpy': + state = np.random.Philox([SEED]) + prng = np.random.Generator(state)\ + + def f(): + return prng.integers(*bounds, dtype=np.uint32, endpoint=False) + + elif mode == 'LLVM': + init_fun = pnlvm.LLVMBinaryFunction.get('__pnl_builtin_philox_rand_init') + state = init_fun.np_buffer_for_arg(0) + init_fun(state, SEED) + + gen_fun = pnlvm.LLVMBinaryFunction.get('__pnl_builtin_philox_rand_int32_bounded') + + def f(): + lower, upper = bounds if len(bounds) == 2 else (0, bounds[0]) + out = gen_fun.np_buffer_for_arg(3) + gen_fun(state, lower, upper, out) + return out + + elif mode == 'PTX': + init_fun = pnlvm.LLVMBinaryFunction.get('__pnl_builtin_philox_rand_init') + state_size = init_fun.np_buffer_for_arg(0).nbytes + gpu_state = pnlvm.jit_engine.pycuda.driver.mem_alloc(state_size) + init_fun.cuda_call(gpu_state, np.int64(SEED)) + + gen_fun = pnlvm.LLVMBinaryFunction.get('__pnl_builtin_philox_rand_int32_bounded') + out = gen_fun.np_buffer_for_arg(3) + gpu_out = pnlvm.jit_engine.pycuda.driver.Out(out) + + def f(): + lower, upper = bounds if len(bounds) == 2 else (0, bounds[0]) + gen_fun.cuda_call(gpu_state, np.int32(lower), np.int32(upper), gpu_out) + return out.copy() + + else: + assert False, "Unknown mode: {}".format(mode) + + # Get >4 samples to force regeneration of Philox buffer + res = [f(), f(), f(), f(), f(), f()] + np.testing.assert_allclose(res, expected) + benchmark(f) + @pytest.mark.benchmark(group="Philox integer PRNG") @pytest.mark.parametrize('mode', ['numpy', pytest.param('LLVM', marks=pytest.mark.llvm),