diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index aaa7aedf8bcd..78a6cfb15fd2 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -647,6 +647,7 @@ const mkldnn::memory *NDArray::GetMKLDNNData() const { // If this is a view, we can't create a MKLDNN memory for the chunk // because we don't have the complete data type and shape information for // the chunk. + CheckAndAlloc(); void *off_addr = static_cast(ptr_->shandle.dptr) + byte_offset_; // Create the primitive desc for the new mkldnn memory. mkldnn::memory::dims dims(shape().ndim()); @@ -665,6 +666,7 @@ const mkldnn::memory *NDArray::GetMKLDNNData() const { } else { // If this isn't a view, we can create a MKLDNN memory and store it in the // chunk. + CheckAndAlloc(); ptr_->SetMKLMem(shape_, dtype_); MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem()); return ptr_->mkl_mem_->GetRaw();