diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 6e12d6fa464cc7..aac77cf30da15a 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -1190,10 +1190,15 @@ void cos_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { } template -void scatter_grad(const Tensor& index, +void scatter_grad(const Tensor& x, + const Tensor& index, const Tensor& updates, + const Tensor& out, const Tensor& out_grad, bool overwrite, + int axis, + const std::string& reduce, + bool include_self, Tensor* x_grad, Tensor* updates_grad) { if (x_grad) { diff --git a/paddle/phi/api/include/tensor.h b/paddle/phi/api/include/tensor.h index b8d66f6c228c72..27d13778d09a63 100644 --- a/paddle/phi/api/include/tensor.h +++ b/paddle/phi/api/include/tensor.h @@ -700,7 +700,10 @@ class PADDLE_API Tensor final { const std::vector& axis = {}) const; Tensor scatter(const Tensor& index, const Tensor& updates, - bool overwrite = true) const; + bool overwrite = true, + int axis = 0, + const std::string& reduce = "add", + bool include_self = false) const; Tensor scatter_nd_add(const Tensor& index, const Tensor& updates) const; Tensor abs() const; Tensor assign() const; diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 157d34e28aaca0..819ec7e3b5a905 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1928,16 +1928,15 @@ invoke : scale(out_grad, scale, 0.0f, true) - backward_op : scatter_grad - forward : scatter (Tensor x, Tensor index, Tensor updates, bool overwrite=true) -> Tensor(out) - args : (Tensor index, Tensor updates, Tensor out_grad, bool overwrite) + forward : scatter (Tensor x, Tensor index, Tensor updates, bool overwrite=true, int axis=0, str reduce="add", bool include_self=false) -> Tensor(out) + args : (Tensor x, Tensor index, Tensor updates, Tensor out, Tensor out_grad, bool overwrite, int axis, str reduce, bool include_self) output : Tensor(x_grad), Tensor(updates_grad) infer_meta : func : ScatterGradInferMeta - param : [index, updates, out_grad, overwrite] + param : [index, updates, out_grad] kernel : func : scatter_grad - no_need_buffer : updates - composite: scatter_grad(index, updates, out_grad, overwrite, x_grad, updates_grad) + composite: scatter_grad(x, index, updates, out, out_grad, overwrite, axis, reduce, include_self, x_grad, updates_grad) - backward_op : scatter_nd_add_grad forward : scatter_nd_add (Tensor x, Tensor index, Tensor updates) -> Tensor(out) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 5bf57114402ee9..73b3ed71156646 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -2211,7 +2211,7 @@ backward : scale_grad - op : scatter - args : (Tensor x, Tensor index, Tensor updates, bool overwrite=true) + args : (Tensor x, Tensor index, Tensor updates, bool overwrite=true, int axis=0, str reduce="add", bool include_self=false) output : Tensor(out) infer_meta : func : ScatterInferMeta diff --git a/paddle/phi/backends/gpu/gpu_primitives.h b/paddle/phi/backends/gpu/gpu_primitives.h index bcf5220f65545c..9c0cdd5c7f19ec 100644 --- a/paddle/phi/backends/gpu/gpu_primitives.h +++ b/paddle/phi/backends/gpu/gpu_primitives.h @@ -395,6 +395,123 @@ CUDA_ATOMIC_WRAPPER(Add, complex) { CudaAtomicAdd(imag, val.imag)); } +// Atomic multiplication implementation. +CUDA_ATOMIC_WRAPPER(Mul, int64_t) { + // Here, we check long long int must be int64_t. + static_assert(sizeof(int64_t) == sizeof(long long int), // NOLINT + "long long should be int64"); + unsigned long long int *address_as_ull = // NOLINT + (unsigned long long int *)address; // NOLINT + unsigned long long int old = *address_as_ull; // NOLINT + unsigned long long int assumed; // NOLINT + + do { + assumed = old; + old = atomicCAS(address_as_ull, + assumed, + static_cast( // NOLINT + val * static_cast(assumed))); + // Note: uses integer comparison to avoid hang in case of NaN (since NaN + // != NaN) + } while (assumed != old); + + return static_cast(old); +} + +CUDA_ATOMIC_WRAPPER(Mul, int) { + int old = *address; + int assumed; + + do { + assumed = old; + old = atomicCAS(address, assumed, val * assumed); + + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != + // NaN) + } while (assumed != old); + + return old; +} + +#ifdef PADDLE_CUDA_FP16 +CUDA_ATOMIC_WRAPPER(Mul, phi::dtype::float16) { + unsigned int *address_as_ui = + (unsigned int *)((char *)address - ((size_t)address & 2)); // NOLINT + unsigned int old = *address_as_ui; + unsigned int assumed; + + phi::dtype::float16 hsum; + do { + assumed = old; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); // NOLINT + + hsum = hsum * val; + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) // NOLINT + : (old & 0xffff0000) | hsum.x; // NOLINT + old = atomicCAS(address_as_ui, assumed, old); + } while (assumed != old); + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); // NOLINT + return hsum; +} +#endif + +CUDA_ATOMIC_WRAPPER(Mul, phi::dtype::bfloat16) { + unsigned int *address_as_ui = + (unsigned int *)((char *)address - ((size_t)address & 2)); // NOLINT + unsigned int old = *address_as_ui; + unsigned int assumed; + + phi::dtype::bfloat16 bsum; + do { + assumed = old; + bsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); // NOLINT + bsum = bsum * val; + old = (size_t)address & 2 ? (old & 0xffff) | (bsum.x << 16) // NOLINT + : (old & 0xffff0000) | bsum.x; // NOLINT + old = atomicCAS(address_as_ui, assumed, old); + } while (assumed != old); + bsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); // NOLINT + return bsum; +} + +CUDA_ATOMIC_WRAPPER(Mul, double) { + unsigned long long int *address_as_ull = // NOLINT + (unsigned long long int *)address; // NOLINT + unsigned long long int old = *address_as_ull; // NOLINT + unsigned long long int assumed; // NOLINT + + do { + assumed = old; + old = atomicCAS( + address_as_ull, + assumed, + __double_as_longlong(val * __longlong_as_double(assumed))); // NOLINT + // Note: uses integer comparison to avoid hang in case of NaN (since NaN + // != NaN) + } while (assumed != old); + + return __longlong_as_double(old); +} + +// Dont use a templated function for this since the addition function defaults +// to the CUDA built-in. +CUDA_ATOMIC_WRAPPER(Mul, float) { + unsigned int *address_as_ull = (unsigned int *)address; // NOLINT + unsigned int old = *address_as_ull; + unsigned int assumed; + + do { + assumed = old; + old = atomicCAS( + address_as_ull, assumed, __float_as_int(val * __int_as_float(assumed))); + + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != + // NaN) + } while (assumed != old); + + return __int_as_float(old); +} + // For atomicMax USE_CUDA_ATOMIC(Max, int); USE_CUDA_ATOMIC(Max, unsigned int); diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 4c5e130aab7a07..ddffb452a13b41 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -979,7 +979,6 @@ void RnnGradInferMeta(const MetaTensor& x, void ScatterGradInferMeta(const MetaTensor& index, const MetaTensor& updates, const MetaTensor& out_grad, - bool overwrite, MetaTensor* x_grad, MetaTensor* updates_grad) { const auto& dtype = out_grad.dtype(); diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 13dd392344f972..66b04ed315ffc6 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -401,7 +401,6 @@ void RnnGradInferMeta(const MetaTensor& x, void ScatterGradInferMeta(const MetaTensor& index, const MetaTensor& updates, const MetaTensor& out_grad, - bool overwrite, MetaTensor* x_grad, MetaTensor* updates_grad); diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index d86b25b7ba224f..70e1064338a65d 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -1061,6 +1061,9 @@ void ScatterInferMeta(const MetaTensor& x, const MetaTensor& index, const MetaTensor& updates, bool overwrite, + int axis, + const std::string& reduce, + bool include_self, MetaTensor* out) { const auto& updates_dims = updates.dims(); const auto& ref_dims = x.dims(); diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index 7272941504ff27..00c0fadd3551c0 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -187,6 +187,9 @@ void ScatterInferMeta(const MetaTensor& x, const MetaTensor& index, const MetaTensor& updates, bool overwrite, + int axis, + const std::string& reduce, + bool include_self, MetaTensor* out); void ScatterNdAddInferMeta(const MetaTensor& x, diff --git a/paddle/phi/kernels/bitwise_kernel.h b/paddle/phi/kernels/bitwise_kernel.h index 17307004f360e1..16e416a45c9eb0 100644 --- a/paddle/phi/kernels/bitwise_kernel.h +++ b/paddle/phi/kernels/bitwise_kernel.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/infermeta/binary.h" namespace phi { @@ -41,4 +42,17 @@ void BitwiseNotKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out); +template +DenseTensor BitwiseAnd(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y) { + DenseTensor dense_out; + MetaTensor meta_out(&dense_out); + MetaTensor meta_x(&x); + MetaTensor meta_y(&y); + ElementwiseInferMeta(meta_x, meta_y, &meta_out); + BitwiseAndKernel(dev_ctx, x, y, &dense_out); + return dense_out; +} + } // namespace phi diff --git a/paddle/phi/kernels/compare_kernel.h b/paddle/phi/kernels/compare_kernel.h index 5d1bb040febb1f..805b31d9d79a38 100644 --- a/paddle/phi/kernels/compare_kernel.h +++ b/paddle/phi/kernels/compare_kernel.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/infermeta/binary.h" namespace phi { @@ -43,4 +44,30 @@ DECALRE_COMPARE_KERNEL(NotEqual) DECALRE_COMPARE_ALL_KERNEL(EqualAll) #undef DECALRE_COMPARE_KERNEL +template +DenseTensor Equal(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y) { + DenseTensor dense_out; + MetaTensor meta_out(&dense_out); + MetaTensor meta_x(&x); + MetaTensor meta_y(&y); + CompareInferMeta(meta_x, meta_y, &meta_out); + EqualKernel(dev_ctx, x, y, &dense_out); + return dense_out; +} + +template +DenseTensor GreaterThan(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y) { + DenseTensor dense_out; + MetaTensor meta_out(&dense_out); + MetaTensor meta_x(&x); + MetaTensor meta_y(&y); + CompareInferMeta(meta_x, meta_y, &meta_out); + GreaterThanKernel(dev_ctx, x, y, &dense_out); + return dense_out; +} + } // namespace phi diff --git a/paddle/phi/kernels/cpu/scatter_grad_kernel.cc b/paddle/phi/kernels/cpu/scatter_grad_kernel.cc index d47e1c54976b2e..cb08d264e19257 100644 --- a/paddle/phi/kernels/cpu/scatter_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/scatter_grad_kernel.cc @@ -13,52 +13,248 @@ // limitations under the License. #include "paddle/phi/kernels/scatter_grad_kernel.h" +#include "glog/logging.h" #include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/bitwise_kernel.h" +#include "paddle/phi/kernels/cast_kernel.h" +#include "paddle/phi/kernels/compare_kernel.h" +#include "paddle/phi/kernels/elementwise_divide_kernel.h" +#include "paddle/phi/kernels/elementwise_multiply_kernel.h" #include "paddle/phi/kernels/funcs/gather.h" #include "paddle/phi/kernels/funcs/scatter.h" +#include "paddle/phi/kernels/index_add_kernel.h" +#include "paddle/phi/kernels/index_select_kernel.h" +#include "paddle/phi/kernels/put_along_axis_kernel.h" +#include "paddle/phi/kernels/reduce_any_kernel.h" +#include "paddle/phi/kernels/scatter_kernel.h" +#include "paddle/phi/kernels/where_kernel.h" namespace phi { template void ScatterGradKernel(const Context &ctx, + const DenseTensor &x, const DenseTensor &index, - const DenseTensor &updates UNUSED, + const DenseTensor &source, + const DenseTensor &out, const DenseTensor &out_grad, - bool overwrite UNUSED, + bool overwrite, + int axis, + const std::string &reduce, + bool include_self, DenseTensor *x_grad, DenseTensor *updates_grad) { const auto &index_type = index.dtype(); - bool index_type_match = - index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64; - PADDLE_ENFORCE_EQ(index_type_match, - true, - phi::errors::InvalidArgument( - "scatter_op index holds the wrong type, it holds [%s]," - "but desires to be [%s] or [%s]", - index_type, - phi::DataType::INT32, - phi::DataType::INT64)); + PADDLE_ENFORCE_EQ( + index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64, + true, + phi::errors::InvalidArgument( + "scatter_op index holds the wrong type, it holds [%s]," + "but desires to be [%s] or [%s]", + index_type, + phi::DataType::INT32, + phi::DataType::INT64)); - if (x_grad) { - phi::Copy(ctx, out_grad, ctx.GetPlace(), false, x_grad); - if (index_type == phi::DataType::INT32) { - phi::funcs::CPUScatterGradForX(ctx, index, x_grad); - } else { - phi::funcs::CPUScatterGradForX(ctx, index, x_grad); + PADDLE_ENFORCE_EQ( + reduce == "add" || reduce == "mul" || reduce == "muliply" || + reduce == "mean" || reduce == "amin" || reduce == "amax", + true, + phi::errors::InvalidArgument( + "Reduce holds the wrong value, it holds [%s]," + "but desires to be add, mul, multiply, mean, amin, amax.", + reduce)); + + if (axis < 0) { + axis += out_grad.dims().size(); + } + + std::string reducer = reduce; + if (overwrite) { + reducer = "assign"; + } + + DenseTensor new_index = index; + DenseTensor new_source = source; + if (index.dims().size() == 0) { + new_index.Resize({1}); + + if (source.dims().size() == x.dims().size() - 1) { + auto dims = vectorize(source.dims()); + dims.insert(dims.begin(), 1); + new_source.Resize(make_ddim(dims)); } } + if (x_grad) { + ctx.template Alloc(x_grad); + } + if (updates_grad) { ctx.template Alloc(updates_grad); - // Gradient by Gather: dUpdates = dO[Ids] - if (index_type == phi::DataType::INT32) { - phi::funcs::CPUGather(ctx, out_grad, index, updates_grad); - } else { - phi::funcs::CPUGather(ctx, out_grad, index, updates_grad); + } + + if (reducer == "add") { + if (x_grad) { + if (include_self) { + phi::Copy(ctx, out_grad, ctx.GetPlace(), false, x_grad); + } else { + *x_grad = Full(ctx, vectorize(out_grad.dims()), 0); + } + } + + if (updates_grad) { + *updates_grad = IndexSelect(ctx, out_grad, index, axis); + } + } else if (reducer == "mean") { + auto zeros = Full(ctx, vectorize(out_grad.dims()), 0); + auto ones = Full(ctx, vectorize(out_grad.dims()), 1); + auto counts = include_self ? ones : zeros; + + auto src_ones = Full(ctx, vectorize(new_source.dims()), 1); + auto src_cnts = + IndexAdd(ctx, counts, new_index, src_ones, axis); + + auto mask = Equal(ctx, src_cnts, zeros); + + auto N = Where(ctx, mask, ones, src_cnts); + + if (x_grad) { + *x_grad = Divide(ctx, out_grad, N); + } + + if (updates_grad) { + auto N_src = IndexSelect(ctx, N, index, axis); + + auto grad_src = IndexSelect(ctx, out_grad, index, axis); + + *updates_grad = Divide(ctx, grad_src, N_src); + } + } else if (reducer == "mul" || reducer == "muliply") { + auto zeros = Full(ctx, vectorize(out_grad.dims()), 0); + auto ones = Full(ctx, vectorize(out_grad.dims()), 1); + if (x_grad) { + auto mask = Equal(ctx, x, zeros); + auto masked_self = Where(ctx, mask, ones, x); + + auto masked_self_result = Scatter( + ctx, x, index, new_source, false, axis, reducer, include_self); + + auto grad_mul_masked_self_result = + Multiply(ctx, out_grad, masked_self_result); + *x_grad = + Divide(ctx, grad_mul_masked_self_result, masked_self); + } + + if (updates_grad) { + auto src_ones = Full(ctx, vectorize(new_source.dims()), 1); + auto src_zeros = Full(ctx, vectorize(new_source.dims()), 1); + auto src_zero = Equal(ctx, new_source, src_zeros); + auto src_zero_t = Cast(ctx, src_zero, x.dtype()); + + auto src_num_zeros_inner = + IndexAdd(ctx, zeros, new_index, src_zero_t, axis); + + auto src_num_zeros = + IndexSelect(ctx, src_num_zeros_inner, index, axis); + + auto src_num_zeros_equal_one = + Equal(ctx, src_num_zeros, src_ones); + + auto src_single_zero_bool = + BitwiseAnd(ctx, src_zero, src_num_zeros_equal_one); + + auto masked_src = + Where(ctx, src_single_zero_bool, src_ones, new_source); + + auto masked_src_result = Scatter( + ctx, x, index, masked_src, false, axis, reducer, include_self); + + auto grad_mul_masked_src_result = + Multiply(ctx, out_grad, masked_src_result); + auto grad_mul_masked_src_result_index_select = + IndexSelect(ctx, grad_mul_masked_src_result, index, axis); + + auto grad_mul_out = Multiply(ctx, out_grad, out); + + auto grad_mul_out_index_select = + IndexSelect(ctx, grad_mul_out, index, axis); + + auto src_masked_fill_one = + Where(ctx, src_zero, src_ones, new_source); + auto where_2 = Divide( + ctx, grad_mul_out_index_select, src_masked_fill_one); + + auto grad_src1 = + Where(ctx, + src_single_zero_bool, + grad_mul_masked_src_result_index_select, + where_2); + + auto tmp_ones = Full(ctx, vectorize(src_num_zeros.dims()), 1); + auto src_num_zeros_greater_one = + GreaterThan(ctx, src_num_zeros, tmp_ones); + auto src_num_zeros_greater_one_any = + Any(ctx, src_num_zeros_greater_one, {}, false); + + bool out_data = src_num_zeros_greater_one_any.template data()[0]; + if (out_data) { + VLOG(3) << "index_reduce(): Double backward is unsupported for " + "new_source when " + ">1 zeros in new_source are scattered to the same position " + "in x"; + *updates_grad = grad_src1; + } else { + *updates_grad = grad_src1; + } + } + + } else if (reducer == "amin" || reducer == "amax") { + auto value = IndexSelect(ctx, out, index, axis); + + auto self_is_result = Equal(ctx, x, out); + auto self_is_result_t = Cast(ctx, self_is_result, x.dtype()); + + auto source_is_result = Equal(ctx, new_source, value); + auto source_is_result_t = + Cast(ctx, source_is_result, x.dtype()); + + auto N_to_distribute = IndexAdd( + ctx, self_is_result_t, new_index, source_is_result_t, axis); + + auto grad_distributed = Divide(ctx, out_grad, N_to_distribute); + + if (x_grad) { + *x_grad = Multiply(ctx, self_is_result_t, grad_distributed); + } + + if (updates_grad) { + auto src_grad_dist = + IndexSelect(ctx, grad_distributed, index, axis); + + *updates_grad = + Multiply(ctx, source_is_result_t, src_grad_dist); } + } else if (reducer == "assign") { + if (x_grad) { + include_self = false; + } + + if (updates_grad) { + *updates_grad = IndexSelect(ctx, out_grad, index, axis); + } + } + + if (!include_self && x_grad) { + auto self_dims = out_grad.dims(); + auto zeros = Full(ctx, vectorize(self_dims), 0); + auto src_ones = Full(ctx, vectorize(new_source.dims()), 1); + auto src_cnts = IndexAdd(ctx, zeros, new_index, src_ones, axis); + auto mask = Equal(ctx, src_cnts, zeros); + *x_grad = Where(ctx, mask, out_grad, zeros); } } diff --git a/paddle/phi/kernels/cpu/scatter_kernel.cc b/paddle/phi/kernels/cpu/scatter_kernel.cc index 2c3e8a2f31d098..e097b662835b6d 100644 --- a/paddle/phi/kernels/cpu/scatter_kernel.cc +++ b/paddle/phi/kernels/cpu/scatter_kernel.cc @@ -17,6 +17,7 @@ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/cpu/scatter_kernel_impl.h" #include "paddle/phi/kernels/funcs/scatter.h" namespace phi { @@ -27,34 +28,85 @@ void ScatterKernel(const Context &ctx, const DenseTensor &index, const DenseTensor &updates, bool overwrite, + int axis, + const std::string &reduce, + bool include_self, DenseTensor *out) { - // In place output: Out = X, Out[Ids] = Updates - phi::Copy(ctx, x, ctx.GetPlace(), false, out); - // Apply ScatterUpdate: Out[index] = Updates[:] const auto &index_type = index.dtype(); - bool index_type_match = - index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64; PADDLE_ENFORCE_EQ( - index_type_match, + index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64, true, phi::errors::InvalidArgument("Index holds the wrong type, it holds [%s]," "but desires to be [%s] or [%s].", index_type, phi::DataType::INT32, phi::DataType::INT64)); - if (overwrite) { - if (index_type == phi::DataType::INT32) { - phi::funcs::ScatterAssign(ctx, updates, index, out); - } else { - phi::funcs::ScatterAssign(ctx, updates, index, out); + + PADDLE_ENFORCE_EQ( + reduce == "add" || reduce == "mul" || reduce == "muliply" || + reduce == "mean" || reduce == "amin" || reduce == "amax", + true, + phi::errors::InvalidArgument( + "Reduce holds the wrong value, it holds [%s]," + "but desires to be add, mul, multiply, mean, amin, amax.", + reduce)); + + DenseTensor new_index = index; + DenseTensor new_updates = updates; + + if (new_index.dims().size() == 2) { + PADDLE_ENFORCE_EQ( + index.dims()[1], + 1, + phi::errors::InvalidArgument("index.dims()[1] should be 1 when " + "index.dims().size() =2 in scatter_op." + "But received value is [%d]", + new_index.dims()[1])); + auto index_dim = new_index.dims()[0]; + new_index.Resize(make_ddim({index_dim})); + } else if (index.dims().size() == 0) { + new_index.Resize(make_ddim({1})); + + if (updates.dims().size() == x.dims().size() - 1) { + auto dims = vectorize(updates.dims()); + dims.insert(dims.begin(), 1); + new_updates.Resize(make_ddim(dims)); } } else { - if (index_type == phi::DataType::INT32) { - phi::funcs::ScatterAssignAdd(ctx, updates, index, out); - } else { - phi::funcs::ScatterAssignAdd(ctx, updates, index, out); - } + PADDLE_ENFORCE_EQ( + new_index.dims().size() == 1, + true, + phi::errors::InvalidArgument("index.dims().size() should be 1 in " + "scatter_op. But received value is [%d]", + new_index.dims().size())); } + + auto src_dims = updates.dims(); + auto dst_dims = out->dims(); + + if (new_index.dims().size() != 0) { + // check src shape and dst shape should match + for (int i = 1; i < src_dims.size(); i++) + PADDLE_ENFORCE_EQ( + src_dims[i], + dst_dims[i], + phi::errors::InvalidArgument( + "The dimensions of the source tensor and target tensor should" + " match, but received source tensor's %d-th dimension is %d," + "target tensor's %d-th dimension is %d.", + i, + src_dims[i], + i, + dst_dims[i])); + } + + std::string reducer = reduce; + if (overwrite) { + reducer = "assign"; + } + + IndexReduceBaseKernel( + ctx, x, new_index, new_updates, axis, reducer, include_self, out); } } // namespace phi diff --git a/paddle/phi/kernels/cpu/scatter_kernel_impl.h b/paddle/phi/kernels/cpu/scatter_kernel_impl.h new file mode 100644 index 00000000000000..42e3d8a2a2ae68 --- /dev/null +++ b/paddle/phi/kernels/cpu/scatter_kernel_impl.h @@ -0,0 +1,186 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "glog/logging.h" + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +#include "paddle/phi/common/memory_utils.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/utils/data_type.h" +#include "paddle/phi/kernels/cpu/index_add_impl.h" + +#include "paddle/phi/kernels/compare_kernel.h" +#include "paddle/phi/kernels/elementwise_divide_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/index_add_kernel.h" +#include "paddle/phi/kernels/where_kernel.h" + +namespace phi { + +template +void IndexReduceInner(const Context& ctx, + DenseTensor* input, + const DenseTensor& index, + int axis, + const std::string& reduce, + bool include_self, + DenseTensor* source, + DenseTensor* output) { + auto input_dim = input->dims(); + auto input_dim_size = input_dim.size(); + auto output_dim = output->dims(); + auto index_size = index.dims()[0]; + auto source_dim = source->dims(); + + const IndexT* index_data = index.data(); + ctx.template Alloc(output); + + auto zeros = Full(ctx, vectorize(input_dim), 0); + auto ones = Full(ctx, vectorize(input_dim), 1); + auto counts = include_self ? ones : zeros; + auto src_ones = Full(ctx, vectorize(source->dims()), 1); + auto src_cnts = IndexAdd(ctx, counts, index, src_ones, axis); + auto mask = Equal(ctx, src_cnts, zeros); + + if (include_self) { + phi::Copy(ctx, *input, ctx.GetPlace(), false, output); + } else { + T init_val; + if (reduce == "mul" || reduce == "multiply") { + init_val = static_cast(1); + } else if (reduce == "amin") { + init_val = std::numeric_limits::has_infinity + ? std::numeric_limits::infinity() + : std::numeric_limits::max(); + } else if (reduce == "amax") { + init_val = std::numeric_limits::has_infinity + ? -std::numeric_limits::infinity() + : std::numeric_limits::lowest(); + } else { + init_val = static_cast(0); + } + auto init = Full(ctx, vectorize(input_dim), init_val); + + auto out = Where(ctx, mask, *input, init); + phi::Copy(ctx, out, ctx.GetPlace(), false, output); + } + + auto slice_size = 1; + for (auto i = axis + 1; i < input_dim_size; i++) { + slice_size *= input_dim[i]; + } + auto outer_nums = 1; + for (auto i = 0; i < axis; i++) { + outer_nums *= input_dim[i]; + } + + for (int i = 0; i < index_size; i++) { + PADDLE_ENFORCE_GE( + index_data[i], + 0, + phi::errors::InvalidArgument( + "Variable value (index) of OP(index_add) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + input_dim[axis], + index_data[i])); + PADDLE_ENFORCE_LT( + index_data[i], + input_dim[axis], + phi::errors::InvalidArgument( + "Variable value (index) of OP(index_add) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + input_dim[axis], + index_data[i])); + } + + output->Resize(phi::make_ddim({outer_nums, input_dim[axis], slice_size})); + source->Resize(phi::make_ddim({outer_nums, index_size, slice_size})); + + auto source_tensor = EigenTensor::From(*source); + auto output_tensor = EigenTensor::From(*output); + + auto& place = *ctx.eigen_device(); + for (auto j = 0; j < index_size; j++) { + IndexT index_value = index_data[j]; + auto output_t = output_tensor.chip(index_value, 1); + auto source_t = source_tensor.chip(j, 1); + if (reduce == "add" || reduce == "mean") { + output_t.device(place) = output_t + source_t; + } else if (reduce == "mul" || reduce == "muliply") { + output_t.device(place) = output_t * source_t; + } else if (reduce == "amin") { + output_t.device(place) = output_t.cwiseMin(source_t); + } else if (reduce == "amax") { + output_t.device(place) = output_t.cwiseMax(source_t); + } else if (reduce == "assign") { + output_t.device(place) = source_t; + } + } + + output->Resize(output_dim); + source->Resize(source_dim); + + if (reduce == "mean") { + auto src_cnts_wo_zeros = Where(ctx, mask, ones, src_cnts); + auto out = Divide(ctx, *output, src_cnts_wo_zeros); + phi::Copy(ctx, out, ctx.GetPlace(), false, output); + } +} + +template +void IndexReduceBaseKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& index, + const DenseTensor& source, + int axis, + const std::string& reduce, + bool include_self, + DenseTensor* output) { + const auto& index_type = index.dtype(); + + if (axis < 0) { + axis += x.dims().size(); + } + + PADDLE_ENFORCE_LT( + axis, + x.dims().size(), + phi::errors::InvalidArgument( + "Axis value (axis) of OP(scatter) " + "expected >= 0 and < %ld, but got %ld. Please check axis " + "value.", + x.dims().size(), + axis)); + + auto inputs = x; + auto src = source; + if (index_type == phi::DataType::INT32) { + IndexReduceInner( + dev_ctx, &inputs, index, axis, reduce, include_self, &src, output); + } else if (index_type == phi::DataType::INT64) { + IndexReduceInner( + dev_ctx, &inputs, index, axis, reduce, include_self, &src, output); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/funcs/scatter.cu.h b/paddle/phi/kernels/funcs/scatter.cu.h index c3f0cf61986918..2fe305d6b38dbe 100644 --- a/paddle/phi/kernels/funcs/scatter.cu.h +++ b/paddle/phi/kernels/funcs/scatter.cu.h @@ -30,7 +30,8 @@ __global__ void ScatterInitCUDAKernel(const IndexT* indices, T* output, size_t output_count, size_t index_size, - size_t slice_size) { + size_t slice_size, + int init_val) { CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) { int64_t indices_i = i / slice_size; int64_t slice_i = i - indices_i * slice_size; // offset inside the slice @@ -46,7 +47,7 @@ __global__ void ScatterInitCUDAKernel(const IndexT* indices, scatter_i); int64_t out_i = scatter_i * slice_size + slice_i; - *(output + out_i) = static_cast(0); + *(output + out_i) = static_cast(init_val); } } @@ -57,7 +58,8 @@ __global__ void ScatterCUDAKernel(const T* params, size_t output_count, size_t index_size, size_t slice_size, - bool overwrite) { + bool overwrite, + int reduce) { CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) { int64_t indices_i = i / slice_size; int64_t slice_i = i - indices_i * slice_size; // offset inside the slice @@ -76,7 +78,13 @@ __global__ void ScatterCUDAKernel(const T* params, if (overwrite) { *(output + out_i) = *(params + i); } else { - phi::CudaAtomicAdd(output + out_i, *(params + i)); + if (reduce == 0) { + // sum + phi::CudaAtomicAdd(output + out_i, *(params + i)); + } else if (reduce == 1) { + // mul + phi::CudaAtomicMul(output + out_i, *(params + i)); + } } } } @@ -127,7 +135,8 @@ void GPUScatterAssign(const phi::GPUContext& ctx, const DenseTensor& src, const DenseTensor& index, DenseTensor* output, - bool overwrite = true) { + bool overwrite = true, + int reduce = 0) { if (index.dims().size() == 2) { PADDLE_ENFORCE_EQ( index.dims()[1], @@ -146,6 +155,13 @@ void GPUScatterAssign(const phi::GPUContext& ctx, index.dims().size())); } + PADDLE_ENFORCE_EQ( + reduce == 0 || reduce == 1, + true, + phi::errors::InvalidArgument("reduce should be 0, or 1 in scatter_op." + "But received value is [%d]", + reduce)); + int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0]; auto src_dims = src.dims(); @@ -173,8 +189,15 @@ void GPUScatterAssign(const phi::GPUContext& ctx, // if not overwrite mode, init data if (!overwrite) { - ScatterInitCUDAKernel<<>>( - p_index, p_output, output_dims[0], index_size, slice_size); + if (reduce == 0) { + // sum + ScatterInitCUDAKernel<<>>( + p_index, p_output, output_dims[0], index_size, slice_size, 0); + } else if (reduce == 1) { + // mul + ScatterInitCUDAKernel<<>>( + p_index, p_output, output_dims[0], index_size, slice_size, 1); + } } ScatterCUDAKernel<<>>(p_src, @@ -183,7 +206,8 @@ void GPUScatterAssign(const phi::GPUContext& ctx, output_dims[0], index_size, slice_size, - overwrite); + overwrite, + reduce); } // The function is only for scatter grad x, @@ -209,7 +233,7 @@ void GPUScatterGradForX(const phi::GPUContext& ctx, phi::backends::gpu::LimitGridDim(ctx, &grid); ScatterInitCUDAKernel<<>>( - p_index, p_output, dst_dims[0], index_size, slice_size); + p_index, p_output, dst_dims[0], index_size, slice_size, 0); } template diff --git a/paddle/phi/kernels/funcs/scatter.h b/paddle/phi/kernels/funcs/scatter.h index 5934f57b47ddec..a4fe1582e09bbf 100644 --- a/paddle/phi/kernels/funcs/scatter.h +++ b/paddle/phi/kernels/funcs/scatter.h @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/phi/common/place.h" #include "paddle/phi/core/ddim.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/eigen/common.h" diff --git a/paddle/phi/kernels/gpu/scatter_grad_kernel.cu b/paddle/phi/kernels/gpu/scatter_grad_kernel.cu index 33f02aad948832..ddfa8cdf3943aa 100644 --- a/paddle/phi/kernels/gpu/scatter_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/scatter_grad_kernel.cu @@ -21,45 +21,252 @@ #include "paddle/phi/kernels/funcs/gather.cu.h" #include "paddle/phi/kernels/funcs/scatter.cu.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/bitwise_kernel.h" +#include "paddle/phi/kernels/cast_kernel.h" +#include "paddle/phi/kernels/compare_kernel.h" +#include "paddle/phi/kernels/elementwise_divide_kernel.h" +#include "paddle/phi/kernels/elementwise_multiply_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/index_add_kernel.h" +#include "paddle/phi/kernels/index_select_kernel.h" +#include "paddle/phi/kernels/put_along_axis_kernel.h" +#include "paddle/phi/kernels/reduce_any_kernel.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" +#include "paddle/phi/kernels/scatter_kernel.h" +#include "paddle/phi/kernels/where_kernel.h" + namespace phi { template void ScatterGradKernel(const Context &ctx, + const DenseTensor &x, const DenseTensor &index, - const DenseTensor &updates, + const DenseTensor &source, + const DenseTensor &out, const DenseTensor &out_grad, bool overwrite, + int axis, + const std::string &reduce, + bool include_self, DenseTensor *x_grad, DenseTensor *updates_grad) { - auto index_type = index.dtype(); - bool index_type_match = - index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64; - PADDLE_ENFORCE_EQ(index_type_match, - true, - phi::errors::InvalidArgument( - "scatter_op index holds the wrong type, it holds [%s]," - "but desires to be [%s] or [%s]", - index_type, - phi::DataType::INT32, - phi::DataType::INT64)); + const auto &index_type = index.dtype(); + PADDLE_ENFORCE_EQ( + index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64, + true, + phi::errors::InvalidArgument( + "scatter_op index holds the wrong type, it holds [%s]," + "but desires to be [%s] or [%s]", + index_type, + phi::DataType::INT32, + phi::DataType::INT64)); - if (x_grad) { - phi::Copy(ctx, out_grad, ctx.GetPlace(), false, x_grad); - if (index_type == phi::DataType::INT32) { - phi::funcs::GPUScatterGradForX(ctx, index, x_grad); - } else { - phi::funcs::GPUScatterGradForX(ctx, index, x_grad); + PADDLE_ENFORCE_EQ( + reduce == "add" || reduce == "mul" || reduce == "muliply" || + reduce == "mean" || reduce == "amin" || reduce == "amax", + true, + phi::errors::InvalidArgument( + "Reduce holds the wrong value, it holds [%s]," + "but desires to be add, mul, multiply, mean, amin, amax.", + reduce)); + + if (axis < 0) { + axis += out_grad.dims().size(); + } + + std::string reducer = reduce; + if (overwrite) { + reducer = "assign"; + } + + DenseTensor new_index = index; + DenseTensor new_source = source; + if (index.dims().size() == 0) { + new_index.Resize({1}); + + if (new_source.dims().size() == x.dims().size() - 1) { + auto dims = vectorize(new_source.dims()); + dims.insert(dims.begin(), 1); + new_source.Resize(make_ddim(dims)); } } + if (x_grad) { + ctx.template Alloc(x_grad); + } + if (updates_grad) { ctx.template Alloc(updates_grad); - // Gradient by Gather: dUpdates = dO[Ids] - if (index_type == phi::DataType::INT32) { - phi::funcs::GPUGather(ctx, out_grad, index, updates_grad); - } else { - phi::funcs::GPUGather(ctx, out_grad, index, updates_grad); + } + + if (reducer == "add") { + if (x_grad) { + if (include_self) { + phi::Copy(ctx, out_grad, ctx.GetPlace(), false, x_grad); + } else { + *x_grad = Full(ctx, vectorize(out_grad.dims()), 0); + } + } + + if (updates_grad) { + *updates_grad = IndexSelect(ctx, out_grad, index, axis); + } + } else if (reducer == "mean") { + auto zeros = Full(ctx, vectorize(out_grad.dims()), 0); + auto ones = Full(ctx, vectorize(out_grad.dims()), 1); + auto counts = include_self ? ones : zeros; + + auto src_ones = Full(ctx, vectorize(new_source.dims()), 1); + auto src_cnts = + IndexAdd(ctx, counts, new_index, src_ones, axis); + + auto mask = Equal(ctx, src_cnts, zeros); + + auto N = Where(ctx, mask, ones, src_cnts); + + if (x_grad) { + *x_grad = Divide(ctx, out_grad, N); + } + + if (updates_grad) { + auto N_src = IndexSelect(ctx, N, index, axis); + + auto grad_src = IndexSelect(ctx, out_grad, index, axis); + + *updates_grad = Divide(ctx, grad_src, N_src); + } + } else if (reducer == "mul" || reducer == "muliply") { + auto zeros = Full(ctx, vectorize(out_grad.dims()), 0); + auto ones = Full(ctx, vectorize(out_grad.dims()), 1); + if (x_grad) { + auto mask = Equal(ctx, x, zeros); + auto masked_self = Where(ctx, mask, ones, x); + + auto masked_self_result = Scatter( + ctx, x, index, new_source, false, axis, reducer, include_self); + + auto grad_mul_masked_self_result = + Multiply(ctx, out_grad, masked_self_result); + *x_grad = + Divide(ctx, grad_mul_masked_self_result, masked_self); + } + + if (updates_grad) { + auto src_ones = Full(ctx, vectorize(new_source.dims()), 1); + auto src_zeros = Full(ctx, vectorize(new_source.dims()), 1); + auto src_zero = Equal(ctx, new_source, src_zeros); + auto src_zero_t = Cast(ctx, src_zero, x.dtype()); + + auto src_num_zeros_inner = + IndexAdd(ctx, zeros, new_index, src_zero_t, axis); + + auto src_num_zeros = + IndexSelect(ctx, src_num_zeros_inner, index, axis); + + auto src_num_zeros_equal_one = + Equal(ctx, src_num_zeros, src_ones); + + auto src_single_zero_bool = + BitwiseAnd(ctx, src_zero, src_num_zeros_equal_one); + + auto masked_src = + Where(ctx, src_single_zero_bool, src_ones, new_source); + + auto masked_src_result = Scatter( + ctx, x, index, masked_src, false, axis, reducer, include_self); + + auto grad_mul_masked_src_result = + Multiply(ctx, out_grad, masked_src_result); + auto grad_mul_masked_src_result_index_select = + IndexSelect(ctx, grad_mul_masked_src_result, index, axis); + + auto grad_mul_out = Multiply(ctx, out_grad, out); + + auto grad_mul_out_index_select = + IndexSelect(ctx, grad_mul_out, index, axis); + + auto src_masked_fill_one = + Where(ctx, src_zero, src_ones, new_source); + auto where_2 = Divide( + ctx, grad_mul_out_index_select, src_masked_fill_one); + + auto grad_src1 = + Where(ctx, + src_single_zero_bool, + grad_mul_masked_src_result_index_select, + where_2); + + auto tmp_ones = Full(ctx, vectorize(src_num_zeros.dims()), 1); + auto src_num_zeros_greater_one = + GreaterThan(ctx, src_num_zeros, tmp_ones); + + auto src_num_zeros_greater_one_any = + Any(ctx, src_num_zeros_greater_one, {}, false); + + DenseTensor flag_cpu; + phi::Copy(ctx, + src_num_zeros_greater_one_any, + phi::CPUPlace(), + false, + &flag_cpu); + bool out_data = flag_cpu.template data()[0]; + + if (out_data) { + VLOG(3) << "index_reduce(): Double backward is unsupported for " + "new_source when " + ">1 zeros in new_source are scattered to the same position " + "in x"; + *updates_grad = grad_src1; + } else { + *updates_grad = grad_src1; + } } + } else if (reducer == "amin" || reducer == "amax") { + auto value = IndexSelect(ctx, out, index, axis); + + auto self_is_result = Equal(ctx, x, out); + auto self_is_result_t = Cast(ctx, self_is_result, x.dtype()); + + auto source_is_result = Equal(ctx, new_source, value); + auto source_is_result_t = + Cast(ctx, source_is_result, x.dtype()); + + auto N_to_distribute = IndexAdd( + ctx, self_is_result_t, new_index, source_is_result_t, axis); + + auto grad_distributed = Divide(ctx, out_grad, N_to_distribute); + + if (x_grad) { + *x_grad = Multiply(ctx, self_is_result_t, grad_distributed); + } + + if (updates_grad) { + auto src_grad_dist = + IndexSelect(ctx, grad_distributed, index, axis); + + *updates_grad = + Multiply(ctx, source_is_result_t, src_grad_dist); + } + } else if (reducer == "assign") { + if (x_grad) { + include_self = false; + } + + if (updates_grad) { + *updates_grad = IndexSelect(ctx, out_grad, index, axis); + } + } + + if (!include_self && x_grad) { + auto self_dims = out_grad.dims(); + auto zeros = Full(ctx, vectorize(self_dims), 0); + auto src_ones = Full(ctx, vectorize(new_source.dims()), 1); + auto src_cnts = IndexAdd(ctx, zeros, new_index, src_ones, axis); + auto mask = Equal(ctx, src_cnts, zeros); + *x_grad = Where(ctx, mask, out_grad, zeros); } } diff --git a/paddle/phi/kernels/gpu/scatter_kernel.cu b/paddle/phi/kernels/gpu/scatter_kernel.cu index e5d8eb7704ba82..eb5e2f55a460b6 100644 --- a/paddle/phi/kernels/gpu/scatter_kernel.cu +++ b/paddle/phi/kernels/gpu/scatter_kernel.cu @@ -14,41 +14,365 @@ #include "paddle/phi/kernels/scatter_kernel.h" +#include "glog/logging.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/core/utils/data_type.h" #include "paddle/phi/kernels/funcs/scatter.cu.h" +#include "paddle/utils/flags.h" + +#include "paddle/phi/kernels/compare_kernel.h" +#include "paddle/phi/kernels/elementwise_divide_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/index_add_kernel.h" +#include "paddle/phi/kernels/where_kernel.h" + +PD_DECLARE_bool(cudnn_deterministic); namespace phi { +using phi::PADDLE_CUDA_NUM_THREADS; + +template +__global__ void index_reduce_cuda_kernel(const T* input, + const IndexT* index, + const T* source, + int reduce, + int64_t N, + int64_t stride, + int64_t size, + int64_t delta, + T* output) { + CUDA_KERNEL_LOOP_TYPE(idx, N, int64_t) { + int64_t pre_idx = idx / (stride * size); + int64_t dim_idx = idx % (stride * size) / stride; + IndexT src_dim_idx = index[dim_idx]; + int64_t input_idx = + idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride; + if (reduce == 0 || reduce == 1) { + phi::CudaAtomicAdd(&output[input_idx], source[idx]); + } else if (reduce == 2) { + phi::CudaAtomicMul(&output[input_idx], source[idx]); + } else if (reduce == 3) { + phi::CudaAtomicMin(&output[input_idx], source[idx]); + } else if (reduce == 4) { + phi::CudaAtomicMax(&output[input_idx], source[idx]); + } else if (reduce == 5) { + output[input_idx] = source[idx]; + } + } +} + +template +__global__ void index_fill_cuda_kernel(const T* input, + const IndexT* index, + const T* source, + T init_val, + int64_t N, + int64_t stride, + int64_t size, + int64_t delta, + T* output) { + CUDA_KERNEL_LOOP_TYPE(idx, N, int64_t) { + int64_t pre_idx = idx / (stride * size); + int64_t dim_idx = idx % (stride * size) / stride; + IndexT src_dim_idx = index[dim_idx]; + int64_t input_idx = + idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride; + + output[input_idx] = init_val; + } +} + template -void ScatterKernel(const Context &ctx, - const DenseTensor &x, - const DenseTensor &index, - const DenseTensor &updates, +void IndexReduceKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& index, + const DenseTensor& source, + int axis, + const std::string& reduce, + bool include_self, + DenseTensor* output) { + auto input_dim = x.dims(); + auto output_dim = output->dims(); + auto source_dim = source.dims(); + + const auto& index_type = index.dtype(); + + int dim = axis; + dim = dim >= 0 ? dim : dim + input_dim.size(); + + auto stride_dim = phi::stride(input_dim); + int64_t stride = stride_dim[dim]; + int64_t size = source_dim[dim]; + int64_t delta = input_dim[dim] - size; + + auto* in_data = x.data(); + T* out_data = ctx.template Alloc(output); + auto* source_data = source.data(); + + int64_t numel = source.numel(); + if (numel == 0) { + return; + } + + auto stream = ctx.stream(); + + unsigned int block_dim = PADDLE_CUDA_NUM_THREADS; + dim3 grid_dim = dim3((numel + block_dim - 1) / block_dim); + phi::backends::gpu::LimitGridDim(ctx, &grid_dim); + + if (FLAGS_cudnn_deterministic) { + VLOG(2) << "Run grad kernel of index_add with single thread."; + block_dim = 1; + grid_dim.x = 1; + } + + phi::Copy(ctx, x, ctx.GetPlace(), false, output); + if (!include_self) { + T init_val; + if (reduce == "mul" || reduce == "multiply") { + init_val = static_cast(1); + } else if (reduce == "amin") { + init_val = std::numeric_limits::has_infinity + ? std::numeric_limits::infinity() + : std::numeric_limits::max(); + } else if (reduce == "amax") { + init_val = std::numeric_limits::has_infinity + ? -std::numeric_limits::infinity() + : std::numeric_limits::lowest(); + } else { + init_val = static_cast(0); + } + + if (index_type == phi::DataType::INT64) { + const int64_t* index_data = index.data(); + index_fill_cuda_kernel + <<>>(in_data, + index_data, + source_data, + init_val, + numel, + stride, + size, + delta, + out_data); + } else { + const int* index_data = index.data(); + index_fill_cuda_kernel + <<>>(in_data, + index_data, + source_data, + init_val, + numel, + stride, + size, + delta, + out_data); + } + } + + int reduce_type = 0; + if (reduce == "add") { + reduce_type = 0; + } else if (reduce == "mean") { + reduce_type = 1; + } else if (reduce == "mul" || reduce == "multiply") { + reduce_type = 2; + } else if (reduce == "amin") { + reduce_type = 3; + } else if (reduce == "amax") { + reduce_type = 4; + } else if (reduce == "assign") { + reduce_type = 5; + } + + if (index_type == phi::DataType::INT64) { + const int64_t* index_data = index.data(); + index_reduce_cuda_kernel + <<>>(in_data, + index_data, + source_data, + reduce_type, + numel, + stride, + size, + delta, + out_data); + } else { + const int* index_data = index.data(); + index_reduce_cuda_kernel + <<>>(in_data, + index_data, + source_data, + reduce_type, + numel, + stride, + size, + delta, + out_data); + } + + if (reduce == "mean") { + auto zeros = Full(ctx, vectorize(input_dim), 0); + auto ones = Full(ctx, vectorize(input_dim), 1); + auto counts = include_self ? ones : zeros; + auto src_ones = Full(ctx, vectorize(source.dims()), 1); + auto src_cnts = IndexAdd(ctx, counts, index, src_ones, dim); + auto mask = Equal(ctx, src_cnts, zeros); + + auto src_cnts_wo_zeros = Where(ctx, mask, ones, src_cnts); + auto out = Divide(ctx, *output, src_cnts_wo_zeros); + phi::Copy(ctx, out, ctx.GetPlace(), false, output); + } +} + +template +void ScatterKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& index, + const DenseTensor& updates, bool overwrite, - DenseTensor *out) { - phi::Copy(ctx, x, ctx.GetPlace(), false, out); - // use template class to support int32_t and int64_t + int axis, + const std::string& reduce, + bool include_self, + DenseTensor* out) { auto index_type = index.dtype(); - bool index_type_match = - index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64; - PADDLE_ENFORCE_EQ(index_type_match, - true, - phi::errors::InvalidArgument( - "scatter_op Index holds the wrong type, it holds [%s]," - "but desires to be [%s] or [%s].", - index_type, - phi::DataType::INT32, - phi::DataType::INT64)); - if (index_type == phi::DataType::INT32) { - phi::funcs::GPUScatterAssign( - ctx, updates, index, out, overwrite); + PADDLE_ENFORCE_EQ( + index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64, + true, + phi::errors::InvalidArgument( + "scatter_op Index holds the wrong type, it holds [%s]," + "but desires to be [%s] or [%s].", + index_type, + phi::DataType::INT32, + phi::DataType::INT64)); + + PADDLE_ENFORCE_EQ( + reduce == "add" || reduce == "mul" || reduce == "muliply" || + reduce == "mean" || reduce == "amin" || reduce == "amax", + true, + phi::errors::InvalidArgument( + "Reduce holds the wrong value, it holds [%s]," + "but desires to be add, mul, multiply, mean, amin, amax.", + reduce)); + + DenseTensor new_index = index; + DenseTensor new_updates = updates; + + if (index.dims().size() == 2) { + PADDLE_ENFORCE_EQ( + index.dims()[1], + 1, + phi::errors::InvalidArgument("index.dims()[1] should be 1 when " + "index.dims().size() =2 in scatter_op." + "But received value is [%d]", + index.dims()[1])); + auto index_dim = new_index.dims()[0]; + new_index.Resize(make_ddim({index_dim})); + } else if (index.dims().size() == 0) { + new_index.Resize(make_ddim({1})); + + if (updates.dims().size() == x.dims().size() - 1) { + auto dims = vectorize(updates.dims()); + dims.insert(dims.begin(), 1); + new_updates.Resize(make_ddim(dims)); + } } else { - phi::funcs::GPUScatterAssign( - ctx, updates, index, out, overwrite); + PADDLE_ENFORCE_EQ( + index.dims().size() == 1, + true, + phi::errors::InvalidArgument("index.dims().size() should be 1 in " + "scatter_op. But received value is [%d]", + index.dims().size())); } + + auto src_dims = updates.dims(); + auto dst_dims = out->dims(); + + if (new_index.dims().size() != 0) { + // check src shape and dst shape should match + for (int i = 1; i < src_dims.size(); i++) + PADDLE_ENFORCE_EQ( + src_dims[i], + dst_dims[i], + phi::errors::InvalidArgument( + "The dimensions of the source tensor and target tensor should" + " match, but received source tensor's %d-th dimension is %d," + "target tensor's %d-th dimension is %d.", + i, + src_dims[i], + i, + dst_dims[i])); + } + + auto input_dim = x.dims(); + axis = axis >= 0 ? axis : axis + input_dim.size(); + int index_size = new_index.dims().size(); + + DenseTensor index_cpu; + phi::Copy(ctx, new_index, phi::CPUPlace(), false, &index_cpu); + + for (int i = 0; i < index_size; i++) { + if (index_type == phi::DataType::INT32) { + const int* index_data = index_cpu.data(); + + PADDLE_ENFORCE_GE( + index_data[i], + 0, + phi::errors::InvalidArgument( + "Variable value (index) of OP(index_add) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + input_dim[axis], + index_data[i])); + PADDLE_ENFORCE_LT( + index_data[i], + input_dim[axis], + phi::errors::InvalidArgument( + "Variable value (index) of OP(index_add) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + input_dim[axis], + index_data[i])); + + } else if (index_type == phi::DataType::INT64) { + const int64_t* index_data = index_cpu.data(); + + PADDLE_ENFORCE_GE( + index_data[i], + 0, + phi::errors::InvalidArgument( + "Variable value (index) of OP(index_add) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + input_dim[axis], + index_data[i])); + PADDLE_ENFORCE_LT( + index_data[i], + input_dim[axis], + phi::errors::InvalidArgument( + "Variable value (index) of OP(index_add) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + input_dim[axis], + index_data[i])); + } + } + + std::string reducer = reduce; + if (overwrite) { + reducer = "assign"; + } + + IndexReduceKernel( + ctx, x, new_index, new_updates, axis, reducer, include_self, out); } } // namespace phi diff --git a/paddle/phi/kernels/index_add_kernel.h b/paddle/phi/kernels/index_add_kernel.h index 62693af8229426..3141782b2db1b3 100644 --- a/paddle/phi/kernels/index_add_kernel.h +++ b/paddle/phi/kernels/index_add_kernel.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/infermeta/binary.h" namespace phi { @@ -25,4 +26,21 @@ void IndexAddKernel(const Context& ctx, const DenseTensor& add_value, int axis, DenseTensor* output); + +template +DenseTensor IndexAdd(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& index, + const DenseTensor& add_value, + int axis) { + DenseTensor dense_out; + MetaTensor meta_out(&dense_out); + MetaTensor meta_x(&x); + MetaTensor meta_index(&index); + MetaTensor meta_add_value(&add_value); + IndexAddInferMeta(meta_x, meta_index, meta_add_value, axis, &meta_out); + IndexAddKernel(dev_ctx, x, index, add_value, axis, &dense_out); + return dense_out; +} + } // namespace phi diff --git a/paddle/phi/kernels/index_select_kernel.h b/paddle/phi/kernels/index_select_kernel.h index 295c5aeaece5b7..83aff7b1528d43 100644 --- a/paddle/phi/kernels/index_select_kernel.h +++ b/paddle/phi/kernels/index_select_kernel.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/infermeta/binary.h" namespace phi { @@ -32,4 +33,18 @@ void IndexSelectStridedKernel(const Context& ctx, int dim, DenseTensor* output); +template +DenseTensor IndexSelect(const Context& ctx, + const DenseTensor& x, + const DenseTensor& index, + int dim) { + DenseTensor dense_out; + MetaTensor meta_out(&dense_out); + MetaTensor meta_x(&x); + MetaTensor meta_index(&index); + IndexSelectInferMeta(meta_x, meta_index, dim, &meta_out); + IndexSelectKernel(ctx, x, index, dim, &dense_out); + return dense_out; +} + } // namespace phi diff --git a/paddle/phi/kernels/reduce_any_kernel.h b/paddle/phi/kernels/reduce_any_kernel.h index 9514d02dbdf94a..782ac86b7562d2 100644 --- a/paddle/phi/kernels/reduce_any_kernel.h +++ b/paddle/phi/kernels/reduce_any_kernel.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/infermeta/unary.h" namespace phi { template @@ -32,4 +33,17 @@ void AnyKernel(const Context& dev_ctx, bool keep_dim, DenseTensor* out); +template +DenseTensor Any(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim) { + DenseTensor dense_out; + MetaTensor meta_out(&dense_out); + MetaTensor meta_x(&x); + ReduceInferMeta(meta_x, dims, keep_dim, &meta_out); + AnyKernel(dev_ctx, x, dims, keep_dim, &dense_out); + return dense_out; +} + } // namespace phi diff --git a/paddle/phi/kernels/scatter_grad_kernel.h b/paddle/phi/kernels/scatter_grad_kernel.h index cf1482fca7f667..4b09c34898b116 100644 --- a/paddle/phi/kernels/scatter_grad_kernel.h +++ b/paddle/phi/kernels/scatter_grad_kernel.h @@ -19,10 +19,15 @@ namespace phi { template void ScatterGradKernel(const Context &ctx, + const DenseTensor &x, const DenseTensor &index, const DenseTensor &updates, + const DenseTensor &out, const DenseTensor &out_grad, bool overwrite, + int axis, + const std::string &reduce, + bool include_self, DenseTensor *x_grad, DenseTensor *updates_grad); diff --git a/paddle/phi/kernels/scatter_kernel.h b/paddle/phi/kernels/scatter_kernel.h index 5191d6bce45f26..d8b93d4a1c34d4 100644 --- a/paddle/phi/kernels/scatter_kernel.h +++ b/paddle/phi/kernels/scatter_kernel.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/infermeta/ternary.h" namespace phi { @@ -24,6 +25,43 @@ void ScatterKernel(const Context &ctx, const DenseTensor &index, const DenseTensor &updates, bool overwrite, + int axis, + const std::string &reduce, + bool include_self, DenseTensor *out); +template +DenseTensor Scatter(const Context &ctx, + const DenseTensor &x, + const DenseTensor &index, + const DenseTensor &updates, + bool overwrite, + int axis, + const std::string &reduce, + bool include_self) { + DenseTensor dense_out; + MetaTensor meta_out(&dense_out); + MetaTensor meta_x(&x); + MetaTensor meta_index(&index); + MetaTensor meta_source(&updates); + ScatterInferMeta(meta_x, + meta_index, + meta_source, + overwrite, + axis, + reduce, + include_self, + &meta_out); + ScatterKernel(ctx, + x, + index, + updates, + overwrite, + axis, + reduce, + include_self, + &dense_out); + return dense_out; +} + } // namespace phi diff --git a/paddle/phi/kernels/where_kernel.h b/paddle/phi/kernels/where_kernel.h index 6348177e697647..e3d350b4757e58 100644 --- a/paddle/phi/kernels/where_kernel.h +++ b/paddle/phi/kernels/where_kernel.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/infermeta/multiary.h" namespace phi { @@ -25,4 +26,18 @@ void WhereKernel(const Context& ctx, const DenseTensor& y, DenseTensor* out); +template +DenseTensor Where(const Context& ctx, + const DenseTensor& condition, + const DenseTensor& x, + const DenseTensor& y) { + DenseTensor dense_out; + MetaTensor meta_out(&dense_out); + MetaTensor meta_cond(&condition); + MetaTensor meta_x(&x); + MetaTensor meta_y(&y); + WhereInferMeta(meta_cond, meta_x, meta_y, &meta_out); + WhereKernel(ctx, condition, x, y, &dense_out); + return dense_out; +} } // namespace phi diff --git a/paddle/phi/kernels/xpu/scatter_grad_kernel.cc b/paddle/phi/kernels/xpu/scatter_grad_kernel.cc index 2d4007d6a2a8e6..ecb56a92cecb53 100644 --- a/paddle/phi/kernels/xpu/scatter_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/scatter_grad_kernel.cc @@ -21,10 +21,15 @@ namespace phi { template void ScatterGradKernel(const Context &ctx, + const DenseTensor &x, const DenseTensor &index, const DenseTensor &updates, + const DenseTensor &out, const DenseTensor &out_grad, bool overwrite, + int axis, + const std::string &reduce, + bool include_self, DenseTensor *x_grad, DenseTensor *updates_grad) { using XPUType = typename XPUTypeTrait::Type; @@ -86,6 +91,7 @@ void ScatterGradKernel(const Context &ctx, } PADDLE_ENFORCE_XDNN_SUCCESS(r, "scatter grad"); } + } // namespace phi PD_REGISTER_KERNEL(scatter_grad, diff --git a/paddle/phi/kernels/xpu/scatter_kernel.cc b/paddle/phi/kernels/xpu/scatter_kernel.cc index 9052cd5b5f5f0d..7702cf966e2ef9 100644 --- a/paddle/phi/kernels/xpu/scatter_kernel.cc +++ b/paddle/phi/kernels/xpu/scatter_kernel.cc @@ -26,6 +26,9 @@ void ScatterKernel(const Context &ctx, const DenseTensor &index, const DenseTensor &updates, bool overwrite, + int axis, + const std::string &reduce, + bool include_self, DenseTensor *out) { using XPUTypeT = typename XPUTypeTrait::Type; out->Resize(x.dims()); diff --git a/python/paddle/nn/initializer/dirac.py b/python/paddle/nn/initializer/dirac.py index 3673a20ddbf619..41069f3214d9b1 100644 --- a/python/paddle/nn/initializer/dirac.py +++ b/python/paddle/nn/initializer/dirac.py @@ -258,7 +258,7 @@ def __call__(self, var, block=None): if framework.in_dygraph_mode(): with base.dygraph.no_grad(): tmp_out = _C_ops.scatter( - out_var, index_tensor, value_tensor, True + out_var, index_tensor, value_tensor, True, 0, 'add', False ) tmp_out._share_underline_tensor_to(out_var) tmp_reshape_out = _C_ops.reshape(out_var, origin_shape) @@ -274,7 +274,12 @@ def __call__(self, var, block=None): "Ids": index_tensor, "Updates": value_tensor, }, - attrs={'overwrite': True}, + attrs={ + 'overwrite': True, + "axis": 0, + "reduce": "add", + "include_self": False, + }, outputs={"Out": out_var}, stop_gradient=True, ) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 203b98c78683b5..30c4c55ecdc024 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -3010,36 +3010,19 @@ def unbind(input, axis=0): return outs -def scatter(x, index, updates, overwrite=True, name=None): +def scatter( + x, + index, + updates, + overwrite=True, + axis=0, + reduce='add', + include_self=False, + name=None, +): """ **Scatter Layer** Output is obtained by updating the input on selected indices based on updates. - - .. code-block:: python - :name: code-example1 - - >>> import paddle - >>> #input: - >>> x = paddle.to_tensor([[1, 1], [2, 2], [3, 3]], dtype='float32') - >>> index = paddle.to_tensor([2, 1, 0, 1], dtype='int64') - >>> # shape of updates should be the same as x - >>> # shape of updates with dim > 1 should be the same as input - >>> updates = paddle.to_tensor([[1, 1], [2, 2], [3, 3], [4, 4]], dtype='float32') - >>> overwrite = False - >>> # calculation: - >>> if not overwrite: - ... for i in range(len(index)): - ... x[index[i]] = paddle.zeros([2]) - >>> for i in range(len(index)): - ... if (overwrite): - ... x[index[i]] = updates[i] - ... else: - ... x[index[i]] += updates[i] - >>> # output: - >>> out = paddle.to_tensor([[3, 3], [6, 6], [1, 1]]) - >>> print(out.shape) - [3, 2] - **NOTICE**: The order in which updates are applied is nondeterministic, so the output will be nondeterministic if index contains duplicates. @@ -3048,6 +3031,9 @@ def scatter(x, index, updates, overwrite=True, name=None): index (Tensor): The index is a 1-D or 0-D Tensor. Data type can be int32, int64. The length of index cannot exceed updates's length, and the value in index cannot exceed input's length. updates (Tensor): Update input with updates parameter based on index. When the index is a 1-D tensor, the updates shape should be the same as input, and dim value with dim > 1 should be the same as input. When the index is a 0-D tensor, the updates should be a (N-1)-D tensor, the ith dim of the updates should be queal with the (i+1)th dim of the input. overwrite (bool, optional): The mode that updating the output when there are same indices.If True, use the overwrite mode to update the output of the same index,if False, use the accumulate mode to update the output of the same index. Default value is True. + axis(int, optional): The axis along which the scatter operation is performed. The default is 0. + reduce (str, optional): Reduction operation to apply, can be either 'add', 'mul', 'multiply', 'mean', 'amin', 'amax'. The default is 'add'. + include_self(bool, optional): If True, the self index will be included in output. Default value is False. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` . Returns: @@ -3057,35 +3043,20 @@ def scatter(x, index, updates, overwrite=True, name=None): .. code-block:: python >>> import paddle - >>> x = paddle.to_tensor([[1, 1], [2, 2], [3, 3]], dtype='float32') >>> index = paddle.to_tensor([2, 1, 0, 1], dtype='int64') >>> updates = paddle.to_tensor([[1, 1], [2, 2], [3, 3], [4, 4]], dtype='float32') - >>> output1 = paddle.scatter(x, index, updates, overwrite=False) >>> print(output1) - Tensor(shape=[3, 2], dtype=float32, place=Place(cpu), stop_gradient=True, - [[3., 3.], - [6., 6.], - [1., 1.]]) - - >>> output2 = paddle.scatter(x, index, updates, overwrite=True) - >>> # CPU device: - >>> # [[3., 3.], - >>> # [4., 4.], - >>> # [1., 1.]] - >>> # GPU device maybe have two results because of the repeated numbers in index - >>> # result 1: - >>> # [[3., 3.], - >>> # [4., 4.], - >>> # [1., 1.]] - >>> # result 2: - >>> # [[3., 3.], - >>> # [2., 2.], - >>> # [1., 1.]] + Tensor(shape=[3, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [[3., 3.], + [6., 6.], + [1., 1.]]) """ if in_dynamic_or_pir_mode(): - return _C_ops.scatter(x, index, updates, overwrite) + return _C_ops.scatter( + x, index, updates, overwrite, axis, reduce, include_self + ) else: check_variable_and_dtype( x, @@ -3094,24 +3065,43 @@ def scatter(x, index, updates, overwrite=True, name=None): 'scatter', ) check_type(overwrite, 'overwrite', bool, 'scatter') + check_type(axis, 'axis', int, 'scatter') + check_type(reduce, 'reduce', str, 'scatter') + check_type(include_self, 'include_self', bool, 'scatter') helper = LayerHelper('scatter', **locals()) out = helper.create_variable_for_type_inference(x.dtype) helper.append_op( type="scatter", inputs={"X": x, "Ids": index, "Updates": updates}, - attrs={'overwrite': overwrite}, + attrs={ + 'overwrite': overwrite, + 'axis': axis, + 'reduce': reduce, + 'include_self': include_self, + }, outputs={"Out": out}, ) return out @inplace_apis_in_dygraph_only -def scatter_(x, index, updates, overwrite=True, name=None): +def scatter_( + x, + index, + updates, + overwrite=True, + axis=0, + reduce='add', + include_self=False, + name=None, +): """ Inplace version of ``scatter`` API, the output Tensor will be inplaced with input ``x``. Please refer to :ref:`api_paddle_tensor_scatter`. """ - return _C_ops.scatter_(x, index, updates, overwrite) + return _C_ops.scatter_( + x, index, updates, overwrite, axis, reduce, include_self + ) def scatter_nd_add(x, index, updates, name=None): diff --git a/test/legacy_test/test_initializer.py b/test/legacy_test/test_initializer.py index 44c952fa38d740..daa8369cbe9148 100644 --- a/test/legacy_test/test_initializer.py +++ b/test/legacy_test/test_initializer.py @@ -1185,7 +1185,6 @@ def test_dirac(self): weight_attr=self.weight_attr, ) weight_dygraph = conv.weight.numpy() - paddle.enable_static() start_prog = paddle.static.Program() main_prog = paddle.static.Program() diff --git a/test/legacy_test/test_inplace.py b/test/legacy_test/test_inplace.py index cac243f5e8682b..685e95f3c34817 100644 --- a/test/legacy_test/test_inplace.py +++ b/test/legacy_test/test_inplace.py @@ -411,6 +411,22 @@ def non_inplace_api_processing(self, var): return paddle.scatter(var, index, updates, overwrite=False) + def non_inplace_api_processing2(self, var): + index = paddle.to_tensor([2, 1, 0, 1], dtype='int64') + updates = paddle.to_tensor( + [[1, 1], [2, 2], [3, 3], [4, 4]], dtype='float32' + ) + + return paddle.scatter( + var, + index, + updates, + overwrite=False, + axis=0, + reduce='add', + include_self=True, + ) + def inplace_api_processing(self, var): index = paddle.to_tensor([2, 1, 0, 1], dtype='int64') updates = paddle.to_tensor( @@ -419,6 +435,22 @@ def inplace_api_processing(self, var): return paddle.scatter_(var, index, updates, overwrite=False) + def inplace_api_processing2(self, var): + index = paddle.to_tensor([2, 1, 0, 1], dtype='int64') + updates = paddle.to_tensor( + [[1, 1], [2, 2], [3, 3], [4, 4]], dtype='float32' + ) + + return paddle.scatter_( + var, + index, + updates, + overwrite=False, + axis=0, + reduce='add', + include_self=True, + ) + class TestDygraphInplaceElu(TestDygraphInplace): def non_inplace_api_processing(self, var): diff --git a/test/legacy_test/test_scatter_op.py b/test/legacy_test/test_scatter_op.py index d44982c6321d09..4430d19285f493 100644 --- a/test/legacy_test/test_scatter_op.py +++ b/test/legacy_test/test_scatter_op.py @@ -526,18 +526,20 @@ def setUp(self): self.prim_op_type = "prim" self.if_enable_cinn() self._set_dtype() - target_dtype = "float16" if self.dtype == np.float16 else "float32" - ref_np = np.ones((3, 50)).astype(target_dtype) - index_np = np.array([[1], [2]]).astype("int32") - updates_np = np.random.random((2, 50)).astype(target_dtype) - output_np = np.copy(ref_np) - output_np[np.array([1, 2]).astype("int32")] = updates_np + self.target_dtype = "float16" if self.dtype == np.float16 else "float32" + self._set_attr() + self._set_inout() + if self.dtype == np.uint16: - ref_np = convert_float_to_uint16(ref_np) - updates_np = convert_float_to_uint16(updates_np) - output_np = convert_float_to_uint16(output_np) - self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} - self.outputs = {'Out': output_np} + self.ref_np = convert_float_to_uint16(self.ref_np) + self.updates_np = convert_float_to_uint16(self.updates_np) + self.output_np = convert_float_to_uint16(self.output_np) + self.inputs = { + 'X': self.ref_np, + 'Ids': self.index_np, + 'Updates': self.updates_np, + } + self.outputs = {'Out': self.output_np} def if_enable_cinn(self): pass @@ -545,13 +547,26 @@ def if_enable_cinn(self): def _set_dtype(self): self.dtype = np.float32 + def _set_attr(self): + self.attrs = { + 'overwrite': True, + 'axis': 0, + 'reduce': "add", + 'include_self': False, + } + + def _set_inout(self): + self.ref_np = np.ones((3, 50)).astype(self.target_dtype) + self.index_np = np.array([1, 2]).astype("int32") + self.updates_np = np.random.random((2, 50)).astype(self.target_dtype) + self.output_np = np.copy(self.ref_np) + self.output_np[np.array([1, 2]).astype("int32")] = self.updates_np + def test_check_output(self): self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad( - ["X", "Updates"], "Out", check_prim=True, check_pir=True - ) + self.check_grad(["X", "Updates"], "Out") class TestScatterFP16Op6(TestScatterOp6): @@ -599,6 +614,7 @@ def executed_api(self): self.scatter = paddle.scatter def check_static_result(self, place): + paddle.enable_static() with paddle.static.program_guard( paddle.static.Program(), paddle.static.Program() ): @@ -633,6 +649,7 @@ def check_static_result(self, place): ).all(), True, ) + paddle.disable_static() @test_with_pir_api def test_static(self): @@ -733,24 +750,30 @@ def compute_ref_grad_updates(self): return ref_grad_updates def test_scatter_fp16(self): - paddle.disable_static(place=paddle.CUDAPlace(0)) - x_tensor = paddle.to_tensor(self.x_np, stop_gradient=False) - index_tensor = paddle.to_tensor(self.index_np) - updates_tensor = paddle.to_tensor(self.updates_np, stop_gradient=False) - out_tensor = paddle.scatter(x_tensor, index_tensor, updates_tensor) - paddle.autograd.backward( - [out_tensor], [paddle.to_tensor(self.dout_np)], retain_graph=True - ) - ref_grad_updates = self.compute_ref_grad_updates() - np.testing.assert_allclose( - ref_grad_updates.numpy(False), - updates_tensor.grad.numpy(False), - rtol=1e-5, - atol=1e-5, - ) - np.testing.assert_allclose( - self.ref_dx, x_tensor.grad.numpy(False), rtol=1e-5, atol=1e-5 - ) + paddle.disable_static() + if paddle.device.is_compiled_with_cuda(): + paddle.set_device('gpu:0') + x_tensor = paddle.to_tensor(self.x_np, stop_gradient=False) + index_tensor = paddle.to_tensor(self.index_np) + updates_tensor = paddle.to_tensor( + self.updates_np, stop_gradient=False + ) + out_tensor = paddle.scatter(x_tensor, index_tensor, updates_tensor) + paddle.autograd.backward( + [out_tensor], + [paddle.to_tensor(self.dout_np)], + retain_graph=True, + ) + ref_grad_updates = self.compute_ref_grad_updates() + np.testing.assert_allclose( + ref_grad_updates.numpy(False), + updates_tensor.grad.numpy(False), + rtol=1e-5, + atol=1e-5, + ) + np.testing.assert_allclose( + self.ref_dx, x_tensor.grad.numpy(False), rtol=1e-5, atol=1e-5 + ) class TestScatterInplaceAPI(TestScatterAPI): @@ -764,24 +787,1550 @@ def test_scatter_index(self): paddle.disable_static() x = paddle.to_tensor([[1, 1], [2, 2], [3, 3]], dtype='float32') - def test_neg_index(): - index = paddle.to_tensor([2, 1, -1, 1], dtype='int64') + def test_too_big_index(): + index = paddle.to_tensor([2, 1, 5, 1], dtype='int64') updates = paddle.to_tensor( [[1, 1], [2, 2], [3, 3], [4, 4]], dtype='float32' ) out = paddle.scatter(x, index, updates) - self.assertRaises(IndexError, test_neg_index) + self.assertRaises(ValueError, test_too_big_index) + paddle.enable_static() - def test_too_big_index(): - index = paddle.to_tensor([2, 1, 5, 1], dtype='int64') - updates = paddle.to_tensor( - [[1, 1], [2, 2], [3, 3], [4, 4]], dtype='float32' + +class TestScatterOp6ReduceAdd(TestScatterOp6): + def _set_attr(self): + self.attrs = { + 'overwrite': False, + 'axis': 0, + 'reduce': "add", + 'include_self': False, + } + + def _set_inout(self): + self.ref_np = np.array([[1, 1], [2, 2], [3, 3]]).astype( + self.target_dtype + ) + self.index_np = np.array([2, 1, 0, 1]).astype(np.int32) + self.updates_np = np.array([[1, 1], [2, 2], [3, 3], [4, 4]]).astype( + self.target_dtype + ) + self.output_np = np.array([[3.0, 3.0], [6.0, 6.0], [1.0, 1.0]]).astype( + self.target_dtype + ) + + +class TestScatterAPIReduceAdd(unittest.TestCase): + def setUp(self): + self.places = [base.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(base.CUDAPlace(0)) + self.executed_api() + + def executed_api(self): + self.scatter = paddle.scatter + + def check_static_result(self, place): + paddle.enable_static() + with base.program_guard(base.Program(), base.Program()): + input = paddle.static.data( + name="input", shape=[3, 2], dtype="float64" ) - out = paddle.scatter(x, index, updates) + index = paddle.static.data(name="index", shape=[4], dtype="int64") + updates = paddle.static.data( + name="updates", shape=[4, 2], dtype="float64" + ) + result = self.scatter( + input, + index, + updates, + overwrite=False, + axis=0, + reduce='add', + include_self=False, + ) + + input_data = np.array([[1, 1], [2, 2], [3, 3]]).astype(np.float64) + index_data = np.array([2, 1, 0, 1]).astype(np.int64) + updates_data = np.array([[1, 1], [2, 2], [3, 3], [4, 4]]).astype( + np.float64 + ) + + exe = base.Executor(place) + fetches = exe.run( + base.default_main_program(), + feed={ + "input": input_data, + "index": index_data, + "updates": updates_data, + }, + fetch_list=[result], + ) + self.assertEqual( + ( + fetches[0] == np.array([[3.0, 3.0], [6.0, 6.0], [1.0, 1.0]]) + ).all(), + True, + ) + paddle.disable_static() + + def test_static(self): + for place in self.places: + self.check_static_result(place=place) + + def test_dygraph(self): + for place in self.places: + with base.dygraph.guard(place): + x_data = np.array([[1, 1], [2, 2], [3, 3]]).astype(np.float64) + index_data = np.array([2, 1, 0, 1]).astype(np.int64) + updates_data = np.array( + [[1, 1], [2, 2], [3, 3], [4, 4]] + ).astype(np.float64) + + x = base.dygraph.to_variable(x_data) + index = base.dygraph.to_variable(index_data) + updates = base.dygraph.to_variable(updates_data) + + output1 = self.scatter( + x, + index, + updates, + overwrite=False, + reduce='add', + include_self=False, + ) + self.assertEqual( + ( + output1.numpy() + == np.array([[3.0, 3.0], [6.0, 6.0], [1.0, 1.0]]) + ).all(), + True, + ) + + def test_large_data(self): + if os.name == "nt" or not paddle.is_compiled_with_cuda(): + return + + x = np.random.rand(183826, 256).astype("float32") + index = np.ones(107592, dtype="int64") + updates = np.ones(shape=[107592, 256], dtype="float32") + + def test_dygraph(): + with base.dygraph.guard(): + gpu_out = paddle.scatter( + paddle.to_tensor(x), + paddle.to_tensor(index), + paddle.to_tensor(updates), + overwrite=False, + reduce='add', + include_self=False, + ) + return gpu_out.numpy() + + @switch_to_static_graph + def test_static_graph(): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x_t = paddle.static.data(name="x", dtype=x.dtype, shape=x.shape) + index_t = paddle.static.data( + name="index", dtype=index.dtype, shape=index.shape + ) + updates_t = paddle.static.data( + name="updates", dtype=updates.dtype, shape=updates.shape + ) + out_t = paddle.scatter( + x_t, + index_t, + updates_t, + overwrite=False, + reduce='add', + include_self=False, + ) + feed = { + x_t.name: x, + index_t.name: index, + updates_t.name: updates, + } + fetch = [out_t] + + gpu_exe = paddle.static.Executor(paddle.CUDAPlace(0)) + gpu_value = gpu_exe.run(feed=feed, fetch_list=fetch)[0] + return gpu_value + + np.testing.assert_array_equal(test_dygraph(), test_static_graph()) - self.assertRaises(IndexError, test_too_big_index) + def check_static_result2(self, place): paddle.enable_static() + with base.program_guard(base.Program(), base.Program()): + input = paddle.static.data( + name="input", shape=[3, 2], dtype="float64" + ) + index = paddle.static.data(name="index", shape=[4], dtype="int64") + updates = paddle.static.data( + name="updates", shape=[4, 2], dtype="float64" + ) + result = self.scatter( + input, + index, + updates, + overwrite=False, + reduce='add', + include_self=True, + ) + + input_data = np.array([[1, 1], [2, 2], [3, 3]]).astype(np.float64) + index_data = np.array([2, 1, 0, 1]).astype(np.int64) + updates_data = np.array([[1, 1], [2, 2], [3, 3], [4, 4]]).astype( + np.float64 + ) + + exe = base.Executor(place) + fetches = exe.run( + base.default_main_program(), + feed={ + "input": input_data, + "index": index_data, + "updates": updates_data, + "overwrite": False, + "reduce": "add", + "include_self": True, + }, + fetch_list=[result], + ) + self.assertEqual( + ( + fetches[0] == np.array([[4.0, 4.0], [8.0, 8.0], [4.0, 4.0]]) + ).all(), + True, + ) + paddle.disable_static() + + def test_static2(self): + for place in self.places: + self.check_static_result2(place=place) + + def test_dygraph2(self): + for place in self.places: + with base.dygraph.guard(place): + x_data = np.array([[1, 1], [2, 2], [3, 3]]).astype(np.float64) + index_data = np.array([2, 1, 0, 1]).astype(np.int64) + updates_data = np.array( + [[1, 1], [2, 2], [3, 3], [4, 4]] + ).astype(np.float64) + + x = base.dygraph.to_variable(x_data) + index = base.dygraph.to_variable(index_data) + updates = base.dygraph.to_variable(updates_data) + + output1 = self.scatter( + x, + index, + updates, + overwrite=False, + reduce='add', + include_self=True, + ) + self.assertEqual( + ( + output1.numpy() + == np.array([[4.0, 4.0], [8.0, 8.0], [4.0, 4.0]]) + ).all(), + True, + ) + + def test_large_data2(self): + if os.name == "nt" or not paddle.is_compiled_with_cuda(): + return + + x = np.random.rand(183826, 256).astype("float32") + index = np.ones(107592, dtype="int64") + updates = np.ones(shape=[107592, 256], dtype="float32") + + def test_dygraph(): + with base.dygraph.guard(): + gpu_out = paddle.scatter( + paddle.to_tensor(x), + paddle.to_tensor(index), + paddle.to_tensor(updates), + overwrite=False, + reduce='add', + include_self=True, + ) + return gpu_out.numpy() + + @switch_to_static_graph + def test_static_graph(): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x_t = paddle.static.data(name="x", dtype=x.dtype, shape=x.shape) + index_t = paddle.static.data( + name="index", dtype=index.dtype, shape=index.shape + ) + updates_t = paddle.static.data( + name="updates", dtype=updates.dtype, shape=updates.shape + ) + out_t = paddle.scatter( + x_t, + index_t, + updates_t, + overwrite=False, + reduce='add', + include_self=True, + ) + feed = { + x_t.name: x, + index_t.name: index, + updates_t.name: updates, + "overwrite": False, + "reduce": 'add', + "include_self": True, + } + fetch = [out_t] + + gpu_exe = paddle.static.Executor(paddle.CUDAPlace(0)) + gpu_value = gpu_exe.run(feed=feed, fetch_list=fetch)[0] + return gpu_value + + np.testing.assert_array_equal(test_dygraph(), test_static_graph()) + + +class TestScatterOp6ReduceMul(TestScatterOp6): + def _set_attr(self): + self.attrs = { + 'overwrite': False, + 'axis': 0, + 'reduce': "mul", + 'include_self': False, + } + + def _set_inout(self): + self.ref_np = np.array([[1, 1], [2, 2], [3, 3]]).astype( + self.target_dtype + ) + self.index_np = np.array([2, 1, 0, 1]).astype(np.int32) + self.updates_np = np.array([[1, 1], [2, 2], [3, 3], [4, 4]]).astype( + self.target_dtype + ) + self.output_np = np.array([[3.0, 3.0], [8.0, 8.0], [1.0, 1.0]]).astype( + self.target_dtype + ) + + +class TestScatterAPIReduceMul(unittest.TestCase): + def setUp(self): + self.places = [base.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(base.CUDAPlace(0)) + self.executed_api() + + def executed_api(self): + self.scatter = paddle.scatter + + def check_static_result(self, place): + paddle.enable_static() + with base.program_guard(base.Program(), base.Program()): + input = paddle.static.data( + name="input", shape=[3, 2], dtype="float64" + ) + index = paddle.static.data(name="index", shape=[4], dtype="int64") + updates = paddle.static.data( + name="updates", shape=[4, 2], dtype="float64" + ) + result = self.scatter( + input, + index, + updates, + overwrite=False, + reduce='mul', + include_self=False, + ) + + input_data = np.array([[1, 1], [2, 2], [3, 3]]).astype(np.float64) + index_data = np.array([2, 1, 0, 1]).astype(np.int64) + updates_data = np.array([[1, 1], [2, 2], [3, 3], [4, 4]]).astype( + np.float64 + ) + + exe = base.Executor(place) + fetches = exe.run( + base.default_main_program(), + feed={ + "input": input_data, + "index": index_data, + "updates": updates_data, + "overwrite": False, + "axis": 0, + "reduce": "mul", + "inlcude_self": False, + }, + fetch_list=[result], + ) + self.assertEqual( + ( + fetches[0] == np.array([[3.0, 3.0], [8.0, 8.0], [1.0, 1.0]]) + ).all(), + True, + ) + paddle.disable_static() + + def test_static(self): + for place in self.places: + self.check_static_result(place=place) + + def test_dygraph(self): + for place in self.places: + with base.dygraph.guard(place): + x_data = np.array([[1, 1], [2, 2], [3, 3]]).astype(np.float64) + index_data = np.array([2, 1, 0, 1]).astype(np.int64) + updates_data = np.array( + [[1, 1], [2, 2], [3, 3], [4, 4]] + ).astype(np.float64) + + x = base.dygraph.to_variable(x_data) + index = base.dygraph.to_variable(index_data) + updates = base.dygraph.to_variable(updates_data) + + output1 = self.scatter( + x, + index, + updates, + overwrite=False, + reduce='mul', + include_self=False, + ) + self.assertEqual( + ( + output1.numpy() + == np.array([[3.0, 3.0], [8.0, 8.0], [1.0, 1.0]]) + ).all(), + True, + ) + + def test_large_data(self): + if os.name == "nt" or not paddle.is_compiled_with_cuda(): + return + + x = np.random.rand(183826, 256).astype("float32") + index = np.ones(107592, dtype="int64") + updates = np.ones(shape=[107592, 256], dtype="float32") + + def test_dygraph(): + with base.dygraph.guard(): + gpu_out = paddle.scatter( + paddle.to_tensor(x), + paddle.to_tensor(index), + paddle.to_tensor(updates), + overwrite=False, + reduce='mul', + include_self=False, + ) + return gpu_out.numpy() + + @switch_to_static_graph + def test_static_graph(): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x_t = paddle.static.data(name="x", dtype=x.dtype, shape=x.shape) + index_t = paddle.static.data( + name="index", dtype=index.dtype, shape=index.shape + ) + updates_t = paddle.static.data( + name="updates", dtype=updates.dtype, shape=updates.shape + ) + out_t = paddle.scatter( + x_t, + index_t, + updates_t, + overwrite=False, + reduce='mul', + include_self=False, + ) + feed = { + x_t.name: x, + index_t.name: index, + updates_t.name: updates, + "overwrite": False, + "reduce": "mul", + "inlcude_self": False, + } + fetch = [out_t] + + gpu_exe = paddle.static.Executor(paddle.CUDAPlace(0)) + gpu_value = gpu_exe.run(feed=feed, fetch_list=fetch)[0] + return gpu_value + + np.testing.assert_array_equal(test_dygraph(), test_static_graph()) + + def check_static_result2(self, place): + paddle.enable_static() + with base.program_guard(base.Program(), base.Program()): + input = paddle.static.data( + name="input", shape=[3, 2], dtype="float64" + ) + index = paddle.static.data(name="index", shape=[4], dtype="int64") + updates = paddle.static.data( + name="updates", shape=[4, 2], dtype="float64" + ) + result = self.scatter( + input, + index, + updates, + overwrite=False, + reduce='mul', + include_self=True, + ) + + input_data = np.array([[1, 1], [2, 2], [3, 3]]).astype(np.float64) + index_data = np.array([2, 1, 0, 1]).astype(np.int64) + updates_data = np.array([[1, 1], [2, 2], [3, 3], [4, 4]]).astype( + np.float64 + ) + + exe = base.Executor(place) + fetches = exe.run( + base.default_main_program(), + feed={ + "input": input_data, + "index": index_data, + "updates": updates_data, + "overwrite": False, + "reduce": "mul", + "include_self": True, + }, + fetch_list=[result], + ) + self.assertEqual( + ( + fetches[0] + == np.array([[3.0, 3.0], [16.0, 16.0], [3.0, 3.0]]) + ).all(), + True, + ) + paddle.disable_static() + + def test_static2(self): + for place in self.places: + self.check_static_result2(place=place) + + def test_dygraph2(self): + for place in self.places: + with base.dygraph.guard(place): + x_data = np.array([[1, 1], [2, 2], [3, 3]]).astype(np.float64) + index_data = np.array([2, 1, 0, 1]).astype(np.int64) + updates_data = np.array( + [[1, 1], [2, 2], [3, 3], [4, 4]] + ).astype(np.float64) + + x = base.dygraph.to_variable(x_data) + index = base.dygraph.to_variable(index_data) + updates = base.dygraph.to_variable(updates_data) + + output1 = self.scatter( + x, + index, + updates, + overwrite=False, + reduce='mul', + include_self=True, + ) + self.assertEqual( + ( + output1.numpy() + == np.array([[3.0, 3.0], [16.0, 16.0], [3.0, 3.0]]) + ).all(), + True, + ) + + def test_large_data2(self): + if os.name == "nt" or not paddle.is_compiled_with_cuda(): + return + + x = np.random.rand(183826, 256).astype("float32") + index = np.ones(107592, dtype="int64") + updates = np.ones(shape=[107592, 256], dtype="float32") + + def test_dygraph(): + with base.dygraph.guard(): + gpu_out = paddle.scatter( + paddle.to_tensor(x), + paddle.to_tensor(index), + paddle.to_tensor(updates), + overwrite=False, + reduce='mul', + include_self=True, + ) + return gpu_out.numpy() + + @switch_to_static_graph + def test_static_graph(): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x_t = paddle.static.data(name="x", dtype=x.dtype, shape=x.shape) + index_t = paddle.static.data( + name="index", dtype=index.dtype, shape=index.shape + ) + updates_t = paddle.static.data( + name="updates", dtype=updates.dtype, shape=updates.shape + ) + out_t = paddle.scatter( + x_t, + index_t, + updates_t, + overwrite=False, + reduce='mul', + include_self=True, + ) + feed = { + x_t.name: x, + index_t.name: index, + updates_t.name: updates, + "overwrite": False, + "reduce": 'mul', + "include_self": True, + } + fetch = [out_t] + + gpu_exe = paddle.static.Executor(paddle.CUDAPlace(0)) + gpu_value = gpu_exe.run(feed=feed, fetch_list=fetch)[0] + return gpu_value + + np.testing.assert_array_equal(test_dygraph(), test_static_graph()) + + +class TestScatterOp6ReduceAmin(TestScatterOp6): + def _set_attr(self): + self.attrs = { + 'overwrite': False, + 'axis': 0, + 'reduce': "amin", + 'include_self': False, + } + + def _set_inout(self): + self.ref_np = np.array([[1, 1], [2, 2], [3, 3]]).astype( + self.target_dtype + ) + self.index_np = np.array([2, 1, 0, 1]).astype(np.int32) + self.updates_np = np.array([[1, 1], [2, 2], [3, 3], [4, 4]]).astype( + self.target_dtype + ) + self.output_np = np.array([[3.0, 3.0], [2.0, 2.0], [1.0, 1.0]]).astype( + self.target_dtype + ) + + def test_check_grad(self): + self.check_grad(["X"], "Out", check_prim=True) + + +class TestScatterAPIReduceAmin(unittest.TestCase): + def setUp(self): + self.places = [base.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(base.CUDAPlace(0)) + self.scatter = paddle.scatter + self.reduce = "amin" + + def check_static_result(self, place): + paddle.enable_static() + with base.program_guard(base.Program(), base.Program()): + input = paddle.static.data( + name="input", shape=[3, 2], dtype="float64" + ) + index = paddle.static.data(name="index", shape=[4], dtype="int64") + updates = paddle.static.data( + name="updates", shape=[4, 2], dtype="float64" + ) + result = self.scatter( + input, + index, + updates, + overwrite=False, + reduce=self.reduce, + include_self=False, + ) + + input_data = np.array([[1, 1], [2, 2], [3, 3]]).astype(np.float64) + index_data = np.array([2, 1, 0, 1]).astype(np.int64) + updates_data = np.array([[1, 1], [2, 2], [3, 3], [4, 4]]).astype( + np.float64 + ) + + exe = base.Executor(place) + fetches = exe.run( + base.default_main_program(), + feed={ + "input": input_data, + "index": index_data, + "updates": updates_data, + "overwrite": False, + "reduce": self.reduce, + "inlcude_self": False, + }, + fetch_list=[result], + ) + self.assertEqual( + ( + fetches[0] == np.array([[3.0, 3.0], [2.0, 2.0], [1.0, 1.0]]) + ).all(), + True, + ) + paddle.disable_static() + + def test_static(self): + for place in self.places: + self.check_static_result(place=place) + + def test_dygraph(self): + for place in self.places: + with base.dygraph.guard(place): + x_data = np.array([[1, 1], [2, 2], [3, 3]]).astype(np.float64) + index_data = np.array([2, 1, 0, 1]).astype(np.int64) + updates_data = np.array( + [[1, 1], [2, 2], [3, 3], [4, 4]] + ).astype(np.float64) + + x = base.dygraph.to_variable(x_data) + index = base.dygraph.to_variable(index_data) + updates = base.dygraph.to_variable(updates_data) + + output1 = self.scatter( + x, + index, + updates, + overwrite=False, + reduce=self.reduce, + include_self=False, + ) + self.assertEqual( + ( + output1.numpy() + == np.array([[3.0, 3.0], [2.0, 2.0], [1.0, 1.0]]) + ).all(), + True, + ) + + x.stop_gradient = False + updates.stop_gradient = False + self.scatter( + x, + index, + updates, + overwrite=False, + reduce=self.reduce, + include_self=False, + ).sum().backward() + + self.assertEqual( + ( + x.grad.numpy() + == np.array([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) + ).all(), + True, + ) + + self.assertEqual( + ( + updates.grad.numpy() + == np.array( + [[1.0, 1.0], [0.5, 0.5], [1.0, 1.0], [0.0, 0.0]] + ) + ).all(), + True, + ) + + def test_large_data(self): + if os.name == "nt" or not paddle.is_compiled_with_cuda(): + return + + x = np.random.rand(183826, 256).astype("float32") + index = np.ones(107592, dtype="int64") + updates = np.ones(shape=[107592, 256], dtype="float32") + + def test_dygraph(): + with base.dygraph.guard(): + gpu_out = paddle.scatter( + paddle.to_tensor(x), + paddle.to_tensor(index), + paddle.to_tensor(updates), + overwrite=False, + reduce=self.reduce, + include_self=False, + ) + return gpu_out.numpy() + + @switch_to_static_graph + def test_static_graph(): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x_t = paddle.static.data(name="x", dtype=x.dtype, shape=x.shape) + index_t = paddle.static.data( + name="index", dtype=index.dtype, shape=index.shape + ) + updates_t = paddle.static.data( + name="updates", dtype=updates.dtype, shape=updates.shape + ) + out_t = paddle.scatter( + x_t, + index_t, + updates_t, + overwrite=False, + reduce=self.reduce, + include_self=False, + ) + feed = { + x_t.name: x, + index_t.name: index, + updates_t.name: updates, + "overwrite": False, + "reduce": self.reduce, + "inlcude_self": False, + } + fetch = [out_t] + + gpu_exe = paddle.static.Executor(paddle.CUDAPlace(0)) + gpu_value = gpu_exe.run(feed=feed, fetch_list=fetch)[0] + return gpu_value + + np.testing.assert_array_equal(test_dygraph(), test_static_graph()) + + def check_static_result2(self, place): + paddle.enable_static() + with base.program_guard(base.Program(), base.Program()): + input = paddle.static.data( + name="input", shape=[3, 2], dtype="float64" + ) + index = paddle.static.data(name="index", shape=[4], dtype="int64") + updates = paddle.static.data( + name="updates", shape=[4, 2], dtype="float64" + ) + result = self.scatter( + input, + index, + updates, + overwrite=False, + reduce=self.reduce, + include_self=True, + ) + + input_data = np.array([[1, 1], [2, 2], [3, 3]]).astype(np.float64) + index_data = np.array([2, 1, 0, 1]).astype(np.int64) + updates_data = np.array([[1, 1], [2, 2], [3, 3], [4, 4]]).astype( + np.float64 + ) + + exe = base.Executor(place) + fetches = exe.run( + base.default_main_program(), + feed={ + "input": input_data, + "index": index_data, + "updates": updates_data, + "overwrite": False, + "reduce": self.reduce, + "include_self": True, + }, + fetch_list=[result], + ) + self.assertEqual( + ( + fetches[0] == np.array([[1.0, 1.0], [2.0, 2.0], [1.0, 1.0]]) + ).all(), + True, + ) + paddle.disable_static() + + def test_static2(self): + for place in self.places: + self.check_static_result2(place=place) + + def test_dygraph2(self): + for place in self.places: + with base.dygraph.guard(place): + x_data = np.array([[1, 1], [2, 2], [3, 3]]).astype(np.float64) + index_data = np.array([2, 1, 0, 1]).astype(np.int64) + updates_data = np.array( + [[1, 1], [2, 2], [3, 3], [4, 4]] + ).astype(np.float64) + + x = base.dygraph.to_variable(x_data) + index = base.dygraph.to_variable(index_data) + updates = base.dygraph.to_variable(updates_data) + + output1 = self.scatter( + x, + index, + updates, + overwrite=False, + reduce=self.reduce, + include_self=True, + ) + self.assertEqual( + ( + output1.numpy() + == np.array([[1.0, 1.0], [2.0, 2.0], [1.0, 1.0]]) + ).all(), + True, + ) + + def test_large_data2(self): + if os.name == "nt" or not paddle.is_compiled_with_cuda(): + return + + x = np.random.rand(183826, 256).astype("float32") + index = np.ones(107592, dtype="int64") + updates = np.ones(shape=[107592, 256], dtype="float32") + + def test_dygraph(): + with base.dygraph.guard(): + gpu_out = paddle.scatter( + paddle.to_tensor(x), + paddle.to_tensor(index), + paddle.to_tensor(updates), + overwrite=False, + reduce=self.reduce, + include_self=True, + ) + return gpu_out.numpy() + + @switch_to_static_graph + def test_static_graph(): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x_t = paddle.static.data(name="x", dtype=x.dtype, shape=x.shape) + index_t = paddle.static.data( + name="index", dtype=index.dtype, shape=index.shape + ) + updates_t = paddle.static.data( + name="updates", dtype=updates.dtype, shape=updates.shape + ) + out_t = paddle.scatter( + x_t, + index_t, + updates_t, + overwrite=False, + reduce=self.reduce, + include_self=True, + ) + feed = { + x_t.name: x, + index_t.name: index, + updates_t.name: updates, + "overwrite": False, + "reduce": self.reduce, + "include_self": True, + } + fetch = [out_t] + + gpu_exe = paddle.static.Executor(paddle.CUDAPlace(0)) + gpu_value = gpu_exe.run(feed=feed, fetch_list=fetch)[0] + return gpu_value + + np.testing.assert_array_equal(test_dygraph(), test_static_graph()) + + +class TestScatterOp6ReduceAmax(TestScatterOp6): + def _set_attr(self): + self.attrs = { + 'overwrite': False, + 'axis': 0, + 'reduce': "amax", + 'include_self': False, + } + + def _set_inout(self): + self.ref_np = np.array([[1, 1], [2, 2], [3, 3]]).astype( + self.target_dtype + ) + self.index_np = np.array([2, 1, 0, 1]).astype(np.int32) + self.updates_np = np.array([[1, 1], [2, 2], [3, 3], [4, 4]]).astype( + self.target_dtype + ) + self.output_np = np.array([[3.0, 3.0], [4.0, 4.0], [1.0, 1.0]]).astype( + self.target_dtype + ) + + +class TestScatterAPIReduceAmax(unittest.TestCase): + def setUp(self): + self.places = [base.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(base.CUDAPlace(0)) + self.scatter = paddle.scatter + self.reduce = "amax" + + def check_static_result(self, place): + paddle.enable_static() + with base.program_guard(base.Program(), base.Program()): + input = paddle.static.data( + name="input", shape=[3, 2], dtype="float64" + ) + index = paddle.static.data(name="index", shape=[4], dtype="int64") + updates = paddle.static.data( + name="updates", shape=[4, 2], dtype="float64" + ) + result = self.scatter( + input, + index, + updates, + overwrite=False, + reduce=self.reduce, + include_self=False, + ) + + input_data = np.array([[1, 1], [2, 2], [3, 3]]).astype(np.float64) + index_data = np.array([2, 1, 0, 1]).astype(np.int64) + updates_data = np.array([[1, 1], [2, 2], [3, 3], [4, 4]]).astype( + np.float64 + ) + + exe = base.Executor(place) + fetches = exe.run( + base.default_main_program(), + feed={ + "input": input_data, + "index": index_data, + "updates": updates_data, + "overwrite": False, + "reduce": self.reduce, + "inlcude_self": False, + }, + fetch_list=[result], + ) + self.assertEqual( + ( + fetches[0] == np.array([[3.0, 3.0], [4.0, 4.0], [1.0, 1.0]]) + ).all(), + True, + ) + paddle.disable_static() + + def test_static(self): + for place in self.places: + self.check_static_result(place=place) + + def test_dygraph(self): + for place in self.places: + with base.dygraph.guard(place): + x_data = np.array([[1, 1], [2, 2], [3, 3]]).astype(np.float64) + index_data = np.array([2, 1, 0, 1]).astype(np.int64) + updates_data = np.array( + [[1, 1], [2, 2], [3, 3], [4, 4]] + ).astype(np.float64) + + x = base.dygraph.to_variable(x_data) + index = base.dygraph.to_variable(index_data) + updates = base.dygraph.to_variable(updates_data) + + output1 = self.scatter( + x, + index, + updates, + overwrite=False, + reduce=self.reduce, + include_self=False, + ) + self.assertEqual( + ( + output1.numpy() + == np.array([[3.0, 3.0], [4.0, 4.0], [1.0, 1.0]]) + ).all(), + True, + ) + + def test_large_data(self): + if os.name == "nt" or not paddle.is_compiled_with_cuda(): + return + + x = np.random.rand(183826, 256).astype("float32") + index = np.ones(107592, dtype="int64") + updates = np.ones(shape=[107592, 256], dtype="float32") + + def test_dygraph(): + with base.dygraph.guard(): + gpu_out = paddle.scatter( + paddle.to_tensor(x), + paddle.to_tensor(index), + paddle.to_tensor(updates), + overwrite=False, + reduce=self.reduce, + include_self=False, + ) + return gpu_out.numpy() + + @switch_to_static_graph + def test_static_graph(): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x_t = paddle.static.data(name="x", dtype=x.dtype, shape=x.shape) + index_t = paddle.static.data( + name="index", dtype=index.dtype, shape=index.shape + ) + updates_t = paddle.static.data( + name="updates", dtype=updates.dtype, shape=updates.shape + ) + out_t = paddle.scatter( + x_t, + index_t, + updates_t, + overwrite=False, + reduce=self.reduce, + include_self=False, + ) + feed = { + x_t.name: x, + index_t.name: index, + updates_t.name: updates, + "overwrite": False, + "reduce": self.reduce, + "inlcude_self": False, + } + fetch = [out_t] + + gpu_exe = paddle.static.Executor(paddle.CUDAPlace(0)) + gpu_value = gpu_exe.run(feed=feed, fetch_list=fetch)[0] + return gpu_value + + np.testing.assert_array_equal(test_dygraph(), test_static_graph()) + + def check_static_result2(self, place): + paddle.enable_static() + with base.program_guard(base.Program(), base.Program()): + input = paddle.static.data( + name="input", shape=[3, 2], dtype="float64" + ) + index = paddle.static.data(name="index", shape=[4], dtype="int64") + updates = paddle.static.data( + name="updates", shape=[4, 2], dtype="float64" + ) + result = self.scatter( + input, + index, + updates, + overwrite=False, + reduce=self.reduce, + include_self=True, + ) + + input_data = np.array([[1, 1], [2, 2], [3, 3]]).astype(np.float64) + index_data = np.array([2, 1, 0, 1]).astype(np.int64) + updates_data = np.array([[1, 1], [2, 2], [3, 3], [4, 4]]).astype( + np.float64 + ) + + exe = base.Executor(place) + fetches = exe.run( + base.default_main_program(), + feed={ + "input": input_data, + "index": index_data, + "updates": updates_data, + "overwrite": False, + "reduce": self.reduce, + "include_self": True, + }, + fetch_list=[result], + ) + self.assertEqual( + ( + fetches[0] == np.array([[3.0, 3.0], [4.0, 4.0], [3.0, 3.0]]) + ).all(), + True, + ) + paddle.disable_static() + + def test_static2(self): + for place in self.places: + self.check_static_result2(place=place) + + def test_dygraph2(self): + for place in self.places: + with base.dygraph.guard(place): + x_data = np.array([[1, 1], [2, 2], [3, 3]]).astype(np.float64) + index_data = np.array([2, 1, 0, 1]).astype(np.int64) + updates_data = np.array( + [[1, 1], [2, 2], [3, 3], [4, 4]] + ).astype(np.float64) + + x = base.dygraph.to_variable(x_data) + index = base.dygraph.to_variable(index_data) + updates = base.dygraph.to_variable(updates_data) + + output1 = self.scatter( + x, + index, + updates, + overwrite=False, + reduce=self.reduce, + include_self=True, + ) + self.assertEqual( + ( + output1.numpy() + == np.array([[3.0, 3.0], [4.0, 4.0], [3.0, 3.0]]) + ).all(), + True, + ) + + def test_large_data2(self): + if os.name == "nt" or not paddle.is_compiled_with_cuda(): + return + + x = np.random.rand(183826, 256).astype("float32") + index = np.ones(107592, dtype="int64") + updates = np.ones(shape=[107592, 256], dtype="float32") + + def test_dygraph(): + with base.dygraph.guard(): + gpu_out = paddle.scatter( + paddle.to_tensor(x), + paddle.to_tensor(index), + paddle.to_tensor(updates), + overwrite=False, + reduce=self.reduce, + include_self=True, + ) + return gpu_out.numpy() + + @switch_to_static_graph + def test_static_graph(): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x_t = paddle.static.data(name="x", dtype=x.dtype, shape=x.shape) + index_t = paddle.static.data( + name="index", dtype=index.dtype, shape=index.shape + ) + updates_t = paddle.static.data( + name="updates", dtype=updates.dtype, shape=updates.shape + ) + out_t = paddle.scatter( + x_t, + index_t, + updates_t, + overwrite=False, + reduce=self.reduce, + include_self=True, + ) + feed = { + x_t.name: x, + index_t.name: index, + updates_t.name: updates, + "overwrite": False, + "reduce": self.reduce, + "include_self": True, + } + fetch = [out_t] + + gpu_exe = paddle.static.Executor(paddle.CUDAPlace(0)) + gpu_value = gpu_exe.run(feed=feed, fetch_list=fetch)[0] + return gpu_value + + np.testing.assert_array_equal(test_dygraph(), test_static_graph()) + + +class TestScatterOp6ReduceMean(TestScatterOp6): + def _set_attr(self): + self.attrs = { + 'overwrite': False, + 'axis': 0, + 'reduce': "mean", + 'include_self': False, + } + + def _set_inout(self): + self.ref_np = np.array([[1, 1], [2, 2], [3, 3]]).astype( + self.target_dtype + ) + self.index_np = np.array([2, 1, 0, 1]).astype(np.int32) + self.updates_np = np.array([[1, 1], [2, 2], [3, 3], [4, 4]]).astype( + self.target_dtype + ) + self.output_np = np.array([[3.0, 3.0], [3.0, 3.0], [1.0, 1.0]]).astype( + self.target_dtype + ) + + +class TestScatterAPIReduceMean(unittest.TestCase): + def setUp(self): + self.places = [base.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(base.CUDAPlace(0)) + self.scatter = paddle.scatter + self.reduce = "mean" + + def check_static_result(self, place): + paddle.enable_static() + with base.program_guard(base.Program(), base.Program()): + input = paddle.static.data( + name="input", shape=[3, 2], dtype="float64" + ) + index = paddle.static.data(name="index", shape=[4], dtype="int64") + updates = paddle.static.data( + name="updates", shape=[4, 2], dtype="float64" + ) + result = self.scatter( + input, + index, + updates, + overwrite=False, + reduce=self.reduce, + include_self=False, + ) + + input_data = np.array([[1, 1], [2, 2], [3, 3]]).astype(np.float64) + index_data = np.array([2, 1, 0, 1]).astype(np.int64) + updates_data = np.array([[1, 1], [2, 2], [3, 3], [4, 4]]).astype( + np.float64 + ) + + exe = base.Executor(place) + fetches = exe.run( + base.default_main_program(), + feed={ + "input": input_data, + "index": index_data, + "updates": updates_data, + "overwrite": False, + "reduce": self.reduce, + "inlcude_self": False, + }, + fetch_list=[result], + ) + self.assertEqual( + ( + fetches[0] == np.array([[3.0, 3.0], [3.0, 3.0], [1.0, 1.0]]) + ).all(), + True, + ) + paddle.disable_static() + + def test_static(self): + for place in self.places: + self.check_static_result(place=place) + + def test_dygraph(self): + for place in self.places: + with base.dygraph.guard(place): + x_data = np.array([[1, 1], [2, 2], [3, 3]]).astype(np.float64) + index_data = np.array([2, 1, 0, 1]).astype(np.int64) + updates_data = np.array( + [[1, 1], [2, 2], [3, 3], [4, 4]] + ).astype(np.float64) + + x = base.dygraph.to_variable(x_data) + index = base.dygraph.to_variable(index_data) + updates = base.dygraph.to_variable(updates_data) + + output1 = self.scatter( + x, + index, + updates, + overwrite=False, + reduce=self.reduce, + include_self=False, + ) + self.assertEqual( + ( + output1.numpy() + == np.array([[3.0, 3.0], [3.0, 3.0], [1.0, 1.0]]) + ).all(), + True, + ) + + def test_large_data(self): + if os.name == "nt" or not paddle.is_compiled_with_cuda(): + return + + x = np.random.rand(183826, 256).astype("float32") + index = np.ones(107592, dtype="int64") + updates = np.ones(shape=[107592, 256], dtype="float32") + + def test_dygraph(): + with base.dygraph.guard(): + gpu_out = paddle.scatter( + paddle.to_tensor(x), + paddle.to_tensor(index), + paddle.to_tensor(updates), + overwrite=False, + reduce=self.reduce, + include_self=False, + ) + return gpu_out.numpy() + + @switch_to_static_graph + def test_static_graph(): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x_t = paddle.static.data(name="x", dtype=x.dtype, shape=x.shape) + index_t = paddle.static.data( + name="index", dtype=index.dtype, shape=index.shape + ) + updates_t = paddle.static.data( + name="updates", dtype=updates.dtype, shape=updates.shape + ) + out_t = paddle.scatter( + x_t, + index_t, + updates_t, + overwrite=False, + reduce=self.reduce, + include_self=False, + ) + feed = { + x_t.name: x, + index_t.name: index, + updates_t.name: updates, + "overwrite": False, + "reduce": self.reduce, + "inlcude_self": False, + } + fetch = [out_t] + + gpu_exe = paddle.static.Executor(paddle.CUDAPlace(0)) + gpu_value = gpu_exe.run(feed=feed, fetch_list=fetch)[0] + return gpu_value + + np.testing.assert_array_equal(test_dygraph(), test_static_graph()) + + def check_static_result2(self, place): + paddle.enable_static() + with base.program_guard(base.Program(), base.Program()): + input = paddle.static.data( + name="input", shape=[3, 2], dtype="float64" + ) + index = paddle.static.data(name="index", shape=[4], dtype="int64") + updates = paddle.static.data( + name="updates", shape=[4, 2], dtype="float64" + ) + result = self.scatter( + input, + index, + updates, + overwrite=False, + reduce=self.reduce, + include_self=True, + ) + + input_data = np.array([[1, 1], [2, 2], [3, 3]]).astype(np.float64) + index_data = np.array([2, 1, 0, 1]).astype(np.int64) + updates_data = np.array([[1, 1], [2, 2], [3, 3], [4, 4]]).astype( + np.float64 + ) + + exe = base.Executor(place) + fetches = exe.run( + base.default_main_program(), + feed={ + "input": input_data, + "index": index_data, + "updates": updates_data, + "overwrite": False, + "reduce": self.reduce, + "include_self": True, + }, + fetch_list=[result], + ) + np.testing.assert_allclose( + fetches[0], + np.array([[2.0, 2.0], [2.66666667, 2.66666667], [2.0, 2.0]]), + ) + paddle.disable_static() + + def test_static2(self): + for place in self.places: + self.check_static_result2(place=place) + + def test_dygraph2(self): + for place in self.places: + with base.dygraph.guard(place): + x_data = np.array([[1, 1], [2, 2], [3, 3]]).astype(np.float64) + index_data = np.array([2, 1, 0, 1]).astype(np.int64) + updates_data = np.array( + [[1, 1], [2, 2], [3, 3], [4, 4]] + ).astype(np.float64) + + x = base.dygraph.to_variable(x_data) + index = base.dygraph.to_variable(index_data) + updates = base.dygraph.to_variable(updates_data) + + output1 = self.scatter( + x, + index, + updates, + overwrite=False, + reduce=self.reduce, + include_self=True, + ) + np.testing.assert_allclose( + output1.numpy(), + np.array( + [[2.0, 2.0], [2.66666667, 2.66666667], [2.0, 2.0]] + ), + ) + + def test_large_data2(self): + if os.name == "nt" or not paddle.is_compiled_with_cuda(): + return + + x = np.random.rand(183826, 256).astype("float32") + index = np.ones(107592, dtype="int64") + updates = np.ones(shape=[107592, 256], dtype="float32") + + def test_dygraph(): + with base.dygraph.guard(): + gpu_out = paddle.scatter( + paddle.to_tensor(x), + paddle.to_tensor(index), + paddle.to_tensor(updates), + overwrite=False, + reduce=self.reduce, + include_self=True, + ) + return gpu_out.numpy() + + @switch_to_static_graph + def test_static_graph(): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x_t = paddle.static.data(name="x", dtype=x.dtype, shape=x.shape) + index_t = paddle.static.data( + name="index", dtype=index.dtype, shape=index.shape + ) + updates_t = paddle.static.data( + name="updates", dtype=updates.dtype, shape=updates.shape + ) + out_t = paddle.scatter( + x_t, + index_t, + updates_t, + overwrite=False, + reduce=self.reduce, + include_self=True, + ) + feed = { + x_t.name: x, + index_t.name: index, + updates_t.name: updates, + "overwrite": False, + "reduce": self.reduce, + "include_self": True, + } + fetch = [out_t] + + gpu_exe = paddle.static.Executor(paddle.CUDAPlace(0)) + gpu_value = gpu_exe.run(feed=feed, fetch_list=fetch)[0] + return gpu_value + + np.testing.assert_array_equal(test_dygraph(), test_static_graph()) if __name__ == "__main__": diff --git a/test/xpu/test_scatter_op_xpu.py b/test/xpu/test_scatter_op_xpu.py index 7ff92985b34b24..03a39fd0b3363c 100644 --- a/test/xpu/test_scatter_op_xpu.py +++ b/test/xpu/test_scatter_op_xpu.py @@ -132,7 +132,12 @@ def setUp(self): 'Ids': self.index_np, 'Updates': self.updates_np, } - self.attrs = {'overwrite': self.overwrite} + self.attrs = { + 'overwrite': self.overwrite, + "axis": 0, + "reduce": "add", + "include_self": False, + } self.outputs = {'Out': self.output_np} def init_config(self):