Skip to content

Commit

Permalink
functions/TransferFunction: Add support for 'per-item' mode derivative
Browse files Browse the repository at this point in the history
Increase precision of expected results and use
np.testing.assert_allclose.
Use multiple elements in per-item derivative test inputs.

Signed-off-by: Jan Vesely <[email protected]>
  • Loading branch information
jvesely committed Apr 14, 2023
1 parent e96bbc8 commit 657e737
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3159,7 +3159,19 @@ def _gen_llvm_function_derivative_body(self, ctx, builder, params, state, arg_in
all_out = builder.alloca(arg_out.type.pointee)
builder = self._gen_llvm_function_body(ctx, builder, params, state, arg_in, all_out, output_type=ALL, tags=forward_tags)

# The rest of the algorithm is for MAX_VAL and MAX_INDICATOR only
if self.parameters.per_item.get():
assert isinstance(arg_in.type.pointee.element, pnlvm.ir.ArrayType)
assert isinstance(arg_out.type.pointee.element, pnlvm.ir.ArrayType)
for i in range(arg_in.type.pointee.count):
inner_all_out = builder.gep(all_out, [ctx.int32_ty(0), ctx.int32_ty(i)])
inner_out = builder.gep(arg_out, [ctx.int32_ty(0), ctx.int32_ty(i)])
builder = self.__gen_llvm_apply_derivative(ctx, builder, params, state, inner_all_out, inner_out, tags=tags)
return builder
else:
return self.__gen_llvm_apply_derivative(ctx, builder, params, state, all_out, arg_out, tags=tags)

def __gen_llvm_apply_derivative(self, ctx, builder, params, state, all_out, arg_out, *, tags:frozenset):

assert self.output in {MAX_VAL, MAX_INDICATOR}, \
"Derivative of SoftMax is only implemented for MAX_VAL and MAX_INDICATOR! ({})".format(self.output)

Expand Down Expand Up @@ -3292,44 +3304,51 @@ def derivative(self, input=None, output=None, context=None):
else:
assert not np.any(np.equal(0, output))

sm = np.squeeze(output)
size = len(sm)
assert (len(output) == 1 and len(output[0]) == size) or len(output) == size
per_item = self._get_current_parameter_value(PER_ITEM, context)
if not per_item:
output = [output]

result = []
for sm in output:
size = len(sm)

output_type = self._get_current_parameter_value(OUTPUT_TYPE, context)
if output_type == ALL:
# Return full Jacobian matrix of derivatives using Kronecker's delta method:
derivative = np.empty([size, size])
for i, j in np.ndindex(size, size):
if i == j:
d = 1
else:
d = 0
derivative[j, i] = sm[i] * (d - sm[j])
elif output_type in {MAX_VAL, MAX_INDICATOR}:
# Return 1d array of derivatives for max element (i.e., the one chosen by SoftMax)
derivative = np.empty(size)
# Get the element of output returned as non-zero (max val) when output_type is not ALL
# IMPLEMENTATION NOTES:
# if there is a tie for max, this chooses the item in sm with the lowest index in sm:
index_of_max = int(np.where(sm == np.max(sm))[-1][0])
# the following would randomly choose a value in case of a tie,
# but may cause problems with compilation:
# index_of_max = np.where(sm == np.max(sm))[0]
# if len(index_of_max)>1:
# index_of_max = int(np.random.choice(index_of_max))
max_val = sm[index_of_max]
for i in range(size):
if i == index_of_max:
d = 1
else:
d = 0
derivative[i] = sm[i] * (d - max_val)
else:
raise FunctionError("Can't assign derivative for SoftMax function{} since OUTPUT_TYPE is PROB "
"(and therefore the relevant element is ambiguous)".format(self.owner_name))

output_type = self._get_current_parameter_value(OUTPUT_TYPE, context)
if output_type == ALL:
# Return full Jacobian matrix of derivatives using Kronecker's delta method:
derivative = np.empty([size, size])
for i, j in np.ndindex(size, size):
if i == j:
d = 1
else:
d = 0
derivative[j, i] = sm[i] * (d - sm[j])
elif output_type in {MAX_VAL, MAX_INDICATOR}:
# Return 1d array of derivatives for max element (i.e., the one chosen by SoftMax)
derivative = np.empty(size)
# Get the element of output returned as non-zero (max val) when output_type is not ALL
# IMPLEMENTATION NOTES:
# if there is a tie for max, this chooses the item in sm with the lowest index in sm:
index_of_max = int(np.where(sm == np.max(sm))[-1][0])
# the following would randomly choose a value in case of a tie,
# but may cause problems with compilation:
# index_of_max = np.where(sm == np.max(sm))[0]
# if len(index_of_max)>1:
# index_of_max = int(np.random.choice(index_of_max))
max_val = sm[index_of_max]
for i in range(size):
if i == index_of_max:
d = 1
else:
d = 0
derivative[i] = sm[i] * (d - max_val)
else:
raise FunctionError("Can't assign derivative for SoftMax function{} since OUTPUT_TYPE is PROB "
"(and therefore the relevant element is ambiguous)".format(self.owner_name))
result.append(derivative)

return derivative
assert per_item or len(result) == 1
return result[0] if not per_item else result


# **********************************************************************************************************************
Expand Down
91 changes: 63 additions & 28 deletions tests/functions/test_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,35 +129,56 @@ def test_execute(func, variable, params, expected, benchmark, func_mode):
[-0.010680386821751537, -0.011118109698906909, -0.01082040340318878, -0.010670257514724047, -0.010362498859374309,
-0.010933660158663306, -0.010397412260182806, -0.011602329078808718, 0.09684744183944892, -0.010262384043848513]),
(Functions.SoftMax, test_var, {kw.GAIN:RAND1, kw.OUTPUT_TYPE:kw.ALL, kw.PER_ITEM:False},
[[ 0.08863569, -0.01005855, -0.00978921, -0.00965338, -0.00937495, -0.00989168, -0.00940653, -0.01049662, -0.01068039, -0.00928437],
[-0.01005855, 0.09185608, -0.01019041, -0.01004901, -0.00975917, -0.01029708, -0.00979205, -0.01092681, -0.01111811, -0.00966488],
[-0.00978921, -0.01019041, 0.08966934, -0.00977993, -0.00949785, -0.01002135, -0.00952985, -0.01063423, -0.0108204, -0.00940609],
[-0.00965338, -0.01004901, -0.00977993, 0.08856078, -0.00936606, -0.0098823, -0.00939761, -0.01048667, -0.01067026, -0.00927557],
[-0.00937495, -0.00975917, -0.00949785, -0.00936606, 0.08627659, -0.00959726, -0.00912656, -0.0101842, -0.0103625, -0.00900804],
[-0.00989168, -0.01029708, -0.01002135, -0.0098823, -0.00959726, 0.09050301, -0.0096296, -0.01074554, -0.01093366, -0.00950454],
[-0.00940653, -0.00979205, -0.00952985, -0.00939761, -0.00912656, -0.0096296, 0.08653653, -0.01021852, -0.01039741, -0.00903839],
[-0.01049662, -0.01092681, -0.01063423, -0.01048667, -0.0101842, -0.01074554, -0.01021852, 0.09538073, -0.01160233, -0.01008581],
[-0.01068039, -0.01111811, -0.0108204, -0.01067026, -0.0103625, -0.01093366, -0.01039741, -0.01160233, 0.09684744, -0.01026238],
[-0.00928437, -0.00966488, -0.00940609, -0.00927557, -0.00900804, -0.00950454, -0.00903839, -0.01008581, -0.01026238, 0.08553008]]),

