Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5 No.5】为 Paddle 增强 scatter API #57748

Closed
wants to merge 42 commits into from
Closed
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
328254b
scatter api
lisamhy Sep 18, 2023
14639d4
cpu forward w/o mean
lisamhy Sep 21, 2023
1542a95
cpu forward ok
lisamhy Sep 22, 2023
6ce92b7
update
lisamhy Sep 22, 2023
6ad884c
update
lisamhy Sep 25, 2023
1aaeb9a
update
lisamhy Sep 25, 2023
f0a7e04
ut
lisamhy Sep 26, 2023
fbe4d2b
fix
lisamhy Sep 26, 2023
0d88319
Merge branch 'develop' into scatter
lisamhy Sep 26, 2023
16bc859
format
lisamhy Sep 26, 2023
5fa2033
fix
lisamhy Sep 26, 2023
d87bb1d
fix xpu
lisamhy Sep 26, 2023
a217c28
fix
lisamhy Sep 27, 2023
db9de60
fix
lisamhy Sep 27, 2023
69a980c
scattor not support zero dim
lisamhy Sep 27, 2023
c885cf0
fix
lisamhy Sep 27, 2023
7dff65c
fix
lisamhy Sep 27, 2023
bee1110
fix
lisamhy Oct 8, 2023
15907a4
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
lisamhy Oct 8, 2023
f4da983
fix
lisamhy Oct 8, 2023
5e9d4b2
fix
lisamhy Oct 8, 2023
d9015e2
fix
lisamhy Oct 8, 2023
b486fae
fix
lisamhy Oct 8, 2023
c296453
fix
lisamhy Oct 8, 2023
328e1f9
fix
lisamhy Oct 8, 2023
6523a29
fix
lisamhy Oct 9, 2023
2e26a28
format
lisamhy Oct 9, 2023
a694390
fix
lisamhy Oct 9, 2023
ae9013f
fix
lisamhy Oct 9, 2023
e1aabed
fix
lisamhy Oct 10, 2023
b4f41e8
fix
lisamhy Oct 11, 2023
d9fa122
fix
lisamhy Oct 12, 2023
4a43d2c
fix
lisamhy Oct 13, 2023
9671c06
fix
lisamhy Oct 13, 2023
b7aafba
fix
lisamhy Oct 13, 2023
35d98af
fix comment
lisamhy Oct 26, 2023
663af87
merge develop
lisamhy Oct 26, 2023
21b5759
fix
lisamhy Oct 30, 2023
e8af985
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
lisamhy Nov 3, 2023
fb811cd
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
lisamhy Nov 3, 2023
f33b83c
fix
lisamhy Nov 3, 2023
25b7d8f
Merge branch 'develop' into scatter
lisamhy Nov 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1191,10 +1191,15 @@ void cos_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
}

