Skip to content

Commit

Permalink
llvm: Implement range integer generation to match Numpy's
Browse files Browse the repository at this point in the history
Uses Lemire's algorithm [0].
Applies to Philox PRNG using "integers" API call,
"randint" uses older, "masked rejection sampling" approach.

[0] https://arxiv.org/abs/1805.10941

Signed-off-by: Jan Vesely <[email protected]>
  • Loading branch information
jvesely committed Nov 13, 2024
1 parent 60d3c50 commit 1a38fbd
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 2 deletions.
89 changes: 87 additions & 2 deletions psyneulink/core/llvm/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@


def _setup_builtin_func_builder(ctx, name, args, *, return_type=ir.VoidType()):
builder = ctx.create_llvm_function(args, None, _BUILTIN_PREFIX + name,
return_type=return_type)
if not name.startswith(_BUILTIN_PREFIX):
name = _BUILTIN_PREFIX + name

builder = ctx.create_llvm_function(args, None, name, return_type=return_type)

# Add noalias attribute
for a in builder.function.args:
Expand Down Expand Up @@ -758,6 +760,88 @@ def _setup_mt_rand_integer(ctx, state_ty):
return builder.function


def _setup_rand_lemire_int32(ctx, state_ty, gen_int32):

"""
Uses Lemire's algorithm - https://arxiv.org/abs/1805.10941
As implemented in Numpy to match Numpy results.
"""

out_ty = gen_int32.args[1].type.pointee
builder = _setup_builtin_func_builder(ctx, gen_int32.name + "_bounded", (state_ty.as_pointer(), out_ty, out_ty, out_ty.as_pointer()))
state, lower, upper, out_ptr = builder.function.args

rand_range_excl = builder.sub(upper, lower)
rand_range_excl_64 = builder.zext(rand_range_excl, ir.IntType(64))
rand_range = builder.sub(rand_range_excl, rand_range_excl.type(1))


builder.call(gen_int32, [state, out_ptr])
val = builder.load(out_ptr)

is_full_range = builder.icmp_unsigned("==", rand_range, rand_range.type(0xffffffff))
with builder.if_then(is_full_range):
builder.ret_void()

val64 = builder.zext(val, rand_range_excl_64.type)
m = builder.mul(val64, rand_range_excl_64)

# Store current result as output. It will be overwritten below if needed.
out_val = builder.lshr(m, m.type(32))
out_val = builder.trunc(out_val, out_ptr.type.pointee)
out_val = builder.add(out_val, lower)
builder.store(out_val, out_ptr)

leftover = builder.and_(m, m.type(0xffffffff))

is_good = builder.icmp_unsigned(">=", leftover, rand_range_excl_64)
with builder.if_then(is_good):
builder.ret_void()

# Apply rejection sampling
leftover_ptr = builder.alloca(leftover.type)
builder.store(leftover, leftover_ptr)

rand_range_64 = builder.zext(rand_range, ir.IntType(64))
threshold = builder.sub(rand_range_64, rand_range_64.type(0xffffffff))
threshold = builder.urem(threshold, rand_range_excl_64)

cond_block = builder.append_basic_block("bounded_cond_block")
loop_block = builder.append_basic_block("bounded_loop_block")
out_block = builder.append_basic_block("bounded_out_block")

builder.branch(cond_block)

# Condition: leftover < threshold
builder.position_at_end(cond_block)
leftover = builder.load(leftover_ptr)
do_next = builder.icmp_unsigned("<", leftover, threshold)
builder.cbranch(do_next, loop_block, out_block)

# Loop block:
# m = ((uint64_t)next_uint32(bitgen_state)) * rng_excl;
# leftover = m & 0xffffffff
# result = m >> 32
builder.position_at_end(loop_block)
builder.call(gen_int32, [state, out_ptr])

val = builder.load(out_ptr)
val64 = builder.zext(val, rand_range_excl_64.type)
m = builder.mul(val64, rand_range_excl_64)

leftover = builder.and_(m, m.type(0xffffffff))
builder.store(leftover, leftover_ptr)

