From 66f44e4d1dc64cf0bbdc7c65df00a6a7eba1fc08 Mon Sep 17 00:00:00 2001 From: Pedro Larroy Date: Tue, 20 Nov 2018 14:05:14 +0000 Subject: [PATCH] fix ActivationGradCompute on CPU fixes #13333 --- src/operator/nn/activation-inl.h | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/operator/nn/activation-inl.h b/src/operator/nn/activation-inl.h index 2705177f951d..8ea6b7563ef7 100644 --- a/src/operator/nn/activation-inl.h +++ b/src/operator/nn/activation-inl.h @@ -199,12 +199,20 @@ void ActivationGradCompute(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { const ActivationParam& param = nnvm::get(attrs.parsed); + + const int act_type = param.act_type; #if (MXNET_USE_CUDNN == 1 || MXNET_USE_MKLDNN == 1) - bool relu = param.act_type == activation::kReLU; - CHECK_EQ(inputs.size(), relu ? 2U : 3U); + if (act_type != activation::kReLU && act_type != activation::kSoftSign) { + CHECK_EQ(inputs.size(), 3U); + } else { + CHECK_EQ(inputs.size(), 2U); + } #else - bool softsign = param.act_type == activation::kSoftSign; - CHECK_EQ(inputs.size(), softsign ? 3U : 2U); + if (act_type == activation::kSoftSign) { + CHECK_EQ(inputs.size(), 3U); + } else { + CHECK_EQ(inputs.size(), 2U); + } #endif CHECK_EQ(outputs.size(), 1U); CHECK_EQ(req.size(), 1U);