Skip to content

Commit

Permalink
treewide: remove use of dot notation for stateful parameters (#1760)
Browse files Browse the repository at this point in the history
  • Loading branch information
kmantel authored Sep 25, 2020
1 parent b076129 commit 8708b66
Show file tree
Hide file tree
Showing 23 changed files with 215 additions and 184 deletions.
8 changes: 4 additions & 4 deletions psyneulink/core/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -1786,7 +1786,7 @@ def generate_error(param_name):
for param_name in runtime_params:
if not isinstance(param_name, str):
generate_error(param_name)
elif hasattr(self, param_name):
elif param_name in self.parameters:
if param_name in {FUNCTION, INPUT_PORTS, OUTPUT_PORTS}:
generate_error(param_name)
if context.execution_id not in self._runtime_params_reset:
Expand All @@ -1797,7 +1797,7 @@ def generate_error(param_name):
# Any remaining params should either belong to the Component's function
# or, if the Component is a Function, to it or its owner
elif ( # If Component is not a function, and its function doesn't have the parameter or
(not is_function_type(self) and not hasattr(self.function, param_name))
(not is_function_type(self) and param_name not in self.function.parameters)
# the Component is a standalone function:
or (is_function_type(self) and not self.owner)):
generate_error(param_name)
Expand Down Expand Up @@ -2856,8 +2856,8 @@ def _instantiate_function(self, function, function_params=None, context=None):
# KAM added 6/14/18 for functions that do not pass their has_initializers status up to their owner via property
# FIX: need comprehensive solution for has_initializers; need to determine whether ports affect mechanism's
# has_initializers status
if self.function.has_initializers:
self.has_initializers = True
if self.function.parameters.has_initializers._get(context):
self.parameters.has_initializers._set(True, context)

self._parse_param_port_sources()

Expand Down
4 changes: 2 additions & 2 deletions psyneulink/core/components/functions/objectivefunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class Parameters(Function_Base.Parameters):
:default value: False
:type: ``bool``
"""
normalize = False
normalize = Parameter(False, stateful=False)
metric = Parameter(None, stateful=False)


Expand Down Expand Up @@ -205,7 +205,7 @@ class Parameters(ObjectiveFunction.Parameters):
metric = Parameter(ENERGY, stateful=False)
metric_fct = Parameter(None, stateful=False, loggable=False)
transfer_fct = Parameter(None, stateful=False, loggable=False)
normalize = False
normalize = Parameter(False, stateful=False)

@tc.typecheck
def __init__(self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ def _accumulator_check_args(self, variable=None, context=None, params=None, targ
runtime_params = params
if runtime_params:
for param_name in runtime_params:
if hasattr(self, param_name):
if param_name in self.parameters:
if param_name in {FUNCTION, INPUT_PORTS, OUTPUT_PORTS}:
continue
if context.execution_id not in self._runtime_params_reset:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,13 +442,20 @@ def _instantiate_attributes_before_function(self, function=None, context=None):
# use np.broadcast_to to guarantee that all initializer type attributes take on the same shape as variable
if not np.isscalar(self.defaults.variable):
for attr in self.initializers:
setattr(self, attr, np.broadcast_to(getattr(self, attr), self.defaults.variable.shape).copy())
param = getattr(self.parameters, attr)
param._set(
np.broadcast_to(
param._get(context),
self.defaults.variable.shape
).copy(),
context
)

# create all stateful attributes and initialize their values to the current values of their
# corresponding initializer attributes
for i, attr_name in enumerate(self.stateful_attributes):
initializer_value = getattr(self, self.initializers[i]).copy()
setattr(self, attr_name, initializer_value)
initializer_value = getattr(self.parameters, self.initializers[i])._get(context).copy()
getattr(self.parameters, attr_name)._set(initializer_value, context)

super()._instantiate_attributes_before_function(function=function, context=context)

Expand Down Expand Up @@ -551,7 +558,7 @@ def reset(self, *args, context=None):
setattr(self, attr, reinitialization_values[i])
getattr(self.parameters, attr).set(reinitialization_values[i],
context, override=True)
value.append(getattr(self, self.stateful_attributes[i]))
value.append(getattr(self.parameters, self.stateful_attributes[i])._get(context))

self.parameters.value.set(value, context, override=True)
return value
Expand Down
13 changes: 9 additions & 4 deletions psyneulink/core/components/functions/transferfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2682,7 +2682,10 @@ def __init__(self,
prefs=prefs,
)

self.matrix = self.instantiate_matrix(self.matrix)
self.parameters.matrix.set(
self.instantiate_matrix(self.parameters.matrix.get()),
skip_log=True,
)

# def _validate_variable(self, variable, context=None):
# """Insure that variable passed to LinearMatrix is a max 2D array
Expand Down Expand Up @@ -2921,10 +2924,12 @@ def _instantiate_attributes_before_function(self, function=None, context=None):
if isinstance(self.owner, Projection):
self.receiver = self.defaults.variable

if self.matrix is None and not hasattr(self.owner, "receiver"):
matrix = self.parameters.matrix._get(context)

if matrix is None and not hasattr(self.owner, "receiver"):
variable_length = np.size(np.atleast_2d(self.defaults.variable), 1)
self.matrix = np.identity(variable_length)
self.matrix = self.instantiate_matrix(self.matrix)
matrix = np.identity(variable_length)
self.parameters.matrix._set(self.instantiate_matrix(matrix), context)

def instantiate_matrix(self, specification, context=None):
"""Implements matrix indicated by specification
Expand Down
54 changes: 38 additions & 16 deletions psyneulink/core/components/mechanisms/mechanism.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,7 @@ class `UserList <https://docs.python.org/3.6/library/collections.html?highlight=
"""

import abc
import copy
import inspect
import itertools
import logging
Expand Down Expand Up @@ -1861,7 +1862,7 @@ def _handle_arg_input_ports(self, input_ports):
mech_variable_item = args[VARIABLE]
else:
try:
mech_variable_item = parsed_input_port_spec.value
mech_variable_item = parsed_input_port_spec.defaults.value
except AttributeError:
mech_variable_item = parsed_input_port_spec.defaults.mech_variable_item
else:
Expand Down Expand Up @@ -2083,37 +2084,57 @@ def _instantiate_function(self, function, function_params=None, context=None):

super()._instantiate_function(function=function, function_params=function_params, context=context)

if self.input_ports and any(input_port.weight is not None for input_port in self.input_ports):
if (
self.input_ports
and any(
input_port.parameters.weight._get(context) is not None
for input_port in self.input_ports
)
and 'weights' in self.function.parameters
):

# Construct defaults:
# from function.weights if specified else 1's
try:
default_weights = self.function.weights
default_weights = self.function.defaults.weights
except AttributeError:
default_weights = None
if default_weights is None:
default_weights = default_weights or [1.0] * len(self.input_ports)

# Assign any weights specified in input_port spec
weights = [[input_port.weight if input_port.weight is not None else default_weight]
weights = [[input_port.defaults.weight if input_port.defaults.weight is not None else default_weight]
for input_port, default_weight in zip(self.input_ports, default_weights)]
self.function._weights = weights
self.function.parameters.weights._set(weights, context)

if self.input_ports and any(input_port.exponent is not None for input_port in self.input_ports):
if (
self.input_ports
and any(
input_port.parameters.exponent._get(context) is not None
for input_port in self.input_ports
)
and 'exponents' in self.function.parameters
):

# Construct defaults:
# from function.weights if specified else 1's
try:
default_exponents = self.function.exponents
default_exponents = self.function.defaults.exponents
except AttributeError:
default_exponents = None
if default_exponents is None:
default_exponents = default_exponents or [1.0] * len(self.input_ports)

# Assign any exponents specified in input_port spec
exponents = [[input_port.exponent if input_port.exponent is not None else default_exponent]
for input_port, default_exponent in zip(self.input_ports, default_exponents)]
self.function._exponents = exponents
exponents = [
[
input_port.parameters.exponent._get(context)
if input_port.parameters.exponent._get(context) is not None
else default_exponent
]
for input_port, default_exponent in zip(self.input_ports, default_exponents)
]
self.function.parameters.exponents._set(exponents, context)

# this may be removed when the restriction making all Mechanism values 2D np arrays is lifted
# ignore warnings of certain Functions that disable conversion
Expand Down Expand Up @@ -3607,7 +3628,10 @@ def add_ports(self, ports, update_variable=True, context=None):
instantiated_output_ports = _instantiate_output_ports(self, output_ports, context=context)

if update_variable:
self._update_default_variable(self.input_values, context)
self._update_default_variable(
[copy.deepcopy(port.defaults.value) for port in self.input_ports],
context
)

return {INPUT_PORTS: instantiated_input_ports,
OUTPUT_PORTS: instantiated_output_ports}
Expand Down Expand Up @@ -3683,12 +3707,10 @@ def delete_port_Projections(proj_list, port):
component=port)

elif port in self.output_ports:
if isinstance(port, OutputPort):
index = self.output_ports.index(port)
else:
index = self.output_ports.index(self.output_ports[port])
delete_port_Projections(port.efferents.copy(), port)
del self.output_values[index]
# NOTE: removed below del because output_values is
# generated on the fly. These comments can be removed
# del self.output_values[index]
del self.output_ports[port]
# If port is subclass of OutputPort:
# check if regsistry has category for that class, and if so, use that
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1532,10 +1532,10 @@ def _instantiate_control_signal_type(self, control_signal_spec, context):

control_signal = _instantiate_port(port_type=ControlSignal,
owner=self,
variable=self.default_allocation # User specified value
variable=self.defaults.default_allocation # User specified value
or allocation_parameter_default, # Parameter default
reference_value=allocation_parameter_default,
modulation=self.modulation,
modulation=self.defaults.modulation,
port_spec=control_signal_spec,
context=context)
if not type(control_signal) in convert_to_list(self.outputPortTypes):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def _instantiate_control_signal_type(self, gating_signal_spec, context):
variable=self.default_allocation # User specified value
or allocation_parameter_default, # Parameter default
reference_value=allocation_parameter_default,
modulation=self.modulation,
modulation=self.defaults.modulation,
port_spec=gating_signal_spec,
context=context)
if not type(gating_signal) in convert_to_list(self.outputPortTypes):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -847,12 +847,13 @@ def _instantiate_attributes_after_function(self, context=None):

