-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MKLDNN layout: Support for pool operator #11101
Merged
tensor-tang
merged 1 commit into
PaddlePaddle:develop
from
mozga-intel:mozga-intel/Pool_mkldnn_layout
Jun 11, 2018
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,9 +18,14 @@ limitations under the License. */ | |
namespace paddle { | ||
namespace operators { | ||
|
||
using mkldnn::memory; // Note: paddle has also "memory" namespace | ||
using mkldnn::pooling_forward; | ||
using framework::DataLayout; | ||
using mkldnn::memory; | ||
using mkldnn::pooling_backward; | ||
using mkldnn::pooling_forward; | ||
using mkldnn::primitive; | ||
using mkldnn::reorder; | ||
using mkldnn::stream; | ||
using platform::to_void_cast; | ||
|
||
// Generate keys for storing/retriving primitives for this operator | ||
// TODO(jczaja): Make hashing function more optimial | ||
|
@@ -55,8 +60,9 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { | |
const Tensor* input = ctx.Input<Tensor>("X"); | ||
Tensor* output = ctx.Output<Tensor>("Out"); | ||
|
||
// Get an unique name from "argument" name of "Out" variable | ||
// This name will be used as key when saving info into device context | ||
PADDLE_ENFORCE(input->layout() == DataLayout::kMKLDNN && | ||
input->format() != memory::format::format_undef, | ||
"Wrong layout/format set for Input tensor"); | ||
|
||
std::string pooling_type = ctx.Attr<std::string>("pooling_type"); | ||
std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize"); | ||
|
@@ -82,6 +88,9 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { | |
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims()); | ||
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims()); | ||
|
||
auto input_format = input->format(); | ||
memory::format output_format{memory::format::format_undef}; | ||
|
||
const std::string key = gethash(src_tz, pooling_type, ksize, strides, | ||
paddings, ctx.op().Output("Out")); | ||
const std::string key_pool_p = key + "@pool_p"; | ||
|
@@ -94,16 +103,17 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { | |
auto pool_p = | ||
std::static_pointer_cast<pooling_forward>(dev_ctx.GetBlob(key_pool_p)); | ||
if (pool_p == nullptr) { | ||
// TODO(pzelazko-intel): support more formats | ||
auto src_md = platform::MKLDNNMemDesc( | ||
src_tz, platform::MKLDNNGetDataType<T>(), input_format); | ||
|
||
auto src_md = | ||
platform::MKLDNNMemDesc(src_tz, platform::MKLDNNGetDataType<T>(), | ||
mkldnn::memory::format::nchw); | ||
auto dst_md = | ||
platform::MKLDNNMemDesc(dst_tz, platform::MKLDNNGetDataType<T>(), | ||
mkldnn::memory::format::nchw); | ||
/* create memory descriptor for pooling without specified format | ||
* ('any') which lets a primitive (pooling in this case) choose | ||
* the memory format preferred for best performance | ||
*/ | ||
auto dst_md = platform::MKLDNNMemDesc(dst_tz, mkldnn::memory::f32, | ||
mkldnn::memory::format::any); | ||
|
||
std::shared_ptr<pooling_forward::primitive_desc> pool_pd = | ||
std::shared_ptr<mkldnn::pooling_forward::primitive_desc> pool_pd = | ||
CreatePrimitiveDesc(src_md, dst_md, strides, paddings, ksize, | ||
pooling_type, mkldnn_engine); | ||
|
||
|
@@ -116,20 +126,22 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { | |
// save pool_workspace_memory to be referred in backward path | ||
dev_ctx.SetBlob(key_pool_workspace_memory, workspace_memory); | ||
|
||
auto pool_src_memory_p = std::make_shared<memory>( | ||
memory::primitive_desc{src_md, mkldnn_engine}, | ||
static_cast<void*>(const_cast<T*>(input_data))); | ||
dev_ctx.SetBlob(key_pool_src_mem_p, pool_src_memory_p); | ||
auto src_memory = std::make_shared<memory>(pool_pd->src_primitive_desc(), | ||
to_void_cast<T>(input_data)); | ||
auto dst_memory = | ||
std::make_shared<memory>(pool_pd->dst_primitive_desc(), output_data); | ||
|
||
auto pool_dst_memory_p = std::make_shared<memory>( | ||
memory::primitive_desc{dst_md, mkldnn_engine}, | ||
static_cast<void*>(output_data)); | ||
dev_ctx.SetBlob(key_pool_dst_mem_p, pool_dst_memory_p); | ||
dev_ctx.SetBlob(key_pool_src_mem_p, src_memory); | ||
dev_ctx.SetBlob(key_pool_dst_mem_p, dst_memory); | ||
|
||
pool_p = std::make_shared<pooling_forward>(*pool_pd, *(src_memory.get()), | ||
*(dst_memory.get()), | ||
*workspace_memory); | ||
|
||
pool_p = std::make_shared<pooling_forward>( | ||
*pool_pd, *(pool_src_memory_p.get()), *(pool_dst_memory_p.get()), | ||
*workspace_memory); | ||
dev_ctx.SetBlob(key_pool_p, pool_p); | ||
|
||
output_format = | ||
(memory::format)dst_memory->get_primitive_desc().desc().data.format; | ||
} else { | ||
// Primitives already exist | ||
auto pool_src_memory_p = | ||
|
@@ -140,14 +152,20 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { | |
std::static_pointer_cast<memory>(dev_ctx.GetBlob(key_pool_dst_mem_p)); | ||
PADDLE_ENFORCE(pool_dst_memory_p != nullptr, | ||
"Fail to find pooling dst mem_p in device context"); | ||
pool_src_memory_p->set_data_handle( | ||
reinterpret_cast<void*>(const_cast<T*>(input_data))); | ||
pool_src_memory_p->set_data_handle(to_void_cast<T>(input_data)); | ||
pool_dst_memory_p->set_data_handle(output_data); | ||
|
||
output_format = (memory::format)pool_dst_memory_p->get_primitive_desc() | ||
.desc() | ||
.data.format; | ||
} | ||
|
||
// push primitive to stream and wait until it's executed | ||
std::vector<mkldnn::primitive> pipeline{*(pool_p.get())}; | ||
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); | ||
stream(stream::kind::eager).submit(pipeline).wait(); | ||
|
||
output->set_layout(DataLayout::kMKLDNN); | ||
output->set_format(output_format); | ||
} | ||
|
||
private: | ||
|
@@ -194,6 +212,13 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { | |
const Tensor* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out")); | ||
Tensor* in_x_grad = ctx.Output<Tensor>(framework::GradVarName("X")); | ||
|
||
PADDLE_ENFORCE(in_x->layout() == DataLayout::kMKLDNN && | ||
in_x->format() != memory::format::format_undef, | ||
"Wrong layout/format set for Input X tensor"); | ||
PADDLE_ENFORCE(out_grad->layout() == DataLayout::kMKLDNN && | ||
out_grad->format() != memory::format::format_undef, | ||
"Wrong layout/format set for Input output_grad tensor"); | ||
|
||
std::string pooling_type = ctx.Attr<std::string>("pooling_type"); | ||
std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize"); | ||
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides"); | ||
|
@@ -212,6 +237,7 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { | |
|
||
const T* out_grad_data = out_grad->data<T>(); | ||
T* in_x_grad_data = in_x_grad->mutable_data<T>(ctx.GetPlace()); | ||
memory::format in_x_grad_format{memory::format::format_undef}; | ||
|
||
std::vector<int> diff_src_tz = | ||
paddle::framework::vectorize2int(in_x_grad->dims()); | ||
|
@@ -225,39 +251,48 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { | |
const std::string key_pool_bwd_p = key + "@pool_bwd_p"; | ||
const std::string key_pool_diff_src_mem_p = key + "@pool_diff_src_mem_p"; | ||
const std::string key_pool_diff_dst_mem_p = key + "@pool_diff_dst_mem_p"; | ||
const std::string key_pool_src_mem_p = key + "@pool_src_mem_p"; | ||
const std::string key_pool_dst_mem_p = key + "@pool_dst_mem_p"; | ||
const std::string key_pool_pd = key + "@pool_pd"; | ||
const std::string key_pool_workspace_memory = | ||
key + "@pool_workspace_memory"; | ||
|
||
auto user_diff_dst_memory = | ||
memory({{{diff_dst_tz}, memory::data_type::f32, out_grad->format()}, | ||
mkldnn_engine}, | ||
to_void_cast<T>(out_grad_data)); | ||
|
||
std::shared_ptr<memory> diff_src_memory; | ||
std::shared_ptr<memory> diff_dst_memory; | ||
auto dst_memory = | ||
std::static_pointer_cast<memory>(dev_ctx.GetBlob(key_pool_dst_mem_p)); | ||
PADDLE_ENFORCE(dst_memory != nullptr, | ||
"Fail to find dst_memory in device context"); | ||
|
||
primitive reorder_diff_dst; | ||
bool is_diff_dst_reordered = false; | ||
auto pool_bwd_p = std::static_pointer_cast<pooling_backward>( | ||
dev_ctx.GetBlob(key_pool_bwd_p)); | ||
if (pool_bwd_p == nullptr) { | ||
auto diff_src_md = | ||
platform::MKLDNNMemDesc(diff_src_tz, platform::MKLDNNGetDataType<T>(), | ||
mkldnn::memory::format::nchw); | ||
auto diff_dst_md = | ||
platform::MKLDNNMemDesc(diff_dst_tz, platform::MKLDNNGetDataType<T>(), | ||
mkldnn::memory::format::nchw); | ||
// Retrieve src_memory/dst_memory saved in forward pass | ||
auto src_memory = | ||
std::static_pointer_cast<memory>(dev_ctx.GetBlob(key_pool_src_mem_p)); | ||
PADDLE_ENFORCE(src_memory != nullptr, | ||
"Fail to find src_memory in device context"); | ||
// Retrieve pool_pd/pool_workspace_memory from device context | ||
auto pool_pd = | ||
std::static_pointer_cast<mkldnn::pooling_forward::primitive_desc>( | ||
dev_ctx.GetBlob(key_pool_pd)); | ||
PADDLE_ENFORCE(pool_pd != nullptr, | ||
"Fail to find pool_pd in device context"); | ||
|
||
auto workspace_memory = std::static_pointer_cast<mkldnn::memory>( | ||
auto workspace_memory = std::static_pointer_cast<memory>( | ||
dev_ctx.GetBlob(key_pool_workspace_memory)); | ||
PADDLE_ENFORCE(workspace_memory != nullptr, | ||
"Fail to find workspace_memory in device context"); | ||
|
||
auto pool_diff_src_memory_p = std::make_shared<memory>(memory( | ||
{diff_src_md, mkldnn_engine}, static_cast<void*>(in_x_grad_data))); | ||
dev_ctx.SetBlob(key_pool_diff_src_mem_p, pool_diff_src_memory_p); | ||
|
||
auto pool_diff_dst_memory_p = std::make_shared<memory>( | ||
memory({diff_dst_md, mkldnn_engine}, | ||
static_cast<void*>(const_cast<T*>(out_grad_data)))); | ||
dev_ctx.SetBlob(key_pool_diff_dst_mem_p, pool_diff_dst_memory_p); | ||
// create memory descriptors for pooling | ||
auto diff_src_md = src_memory.get()->get_primitive_desc().desc(); | ||
auto diff_dst_md = dst_memory.get()->get_primitive_desc().desc(); | ||
|
||
auto pool_bwd_desc = mkldnn::pooling_backward::desc( | ||
pooling_type == "max" ? mkldnn::algorithm::pooling_max | ||
|
@@ -267,35 +302,74 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { | |
auto pool_bwd_pd = mkldnn::pooling_backward::primitive_desc( | ||
pool_bwd_desc, mkldnn_engine, *pool_pd); | ||
|
||
// reorder between user_diff_dst and pool diff_dst if needed | ||
diff_dst_memory = std::make_shared<memory>(user_diff_dst_memory); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We recommend |
||
if (memory::primitive_desc(dst_memory->get_primitive_desc()) != | ||
user_diff_dst_memory.get_primitive_desc()) { | ||
diff_dst_memory = | ||
std::make_shared<memory>(dst_memory.get()->get_primitive_desc()); | ||
reorder_diff_dst = reorder(user_diff_dst_memory, *diff_dst_memory); | ||
is_diff_dst_reordered = true; | ||
} | ||
|
||
diff_src_memory = std::make_shared<memory>( | ||
pool_bwd_pd.diff_src_primitive_desc(), in_x_grad_data); | ||
|
||
dev_ctx.SetBlob(key_pool_diff_src_mem_p, diff_src_memory); | ||
dev_ctx.SetBlob(key_pool_diff_dst_mem_p, diff_dst_memory); | ||
|
||
pool_bwd_p = std::make_shared<pooling_backward>( | ||
pool_bwd_pd, *(pool_diff_dst_memory_p.get()), *workspace_memory, | ||
*(pool_diff_src_memory_p)); | ||
pool_bwd_pd, *(diff_dst_memory.get()), *workspace_memory, | ||
*(diff_src_memory)); | ||
dev_ctx.SetBlob(key_pool_bwd_p, pool_bwd_p); | ||
|
||
} else { | ||
// Primitives already exist | ||
auto pool_diff_src_memory_p = std::static_pointer_cast<memory>( | ||
diff_src_memory = std::static_pointer_cast<memory>( | ||
dev_ctx.GetBlob(key_pool_diff_src_mem_p)); | ||
PADDLE_ENFORCE(pool_diff_src_memory_p != nullptr, | ||
PADDLE_ENFORCE(diff_src_memory != nullptr, | ||
"Fail to find pooling src mem_p in device context"); | ||
auto pool_diff_dst_memory_p = std::static_pointer_cast<memory>( | ||
diff_dst_memory = std::static_pointer_cast<memory>( | ||
dev_ctx.GetBlob(key_pool_diff_dst_mem_p)); | ||
PADDLE_ENFORCE(pool_diff_dst_memory_p != nullptr, | ||
PADDLE_ENFORCE(diff_dst_memory != nullptr, | ||
"Fail to find pooling dst mem_p in device context"); | ||
pool_diff_src_memory_p->set_data_handle( | ||
reinterpret_cast<void*>(in_x_grad_data)); | ||
pool_diff_dst_memory_p->set_data_handle(const_cast<T*>(out_grad_data)); | ||
|
||
diff_src_memory->set_data_handle(reinterpret_cast<void*>(in_x_grad_data)); | ||
diff_dst_memory->set_data_handle(const_cast<T*>(out_grad_data)); | ||
|
||
// reorder between user_diff_dst and pool diff_dst if needed | ||
if (memory::primitive_desc(dst_memory->get_primitive_desc()) != | ||
user_diff_dst_memory.get_primitive_desc()) { | ||
diff_dst_memory = | ||
std::make_shared<memory>(dst_memory.get()->get_primitive_desc()); | ||
reorder_diff_dst = reorder(user_diff_dst_memory, *diff_dst_memory); | ||
is_diff_dst_reordered = true; | ||
} | ||
} | ||
|
||
in_x_grad_format = (memory::format)diff_src_memory->get_primitive_desc() | ||
.desc() | ||
.data.format; | ||
|
||
// push primitive to stream and wait until it's executed | ||
std::vector<mkldnn::primitive> pipeline{*(pool_bwd_p.get())}; | ||
std::vector<mkldnn::primitive> pipeline; | ||
if (is_diff_dst_reordered) { | ||
pipeline.push_back(reorder_diff_dst); | ||
} | ||
pipeline.push_back(*(pool_bwd_p.get())); | ||
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); | ||
|
||
in_x_grad->set_layout(DataLayout::kMKLDNN); | ||
in_x_grad->set_format(in_x_grad_format); | ||
} // Compute() | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
|
||
REGISTER_OP_KERNEL(pool2d, MKLDNN, ::paddle::platform::CPUPlace, | ||
paddle::operators::PoolMKLDNNOpKernel<float>); | ||
ops::PoolMKLDNNOpKernel<float>); | ||
REGISTER_OP_KERNEL(pool2d_grad, MKLDNN, ::paddle::platform::CPUPlace, | ||
paddle::operators::PoolMKLDNNGradOpKernel<float>); | ||
ops::PoolMKLDNNGradOpKernel<float>); |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input->format()
.input_format is only used once.