Skip to content

Commit

Permalink
feat: suppourt complex for softplus
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoyewww committed Sep 25, 2023
1 parent 3e06282 commit 8bcc691
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 11 deletions.
7 changes: 4 additions & 3 deletions paddle/phi/kernels/cpu/activation_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,8 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(stanh_grad, STanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_grad, RsqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(softplus_grad,
SoftplusGradKernel)

PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(relu_double_grad,
ReluDoubleGradKernel)
Expand All @@ -320,8 +321,8 @@ PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(sqrt_double_grad,
SqrtDoubleGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(rsqrt_double_grad,
RsqrtDoubleGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(softplus_double_grad,
SoftplusDoubleGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL_WITH_COMPLEX(softplus_double_grad,
SoftplusDoubleGradKernel)

PD_REGISTER_KERNEL(tanh_triple_grad,
CPU,
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/cpu/activation_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ PD_REGISTER_ACTIVATION_KERNEL(stanh, STanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel)
PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(softplus, SoftplusKernel)

PD_REGISTER_KERNEL(exp,
CPU,
Expand Down
55 changes: 54 additions & 1 deletion paddle/phi/kernels/funcs/activation_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,31 @@ struct SoftplusGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct SoftplusGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
float beta;
float threshold;
typename BaseActivationFunctor<ComplexType<T>>::AttrPair GetAttrs() {
return {{"beta", &beta}, {"threshold", &threshold}};
}
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
auto x_beta = static_cast<ComplexType<T>>(beta) * x; // NOLINT
dx.device(d) =
(x_beta > static_cast<ComplexType<T>>(threshold))
.select(dout,
dout / (static_cast<ComplexType<T>>(1) + (-x_beta).exp())
.unaryExpr(Conj<T>()));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct SoftplusDoubleGradFunctor : public BaseActivationFunctor<T> {
float beta;
Expand Down Expand Up @@ -3576,7 +3601,7 @@ struct CudaSoftplusFunctor : public BaseActivationFunctor<T> {
MPType x = static_cast<MPType>(arg_x);
MPType b = static_cast<MPType>(beta);
MPType t = static_cast<MPType>(threshold);
MPType x_beta = x * beta;
MPType x_beta = x * static_cast<MPType>(beta);
return static_cast<T>(x_beta > t ? x : log(one + exp(x_beta)) / b);
}
};
Expand Down Expand Up @@ -3606,6 +3631,34 @@ struct CudaSoftplusGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaSoftplusGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
using MPType = typename phi::dtype::MPTypeTrait<ComplexType<T>>::Type;
MPType one = static_cast<MPType>(1.0f);
float beta;
float threshold;

typename BaseActivationFunctor<ComplexType<T>>::AttrPair GetAttrs() {
return {{"beta", &beta}, {"threshold", &threshold}};
}

// dx = x * beta > threshold ? dout : dout / (1 + exp(-beta * x))
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> arg_dout, const ComplexType<T> arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
MPType b = static_cast<MPType>(beta);
MPType t = static_cast<MPType>(threshold);
MPType x_beta = x * static_cast<MPType>(beta);
return x_beta > t
? dout
: static_cast<ComplexType<T>>(dout / conj(one + exp(-x_beta)));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaAtanhGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
Expand Down
7 changes: 4 additions & 3 deletions paddle/phi/kernels/gpu/activation_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,10 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(relu6_grad, Relu6GradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(stanh_grad, STanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_double_grad,
SoftplusDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(softplus_grad,
SoftplusGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(softplus_double_grad,
SoftplusDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_double_grad, SqrtDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_grad, RsqrtGradKernel)
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/gpu/activation_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ PD_REGISTER_ACTIVATION_KERNEL(stanh, StanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel)
PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(softplus, SoftplusKernel)

PD_REGISTER_KERNEL(exp,
GPU,
Expand Down
14 changes: 12 additions & 2 deletions python/paddle/nn/functional/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1269,7 +1269,7 @@ def softplus(x, beta=1, threshold=20, name=None):
\end{cases}
Parameters:
x (Tensor): The input Tensor with data type float32, float64.
x (Tensor): The input Tensor with data type float32, float64, complex64, complex128.
beta (float, optional): The value of :math:`\beta` for softplus. Default is 1
threshold (float, optional): The value of :math:`\varepsilon` for softplus. Default is 20
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
Expand All @@ -1294,7 +1294,17 @@ def softplus(x, beta=1, threshold=20, name=None):
return _C_ops.softplus(x, beta, threshold)
else:
check_variable_and_dtype(
x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'softplus'
x,
'x',
[
'float16',
'uint16',
'float32',
'float64',
'complex64',
'complex128',
],
'softplus',
)
helper = LayerHelper('softplus', **locals())
out = helper.create_variable_for_type_inference(x.dtype)
Expand Down
18 changes: 18 additions & 0 deletions test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3760,6 +3760,11 @@ def setUp(self):

np.random.seed(1024)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
if self.dtype == np.complex64 or self.dtype == np.complex128:
x = (
np.random.uniform(-1, 1, self.shape)
+ 1j * np.random.uniform(-1, 1, self.shape)
).astype(self.dtype)
out = ref_softplus(x, beta, threshold)
self.inputs = {'X': x}
self.attrs = {'beta': beta, "threshold": threshold}
Expand All @@ -3774,6 +3779,19 @@ def test_check_grad(self):
self.check_grad(['X'], 'Out')


class TestSoftplus_Complex64(TestSoftplus):
def init_dtype(self):
self.dtype = np.complex64

def test_check_grad(self):
self.check_grad(['X'], 'Out', max_relative_error=0.06)


class TestSoftplus_Complex128(TestSoftplus):
def init_dtype(self):
self.dtype = np.complex128


class TestSoftplus_ZeroDim(TestSoftplus):
def init_shape(self):
self.shape = []
Expand Down

0 comments on commit 8bcc691

Please sign in to comment.