super()._instantiate_attributes_after_function(context=context)
# Assign parameters to function (OptimizationFunction) that rely on OptimizationControlMechanism
self.function.reset({DEFAULT_VARIABLE: self.control_allocation,
OBJECTIVE_FUNCTION: self.evaluation_function,
# SEARCH_FUNCTION: self.search_function,
# SEARCH_TERMINATION_FUNCTION: self.search_termination_function,
SEARCH_SPACE: self.control_allocation_search_space
})
self.function.reset({
DEFAULT_VARIABLE: self.parameters.control_allocation._get(context),
OBJECTIVE_FUNCTION: self.evaluation_function,
# SEARCH_FUNCTION: self.search_function,
# SEARCH_TERMINATION_FUNCTION: self.search_termination_function,
SEARCH_SPACE: self.parameters.control_allocation_search_space._get(context)
})

if isinstance(self.agent_rep, type):
self.agent_rep = self.agent_rep()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1232,7 +1232,7 @@ def _instantiate_output_ports(self, context=None):
variable=(OWNER_VALUE,0),
params=params,
reference_value=self.parameters.learning_signal._get(context),
modulation=self.modulation,
modulation=self.defaults.modulation,
# port_spec=self.learning_signal)
port_spec=learning_signal,
context=context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -761,15 +761,21 @@ def _instantiate_function_weights_and_exponents(self, context=None):
DEFAULT_WEIGHT = 1
DEFAULT_EXPONENT = 1