out_val = builder.lshr(m, m.type(32))
out_val = builder.trunc(out_val, out_ptr.type.pointee)
out_val = builder.add(out_val, lower)
builder.store(out_val, out_ptr)
builder.branch(cond_block)


builder.position_at_end(out_block)
builder.ret_void()

def _setup_mt_rand_float(ctx, state_ty, gen_int):
"""
Mersenne Twister double prcision random number generation.
Expand Down Expand Up @@ -2087,6 +2171,7 @@ def setup_philox(ctx):
_setup_rand_binomial(ctx, state_ty, gen_double, prefix="philox")

gen_int32 = _setup_philox_rand_int32(ctx, state_ty, gen_int64)
_setup_rand_lemire_int32(ctx, state_ty, gen_int32)
gen_float = _setup_philox_rand_float(ctx, state_ty, gen_int32)
_setup_philox_rand_normal(ctx, state_ty, gen_float, gen_int32, _wi_float_data, _ki_i32_data, _fi_float_data)
_setup_rand_binomial(ctx, state_ty, gen_float, prefix="philox")
55 changes: 55 additions & 0 deletions tests/llvm/test_builtins_philox_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,61 @@ def f():
benchmark(f)


@pytest.mark.benchmark(group="Philox integer PRNG")
@pytest.mark.parametrize('mode', ['numpy',
pytest.param('LLVM', marks=pytest.mark.llvm),
pytest.helpers.cuda_param('PTX')])
@pytest.mark.parametrize("bounds, expected",
[((0xffffffff,), [582496169, 60417458, 4027530181, 1107101889, 1659784452, 2025357889]),
((15,), [2, 0, 14, 3, 5, 7]),
((0,15), [2, 0, 14, 3, 5, 7]),
((5,0xffff), [8892, 926, 61454, 16896, 25328, 30906]),
], ids=lambda x: str(x) if len(x) != 6 else "")
def test_random_int32_bounded(benchmark, mode, bounds, expected):
res = []
if mode == 'numpy':
state = np.random.Philox([SEED])
prng = np.random.Generator(state)\

def f():
return prng.integers(*bounds, dtype=np.uint32, endpoint=False)

elif mode == 'LLVM':
init_fun = pnlvm.LLVMBinaryFunction.get('__pnl_builtin_philox_rand_init')
state = init_fun.np_buffer_for_arg(0)
init_fun(state, SEED)

gen_fun = pnlvm.LLVMBinaryFunction.get('__pnl_builtin_philox_rand_int32_bounded')

def f():
lower, upper = bounds if len(bounds) == 2 else (0, bounds[0])
out = gen_fun.np_buffer_for_arg(3)
gen_fun(state, lower, upper, out)
return out

elif mode == 'PTX':
init_fun = pnlvm.LLVMBinaryFunction.get('__pnl_builtin_philox_rand_init')
state_size = init_fun.np_buffer_for_arg(0).nbytes
gpu_state = pnlvm.jit_engine.pycuda.driver.mem_alloc(state_size)
init_fun.cuda_call(gpu_state, np.int64(SEED))

gen_fun = pnlvm.LLVMBinaryFunction.get('__pnl_builtin_philox_rand_int32_bounded')
out = gen_fun.np_buffer_for_arg(3)
gpu_out = pnlvm.jit_engine.pycuda.driver.Out(out)

def f():
lower, upper = bounds if len(bounds) == 2 else (0, bounds[0])
gen_fun.cuda_call(gpu_state, np.int32(lower), np.int32(upper), gpu_out)
return out.copy()

else:
assert False, "Unknown mode: {}".format(mode)

# Get >4 samples to force regeneration of Philox buffer
res = [f(), f(), f(), f(), f(), f()]
np.testing.assert_allclose(res, expected)
benchmark(f)

@pytest.mark.benchmark(group="Philox integer PRNG")
@pytest.mark.parametrize('mode', ['numpy',
pytest.param('LLVM', marks=pytest.mark.llvm),
Expand Down

0 comments on commit 1a38fbd

Please sign in to comment.