Skip to content

Commit

Permalink
Merge pull request #1763 from PrincetonUniversity/fix/OCM/various
Browse files Browse the repository at this point in the history
Fix/ocm/various
  • Loading branch information
dillontsmith authored Sep 28, 2020
2 parents 8708b66 + 9ffa81e commit 98e5c2c
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 28 deletions.
1 change: 1 addition & 0 deletions psyneulink/core/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -1652,6 +1652,7 @@ def _deferred_init(self, context=None):
# Complete initialization
# MODIFIED 10/27/18 OLD:
super(self.__class__,self).__init__(**self._init_args)

# MODIFIED 10/27/18 NEW: FOLLOWING IS NEEDED TO HANDLE FUNCTION DEFERRED INIT (JDC)
# try:
# super(self.__class__,self).__init__(**self._init_args)
Expand Down
13 changes: 9 additions & 4 deletions psyneulink/core/components/functions/optimizationfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,11 +405,11 @@ def _validate_params(self, request_set, target_set=None, context=None):

if SEARCH_SPACE in request_set and request_set[SEARCH_SPACE] is not None:
search_space = request_set[SEARCH_SPACE]
if not all(isinstance(s, (SampleIterator, type(None))) for s in search_space):
if not all(isinstance(s, (SampleIterator, type(None), list, tuple, np.ndarray)) for s in search_space):
raise OptimizationFunctionError("All entries in list specified for {} arg of {} must be a {}".
format(repr(SEARCH_SPACE),
self.__class__.__name__,
SampleIterator.__name__))
"SampleIterator, list, tuple, or ndarray"))

if SEARCH_TERMINATION_FUNCTION in request_set and request_set[SEARCH_TERMINATION_FUNCTION] is not None:
if not is_function_type(request_set[SEARCH_TERMINATION_FUNCTION]):
Expand Down Expand Up @@ -495,7 +495,11 @@ def _function(self,
self._unspecified_args = []

current_sample = self._check_args(variable=variable, context=context, params=params)
current_value = self.owner.objective_mechanism.parameters.value._get(context) if self.owner else 0.

try:
current_value = self.owner.objective_mechanism.parameters.value._get(context)
except AttributeError:
current_value = 0

samples = []
values = []
Expand Down Expand Up @@ -766,6 +770,7 @@ class GradientOptimization(OptimizationFunction):
"""

componentName = GRADIENT_OPTIMIZATION_FUNCTION
bounds = None

class Parameters(OptimizationFunction.Parameters):
"""
Expand Down Expand Up @@ -924,7 +929,7 @@ def reset(self, *args, context=None):
super().reset(*args)

# Differentiate objective_function using autograd.grad()
if OBJECTIVE_FUNCTION in args[0]:
if OBJECTIVE_FUNCTION in args[0] and not self.gradient_function:
try:
from autograd import grad
self.gradient_function = grad(self.objective_function)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@
from psyneulink.core.components.component import DefaultsFlexibility
from psyneulink.core.components.functions.function import is_function_type, FunctionError
from psyneulink.core.components.functions.optimizationfunctions import \
OBJECTIVE_FUNCTION, SEARCH_SPACE
GridSearch, OBJECTIVE_FUNCTION, SEARCH_SPACE
from psyneulink.core.components.functions.combinationfunctions import LinearCombination
from psyneulink.core.components.functions.transferfunctions import CostFunctions
from psyneulink.core.components.mechanisms.mechanism import Mechanism
Expand All @@ -417,6 +417,7 @@
from psyneulink.core.globals.parameters import Parameter, ParameterAlias
from psyneulink.core.globals.preferences.preferenceset import PreferenceLevel
from psyneulink.core.globals.context import handle_external_context
from psyneulink.core.globals.sampleiterator import SampleIterator, SampleSpec

from psyneulink.core import llvm as pnlvm

Expand All @@ -442,8 +443,11 @@ def __str__(self):


def _control_allocation_search_space_getter(owning_component=None, context=None):
return [c.parameters.allocation_samples._get(context) for c in owning_component.control_signals]

search_space = owning_component.parameters.search_space._get(context)
if not search_space:
return [c.parameters.allocation_samples._get(context) for c in owning_component.control_signals]
else:
return search_space

class OptimizationControlMechanism(ControlMechanism):
"""OptimizationControlMechanism( \
Expand All @@ -454,7 +458,7 @@ class OptimizationControlMechanism(ControlMechanism):
terminal_objective_mechanism=False \
features=None, \
feature_function=None, \
function=None, \
function=GridSearch, \
agent_rep=None, \
search_function=None, \
search_termination_function=None, \
Expand Down Expand Up @@ -507,11 +511,13 @@ class OptimizationControlMechanism(ControlMechanism):
`control_allocation <ControlMechanism.control_allocation>`, and the second the current iteration of the
`optimization process <OptimizationFunction_Process>`); it must return `True` or `False`.
search_space : list or ndarray
search_space : iterable [list, tuple, ndarray, SampleSpec, or SampleIterator] | list, tuple, ndarray, SampleSpec, or SampleIterator
specifies the `search_space <OptimizationFunction.search_space>` parameter for `function
<OptimizationControlMechanism.function>`, unless that is specified in a constructor for `function
<OptimizationControlMechanism.function>`. Each item must have the same shape as `control_allocation
<ControlMechanism.control_allocation>`.
<OptimizationControlMechanism.function>`. An element at index i should correspond to an element at index i in
`control_allocation <ControlMechanism.control_allocation>`. If
`control_allocation <ControlMechanism.control_allocation>` contains only one element, then search_space can be
specified as a single element without an enclosing iterable.
function : OptimizationFunction, function or method
specifies the function used to optimize the `control_allocation <ControlMechanism.control_allocation>`;
Expand Down Expand Up @@ -684,9 +690,10 @@ class Parameters(ControlMechanism.Parameters):
:default value: None
:type:
"""
function = Parameter(None, stateful=False, loggable=False)
function = Parameter(GridSearch, stateful=False, loggable=False)
feature_function = Parameter(None, reference=True, stateful=False, loggable=False)
search_function = Parameter(None, stateful=False, loggable=False)
search_space = Parameter(None, read_only=True)
search_termination_function = Parameter(None, stateful=False, loggable=False)
comp_execution_mode = Parameter('Python', stateful=False, loggable=False, pnl_internal=True)
search_statefulness = Parameter(True, stateful=False, loggable=False)
Expand Down Expand Up @@ -727,6 +734,8 @@ def __init__(self,
**kwargs):
"""Implement OptimizationControlMechanism"""

function = function or GridSearch

# If agent_rep hasn't been specified, put into deferred init
if agent_rep is None:
if context.source==ContextFlags.COMMAND_LINE:
Expand Down Expand Up @@ -846,6 +855,27 @@ def _instantiate_attributes_after_function(self, context=None):
"""Instantiate OptimizationControlMechanism's OptimizatonFunction attributes"""

super()._instantiate_attributes_after_function(context=context)

search_space = self.parameters.search_space._get(context)
if type(search_space) == np.ndarray:
search_space = search_space.tolist()
if search_space:
corrected_search_space = []
try:
if type(search_space) == SampleIterator:
corrected_search_space.append(search_space)
elif type(search_space) == SampleSpec:
corrected_search_space.append(SampleIterator(search_space))
else:
for i in self.parameters.search_space._get(context):
if not type(i) == SampleIterator:
corrected_search_space.append(SampleIterator(specification=i))
continue
corrected_search_space.append(i)
except AssertionError:
corrected_search_space = [SampleIterator(specification=search_space)]
self.parameters.search_space._set(corrected_search_space, context)

# Assign parameters to function (OptimizationFunction) that rely on OptimizationControlMechanism
self.function.reset({
DEFAULT_VARIABLE: self.parameters.control_allocation._get(context),
Expand Down Expand Up @@ -1311,11 +1341,6 @@ def _parse_feature_specs(self, input_ports, feature_function, context=None):

return parsed_features

@property
def control_allocation_search_space(self):
"""Return list of SampleIterators for allocation_samples of control_signals"""
return [c.allocation_samples for c in self.control_signals]

@property
def _model_spec_parameter_blacklist(self):
# default_variable is hidden in constructor arguments,
Expand Down
6 changes: 4 additions & 2 deletions psyneulink/core/globals/sampleiterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,8 @@ class SampleIterator(Iterator):
"""

@tc.typecheck
def __init__(self,
specification:tc.any(*allowable_specs)):
specification):
"""
Arguments
Expand Down Expand Up @@ -315,6 +314,9 @@ def __init__(self,

specification = SampleSpec(function=specification)

elif isinstance(specification, np.ndarray):
specification = specification.tolist()

if isinstance(specification, list):
self.start = specification[0]
self.stop = None
Expand Down
90 changes: 88 additions & 2 deletions tests/composition/test_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,12 +355,15 @@ def test_agent_rep_assignement_as_controller_and_replacement(self):
comp = pnl.Composition(name='comp',
pathways=[mech],
controller=pnl.OptimizationControlMechanism(agent_rep=None,
control_signals=(pnl.SLOPE, mech)))
control_signals=(pnl.SLOPE, mech),
search_space=[1]))
assert comp.controller.composition == comp
assert any(pnl.SLOPE in p_name for p_name in comp.projections.names)
assert not any(pnl.INTERCEPT in p_name for p_name in comp.projections.names)

new_ocm = pnl.OptimizationControlMechanism(agent_rep=None, control_signals=(pnl.INTERCEPT, mech))
new_ocm = pnl.OptimizationControlMechanism(agent_rep=None,
control_signals=(pnl.INTERCEPT, mech),
search_space=[1])
old_ocm = comp.controller
comp.add_controller(new_ocm)

Expand Down Expand Up @@ -1064,6 +1067,89 @@ def test_control_of_mech_port(self, mode):


class TestModelBasedOptimizationControlMechanisms:
def test_ocm_default_function(self):
a = pnl.ProcessingMechanism()
comp = pnl.Composition(
controller_mode=pnl.BEFORE,
nodes=[a],
controller=pnl.OptimizationControlMechanism(
control=pnl.ControlSignal(
modulates=(pnl.SLOPE, a),
intensity_cost_function=lambda x: 0,
adjustment_cost_function=lambda x: 0,
allocation_samples=[1, 10]
),
features=[a.input_port],
objective_mechanism=pnl.ObjectiveMechanism(
monitor=[a.output_port]
),
)
)
assert type(comp.controller.function) == pnl.GridSearch
assert comp.run([1]) == [10]

def test_ocm_searchspace_arg(self):
a = pnl.ProcessingMechanism()
comp = pnl.Composition(
controller_mode=pnl.BEFORE,
nodes=[a],
controller=pnl.OptimizationControlMechanism(
control=pnl.ControlSignal(
modulates=(pnl.SLOPE, a),
intensity_cost_function=lambda x: 0,
adjustment_cost_function=lambda x: 0,
),
features=[a.input_port],
objective_mechanism=pnl.ObjectiveMechanism(
monitor=[a.output_port]
),
search_space=[pnl.SampleIterator([1, 10])]
)
)
assert type(comp.controller.function) == pnl.GridSearch
assert comp.run([1]) == [10]

@pytest.mark.parametrize("format,nested",
[("list", True), ("list", False),
("tuple", True), ("tuple", False),
("SampleIterator", True), ("SampleIterator", False),
("SampleSpec", True), ("SampleSpec", False),
("ndArray", True), ("ndArray", False),
],)
def test_ocm_searchspace_format_equivalence(self, format, nested):
if format == "list":
search_space = [1, 10]
elif format == "tuple":
search_space = (1, 10)
elif format == "SampleIterator":
search_space = SampleIterator((1,10))
elif format == "SampleSpec":
search_space = SampleSpec(1, 10, 9)
elif format == "ndArray":
search_space = np.array((1, 10))

if nested:
search_space = [search_space]

a = pnl.ProcessingMechanism()
comp = pnl.Composition(
controller_mode=pnl.BEFORE,
nodes=[a],
controller=pnl.OptimizationControlMechanism(
control=pnl.ControlSignal(
modulates=(pnl.SLOPE, a),
intensity_cost_function=lambda x: 0,
adjustment_cost_function=lambda x: 0,
),
features=[a.input_port],
objective_mechanism=pnl.ObjectiveMechanism(
monitor=[a.output_port]
),
search_space=search_space
)
)
assert type(comp.controller.function) == pnl.GridSearch
assert comp.run([1]) == [10]

def test_evc(self):
# Mechanisms
Expand Down
6 changes: 5 additions & 1 deletion tests/composition/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,8 @@ def test_parameter_CIM_port_order(self):
ControlSignal(projections=[(NOISE, ia)]),
ControlSignal(projections=[(INTERCEPT, ia)]),
ControlSignal(projections=[(SLOPE, ia)]),
]
],
search_space=[[1], [1], [1]]
)
ocomp.add_controller(ocm)

