Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[operator] Add Mish Activation Function #20320

Merged
merged 11 commits into from
Jun 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions benchmark/opperf/nd_operations/nn_activation_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@
8.1 relu
8.2 sigmoid
8.3 log_sigmoid
8.4 softrelu
8.5 softsign
8.6 tanh
8.4 mish
8.5 softrelu
8.6 softsign
8.7 tanh

"""

Expand Down
2 changes: 1 addition & 1 deletion benchmark/opperf/rules/default_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@

# For NN operators
DEFAULT_ACT_TYPE_LR = ['leaky', 'elu', 'selu', 'gelu']
DEFAULT_ACT_TYPE_ACTIVATION = ['relu', 'sigmoid', 'log_sigmoid', 'softrelu', 'softsign', 'tanh']
DEFAULT_ACT_TYPE_ACTIVATION = ['relu', 'sigmoid', 'log_sigmoid', 'mish', 'softrelu', 'softsign', 'tanh']
DEFAULT_LABEL_SOFTMAX = [(1024, 1024), (10000, 1), (10000, 100)]

DEFAULT_LABEL_SOFTMAX_LARGE_TENSOR = [(2**32, 1)]
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/amp/lists/symbol_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@
'max',
'min',
'min_axis',
'mish',
'mp_sgd_mom_update',
'mp_sgd_update',
'multi_all_finite',
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/amp/lists/symbol_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@
'log_sigmoid',
'max',
'min',
'mish',
'mp_lamb_update_phase1',
'mp_lamb_update_phase2',
'mp_nag_mom_update',
Expand Down
8 changes: 8 additions & 0 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2267,6 +2267,14 @@ def softmin(self, *args, **kwargs):
"""
return op.softmin(self, *args, **kwargs)

