diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index 221aa0f3c..1236e52e1 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -37,6 +37,7 @@ #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" +#include "cutlass/numeric_conversion.h" #include "cutlass/constants.h" #include "cutlass/complex.h" #include "cutlass/array.h" @@ -54,7 +55,7 @@ namespace thread { // Identity operator template struct Identity { - static const bool kIsHeavy=false; + static const bool kIsHeavy = false; CUTLASS_HOST_DEVICE T operator()(T value) const { @@ -128,7 +129,8 @@ struct Scale> { /// Always put threshold in the right hand side of max to propagate NaN. template struct ReLu { - static const bool kIsHeavy=false; + static const bool kIsHeavy = false; + CUTLASS_HOST_DEVICE T operator()(T const & threshold, T value) const { maximum mx; @@ -149,7 +151,8 @@ using ReLU = ReLu; template struct ReLu> { - static const bool kIsHeavy=false; + static const bool kIsHeavy = false; + CUTLASS_HOST_DEVICE Array operator()(T const & threshold, Array const &frag) const { maximum> mx; @@ -207,6 +210,9 @@ struct Clamp> { // Leaky Relu operator template struct LeakyReLU { + + static const bool kIsHeavy = false; + struct Arguments { T leaky_alpha = T(0); }; @@ -225,6 +231,9 @@ struct LeakyReLU { template struct LeakyReLU > { + + static const bool kIsHeavy = false; + using Arguments = typename LeakyReLU::Arguments; CUTLASS_HOST_DEVICE @@ -249,6 +258,8 @@ struct LeakyReLU > { // Tanh operator template struct Tanh { + static const bool kIsHeavy = true; + CUTLASS_HOST_DEVICE T operator()(T const &value) const { return fast_tanh(value); @@ -257,6 +268,8 @@ struct Tanh { template struct Tanh > { + static const bool kIsHeavy = true; + CUTLASS_HOST_DEVICE Array operator()(Array const &value) const { Array y; @@ -274,6 +287,7 @@ struct Tanh > { template struct Tanh> { using T = half_t; + static const bool kIsHeavy = true; CUTLASS_HOST_DEVICE Array operator()(Array const& z) const { @@ -285,6 +299,8 @@ struct Tanh> { // Sigmoid operator template struct Sigmoid { + static const bool kIsHeavy = true; + CUTLASS_HOST_DEVICE T operator()(T const &value) const { return T(1) / (T(1) + fast_exp(-value)); @@ -293,6 +309,8 @@ struct Sigmoid { template struct Sigmoid > { + static const bool kIsHeavy = true; + CUTLASS_HOST_DEVICE Array operator()(Array const &value) const { Array y; @@ -310,6 +328,7 @@ struct Sigmoid > { template struct Sigmoid> { using T = half_t; + static const bool kIsHeavy = true; CUTLASS_HOST_DEVICE Array operator()(Array const& z) const { @@ -338,6 +357,8 @@ struct Sigmoid> { // Reference: https://pytorch.org/docs/stable/generated/torch.nn.SiLU.html template struct SiLu { + static const bool kIsHeavy = true; + CUTLASS_HOST_DEVICE T operator()(T const &value) const { Sigmoid sigmoid; @@ -347,6 +368,8 @@ struct SiLu { template struct SiLu> { + static const bool kIsHeavy = true; + CUTLASS_HOST_DEVICE Array operator()(Array const &value) const { Sigmoid> sigmoid_op; @@ -362,6 +385,8 @@ struct SiLu> { // Reference: https://pytorch.org/docs/stable/generated/torch.nn.Hardswish.html template struct HardSwish { + static const bool kIsHeavy = false; + CUTLASS_HOST_DEVICE T operator()(T const &x) const { minimum mn; @@ -374,6 +399,7 @@ struct HardSwish { template <> struct HardSwish { using T = float; + static const bool kIsHeavy = false; CUTLASS_HOST_DEVICE T operator()(T const &x) const { @@ -386,6 +412,8 @@ struct HardSwish { template struct HardSwish > { + static const bool kIsHeavy = false; + CUTLASS_HOST_DEVICE Array operator()(Array const &value) const { Array y; @@ -403,6 +431,7 @@ struct HardSwish > { template struct HardSwish > { using T = half_t; + static const bool kIsHeavy = false; CUTLASS_HOST_DEVICE Array operator()(Array const &value) const { @@ -427,6 +456,8 @@ struct HardSwish > { // GELU operator template struct GELU { + static const bool kIsHeavy = true; + CUTLASS_HOST_DEVICE T operator()(T const &value) const { return T(cutlass::constants::half() * value * @@ -436,6 +467,8 @@ struct GELU { template <> struct GELU { + static const bool kIsHeavy = true; + CUTLASS_HOST_DEVICE float operator()(float const &value) const { return cutlass::constants::half() * value * @@ -445,6 +478,8 @@ struct GELU { template <> struct GELU { + static const bool kIsHeavy = true; + CUTLASS_HOST_DEVICE double operator()(double const &value) const { return cutlass::constants::half() * value * @@ -454,6 +489,8 @@ struct GELU { template struct GELU > { + static const bool kIsHeavy = true; + CUTLASS_HOST_DEVICE Array operator()(Array const &value) const { Array y; @@ -474,7 +511,8 @@ using ScaledGELU = Scale>; // GELU operator implemented using the Taylor series approximation template struct GELU_taylor { - static const bool kIsHeavy=true; + static const bool kIsHeavy = true; + CUTLASS_HOST_DEVICE T operator()(T const &z) const { @@ -488,7 +526,8 @@ struct GELU_taylor { template struct GELU_taylor > { - static const bool kIsHeavy=true; + static const bool kIsHeavy = true; + CUTLASS_HOST_DEVICE Array operator()(Array const &z) const { @@ -514,7 +553,8 @@ struct GELU_taylor > { template struct GELU_taylor > { - static const bool kIsHeavy=true; + static const bool kIsHeavy = true; + CUTLASS_HOST_DEVICE Array operator()(Array const &value) const { Array y; @@ -536,6 +576,8 @@ using ScaledGELU_taylor = Scale>; /// z is computed from the forward pass. template struct dGELU { + static const bool kIsHeavy = true; + CUTLASS_HOST_DEVICE T operator()(T const &d_t, T const &z) const { @@ -554,6 +596,8 @@ struct dGELU { template struct dGELU > { + static const bool kIsHeavy = true; + CUTLASS_HOST_DEVICE Array operator()(Array const &d_t, Array const &z) const { Array y; @@ -568,6 +612,45 @@ struct dGELU > { } }; +template +struct dReLU { + CUTLASS_HOST_DEVICE + T operator()(T const& d_t, bool d_relu) const { + return d_relu ? d_t : T(0); + } + + CUTLASS_HOST_DEVICE + T operator()(T const& d_t, uint1b_t d_relu) const { + return operator()(d_t, static_cast(d_relu)); + } +}; + +template +struct dReLU> { + CUTLASS_HOST_DEVICE + Array operator()(Array const& d_t, bool const (&d_relu)[N]) const { + Array y; + dReLU relu_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + y[i] = relu_op(d_t[i], d_relu[i]); + } + + return y; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const& d_t, Array const& d_relu) const { + UnpackPredicates unpack_op; + + bool preds[N]; + unpack_op(preds, d_relu); + + return operator()(d_t, preds); + } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace thread