Expand Down Expand Up @@ -598,6 +599,9 @@ def test_nested_control_projection_count_controller(self):
ControlSignal(projections=[(NOISE, ia)]),
ControlSignal(projections=[(INTERCEPT, ia)]),
ControlSignal(projections=[(SLOPE, ia)]),
],
search_space=[
[1], [1], [1]
]
)
ocomp.add_controller(ocm)
Expand Down
17 changes: 11 additions & 6 deletions tests/composition/test_show_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ def test_no_nested_and_controler_name_with_space_in_it(self):
ControlSignal(projections=[(NOISE, ia)]),
ControlSignal(projections=[(INTERCEPT, ia)]),
ControlSignal(projections=[(SLOPE, ib)])
])
],
search_space=[[1],[1],[1]])
comp = Composition(name='ocomp', pathways=[ia, ib], controller=ocm)

gv = comp.show_graph(show_controller=False, output_fmt='source')
Expand Down Expand Up @@ -277,8 +278,9 @@ def test_nested_learning_test_with_user_specified_target_in_outer_composition(se
name='CONTROLLER',
objective_mechanism=ObjectiveMechanism(name='OBJECTIVE MECHANISM',
monitor=[input_mech, output_mech]),
control=(SLOPE, internal_mech))
)
control=(SLOPE, internal_mech),
search_space=[1]
))
ocomp.add_node(target)
ocomp.add_projection(sender=target, receiver=p.target)

