Skip to content

Commit

Permalink
Parameters: allow _validate_ methods to reference other parameters (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
kmantel authored and jdcpni committed Oct 28, 2022
1 parent 714b6c4 commit 8d7ece9
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 14 deletions.
50 changes: 37 additions & 13 deletions psyneulink/core/globals/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,11 @@ def __init__(p=None, q=1.0):
`Parameter attributes <Parameter_Attributes_Table>`
- default values for the parameters can be specified in the Parameters class body, or in the
arguments for *B*.__init__. If both are specified and the values differ, an exception will be raised
- if you want assignments to parameter *p* to be validated, add a method _validate_p(value),
- if you want assignments to parameter *p* to be validated, add a method _validate_p(value), \
that returns None if value is a valid assignment, or an error string if value is not a valid assignment
- NOTE: a validation method for *p* may reference other parameters \
only if they are listed in *p*'s \
`dependencies <Parameter.dependencies>`
- if you want all values set to *p* to be parsed beforehand, add a method _parse_p(value) that returns the parsed value
- for example, convert to a numpy array or float
Expand All @@ -123,6 +126,8 @@ def __init__(p=None, q=1.0):
def _parse_p(value):
return np.asarray(value)
- NOTE: parsers may not reference other parameters
- setters and getters (used for more advanced behavior than parsing) should both return the final value to return (getter) or set (setter)
For example, `costs <ControlMechanism.costs>` of `ControlMechanism <ControlMechanism>` has a special
Expand Down Expand Up @@ -607,13 +612,15 @@ def _owner(self, value):
except TypeError:
self._owner_ref = value

@property
def _in_dependency_order(self):
def _dependency_order_key(self, names=False):
"""
Returns:
list[Parameter] - a list of Parameters such that any
Parameter is placed before all of its
`dependencies <Parameter.dependencies>`
Args:
names (bool, optional): Whether sorting key is based on
Parameter names or Parameter objects. Defaults to False.
Returns:
types.FunctionType: a function that may be passed in as sort
key so that any Parameter is placed before its dependencies
"""
parameter_function_ordering = list(toposort.toposort({
p.name: p.dependencies for p in self if p.dependencies is not None
Expand All @@ -622,13 +629,30 @@ def _in_dependency_order(self):
itertools.chain.from_iterable(parameter_function_ordering)
)

def ordering(p):
try:
return parameter_function_ordering.index(p.name)
except ValueError:
return -1
if names:
def ordering(p):
try:
return parameter_function_ordering.index(p)
except ValueError:
return -1
else:
def ordering(p):
try:
return parameter_function_ordering.index(p.name)
except ValueError:
return -1

return ordering

return sorted(self, key=ordering)
@property
def _in_dependency_order(self):
"""
Returns:
list[Parameter] - a list of Parameters such that any
Parameter is placed before all of its
`dependencies <Parameter.dependencies>`
"""
return sorted(self, key=self._dependency_order_key())


class Defaults(ParametersTemplate):
Expand Down
9 changes: 8 additions & 1 deletion psyneulink/core/globals/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,14 @@ def __deepcopy__(self, memo):
result = cls.__new__(cls)
memo[id(self)] = result

for k, v in self.__dict__.items():
try:
# follow dependency order for Parameters to allow validation involving other parameters
ordered_dict_keys = sorted(self.__dict__, key=self._dependency_order_key(names=True))
except AttributeError:
ordered_dict_keys = self.__dict__

for k in ordered_dict_keys:
v = self.__dict__[k]
if k in shared_keys or isinstance(v, shared_types):
res_val = v
else:
Expand Down
38 changes: 38 additions & 0 deletions tests/misc/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,3 +633,41 @@ def set_p_default(obj, val):
assert TestParent.defaults.p == 0
assert TestChild.defaults.p == 1
assert TestGrandchild.defaults.p == 20


def test_dependent_parameter_validate():
# using 3 parameters to reduce chance of random success
class NewF(pnl.Function_Base):
class Parameters(pnl.Function_Base.Parameters):
a = pnl.Parameter(1)
b = pnl.Parameter(2, dependencies='a')
c = pnl.Parameter(3, dependencies='b')
d = pnl.Parameter(4, dependencies='c')

def _validate_b(self, b):
if b != self.a.default_value + 1:
return 'invalid'

def _validate_c(self, c):
if c != self.b.default_value + 1:
return 'invalid'

def _validate_d(self, d):
if d != self.c.default_value + 1:
return 'invalid'

def __init__(self, **kwargs):
return super().__init__(0, {}, **kwargs)

def _function(self, variable=None, context=None, params=None):
return 0

pnl.ProcessingMechanism(function=NewF(a=2, b=3, c=4, d=5))

with pytest.raises(pnl.ParameterError) as err:
# b should be first error to occur
pnl.ProcessingMechanism(function=NewF(b=3, c=5, d=7))
assert re.match(
r"Value \(3\) assigned to parameter 'b'.*is not valid: invalid",
str(err.value)
)

0 comments on commit 8d7ece9

Please sign in to comment.