From ee61d35dbb0a15766101c29685697a8f5994d634 Mon Sep 17 00:00:00 2001 From: jdcpni Date: Thu, 14 Nov 2024 06:32:35 -0500 Subject: [PATCH] Fix/matrix transform l0 (#3113) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • transformfunctions.py - MatrixTransform: allow normaliation for L0 • emcomposition.py - enforce normalize_memories for len(keys)==1 • test_emcomposition.py - test_simple_execution_without_learning(): add tests for scalar keys & use of L0 in MatrixTransform --- docs/source/CombinationFunctions.rst | 11 --- docs/source/Core.rst | 4 +- docs/source/NonStatefulFunctions.rst | 5 +- docs/source/TransformFunctions.rst | 11 +++ .../nonstateful/transformfunctions.py | 94 +++++++++---------- .../library/compositions/emcomposition.py | 7 +- tests/composition/test_emcomposition.py | 64 +++++++------ 7 files changed, 101 insertions(+), 95 deletions(-) delete mode 100644 docs/source/CombinationFunctions.rst create mode 100644 docs/source/TransformFunctions.rst diff --git a/docs/source/CombinationFunctions.rst b/docs/source/CombinationFunctions.rst deleted file mode 100644 index 31a55947cc0..00000000000 --- a/docs/source/CombinationFunctions.rst +++ /dev/null @@ -1,11 +0,0 @@ -CombinationFunctions -==================== - -.. toctree:: - :maxdepth: 3 - -.. automodule:: psyneulink.core.components.functions.combinationfunctions - :members: Concatenate, Rearrange, Reduce, LinearCombination, CombineMeans, PredictionErrorDeltaFunction - :private-members: - :exclude-members: Parameters - diff --git a/docs/source/Core.rst b/docs/source/Core.rst index 292689884fd..ea1f1b7105d 100644 --- a/docs/source/Core.rst +++ b/docs/source/Core.rst @@ -57,8 +57,6 @@ Core - `NonStatefulFunctions` - - `CombinationFunctions` - - `DistributionFunctions` - `LearningFunctions` @@ -71,6 +69,8 @@ Core - `TransferFunctions` + - `TransformFunctions` + - `StatefulFunctions` - `IntegratorFunctions` diff --git a/docs/source/NonStatefulFunctions.rst b/docs/source/NonStatefulFunctions.rst index 55ad922776c..f69c780ad63 100644 --- a/docs/source/NonStatefulFunctions.rst +++ b/docs/source/NonStatefulFunctions.rst @@ -8,10 +8,11 @@ Functions that do *not* depend on a previous value. .. toctree:: :maxdepth: 1 - CombinationFunctions + DistributionFunctions LearningFunctions ObjectiveFunctions OptimizationFunctions SelectionFunctions - TransferFunctions \ No newline at end of file + TransferFunctions + TransformFunctions \ No newline at end of file diff --git a/docs/source/TransformFunctions.rst b/docs/source/TransformFunctions.rst new file mode 100644 index 00000000000..70cd2194ad7 --- /dev/null +++ b/docs/source/TransformFunctions.rst @@ -0,0 +1,11 @@ +TransformFunctions +================== + +.. toctree:: + :maxdepth: 3 + +.. automodule:: psyneulink.core.components.functions.transformfunctions + :members: Concatenate, Rearrange, Reduce, LinearCombination, CombineMeans, MatrixTransform, PredictionErrorDeltaFunction + :private-members: + :exclude-members: Parameters + diff --git a/psyneulink/core/components/functions/nonstateful/transformfunctions.py b/psyneulink/core/components/functions/nonstateful/transformfunctions.py index bd0403bfcf5..86c1db6b7b5 100644 --- a/psyneulink/core/components/functions/nonstateful/transformfunctions.py +++ b/psyneulink/core/components/functions/nonstateful/transformfunctions.py @@ -1628,20 +1628,48 @@ class MatrixTransform(TransformFunction): # ----------------------------------- Matrix transform of `variable `. - `function ` returns dot product of variable with matrix: + `function ` returns a matrix transform of `variable ` + based on the **operation** argument. - .. math:: - variable \\bullet matrix + **operation** = *DOT_PRODUCT*: - If *DOT_PRODUCT* is specified as the **operation*, the result is the dot product of `variable - ` and `matrix `; if *L0* is specified, the result is the - difference between `variable ` and `matrix ` (see - `operation ` for additional details). + Returns the dot (inner) product of `variable ` and `matrix `: - If **normalize** is True, the result is normalized by the product of the norms of the variable and matrix: + .. math:: + {variable} \\bullet |matrix| + + If **normalize** =True, the result is normalized by the product of the norms of the variable and matrix: + + .. math:: + \\frac{variable \\bullet matrix}{\\|variable\\| \\cdot \\|matrix\\|} + + .. note:: + For **normalize** =True, the result is the same as the cosine of the angle between pairs of vectors. + + **operation** = *L0*: + + Returns the absolute value of the difference between `variable ` and `matrix + `: + + .. math:: + |variable - matrix| + + If **normalize** =True, the result is normalized by the norm of the sum of differences between the variable and + matrix, which is then subtracted from 1: + + .. math:: + 1 - \\frac{|variable - matrix|}{\\|variable - matrix\\|} + + .. note:: + For **normalize** =True, the result has the same effect as the normalized *DOT_PRODUCT* operation, + with more similar pairs of vectors producing larger values (closer to 1). + + .. warning:: + For **normalize** =False, the result is smaller (closer to 0) for more similar pairs of vectors, + which is **opposite** the effect of the *DOT_PRODUCT* and normalized *L0* operations. If the desired + result is that more similar pairs of vectors produce larger values, set **normalize** =True or + use the *DOT_PRODUCT* operation. - .. math:: - \\frac{variable \\bullet matrix}{\\|variable\\| \\cdot \\|matrix\\|} COMMENT: [CONVERT TO FIGURE] ---------------------------------------------------------------------------------------------------------- @@ -1679,7 +1707,7 @@ class MatrixTransform(TransformFunction): # ----------------------------------- specifies matrix used to transform `variable ` (see `matrix ` for specification details). - When MatrixTransform is the `function ` of a projection: + When MatrixTransform is the `function ` of a projection: - the matrix specification must be compatible with the variables of the `sender ` and `receiver ` @@ -1795,15 +1823,6 @@ class Parameters(TransformFunction.Parameters): normalize = Parameter(False) bounds = None - # def is_matrix_spec(m): - # if m is None: - # return True - # if m in MATRIX_KEYWORD_VALUES: - # return True - # if isinstance(m, (list, np.ndarray, types.FunctionType)): - # return True - # return False - @check_user_specified @beartype def __init__(self, @@ -1833,25 +1852,6 @@ def __init__(self, skip_log=True, ) - # def _validate_variable(self, variable, context=None): - # """Insure that variable passed to MatrixTransform is a max 2D array - # - # :param variable: (max 2D array) - # :param context: - # :return: - # """ - # variable = super()._validate_variable(variable, context) - # - # # Check that variable <= 2D - # try: - # if not variable.ndim <= 2: - # raise FunctionError("variable ({0}) for {1} must be a numpy.ndarray of dimension at most 2".format(variable, self.__class__.__name__)) - # except AttributeError: - # raise FunctionError("PROGRAM ERROR: variable ({0}) for {1} should be a numpy.ndarray". - # format(variable, self.__class__.__name__)) - # - # return variable - def _validate_params(self, request_set, target_set=None, context=None): """Validate params and assign to targets @@ -2013,15 +2013,6 @@ def _validate_params(self, request_set, target_set=None, context=None): self.name, self.owner_name, MATRIX_KEYWORD_NAMES)) - - # operation param - elif param_name == OPERATION: - if param_value == L0 and NORMALIZE in param_set and param_set[NORMALIZE]: - raise FunctionError(f"The 'operation' parameter for the {self.name} function of " - f"{self.owner_name} is set to 'L0', so the 'normalize' parameter " - f"should not be set to True " - f"(normalization is not needed, and can cause a divide by zero error). " - f"Set 'normalize' to False or change 'operation' to 'DOT_PRODUCT'.") else: continue @@ -2176,7 +2167,7 @@ def diff_with_normalization(vector, matrix): if normalize: return diff_with_normalization else: - return lambda x, y: torch.sum((1 - torch.abs(x - y)),axis=0) + return lambda x, y: torch.sum(torch.abs(x - y),axis=0) else: from psyneulink.library.compositions.autodiffcomposition import AutodiffCompositionError @@ -2224,10 +2215,11 @@ def _function(self, result = np.dot(vector, matrix) elif operation == L0: - normalization = 1 if normalize: normalization = np.sum(np.abs(vector - matrix)) - result = np.sum(((1 - np.abs(vector - matrix)) / normalization),axis=0) + result = np.sum((1 - (np.abs(vector - matrix)) / normalization),axis=0) + else: + result = np.sum((np.abs(vector - matrix)),axis=0) return self.convert_output_type(result) diff --git a/psyneulink/library/compositions/emcomposition.py b/psyneulink/library/compositions/emcomposition.py index 46acce0308d..8992e053c9f 100644 --- a/psyneulink/library/compositions/emcomposition.py +++ b/psyneulink/library/compositions/emcomposition.py @@ -2201,7 +2201,11 @@ def _construct_match_nodes(self, memory_template, memory_capacity, concatenate_q """ OPERATION = 0 NORMALIZE = 1 - args = [(L0,False) if len(key) == 1 else (DOT_PRODUCT,normalize_memories) for key in memory_template[0]] + # Enforce normalization of memories if key is a scalar + # (this is to allow 1-L0 distance to be used as similarity measure, so that better matches + # (more similar memories) have higher match values; see `MatrixTransform` for explanation) + args = [(L0,True) if len(key) == 1 else (DOT_PRODUCT,normalize_memories) + for key in memory_template[0]] if concatenate_queries: # Get fields of memory structure corresponding to the keys @@ -2238,7 +2242,6 @@ def _construct_match_nodes(self, memory_template, memory_capacity, concatenate_q for i in range(self.num_keys) ] - return match_nodes # FIX: CONVERT TO _construct_weight_control_nodes diff --git a/tests/composition/test_emcomposition.py b/tests/composition/test_emcomposition.py index 55c01ad7b51..024076d9ec6 100644 --- a/tests/composition/test_emcomposition.py +++ b/tests/composition/test_emcomposition.py @@ -258,7 +258,7 @@ class TestExecution: # NOTE: None => use default value (i.e., don't specify in constructor, rather than forcing None as value of arg) # ---------------------------------------- SPECS ----------------------------------- ----- EXPECTED --------- # memory_template mem mem mem fld concat nlz sm str inputs expected_retrieval - # fill cap decay wts keys gain prob + # fill cap decay wts keys mem gain prob # ---------------------------------------------------------------------------------- ------------------------ (0, [[[1,2,3],[4,6]], [[1,2,5],[4,8]], @@ -266,26 +266,26 @@ class TestExecution: [4., 6.16540637]]), (1, [[[1,2,3],[4,6]], [[1,2,5],[4,8]], - [[1,2,10],[4,10]]], None, 3, 0, [1,0], None, None, 100, 0, [[[1, 2, 3]], - [[4, 6]]], [[1., 2., 3.16585899], + [[1,2,10],[4,10]]], None, 3, 0, [1,0], None, None, 100, 0, [[1, 2, 3], + [4, 6]], [[1., 2., 3.16585899], [4., 6.16540637]]), (2, [[[1,2,3],[4,6]], [[1,2,5],[4,8]], - [[1,2,10],[4,10]]], None, 3, 0, [1,0], None, None, 100, 0, [[[1, 2, 3]], - [[4, 8]]], [[1., 2., 3.16585899], + [[1,2,10],[4,10]]], None, 3, 0, [1,0], None, None, 100, 0, [[1, 2, 3], + [4, 8]], [[1., 2., 3.16585899], [4., 6.16540637]]), (3, [[[1,2,3],[4,6]], [[1,2,5],[4,8]], - [[1,2,10],[4,10]]], (0,.01), 4, 0, [1,0], None, None, 100, 0, [[[1, 2, 3]], - [[4, 8]]], [[0.99998628, + [[1,2,10],[4,10]]], (0,.01), 4, 0, [1,0], None, None, 100, 0, [[1, 2, 3], + [4, 8]], [[0.99998628, 1.99997247, 3.1658154 ], [3.99994492, 6.16532141]]), (4, [[[1,2,3],[4,6]], # Equal field_weights (but not concatenated) [[1,2,5],[4,6]], - [[1,2,10],[4,6]]], (0,.01), 4, 0, [1,1], None, None, 100, 0, [[[1, 2, 3]], - [[4, 6]]], [[0.99750462, + [[1,2,10],[4,6]]], (0,.01), 4, 0, [1,1], None, None, 100, 0, [[1, 2, 3], + [4, 6]], [[0.99750462, 1.99499376, 3.51623568], [3.98998465, @@ -293,62 +293,67 @@ class TestExecution: ), (5, [[[1,2,3],[4,6]], # Equal field_weights with concatenation [[1,2,5],[4,8]], - [[1,2,10],[4,10]]], (0,.01), 4, 0, [1,1], True, None, 100, 0, [[[1, 2, 4]], - [[4, 6]]], [[0.99898504, + [[1,2,10],[4,10]]], (0,.01), 4, 0, [1,1], True, None, 100, 0, [[1, 2, 4], + [4, 6]], [[0.99898504, 1.99796378, 4.00175037], [3.99592639, 6.97406456]]), (6, [[[1,2,3],[4,6]], # Unequal field_weights [[1,2,5],[4,8]], - [[1,2,10],[4,10]]], (0,.01), 4, 0, [9,1], None, None, 100, 0, [[[1, 2, 3]], - [[4, 6]]], [[0.99996025, + [[1,2,10],[4,10]]], (0,.01), 4, 0, [9,1], None, None, 100, 0, [[1, 2, 3], + [4, 6]], [[0.99996025, 1.99992024, 3.19317783], [3.99984044, 6.19219795]]), (7, [[[1,2,3],[4,6]], # Store + no decay [[1,2,5],[4,8]], - [[1,2,10],[4,10]]], (0,.01), 4, 0, [9,1], None, None, 100, 1, [[[1, 2, 3]], - [[4, 6]]], [[0.99996025, + [[1,2,10],[4,10]]], (0,.01), 4, 0, [9,1], None, None, 100, 1, [[1, 2, 3], + [4, 6]], [[0.99996025, 1.99992024, 3.19317783], [3.99984044, 6.19219795]]), (8, [[[1,2,3],[4,6]], # Store + default decay (should be AUTO) [[1,2,5],[4,8]], - [[1,2,10],[4,10]]], (0,.01), 4, None, [9,1], None, None, 100, 1, [[[1, 2, 3]], - [[4, 6]]], [[0.99996025, + [[1,2,10],[4,10]]], (0,.01), 4, None, [9,1], None, None, 100, 1,[[1, 2, 3], + [4, 6]], [[0.99996025, 1.99992024, 3.19317783], [3.99984044, 6.19219795]]), (9, [[[1,2,3],[4,6]], # Store + explicit AUTO decay [[1,2,5],[4,8]], - [[1,2,10],[4,10]]], (0,.01), 4, AUTO, [9,1], None, None, 100, 1, [[[1, 2, 3]], - [[4, 6]]], [[0.99996025, + [[1,2,10],[4,10]]], (0,.01), 4, AUTO, [9,1], None, None, 100, 1, [[1, 2, 3], + [4, 6]], [[0.99996025, 1.99992024, 3.19317783], [3.99984044, 6.19219795]]), (10, [[[1,2,3],[4,6]], # Store + numerical decay [[1,2,5],[4,8]], - [[1,2,10],[4,10]]], (0,.01), 4, .1, [9,1], None, None, 100, 1, [[[1, 2, 3]], - [[4, 6]]], [[0.99996025, + [[1,2,10],[4,10]]], (0,.01), 4, .1, [9,1], None, None, 100, 1, [[1, 2, 3], + [4, 6]], [[0.99996025, 1.99992024, 3.19317783], [3.99984044, 6.19219795]]), (11, [[[1,2,3],[4,6]], # Same as 10, but with equal weights and concatenate keys [[1,2,5],[4,8]], - [[1,2,10],[4,10]]], (0,.01), 4, .1, [1,1], True, None, 100, 1, [[[1, 2, 3]], - [[4, 6]]], [[0.99922544, + [[1,2,10],[4,10]]], (0,.01), 4, .1, [1,1], True, None, 100, 1, [[1, 2, 3], + [4, 6]], [[0.99922544, 1.99844608, 3.38989346], [3.99689126, 6.38682264]]), -# [3.99984044, -# 6.19219795]]), + + (12, [[[1],[2],[3]], # Scalar keys - exact match (this tests use of L0 for retreieval in MEMORY matrix) + [[10],[0],[100]]], (0,.01), 3, 0, [1,1,0], None, None, pnl.ARG_MAX, 1, [[10],[0],[100]], + [[10],[0],[100]]), + + (13, [[[1],[2],[3]], # Scalar keys - close match (this tests use of L0 for retreieval in MEMORY matrix + [[10],[0],[100]]], (0,.01), 3, 0, [1,1,0], None, None, pnl.ARG_MAX, 1, [[2],[3],[4]], [[1],[2],[3]]), ] args_names = "test_num, memory_template, memory_fill, memory_capacity, memory_decay_rate, field_weights, " \ @@ -401,7 +406,11 @@ def test_simple_execution_without_learning(self, if normalize_memories is not None: params.update({'normalize_memories': normalize_memories}) if softmax_gain is not None: - params.update({'softmax_gain': softmax_gain}) + if softmax_gain == pnl.ARG_MAX: + params.update({'softmax_choice': softmax_gain}) + params.update({'softmax_gain': 100}) + else: + params.update({'softmax_gain': softmax_gain}) if storage_prob is not None: params.update({'storage_prob': storage_prob}) params.update({'softmax_threshold': None}) @@ -432,7 +441,8 @@ def test_simple_execution_without_learning(self, # Validate storage if storage_prob: - for actual, expected in zip(em.memory[-1], [[1,2,3],[4,6]]): + # for actual, expected in zip(em.memory[-1], [[1,2,3],[4,6]]): + for actual, expected in zip(em.memory[-1], list(inputs.values())): np.testing.assert_array_equal(actual, expected) if memory_decay_rate in {None, AUTO}: