From 9b2ab4f16d80a6409760754fe3949586b5dc79ef Mon Sep 17 00:00:00 2001 From: Jan Vesely Date: Thu, 30 Sep 2021 00:48:55 -0400 Subject: [PATCH] llvm: Add helper function to retrieve and optionally reseed random state Store used seed in compiled random state. Check the last used seed with the most recent value of the 'seed' param, reseed the random state if the seeds don't match. Add simple test. Signed-off-by: Jan Vesely --- psyneulink/core/components/component.py | 2 +- .../nonstateful/distributionfunctions.py | 4 ++-- .../nonstateful/selectionfunctions.py | 4 ++-- .../nonstateful/transferfunctions.py | 2 +- .../functions/stateful/integratorfunctions.py | 2 +- .../functions/stateful/memoryfunctions.py | 2 +- psyneulink/core/llvm/builder_context.py | 23 ++++++++++++++++++- psyneulink/core/llvm/builtins.py | 9 +++++++- .../mechanisms/processing/integrator/ddm.py | 2 +- tests/composition/test_control.py | 21 +++++++++++++++++ 10 files changed, 60 insertions(+), 11 deletions(-) diff --git a/psyneulink/core/components/component.py b/psyneulink/core/components/component.py index 0636e8990fc..13b84ee84d7 100644 --- a/psyneulink/core/components/component.py +++ b/psyneulink/core/components/component.py @@ -1332,7 +1332,7 @@ def _convert(p): x = p.get(context) if isinstance(x, np.random.RandomState): # Skip first element of random state (id string) - val = pnlvm._tupleize(x.get_state()[1:]) + val = pnlvm._tupleize((*x.get_state()[1:], x.used_seed[0])) elif isinstance(x, Time): val = tuple(getattr(x, graph_scheduler.time._time_scale_to_attr_str(t)) for t in TimeScale) elif isinstance(x, Component): diff --git a/psyneulink/core/components/functions/nonstateful/distributionfunctions.py b/psyneulink/core/components/functions/nonstateful/distributionfunctions.py index b62d2d96788..31af21a94d1 100644 --- a/psyneulink/core/components/functions/nonstateful/distributionfunctions.py +++ b/psyneulink/core/components/functions/nonstateful/distributionfunctions.py @@ -195,7 +195,7 @@ def _function(self, return self.convert_output_type(result) def _gen_llvm_function_body(self, ctx, builder, params, state, _, arg_out, *, tags:frozenset): - random_state = pnlvm.helpers.get_state_ptr(builder, self, state, "random_state") + random_state = ctx.get_random_state_ptr(builder, self, state, params) mean_ptr = pnlvm.helpers.get_param_ptr(builder, self, params, "mean") std_dev_ptr = pnlvm.helpers.get_param_ptr(builder, self, params, "standard_deviation") ret_val_ptr = builder.alloca(ctx.float_ty) @@ -620,7 +620,7 @@ def _function(self, return self.convert_output_type(result) def _gen_llvm_function_body(self, ctx, builder, params, state, _, arg_out, *, tags:frozenset): - random_state = pnlvm.helpers.get_state_ptr(builder, self, state, "random_state") + random_state = ctx.get_random_state_ptr(builder, self, state, params) low_ptr = pnlvm.helpers.get_param_ptr(builder, self, params, LOW) high_ptr = pnlvm.helpers.get_param_ptr(builder, self, params, HIGH) ret_val_ptr = builder.alloca(ctx.float_ty) diff --git a/psyneulink/core/components/functions/nonstateful/selectionfunctions.py b/psyneulink/core/components/functions/nonstateful/selectionfunctions.py index a277adbd373..5b580c5a32f 100644 --- a/psyneulink/core/components/functions/nonstateful/selectionfunctions.py +++ b/psyneulink/core/components/functions/nonstateful/selectionfunctions.py @@ -253,14 +253,14 @@ def _validate_params(self, request_set, target_set=None, context=None): "array of probabilities that sum to 1". format(MODE, self.__class__.__name__, Function.__name__, PROB, prob_dist)) - def _gen_llvm_function_body(self, ctx, builder, _, state, arg_in, arg_out, *, tags:frozenset): + def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out, *, tags:frozenset): idx_ptr = builder.alloca(ctx.int32_ty) builder.store(ctx.int32_ty(0), idx_ptr) if self.mode in {PROB, PROB_INDICATOR}: rng_f = ctx.import_llvm_function("__pnl_builtin_mt_rand_double") dice_ptr = builder.alloca(ctx.float_ty) - mt_state_ptr = pnlvm.helpers.get_state_ptr(builder, self, state, "random_state") + mt_state_ptr = ctx.get_random_state_ptr(builder, self, state, params) builder.call(rng_f, [mt_state_ptr, dice_ptr]) dice = builder.load(dice_ptr) sum_ptr = builder.alloca(ctx.float_ty) diff --git a/psyneulink/core/components/functions/nonstateful/transferfunctions.py b/psyneulink/core/components/functions/nonstateful/transferfunctions.py index b5fa0b2f659..4b6e67ebe2b 100644 --- a/psyneulink/core/components/functions/nonstateful/transferfunctions.py +++ b/psyneulink/core/components/functions/nonstateful/transferfunctions.py @@ -2249,7 +2249,7 @@ def _gen_llvm_transfer(self, builder, index, ctx, vi, vo, params, state, *, tags offset = pnlvm.helpers.load_extract_scalar_array_one(builder, offset_ptr) rvalp = builder.alloca(ptri.type.pointee) - rand_state_ptr = pnlvm.helpers.get_state_ptr(builder, self, state, "random_state") + rand_state_ptr = ctx.get_random_state_ptr(builder, self, state, params) normal_f = ctx.import_llvm_function("__pnl_builtin_mt_rand_normal") builder.call(normal_f, [rand_state_ptr, rvalp]) diff --git a/psyneulink/core/components/functions/stateful/integratorfunctions.py b/psyneulink/core/components/functions/stateful/integratorfunctions.py index 13f2f894da1..b266362b5c1 100644 --- a/psyneulink/core/components/functions/stateful/integratorfunctions.py +++ b/psyneulink/core/components/functions/stateful/integratorfunctions.py @@ -2485,7 +2485,7 @@ def _gen_llvm_integrate(self, builder, index, ctx, vi, vo, params, state): threshold = self._gen_llvm_load_param(ctx, builder, params, index, THRESHOLD) time_step_size = self._gen_llvm_load_param(ctx, builder, params, index, TIME_STEP_SIZE) - random_state = pnlvm.helpers.get_state_ptr(builder, self, state, "random_state") + random_state = ctx.get_random_state_ptr(builder, self, state, params) rand_val_ptr = builder.alloca(ctx.float_ty) rand_f = ctx.import_llvm_function("__pnl_builtin_mt_rand_normal") builder.call(rand_f, [random_state, rand_val_ptr]) diff --git a/psyneulink/core/components/functions/stateful/memoryfunctions.py b/psyneulink/core/components/functions/stateful/memoryfunctions.py index bd95a6573b3..c21260ef478 100644 --- a/psyneulink/core/components/functions/stateful/memoryfunctions.py +++ b/psyneulink/core/components/functions/stateful/memoryfunctions.py @@ -2243,7 +2243,7 @@ def _get_state_initializer(self, context): def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out, *, tags:frozenset): # PRNG - rand_struct = pnlvm.helpers.get_state_ptr(builder, self, state, "random_state") + rand_struct = ctx.get_random_state_ptr(builder, self, state, params) uniform_f = ctx.import_llvm_function("__pnl_builtin_mt_rand_double") # Ring buffer diff --git a/psyneulink/core/llvm/builder_context.py b/psyneulink/core/llvm/builder_context.py index fc6d1318553..9c7e80cd232 100644 --- a/psyneulink/core/llvm/builder_context.py +++ b/psyneulink/core/llvm/builder_context.py @@ -24,6 +24,7 @@ from psyneulink.core.globals.utilities import ContentAddressableList from psyneulink.core import llvm as pnlvm from . import codegen +from . import helpers from .debug import debug_env __all__ = ['LLVMBuilderContext', '_modules', '_find_llvm_function'] @@ -51,7 +52,7 @@ def module_count(): _BUILTIN_PREFIX = "__pnl_builtin_" -_builtin_intrinsics = frozenset(('pow', 'log', 'exp', 'tanh', 'coth', 'csch', 'is_close')) +_builtin_intrinsics = frozenset(('pow', 'log', 'exp', 'tanh', 'coth', 'csch', 'is_close', 'mt_rand_init')) class _node_wrapper(): @@ -188,6 +189,26 @@ def import_llvm_function(self, fun, *, tags:frozenset=frozenset()) -> ir.Functio return decl_f return f + def get_random_state_ptr(self, builder, component, state, params): + random_state_ptr = helpers.get_state_ptr(builder, component, state, "random_state") + used_seed_ptr = builder.gep(random_state_ptr, [self.int32_ty(0), self.int32_ty(4)]) + used_seed = builder.load(used_seed_ptr) + + seed_ptr = helpers.get_param_ptr(builder, component, params, "seed") + if isinstance(seed_ptr.type.pointee, ir.ArrayType): + # Modulated params are usually single element arrays + seed_ptr = builder.gep(seed_ptr, [self.int32_ty(0), self.int32_ty(0)]) + new_seed = builder.load(seed_ptr) + # FIXME: the seed should ideally be integer already + new_seed = builder.fptoui(new_seed, used_seed.type) + + seeds_cmp = builder.icmp_unsigned("!=", used_seed, new_seed) + with builder.if_then(seeds_cmp, likely=False): + reseed_f = self.get_builtin("mt_rand_init") + builder.call(reseed_f, [random_state_ptr, new_seed]) + + return random_state_ptr + @staticmethod def get_debug_location(func: ir.Function, component): if "debug_info" not in debug_env: diff --git a/psyneulink/core/llvm/builtins.py b/psyneulink/core/llvm/builtins.py index cb3ea246fc7..683e7cec11f 100644 --- a/psyneulink/core/llvm/builtins.py +++ b/psyneulink/core/llvm/builtins.py @@ -527,6 +527,8 @@ def _setup_mt_rand_init_scalar(ctx, state_ty): pidx = builder.gep(state, [ctx.int32_ty(0), ctx.int32_ty(1)]) builder.store(pidx.type.pointee(_MERSENNE_N), pidx) + seed_p = builder.gep(state, [ctx.int32_ty(0), ctx.int32_ty(4)]) + builder.store(seed, seed_p) builder.ret_void() return builder.function @@ -615,6 +617,10 @@ def _setup_mt_rand_init(ctx, state_ty, init_scalar): # set the 0th element to INT_MIN builder.store(a_0.type.pointee(0x80000000), a_0) + + # store used seed + used_seed_p = builder.gep(state, [ctx.int32_ty(0), ctx.int32_ty(4)]) + builder.store(seed, used_seed_p) builder.ret_void() return builder.function @@ -842,7 +848,8 @@ def get_mersenne_twister_state_struct(ctx): ir.ArrayType(ctx.int32_ty, _MERSENNE_N), # array ctx.int32_ty, # index ctx.int32_ty, # last_gauss available - ctx.float_ty]) # last_gauss + ctx.float_ty, # last_gauss + ctx.int32_ty]) # used seed def setup_mersenne_twister(ctx): diff --git a/psyneulink/library/components/mechanisms/processing/integrator/ddm.py b/psyneulink/library/components/mechanisms/processing/integrator/ddm.py index 1f09f20b40c..111cab6aafe 100644 --- a/psyneulink/library/components/mechanisms/processing/integrator/ddm.py +++ b/psyneulink/library/components/mechanisms/processing/integrator/ddm.py @@ -1128,7 +1128,7 @@ def _gen_llvm_invoke_function(self, ctx, builder, function, params, state, varia threshold_ptr) # Load mechanism state to generate random numbers state = builder.function.args[1] - random_state = pnlvm.helpers.get_state_ptr(builder, self, state, "random_state") + random_state = ctx.get_random_state_ptr(builder, self, state, params) random_f = ctx.import_llvm_function("__pnl_builtin_mt_rand_double") random_val_ptr = builder.alloca(random_f.args[1].type.pointee) builder.call(random_f, [random_state, random_val_ptr]) diff --git a/tests/composition/test_control.py b/tests/composition/test_control.py index 125fe2971a2..8470995386b 100644 --- a/tests/composition/test_control.py +++ b/tests/composition/test_control.py @@ -1085,6 +1085,27 @@ def test_modulation_simple(self, cost, expected, exp_values): assert np.allclose(ret, expected) assert np.allclose([float(x) for x in comp.controller.function.saved_values], exp_values) + @pytest.mark.benchmark + @pytest.mark.control + @pytest.mark.composition + def test_modulation_of_random_state_direct(self, comp_mode, benchmark): + # set explicit seed to make sure modulation is different + mech = pnl.ProcessingMechanism(function=pnl.UniformDist(seed=0)) + ctl_mech = pnl.ControlMechanism(control_signals=pnl.ControlSignal(modulates=('seed', mech), + modulation=pnl.OVERRIDE)) + comp = pnl.Composition() + comp.add_node(mech) + comp.add_node(ctl_mech) + + seeds = [13, 13, 14] + prngs = {s:np.random.RandomState([s]) for s in seeds} + expected = [prngs[s].uniform() for s in seeds] * 2 + # cycle over the seeds twice setting and resetting the random state + benchmark(comp.run, inputs={ctl_mech:seeds}, num_trials=len(seeds) * 2, execution_mode=comp_mode) + + assert np.allclose(np.squeeze(comp.results[:len(seeds) * 2]), expected) + + @pytest.mark.control @pytest.mark.composition @pytest.mark.parametrize("mode", [pnl.ExecutionMode.Python])