diff --git a/psyneulink/core/components/functions/nonstateful/transferfunctions.py b/psyneulink/core/components/functions/nonstateful/transferfunctions.py index eccea3765d0..74f5925ab30 100644 --- a/psyneulink/core/components/functions/nonstateful/transferfunctions.py +++ b/psyneulink/core/components/functions/nonstateful/transferfunctions.py @@ -3159,7 +3159,19 @@ def _gen_llvm_function_derivative_body(self, ctx, builder, params, state, arg_in all_out = builder.alloca(arg_out.type.pointee) builder = self._gen_llvm_function_body(ctx, builder, params, state, arg_in, all_out, output_type=ALL, tags=forward_tags) - # The rest of the algorithm is for MAX_VAL and MAX_INDICATOR only + if self.parameters.per_item.get(): + assert isinstance(arg_in.type.pointee.element, pnlvm.ir.ArrayType) + assert isinstance(arg_out.type.pointee.element, pnlvm.ir.ArrayType) + for i in range(arg_in.type.pointee.count): + inner_all_out = builder.gep(all_out, [ctx.int32_ty(0), ctx.int32_ty(i)]) + inner_out = builder.gep(arg_out, [ctx.int32_ty(0), ctx.int32_ty(i)]) + builder = self.__gen_llvm_apply_derivative(ctx, builder, params, state, inner_all_out, inner_out, tags=tags) + return builder + else: + return self.__gen_llvm_apply_derivative(ctx, builder, params, state, all_out, arg_out, tags=tags) + + def __gen_llvm_apply_derivative(self, ctx, builder, params, state, all_out, arg_out, *, tags:frozenset): + assert self.output in {MAX_VAL, MAX_INDICATOR}, \ "Derivative of SoftMax is only implemented for MAX_VAL and MAX_INDICATOR! ({})".format(self.output) @@ -3292,44 +3304,51 @@ def derivative(self, input=None, output=None, context=None): else: assert not np.any(np.equal(0, output)) - sm = np.squeeze(output) - size = len(sm) - assert (len(output) == 1 and len(output[0]) == size) or len(output) == size + per_item = self._get_current_parameter_value(PER_ITEM, context) + if not per_item: + output = [output] + + result = [] + for sm in output: + size = len(sm) + + output_type = self._get_current_parameter_value(OUTPUT_TYPE, context) + if output_type == ALL: + # Return full Jacobian matrix of derivatives using Kronecker's delta method: + derivative = np.empty([size, size]) + for i, j in np.ndindex(size, size): + if i == j: + d = 1 + else: + d = 0 + derivative[j, i] = sm[i] * (d - sm[j]) + elif output_type in {MAX_VAL, MAX_INDICATOR}: + # Return 1d array of derivatives for max element (i.e., the one chosen by SoftMax) + derivative = np.empty(size) + # Get the element of output returned as non-zero (max val) when output_type is not ALL + # IMPLEMENTATION NOTES: + # if there is a tie for max, this chooses the item in sm with the lowest index in sm: + index_of_max = int(np.where(sm == np.max(sm))[-1][0]) + # the following would randomly choose a value in case of a tie, + # but may cause problems with compilation: + # index_of_max = np.where(sm == np.max(sm))[0] + # if len(index_of_max)>1: + # index_of_max = int(np.random.choice(index_of_max)) + max_val = sm[index_of_max] + for i in range(size): + if i == index_of_max: + d = 1 + else: + d = 0 + derivative[i] = sm[i] * (d - max_val) + else: + raise FunctionError("Can't assign derivative for SoftMax function{} since OUTPUT_TYPE is PROB " + "(and therefore the relevant element is ambiguous)".format(self.owner_name)) - output_type = self._get_current_parameter_value(OUTPUT_TYPE, context) - if output_type == ALL: - # Return full Jacobian matrix of derivatives using Kronecker's delta method: - derivative = np.empty([size, size]) - for i, j in np.ndindex(size, size): - if i == j: - d = 1 - else: - d = 0 - derivative[j, i] = sm[i] * (d - sm[j]) - elif output_type in {MAX_VAL, MAX_INDICATOR}: - # Return 1d array of derivatives for max element (i.e., the one chosen by SoftMax) - derivative = np.empty(size) - # Get the element of output returned as non-zero (max val) when output_type is not ALL - # IMPLEMENTATION NOTES: - # if there is a tie for max, this chooses the item in sm with the lowest index in sm: - index_of_max = int(np.where(sm == np.max(sm))[-1][0]) - # the following would randomly choose a value in case of a tie, - # but may cause problems with compilation: - # index_of_max = np.where(sm == np.max(sm))[0] - # if len(index_of_max)>1: - # index_of_max = int(np.random.choice(index_of_max)) - max_val = sm[index_of_max] - for i in range(size): - if i == index_of_max: - d = 1 - else: - d = 0 - derivative[i] = sm[i] * (d - max_val) - else: - raise FunctionError("Can't assign derivative for SoftMax function{} since OUTPUT_TYPE is PROB " - "(and therefore the relevant element is ambiguous)".format(self.owner_name)) + result.append(derivative) - return derivative + assert per_item or len(result) == 1 + return result[0] if not per_item else result # ********************************************************************************************************************** diff --git a/tests/functions/test_transfer.py b/tests/functions/test_transfer.py index 55124005a1c..b9f2535ee1f 100644 --- a/tests/functions/test_transfer.py +++ b/tests/functions/test_transfer.py @@ -129,35 +129,56 @@ def test_execute(func, variable, params, expected, benchmark, func_mode): [-0.010680386821751537, -0.011118109698906909, -0.01082040340318878, -0.010670257514724047, -0.010362498859374309, -0.010933660158663306, -0.010397412260182806, -0.011602329078808718, 0.09684744183944892, -0.010262384043848513]), (Functions.SoftMax, test_var, {kw.GAIN:RAND1, kw.OUTPUT_TYPE:kw.ALL, kw.PER_ITEM:False}, - [[ 0.08863569, -0.01005855, -0.00978921, -0.00965338, -0.00937495, -0.00989168, -0.00940653, -0.01049662, -0.01068039, -0.00928437], - [-0.01005855, 0.09185608, -0.01019041, -0.01004901, -0.00975917, -0.01029708, -0.00979205, -0.01092681, -0.01111811, -0.00966488], - [-0.00978921, -0.01019041, 0.08966934, -0.00977993, -0.00949785, -0.01002135, -0.00952985, -0.01063423, -0.0108204, -0.00940609], - [-0.00965338, -0.01004901, -0.00977993, 0.08856078, -0.00936606, -0.0098823, -0.00939761, -0.01048667, -0.01067026, -0.00927557], - [-0.00937495, -0.00975917, -0.00949785, -0.00936606, 0.08627659, -0.00959726, -0.00912656, -0.0101842, -0.0103625, -0.00900804], - [-0.00989168, -0.01029708, -0.01002135, -0.0098823, -0.00959726, 0.09050301, -0.0096296, -0.01074554, -0.01093366, -0.00950454], - [-0.00940653, -0.00979205, -0.00952985, -0.00939761, -0.00912656, -0.0096296, 0.08653653, -0.01021852, -0.01039741, -0.00903839], - [-0.01049662, -0.01092681, -0.01063423, -0.01048667, -0.0101842, -0.01074554, -0.01021852, 0.09538073, -0.01160233, -0.01008581], - [-0.01068039, -0.01111811, -0.0108204, -0.01067026, -0.0103625, -0.01093366, -0.01039741, -0.01160233, 0.09684744, -0.01026238], - [-0.00928437, -0.00966488, -0.00940609, -0.00927557, -0.00900804, -0.00950454, -0.00903839, -0.01008581, -0.01026238, 0.08553008]]), - - # SoftMax per-tem=True 2D single element + [[ 0.088635686173821480, -0.010058549286956951, -0.009789214523259433, -0.009653377599514660, -0.009374948470179183, + -0.009891677863509920, -0.009406534609578588, -0.010496622361458180, -0.010680386821751540, -0.009284374637613039], + [-0.010058549286956951, 0.091856076128865180, -0.010190413769852785, -0.010049009732287338, -0.009759169518165271, + -0.010297076447528582, -0.009792050177702091, -0.010926813872042194, -0.011118109698906910, -0.009664883625423075], + [-0.009789214523259433, -0.010190413769852785, 0.089669339130699100, -0.009779930406389987, -0.009497851156931268, + -0.010021354713444461, -0.009529851380888969, -0.010634229847424508, -0.010820403403188785, -0.009406089929318929], + [-0.009653377599514660, -0.010049009732287338, -0.009779930406389987, 0.088560779144081720, -0.009366057244326959, + -0.009882296570138368, -0.009397613427348460, -0.010486667337129447, -0.010670257514724050, -0.009275569312222474], + [-0.009374948470179183, -0.009759169518165271, -0.009497851156931268, -0.009366057244326959, 0.08627659236704915, + -0.009597264807784339, -0.009126561218167337, -0.010184203911638403, -0.010362498859374313, -0.009008037180482098], + [-0.009891677863509920, -0.010297076447528582, -0.010021354713444461, -0.009882296570138368, -0.009597264807784339, + 0.090503011588098000, -0.009629599976882700, -0.010745537931292683, -0.010933660158663310, -0.009504543118853646], + [-0.009406534609578588, -0.009792050177702091, -0.009529851380888969, -0.009397613427348460, -0.009126561218167337, + -0.009629599976882700, 0.086536526770559590, -0.010218516599910580, -0.010397412260182810, -0.009038387119898062], + [-0.010496622361458180, -0.010926813872042194, -0.010634229847424508, -0.010486667337129447, -0.010184203911638403, + -0.010745537931292683, -0.010218516599910580, 0.095380732590004670, -0.011602329078808723, -0.01008581165029997], + [-0.010680386821751540, -0.011118109698906910, -0.010820403403188785, -0.010670257514724050, -0.010362498859374313, + -0.010933660158663310, -0.010397412260182810, -0.011602329078808723, 0.096847441839448930, -0.010262384043848514], + [-0.009284374637613039, -0.009664883625423075, -0.009406089929318929, -0.009275569312222474, -0.009008037180482098, + -0.009504543118853646, -0.009038387119898062, -0.010085811650299970, -0.010262384043848514, 0.08553008061795979]]), + # SoftMax per-tem=True (Functions.SoftMax, [test_var], {kw.GAIN:RAND1, kw.OUTPUT_TYPE:kw.MAX_VAL, kw.PER_ITEM:True}, [[-0.010680386821751537, -0.011118109698906909, -0.01082040340318878, -0.010670257514724047, -0.010362498859374309, -0.010933660158663306, -0.010397412260182806, -0.011602329078808718, 0.09684744183944892, -0.010262384043848513]]), - (Functions.SoftMax, [test_var], {kw.GAIN:RAND1, kw.OUTPUT_TYPE:kw.MAX_INDICATOR, kw.PER_ITEM:True}, + (Functions.SoftMax, [test_var, test_var], {kw.GAIN:RAND1, kw.OUTPUT_TYPE:kw.MAX_INDICATOR, kw.PER_ITEM:True}, [[-0.010680386821751537, -0.011118109698906909, -0.01082040340318878, -0.010670257514724047, -0.010362498859374309, + -0.010933660158663306, -0.010397412260182806, -0.011602329078808718, 0.09684744183944892, -0.010262384043848513], + [-0.010680386821751537, -0.011118109698906909, -0.01082040340318878, -0.010670257514724047, -0.010362498859374309, -0.010933660158663306, -0.010397412260182806, -0.011602329078808718, 0.09684744183944892, -0.010262384043848513]]), (Functions.SoftMax, [test_var], {kw.GAIN:RAND1, kw.OUTPUT_TYPE:kw.ALL, kw.PER_ITEM:True}, - [[ 0.08863569, -0.01005855, -0.00978921, -0.00965338, -0.00937495, -0.00989168, -0.00940653, -0.01049662, -0.01068039, -0.00928437], - [-0.01005855, 0.09185608, -0.01019041, -0.01004901, -0.00975917, -0.01029708, -0.00979205, -0.01092681, -0.01111811, -0.00966488], - [-0.00978921, -0.01019041, 0.08966934, -0.00977993, -0.00949785, -0.01002135, -0.00952985, -0.01063423, -0.0108204, -0.00940609], - [-0.00965338, -0.01004901, -0.00977993, 0.08856078, -0.00936606, -0.0098823, -0.00939761, -0.01048667, -0.01067026, -0.00927557], - [-0.00937495, -0.00975917, -0.00949785, -0.00936606, 0.08627659, -0.00959726, -0.00912656, -0.0101842, -0.0103625, -0.00900804], - [-0.00989168, -0.01029708, -0.01002135, -0.0098823, -0.00959726, 0.09050301, -0.0096296, -0.01074554, -0.01093366, -0.00950454], - [-0.00940653, -0.00979205, -0.00952985, -0.00939761, -0.00912656, -0.0096296, 0.08653653, -0.01021852, -0.01039741, -0.00903839], - [-0.01049662, -0.01092681, -0.01063423, -0.01048667, -0.0101842, -0.01074554, -0.01021852, 0.09538073, -0.01160233, -0.01008581], - [-0.01068039, -0.01111811, -0.0108204, -0.01067026, -0.0103625, -0.01093366, -0.01039741, -0.01160233, 0.09684744, -0.01026238], - [-0.00928437, -0.00966488, -0.00940609, -0.00927557, -0.00900804, -0.00950454, -0.00903839, -0.01008581, -0.01026238, 0.08553008]]), + [[[ 0.088635686173821480, -0.010058549286956951, -0.009789214523259433, -0.009653377599514660, -0.009374948470179183, + -0.009891677863509920, -0.009406534609578588, -0.010496622361458180, -0.010680386821751540, -0.009284374637613039], + [-0.010058549286956951, 0.091856076128865180, -0.010190413769852785, -0.010049009732287338, -0.009759169518165271, + -0.010297076447528582, -0.009792050177702091, -0.010926813872042194, -0.011118109698906910, -0.009664883625423075], + [-0.009789214523259433, -0.010190413769852785, 0.089669339130699100, -0.009779930406389987, -0.009497851156931268, + -0.010021354713444461, -0.009529851380888969, -0.010634229847424508, -0.010820403403188785, -0.009406089929318929], + [-0.009653377599514660, -0.010049009732287338, -0.009779930406389987, 0.088560779144081720, -0.009366057244326959, + -0.009882296570138368, -0.009397613427348460, -0.010486667337129447, -0.010670257514724050, -0.009275569312222474], + [-0.009374948470179183, -0.009759169518165271, -0.009497851156931268, -0.009366057244326959, 0.08627659236704915, + -0.009597264807784339, -0.009126561218167337, -0.010184203911638403, -0.010362498859374313, -0.009008037180482098], + [-0.009891677863509920, -0.010297076447528582, -0.010021354713444461, -0.009882296570138368, -0.009597264807784339, + 0.090503011588098000, -0.009629599976882700, -0.010745537931292683, -0.010933660158663310, -0.009504543118853646], + [-0.009406534609578588, -0.009792050177702091, -0.009529851380888969, -0.009397613427348460, -0.009126561218167337, + -0.009629599976882700, 0.086536526770559590, -0.010218516599910580, -0.010397412260182810, -0.009038387119898062], + [-0.010496622361458180, -0.010926813872042194, -0.010634229847424508, -0.010486667337129447, -0.010184203911638403, + -0.010745537931292683, -0.010218516599910580, 0.095380732590004670, -0.011602329078808723, -0.01008581165029997], + [-0.010680386821751540, -0.011118109698906910, -0.010820403403188785, -0.010670257514724050, -0.010362498859374313, + -0.010933660158663310, -0.010397412260182810, -0.011602329078808723, 0.096847441839448930, -0.010262384043848514], + [-0.009284374637613039, -0.009664883625423075, -0.009406089929318929, -0.009275569312222474, -0.009008037180482098, + -0.009504543118853646, -0.009038387119898062, -0.010085811650299970, -0.010262384043848514, 0.08553008061795979]]]), ] @pytest.mark.function @@ -180,7 +201,14 @@ def test_transfer_derivative(func, variable, params, expected, benchmark, func_m assert False, "unknown function mode: {}".format(func_mode) res = benchmark(ex, variable) - assert np.allclose(res, expected) + + # Tanh and Logistic need reduced accuracy in single precision mode + if func_mode != 'Python' and func in {Functions.Tanh, Functions.Logistic} and pytest.helpers.llvm_current_fp_precision() == 'fp32': + tolerance = {'rtol': 5e-7, 'atol': 1e-8} + else: + tolerance = {} + + np.testing.assert_allclose(res, expected, **tolerance) derivative_out_test_data = [ @@ -189,8 +217,10 @@ def test_transfer_derivative(func, variable, params, expected, benchmark, func_m (Functions.SoftMax, softmax_helper, {kw.GAIN:RAND1, kw.OUTPUT_TYPE:kw.MAX_VAL, kw.PER_ITEM:False}, [-0.010680386821751537, -0.011118109698906909, -0.01082040340318878, -0.010670257514724047, -0.010362498859374309, -0.010933660158663306, -0.010397412260182806, -0.011602329078808718, 0.09684744183944892, -0.010262384043848513]), - (Functions.SoftMax, [softmax_helper], {kw.GAIN:RAND1, kw.OUTPUT_TYPE:kw.MAX_VAL, kw.PER_ITEM:True}, + (Functions.SoftMax, [softmax_helper, softmax_helper], {kw.GAIN:RAND1, kw.OUTPUT_TYPE:kw.MAX_VAL, kw.PER_ITEM:True}, [[-0.010680386821751537, -0.011118109698906909, -0.01082040340318878, -0.010670257514724047, -0.010362498859374309, + -0.010933660158663306, -0.010397412260182806, -0.011602329078808718, 0.09684744183944892, -0.010262384043848513], + [-0.010680386821751537, -0.011118109698906909, -0.01082040340318878, -0.010670257514724047, -0.010362498859374309, -0.010933660158663306, -0.010397412260182806, -0.011602329078808718, 0.09684744183944892, -0.010262384043848513]]), ] @pytest.mark.function @@ -214,9 +244,14 @@ def ex(x): assert False, "unknown function mode: {}".format(func_mode) res = benchmark(ex, variable) - # FIX: THIS FAILS FOR func_mode=Python, func=SoftMax, and kw.PER_ITEM:True: - # EXPECTS 2d BUT ONLY 1D IS RETURNED - assert np.allclose(res, expected) + + # Logistic needs reduced accuracy in single precision mode + if func_mode != 'Python' and func is Functions.Logistic and pytest.helpers.llvm_current_fp_precision() == 'fp32': + tolerance = {'rtol': 1e-7, 'atol': 1e-8} + else: + tolerance = {} + + np.testing.assert_allclose(res, expected, **tolerance) def test_transfer_with_costs_function(): f = Functions.TransferWithCosts()