Skip to content

Commit

Permalink
erge remote-tracking branch 'origin/devel' into devel-llvm
Browse files Browse the repository at this point in the history
  • Loading branch information
jvesely committed Apr 29, 2022
2 parents 4aa8dae + c7f21f8 commit 47be5f5
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 15 deletions.
2 changes: 1 addition & 1 deletion dev_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
jupyter<=1.0.0
pytest<7.1.2
pytest<7.1.3
pytest-benchmark<3.4.2
pytest-cov<3.0.1
pytest-helpers-namespace<2021.12.30
Expand Down
20 changes: 10 additions & 10 deletions psyneulink/core/components/functions/stateful/memoryfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,16 @@
class MemoryFunction(StatefulFunction): # -----------------------------------------------------------------------------
componentType = MEMORY_FUNCTION

# TODO: refactor to avoid skip of direct super
def _update_default_variable(self, new_default_variable, context=None):
if not self.parameters.initializer._user_specified:
# use * 0 instead of zeros_like to deal with ragged arrays
self._initialize_previous_value([new_default_variable * 0], context)

# bypass the additional _initialize_previous_value call used by
# other stateful functions
super(StatefulFunction, self)._update_default_variable(new_default_variable, context=context)


class Buffer(MemoryFunction): # ------------------------------------------------------------------------------
"""
Expand Down Expand Up @@ -259,16 +269,6 @@ def _initialize_previous_value(self, initializer, context=None):

return previous_value

# TODO: Buffer variable fix: remove this or refactor to avoid skip
# of direct super
def _update_default_variable(self, new_default_variable, context=None):
if not self.parameters.initializer._user_specified:
self._initialize_previous_value([np.zeros_like(new_default_variable)], context)

# bypass the additional _initialize_previous_value call used by
# other stateful functions
super(StatefulFunction, self)._update_default_variable(new_default_variable, context=context)

def _instantiate_attributes_before_function(self, function=None, context=None):
self.parameters.previous_value._set(
self._initialize_previous_value(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1150,7 +1150,7 @@ class Parameters(ModulatoryMechanism_Base.Parameters):
)

monitor_for_control = Parameter(
[OUTCOME],
[],
stateful=False,
loggable=False,
read_only=True,
Expand All @@ -1167,6 +1167,7 @@ class Parameters(ModulatoryMechanism_Base.Parameters):
aliases=[CONTROL, CONTROL_SIGNALS],
constructor_argument=CONTROL
)
function = Parameter(Identity, stateful=False, loggable=False)

def _parse_output_ports(self, output_ports):
def is_2tuple(o):
Expand Down Expand Up @@ -1278,8 +1279,6 @@ def __init__(self,
f"creating unnecessary and/or duplicated Components.")
control = convert_to_list(args)

function = function or Identity

super(ControlMechanism, self).__init__(
default_variable=default_variable,
size=size,
Expand Down
1 change: 0 additions & 1 deletion psyneulink/core/globals/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,6 @@ class Parameter(ParameterBase):
'default_value',
'history_max_length',
'log_condition',
'delivery_condition',
'spec',
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@
from psyneulink.core.globals.preferences.basepreferenceset import is_pref_set
from psyneulink.core.globals.utilities import is_numeric_or_none
from psyneulink.library.components.mechanisms.processing.transfer.recurrenttransfermechanism import RecurrentTransferMechanism
from psyneulink.library.components.projections.pathway.autoassociativeprojection import get_auto_matrix, get_hetero_matrix

__all__ = [
'KWTAMechanism', 'KWTAError',
Expand Down Expand Up @@ -414,6 +415,17 @@ def _instantiate_attributes_before_function(self, function=None, context=None):
# so it shouldn't be a problem)
self.indexOfInhibitionInputPort = len(self.input_ports) - 1

# NOTE: this behavior matches what kwta tests assert. Values for
# auto and hetero were ALWAYS "user_specified" due to using
# values set in KWTAMechanism.__init__. To change this and use
# default RecurrentTransferMechanism behavior, the test values
# must be changed
matrix = (
get_auto_matrix(self.defaults.auto, self.recurrent_size)
+ get_hetero_matrix(self.defaults.hetero, self.recurrent_size)
)
self.parameters.matrix._set(matrix, context)

def _kwta_scale(self, current_input, context=None):
k_value = self._get_current_parameter_value(self.parameters.k_value, context)
threshold = self._get_current_parameter_value(self.parameters.threshold, context)
Expand Down
11 changes: 11 additions & 0 deletions tests/misc/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,17 @@ def test_function_user_specified(kwargs, parameter, is_user_specified):
assert getattr(t.function.parameters, parameter)._user_specified == is_user_specified


# sort param names or pytest-xdist may cause failure
# see https://github.com/pytest-dev/pytest/issues/4101
@pytest.mark.parametrize('attr', sorted(pnl.Parameter._additional_param_attr_properties))
def test_additional_param_attrs(attr):
assert hasattr(pnl.Parameter, f'_set_{attr}'), (
f'To include {attr} in Parameter._additional_param_attr_properties, you'
f' must add a _set_{attr} method on Parameter. If this is unneeded,'
' remove it from Parameter._additional_param_attr_properties.'
)


class TestSharedParameters:

recurrent_mech = pnl.RecurrentTransferMechanism(default_variable=[0, 0], enable_learning=True)
Expand Down

0 comments on commit 47be5f5

Please sign in to comment.