Skip to content

Commit

Permalink
llvm: Check for correct use of get_{param,state}_ptr helpers (#2198)
Browse files Browse the repository at this point in the history
Pass the mechanism params and state when retrieving DDM mechanism random state
Use a correct helper to get composition state.
Add state id for autodiff optimizer state.
Add basic sanity checks to get_{param,state}_ptr helpers.
  • Loading branch information
jvesely authored Nov 10, 2021
2 parents 5f79271 + 695dc51 commit 91f710a
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 6 deletions.
4 changes: 3 additions & 1 deletion psyneulink/core/llvm/builder_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,9 @@ def get_random_state_ptr(self, builder, component, state, params):
# Modulated params are usually single element arrays
seed_ptr = builder.gep(seed_ptr, [self.int32_ty(0), self.int32_ty(0)])
new_seed = builder.load(seed_ptr)
# FIXME: the seed should ideally be integer already
# FIXME: The seed should ideally be integer already.
# However, it can be modulated and we don't support,
# passing integer values as computed results.
new_seed = builder.fptoui(new_seed, used_seed.type)

seeds_cmp = builder.icmp_unsigned("!=", used_seed, new_seed)
Expand Down
2 changes: 1 addition & 1 deletion psyneulink/core/llvm/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ def gen_composition_exec(ctx, composition, *, tags:frozenset):
with _gen_composition_exec_context(ctx, composition, tags=tags) as (builder, data, params, cond_gen):
state, _, comp_in, _, cond = builder.function.args

nodes_states = helpers.get_param_ptr(builder, composition, state, "nodes")
nodes_states = helpers.get_state_ptr(builder, composition, state, "nodes")

# Allocate temporary output storage
output_storage = builder.alloca(data.type.pointee, name="output_storage")
Expand Down
6 changes: 6 additions & 0 deletions psyneulink/core/llvm/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,18 @@ def uint_min(builder, val, other):


def get_param_ptr(builder, component, params_ptr, param_name):
# check if the passed location matches expected size
assert len(params_ptr.type.pointee) == len(component.llvm_param_ids)

idx = ir.IntType(32)(component.llvm_param_ids.index(param_name))
return builder.gep(params_ptr, [ir.IntType(32)(0), idx],
name="ptr_param_{}_{}".format(param_name, component.name))


def get_state_ptr(builder, component, state_ptr, stateful_name, hist_idx=0):
# check if the passed location matches expected size
assert len(state_ptr.type.pointee) == len(component.llvm_state_ids)

idx = ir.IntType(32)(component.llvm_state_ids.index(stateful_name))
ptr = builder.gep(state_ptr, [ir.IntType(32)(0), idx],
name="ptr_state_{}_{}".format(stateful_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1127,8 +1127,9 @@ def _gen_llvm_invoke_function(self, ctx, builder, function, params, state, varia
threshold = pnlvm.helpers.load_extract_scalar_array_one(builder,
threshold_ptr)
# Load mechanism state to generate random numbers
state = builder.function.args[1]
random_state = ctx.get_random_state_ptr(builder, self, state, params)
mech_params = builder.function.args[0]
mech_state = builder.function.args[1]
random_state = ctx.get_random_state_ptr(builder, self, mech_state, mech_params)
random_f = ctx.get_uniform_dist_function_by_state(random_state)
random_val_ptr = builder.alloca(random_f.args[1].type.pointee)
builder.call(random_f, [random_state, random_val_ptr])
Expand Down
3 changes: 3 additions & 0 deletions psyneulink/library/compositions/autodiffcomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,9 @@ def execute(self,
report_num=report_num
)

def _get_state_ids(self):
return super()._get_state_ids() + ["optimizer"]

def _get_state_struct_type(self, ctx):
comp_state_type_list = ctx.get_state_struct_type(super())
pytorch_representation = self._build_pytorch_representation()
Expand Down
2 changes: 0 additions & 2 deletions tests/mechanisms/test_ddm_mechanism.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,6 @@ def test_selected_input_array(self):
@pytest.mark.benchmark
@pytest.mark.parametrize('prng', ['Default', 'Philox'])
def test_DDM_Integrator_Bogacz(benchmark, mech_mode, prng):
if prng == 'Philox':
pytest.skip("Known broken")
stim = 10
T = DDM(
name='DDM',
Expand Down

0 comments on commit 91f710a

Please sign in to comment.