Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【complex op】No.7 add complex support for isclose #56723

Merged
merged 8 commits into from
Aug 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions paddle/phi/kernels/cpu/isclose_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,11 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/isclose_kernel_impl.h"

PD_REGISTER_KERNEL(
isclose, CPU, ALL_LAYOUT, phi::IscloseKernel, float, double) {}
PD_REGISTER_KERNEL(isclose,
CPU,
ALL_LAYOUT,
phi::IscloseKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/isclose_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,6 @@ PD_REGISTER_KERNEL(isclose,
phi::IscloseKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
87 changes: 87 additions & 0 deletions paddle/phi/kernels/impl/isclose_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
Expand Down Expand Up @@ -86,6 +87,40 @@ struct IscloseFunctor<phi::CPUContext, T> {
}
};

template <typename T>
struct IscloseFunctor<phi::CPUContext, phi::dtype::complex<T>> {
void operator()(const phi::CPUContext& ctx,
const DenseTensor& in,
const DenseTensor& other,
const double rtol,
const double atol,
bool equal_nan,
DenseTensor* output) {
auto* in_a = in.data<phi::dtype::complex<T>>();
auto* in_b = other.data<phi::dtype::complex<T>>();
auto* out_data = ctx.template Alloc<bool>(output);
auto num = in.numel();
// *out_data = true;
for (int i = 0; i < num; i++) {
out_data[i] = true;
}
for (int i = 0; i < num; i++) {
const phi::dtype::complex<T> a = in_a[i], b = in_b[i];
bool val;
if (std::isnan(a) || std::isnan(b)) {
val = equal_nan && std::isnan(a) == std::isnan(b);
} else {
T left = abs(a - b);
T right = atol + rtol * abs(b);
T diff = abs(left - right);
val = a == b || left <= right || diff <= 1e-15;
// *out_data &= val;
out_data[i] = val;
}
}
}
};

#if defined(__NVCC__) || defined(__HIPCC__)
template <typename T>
__global__ void IscloseCUDAKernel(const T* in_data,
Expand Down Expand Up @@ -113,7 +148,59 @@ __global__ void IscloseCUDAKernel(const T* in_data,
// if (!val) *out_data = false;
}
}
template <>
__global__ void IscloseCUDAKernel<phi::dtype::complex<float>>(
const phi::dtype::complex<float>* in_data,
const phi::dtype::complex<float>* other_data,
const double rtol,
const double atol,
bool equal_nan,
int num,
bool* out_data) {
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
bool val;
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
const phi::dtype::complex<float> a = in_data[i];
const phi::dtype::complex<float> b = other_data[i];
if (isnan(a) || isnan(b)) {
val = equal_nan && isnan(a) == isnan(b);
} else {
float left = abs(a - b);
float right = atol + rtol * abs(b);
float diff = abs(left - right);
val = a == b || left <= right || diff <= 1e-15;
}
out_data[i] = val;
// if (!val) *out_data = false;
}
}

