Skip to content

Commit

Permalink
opt kernel_selection error msg (PaddlePaddle#48864)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiahy0825 authored Dec 8, 2022
1 parent fa65fed commit 3fbbee7
Showing 1 changed file with 12 additions and 26 deletions.
38 changes: 12 additions & 26 deletions paddle/phi/core/kernel_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,16 @@
#include "paddle/phi/core/compat/convert_utils.h"
#endif
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/utils/string/string_helper.h"

DECLARE_bool(enable_api_kernel_fallback);

namespace phi {

const static Kernel empty_kernel; // NOLINT

std::string kernel_selection_error_message(const std::string& kernel_name,
const KernelKey& target_key);
std::string KernelSelectionErrorMessage(const std::string& kernel_name,
const KernelKey& target_key);

uint32_t KernelKey::Hash::operator()(const KernelKey& key) const {
uint32_t hash_value = 0;
Expand Down Expand Up @@ -146,7 +147,7 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
"The kernel with key %s of kernel `%s` is not registered. %s",
kernel_key,
kernel_name,
kernel_selection_error_message(kernel_name, kernel_key)));
KernelSelectionErrorMessage(kernel_name, kernel_key)));

#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
VLOG(6) << "fluid_op_name: " << TransToFluidOpName(kernel_name);
Expand Down Expand Up @@ -176,7 +177,7 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
"fail to fallback to CPU one. %s",
kernel_key,
kernel_name,
kernel_selection_error_message(kernel_name, kernel_key)));
KernelSelectionErrorMessage(kernel_name, kernel_key)));

VLOG(3) << "missing " << kernel_key.backend() << " kernel: " << kernel_name
<< ", expected_kernel_key:" << kernel_key
Expand All @@ -195,7 +196,7 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
" to CPU one, please set the flag true before run again.",
kernel_key,
kernel_name,
kernel_selection_error_message(kernel_name, kernel_key)));
KernelSelectionErrorMessage(kernel_name, kernel_key)));

return {kernel_iter->second, false};
}
Expand Down Expand Up @@ -368,8 +369,8 @@ std::ostream& operator<<(std::ostream& os, KernelFactory& kernel_factory) {
// (GPU, Undefined(AnyLayout), [float32, float64, ...]);
// ...
// }
std::string kernel_selection_error_message(const std::string& kernel_name,
const KernelKey& target_key) {
std::string KernelSelectionErrorMessage(const std::string& kernel_name,
const KernelKey& target_key) {
PADDLE_ENFORCE_NE(
KernelFactory::Instance().kernels().find(kernel_name),
KernelFactory::Instance().kernels().end(),
Expand Down Expand Up @@ -402,25 +403,15 @@ std::string kernel_selection_error_message(const std::string& kernel_name,
// 1. If target_key not supports target backend, output "Selected wrong
// Backend ..."
if (!support_backend) {
std::string error_message = "";
for (auto iter = backend_set.begin(); iter != backend_set.end(); ++iter) {
error_message += *iter;
error_message += ", ";
}
error_message = error_message.substr(0, error_message.length() - 2);
std::string error_message = paddle::string::join_strings(backend_set, ", ");
return "Selected wrong Backend `" +
paddle::experimental::BackendToString(target_key.backend()) +
"`. Paddle support following Backends: " + error_message + ".";
}
// 2. If target_key not supports target datatype, output "Selected wrong
// DataType ..."
if (!support_dtype) {
std::string error_message = "";
for (auto iter = dtype_set.begin(); iter != dtype_set.end(); ++iter) {
error_message += *iter;
error_message += ", ";
}
error_message = error_message.substr(0, error_message.length() - 2);
std::string error_message = paddle::string::join_strings(dtype_set, ", ");
return "Selected wrong DataType `" +
paddle::experimental::DataTypeToString(target_key.dtype()) +
"`. Paddle support following DataTypes: " + error_message + ".";
Expand All @@ -431,14 +422,9 @@ std::string kernel_selection_error_message(const std::string& kernel_name,
kernel_name + "`: { ";
for (auto iter = all_kernel_key.begin(); iter != all_kernel_key.end();
++iter) {
message += "(" + iter->first + ", [";
std::vector<std::string>& dtype_vec = iter->second;
for (std::size_t i = 0; i < dtype_vec.size(); ++i) {
message += dtype_vec[i];
if (i + 1 != dtype_vec.size()) {
message += ", ";
}
}
message += "(" + iter->first + ", [";
message += paddle::string::join_strings(dtype_vec, ", ");
message += "]); ";
}
message += "}.";
Expand Down

0 comments on commit 3fbbee7

Please sign in to comment.