Skip to content

Commit

Permalink
set kIsHeavy member variables (#1012)
Browse files Browse the repository at this point in the history
* set kIsHeavy member variables

* correct kIsHeavy value for Tanh

* set kIsHeavy=false for HardSwish

---------

Co-authored-by: Haicheng Wu <[email protected]>
  • Loading branch information
FabianSchuetze and hwu36 authored Oct 4, 2023
1 parent 61a38f8 commit 5f13dca
Showing 1 changed file with 89 additions and 6 deletions.
95 changes: 89 additions & 6 deletions include/cutlass/epilogue/thread/activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/constants.h"
#include "cutlass/complex.h"
#include "cutlass/array.h"
Expand All @@ -54,7 +55,7 @@ namespace thread {
// Identity operator
template <typename T>
struct Identity {
static const bool kIsHeavy=false;
static const bool kIsHeavy = false;

CUTLASS_HOST_DEVICE
T operator()(T value) const {
Expand Down Expand Up @@ -128,7 +129,8 @@ struct Scale<Activation<T>> {
/// Always put threshold in the right hand side of max to propagate NaN.
template <typename T>
struct ReLu {
static const bool kIsHeavy=false;
static const bool kIsHeavy = false;

CUTLASS_HOST_DEVICE
T operator()(T const & threshold, T value) const {
maximum<T> mx;
Expand All @@ -149,7 +151,8 @@ using ReLU = ReLu<T>;

template <typename T, int N>
struct ReLu<Array<T, N>> {
static const bool kIsHeavy=false;
static const bool kIsHeavy = false;

CUTLASS_HOST_DEVICE
Array<T, N> operator()(T const & threshold, Array<T, N> const &frag) const {
maximum<Array<T, N>> mx;
Expand Down Expand Up @@ -207,6 +210,9 @@ struct Clamp<Array<T,N>> {
// Leaky Relu operator
template <typename T>
struct LeakyReLU {

static const bool kIsHeavy = false;

struct Arguments {
T leaky_alpha = T(0);
};
Expand All @@ -225,6 +231,9 @@ struct LeakyReLU {

template <typename T, int N>
struct LeakyReLU<Array<T, N> > {

static const bool kIsHeavy = false;

using Arguments = typename LeakyReLU<T>::Arguments;

CUTLASS_HOST_DEVICE
Expand All @@ -249,6 +258,8 @@ struct LeakyReLU<Array<T, N> > {
// Tanh operator
template <typename T>
struct Tanh {
static const bool kIsHeavy = true;

CUTLASS_HOST_DEVICE
T operator()(T const &value) const {
return fast_tanh(value);
Expand All @@ -257,6 +268,8 @@ struct Tanh {

template <typename T, int N>
struct Tanh<Array<T, N> > {
static const bool kIsHeavy = true;

CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &value) const {
Array<T, N> y;
Expand All @@ -274,6 +287,7 @@ struct Tanh<Array<T, N> > {
template <int N>
struct Tanh<Array<half_t, N>> {
using T = half_t;
static const bool kIsHeavy = true;

CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const& z) const {
Expand All @@ -285,6 +299,8 @@ struct Tanh<Array<half_t, N>> {
// Sigmoid operator
template <typename T>
struct Sigmoid {
static const bool kIsHeavy = true;

CUTLASS_HOST_DEVICE
T operator()(T const &value) const {
return T(1) / (T(1) + fast_exp(-value));
Expand All @@ -293,6 +309,8 @@ struct Sigmoid {

template <typename T, int N>
struct Sigmoid<Array<T, N> > {
static const bool kIsHeavy = true;

CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &value) const {
Array<T, N> y;
Expand All @@ -310,6 +328,7 @@ struct Sigmoid<Array<T, N> > {
template <int N>
struct Sigmoid<Array<half_t, N>> {
using T = half_t;
static const bool kIsHeavy = true;

CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const& z) const {
Expand Down Expand Up @@ -338,6 +357,8 @@ struct Sigmoid<Array<half_t, N>> {
// Reference: https://pytorch.org/docs/stable/generated/torch.nn.SiLU.html
template <typename T>
struct SiLu {
static const bool kIsHeavy = true;

CUTLASS_HOST_DEVICE
T operator()(T const &value) const {
Sigmoid<T> sigmoid;
Expand All @@ -347,6 +368,8 @@ struct SiLu {

template <typename T, int N>
struct SiLu<Array<T, N>> {
static const bool kIsHeavy = true;

CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &value) const {
Sigmoid<Array<T, N>> sigmoid_op;
Expand All @@ -362,6 +385,8 @@ struct SiLu<Array<T, N>> {
// Reference: https://pytorch.org/docs/stable/generated/torch.nn.Hardswish.html
template <typename T>
struct HardSwish {
static const bool kIsHeavy = false;

CUTLASS_HOST_DEVICE
T operator()(T const &x) const {
minimum<T> mn;
Expand All @@ -374,6 +399,7 @@ struct HardSwish {
template <>
struct HardSwish<float> {
using T = float;
static const bool kIsHeavy = false;

CUTLASS_HOST_DEVICE
T operator()(T const &x) const {
Expand All @@ -386,6 +412,8 @@ struct HardSwish<float> {

template <typename T, int N>
struct HardSwish<Array<T, N> > {
static const bool kIsHeavy = false;

CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &value) const {
Array<T, N> y;
Expand All @@ -403,6 +431,7 @@ struct HardSwish<Array<T, N> > {
template <int N>
struct HardSwish<Array<half_t, N> > {
using T = half_t;
static const bool kIsHeavy = false;

CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &value) const {
Expand All @@ -427,6 +456,8 @@ struct HardSwish<Array<half_t, N> > {
// GELU operator
template <typename T>
struct GELU {
static const bool kIsHeavy = true;

CUTLASS_HOST_DEVICE
T operator()(T const &value) const {
return T(cutlass::constants::half<T>() * value *
Expand All @@ -436,6 +467,8 @@ struct GELU {

template <>
struct GELU<float> {
static const bool kIsHeavy = true;

CUTLASS_HOST_DEVICE
float operator()(float const &value) const {
return cutlass::constants::half<float>() * value *
Expand All @@ -445,6 +478,8 @@ struct GELU<float> {

template <>
struct GELU<double> {
static const bool kIsHeavy = true;

CUTLASS_HOST_DEVICE
double operator()(double const &value) const {
return cutlass::constants::half<double>() * value *
Expand All @@ -454,6 +489,8 @@ struct GELU<double> {

template <typename T, int N>
struct GELU<Array<T, N> > {
static const bool kIsHeavy = true;

CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &value) const {
Array<T, N> y;
Expand All @@ -474,7 +511,8 @@ using ScaledGELU = Scale<GELU<T>>;
// GELU operator implemented using the Taylor series approximation
template <typename T>
struct GELU_taylor {
static const bool kIsHeavy=true;
static const bool kIsHeavy = true;

CUTLASS_HOST_DEVICE
T operator()(T const &z) const {

Expand All @@ -488,7 +526,8 @@ struct GELU_taylor {

template <int N>
struct GELU_taylor<Array<half_t, N> > {
static const bool kIsHeavy=true;
static const bool kIsHeavy = true;

CUTLASS_HOST_DEVICE
Array<half_t, N> operator()(Array<half_t, N> const &z) const {

Expand All @@ -514,7 +553,8 @@ struct GELU_taylor<Array<half_t, N> > {

template <typename T, int N>
struct GELU_taylor<Array<T, N> > {
static const bool kIsHeavy=true;
static const bool kIsHeavy = true;

CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &value) const {
Array<T, N> y;
Expand All @@ -536,6 +576,8 @@ using ScaledGELU_taylor = Scale<GELU_taylor<T>>;
/// z is computed from the forward pass.
template <typename T>
struct dGELU {
static const bool kIsHeavy = true;

CUTLASS_HOST_DEVICE
T operator()(T const &d_t, T const &z) const {

Expand All @@ -554,6 +596,8 @@ struct dGELU {

template <typename T, int N>
struct dGELU<Array<T, N> > {
static const bool kIsHeavy = true;

CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &d_t, Array<T, N> const &z) const {
Array<T, N> y;
Expand All @@ -568,6 +612,45 @@ struct dGELU<Array<T, N> > {
}
};

template <typename T>
struct dReLU {
CUTLASS_HOST_DEVICE
T operator()(T const& d_t, bool d_relu) const {
return d_relu ? d_t : T(0);
}

CUTLASS_HOST_DEVICE
T operator()(T const& d_t, uint1b_t d_relu) const {
return operator()(d_t, static_cast<bool>(d_relu));
}
};

template <typename T, int N>
struct dReLU<Array<T, N>> {
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const& d_t, bool const (&d_relu)[N]) const {
Array<T, N> y;
dReLU<T> relu_op;

CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
y[i] = relu_op(d_t[i], d_relu[i]);
}

return y;
}

CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const& d_t, Array<uint1b_t, N> const& d_relu) const {
UnpackPredicates<N> unpack_op;

bool preds[N];
unpack_op(preds, d_relu);

return operator()(d_t, preds);
}
};

/////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace thread
Expand Down

0 comments on commit 5f13dca

Please sign in to comment.