Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add some annotations and log strings, rename mem_desc variables (#16609)
Browse files Browse the repository at this point in the history
  • Loading branch information
xziya authored and pengzhao-intel committed Oct 24, 2019
1 parent 2210b21 commit 29f2e32
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
17 changes: 9 additions & 8 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -532,16 +532,15 @@ const mkldnn::memory *NDArray::GetMKLDNNDataReorder(
return GetMKLDNNExact(mem, new_desc);
}

mkldnn::memory::desc desc1 = mem->get_desc();
mkldnn::memory::desc desc2 = new_desc;
mkldnn::memory::desc old_desc = mem->get_desc();
// Now we need to determine if we should reorder the memory.
// If both use the default formats, we think we don't need to reorder.
if ((!mxnet::IsMKLDNN(desc1)) && (!mxnet::IsMKLDNN(desc2))) {
if ((!mxnet::IsMKLDNN(old_desc)) && (!mxnet::IsMKLDNN(new_desc))) {
mkldnn_mem_ptr ret(new mkldnn::memory(new_desc,
CpuEngine::Get()->get_engine(), mem->get_data_handle()));
stream->RegisterMem(ret);
return ret.get();
} else if (same_shape(desc1, desc2)) {
} else if (same_shape(old_desc, new_desc)) {
// If they have the same shape, we can reorder data directly.
mkldnn::memory *ret = TmpMemMgr::Get()->Alloc(new_desc);
std::unordered_map<int, mkldnn::memory> args({{MKLDNN_ARG_FROM, *mem }, {MKLDNN_ARG_TO, *ret}});
Expand All @@ -551,9 +550,9 @@ const mkldnn::memory *NDArray::GetMKLDNNDataReorder(
// If they have different shapes, we need to reshape the array first.
// Since this method will only be used inside an operator, we can call
// MKLDNNDataReshape to reshape an array.
mxnet::TShape required_shape(desc2.data.ndims, -1);
for (int i = 0; i < desc2.data.ndims; i++)
required_shape[i] = desc2.data.dims[i];
mxnet::TShape required_shape(new_desc.data.ndims, -1);
for (int i = 0; i < new_desc.data.ndims; i++)
required_shape[i] = new_desc.data.dims[i];
NDArray reshaped = MKLDNNDataReshape(required_shape);
const mkldnn::memory *ret = reshaped.GetMKLDNNData();
if (ret->get_desc() == new_desc) {
Expand Down Expand Up @@ -684,7 +683,9 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) {

mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::desc &desc) {
if (desc.get_size() != shape().Size() * GetTypeSize(dtype_)) {
LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN memory desc ";
LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN memory desc. "
<< "MKLDNN memory requests for " << desc.get_size() << " bytes, but got "
<< shape().Size() * GetTypeSize(dtype_) << " bytes from NDArray";
return nullptr;
}
bool isDefaultFormat = IsDefaultFormat(desc);
Expand Down
13 changes: 9 additions & 4 deletions src/operator/nn/mkldnn/mkldnn_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,16 @@ mkldnn::memory *TmpMemMgr::Alloc(const mkldnn::memory::desc &md) {
this->curr_mem = static_cast<char *>(mem) + md.get_size();
return ret.get();
} else {
// If curr_mem has been initialized and we still reach here. It means
// the current allocated memory isn't enough.
// If curr_mem has been initialized and we still reach here, it means the current
// allocated memory isn't enough. But it doesn't matter for multiple invokes of a
// operator, as the TmpMemMgr could estimate the space at the first iteration and
// then re-requests abundant space from MXNet resource. MKL-DNN could allocate
// the space by itself. Thus, we just let it continue for estimating the maximum
// required space size. It will be allocated at next call.
if (this->curr_mem && dmlc::GetEnv("MXNET_MKLDNN_DEBUG", false)) {
LOG(WARNING) << "Allocate " << md.get_size()
<< " bytes with malloc directly";
LOG(WARNING) << "mkl-dnn debug message: The rest of the temporary space is not "
<< "adequate for allocating " << md.get_size() << " bytes. Thus, mkl-dnn "
<< "allocate the space by itself.";
}
mkldnn_mem_ptr ret(new mkldnn::memory(md, CpuEngine::Get()->get_engine()));
MKLDNNStream::Get()->RegisterMem(ret);
Expand Down

0 comments on commit 29f2e32

Please sign in to comment.