Skip to content

Commit

Permalink
functions/ReLU,SoftMax: Restore functionality broken in #2528 (#2532)
Browse files Browse the repository at this point in the history
PR #2528 introduced two breaking changes
1.) Renamed ReLU's derivative argument 'input'->'variable' which broke existing users of ReLU derivative.
2.) Changed SoftMax derivative calculation to use 'input' instead of 'output'. This is incorrect as the algorithm reuses existing outputs to calculate derivative results.

Both are restored to their original form.
Moreover, a single input codepath is added to SoftMax, calculating results if they are not provided.
The compiled version and tests are adjusted accordingly.
  • Loading branch information
jvesely authored Nov 11, 2022
2 parents e462f10 + a663d2c commit 4a40149
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1593,7 +1593,7 @@ def _gen_llvm_transfer(self, builder, index, ctx, vi, vo, params, state, *, tags
builder.store(val, ptro)

@handle_external_context()
def derivative(self, variable, output=None, context=None):
def derivative(self, input=None, output=None, context=None):
"""
derivative(input)
Expand All @@ -1615,9 +1615,9 @@ def derivative(self, variable, output=None, context=None):
leak = self._get_current_parameter_value(LEAK, context)
bias = self._get_current_parameter_value(BIAS, context)

value = np.empty_like(variable)
value[(variable - bias) > 0] = gain
value[(variable - bias) <= 0] = gain * leak
value = np.empty_like(input)
value[(input - bias) > 0] = gain
value[(input - bias) <= 0] = gain * leak

return value

Expand Down Expand Up @@ -2700,8 +2700,18 @@ def __gen_llvm_apply(self, ctx, builder, params, state, arg_in, arg_out, output_
def _gen_llvm_function_derivative_body(self, ctx, builder, params, state, arg_in, arg_out, *, tags:frozenset):
assert "derivative" in tags
forward_tags = tags.difference({"derivative"})

# SoftMax derivative is calculated from the results. Recalculate them here
base_out = builder.alloca(arg_out.type.pointee)
builder = self._gen_llvm_function_body(ctx, builder, params, state, arg_in, base_out, output_type=self.output, tags=forward_tags)


all_out = builder.alloca(arg_out.type.pointee)
builder = self._gen_llvm_function_body(ctx, builder, params, state, arg_in, all_out, output_type=ALL, tags=forward_tags)
builder = self._gen_llvm_function_body(ctx, builder, params, state, base_out, all_out, output_type=ALL, tags=forward_tags)

# The rest of the algorithm is for MAX_VAL and MAX_INDICATOR only
assert self.output in {MAX_VAL, MAX_INDICATOR}, \
"Derivative of SoftMax is only implemented for MAX_VAL and MAX_INDICATOR! ({})".format(self.output)

max_pos_ptr = builder.alloca(ctx.int32_ty)
builder.store(max_pos_ptr.type.pointee(-1), max_pos_ptr)
Expand Down Expand Up @@ -2819,9 +2829,12 @@ def derivative(self, input=None, output=None, context=None):
derivative of values returned by SoftMax : 1d or 2d array (depending on *OUTPUT_TYPE* of SoftMax)
"""

if output is None:
output = self.function(input, context=context)

output_type = self._get_current_parameter_value(OUTPUT_TYPE, context)
size = len(input)
sm = self.function(input, params={OUTPUT_TYPE: ALL}, context=context)
size = len(output)
sm = self.function(output, params={OUTPUT_TYPE: ALL}, context=context)
sm = np.squeeze(sm)

if output_type == ALL:
Expand All @@ -2839,7 +2852,7 @@ def derivative(self, input=None, output=None, context=None):
# Return 1d array of derivatives for max element (i.e., the one chosen by SoftMax)
derivative = np.empty(size)
# Get the element of output returned as non-zero when output_type is not ALL
index_of_max = int(np.where(sm == np.max(sm))[0])
index_of_max = int(np.where(output == np.max(output))[0][0])
max_val = sm[index_of_max]
for i in range(size):
if i == index_of_max:
Expand Down
8 changes: 4 additions & 4 deletions tests/functions/test_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ def test_execute(func, variable, params, expected, benchmark, func_mode):
(Functions.ReLU, test_var, {'gain':RAND1, 'bias':RAND2, 'leak':RAND3}, np.where((test_var - RAND2) > 0, RAND1, RAND1 * RAND3)),
(Functions.Tanh, test_var, {'gain':RAND1, 'bias':RAND2, 'offset':RAND3, 'scale':RAND4}, tanh_derivative_helper),
(Functions.SoftMax, test_var, {'gain':RAND1, 'params':{kw.OUTPUT_TYPE:kw.MAX_VAL}, 'per_item': False},
[-0.010680386821751537, -0.011118109698906909, -0.01082040340318878, -0.010670257514724047, -0.010362498859374309,
-0.010933660158663306, -0.010397412260182806, -0.011602329078808718, 0.09684744183944892, -0.010262384043848513]),
[-0.010211427111966652, -0.010211427111966652, -0.010211427111966652, -0.010211427111966652, -0.010211427111966652,
-0.010211427111966652, -0.010211427111966652, -0.010211427111966652, 0.09190284400769985, -0.010211427111966652]),
(Functions.SoftMax, test_var, {'gain':RAND1, 'params':{kw.OUTPUT_TYPE:kw.MAX_INDICATOR}, 'per_item': False},
[-0.010680386821751537, -0.011118109698906909, -0.01082040340318878, -0.010670257514724047, -0.010362498859374309,
-0.010933660158663306, -0.010397412260182806, -0.011602329078808718, 0.09684744183944892, -0.010262384043848513]),
[-0.012062786611097685, -0.012062786611097685, -0.012062786611097685, -0.012062786611097685, -0.012062786611097685,
-0.012062786611097685, -0.012062786611097685, -0.012062786611097685, 0.10856507949987917, -0.012062786611097685]),
]

@pytest.mark.function
Expand Down

0 comments on commit 4a40149

Please sign in to comment.