# SoftMax per-tem=True 2D single element
[[ 0.088635686173821480, -0.010058549286956951, -0.009789214523259433, -0.009653377599514660, -0.009374948470179183,
-0.009891677863509920, -0.009406534609578588, -0.010496622361458180, -0.010680386821751540, -0.009284374637613039],
[-0.010058549286956951, 0.091856076128865180, -0.010190413769852785, -0.010049009732287338, -0.009759169518165271,
-0.010297076447528582, -0.009792050177702091, -0.010926813872042194, -0.011118109698906910, -0.009664883625423075],
[-0.009789214523259433, -0.010190413769852785, 0.089669339130699100, -0.009779930406389987, -0.009497851156931268,
-0.010021354713444461, -0.009529851380888969, -0.010634229847424508, -0.010820403403188785, -0.009406089929318929],
[-0.009653377599514660, -0.010049009732287338, -0.009779930406389987, 0.088560779144081720, -0.009366057244326959,
-0.009882296570138368, -0.009397613427348460, -0.010486667337129447, -0.010670257514724050, -0.009275569312222474],
[-0.009374948470179183, -0.009759169518165271, -0.009497851156931268, -0.009366057244326959, 0.08627659236704915,
-0.009597264807784339, -0.009126561218167337, -0.010184203911638403, -0.010362498859374313, -0.009008037180482098],
[-0.009891677863509920, -0.010297076447528582, -0.010021354713444461, -0.009882296570138368, -0.009597264807784339,
0.090503011588098000, -0.009629599976882700, -0.010745537931292683, -0.010933660158663310, -0.009504543118853646],
[-0.009406534609578588, -0.009792050177702091, -0.009529851380888969, -0.009397613427348460, -0.009126561218167337,
-0.009629599976882700, 0.086536526770559590, -0.010218516599910580, -0.010397412260182810, -0.009038387119898062],
[-0.010496622361458180, -0.010926813872042194, -0.010634229847424508, -0.010486667337129447, -0.010184203911638403,
-0.010745537931292683, -0.010218516599910580, 0.095380732590004670, -0.011602329078808723, -0.01008581165029997],
[-0.010680386821751540, -0.011118109698906910, -0.010820403403188785, -0.010670257514724050, -0.010362498859374313,
-0.010933660158663310, -0.010397412260182810, -0.011602329078808723, 0.096847441839448930, -0.010262384043848514],
[-0.009284374637613039, -0.009664883625423075, -0.009406089929318929, -0.009275569312222474, -0.009008037180482098,
-0.009504543118853646, -0.009038387119898062, -0.010085811650299970, -0.010262384043848514, 0.08553008061795979]]),
# SoftMax per-tem=True
(Functions.SoftMax, [test_var], {kw.GAIN:RAND1, kw.OUTPUT_TYPE:kw.MAX_VAL, kw.PER_ITEM:True},
[[-0.010680386821751537, -0.011118109698906909, -0.01082040340318878, -0.010670257514724047, -0.010362498859374309,
-0.010933660158663306, -0.010397412260182806, -0.011602329078808718, 0.09684744183944892, -0.010262384043848513]]),
(Functions.SoftMax, [test_var], {kw.GAIN:RAND1, kw.OUTPUT_TYPE:kw.MAX_INDICATOR, kw.PER_ITEM:True},
(Functions.SoftMax, [test_var, test_var], {kw.GAIN:RAND1, kw.OUTPUT_TYPE:kw.MAX_INDICATOR, kw.PER_ITEM:True},
[[-0.010680386821751537, -0.011118109698906909, -0.01082040340318878, -0.010670257514724047, -0.010362498859374309,
-0.010933660158663306, -0.010397412260182806, -0.011602329078808718, 0.09684744183944892, -0.010262384043848513],
[-0.010680386821751537, -0.011118109698906909, -0.01082040340318878, -0.010670257514724047, -0.010362498859374309,
-0.010933660158663306, -0.010397412260182806, -0.011602329078808718, 0.09684744183944892, -0.010262384043848513]]),
(Functions.SoftMax, [test_var], {kw.GAIN:RAND1, kw.OUTPUT_TYPE:kw.ALL, kw.PER_ITEM:True},
[[ 0.08863569, -0.01005855, -0.00978921, -0.00965338, -0.00937495, -0.00989168, -0.00940653, -0.01049662, -0.01068039, -0.00928437],
[-0.01005855, 0.09185608, -0.01019041, -0.01004901, -0.00975917, -0.01029708, -0.00979205, -0.01092681, -0.01111811, -0.00966488],
[-0.00978921, -0.01019041, 0.08966934, -0.00977993, -0.00949785, -0.01002135, -0.00952985, -0.01063423, -0.0108204, -0.00940609],
[-0.00965338, -0.01004901, -0.00977993, 0.08856078, -0.00936606, -0.0098823, -0.00939761, -0.01048667, -0.01067026, -0.00927557],
[-0.00937495, -0.00975917, -0.00949785, -0.00936606, 0.08627659, -0.00959726, -0.00912656, -0.0101842, -0.0103625, -0.00900804],
[-0.00989168, -0.01029708, -0.01002135, -0.0098823, -0.00959726, 0.09050301, -0.0096296, -0.01074554, -0.01093366, -0.00950454],
[-0.00940653, -0.00979205, -0.00952985, -0.00939761, -0.00912656, -0.0096296, 0.08653653, -0.01021852, -0.01039741, -0.00903839],
[-0.01049662, -0.01092681, -0.01063423, -0.01048667, -0.0101842, -0.01074554, -0.01021852, 0.09538073, -0.01160233, -0.01008581],
[-0.01068039, -0.01111811, -0.0108204, -0.01067026, -0.0103625, -0.01093366, -0.01039741, -0.01160233, 0.09684744, -0.01026238],
[-0.00928437, -0.00966488, -0.00940609, -0.00927557, -0.00900804, -0.00950454, -0.00903839, -0.01008581, -0.01026238, 0.08553008]]),
[[[ 0.088635686173821480, -0.010058549286956951, -0.009789214523259433, -0.009653377599514660, -0.009374948470179183,
-0.009891677863509920, -0.009406534609578588, -0.010496622361458180, -0.010680386821751540, -0.009284374637613039],
[-0.010058549286956951, 0.091856076128865180, -0.010190413769852785, -0.010049009732287338, -0.009759169518165271,
-0.010297076447528582, -0.009792050177702091, -0.010926813872042194, -0.011118109698906910, -0.009664883625423075],
[-0.009789214523259433, -0.010190413769852785, 0.089669339130699100, -0.009779930406389987, -0.009497851156931268,
-0.010021354713444461, -0.009529851380888969, -0.010634229847424508, -0.010820403403188785, -0.009406089929318929],
[-0.009653377599514660, -0.010049009732287338, -0.009779930406389987, 0.088560779144081720, -0.009366057244326959,
-0.009882296570138368, -0.009397613427348460, -0.010486667337129447, -0.010670257514724050, -0.009275569312222474],
[-0.009374948470179183, -0.009759169518165271, -0.009497851156931268, -0.009366057244326959, 0.08627659236704915,
-0.009597264807784339, -0.009126561218167337, -0.010184203911638403, -0.010362498859374313, -0.009008037180482098],
[-0.009891677863509920, -0.010297076447528582, -0.010021354713444461, -0.009882296570138368, -0.009597264807784339,
0.090503011588098000, -0.009629599976882700, -0.010745537931292683, -0.010933660158663310, -0.009504543118853646],
[-0.009406534609578588, -0.009792050177702091, -0.009529851380888969, -0.009397613427348460, -0.009126561218167337,
-0.009629599976882700, 0.086536526770559590, -0.010218516599910580, -0.010397412260182810, -0.009038387119898062],
[-0.010496622361458180, -0.010926813872042194, -0.010634229847424508, -0.010486667337129447, -0.010184203911638403,
-0.010745537931292683, -0.010218516599910580, 0.095380732590004670, -0.011602329078808723, -0.01008581165029997],
[-0.010680386821751540, -0.011118109698906910, -0.010820403403188785, -0.010670257514724050, -0.010362498859374313,
-0.010933660158663310, -0.010397412260182810, -0.011602329078808723, 0.096847441839448930, -0.010262384043848514],
[-0.009284374637613039, -0.009664883625423075, -0.009406089929318929, -0.009275569312222474, -0.009008037180482098,
-0.009504543118853646, -0.009038387119898062, -0.010085811650299970, -0.010262384043848514, 0.08553008061795979]]]),
]

