From b28e61551768536c8601a7418e923f75b08cba33 Mon Sep 17 00:00:00 2001 From: Katherine Mantel Date: Mon, 6 Dec 2021 23:43:06 -0500 Subject: [PATCH 1/2] ControlMechanism: fix default allocation for new control signal if no allocation samples are specified for a control signal _instantiate_control_signal_type, defaultControlAllocation should be used instead of the default value for control_allocation (value), because defaultControlAllocation refers to a single element (signal) in control_allocation, which refers to all the elements (signals) --- .../mechanisms/modulatory/control/controlmechanism.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/psyneulink/core/components/mechanisms/modulatory/control/controlmechanism.py b/psyneulink/core/components/mechanisms/modulatory/control/controlmechanism.py index 1e69300fc24..579c692cacf 100644 --- a/psyneulink/core/components/mechanisms/modulatory/control/controlmechanism.py +++ b/psyneulink/core/components/mechanisms/modulatory/control/controlmechanism.py @@ -1756,7 +1756,13 @@ def _instantiate_control_signal_type(self, control_signal_spec, context): # tests/composition/test_control.py::TestModelBasedOptimizationControlMechanisms::test_stateful_mechanism_in_simulation allocation_parameter_default = np.ones(np.asarray(allocation_parameter_default).shape) except (KeyError, IndexError, TypeError): - allocation_parameter_default = self.parameters.control_allocation.default_value + # if control allocation is a single value specified from + # default_variable for example, it should be used here + # instead of the "global default" defaultControlAllocation + if len(self.defaults.control_allocation) == 1: + allocation_parameter_default = copy.deepcopy(self.defaults.control_allocation) + else: + allocation_parameter_default = copy.deepcopy(defaultControlAllocation) control_signal = _instantiate_port(port_type=ControlSignal, owner=self, From cef8bc4c89f885bed8f609c20efbd495d388f141 Mon Sep 17 00:00:00 2001 From: Katherine Mantel Date: Thu, 2 Dec 2021 18:58:05 -0500 Subject: [PATCH 2/2] tests: add minimal OCM control variations --- tests/composition/test_control.py | 65 +++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/tests/composition/test_control.py b/tests/composition/test_control.py index 6bb99e19f56..bae8848368b 100644 --- a/tests/composition/test_control.py +++ b/tests/composition/test_control.py @@ -574,6 +574,71 @@ def test_state_input_ports_for_two_input_nodes(self): assert all([node in [input_port.shadow_inputs.owner for input_port in ocomp.controller.state_input_ports] for node in {oa, ob}]) + @pytest.mark.parametrize( + 'ocm_control_signals', + [ + 'None', + "[pnl.ControlSignal(modulates=('slope', a))]", + "[pnl.ControlSignal(modulates=('slope', a), allocation_samples=[1, 2])]", + ] + ) + @pytest.mark.parametrize('ocm_num_estimates', [None, 1]) + @pytest.mark.parametrize( + 'slope, intercept', + [ + ((1.0, pnl.CONTROL), None), + ((1.0, pnl.CONTROL), (1.0, pnl.CONTROL)), + ] + ) + def test_transfer_mechanism_and_ocm_variations( + self, + slope, + intercept, + ocm_num_estimates, + ocm_control_signals, + ): + a = pnl.TransferMechanism( + name='a', + function=pnl.Linear( + slope=slope, + intercept=intercept, + ) + ) + + comp = pnl.Composition() + comp.add_node(a) + + ocm_control_signals = eval(ocm_control_signals) + ocm = pnl.OptimizationControlMechanism( + agent_rep=comp, + search_space=[[0, 1]], + num_estimates=ocm_num_estimates, + control_signals=ocm_control_signals, + ) + comp.add_controller(ocm) + + # assume tuple is a control spec + if ( + isinstance(slope, tuple) + or ( + ocm_control_signals is not None + and any(cs.name == 'slope' for cs in ocm_control_signals) + ) + ): + assert 'a[slope] ControlSignal' in ocm.control.names + else: + assert 'a[slope] ControlSignal' not in ocm.control.names + + if ( + isinstance(intercept, tuple) + or ( + ocm_control_signals is not None + and any(cs.name == 'intercept' for cs in ocm_control_signals) + ) + ): + assert 'a[intercept] ControlSignal' in ocm.control.names + else: + assert 'a[intercept] ControlSignal' not in ocm.control.names class TestControlMechanisms: