Skip to content

Commit

Permalink
Merge pull request #6 from tsocha/elementwise_div
Browse files Browse the repository at this point in the history
Fix Elementwise div
  • Loading branch information
jczaja authored Apr 4, 2023
2 parents 71090b8 + 273b31d commit a13d562
Showing 1 changed file with 138 additions and 44 deletions.
182 changes: 138 additions & 44 deletions paddle/phi/kernels/onednn/elementwise_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -243,44 +243,6 @@ void ElementwiseGradKernel(const OneDNNContext& dev_ctx,
false);

src_1_memory = binary_handler.AcquireSecondSrcMemory(non_const_x);

if (BINARY_OP == dnnl::algorithm::binary_div) {
funcs::BinaryOneDNNHandler<T> post_op_binary_handler(
dnnl::algorithm::binary_div,
axis,
onednn_engine,
dev_ctx.GetPlace(),
non_const_y,
non_const_y,
nullptr,
1.0f,
1.0f,
1.0f,
false);

post_op_memory = post_op_binary_handler.AcquireSrcMemory(non_const_y);

dnnl::post_ops po;
po.append_binary(dnnl::algorithm::binary_div,
post_op_memory->get_desc());

binary_handler =
funcs::BinaryOneDNNHandler<T>(dnnl::algorithm::binary_mul,
axis,
onednn_engine,
dev_ctx.GetPlace(),
&dout,
out,
nullptr,
-1.0f,
1.0f,
1.0f,
false,
po);

src_1_memory = binary_handler.AcquireSecondSrcMemory(out);
}

src_0_memory = binary_handler.AcquireSrcMemory(&dout);

const auto dst_dy_memory = (dout.dims() == dy->dims())
Expand All @@ -294,10 +256,6 @@ void ElementwiseGradKernel(const OneDNNContext& dev_ctx,
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_0, scales_mem},
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_1, scales_mem}};

if (BINARY_OP == dnnl::algorithm::binary_div)
args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1,
*post_op_memory});

binary_prim->execute(astream, args);
broadcast_src_memory = dst_dy_memory;
dst_memory = dst_dy_memory;
Expand Down Expand Up @@ -345,8 +303,144 @@ void DivideGradKernel(const Context& dev_ctx,
int axis,
DenseTensor* dx,
DenseTensor* dy) {
ElementwiseGradKernel<T, dnnl::algorithm::binary_div>(
dev_ctx, x, y, &out, dout, axis, dx, dy);
const auto& onednn_engine = dev_ctx.GetEngine();
auto* non_const_y = &y;

float scale{1.0};

auto tz = phi::vectorize<int64_t>(dout.dims());

funcs::ReorderOneDNNHandler reorder_handler(
tz, dout.dtype(), funcs::ToOneDNNDataType(dout.dtype()), onednn_engine);

auto reorder_src_memory = reorder_handler.AcquireSrcMemory(
dout.mem_desc(), funcs::to_void_cast(dout.data<T>()));

std::shared_ptr<dnnl::memory> dst_memory;
std::shared_ptr<dnnl::memory> broadcast_src_memory = reorder_src_memory;

auto& astream = OneDNNContext::tls().get_stream();
auto scales_md =
dnnl::memory::desc({1},
dnnl::memory::data_type::f32,
dnnl::memory::format_tag::x);
auto scales_mem = dnnl::memory(scales_md, onednn_engine);
auto scale_memory_buf = static_cast<float*>(scales_mem.get_data_handle());
*scale_memory_buf = scale;

auto neg_scales_mem = dnnl::memory(scales_md, onednn_engine);
auto neg_scale_memory_buf = static_cast<float*>(neg_scales_mem.get_data_handle());
*neg_scale_memory_buf = -scale;
if (dx) {
funcs::BinaryOneDNNHandler<T> binary_handler(dnnl::algorithm::binary_div,
axis,
onednn_engine,
dev_ctx.GetPlace(),
&dout,
non_const_y,
dx,
1.0f,
1.0f,
1.0f,
false);

const auto src_dout_memory = binary_handler.AcquireSrcMemory(&dout);
const auto src_y_memory =
binary_handler.AcquireSecondSrcMemory(non_const_y);
dst_memory = binary_handler.AcquireDstMemory(dx);

const auto binary_prim = binary_handler.AcquireForwardPrimitive();

const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *src_dout_memory},
{DNNL_ARG_SRC_1, *src_y_memory},
{DNNL_ARG_DST, *dst_memory},
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_0, scales_mem},
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_1, scales_mem}};

binary_prim->execute(astream, args);
astream.wait();

if (dout.dims() != dx->dims()) {
funcs::BroadcastReduction<T>(dev_ctx.GetPlace(),
onednn_engine,
dx,
&dout,
broadcast_src_memory,
dst_memory,
{scale},
false);
} else {
dx->set_mem_desc(dst_memory->get_desc());
}
}

if (dy) {
funcs::BinaryOneDNNHandler<T> y_handler(
dnnl::algorithm::binary_div,
axis,
onednn_engine,
dev_ctx.GetPlace(),
non_const_y,
non_const_y,
nullptr,
1.0f,
1.0f,
1.0f,
false);

const auto y_memory = y_handler.AcquireSrcMemory(non_const_y);

dnnl::post_ops po;
po.append_binary(dnnl::algorithm::binary_div,
y_memory->get_desc());

funcs::BinaryOneDNNHandler<T> handler =
funcs::BinaryOneDNNHandler<T>(dnnl::algorithm::binary_mul,
axis,
onednn_engine,
dev_ctx.GetPlace(),
&dout,
&out,
nullptr,
1.0f,
1.0f,
1.0f,
false,
po);

const auto src_dout_memory = handler.AcquireSrcMemory(&dout);
const auto src_out_memory = handler.AcquireSecondSrcMemory(&out);

const auto dst_dy_memory = (dout.dims() == dy->dims())
? handler.AcquireDstMemory(dy)
: handler.AcquireDstMemory();

const auto binary_prim = handler.AcquireForwardPrimitive();
const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *src_dout_memory},
{DNNL_ARG_SRC_1, *src_out_memory},
{DNNL_ARG_DST, *dst_dy_memory},
{DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, *y_memory},
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_0, neg_scales_mem},
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_1, scales_mem}};

binary_prim->execute(astream, args);
astream.wait();

if (dout.dims() != dy->dims()) {
funcs::BroadcastReduction<T>(dev_ctx.GetPlace(),
onednn_engine,
dy,
&dout,
broadcast_src_memory,
dst_memory,
{scale},
false);
} else {
dy->set_mem_desc(dst_dy_memory->get_desc());
}
}
}
} // namespace phi

Expand Down

0 comments on commit a13d562

Please sign in to comment.