Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parameters: fix Parameter inheritance when setting attributes #2402

Merged
merged 1 commit into from
May 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions psyneulink/core/globals/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,6 @@ def __getattr__(self, attr):
def __setattr__(self, attr, value):
if (attr[:1] != '_'):
param = getattr(self._owner.parameters, attr)
param._inherited = False
param.default_value = value
else:
super().__setattr__(attr, value)
Expand Down Expand Up @@ -927,6 +926,7 @@ def __init__(
_inherited=_inherited,
_inherited_source=_inherited_source,
_user_specified=_user_specified,
_temp_uninherited=set(),
**kwargs
)

Expand Down Expand Up @@ -1011,10 +1011,15 @@ def __getattr__(self, attr):

def __setattr__(self, attr, value):
if attr in self._additional_param_attr_properties:
self._temp_uninherited.add(attr)
self._inherited = False

try:
getattr(self, '_set_{0}'.format(attr))(value)
except AttributeError:
super().__setattr__(attr, value)

self._temp_uninherited.remove(attr)
else:
super().__setattr__(attr, value)

Expand Down Expand Up @@ -1053,6 +1058,7 @@ def _inherited(self, value):
if value is not self._inherited:
# invalid if set to inherited
self._is_invalid_source = value
self.__inherited = value

if value:
self._cache_inherited_attrs()
Expand All @@ -1078,14 +1084,14 @@ def _inherited(self, value):

self._restore_inherited_attrs()

self.__inherited = value

def _inherit_from(self, parent):
self._inherited_source = weakref.ref(parent)

def _cache_inherited_attrs(self, exclusions=None):
if exclusions is None:
exclusions = self._uninherited_attrs
exclusions = set()

exclusions = self._uninherited_attrs.union(self._temp_uninherited).union(exclusions)

for attr in self._param_attrs:
if attr not in exclusions:
Expand All @@ -1094,7 +1100,9 @@ def _cache_inherited_attrs(self, exclusions=None):

def _restore_inherited_attrs(self, exclusions=None):
if exclusions is None:
exclusions = self._uninherited_attrs
exclusions = set()

exclusions = self._uninherited_attrs.union(self._temp_uninherited).union(exclusions)

for attr in self._param_attrs:
if (
Expand Down Expand Up @@ -1786,12 +1794,12 @@ def __setattr__(self, attr, value):

def _cache_inherited_attrs(self):
super()._cache_inherited_attrs(
exclusions=self._uninherited_attrs.union(self._sourced_attrs)
exclusions=self._sourced_attrs
)

def _restore_inherited_attrs(self):
super()._restore_inherited_attrs(
exclusions=self._uninherited_attrs.union(self._sourced_attrs)
exclusions=self._sourced_attrs
)

def _set_name(self, name):
Expand Down
13 changes: 13 additions & 0 deletions tests/misc/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,19 @@ def test_parameter_values_overriding(ancestor, child, should_override, reset_var
assert child.parameters.variable.default_value == original_child_variable


def test_unspecified_inheritance():
class NewTM(pnl.TransferMechanism):
class Parameters(pnl.TransferMechanism.Parameters):
pass

assert NewTM.parameters.variable._inherited
NewTM.parameters.variable.default_value = -1
assert not NewTM.parameters.variable._inherited

NewTM.parameters.variable.reset()
assert NewTM.parameters.variable._inherited


@pytest.mark.parametrize('obj, param_name, alias_name', param_alias_data)
def test_aliases(obj, param_name, alias_name):
obj = obj()
Expand Down