Skip to content

Commit

Permalink
llvm: Implement range integer generation to match Python's
Browse files Browse the repository at this point in the history
Uses bit masked rejection sampling.
Matches to Python's Random.randrange API call.

Signed-off-by: Jan Vesely <[email protected]>
  • Loading branch information
jvesely committed Nov 15, 2024
1 parent 87054dc commit 2b4a77d
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 4 deletions.
48 changes: 44 additions & 4 deletions psyneulink/core/llvm/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,10 +658,11 @@ def _setup_mt_rand_init(ctx, state_ty, init_scalar):
return builder.function


def _setup_mt_rand_integer(ctx, state_ty):
def _setup_mt_rand_int32(ctx, state_ty):
int64_ty = ir.IntType(64)

# Generate random number generator function.
# It produces random 32bit numberin a 64bit word
# It produces random 32bit number in a 64bit word
builder = _setup_builtin_func_builder(ctx, "mt_rand_int32", (state_ty.as_pointer(), int64_ty.as_pointer()))
state, out = builder.function.args

Expand Down Expand Up @@ -759,6 +760,44 @@ def _setup_mt_rand_integer(ctx, state_ty):

return builder.function


def _setup_rand_bounded_int32(ctx, state_ty, gen_int32):

out_ty = gen_int32.args[1].type.pointee
builder = _setup_builtin_func_builder(ctx, gen_int32.name + "_bounded", (state_ty.as_pointer(), ctx.int32_ty, ctx.int32_ty, out_ty.as_pointer()))
state, lower, upper, out_ptr = builder.function.args

rand_range_excl = builder.sub(upper, lower)

range_leading_zeros_in32bit = builder.ctlz(rand_range_excl, ctx.bool_ty(1))
range_leading_zeros_in32bit = builder.zext(range_leading_zeros_in32bit, out_ty)

rand_range_excl = builder.zext(rand_range_excl, out_ty)

loop_block = builder.append_basic_block("bounded_loop_block")
out_block = builder.append_basic_block("bounded_out_block")

builder.branch(loop_block)

# Loop:
# do:
# r = random() >> INT_BITS - num_bits # num_bits = INT_BITS - leading zeros
# while r >= limit
builder.position_at_end(loop_block)

builder.call(gen_int32, [state, out_ptr])
val = builder.load(out_ptr)
val = builder.lshr(val, range_leading_zeros_in32bit)

is_above_limit = builder.icmp_unsigned(">=", val, rand_range_excl)
builder.cbranch(is_above_limit, loop_block, out_block)

builder.position_at_end(out_block)
offset = builder.zext(lower, val.type)
result = builder.add(val, offset)
builder.store(result, out_ptr)
builder.ret_void()

def _setup_mt_rand_float(ctx, state_ty, gen_int):
"""
Mersenne Twister double prcision random number generation.
Expand Down Expand Up @@ -893,8 +932,9 @@ def setup_mersenne_twister(ctx):
init_scalar = _setup_mt_rand_init_scalar(ctx, state_ty)
_setup_mt_rand_init(ctx, state_ty, init_scalar)

gen_int = _setup_mt_rand_integer(ctx, state_ty)
gen_float = _setup_mt_rand_float(ctx, state_ty, gen_int)
gen_int32 = _setup_mt_rand_int32(ctx, state_ty)
_setup_rand_bounded_int32(ctx, state_ty, gen_int32)
gen_float = _setup_mt_rand_float(ctx, state_ty, gen_int32)
_setup_mt_rand_normal(ctx, state_ty, gen_float)
_setup_rand_binomial(ctx, state_ty, gen_float, prefix="mt")

Expand Down
58 changes: 58 additions & 0 deletions tests/llvm/test_builtins_mt_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,64 @@

SEED = 0

@pytest.mark.benchmark(group="Mersenne Twister bounded integer PRNG")
@pytest.mark.parametrize('mode', ['Python',
pytest.param('LLVM', marks=pytest.mark.llvm),
pytest.helpers.cuda_param('PTX')])
@pytest.mark.parametrize("bounds, expected",
[((0xffffffff,), [3626764237, 1654615998, 3255389356, 3823568514, 1806341205]),
((14,), [13, 6, 12, 6, 0]),
((0,14), [13, 6, 12, 6, 0]),
((5,0xffff), [55345, 25252, 49678, 58348, 27567]),
], ids=lambda x: str(x) if len(x) != 5 else "")
def test_random_int32_bounded(benchmark, mode, bounds, expected):

res = []
if mode == 'Python':
state = random.Random(SEED)

def f():
# Python randrange is exclusive upper bound
return state.randrange(*bounds)

elif mode == 'LLVM':
init_fun = pnlvm.LLVMBinaryFunction.get('__pnl_builtin_mt_rand_init')
state = init_fun.np_buffer_for_arg(0)

init_fun(state, SEED)

gen_fun = pnlvm.LLVMBinaryFunction.get('__pnl_builtin_mt_rand_int32_bounded')

def f():
lower, upper = bounds if len(bounds) == 2 else (0, bounds[0])
out = gen_fun.np_buffer_for_arg(3)
gen_fun(state, lower, upper, out)
return out

elif mode == 'PTX':
init_fun = pnlvm.LLVMBinaryFunction.get('__pnl_builtin_mt_rand_init')

state_size = init_fun.np_buffer_for_arg(0).nbytes
gpu_state = pnlvm.jit_engine.pycuda.driver.mem_alloc(state_size)

init_fun.cuda_call(gpu_state, np.int32(SEED))

gen_fun = pnlvm.LLVMBinaryFunction.get('__pnl_builtin_mt_rand_int32_bounded')
out = gen_fun.np_buffer_for_arg(3)
gpu_out = pnlvm.jit_engine.pycuda.driver.Out(out)

def f():
lower, upper = bounds if len(bounds) == 2 else (0, bounds[0])
gen_fun.cuda_call(gpu_state, np.uint32(lower), np.uint32(upper), gpu_out)
return out.copy()

else:
assert False, "Unknown mode: {}".format(mode)

res = [f(), f(), f(), f(), f()]
np.testing.assert_allclose(res, expected)
benchmark(f)

@pytest.mark.benchmark(group="Mersenne Twister integer PRNG")
@pytest.mark.parametrize('mode', ['Python', 'numpy',
pytest.param('LLVM', marks=pytest.mark.llvm),
Expand Down

0 comments on commit 2b4a77d

Please sign in to comment.