Skip to content

Commit

Permalink
add example of unstable batchnorm
Browse files Browse the repository at this point in the history
  • Loading branch information
CAHEK7 committed Oct 19, 2023
1 parent 92deecf commit b2ce275
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
6 changes: 3 additions & 3 deletions test/gtest/bn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,9 @@ struct BNBwdTest : public ::testing::TestWithParam<std::tuple<BNTestCase, miopen
DscaleDbiasDataType,
MeanVarDataType>(bn_bwd_test_data);

test::CompareTensor<DxDataType>(bn_bwd_test_data.output, bn_bwd_test_data.ref_out, 5e-4);
test::CompareTensor<DxDataType>(bn_bwd_test_data.dScale, bn_bwd_test_data.dScale_ref, 5e-4);
test::CompareTensor<DxDataType>(bn_bwd_test_data.dBias, bn_bwd_test_data.dBias_ref, 5e-4);
test::CompareTensor<DxDataType>(bn_bwd_test_data.output, bn_bwd_test_data.ref_out, 1e-4);
test::CompareTensor<DxDataType>(bn_bwd_test_data.dScale, bn_bwd_test_data.dScale_ref, 1e-4);
test::CompareTensor<DxDataType>(bn_bwd_test_data.dBias, bn_bwd_test_data.dBias_ref, 1e-4);
}

BNTestCase bn_config;
Expand Down
14 changes: 14 additions & 0 deletions test/gtest/bn_test_data.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@
#include "tensor_util.hpp"
#include "get_handle.hpp"

// setting this to "enabled" makes ./bin/test_bn_bwd --gtest_filter=*BnBwdCKFloat/0 failed
// and ./bin/test_bn_bwd --gtest_filter=*BnBwdCKFloat/24 (25 and 26) passed
// while setting it to "true" makes opposite: 0 starts to pass and 24, 25, 26 start to fail
// the other cases are not affected
// the same problem happens for half precision as well
MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_UNSTABLE_BN)
struct BNTestCase
{
size_t N;
Expand Down Expand Up @@ -301,6 +307,14 @@ struct BNBwdTestData : public BNTestData<XDataType, DyDataType, TConfig>
void InitTensorsWithRandValue()
{
auto gen_value = [](auto...) {
if(miopen::IsEnabled(MIOPEN_DEBUG_UNSTABLE_BN{}))
{
// just advace PRNG to get slighly different sequence. but
// but with the same probability distribution and so on
prng::gen_canonical<int>();
prng::gen_canonical<int>();
prng::gen_canonical<int>();
}
return prng::gen_descreet_uniform_sign<ScaleDataType>(1e-2, 100);
};
dy.generate(gen_value);
Expand Down

0 comments on commit b2ce275

Please sign in to comment.