Skip to content

Commit

Permalink
functions/SoftMax: Restore correct computation of derivation
Browse files Browse the repository at this point in the history
Commit cae1465
	("llvm, functions/SoftMax: Implement compiled 'derivative' variant")
incorrectly assumed that the use of 'output' was an oversight.
It wasn't SoftMax derivative can take advantage of results if available.
This change restores the original functionality and adds a path to
compute the results if output is None.
This is used for testing where the results would need to be calculated
anyway.
The compiled variant is adapted in the same way, and the test are
updated to reflect the new results.

Signed-off-by: Jan Vesely <[email protected]>
  • Loading branch information
jvesely committed Nov 11, 2022
1 parent 20fdbf3 commit a663d2c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2700,8 +2700,18 @@ def __gen_llvm_apply(self, ctx, builder, params, state, arg_in, arg_out, output_
def _gen_llvm_function_derivative_body(self, ctx, builder, params, state, arg_in, arg_out, *, tags:frozenset):
assert "derivative" in tags
forward_tags = tags.difference({"derivative"})

# SoftMax derivative is calculated from the results. Recalculate them here
base_out = builder.alloca(arg_out.type.pointee)
builder = self._gen_llvm_function_body(ctx, builder, params, state, arg_in, base_out, output_type=self.output, tags=forward_tags)


all_out = builder.alloca(arg_out.type.pointee)
builder = self._gen_llvm_function_body(ctx, builder, params, state, arg_in, all_out, output_type=ALL, tags=forward_tags)
builder = self._gen_llvm_function_body(ctx, builder, params, state, base_out, all_out, output_type=ALL, tags=forward_tags)

# The rest of the algorithm is for MAX_VAL and MAX_INDICATOR only
assert self.output in {MAX_VAL, MAX_INDICATOR}, \
"Derivative of SoftMax is only implemented for MAX_VAL and MAX_INDICATOR! ({})".format(self.output)

max_pos_ptr = builder.alloca(ctx.int32_ty)
builder.store(max_pos_ptr.type.pointee(-1), max_pos_ptr)
Expand Down Expand Up @@ -2819,9 +2829,12 @@ def derivative(self, input=None, output=None, context=None):
derivative of values returned by SoftMax : 1d or 2d array (depending on *OUTPUT_TYPE* of SoftMax)
"""

if output is None:
output = self.function(input, context=context)

output_type = self._get_current_parameter_value(OUTPUT_TYPE, context)
size = len(input)
sm = self.function(input, params={OUTPUT_TYPE: ALL}, context=context)
size = len(output)
sm = self.function(output, params={OUTPUT_TYPE: ALL}, context=context)
sm = np.squeeze(sm)

if output_type == ALL:
Expand All @@ -2839,7 +2852,7 @@ def derivative(self, input=None, output=None, context=None):
# Return 1d array of derivatives for max element (i.e., the one chosen by SoftMax)
derivative = np.empty(size)
# Get the element of output returned as non-zero when output_type is not ALL
index_of_max = int(np.where(sm == np.max(sm))[0])
index_of_max = int(np.where(output == np.max(output))[0][0])
max_val = sm[index_of_max]
for i in range(size):
if i == index_of_max:
Expand Down
8 changes: 4 additions & 4 deletions tests/functions/test_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ def test_execute(func, variable, params, expected, benchmark, func_mode):
(Functions.ReLU, test_var, {'gain':RAND1, 'bias':RAND2, 'leak':RAND3}, np.where((test_var - RAND2) > 0, RAND1, RAND1 * RAND3)),
(Functions.Tanh, test_var, {'gain':RAND1, 'bias':RAND2, 'offset':RAND3, 'scale':RAND4}, tanh_derivative_helper),
(Functions.SoftMax, test_var, {'gain':RAND1, 'params':{kw.OUTPUT_TYPE:kw.MAX_VAL}, 'per_item': False},
[-0.010680386821751537, -0.011118109698906909, -0.01082040340318878, -0.010670257514724047, -0.010362498859374309,
-0.010933660158663306, -0.010397412260182806, -0.011602329078808718, 0.09684744183944892, -0.010262384043848513]),
[-0.010211427111966652, -0.010211427111966652, -0.010211427111966652, -0.010211427111966652, -0.010211427111966652,
-0.010211427111966652, -0.010211427111966652, -0.010211427111966652, 0.09190284400769985, -0.010211427111966652]),
(Functions.SoftMax, test_var, {'gain':RAND1, 'params':{kw.OUTPUT_TYPE:kw.MAX_INDICATOR}, 'per_item': False},
[-0.010680386821751537, -0.011118109698906909, -0.01082040340318878, -0.010670257514724047, -0.010362498859374309,
-0.010933660158663306, -0.010397412260182806, -0.011602329078808718, 0.09684744183944892, -0.010262384043848513]),
[-0.012062786611097685, -0.012062786611097685, -0.012062786611097685, -0.012062786611097685, -0.012062786611097685,
-0.012062786611097685, -0.012062786611097685, -0.012062786611097685, 0.10856507949987917, -0.012062786611097685]),
]

@pytest.mark.function
Expand Down

0 comments on commit a663d2c

Please sign in to comment.