-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/use cudnn #7141
Feature/use cudnn #7141
Changes from all commits
67150a1
5cf9849
8016c7c
19d5777
ed34369
00cad50
c73b619
356301f
0bb8e36
dd6f28f
f87fb57
ffab74d
1b2a283
cc2aacd
b73a76b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,6 +37,28 @@ auto KernelNHWC = OpKernelType(proto::DataType::FP64, platform::CPUPlace(), | |
auto KernelNCHW = OpKernelType(proto::DataType::FP64, platform::CPUPlace(), | ||
DataLayout::kNCHW, LibraryType::kPlain); | ||
|
||
// TODO(dzhwinter): Only for testing multiple op kernel. | ||
// Dummy transform function for library_type | ||
// should be removed. | ||
auto KernelPlain = OpKernelType(proto::DataType::FP32, platform::CUDAPlace(0), | ||
DataLayout::kAnyLayout, LibraryType::kPlain); | ||
|
||
auto KernelCUDNN = OpKernelType(proto::DataType::FP32, platform::CUDAPlace(0), | ||
DataLayout::kAnyLayout, LibraryType::kCUDNN); | ||
|
||
void DummyTrans(const platform::DeviceContext* ctx, | ||
const KernelTypePair& kernel_pair, const Variable& in, | ||
Variable* out) { | ||
PADDLE_ENFORCE(in.IsType<Tensor>(), "Only Support Tensor transform!."); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since all
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok. Will fixed in next PR |
||
PADDLE_ENFORCE( | ||
platform::places_are_same_class(kernel_pair.first.place_, | ||
kernel_pair.second.place_), | ||
"TransDataType Only Support DataType transform on same place!"); | ||
auto src = in.Get<Tensor>(); | ||
auto* dst = out->GetMutable<Tensor>(); | ||
*dst = src; | ||
} | ||
|
||
void TransDataType(const platform::DeviceContext* ctx, | ||
const KernelTypePair& kernel_pair, const Variable& in, | ||
Variable* out) { | ||
|
@@ -121,6 +143,8 @@ std::vector<int> NCHW2NHWC = {0, 2, 3, 1}; | |
} | ||
|
||
REGISTER_DATA_TRANSFORM_FN(f::KernelFP32, f::KernelFP64, f::TransDataType); | ||
REGISTER_DATA_TRANSFORM_FN(f::KernelPlain, f::KernelCUDNN, f::DummyTrans); | ||
REGISTER_DATA_TRANSFORM_FN(f::KernelCUDNN, f::KernelPlain, f::DummyTrans); | ||
REGISTER_DATA_TRANSFORM_FN(f::KernelNHWC, f::KernelNCHW, | ||
std::bind(f::TransDataLayout, NHWC2NCHW, | ||
std::placeholders::_1, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, | |
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
#include <glog/logging.h> | ||
|
||
#include <algorithm> | ||
#include <atomic> | ||
|
@@ -25,6 +26,53 @@ limitations under the License. */ | |
namespace paddle { | ||
namespace framework { | ||
|
||
std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority; | ||
|
||
void UseCPU() { | ||
kKernelPriority.clear(); | ||
/*Plain CPU*/ | ||
auto pair0 = std::make_tuple(platform::CPUPlace(), LibraryType::kPlain); | ||
kKernelPriority.insert(kKernelPriority.begin(), pair0); | ||
} | ||
|
||
void UseMKLDNN() { | ||
UseCPU(); | ||
#if PADDLE_WITH_MKLML | ||
{ | ||
/*MKLDNN Kernel*/ | ||
auto pair0 = std::make_tuple(platform::CPUPlace(), LibraryType::kMKLDNN); | ||
kKernelPriority.insert(kKernelPriority.begin(), pair0); | ||
} | ||
#endif | ||
} | ||
|
||
void UseCUDA() { | ||
UseMKLDNN(); | ||
#if PADDLE_WITH_CUDA | ||
/*Plain GPU*/ | ||
auto pair0 = std::make_tuple(platform::CUDAPlace(0), LibraryType::kPlain); | ||
kKernelPriority.insert(kKernelPriority.begin(), pair0); | ||
#endif | ||
} | ||
|
||
void UseCUDNN() { | ||
UseCUDA(); | ||
#if PADDLE_WITH_CUDA | ||
if (platform::dynload::HasCUDNN()) { | ||
/*CUDNN Kernel*/ | ||
auto pair0 = std::make_tuple(platform::CUDAPlace(0), LibraryType::kCUDNN); | ||
kKernelPriority.insert(kKernelPriority.begin(), pair0); | ||
} | ||
#endif | ||
} | ||
|
||
void UseALL() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since UseCUDNN calls UseCUDA UseALL is no needed. We can call UseCUDNN directly There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, UseXXX is recursively called previous UseXXX. |
||
UseCPU(); | ||
UseMKLDNN(); | ||
UseCUDA(); | ||
UseCUDNN(); | ||
} | ||
|
||
std::string OperatorBase::Input(const std::string& name) const { | ||
auto& ins = Inputs(name); | ||
PADDLE_ENFORCE_LE(ins.size(), 1UL, | ||
|
@@ -402,6 +450,12 @@ const platform::DeviceContext* GetDeviceContext( | |
} | ||
} | ||
|
||
const platform::DeviceContext* GetDeviceContext( | ||
const framework::OpKernelType& kernel) { | ||
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); | ||
return pool.Get(kernel.place_); | ||
} | ||
|
||
void OperatorWithKernel::Run(const Scope& scope, | ||
const platform::Place& place) const { | ||
RuntimeInferShapeContext infer_shape_ctx(*this, scope); | ||
|
@@ -422,23 +476,33 @@ void OperatorWithKernel::Run(const Scope& scope, | |
|
||
ExecutionContext ctx(*this, scope, *dev_ctx); | ||
auto actual_kernel_key = GetActualKernelType(ctx); | ||
auto expected_kernel_key = GetExpectedKernelType(actual_kernel_key); | ||
auto kernel_iter = kernels.find(expected_kernel_key); | ||
|
||
if (kernel_iter == kernels.end()) { | ||
PADDLE_THROW("The operator %s does not support %s", type_, | ||
expected_kernel_key); | ||
} | ||
auto expected_kernel_key = GetExpectedKernelType(actual_kernel_key); | ||
|
||
if (actual_kernel_key == expected_kernel_key) { | ||
PADDLE_ENFORCE_EQ(actual_kernel_key.place_, expected_kernel_key.place_, | ||
"Currently, model parallelism is only supported between " | ||
"CPU and other devices. For example, multi-GPU model " | ||
"parallelism will failed."); | ||
} else { | ||
// find the best key candidate | ||
const DataTransformFnMap& trans_map = DataTransformFnMap::Instance(); | ||
for (auto& candidate : kKernelPriority) { | ||
auto candidate_key = | ||
OpKernelType(actual_kernel_key.data_type_, std::get<0>(candidate), | ||
actual_kernel_key.data_layout_, std::get<1>(candidate)); | ||
|
||
auto candidate_pair = std::make_pair(actual_kernel_key, candidate_key); | ||
if ((actual_kernel_key == candidate_key) || | ||
(kernels.count(candidate_key) && | ||
trans_map.GetNullable(candidate_pair))) { | ||
expected_kernel_key = candidate_key; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The default Priority will overwrite user's configuration. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this does not obey the user configuration first rule. Will fix it in the next PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The cost of DataTrans are different, and the cost from small to large is the following: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我觉得我们把问题搞复杂了,目前为止只有 CPU <-> GPU的需求,MKLDNNLayout <-> kPlain的需求。下个PR里从op的attribute让用户选是否使用就可以,考虑cost就和tensorflow的cost model一样了 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. premature optimization is the root of all evil. |
||
break; | ||
} | ||
} | ||
|
||
auto kernel_pair = std::make_pair(actual_kernel_key, expected_kernel_key); | ||
const DataTransformFn* trans_fun = | ||
DataTransformFnMap::Instance().GetNullable(kernel_pair); | ||
const DataTransformFn* trans_fun = trans_map.GetNullable(kernel_pair); | ||
if (trans_fun) { | ||
auto input_vars = this->InputVars(); | ||
// TODO(qijun) filter the input vars that do not need to be transformed | ||
|
@@ -471,7 +535,20 @@ void OperatorWithKernel::Run(const Scope& scope, | |
} | ||
} | ||
|
||
kernel_iter->second->Compute(ctx); | ||
VLOG(10) << "Actual kernel: " << actual_kernel_key | ||
<< "Expected kernel: " << expected_kernel_key; | ||
|
||
auto kernel_iter = kernels.find(expected_kernel_key); | ||
|
||
if (kernel_iter == kernels.end()) { | ||
PADDLE_THROW("The operator %s does not support %s", type_, | ||
expected_kernel_key); | ||
} | ||
|
||
auto* expected_dev_ctx = GetDeviceContext(expected_kernel_key); | ||
ExecutionContext expected_ctx(*this, scope, *expected_dev_ctx); | ||
|
||
kernel_iter->second->Compute(expected_ctx); | ||
} | ||
|
||
OpKernelType OperatorWithKernel::GetActualKernelType( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
init does rely on operator