Skip to content

Commit

Permalink
Merge pull request #2689 from kmantel/parameter-parsing-errors
Browse files Browse the repository at this point in the history
Rework Parameter parse/validation method calling
  • Loading branch information
kmantel authored Jun 7, 2023
2 parents f4850b5 + 17b896f commit 1582baa
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 59 deletions.
18 changes: 5 additions & 13 deletions psyneulink/core/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -2061,20 +2061,15 @@ def _initialize_parameters(self, context=None, **param_defaults):

if parameter_obj.modulable:
# later, validate this
try:
modulable_param_parser = self.parameters._get_prefixed_method(
parse=True,
modulable=True
)
modulable_param_parser = self.parameters._get_parse_method('modulable')
if modulable_param_parser is not None:
parsed = modulable_param_parser(name, value)

if parsed is not value:
# we have a modulable param spec
parameter_obj.spec = value
value = parsed
param_defaults[name] = parsed
except AttributeError:
pass

if value is not None or parameter_obj.specify_none:
defaults[name] = value
Expand Down Expand Up @@ -2694,14 +2689,11 @@ def _validate_params(self, request_set, target_set=None, context=None):
try:
target_set[param_name] = param_value.copy()
except AttributeError:
try:
modulable_param_parser = self.parameters._get_prefixed_method(
parse=True,
modulable=True
)
modulable_param_parser = self.parameters._get_parse_method('modulable')
if modulable_param_parser is not None:
param_value = modulable_param_parser(param_name, param_value)
target_set[param_name] = param_value
except AttributeError:
else:
target_set[param_name] = param_value.copy()

else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from psyneulink.core.globals.parameters import Parameter, check_user_specified
from psyneulink.core.globals.preferences.basepreferenceset import ValidPrefSet
from psyneulink.core.globals.utilities import ValidParamSpecType, all_within_range, \
convert_all_elements_to_np_array, parse_valid_identifier
convert_all_elements_to_np_array, parse_valid_identifier, safe_len

__all__ = ['SimpleIntegrator', 'AdaptiveIntegrator', 'DriftDiffusionIntegrator', 'DriftOnASphereIntegrator',
'OrnsteinUhlenbeckIntegrator', 'FitzHughNagumoIntegrator', 'AccumulatorIntegrator',
Expand Down Expand Up @@ -276,7 +276,7 @@ def _validate_params(self, request_set, target_set=None, context=None):
# If param is in Parameter class for function and it is a function_arg:
if (param in self.parameters.names() and getattr(self.parameters, param).function_arg
and getattr(self.parameters, param)._user_specified):
if value is not None and isinstance(value, (list, np.ndarray)) and len(value)>1:
if value is not None and isinstance(value, (list, np.ndarray)) and safe_len(value)>1:
# Store ones with length > 1 in dict for evaluation below
params_to_check.update({param:value})

Expand Down Expand Up @@ -305,7 +305,7 @@ def _instantiate_attributes_before_function(self, function=None, context=None):
values_with_a_len = [param.default_value for param in self.parameters if
param.function_arg and
isinstance(param.default_value, (list, np.ndarray)) and
len(param.default_value)>1]
safe_len(param.default_value)>1]
# One or more parameters are specified with length > 1 in the inner dimension
if values_with_a_len:
# If shape already matches,
Expand Down Expand Up @@ -2427,6 +2427,7 @@ class Parameters(IntegratorFunction.Parameters):
random_draw = Parameter()

def _parse_initializer(self, initializer):
initializer = np.array(initializer)
if initializer.ndim > 1:
return np.atleast_1d(initializer.squeeze())
else:
Expand Down Expand Up @@ -2997,6 +2998,7 @@ def _validate_initializer(self, initializer):

def _parse_initializer(self, initializer):
"""Assign initial value as array of random values of length dimension-1"""
initializer = np.array(initializer)
initializer_dim = self.dimension.default_value - 1
if initializer.ndim != 1 or len(initializer) != initializer_dim:
initializer = np.random.random(initializer_dim)
Expand Down Expand Up @@ -3519,6 +3521,7 @@ class Parameters(IntegratorFunction.Parameters):
)

def _parse_initializer(self, initializer):
initializer = np.array(initializer)
if initializer.ndim > 1:
return np.atleast_1d(initializer.squeeze())
else:
Expand Down
70 changes: 28 additions & 42 deletions psyneulink/core/globals/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,13 @@ class ParameterError(Exception):
pass


def _get_prefixed_method(obj, prefix, name, sep=''):
try:
return getattr(obj, f'{prefix}{sep}{name}')
except AttributeError:
return None


def get_validator_by_function(function):
"""
Arguments
Expand Down Expand Up @@ -2290,41 +2297,23 @@ def _param_is_specified_in_class(self, param_name):
)
)

def _get_prefixed_method(
self,
parse=False,
validate=False,
modulable=False,
parameter_name=None
):
def _get_parse_method(self, parameter):
"""
Returns the parsing or validation method for the Parameter named
**parameter_name** or for any modulable Parameter
Returns:
the parsing method for the **parameter** or for any Parameter
attribute (ex: 'modulable') if it exists, or None if it does
not
"""
return _get_prefixed_method(self, self._parsing_method_prefix, parameter)

if (
parse and validate
or (not parse and not validate)
):
raise ValueError('Exactly one of parse or validate must be True')

if parse:
prefix = self._parsing_method_prefix
elif validate:
prefix = self._validation_method_prefix

if (
modulable and parameter_name is not None
or not modulable and parameter_name is None
):
raise ValueError('modulable must be True or parameter_name must be specified, but not both.')

if modulable:
suffix = 'modulable'
elif parameter_name is not None:
suffix = parameter_name

return getattr(self, '{0}{1}'.format(prefix, suffix))
def _get_validate_method(self, parameter):
"""
Returns:
the validation method for the **parameter** or for any
Parameter attribute (ex: 'modulable') if it exists, or None
if it does not
"""
return _get_prefixed_method(self, self._validation_method_prefix, parameter)

def _validate(self, attr, value):
err_msg = None
Expand All @@ -2337,16 +2326,13 @@ def _validate(self, attr, value):
valid_types
)

try:
validation_method = self._get_prefixed_method(validate=True, parameter_name=attr)
validation_method = self._get_validate_method(attr)
if validation_method is not None:
err_msg = validation_method(value)
# specifically check for False because None indicates a valid assignment
if err_msg is False:
err_msg = '{0} returned False'.format(validation_method)

except AttributeError:
# parameter does not have a validation method
pass

if err_msg is not None:
raise ParameterError(
"Value ({0}) assigned to parameter '{1}' of {2}.parameters is not valid: {3}".format(
Expand All @@ -2358,7 +2344,7 @@ def _validate(self, attr, value):
)

def _parse(self, attr, value):
try:
return self._get_prefixed_method(parse=True, parameter_name=attr)(value)
except AttributeError:
return value
parse_method = self._get_parse_method(attr)
if parse_method is not None:
value = parse_method(value)
return value
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ class Parameters(RecurrentTransferMechanism.Parameters):
def _validate_competition(self, competition):
if competition < 0:
warnings.warn(
f"The 'competition' arg specified for {self.name} is a negative value ({competition}); "
f"The 'competition' arg specified for {self._owner.name} is a negative value ({competition}); "
f"note that this will result in a matrix that has positive off-diagonal elements "
f"since 'competition' is assumed to specify the magnitude of inhibition."
)
Expand Down

0 comments on commit 1582baa

Please sign in to comment.