diff --git a/src/operator/nn/activation-inl.h b/src/operator/nn/activation-inl.h index 2705177f951d..8ea6b7563ef7 100644 --- a/src/operator/nn/activation-inl.h +++ b/src/operator/nn/activation-inl.h @@ -199,12 +199,20 @@ void ActivationGradCompute(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { const ActivationParam& param = nnvm::get(attrs.parsed); + + const int act_type = param.act_type; #if (MXNET_USE_CUDNN == 1 || MXNET_USE_MKLDNN == 1) - bool relu = param.act_type == activation::kReLU; - CHECK_EQ(inputs.size(), relu ? 2U : 3U); + if (act_type != activation::kReLU && act_type != activation::kSoftSign) { + CHECK_EQ(inputs.size(), 3U); + } else { + CHECK_EQ(inputs.size(), 2U); + } #else - bool softsign = param.act_type == activation::kSoftSign; - CHECK_EQ(inputs.size(), softsign ? 3U : 2U); + if (act_type == activation::kSoftSign) { + CHECK_EQ(inputs.size(), 3U); + } else { + CHECK_EQ(inputs.size(), 2U); + } #endif CHECK_EQ(outputs.size(), 1U); CHECK_EQ(req.size(), 1U);