Skip to content

Commit

Permalink
【Complex OP】No.44 complex support hardswish op (PaddlePaddle#57638)
Browse files Browse the repository at this point in the history
* complex support hardswish op

* fix 2023-09-25
  • Loading branch information
yangguohao authored Sep 26, 2023
1 parent f025068 commit e855881
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 10 deletions.
3 changes: 2 additions & 1 deletion paddle/phi/kernels/cpu/activation_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,8 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(log2_grad, Log2GradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(log10_grad, Log10GradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(log1p_grad, Log1pGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(log_double_grad, LogDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardswish_grad, HardSwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(hardswish_grad,
HardSwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(round_grad, RoundGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(floor_grad, FloorGradKernel)
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/cpu/activation_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ PD_REGISTER_KERNEL(log1p,
phi::dtype::float16,
phi::dtype::bfloat16) {}

PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(hardswish, HardSwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel)
PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel)
PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel)
Expand Down
78 changes: 76 additions & 2 deletions paddle/phi/kernels/funcs/activation_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -2618,7 +2618,42 @@ struct HardSwishGradFunctor : public BaseActivationFunctor<T> {

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct HardSwishGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
float threshold;
float scale;
float offset;

typename BaseActivationFunctor<ComplexType<T>>::AttrPair GetAttrs() {
return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}};
}
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
auto offset_t = static_cast<ComplexType<T>>(offset);
auto threshold_t = static_cast<ComplexType<T>>(threshold);
auto one = static_cast<ComplexType<T>>(1);
auto zero = static_cast<ComplexType<T>>(0);
auto two = static_cast<ComplexType<T>>(2);
auto scale_t = static_cast<ComplexType<T>>(scale);
auto tmp1 = ((x + offset_t) < threshold_t) // NOLINT
.template cast<ComplexType<T>>();
auto tmp2 = ((x + offset_t) > zero).template cast<ComplexType<T>>();
// dx = 0, when x <= -offset
// dout , when x >= threshold - offset
// dout * (2 * x / scale + offset / scale), otherwise
// threshold = scale = 6, offset = 3 by default
dx.device(d) = dout * (tmp2 * (two * x + offset_t) / scale_t * tmp1 +
one * (one - tmp1))
.unaryExpr(Conj<T>());
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct SwishFunctor : public BaseActivationFunctor<T> {
float beta;
Expand Down Expand Up @@ -4627,8 +4662,10 @@ struct CudaHardSwishFunctor : public BaseActivationFunctor<T> {
// threshold = scale = 6, offset = 3 by default
__device__ __forceinline__ T operator()(const T x) const {
const MPType x_t = static_cast<MPType>(x);
const MPType temp_max = std::max(x_t + static_cast<MPType>(offset), zero);
const MPType temp_min = std::min(temp_max, static_cast<MPType>(threshold));
const MPType x_offset_t = x_t + static_cast<MPType>(offset);
const MPType temp_max = (x_offset_t >= zero) ? x_offset_t : zero;
const MPType threshold_t = static_cast<MPType>(threshold);
const MPType temp_min = (temp_max < threshold_t) ? temp_max : threshold_t;
return static_cast<T>(temp_min * x_t / static_cast<MPType>(scale));
}
};
Expand Down Expand Up @@ -4668,6 +4705,43 @@ struct CudaHardSwishGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaHardSwishGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
const ComplexType<T> zero = static_cast<ComplexType<T>>(0.0f);
const ComplexType<T> one = static_cast<ComplexType<T>>(1.0f);
const ComplexType<T> two = static_cast<ComplexType<T>>(2.0f);
float threshold;
float scale;
float offset;

typename BaseActivationFunctor<ComplexType<T>>::AttrPair GetAttrs() {
return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}};
}

// dx = 0, when x <= -offset
// dout , when x >= threshold - offset
// dout * (2 * x / scale + offset / scale), otherwise
// threshold = scale = 6, offset = 3 by default
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> dout, const ComplexType<T> x) const {
const ComplexType<T> dout_t = static_cast<ComplexType<T>>(dout);
const ComplexType<T> x_t = static_cast<ComplexType<T>>(x);
const ComplexType<T> offset_t = static_cast<ComplexType<T>>(offset);
const ComplexType<T> scale_t = static_cast<ComplexType<T>>(scale);
const ComplexType<T> temp1 =
static_cast<ComplexType<T>>(x_t + offset_t > zero);
const ComplexType<T> temp2 = static_cast<ComplexType<T>>(
x_t + offset_t < static_cast<ComplexType<T>>(threshold));

return static_cast<ComplexType<T>>(
dout_t *
conj(temp1 * temp2 * (two * x_t + offset_t) / scale_t + one - temp2));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaCeilFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/gpu/activation_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,8 @@ PD_REGISTER_KERNEL(log_double_grad,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardswish_grad, HardSwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(hardswish_grad,
HardSwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(round_grad, RoundGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(floor_grad, FloorGradKernel)
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/gpu/activation_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ PD_REGISTER_ACTIVATION_KERNEL(softsign, SoftsignKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(sigmoid, SigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(logsigmoid, LogSigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(hardsigmoid, HardSigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(hardswish, HardSwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel)
PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel)
Expand Down
12 changes: 11 additions & 1 deletion python/paddle/nn/functional/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,17 @@ def hardswish(x, name=None):
return _C_ops.hardswish(x)
else:
check_variable_and_dtype(
x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'hardswish'
x,
'x',
[
'float16',
'uint16',
'float32',
'float64',
'complex64',
'complex128',
],
'hardswish',
)

threshold = 6.0
Expand Down
28 changes: 25 additions & 3 deletions test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2812,7 +2812,13 @@ def setUp(self):
self.public_python_api = paddle.nn.functional.hardswish

np.random.seed(1024)
x = np.random.uniform(-6, 6, self.shape).astype(self.dtype)
if self.dtype is np.complex64 or self.dtype is np.complex128:
x = (
np.random.uniform(-6, 6, self.shape)
+ 1j * np.random.uniform(-6, 6, self.shape)
).astype(self.dtype)
else:
x = np.random.uniform(-6, 6, self.shape).astype(self.dtype)
threshold = 6.0
scale = 6.0
offset = 3.0
Expand All @@ -2836,19 +2842,35 @@ def test_check_grad(self):
self.check_grad(
['X'],
'Out',
check_prim=True,
check_prim=True
if self.dtype not in [np.complex64, np.complex128]
else False,
only_check_prim=self.if_only_check_prim(),
)

def test_check_output(self):
self.check_output(check_prim=True)
self.check_output(
check_prim=True
if self.dtype not in [np.complex64, np.complex128]
else False
)


class TestHardSwish_ZeroDim(TestHardSwish):
def init_shape(self):
self.shape = []


class TestHardSwishComplex64(TestHardSwish):
def init_dtype(self):
self.dtype = np.complex64


class TestHardSwishComplex128(TestHardSwish):
def init_dtype(self):
self.dtype = np.complex128


class TestHardswishAPI(unittest.TestCase):
# test paddle.nn.Hardswish, paddle.nn.functional.hardswish
def setUp(self):
Expand Down

0 comments on commit e855881

Please sign in to comment.