Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

functions: _is_identity: add option for defaults #2136

Merged
merged 1 commit into from
Oct 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion psyneulink/core/components/functions/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,7 @@ def owner_name(self):
except AttributeError:
return '<no owner>'

def _is_identity(self, context=None):
def _is_identity(self, context=None, defaults=False):
# should return True in subclasses if the parameters for context are such that
# the Function's output will be the same as its input
# Used to bypass execute when unnecessary
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -477,11 +477,15 @@ def derivative(self, input=None, output=None, context=None):

return self._get_current_parameter_value(SLOPE, context)

def _is_identity(self, context=None):
return (
self.parameters.slope._get(context) == 1
and self.parameters.intercept._get(context) == 0
)
def _is_identity(self, context=None, defaults=False):
if defaults:
slope = self.defaults.slope
intercept = self.defaults.intercept
else:
slope = self.parameters.slope._get(context)
intercept = self.parameters.intercept._get(context)

return slope == 1 and intercept == 0


# **********************************************************************************************************************
Expand Down Expand Up @@ -3278,8 +3282,11 @@ def param_function(owner, function):
receiver_len = len(owner.receiver.defaults.variable)
return function(sender_len, receiver_len)

def _is_identity(self, context=None):
matrix = self.parameters.matrix._get(context)
def _is_identity(self, context=None, defaults=False):
if defaults:
matrix = self.defaults.matrix
else:
matrix = self.parameters.matrix._get(context)

# if matrix is not an np array with at least one dimension,
# this isn't an identity matrix
Expand Down Expand Up @@ -4088,9 +4095,18 @@ def _function(self,

return intensity

def _is_identity(self, context=None):
return (self.parameters.transfer_fct.get()._is_identity(context) and
self.parameters.enabled_cost_functions.get(context) == CostFunctions.NONE)
def _is_identity(self, context=None, defaults=False):
transfer_fct = self.parameters.transfer_fct.get()

if defaults:
enabled_cost_functions = self.defaults.enabled_cost_functions
else:
enabled_cost_functions = self.parameters.enabled_cost_functions.get(context)

return (
transfer_fct._is_identity(context, defaults=defaults)
and enabled_cost_functions == CostFunctions.NONE
)

@tc.typecheck
def assign_costs(self, cost_functions: tc.any(CostFunctions, list), execution_context=None):
Expand Down