Skip to content

Commit

Permalink
Merge pull request #2226 from kmantel/control
Browse files Browse the repository at this point in the history
Control
  • Loading branch information
kmantel authored Dec 10, 2021
2 parents 78fe80a + cef8bc4 commit 82f6379
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -1756,7 +1756,13 @@ def _instantiate_control_signal_type(self, control_signal_spec, context):
# tests/composition/test_control.py::TestModelBasedOptimizationControlMechanisms::test_stateful_mechanism_in_simulation
allocation_parameter_default = np.ones(np.asarray(allocation_parameter_default).shape)
except (KeyError, IndexError, TypeError):
allocation_parameter_default = self.parameters.control_allocation.default_value
# if control allocation is a single value specified from
# default_variable for example, it should be used here
# instead of the "global default" defaultControlAllocation
if len(self.defaults.control_allocation) == 1:
allocation_parameter_default = copy.deepcopy(self.defaults.control_allocation)
else:
allocation_parameter_default = copy.deepcopy(defaultControlAllocation)

control_signal = _instantiate_port(port_type=ControlSignal,
owner=self,
Expand Down
65 changes: 65 additions & 0 deletions tests/composition/test_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,71 @@ def test_state_input_ports_for_two_input_nodes(self):
assert all([node in [input_port.shadow_inputs.owner for input_port in ocomp.controller.state_input_ports]
for node in {oa, ob}])

@pytest.mark.parametrize(
'ocm_control_signals',
[
'None',
"[pnl.ControlSignal(modulates=('slope', a))]",
"[pnl.ControlSignal(modulates=('slope', a), allocation_samples=[1, 2])]",
]
)
@pytest.mark.parametrize('ocm_num_estimates', [None, 1])
@pytest.mark.parametrize(
'slope, intercept',
[
((1.0, pnl.CONTROL), None),
((1.0, pnl.CONTROL), (1.0, pnl.CONTROL)),
]
)
def test_transfer_mechanism_and_ocm_variations(
self,
slope,
intercept,
ocm_num_estimates,
ocm_control_signals,
):
a = pnl.TransferMechanism(
name='a',
function=pnl.Linear(
slope=slope,
intercept=intercept,
)
)

comp = pnl.Composition()
comp.add_node(a)

ocm_control_signals = eval(ocm_control_signals)
ocm = pnl.OptimizationControlMechanism(
agent_rep=comp,
search_space=[[0, 1]],
num_estimates=ocm_num_estimates,
control_signals=ocm_control_signals,
)
comp.add_controller(ocm)

# assume tuple is a control spec
if (
isinstance(slope, tuple)
or (
ocm_control_signals is not None
and any(cs.name == 'slope' for cs in ocm_control_signals)
)
):
assert 'a[slope] ControlSignal' in ocm.control.names
else:
assert 'a[slope] ControlSignal' not in ocm.control.names

if (
isinstance(intercept, tuple)
or (
ocm_control_signals is not None
and any(cs.name == 'intercept' for cs in ocm_control_signals)
)
):
assert 'a[intercept] ControlSignal' in ocm.control.names
else:
assert 'a[intercept] ControlSignal' not in ocm.control.names

class TestControlMechanisms:

Expand Down

0 comments on commit 82f6379

Please sign in to comment.