diff --git a/psyneulink/core/components/functions/nonstateful/transferfunctions.py b/psyneulink/core/components/functions/nonstateful/transferfunctions.py index da3bfd300b0..381d80d9c6b 100644 --- a/psyneulink/core/components/functions/nonstateful/transferfunctions.py +++ b/psyneulink/core/components/functions/nonstateful/transferfunctions.py @@ -1593,7 +1593,7 @@ def _gen_llvm_transfer(self, builder, index, ctx, vi, vo, params, state, *, tags builder.store(val, ptro) @handle_external_context() - def derivative(self, variable, output=None, context=None): + def derivative(self, input=None, output=None, context=None): """ derivative(input) @@ -1615,9 +1615,9 @@ def derivative(self, variable, output=None, context=None): leak = self._get_current_parameter_value(LEAK, context) bias = self._get_current_parameter_value(BIAS, context) - value = np.empty_like(variable) - value[(variable - bias) > 0] = gain - value[(variable - bias) <= 0] = gain * leak + value = np.empty_like(input) + value[(input - bias) > 0] = gain + value[(input - bias) <= 0] = gain * leak return value @@ -2700,8 +2700,18 @@ def __gen_llvm_apply(self, ctx, builder, params, state, arg_in, arg_out, output_ def _gen_llvm_function_derivative_body(self, ctx, builder, params, state, arg_in, arg_out, *, tags:frozenset): assert "derivative" in tags forward_tags = tags.difference({"derivative"}) + + # SoftMax derivative is calculated from the results. Recalculate them here + base_out = builder.alloca(arg_out.type.pointee) + builder = self._gen_llvm_function_body(ctx, builder, params, state, arg_in, base_out, output_type=self.output, tags=forward_tags) + + all_out = builder.alloca(arg_out.type.pointee) - builder = self._gen_llvm_function_body(ctx, builder, params, state, arg_in, all_out, output_type=ALL, tags=forward_tags) + builder = self._gen_llvm_function_body(ctx, builder, params, state, base_out, all_out, output_type=ALL, tags=forward_tags) + + # The rest of the algorithm is for MAX_VAL and MAX_INDICATOR only + assert self.output in {MAX_VAL, MAX_INDICATOR}, \ + "Derivative of SoftMax is only implemented for MAX_VAL and MAX_INDICATOR! ({})".format(self.output) max_pos_ptr = builder.alloca(ctx.int32_ty) builder.store(max_pos_ptr.type.pointee(-1), max_pos_ptr) @@ -2819,9 +2829,12 @@ def derivative(self, input=None, output=None, context=None): derivative of values returned by SoftMax : 1d or 2d array (depending on *OUTPUT_TYPE* of SoftMax) """ + if output is None: + output = self.function(input, context=context) + output_type = self._get_current_parameter_value(OUTPUT_TYPE, context) - size = len(input) - sm = self.function(input, params={OUTPUT_TYPE: ALL}, context=context) + size = len(output) + sm = self.function(output, params={OUTPUT_TYPE: ALL}, context=context) sm = np.squeeze(sm) if output_type == ALL: @@ -2839,7 +2852,7 @@ def derivative(self, input=None, output=None, context=None): # Return 1d array of derivatives for max element (i.e., the one chosen by SoftMax) derivative = np.empty(size) # Get the element of output returned as non-zero when output_type is not ALL - index_of_max = int(np.where(sm == np.max(sm))[0]) + index_of_max = int(np.where(output == np.max(output))[0][0]) max_val = sm[index_of_max] for i in range(size): if i == index_of_max: diff --git a/tests/functions/test_transfer.py b/tests/functions/test_transfer.py index 47dc5c0a94e..3c4cef5beef 100644 --- a/tests/functions/test_transfer.py +++ b/tests/functions/test_transfer.py @@ -81,11 +81,11 @@ def test_execute(func, variable, params, expected, benchmark, func_mode): (Functions.ReLU, test_var, {'gain':RAND1, 'bias':RAND2, 'leak':RAND3}, np.where((test_var - RAND2) > 0, RAND1, RAND1 * RAND3)), (Functions.Tanh, test_var, {'gain':RAND1, 'bias':RAND2, 'offset':RAND3, 'scale':RAND4}, tanh_derivative_helper), (Functions.SoftMax, test_var, {'gain':RAND1, 'params':{kw.OUTPUT_TYPE:kw.MAX_VAL}, 'per_item': False}, - [-0.010680386821751537, -0.011118109698906909, -0.01082040340318878, -0.010670257514724047, -0.010362498859374309, - -0.010933660158663306, -0.010397412260182806, -0.011602329078808718, 0.09684744183944892, -0.010262384043848513]), + [-0.010211427111966652, -0.010211427111966652, -0.010211427111966652, -0.010211427111966652, -0.010211427111966652, + -0.010211427111966652, -0.010211427111966652, -0.010211427111966652, 0.09190284400769985, -0.010211427111966652]), (Functions.SoftMax, test_var, {'gain':RAND1, 'params':{kw.OUTPUT_TYPE:kw.MAX_INDICATOR}, 'per_item': False}, - [-0.010680386821751537, -0.011118109698906909, -0.01082040340318878, -0.010670257514724047, -0.010362498859374309, - -0.010933660158663306, -0.010397412260182806, -0.011602329078808718, 0.09684744183944892, -0.010262384043848513]), + [-0.012062786611097685, -0.012062786611097685, -0.012062786611097685, -0.012062786611097685, -0.012062786611097685, + -0.012062786611097685, -0.012062786611097685, -0.012062786611097685, 0.10856507949987917, -0.012062786611097685]), ] @pytest.mark.function