Expand Down Expand Up @@ -363,7 +365,8 @@ def test_of_show_nested_show_cim_and_show_node_structure(self):
ControlSignal(projections=[(NOISE, ia)]),
ControlSignal(projections=[(INTERCEPT, ia)]),
ControlSignal(projections=[(SLOPE, oa)]),
])
],
search_space=[[1],[1],[1]])
ocomp.add_controller(ocm)

# ocomp.show_graph(show_cim=True, show_nested=INSET)
Expand Down Expand Up @@ -414,7 +417,8 @@ def test_of_show_3_level_nested_show_cim_and_show_node_structure(self):
ControlSignal(projections=[(NOISE, ia)]),
ControlSignal(projections=[(INTERCEPT, ia)]),
ControlSignal(projections=[(SLOPE, oa)]),
])
],
search_space=[[1],[1],[1]])
ocomp.add_controller(ocm)

gv = ocomp.show_graph(show_nested=False, output_fmt='source')
Expand Down Expand Up @@ -461,7 +465,8 @@ def test_of_show_nested_show_cim_and_show_node_structure_with_singleton_in_outer
ControlSignal(projections=[(NOISE, ia)]),
ControlSignal(projections=[(INTERCEPT, ia)]),
ControlSignal(projections=[(SLOPE, oa)]),
])
],
search_space=[[1],[1],[1]])
ocomp.add_controller(ocm)

ocomp.show_graph(show_cim=True, show_nested=INSET)
Expand Down

0 comments on commit 98e5c2c

Please sign in to comment.