Skip to content

Commit

Permalink
fix op_teller dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
sanbuphy committed Jan 14, 2023
1 parent 47e3822 commit 8094c09
Showing 1 changed file with 20 additions and 10 deletions.
30 changes: 20 additions & 10 deletions paddle/fluid/inference/tensorrt/op_teller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2132,19 +2132,29 @@ struct SimpleOpTypeSetTeller : public Teller {
}

auto dtype = x_var_desc->GetDataType();
if (op_type == "reduce_all" || op_type == "reduce_any") {
if (dtype != framework::proto::VarType::BOOL) {
VLOG(3)
<< "reduce_all and reduce_any op input data type must be bool";
return false;
}
} else {
#if IS_TRT_VERSION_GE(7000)
if (dtype != framework::proto::VarType::INT32 &&
dtype != framework::proto::VarType::FP32) {
VLOG(3) << "reduce op input data type must be int32 or float32";
return false;
}
if (dtype != framework::proto::VarType::INT32 &&
dtype != framework::proto::VarType::FP32) {
std::cout << "reduce op input data type must be int32 or float32"
<< std::endl;
VLOG(3) << "reduce op input data type must be int32 or float32";
return false;
}
#else
if (dtype != framework::proto::VarType::FP32) {
VLOG(3) << "reduce op input data type must be float32 using TensorRT "
"< 7.0";
return false;
}
if (dtype != framework::proto::VarType::FP32) {
VLOG(3) << "reduce op input data type must be float32 using TensorRT "
"< 7.0";
return false;
}
#endif
}
}
#if IS_TRT_VERSION_GE(7000)
if (op_type == "tile") {
Expand Down

0 comments on commit 8094c09

Please sign in to comment.