Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

functions/ReLU,SoftMax: Restore functionality broken in #2528 #2532

Merged
merged 2 commits into from
Nov 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1593,7 +1593,7 @@ def _gen_llvm_transfer(self, builder, index, ctx, vi, vo, params, state, *, tags
builder.store(val, ptro)

@handle_external_context()
def derivative(self, variable, output=None, context=None):
def derivative(self, input=None, output=None, context=None):
"""
derivative(input)

Expand All @@ -1615,9 +1615,9 @@ def derivative(self, variable, output=None, context=None):
leak = self._get_current_parameter_value(LEAK, context)
bias = self._get_current_parameter_value(BIAS, context)

value = np.empty_like(variable)
value[(variable - bias) > 0] = gain
value[(variable - bias) <= 0] = gain * leak
value = np.empty_like(input)
value[(input - bias) > 0] = gain
value[(input - bias) <= 0] = gain * leak

return value

Expand Down Expand Up @@ -2700,8 +2700,18 @@ def __gen_llvm_apply(self, ctx, builder, params, state, arg_in, arg_out, output_
def _gen_llvm_function_derivative_body(self, ctx, builder, params, state, arg_in, arg_out, *, tags:frozenset):
assert "derivative" in tags
forward_tags = tags.difference({"derivative"})

# SoftMax derivative is calculated from the results. Recalculate them here
base_out = builder.alloca(arg_out.type.pointee)
builder = self._gen_llvm_function_body(ctx, builder, params, state, arg_in, base_out, output_type=self.output, tags=forward_tags)


all_out = builder.alloca(arg_out.type.pointee)
builder = self._gen_llvm_function_body(ctx, builder, params, state, arg_in, all_out, output_type=ALL, tags=forward_tags)
builder = self._gen_llvm_function_body(ctx, builder, params, state, base_out, all_out, output_type=ALL, tags=forward_tags)

# The rest of the algorithm is for MAX_VAL and MAX_INDICATOR only
assert self.output in {MAX_VAL, MAX_INDICATOR}, \
"Derivative of SoftMax is only implemented for MAX_VAL and MAX_INDICATOR! ({})".format(self.output)

max_pos_ptr = builder.alloca(ctx.int32_ty)
builder.store(max_pos_ptr.type.pointee(-1), max_pos_ptr)
Expand Down Expand Up @@ -2819,9 +2829,12 @@ def derivative(self, input=None, output=None, context=None):
derivative of values returned by SoftMax : 1d or 2d array (depending on *OUTPUT_TYPE* of SoftMax)
"""

if output is None:
output = self.function(input, context=context)

output_type = self._get_current_parameter_value(OUTPUT_TYPE, context)
size = len(input)
sm = self.function(input, params={OUTPUT_TYPE: ALL}, context=context)
size = len(output)
sm = self.function(output, params={OUTPUT_TYPE: ALL}, context=context)
sm = np.squeeze(sm)

if output_type == ALL:
Expand All @@ -2839,7 +2852,7 @@ def derivative(self, input=None, output=None, context=None):
# Return 1d array of derivatives for max element (i.e., the one chosen by SoftMax)
derivative = np.empty(size)
# Get the element of output returned as non-zero when output_type is not ALL
index_of_max = int(np.where(sm == np.max(sm))[0])
index_of_max = int(np.where(output == np.max(output))[0][0])
max_val = sm[index_of_max]
for i in range(size):
if i == index_of_max:
Expand Down
8 changes: 4 additions & 4 deletions tests/functions/test_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ def test_execute(func, variable, params, expected, benchmark, func_mode):
(Functions.ReLU, test_var, {'gain':RAND1, 'bias':RAND2, 'leak':RAND3}, np.where((test_var - RAND2) > 0, RAND1, RAND1 * RAND3)),
(Functions.Tanh, test_var, {'gain':RAND1, 'bias':RAND2, 'offset':RAND3, 'scale':RAND4}, tanh_derivative_helper),
(Functions.SoftMax, test_var, {'gain':RAND1, 'params':{kw.OUTPUT_TYPE:kw.MAX_VAL}, 'per_item': False},
[-0.010680386821751537, -0.011118109698906909, -0.01082040340318878, -0.010670257514724047, -0.010362498859374309,
-0.010933660158663306, -0.010397412260182806, -0.011602329078808718, 0.09684744183944892, -0.010262384043848513]),
[-0.010211427111966652, -0.010211427111966652, -0.010211427111966652, -0.010211427111966652, -0.010211427111966652,
-0.010211427111966652, -0.010211427111966652, -0.010211427111966652, 0.09190284400769985, -0.010211427111966652]),
(Functions.SoftMax, test_var, {'gain':RAND1, 'params':{kw.OUTPUT_TYPE:kw.MAX_INDICATOR}, 'per_item': False},
[-0.010680386821751537, -0.011118109698906909, -0.01082040340318878, -0.010670257514724047, -0.010362498859374309,
-0.010933660158663306, -0.010397412260182806, -0.011602329078808718, 0.09684744183944892, -0.010262384043848513]),
[-0.012062786611097685, -0.012062786611097685, -0.012062786611097685, -0.012062786611097685, -0.012062786611097685,
-0.012062786611097685, -0.012062786611097685, -0.012062786611097685, 0.10856507949987917, -0.012062786611097685]),
]

@pytest.mark.function
Expand Down