Skip to content

Commit

Permalink
fix: truncate_long_and_double incur torchscript inference issues
Browse files Browse the repository at this point in the history
Signed-off-by: Bo Wang <[email protected]>
  • Loading branch information
bowang007 committed Mar 23, 2022
1 parent 17490b1 commit c83aa15
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions core/partitioning/shape_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,13 @@ void getSegmentsOutputByRunning(
jit_inputs_ivalues.push_back(ivalues_maps[input].toList());
} else if (input->type()->kind() == torch::jit::TypeKind::TupleType) {
jit_inputs_ivalues.push_back(ivalues_maps[input].toTuple());
} else if (input->type()->kind() == torch::jit::TypeKind::NumberType) {
jit_inputs_ivalues.push_back(ivalues_maps[input].toScalar());
} else {
TORCHTRT_THROW_ERROR("Unable to find type for value: " << input->debugName() << " to get the ivalues.\n");
TORCHTRT_THROW_ERROR(
"Unable to find type for value: " << input->debugName()
<< " to get the ivalues. The type for this value should be "
<< input->type()->str() << " \n");
}
}

Expand Down Expand Up @@ -110,28 +115,31 @@ void getSegmentsOutputByRunning(
for (auto& i : seg_block.raw_inputs()) {
if (ivalues_maps[i].isTensor()) {
// set the input_shape and data_type
at::ScalarType t = ivalues_maps[i].toTensor().scalar_type();
// we can use a temp value here instead of replacing the values in ivalues_map since we only use ivalues_map for
// shape inference
auto cur_ivalue = ivalues_maps[i];
at::ScalarType t = cur_ivalue.toTensor().scalar_type();
if (!partition_info.truncate_long_and_double && (t == at::kLong || t == at::kDouble)) {
TORCHTRT_THROW_ERROR(
"Unable to process subgraph input type of at::kLong/at::kDouble, try to compile model with truncate_long_and_double enabled");
} else if (partition_info.truncate_long_and_double && t == at::kLong) {
ivalues_maps[i] = ivalues_maps[i].toTensor().to(at::kInt);
cur_ivalue = cur_ivalue.toTensor().to(at::kInt);
LOG_WARNING("Truncating graph input type from at::kLong to at::kInt");
} else if (partition_info.truncate_long_and_double && t == at::kDouble) {
ivalues_maps[i] = ivalues_maps[i].toTensor().to(at::kFloat);
cur_ivalue = cur_ivalue.toTensor().to(at::kFloat);
LOG_WARNING("Truncating graph input type from at::kDouble to at::kFloat");
}
c10::optional<nvinfer1::DataType> dtype = util::optTypeMetaToTRTDataType(ivalues_maps[i].toTensor().dtype());
c10::optional<nvinfer1::DataType> dtype = util::optTypeMetaToTRTDataType(cur_ivalue.toTensor().dtype());
if (dtype == c10::nullopt) {
TORCHTRT_THROW_ERROR("Unsupported input data type " << ivalues_maps[i].toTensor().dtype());
TORCHTRT_THROW_ERROR("Unsupported input data type " << cur_ivalue.toTensor().dtype());
}
if (ivalues_maps[i].toTensor().sizes().size() == 0) {
if (cur_ivalue.toTensor().sizes().size() == 0) {
// handle Scalar types, which has sizes of []
input_shapes.push_back(util::toVec(util::toDims(c10::List<long int>({1}))));
} else {
input_shapes.push_back(util::toVec(util::toDims(ivalues_maps[i].toTensor().sizes())));
input_shapes.push_back(util::toVec(util::toDims(cur_ivalue.toTensor().sizes())));
}
input_types.push_back(ivalues_maps[i].toTensor().scalar_type());
input_types.push_back(cur_ivalue.toTensor().scalar_type());
}
}

Expand Down

0 comments on commit c83aa15

Please sign in to comment.