def mish(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`mish`.

The arguments are the same as for :py:func:`mish`, with
this array as data.
"""
return op.mish(self, *args, **kwargs)

def squeeze(self, axis=None, inplace=False):
"""Remove dimensions with size 1 from this array without altering any data.

Expand Down
4 changes: 3 additions & 1 deletion python/mxnet/ndarray/numpy_extension/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ def activation(data, act_type='relu', **kwargs):

The following activation functions are supported:

- `log_sigmoid`: :math:`y = log(\frac{1}{1 + exp(-x)})`
- `mish`: :math:`y = x * tanh(log(1 + exp(x)))`
- `relu`: Rectified Linear Unit, :math:`y = max(x, 0)`
- `sigmoid`: :math:`y = \frac{1}{1 + exp(-x)}`
- `tanh`: Hyperbolic tangent, :math:`y = \frac{exp(x) - exp(-x)}{exp(x) + exp(-x)}`
Expand All @@ -227,7 +229,7 @@ def activation(data, act_type='relu', **kwargs):
----------
data : NDArray
The input array.
act_type : {'relu', 'sigmoid', 'softrelu', 'softsign', 'tanh'}, required
act_type : {'log_sigmoid', 'mish', 'relu', 'sigmoid', 'softrelu', 'softsign', 'tanh'}, required
Activation function to be applied.

Returns
Expand Down
8 changes: 8 additions & 0 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2356,6 +2356,14 @@ def softmin(self, *args, **kwargs):
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute softmin')

def mish(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`mish`.

The arguments are the same as for :py:func:`mish`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute mish')

def squeeze(self, axis=None): # pylint: disable=arguments-differ
"""Remove single-dimensional entries from the shape of a."""
return squeeze(self, axis=axis)
Expand Down
4 changes: 3 additions & 1 deletion python/mxnet/numpy_extension/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ def activation(data, act_type='relu', **kwargs):

The following activation functions are supported:

- `log_sigmoid`: :math:`y = log(\frac{1}{1 + exp(-x)})`
- `mish`: :math:`y = x * tanh(log(1 + exp(x)))`
- `relu`: Rectified Linear Unit, :math:`y = max(x, 0)`
- `sigmoid`: :math:`y = \frac{1}{1 + exp(-x)}`
- `tanh`: Hyperbolic tangent, :math:`y = \frac{exp(x) - exp(-x)}{exp(x) + exp(-x)}`
Expand All @@ -212,7 +214,7 @@ def activation(data, act_type='relu', **kwargs):
----------
data : NDArray
The input array.
act_type : {'relu', 'sigmoid', 'softrelu', 'softsign', 'tanh'}, required
act_type : {'log_sigmoid', 'mish', 'relu', 'sigmoid', 'softrelu', 'softsign', 'tanh'}, required
Activation function to be applied.

Returns
Expand Down
8 changes: 8 additions & 0 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2527,6 +2527,14 @@ def log_sigmoid(self, *args, **kwargs):
"""
return op.log_sigmoid(self, *args, **kwargs)

def mish(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`mish`.

The arguments are the same as for :py:func:`mish`, with
this array as data.
"""
return op.mish(self, *args, **kwargs)

def sqrt(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`sqrt`.

Expand Down
2 changes: 2 additions & 0 deletions src/api/operator/numpy_extension/npx_activation_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ inline int String2MXNetActType(const std::string& s) {
return activation::kSigmoid;
} else if (s == "log_sigmoid") {
return activation::kLogSigmoid;
} else if (s == "mish") {
return activation::kMish;
} else if (s == "tanh") {
return activation::kTanh;
} else if (s == "softrelu") {
Expand Down
9 changes: 9 additions & 0 deletions src/common/cuda/rtc/backward_functions-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ backward_log_sigmoid(const DTypeGrad grad, const DType val) {
return grad * 1 / (1 + op::exp(val));
}

template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_mish(const DTypeGrad grad, const DType val) {
const mixed_type<DTypeGrad, DType> v = val;
const auto softrelu = op::log(1 + exp(v));
const auto tanh = op::tanh(softrelu);
return grad * (tanh + v * sigmoid(v) * (1 - tanh * tanh));
}

template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_softrelu(const DTypeGrad grad, const DType val) {
Expand Down
9 changes: 9 additions & 0 deletions src/common/cuda/rtc/forward_functions-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,15 @@ __device__ inline DType log_sigmoid(const DType val) {
}
}

template <typename DType>
__device__ inline DType mish(const DType val) {
if (type_util::has_double_or_integral<DType>::value) {
return val * ::tanh(::log(1 + ::exp(val)));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing that could be improved here (I did not notice this PR earlier, sorry for a late feedback) is the numerical stability of the softrelu part - see the implementation of the softrelu (it switches to softrelu(x) = x for large values of x to avoid overflow). @Adnios could you open another PR changing e.g. this function to

return val * op::tanh(op::softrelu(val));

(the double vs float is handled in op::tanh and op::softrelu anyway so this one will also be simpler as a result) and similarly backward?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, agreed, usually Softplus has an upper bound of 20.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Thanks for your advice.

} else {
return val * ::tanhf(logf(1 + expf(val)));
}
}

template <typename DType>
__device__ inline DType softrelu(const DType val) {
// Avoid overflow of exp for large inputs.
Expand Down
2 changes: 2 additions & 0 deletions src/operator/fusion/fused_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ const std::map<std::string, std::vector<std::vector<std::string>>> ops_desc = {
{"relu" , {{"op::relu(%)", "_0"}}},
{"sigmoid" , {{"op::sigmoid(%)", "_0"}}},
{"log_sigmoid" , {{"op::log_sigmoid(%)", "_0"}}},
{"mish" , {{"op::mish(%)", "_0"}}},
{"softsign" , {{"op::softsign(%)", "_0"}}},
{"exp" , {{"op::exp(%)", "_0"}}},
{"expm1" , {{"op::expm1(%)", "_0"}}},
Expand Down Expand Up @@ -137,6 +138,7 @@ const std::map<std::string, std::vector<std::vector<std::string>>> ops_desc = {
{"_backward_relu" , {{"op::backward_relu(%, %)", "_0", "_1"}}},
{"_backward_sigmoid" , {{"op::backward_sigmoid(%, %)", "_0", "_1"}}},
{"_backward_log_sigmoid" , {{"op::backward_log_sigmoid(%, %)", "_0", "_1"}}},
{"_backward_mish" , {{"op::backward_mish(%, %)", "_0", "_1"}}},
{"_backward_expm1" , {{"op::backward_expm1(%, %)", "_0", "_1"}}},
{"_backward_log" , {{"op::backward_log(%, %)", "_0", "_1"}}},
{"_backward_log10" , {{"op::backward_log10(%, %)", "_0", "_1"}}},
Expand Down
6 changes: 6 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,12 @@ MXNET_UNARY_MATH_OP(log_sigmoid, math::log(1.0f / (1.0f + math::exp(-a))));

MXNET_UNARY_MATH_OP(log_sigmoid_grad, 1.0f / (1.0f + math::exp(a)));

MXNET_UNARY_MATH_OP(mish, a * math::tanh(math::log(1.0f + math::exp(a))));

MXNET_UNARY_MATH_OP(mish_grad, math::tanh(math::log(1.0f + math::exp(a))) +
a * (1.0f / (1.0f + math::exp(-a))) *
(1.0f - math::sqr(math::tanh(math::log(1.0f + math::exp(a))))));

MXNET_UNARY_MATH_OP(softsign, a / (1.0f + math::fabs(a)));

MXNET_UNARY_MATH_OP(softsign_grad, 1.0f / math::sqr(1.0f + math::fabs(a)));
Expand Down
13 changes: 12 additions & 1 deletion src/operator/nn/activation-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ namespace activation {
enum ActivationOpInputs {kData};
enum ActivationOpOutputs {kOut};
enum ActivationOpResource {kTempSpace};
enum ActivationOpType {kReLU, kSigmoid, kLogSigmoid, kTanh, kSoftReLU, kSoftSign};
enum ActivationOpType {kReLU, kSigmoid, kLogSigmoid, kMish, kTanh, kSoftReLU, kSoftSign};

// Get the number of inputs to the gradient depending on the activation type
int GradNumInputs(int act_type);
Expand All @@ -61,6 +61,7 @@ struct ActivationParam : public dmlc::Parameter<ActivationParam> {
.add_enum("relu", activation::kReLU)
.add_enum("sigmoid", activation::kSigmoid)
.add_enum("log_sigmoid", activation::kLogSigmoid)
.add_enum("mish", activation::kMish)
.add_enum("tanh", activation::kTanh)
.add_enum("softrelu", activation::kSoftReLU)
.add_enum("softsign", activation::kSoftSign)
Expand All @@ -78,6 +79,8 @@ struct ActivationParam : public dmlc::Parameter<ActivationParam> {
return "sigmoid";
case activation::kLogSigmoid:
return "log_sigmoid";
case activation::kMish:
return "mish";
case activation::kTanh:
return "tanh";
case activation::kSoftReLU:
Expand Down Expand Up @@ -166,6 +169,10 @@ void ActivationComputeImpl(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
ActivationForward<xpu, mshadow_op::log_sigmoid, mshadow_op::log_sigmoid_grad>(
ctx, inputs[0], req[0], outputs[0]);
break;
case activation::kMish:
ActivationForward<xpu, mshadow_op::mish, mshadow_op::mish_grad>(
ctx, inputs[0], req[0], outputs[0]);
break;
case activation::kTanh:
ActivationForward<xpu, mshadow_op::tanh, mshadow_op::tanh_grad>(
ctx, inputs[0], req[0], outputs[0]);
Expand Down Expand Up @@ -201,6 +208,10 @@ void ActivationGradComputeImpl(const nnvm::NodeAttrs& attrs, const OpContext &ct
ActivationBackward<xpu, mshadow_op::log_sigmoid, mshadow_op::log_sigmoid_grad>(
ctx, inputs[0], inputs[1], req[0], outputs[0]);
break;
case activation::kMish:
ActivationBackward<xpu, mshadow_op::mish, mshadow_op::mish_grad>(
ctx, inputs[0], inputs[2], req[0], outputs[0]);
break;
case activation::kTanh:
ActivationBackward<xpu, mshadow_op::tanh, mshadow_op::tanh_grad>(
ctx, inputs[0], inputs[1], req[0], outputs[0]);
Expand Down
3 changes: 3 additions & 0 deletions src/operator/nn/activation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ int GradNumInputs(int act_type) {
case kTanh:
case kSigmoid:
case kLogSigmoid:
case kMish:
return 3;
default:
CHECK(false) << "missing activation type";
Expand Down Expand Up @@ -93,6 +94,7 @@ struct ActivationGrad {
case kTanh:
case kSigmoid:
case kLogSigmoid:
case kMish:
heads.push_back(n->inputs[activation::kData]);
break;
default:
Expand Down Expand Up @@ -171,6 +173,7 @@ The following activation functions are supported:
- `relu`: Rectified Linear Unit, :math:`y = max(x, 0)`
- `sigmoid`: :math:`y = \frac{1}{1 + exp(-x)}`
- `log_sigmoid`: :math:`y = log(\frac{1}{1 + exp(-x)})`
- `mish`: :math:`y = x * tanh(log(1 + exp(x)))`
- `tanh`: Hyperbolic tangent, :math:`y = \frac{exp(x) - exp(-x)}{exp(x) + exp(-x)}`
- `softrelu`: Soft ReLU, or SoftPlus, :math:`y = log(1 + exp(x))`
- `softsign`: :math:`y = \frac{x}{1 + abs(x)}`
Expand Down
10 changes: 8 additions & 2 deletions src/operator/nn/activation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,16 @@ void ActivationCompute<gpu>(const nnvm::NodeAttrs& attrs,
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
const int act_type = param.act_type;

// SoftReLU and kSoftSign are both not supported by CUDNN yet
// SoftReLU, kSoftSign and Mish are not supported by CUDNN yet
if (act_type == activation::kSoftReLU) {
ActivationForward<gpu, mshadow_op::softrelu, mshadow_op::softrelu_grad>(ctx,
inputs[0], req[0], outputs[0]);
} else if (act_type == activation::kSoftSign) {
ActivationForward<gpu, mshadow_op::softsign, mshadow_op::softsign_grad>(ctx,
inputs[0], req[0], outputs[0]);
} else if (act_type == activation::kMish) {
ActivationForward<gpu, mshadow_op::mish, mshadow_op::mish_grad>(ctx,
inputs[0], req[0], outputs[0]);
} else {
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
get_cudnn_op<DType>(param).Forward(ctx, inputs[0], req[0], outputs[0]);
Expand All @@ -84,10 +87,13 @@ void ActivationGradCompute<gpu>(const nnvm::NodeAttrs& attrs,

bool do_memory_opt = dmlc::GetEnv("MXNET_MEMORY_OPT", 0);

// both SoftReLU and SoftSign not supported by CUDNN yet
// SoftReLU, SoftSign and Mish not supported by CUDNN yet
if (act_type == activation::kSoftReLU) {
ActivationBackward<gpu, mshadow_op::softrelu, mshadow_op::softrelu_grad>(
ctx, inputs.at(0), inputs.at(1), req[0], outputs[0]);
} else if (act_type == activation::kMish) {
ActivationBackward<gpu, mshadow_op::mish, mshadow_op::mish_grad>(
ctx, inputs.at(0), inputs.at(2), req[0], outputs[0]);
} else if (act_type == activation::kSoftSign) {
if (do_memory_opt) {
ActivationBackward<gpu, mshadow_op::softsign, mshadow_op::softsign_grad>(
Expand Down
3 changes: 3 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_act.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ bool SupportMKLDNNAct(const ActivationParam& param) {
return param.act_type == activation::kReLU
|| param.act_type == activation::kSigmoid
|| param.act_type == activation::kLogSigmoid
|| param.act_type == activation::kMish
|| param.act_type == activation::kSoftReLU
|| param.act_type == activation::kTanh;
}
Expand Down Expand Up @@ -86,6 +87,8 @@ mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param) {
return mkldnn::algorithm::eltwise_logistic;
case activation::kLogSigmoid:
return mkldnn::algorithm::eltwise_logsigmoid;
case activation::kMish:
return mkldnn::algorithm::eltwise_mish;
case activation::kTanh:
return mkldnn::algorithm::eltwise_tanh;
case activation::kSoftReLU:
Expand Down
2 changes: 2 additions & 0 deletions src/operator/operator_tune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sigmoid); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sigmoid_grad); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log_sigmoid); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log_sigmoid_grad); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::mish); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::mish_grad); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::softsign); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::softsign_grad); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::relu); // NOLINT()
Expand Down
17 changes: 17 additions & 0 deletions src/operator/tensor/elemwise_unary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,23 @@ The storage type of ``log_sigmoid`` output is always dense
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_log_sigmoid,
unary_bwd<mshadow_op::log_sigmoid_grad>);

// mish
MXNET_OPERATOR_REGISTER_UNARY(mish)
MXNET_ADD_SPARSE_OP_ALIAS(mish)
.describe(R"code(Computes mish of x element-wise.

.. math::
y = x * tanh(log(1 + exp(x)))

The storage type of ``mish`` output is always dense

)code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::Compute<cpu, mshadow_op::mish>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_mish"});

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_mish,
unary_bwd<mshadow_op::mish_grad>);



DMLC_REGISTER_PARAMETER(HardSigmoidParam);
Expand Down
6 changes: 6 additions & 0 deletions src/operator/tensor/elemwise_unary_op_basic.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ NNVM_REGISTER_OP(log_sigmoid)
NNVM_REGISTER_OP(_backward_log_sigmoid)
.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryRTCCompute{"backward_log_sigmoid"});

NNVM_REGISTER_OP(mish)
.set_attr<FCompute>("FCompute<gpu>", UnaryRTCCompute{"mish"});

NNVM_REGISTER_OP(_backward_mish)
.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryRTCCompute{"backward_mish"});

NNVM_REGISTER_OP(hard_sigmoid)
.set_attr<FCompute>("FCompute<gpu>", HardSigmoidForward<gpu>);

Expand Down
1 change: 1 addition & 0 deletions tests/cpp/operator/activation_perf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ TEST(ACTIVATION_PERF, ExecuteBidirectional) {
"relu",
"sigmoid",
"log_sigmoid",
"mish",
"tanh",
"softrelu",
"softsign"
Expand Down
4 changes: 4 additions & 0 deletions tests/python/mkl/subgraphs/test_conv_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def hybrid_forward(self, F, x):
("relu", False), #TODO(bgawrych): investigate
("sigmoid", True),
("log_sigmoid", False),
("mish", False),
("tanh", False), #TODO(bgawrych): investigate
#("softrelu", True), #TODO(bgawrych): bug in oneDNN with AVX
("relu6", False), #TODO(bgawrych): investigate
Expand Down Expand Up @@ -149,6 +150,7 @@ def hybrid_forward(self, F, x):
("relu", True),
("sigmoid", True),
("log_sigmoid", True),
("mish", True),
("tanh", True),
("softrelu", True),
("relu6", True),
Expand Down Expand Up @@ -186,6 +188,7 @@ def hybrid_forward(self, F, x):
("relu", True),
("sigmoid", True),
("log_sigmoid", True),
("mish", True),
("tanh", True),
#("softrelu", True), #TODO(bgawrych): failing fusion check - difference in random single element
("relu6", True),
Expand Down Expand Up @@ -293,6 +296,7 @@ def hybrid_forward(self, F, x, shared_weight):
("relu", True),
("sigmoid", True),
("log_sigmoid", True),
("mish", True),
("tanh", True),
("softrelu", True),
("relu6", True),
Expand Down
Loading