template <>
__global__ void IscloseCUDAKernel<phi::dtype::complex<double>>(
const phi::dtype::complex<double>* in_data,
const phi::dtype::complex<double>* other_data,
const double rtol,
const double atol,
bool equal_nan,
int num,
bool* out_data) {
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
bool val;
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
const phi::dtype::complex<double> a = in_data[i];
const phi::dtype::complex<double> b = other_data[i];
if (isnan(a) || isnan(b)) {
val = equal_nan && isnan(a) == isnan(b);
} else {
double left = abs(a - b);
double right = atol + rtol * abs(b);
double diff = abs(left - right);
val = a == b || left <= right || diff <= 1e-15;
}
out_data[i] = val;
// if (!val) *out_data = false;
}
}
template <typename T>
struct GetTensorValue<phi::GPUContext, T> {
T operator()(const phi::GPUContext& dev_ctx,
Expand Down
14 changes: 10 additions & 4 deletions python/paddle/tensor/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1316,8 +1316,8 @@ def isclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None):
two tensors are elementwise equal within a tolerance.

Args:
x(Tensor): The input tensor, it's data type should be float16, float32, float64.
y(Tensor): The input tensor, it's data type should be float16, float32, float64.
x(Tensor): The input tensor, it's data type should be float16, float32, float64, complex64, complex128.
y(Tensor): The input tensor, it's data type should be float16, float32, float64, complex64, complex128.
rtol(rtoltype, optional): The relative tolerance. Default: :math:`1e-5` .
atol(atoltype, optional): The absolute tolerance. Default: :math:`1e-8` .
equal_nan(equalnantype, optional): If :math:`True` , then two :math:`NaNs` will be compared as equal. Default: :math:`False` .
Expand Down Expand Up @@ -1355,10 +1355,16 @@ def isclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None):
return _C_ops.isclose(x, y, rtol, atol, equal_nan)
else:
check_variable_and_dtype(
x, "input", ['float16', 'float32', 'float64'], 'isclose'
x,
"input",
['float16', 'float32', 'float64', 'complex64', 'complex128'],
'isclose',
)
check_variable_and_dtype(
y, "input", ['float16', 'float32', 'float64'], 'isclose'
y,
"input",
['float16', 'float32', 'float64', 'complex64', 'complex128'],
'isclose',
)
check_type(rtol, 'rtol', float, 'isclose')
check_type(atol, 'atol', float, 'isclose')
Expand Down
63 changes: 63 additions & 0 deletions test/legacy_test/test_isclose_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,69 @@ def test_check_output(self):
self.check_output()


class TestIscloseOpCp64(unittest.TestCase):
def test_cp64(self):
x_data = (
np.random.rand(10, 10) + 1.0j * np.random.rand(10, 10)
).astype(np.complex64)
y_data = (
np.random.rand(10, 10) + 1.0j * np.random.rand(10, 10)
).astype(np.complex64)
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(shape=[10, 10], name='x', dtype=np.complex64)
y = paddle.static.data(shape=[10, 10], name='y', dtype=np.complex64)
out = paddle.isclose(x, y, rtol=1e-05, atol=1e-08)
if core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
out = exe.run(feed={'x': x_data, 'y': y_data}, fetch_list=[out])


class TestIscloseOpCp128(unittest.TestCase):
def test_cp128(self):
x_data = (
np.random.rand(10, 10) + 1.0j * np.random.rand(10, 10)
).astype(np.complex128)
y_data = (
np.random.rand(10, 10) + 1.0j * np.random.rand(10, 10)
).astype(np.complex128)
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(
shape=[10, 10], name='x', dtype=np.complex128
)
y = paddle.static.data(
shape=[10, 10], name='y', dtype=np.complex128
)
out = paddle.isclose(x, y, rtol=1e-05, atol=1e-08)
if core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
out = exe.run(feed={'x': x_data, 'y': y_data}, fetch_list=[out])


class TestIscloseOpComplex64(TestIscloseOp):
def set_args(self):
self.input = np.array([10.1 + 0.1j]).astype(np.complex64)
self.other = np.array([10 + 0j]).astype(np.complex64)
self.rtol = np.array([0.01]).astype("float64")
self.atol = np.array([0]).astype("float64")
self.equal_nan = False


class TestIscloseOpComplex128(TestIscloseOp):
def set_args(self):
self.input = np.array([10.1 + 0.1j]).astype(np.complex128)
self.other = np.array([10 + 0j]).astype(np.complex128)
self.rtol = np.array([0.01]).astype("float64")
self.atol = np.array([0]).astype("float64")
self.equal_nan = False

def test_check_output(self):
self.check_output()


class TestIscloseOpLargeDimInput(TestIscloseOp):
def set_args(self):
self.input = np.array(np.zeros([2048, 1024])).astype("float64")
Expand Down