Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Complex OP】No.44 complex support hardswish op #57638

Merged
merged 2 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion paddle/phi/kernels/cpu/activation_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,8 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(log2_grad, Log2GradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(log10_grad, Log10GradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(log1p_grad, Log1pGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(log_double_grad, LogDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardswish_grad, HardSwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(hardswish_grad,
HardSwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(round_grad, RoundGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(floor_grad, FloorGradKernel)
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/cpu/activation_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ PD_REGISTER_KERNEL(log1p,
phi::dtype::float16,
phi::dtype::bfloat16) {}

PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(hardswish, HardSwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel)
PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel)
PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel)
Expand Down
74 changes: 72 additions & 2 deletions paddle/phi/kernels/funcs/activation_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -2574,7 +2574,38 @@ struct HardSwishGradFunctor : public BaseActivationFunctor<T> {

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct HardSwishGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
float threshold;
float scale;
float offset;

typename BaseActivationFunctor<ComplexType<T>>::AttrPair GetAttrs() {
return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}};
}
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
auto tmp = ((x + static_cast<ComplexType<T>>(offset)) <
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里写的太乱了,和下面的GPU一样吧,中间变量先cast一次,把公式注释在上面

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

static_cast<ComplexType<T>>(threshold)) // NOLINT
.template cast<ComplexType<T>>();
dx.device(d) = dout * (((x + static_cast<ComplexType<T>>(offset)) >
static_cast<ComplexType<T>>(0))
.template cast<ComplexType<T>>() *
(static_cast<ComplexType<T>>(2) * x +
static_cast<ComplexType<T>>(offset)) /
static_cast<ComplexType<T>>(scale) * tmp +
static_cast<ComplexType<T>>(1) *
(static_cast<ComplexType<T>>(1) - tmp))
.unaryExpr(Conj<T>());
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct SwishFunctor : public BaseActivationFunctor<T> {
float beta;
Expand Down Expand Up @@ -4540,8 +4571,10 @@ struct CudaHardSwishFunctor : public BaseActivationFunctor<T> {
// threshold = scale = 6, offset = 3 by default
__device__ __forceinline__ T operator()(const T x) const {
const MPType x_t = static_cast<MPType>(x);
const MPType temp_max = std::max(x_t + static_cast<MPType>(offset), zero);
const MPType temp_min = std::min(temp_max, static_cast<MPType>(threshold));
const MPType x_offset_t = x_t + static_cast<MPType>(offset);
const MPType temp_max = (x_offset_t >= zero) ? x_offset_t : zero;
const MPType threshold_t = static_cast<MPType>(threshold);
const MPType temp_min = (temp_max < threshold_t) ? temp_max : threshold_t;
return static_cast<T>(temp_min * x_t / static_cast<MPType>(scale));
}
};
Expand Down Expand Up @@ -4581,6 +4614,43 @@ struct CudaHardSwishGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaHardSwishGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
const ComplexType<T> zero = static_cast<ComplexType<T>>(0.0f);
const ComplexType<T> one = static_cast<ComplexType<T>>(1.0f);
const ComplexType<T> two = static_cast<ComplexType<T>>(2.0f);
float threshold;
float scale;
float offset;

typename BaseActivationFunctor<ComplexType<T>>::AttrPair GetAttrs() {
return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}};
}

// dx = 0, when x <= -offset
// dout , when x >= threshold - offset
// dout * (2 * x / scale + offset / scale), otherwise
// threshold = scale = 6, offset = 3 by default
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> dout, const ComplexType<T> x) const {
const ComplexType<T> dout_t = static_cast<ComplexType<T>>(dout);
const ComplexType<T> x_t = static_cast<ComplexType<T>>(x);
const ComplexType<T> offset_t = static_cast<ComplexType<T>>(offset);
const ComplexType<T> scale_t = static_cast<ComplexType<T>>(scale);
const ComplexType<T> temp1 =
static_cast<ComplexType<T>>(x_t + offset_t > zero);
const ComplexType<T> temp2 = static_cast<ComplexType<T>>(
x_t + offset_t < static_cast<ComplexType<T>>(threshold));

return static_cast<ComplexType<T>>(
dout_t *
conj(temp1 * temp2 * (two * x_t + offset_t) / scale_t + one - temp2));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaCeilFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/gpu/activation_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,8 @@ PD_REGISTER_KERNEL(log_double_grad,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardswish_grad, HardSwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(hardswish_grad,
HardSwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(round_grad, RoundGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(floor_grad, FloorGradKernel)
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/gpu/activation_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ PD_REGISTER_ACTIVATION_KERNEL(softsign, SoftsignKernel)
PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(logsigmoid, LogSigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(hardsigmoid, HardSigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(hardswish, HardSwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel)
PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel)
Expand Down
12 changes: 11 additions & 1 deletion python/paddle/nn/functional/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,17 @@ def hardswish(x, name=None):
return _C_ops.hardswish(x)
else:
check_variable_and_dtype(
x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'hardswish'
x,
'x',
[
'float16',
'uint16',
'float32',
'float64',
'complex64',
'complex128',
],
'hardswish',
)

threshold = 6.0
Expand Down
28 changes: 25 additions & 3 deletions test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2734,7 +2734,13 @@ def setUp(self):
self.public_python_api = paddle.nn.functional.hardswish

np.random.seed(1024)
x = np.random.uniform(-6, 6, self.shape).astype(self.dtype)
if self.dtype is np.complex64 or self.dtype is np.complex128:
x = (
np.random.uniform(-6, 6, self.shape)
+ 1j * np.random.uniform(-6, 6, self.shape)
).astype(self.dtype)
else:
x = np.random.uniform(-6, 6, self.shape).astype(self.dtype)
threshold = 6.0
scale = 6.0
offset = 3.0
Expand All @@ -2758,19 +2764,35 @@ def test_check_grad(self):
self.check_grad(
['X'],
'Out',
check_prim=True,
check_prim=True
if self.dtype not in [np.complex64, np.complex128]
else False,
only_check_prim=self.if_only_check_prim(),
)

def test_check_output(self):
self.check_output(check_prim=True)
self.check_output(
check_prim=True
if self.dtype not in [np.complex64, np.complex128]
else False
)


class TestHardSwish_ZeroDim(TestHardSwish):
def init_shape(self):
self.shape = []


class TestHardSwishComplex64(TestHardSwish):
def init_dtype(self):
self.dtype = np.complex64


class TestHardSwishComplex128(TestHardSwish):
def init_dtype(self):
self.dtype = np.complex128


class TestHardswishAPI(unittest.TestCase):
# test paddle.nn.Hardswish, paddle.nn.functional.hardswish
def setUp(self):
Expand Down