template <typename T>
void scatter_grad(const Tensor& index,
void scatter_grad(const Tensor& x,
const Tensor& index,
const Tensor& updates,
const Tensor& out,
const Tensor& out_grad,
bool overwrite,
int axis,
const std::string& reduce,
bool include_self,
Tensor* x_grad,
Tensor* updates_grad) {
if (x_grad) {
Expand Down
5 changes: 4 additions & 1 deletion paddle/phi/api/include/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,10 @@ class PADDLE_API Tensor final {
const std::vector<int64_t>& axis = {}) const;
Tensor scatter(const Tensor& index,
const Tensor& updates,
bool overwrite = true) const;
bool overwrite = true,
int axis = 0,
const std::string& reduce = "add",
bool include_self = false) const;
Tensor scatter_nd_add(const Tensor& index, const Tensor& updates) const;
Tensor abs() const;
Tensor assign() const;
Expand Down
9 changes: 4 additions & 5 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1918,16 +1918,15 @@
invoke : scale(out_grad, scale, 0.0f, true)

- backward_op : scatter_grad
forward : scatter (Tensor x, Tensor index, Tensor updates, bool overwrite=true) -> Tensor(out)
args : (Tensor index, Tensor updates, Tensor out_grad, bool overwrite)
forward : scatter (Tensor x, Tensor index, Tensor updates, bool overwrite=true, int axis=0, str reduce="add", bool include_self=false) -> Tensor(out)
args : (Tensor x, Tensor index, Tensor updates, Tensor out, Tensor out_grad, bool overwrite, int axis, str reduce, bool include_self)
output : Tensor(x_grad), Tensor(updates_grad)
infer_meta :
func : ScatterGradInferMeta
param : [index, updates, out_grad, overwrite]
param : [index, updates, out_grad]
kernel :
func : scatter_grad
no_need_buffer : updates
composite: scatter_grad(index, updates, out_grad, overwrite, x_grad, updates_grad)
composite: scatter_grad(x, index, updates, out, out_grad, overwrite, axis, reduce, include_self, x_grad, updates_grad)

- backward_op : scatter_nd_add_grad
forward : scatter_nd_add (Tensor x, Tensor index, Tensor updates) -> Tensor(out)
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2198,7 +2198,7 @@
backward : scale_grad

- op : scatter
args : (Tensor x, Tensor index, Tensor updates, bool overwrite=true)
args : (Tensor x, Tensor index, Tensor updates, bool overwrite=true, int axis=0, str reduce="add", bool include_self=false)
output : Tensor(out)
infer_meta :
func : ScatterInferMeta
Expand Down
117 changes: 117 additions & 0 deletions paddle/phi/backends/gpu/gpu_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,123 @@ CUDA_ATOMIC_WRAPPER(Add, complex<double>) {
CudaAtomicAdd(imag, val.imag));
}

// Atomic multiplication implementation.
CUDA_ATOMIC_WRAPPER(Mul, int64_t) {
// Here, we check long long int must be int64_t.
static_assert(sizeof(int64_t) == sizeof(long long int), // NOLINT
"long long should be int64");
unsigned long long int *address_as_ull = // NOLINT
(unsigned long long int *)address; // NOLINT
unsigned long long int old = *address_as_ull; // NOLINT
unsigned long long int assumed; // NOLINT

do {
assumed = old;
old = atomicCAS(address_as_ull,
assumed,
static_cast<unsigned long long int>( // NOLINT
val * static_cast<int64_t>(assumed)));
// Note: uses integer comparison to avoid hang in case of NaN (since NaN
// != NaN)
} while (assumed != old);

return static_cast<int64_t>(old);
}

CUDA_ATOMIC_WRAPPER(Mul, int) {
int old = *address;
int assumed;

do {
assumed = old;
old = atomicCAS(address, assumed, val * assumed);

// Note: uses integer comparison to avoid hang in case of NaN (since NaN !=
// NaN)
} while (assumed != old);

return old;
}

#ifdef PADDLE_CUDA_FP16
CUDA_ATOMIC_WRAPPER(Mul, phi::dtype::float16) {
unsigned int *address_as_ui =
(unsigned int *)((char *)address - ((size_t)address & 2)); // NOLINT
unsigned int old = *address_as_ui;
unsigned int assumed;

phi::dtype::float16 hsum;
do {
assumed = old;
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); // NOLINT

hsum = hsum * val;
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) // NOLINT
: (old & 0xffff0000) | hsum.x; // NOLINT
old = atomicCAS(address_as_ui, assumed, old);
} while (assumed != old);
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); // NOLINT
return hsum;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这几个新增的atomic操作的实现算法,有对应的参考的地方吗;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考stackoverflow里的写的。

}
#endif

CUDA_ATOMIC_WRAPPER(Mul, phi::dtype::bfloat16) {
unsigned int *address_as_ui =
(unsigned int *)((char *)address - ((size_t)address & 2)); // NOLINT
unsigned int old = *address_as_ui;
unsigned int assumed;

phi::dtype::bfloat16 bsum;
do {
assumed = old;
bsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); // NOLINT
bsum = bsum * val;
old = (size_t)address & 2 ? (old & 0xffff) | (bsum.x << 16) // NOLINT
: (old & 0xffff0000) | bsum.x; // NOLINT
old = atomicCAS(address_as_ui, assumed, old);
} while (assumed != old);
bsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); // NOLINT
return bsum;
}

CUDA_ATOMIC_WRAPPER(Mul, double) {
unsigned long long int *address_as_ull = // NOLINT
(unsigned long long int *)address; // NOLINT
unsigned long long int old = *address_as_ull; // NOLINT
unsigned long long int assumed; // NOLINT

do {
assumed = old;
old = atomicCAS(
address_as_ull,
assumed,
__double_as_longlong(val * __longlong_as_double(assumed))); // NOLINT
// Note: uses integer comparison to avoid hang in case of NaN (since NaN
// != NaN)
} while (assumed != old);

return __longlong_as_double(old);
}

