diff --git a/efficientdet/utils.py b/efficientdet/utils.py index 340e63342..f8578a0bb 100644 --- a/efficientdet/utils.py +++ b/efficientdet/utils.py @@ -42,6 +42,8 @@ def activation_fn(features: tf.Tensor, act_type: Text): return tf.nn.relu(features) elif act_type == 'relu6': return tf.nn.relu6(features) + elif act_type == 'mish': + return features * tf.math.tanh(tf.math.softplus(features)) else: raise ValueError('Unsupported act_type {}'.format(act_type)) diff --git a/efficientdet/utils_test.py b/efficientdet/utils_test.py index 462b2ea69..75a3aedc2 100644 --- a/efficientdet/utils_test.py +++ b/efficientdet/utils_test.py @@ -114,6 +114,33 @@ def _model(inputs): self.assertIs(out.dtype, tf.float16) # output should be float16. +class ActivationTest(tf.test.TestCase): + + def test_swish(self): + features = tf.constant([.5, 10]) + + result = utils.activation_fn(features, "swish") + expected = features * tf.sigmoid(features) + self.assertAllClose(result, expected) + + result = utils.activation_fn(features, "swish_native") + self.assertAllClose(result, expected) + + def test_relu(self): + features = tf.constant([.5, 10]) + result = utils.activation_fn(features, "relu") + self.assertAllClose(result, [0.5, 10]) + + def test_relu6(self): + features = tf.constant([.5, 10]) + result = utils.activation_fn(features, "relu6") + self.assertAllClose(result, [0.5, 6]) + + def test_mish(self): + features = tf.constant([.5, 10]) + result = utils.activation_fn(features, "mish") + self.assertAllClose(result, [0.37524524, 10.0]) + if __name__ == '__main__': logging.set_verbosity(logging.WARNING) tf.disable_eager_execution()