Skip to content

Commit

Permalink
llvm, mechanism: Use Parameter class to build a list of needed parame…
Browse files Browse the repository at this point in the history
…ter ports

Names of Parameter and ParameterPort don't have to match if there are
multiple parameters with the same name. (Like 'seed' in DDM with DDI
function).
Add test modulating seed of DDI within DDM.

Signed-off-by: Jan Vesely <[email protected]>
  • Loading branch information
jvesely committed Oct 7, 2021
1 parent 9f83049 commit 53f1cd2
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 8 deletions.
18 changes: 10 additions & 8 deletions psyneulink/core/components/mechanisms/mechanism.py
Original file line number Diff line number Diff line change
Expand Up @@ -2915,24 +2915,26 @@ def _gen_llvm_param_ports_for_obj(self, obj, params_in, ctx, builder,
# Few extra copies will be eliminated by the compiler.
builder.store(builder.load(params_in), params_out)

# Filter out param ports without corresponding params for this function
param_ports = [p for p in self._parameter_ports if p.name in obj.llvm_param_ids]
# This should be faster than 'obj._get_compilation_params'
compilation_params = (getattr(obj.parameters, p_id, None) for p_id in obj.llvm_param_ids)
# Filter out param ports without corresponding param for this function
param_ports = [self._parameter_ports[param] for param in compilation_params if param in self._parameter_ports]

def _get_output_ptr(b, i):
ptr = pnlvm.helpers.get_param_ptr(b, obj, params_out,
param_ports[i].name)
param_ports[i].source.name)
return b, ptr

def _fill_input(b, p_input, i):
param_in_ptr = pnlvm.helpers.get_param_ptr(b, obj, params_in,
param_ports[i].name)
param_ptr = pnlvm.helpers.get_param_ptr(b, obj, params_in,
param_ports[i].source.name)
# Parameter port inputs are {original parameter, [modulations]},
# fill in the first one.
# here we fill in the first one.
data_ptr = builder.gep(p_input, [ctx.int32_ty(0), ctx.int32_ty(0)])
assert data_ptr.type == param_in_ptr.type, \
assert data_ptr.type == param_ptr.type, \
"Mishandled modulation type for: {} in '{}' in '{}'".format(
param_ports[i].name, obj.name, self.name)
b.store(b.load(param_in_ptr), data_ptr)
b.store(b.load(param_ptr), data_ptr)
return b

builder = self._gen_llvm_ports(ctx, builder, param_ports, "_parameter_ports",
Expand Down
22 changes: 22 additions & 0 deletions tests/composition/test_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,28 @@ def test_modulation_of_random_state_direct(self, comp_mode, benchmark):

assert np.allclose(np.squeeze(comp.results[:len(seeds) * 2]), expected)

@pytest.mark.benchmark
@pytest.mark.control
@pytest.mark.composition
# 'LLVM' mode is not supported, because synchronization of compiler and
# python values during execution is not implemented.
@pytest.mark.usefixtures("comp_mode_no_llvm")
def test_modulation_of_random_state_DDM(self, comp_mode, benchmark):
# set explicit seed to make sure modulation is different
mech = pnl.DDM(function=pnl.DriftDiffusionIntegrator(noise=5.),
reset_stateful_function_when=pnl.AtPass(0),
execute_until_finished=True)
ctl_mech = pnl.ControlMechanism(control_signals=pnl.ControlSignal(modulates=('seed-function', mech),
modulation=pnl.OVERRIDE))
comp = pnl.Composition()
comp.add_node(mech, required_roles=pnl.NodeRole.INPUT)
comp.add_node(ctl_mech)

seeds = [13, 13, 14]
# cycle over the seeds twice setting and resetting the random state
benchmark(comp.run, inputs={ctl_mech:seeds, mech:5.0}, num_trials=len(seeds) * 2, execution_mode=comp_mode)

assert np.allclose(np.squeeze(comp.results[:len(seeds) * 2]), [[100, 21], [100, 23], [100, 20]] * 2)

@pytest.mark.control
@pytest.mark.composition
Expand Down

0 comments on commit 53f1cd2

Please sign in to comment.