-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[oneDNN] Added Elementwise Mul grad fp32/bf16 #31647
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,118 @@ limitations under the License. */ | |
|
||
#include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
class ExecutionContext; | ||
} // namespace framework | ||
namespace platform { | ||
class CPUDeviceContext; | ||
struct CPUPlace; | ||
} // namespace platform | ||
} // namespace paddle | ||
Comment on lines
+17
to
+25
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are you using forward declarations instead of including an appropriate header file? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I copied that from elementwise_add which this is implementation is based on. What is advantage of not using forward declaration? I always sow them as speeding up a compilation process ? |
||
|
||
namespace paddle { | ||
namespace operators { | ||
template <typename T> | ||
class EltwiseMulMKLDNNGradKernel : public ElemwiseGradKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const override { | ||
ElemwiseGradKernel<T>::Compute(ctx); | ||
|
||
auto& dev_ctx = | ||
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>(); | ||
const auto& mkldnn_engine = dev_ctx.GetEngine(); | ||
|
||
auto* x = ctx.Input<framework::Tensor>("X"); | ||
auto* y = ctx.Input<framework::Tensor>("Y"); | ||
auto* dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out")); | ||
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X")); | ||
auto* dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y")); | ||
int axis = ctx.Attr<int>("axis"); | ||
|
||
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); | ||
|
||
if (dx) { | ||
// dx = dout*y | ||
platform::BinaryMKLDNNHandler<T> handler( | ||
dnnl::algorithm::binary_mul, axis, dev_ctx, mkldnn_engine, | ||
ctx.GetPlace(), dout, y, dx, 1.0f, 1.0f, 1.0f, | ||
ctx.InputName(framework::GradVarName("Out"))); | ||
|
||
const auto src_dout_memory = handler.AcquireSrcMemory(dout); | ||
const auto src_y_memory = handler.AcquireSecondSrcMemory(y); | ||
const auto dst_dx_memory = handler.AcquireDstMemory(dx); | ||
|
||
const auto binary_prim = handler.AcquireForwardPrimitive(); | ||
|
||
const std::unordered_map<int, dnnl::memory> args = { | ||
{DNNL_ARG_SRC_0, *src_dout_memory}, | ||
{DNNL_ARG_SRC_1, *src_y_memory}, | ||
{DNNL_ARG_DST, *dst_dx_memory}}; | ||
|
||
binary_prim->execute(astream, args); | ||
astream.wait(); | ||
|
||
dx->set_layout(framework::DataLayout::kMKLDNN); | ||
dx->set_format(platform::GetMKLDNNFormat(*dst_dx_memory)); | ||
} | ||
|
||
if (dy) { | ||
// dy = dout*x | ||
// Handler is having nullptr passed instead of output tensor as | ||
// we want Dst buffer to be allocated by oneDNN not to use Tensor | ||
platform::BinaryMKLDNNHandler<T> handler( | ||
dnnl::algorithm::binary_mul, axis, dev_ctx, mkldnn_engine, | ||
ctx.GetPlace(), dout, x, nullptr, 1.0f, 1.0f, 1.0f, | ||
ctx.InputName(framework::GradVarName("Out"))); | ||
|
||
const auto src_dout_memory = handler.AcquireSrcMemory(dout); | ||
const auto src_x_memory = handler.AcquireSecondSrcMemory(x); | ||
|
||
// If broadcasting is in use then let's write to temporary | ||
// buffer allocated by oneDNN | ||
const auto dst_dy_memory = (dout->dims() == dy->dims()) | ||
? handler.AcquireDstMemory(dy) | ||
: handler.AcquireDstMemory(); | ||
|
||
const auto binary_prim = handler.AcquireForwardPrimitive(); | ||
|
||
const std::unordered_map<int, dnnl::memory> args = { | ||
{DNNL_ARG_SRC_0, *src_dout_memory}, | ||
{DNNL_ARG_SRC_1, *src_x_memory}, | ||
{DNNL_ARG_DST, *dst_dy_memory}}; | ||
|
||
binary_prim->execute(astream, args); | ||
astream.wait(); | ||
|
||
dy->set_layout(framework::DataLayout::kMKLDNN); | ||
|
||
// Reduction is needed for broadcasting scenario | ||
if (dout->dims() != dy->dims()) { | ||
platform::ReductionMKLDNNHandler<T> handler_sum( | ||
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, mkldnn_engine, | ||
ctx.GetPlace(), dout, dy, | ||
ctx.InputName(framework::GradVarName("Out"))); | ||
auto dy_memory_p = handler_sum.AcquireDstMemory(dy); | ||
auto reduction_p = handler_sum.AcquireForwardPrimitive(); | ||
// As source we use mem object with results from binary operation | ||
reduction_p->execute(astream, {{DNNL_ARG_SRC, *dst_dy_memory}, | ||
{DNNL_ARG_DST, *dy_memory_p}}); | ||
astream.wait(); | ||
dy->set_format( | ||
platform::GetMKLDNNFormat(dy_memory_p->get_desc().reshape( | ||
paddle::framework::vectorize<int64_t>(dy->dims())))); | ||
|
||
} else { | ||
dy->set_format(platform::GetMKLDNNFormat(*dst_dy_memory)); | ||
} | ||
} | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
|
||
REGISTER_OP_KERNEL( | ||
|
@@ -23,3 +135,7 @@ REGISTER_OP_KERNEL( | |
dnnl::algorithm::binary_mul>, | ||
ops::EltwiseMKLDNNKernel<int8_t, dnnl::algorithm::binary_mul>, | ||
ops::EltwiseMKLDNNKernel<uint8_t, dnnl::algorithm::binary_mul>) | ||
|
||
REGISTER_OP_KERNEL(elementwise_mul_grad, MKLDNN, ::paddle::platform::CPUPlace, | ||
ops::EltwiseMulMKLDNNGradKernel<paddle::platform::bfloat16>, | ||
ops::EltwiseMulMKLDNNGradKernel<float>) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
import numpy as np | ||
from paddle.fluid.tests.unittests.op_test import skip_check_grad_ci | ||
from paddle.fluid.tests.unittests.test_elementwise_mul_op import ElementwiseMulOp | ||
from paddle import enable_static | ||
|
||
|
||
class TestMKLDNNElementwiseMulOp(ElementwiseMulOp): | ||
|
@@ -51,13 +52,17 @@ def init_input_output(self): | |
def test_check_grad_normal(self): | ||
pass | ||
|
||
def test_check_grad_ingore_x(self): | ||
pass | ||
|
||
def test_check_grad_ingore_y(self): | ||
pass | ||
Comment on lines
-54
to
56
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why removing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very nice catch. Reason is that there maybe something broken for reference implementation of test_check_grad_ingore_y , because it failed for that UT even without oneDNN being used. So I enabled only this test that I checked that works and left other commented out. |
||
|
||
|
||
class TestMKLDNNElementwiseMulOp5(TestMKLDNNElementwiseMulOp): | ||
def init_input_output(self): | ||
self.x = np.random.uniform(1, 2, [2, 3, 4, 100]).astype(self.dtype) | ||
self.y = np.random.uniform(1, 2, [100]).astype(self.dtype) | ||
self.out = np.multiply(self.x, self.y) | ||
|
||
|
||
''' INT8 Tests ''' | ||
|
||
|
||
|
@@ -140,4 +145,5 @@ def init_dtype(self): | |
|
||
|
||
if __name__ == '__main__': | ||
enable_static() | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about a case like this:
dx_dims = [2, 3, 4, 5]
, anddy_dims = [3, 3, 5, 5]
This couldn't be broadcasted together and this will pass this condition.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but operator itself will never be given such a values. If they were given such a shape infershape of op should reject them.