Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PHI decoupling] move gather_scatter_kernel from fluid to phi #49132

Merged
merged 2 commits into from
Dec 19, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions paddle/fluid/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,7 @@ if(WITH_UNITY_BUILD)
include(unity_build_rule.cmake)
endif()

if (WITH_ROCM)
hip_library(gather_scatter_kernel SRCS gather_scatter_kernel.cc gather_scatter_kernel.cu DEPS tensor)
else()
cc_library(gather_scatter_kernel SRCS gather_scatter_kernel.cc gather_scatter_kernel.cu DEPS tensor)
endif()

set(OP_HEADER_DEPS ${OP_HEADER_DEPS} phi phi_api_utils gather_scatter_kernel backward_infermeta sparse_backward_infermeta)
set(OP_HEADER_DEPS ${OP_HEADER_DEPS} phi phi_api_utils backward_infermeta sparse_backward_infermeta)

register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op load_combine_op lstm_op run_program_op eye_op quantize_linear_op
recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})
Expand Down
1 change: 0 additions & 1 deletion paddle/phi/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ set(COMMON_KERNEL_DEPS
deformable_conv_functor
matrix_reduce
segment_pooling
gather_scatter_kernel
pooling
maxouting
matrix_inverse
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/gather_scatter_kernel.h"
namespace paddle {
namespace operators {
#include "paddle/phi/kernels/gather_scatter_kernel.h"

namespace phi {

class TensorAssign {
public:
Expand Down Expand Up @@ -54,7 +54,7 @@ struct cpu_gather_scatter_functor {
const phi::DenseTensor& src,
const std::string& method_name,
const func_t& reduce_op,
const platform::DeviceContext& ctx) {
const phi::DeviceContext& ctx) {
if (index.numel() == 0) {
return;
}
Expand All @@ -69,7 +69,7 @@ struct cpu_gather_scatter_functor {
auto src_dims = src.dims();
if (self_size == 0 || src_size == 0 || index_size == 0) {
VLOG(3) << "zero size input found";
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"self_size, src_size, index_size cannot be 0");
return;
}
Expand Down Expand Up @@ -132,7 +132,7 @@ void cpu_gather_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor result,
const platform::DeviceContext& ctx) {
const phi::DeviceContext& ctx) {
cpu_gather_scatter_functor<tensor_t,
index_t,
/*is_scatter_like=*/false>()(
Expand All @@ -144,7 +144,7 @@ void cpu_scatter_assign_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor src,
const platform::DeviceContext& ctx) {
const phi::DeviceContext& ctx) {
cpu_gather_scatter_functor<tensor_t,
index_t,
/*is_scatter_like=*/true>()(
Expand All @@ -156,7 +156,7 @@ void cpu_scatter_add_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor src,
const platform::DeviceContext& ctx) {
const phi::DeviceContext& ctx) {
cpu_gather_scatter_functor<tensor_t,
index_t,
/*is_scatter_like=*/true>()(
Expand All @@ -168,7 +168,7 @@ void cpu_scatter_mul_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor src,
const platform::DeviceContext& ctx) {
const phi::DeviceContext& ctx) {
cpu_gather_scatter_functor<tensor_t,
index_t,
/*is_scatter_like=*/true>()(
Expand All @@ -180,7 +180,7 @@ void cpu_scatter_input_grad_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor output,
const platform::DeviceContext& ctx) {
const phi::DeviceContext& ctx) {
auto* index_data = index.data<index_t>();
auto* output_data = output.data<tensor_t>();

Expand Down Expand Up @@ -219,5 +219,4 @@ Instantiate_Template_Function(cpu_gather_kernel)
Instantiate_Template_Function(cpu_scatter_mul_kernel)
Instantiate_Template_Function(cpu_scatter_input_grad_kernel)

} // namespace operators
} // namespace paddle
} // namespace phi
10 changes: 5 additions & 5 deletions paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

#include "paddle/phi/kernels/put_along_axis_grad_kernel.h"

#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/gather_scatter_kernel.h"

namespace phi {

Expand All @@ -41,7 +41,7 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
if (x_grad) {
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
if (index_type == DataType::INT32) {
paddle::operators::cpu_scatter_input_grad_kernel<T, int32_t>(
phi::cpu_scatter_input_grad_kernel<T, int32_t>(
// Here passing an unused argument out_grad, because it's
// convenient to instantiate a bunch of template function with the
// same arguments list.
Expand All @@ -51,7 +51,7 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
*x_grad,
dev_ctx);
} else {
paddle::operators::cpu_scatter_input_grad_kernel<T, int64_t>(
phi::cpu_scatter_input_grad_kernel<T, int64_t>(
out_grad, axis, index, *x_grad, dev_ctx);
}
}
Expand All @@ -60,10 +60,10 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
value_grad->Resize(index.dims());
dev_ctx.template Alloc<T>(value_grad);
if (index_type == DataType::INT32) {
paddle::operators::cpu_gather_kernel<T, int32_t>(
phi::cpu_gather_kernel<T, int32_t>(
out_grad, axis, index, *value_grad, dev_ctx);
} else if (index_type == DataType::INT64) {
paddle::operators::cpu_gather_kernel<T, int64_t>(
phi::cpu_gather_kernel<T, int64_t>(
out_grad, axis, index, *value_grad, dev_ctx);
}
}
Expand Down
14 changes: 7 additions & 7 deletions paddle/phi/kernels/cpu/put_along_axis_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

#include "paddle/phi/kernels/put_along_axis_kernel.h"

#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/gather_scatter_kernel.h"

namespace phi {

Expand All @@ -40,26 +40,26 @@ void PutAlongAxisKernel(const Context& dev_ctx,
const auto& index_type = index.dtype();
if (reduce == "add") {
if (index_type == DataType::INT32) {
paddle::operators::cpu_scatter_add_kernel<T, int32_t>(
phi::cpu_scatter_add_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx);
} else if (index_type == DataType::INT64) {
paddle::operators::cpu_scatter_add_kernel<T, int64_t>(
phi::cpu_scatter_add_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx);
}
} else if (reduce == "multiply" || reduce == "mul") {
if (index_type == DataType::INT32) {
paddle::operators::cpu_scatter_mul_kernel<T, int32_t>(
phi::cpu_scatter_mul_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx);
} else if (index_type == DataType::INT64) {
paddle::operators::cpu_scatter_mul_kernel<T, int64_t>(
phi::cpu_scatter_mul_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx);
}
} else if (reduce == "assign") {
if (index_type == DataType::INT32) {
paddle::operators::cpu_scatter_assign_kernel<T, int32_t>(
phi::cpu_scatter_assign_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx);
} else if (index_type == DataType::INT64) {
paddle::operators::cpu_scatter_assign_kernel<T, int64_t>(
phi::cpu_scatter_assign_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx);
}
} else {
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/kernels/cpu/take_along_axis_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

#include "paddle/phi/kernels/take_along_axis_grad_kernel.h"

#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/gather_scatter_kernel.h"

namespace phi {

Expand Down Expand Up @@ -46,14 +46,14 @@ void TakeAlongAxisGradKernel(const Context& dev_ctx,
const auto& index_type =
paddle::framework::TransToProtoVarType(index.dtype());
if (index_type == paddle::framework::proto::VarType::INT32) {
paddle::operators::cpu_scatter_add_kernel<T, int32_t>(
phi::cpu_scatter_add_kernel<T, int32_t>(
*x_grad,
axis,
index,
out_grad,
dev_ctx); // the gradient of gather is scatter
} else if (index_type == paddle::framework::proto::VarType::INT64) {
paddle::operators::cpu_scatter_add_kernel<T, int64_t>(
phi::cpu_scatter_add_kernel<T, int64_t>(
*x_grad, axis, index, out_grad, dev_ctx);
}
}
Expand Down
8 changes: 3 additions & 5 deletions paddle/phi/kernels/cpu/take_along_axis_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

#include "paddle/phi/kernels/take_along_axis_kernel.h"

#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/gather_scatter_kernel.h"

namespace phi {

Expand All @@ -38,11 +38,9 @@ void TakeAlongAxisKernel(const Context& dev_ctx,

const auto& index_type = index.dtype();
if (index_type == DataType::INT32) {
paddle::operators::cpu_gather_kernel<T, int32_t>(
x, axis, index, *out, dev_ctx);
phi::cpu_gather_kernel<T, int32_t>(x, axis, index, *out, dev_ctx);
} else if (index_type == DataType::INT64) {
paddle::operators::cpu_gather_kernel<T, int64_t>(
x, axis, index, *out, dev_ctx);
phi::cpu_gather_kernel<T, int64_t>(x, axis, index, *out, dev_ctx);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,103 +12,102 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"

#pragma once

namespace paddle {
namespace operators {
namespace phi {

#define Instantiate_Template_Function(func) \
Instantiate_Template_Function_index_t( \
func, int) Instantiate_Template_Function_index_t(func, float) \
Instantiate_Template_Function_index_t(func, double) \
Instantiate_Template_Function_index_t(func, int64_t) \
Instantiate_Template_Function_index_t(func, platform::float16) \
#define Instantiate_Template_Function(func) \
Instantiate_Template_Function_index_t( \
func, int) Instantiate_Template_Function_index_t(func, float) \
Instantiate_Template_Function_index_t(func, double) \
Instantiate_Template_Function_index_t(func, int64_t) \
Instantiate_Template_Function_index_t(func, phi::dtype::float16) \
Instantiate_Template_Function_index_t(func, unsigned char)

#define Instantiate_Template_Function_index_t(func, tensor_t) \
template void func<tensor_t, int>(phi::DenseTensor input, \
int dim, \
const phi::DenseTensor& index, \
phi::DenseTensor result, \
const platform::DeviceContext& ctx); \
template void func<tensor_t, int64_t>(phi::DenseTensor input, \
int dim, \
const phi::DenseTensor& index, \
phi::DenseTensor result, \
const platform::DeviceContext& ctx);
#define Instantiate_Template_Function_index_t(func, tensor_t) \
template void func<tensor_t, int>(phi::DenseTensor input, \
int dim, \
const phi::DenseTensor& index, \
phi::DenseTensor result, \
const phi::DeviceContext& ctx); \
template void func<tensor_t, int64_t>(phi::DenseTensor input, \
int dim, \
const phi::DenseTensor& index, \
phi::DenseTensor result, \
const phi::DeviceContext& ctx);

template <typename tensor_t, typename index_t>
void cpu_gather_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor result,
const platform::DeviceContext& ctx);
const phi::DeviceContext& ctx);

template <typename tensor_t, typename index_t>
void cpu_scatter_assign_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor src,
const platform::DeviceContext& ctx);
const phi::DeviceContext& ctx);

template <typename tensor_t, typename index_t>
void cpu_scatter_add_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor src,
const platform::DeviceContext& ctx);
const phi::DeviceContext& ctx);

template <typename tensor_t, typename index_t>
void cpu_scatter_mul_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor src,
const platform::DeviceContext& ctx);
const phi::DeviceContext& ctx);

template <typename tensor_t, typename index_t>
void cpu_scatter_input_grad_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor result,
const platform::DeviceContext& ctx);
const phi::DeviceContext& ctx);

template <typename tensor_t, typename index_t>
void gpu_gather_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor result,
const platform::DeviceContext& ctx);
const phi::DeviceContext& ctx);

template <typename tensor_t, typename index_t>
void gpu_scatter_assign_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor src,
const platform::DeviceContext& ctx);
const phi::DeviceContext& ctx);

template <typename tensor_t, typename index_t>
void gpu_scatter_add_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor src,
const platform::DeviceContext& ctx);
const phi::DeviceContext& ctx);

template <typename tensor_t, typename index_t>
void gpu_scatter_mul_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor src,
const platform::DeviceContext& ctx);
const phi::DeviceContext& ctx);

template <typename tensor_t, typename index_t>
void gpu_scatter_input_grad_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor result,
const platform::DeviceContext& ctx);
} // namespace operators
} // namespace paddle
const phi::DeviceContext& ctx);

} // namespace phi
Loading