weights = [input_port.weight for input_port in self.input_ports]
exponents = [input_port.exponent for input_port in self.input_ports]
weights = [input_port.defaults.weight for input_port in self.input_ports]
exponents = [input_port.defaults.exponent for input_port in self.input_ports]

if hasattr(self.function, WEIGHTS):
if WEIGHTS in self.function.parameters:
if any(weight is not None for weight in weights):
self.function.weights = [[weight or DEFAULT_WEIGHT] for weight in weights]
if hasattr(self.function, EXPONENTS):
self.function.parameters.weights._set(
[[weight or DEFAULT_WEIGHT] for weight in weights],
context
)
if EXPONENTS in self.function.parameters:
if any(exponent is not None for exponent in exponents):
self.function.exponents = [[exponent or DEFAULT_EXPONENT] for exponent in exponents]
self.function.parameters.exponents._set(
[[exponent or DEFAULT_EXPONENT] for exponent in exponents],
context
)
assert True

# # MODIFIED 6/8/19 NEW: [JDC]
Expand Down Expand Up @@ -800,14 +806,14 @@ def monitor(self):

@property
def monitor_weights_and_exponents(self):
if hasattr(self.function, WEIGHTS) and self.function.weights is not None:
weights = self.function.weights
if hasattr(self.function, WEIGHTS) and self.function.weights.base is not None:
weights = self.function.weights.base
else:
weights = [input_port.weight for input_port in self.input_ports]
if hasattr(self.function, EXPONENTS) and self.function.exponents is not None:
exponents = self.function.exponents
weights = [input_port.weight.base for input_port in self.input_ports]
if hasattr(self.function, EXPONENTS) and self.function.exponents.base is not None:
exponents = self.function.exponents.base
else:
exponents = [input_port.exponent for input_port in self.input_ports]
exponents = [input_port.exponent.base for input_port in self.input_ports]
return [(w,e) for w, e in zip(weights,exponents)]

@monitor_weights_and_exponents.setter
Expand Down
Loading

0 comments on commit 8708b66

Please sign in to comment.