@pytest.mark.function
Expand All @@ -180,7 +201,14 @@ def test_transfer_derivative(func, variable, params, expected, benchmark, func_m
assert False, "unknown function mode: {}".format(func_mode)

res = benchmark(ex, variable)
assert np.allclose(res, expected)

# Tanh and Logistic need reduced accuracy in single precision mode
if func_mode != 'Python' and func in {Functions.Tanh, Functions.Logistic} and pytest.helpers.llvm_current_fp_precision() == 'fp32':
tolerance = {'rtol': 5e-7, 'atol': 1e-8}
else:
tolerance = {}

np.testing.assert_allclose(res, expected, **tolerance)


derivative_out_test_data = [
Expand All @@ -189,8 +217,10 @@ def test_transfer_derivative(func, variable, params, expected, benchmark, func_m
(Functions.SoftMax, softmax_helper, {kw.GAIN:RAND1, kw.OUTPUT_TYPE:kw.MAX_VAL, kw.PER_ITEM:False},
[-0.010680386821751537, -0.011118109698906909, -0.01082040340318878, -0.010670257514724047, -0.010362498859374309,
-0.010933660158663306, -0.010397412260182806, -0.011602329078808718, 0.09684744183944892, -0.010262384043848513]),
(Functions.SoftMax, [softmax_helper], {kw.GAIN:RAND1, kw.OUTPUT_TYPE:kw.MAX_VAL, kw.PER_ITEM:True},
(Functions.SoftMax, [softmax_helper, softmax_helper], {kw.GAIN:RAND1, kw.OUTPUT_TYPE:kw.MAX_VAL, kw.PER_ITEM:True},
[[-0.010680386821751537, -0.011118109698906909, -0.01082040340318878, -0.010670257514724047, -0.010362498859374309,
-0.010933660158663306, -0.010397412260182806, -0.011602329078808718, 0.09684744183944892, -0.010262384043848513],
[-0.010680386821751537, -0.011118109698906909, -0.01082040340318878, -0.010670257514724047, -0.010362498859374309,
-0.010933660158663306, -0.010397412260182806, -0.011602329078808718, 0.09684744183944892, -0.010262384043848513]]),
]
@pytest.mark.function
Expand All @@ -214,9 +244,14 @@ def ex(x):
assert False, "unknown function mode: {}".format(func_mode)

res = benchmark(ex, variable)
# FIX: THIS FAILS FOR func_mode=Python, func=SoftMax, and kw.PER_ITEM:True:
# EXPECTS 2d BUT ONLY 1D IS RETURNED
assert np.allclose(res, expected)

# Logistic needs reduced accuracy in single precision mode
if func_mode != 'Python' and func is Functions.Logistic and pytest.helpers.llvm_current_fp_precision() == 'fp32':
tolerance = {'rtol': 1e-7, 'atol': 1e-8}
else:
tolerance = {}

np.testing.assert_allclose(res, expected, **tolerance)

def test_transfer_with_costs_function():
f = Functions.TransferWithCosts()
Expand Down

0 comments on commit 657e737

Please sign in to comment.