Skip to content

Commit

Permalink
add LaunchElementwiseCudaKernel in phi
Browse files Browse the repository at this point in the history
  • Loading branch information
huangjiyi committed Nov 10, 2022
1 parent 71f8e13 commit 588f45b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
39 changes: 39 additions & 0 deletions paddle/phi/kernels/funcs/elementwise_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/platform/function_traits.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/primitive/kernel_primitives.h"

#define HOSTDEVICE __host__ __device__
Expand Down Expand Up @@ -824,6 +825,44 @@ void LaunchElementwiseCudaKernel(const KPDevice &ctx,
#endif
}

template <ElementwiseType ET,
typename InT,
typename OutT,
typename Functor,
int NumOuts = 1>
void LaunchElementwiseCudaKernel(
const KPDevice &ctx,
const std::vector<const phi::DenseTensor *> &ins,
std::vector<phi::DenseTensor *> *outs,
int axis,
Functor func) {
std::vector<const phi::DenseTensor *> pt_inputs;
std::vector<phi::DenseTensor *> pt_outputs;
// TODO(YuanRisheng) *_tmp for cache DenseTensor, because the temporary
// DenseTensor obj
// generated by MakePhiDenseTensor can be destroyed when exits loop. *_tmp
// can be deleted
// when DenseTensor support copy constructor.
std::vector<std::unique_ptr<phi::DenseTensor>> pt_inputs_tmp;
std::vector<std::unique_ptr<phi::DenseTensor>> pt_outputs_tmp;
for (auto in : ins) {
pt_inputs_tmp.emplace_back(
std::move(paddle::experimental::MakePhiDenseTensor(*in)));
}
for (auto out : *outs) {
pt_outputs_tmp.emplace_back(
std::move(paddle::experimental::MakePhiDenseTensor(*out)));
}
for (int i = 0; i < pt_inputs_tmp.size(); i++) {
pt_inputs.push_back(pt_inputs_tmp[i].get());
}
for (int i = 0; i < pt_outputs_tmp.size(); i++) {
pt_outputs.push_back(pt_outputs_tmp[i].get());
}
phi::funcs::BroadcastKernel<ET, InT, OutT, Functor, NumOuts>(
ctx, pt_inputs, &pt_outputs, axis, func);
}

template <typename OutT, typename Functor, int NumOuts = 1>
void ElementwiseKernel(const KPDevice &ctx,
const std::vector<const DenseTensor *> &ins,
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/gpu/viterbi_decode_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ struct BinaryOperation {
std::vector<const DenseTensor*> ins{&lhs, &rhs};
std::vector<DenseTensor*> outs{output};
phi::funcs::LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, BinaryFunctor<T>());
dev_ctx, ins, &outs, 0, BinaryFunctor<T>());
}
};

Expand Down

0 comments on commit 588f45b

Please sign in to comment.