-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update activations for MKL-DNN #10597
Update activations for MKL-DNN #10597
Conversation
771ac6f
to
598ae8e
Compare
What's the changes do you mean? |
Changes for register activation operators (look at changes in activation_op.cc file, e.g. macros to create REGISTER_ACTIVATION_OP_MAKER class XXXOpMaker, new macros for register operators e.g. INPLACE, etc.) |
__macro(SoftRelu, soft_relu); \ | ||
__macro(Relu6, relu6); \ | ||
__macro(Reciprocal, reciprocal); \ | ||
__macro(HardSigmoid, hard_sigmoid); | ||
|
||
#define FOR_EACH_MKLDNN_INPLACE_OP_FUNCTOR(__macro) \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why changes of activation_op.cc
?
- Are the previous registration not worked? But I see that the unit-test
test_activation_mkldnn_op.py
works well. - Current changes are not suitable for other devices. For example, if there is
amd_relu_op
, how should it be defined?
598ae8e
to
bf447fc
Compare
static_cast<void *>(const_cast<float *>(src_data))); | ||
static_cast<void *>(const_cast<float *>(src_data)))); | ||
// save source memory to device context to be referred in backward path | ||
dev_ctx.SetBlob("InputX@eltwise_pd", src_memory); | ||
auto dst_memory = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why only save src_memory
, how about dst_memory
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eltwise_grad needs input data (e.g. look at activation_op.cc ReluAddInput("X", ...)). I have no access to this data directly in eltwise_grad function. Only eltwise_forward has access to this input data.
dst_memory is not used in eltwise_grad.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example, in MKL-DNN sqrt_bwd has different implementation than PaddlePaddle:
// MKL-DNN
template
T sqrt_bwd(T dd, T s) {
return s > 0 ? dd / (2 * ::sqrtf(s)) : 0;
}
// PP
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
const Out out_conj = Eigen::numext::conj(out);
dx.device(d) = static_cast(0.5) * dout / out_conj;
}
this the reason I need input data in eltwise_grad
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know why src_memory
is needed, but my point is why not also save dst_memory
to context to save time since you have already saved src.
@@ -69,7 +71,7 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm, | |||
forward_desc, mkldnn_engine); | |||
dev_ctx.SetBlob(key_eltwise_pd, forward_pd); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can only see setting forward_pd every forward iteration but not retrieving the existed one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
forward_pd is retrieved in eltwise_grad() function, line with dev_ctx.GetBlob(key_eltwise_pd)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, my point is I can only see you always create and set this pd to context every forward iteration, but no reuse it in next iteration. Could we avoid this recreation among every iterations to enhance performance?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, of course. I will improve it in next commit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added suport for reusing memory buffers to MKL-DNN activations.
bf447fc
to
f6404f0
Compare
Could you possibly continue your review ? I added suport for reusing memory buffers. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the late reply, I have one question.
@@ -23,6 +24,18 @@ using paddle::framework::Tensor; | |||
using paddle::platform::MKLDNNDeviceContext; | |||
|
|||
namespace { | |||
std::string gethash(const mkldnn::memory::dims &operand_dims, | |||
const mkldnn::algorithm algorithm) { | |||
auto dim2str = [](const mkldnn::memory::dims &operand_dims) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is that possible two fc ops with same src dims, then would this be right? They would share same hash code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They will be reusable except input data. I improved hash code for input data (key_src_data)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, good.
f6404f0
to
24904b9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Updated activations for MKL-DNN after changes in PaddlePaddle activations.