// Dont use a templated function for this since the addition function defaults
// to the CUDA built-in.
CUDA_ATOMIC_WRAPPER(Mul, float) {
unsigned int *address_as_ull = (unsigned int *)address; // NOLINT
unsigned int old = *address_as_ull;
unsigned int assumed;

do {
assumed = old;
old = atomicCAS(
address_as_ull, assumed, __float_as_int(val * __int_as_float(assumed)));

// Note: uses integer comparison to avoid hang in case of NaN (since NaN !=
// NaN)
} while (assumed != old);

return __int_as_float(old);
}

// For atomicMax
USE_CUDA_ATOMIC(Max, int);
USE_CUDA_ATOMIC(Max, unsigned int);
Expand Down
1 change: 0 additions & 1 deletion paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -979,7 +979,6 @@ void RnnGradInferMeta(const MetaTensor& x,
void ScatterGradInferMeta(const MetaTensor& index,
const MetaTensor& updates,
const MetaTensor& out_grad,
bool overwrite,
MetaTensor* x_grad,
MetaTensor* updates_grad) {
const auto& dtype = out_grad.dtype();
Expand Down
1 change: 0 additions & 1 deletion paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,6 @@ void RnnGradInferMeta(const MetaTensor& x,
void ScatterGradInferMeta(const MetaTensor& index,
const MetaTensor& updates,
const MetaTensor& out_grad,
bool overwrite,
MetaTensor* x_grad,
MetaTensor* updates_grad);

Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1035,6 +1035,9 @@ void ScatterInferMeta(const MetaTensor& x,
const MetaTensor& index,
const MetaTensor& updates,
bool overwrite,
int axis,
const std::string& reduce,
bool include_self,
MetaTensor* out) {
const auto& updates_dims = updates.dims();
const auto& ref_dims = x.dims();
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/infermeta/ternary.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ void ScatterInferMeta(const MetaTensor& x,
const MetaTensor& index,
const MetaTensor& updates,
bool overwrite,
int axis,
const std::string& reduce,
bool include_self,
MetaTensor* out);

void ScatterNdAddInferMeta(const MetaTensor& x,
Expand Down
14 changes: 14 additions & 0 deletions paddle/phi/kernels/bitwise_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#pragma once

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/binary.h"

namespace phi {

Expand All @@ -41,4 +42,17 @@ void BitwiseNotKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out);

template <typename T, typename Context>
DenseTensor BitwiseAnd(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y) {
DenseTensor dense_out;
MetaTensor meta_out(&dense_out);
MetaTensor meta_x(&x);
MetaTensor meta_y(&y);
ElementwiseInferMeta(meta_x, meta_y, &meta_out);
BitwiseAndKernel<T, Context>(dev_ctx, x, y, &dense_out);
return dense_out;
}

} // namespace phi
27 changes: 27 additions & 0 deletions paddle/phi/kernels/compare_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#pragma once

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/binary.h"

namespace phi {

Expand Down Expand Up @@ -43,4 +44,30 @@ DECALRE_COMPARE_KERNEL(NotEqual)
DECALRE_COMPARE_ALL_KERNEL(EqualAll)
#undef DECALRE_COMPARE_KERNEL

template <typename T, typename Context>
DenseTensor Equal(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y) {
DenseTensor dense_out;
MetaTensor meta_out(&dense_out);
MetaTensor meta_x(&x);
MetaTensor meta_y(&y);
CompareInferMeta(meta_x, meta_y, &meta_out);
EqualKernel<T, Context>(dev_ctx, x, y, &dense_out);
return dense_out;
}

template <typename T, typename Context>
DenseTensor GreaterThan(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y) {
DenseTensor dense_out;
MetaTensor meta_out(&dense_out);
MetaTensor meta_x(&x);
MetaTensor meta_y(&y);
CompareInferMeta(meta_x, meta_y, &meta_out);
GreaterThanKernel<T, Context>(dev_ctx, x, y, &dense_out);
return dense_out;
}

} // namespace phi
Loading