From 908340fb4a28bd4f73a55804cd4463604ddd07c2 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Mon, 7 Feb 2022 19:00:15 -0800 Subject: [PATCH] feat(aten::Int): Lowers out aten::Int This commit adds a pass to lower out aten::[Int/Float/Bool], aten::NumToTensor pairs w.o. exception. We are assumming this is safe as there are similar passes in PyTorch for ONNX lowering however the scope of this rule is intentionally limited to avoid possible cases where it is not safe. Therefore it should not be expected that all aten::Int issues will be solved with this change and the operator itself remains a limitation of TorchTRT Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- core/lowering/lowering.cpp | 2 + core/lowering/passes/BUILD | 1 + core/lowering/passes/passes.h | 1 + .../passes/remove_unnecessary_casts.cpp | 61 ++++++++++++++ tests/core/lowering/BUILD | 5 ++ .../test_remove_unnecessary_casts.cpp | 79 +++++++++++++++++++ 6 files changed, 149 insertions(+) create mode 100644 core/lowering/passes/remove_unnecessary_casts.cpp create mode 100644 tests/core/lowering/test_remove_unnecessary_casts.cpp diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index e5e2e780df..3196cc550e 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -1,6 +1,7 @@ #include "torch/csrc/jit/passes/common_subexpression_elimination.h" #include "torch/csrc/jit/passes/create_functional_graphs.h" #include "torch/csrc/jit/passes/dead_code_elimination.h" +#include "torch/csrc/jit/passes/erase_number_types.h" #include "torch/csrc/jit/passes/freeze_module.h" #include "torch/csrc/jit/passes/fuse_linear.h" #include "torch/csrc/jit/passes/guard_elimination.h" @@ -63,6 +64,7 @@ void LowerGraph(std::shared_ptr& g, LowerInfo lower_info) { passes::RemoveNOPs(g); passes::AliasOperators(g); passes::SiluToSigmoidMultipication(g); + passes::RemoveUnnecessaryCasts(g); LOG_GRAPH(*g); } diff --git a/core/lowering/passes/BUILD b/core/lowering/passes/BUILD index fde517c428..de0e488376 100644 --- a/core/lowering/passes/BUILD +++ b/core/lowering/passes/BUILD @@ -23,6 +23,7 @@ cc_library( "view_to_reshape.cpp", "remove_dropout.cpp", "remove_nops.cpp", + "remove_unnecessary_casts.cpp", "silu_to_sigmoid_multiplication.cpp", "unpack_addmm.cpp", "unpack_batch_norm.cpp", diff --git a/core/lowering/passes/passes.h b/core/lowering/passes/passes.h index a79d6c1cdc..00c18883cd 100644 --- a/core/lowering/passes/passes.h +++ b/core/lowering/passes/passes.h @@ -27,6 +27,7 @@ void RemoveContiguous(std::shared_ptr& graph); void ViewToReshape(std::shared_ptr& graph); void RemoveDropout(std::shared_ptr& graph); void RemoveNOPs(std::shared_ptr graph); +void RemoveUnnecessaryCasts(std::shared_ptr& graph); void UnpackAddMM(std::shared_ptr& graph); void UnpackBatchNorm(std::shared_ptr& graph); void UnpackLogSoftmax(std::shared_ptr& graph); diff --git a/core/lowering/passes/remove_unnecessary_casts.cpp b/core/lowering/passes/remove_unnecessary_casts.cpp new file mode 100644 index 0000000000..78bc0d0a71 --- /dev/null +++ b/core/lowering/passes/remove_unnecessary_casts.cpp @@ -0,0 +1,61 @@ +#include "torch/csrc/jit/passes/subgraph_rewrite.h" + +#include "core/util/prelude.h" + +#include + +namespace torch_tensorrt { +namespace core { +namespace lowering { +namespace passes { + + +// Presumably this is safe since torch::jit::EraseNumberTypesOnBlock exists which just +// removes prim::TensorToNum, aten::Float, aten::Int and prim::NumToTensor nodes outright +void RemoveUnnecessaryCasts(std::shared_ptr& graph) { + std::string int_cast_pattern = R"IR( + graph(%1: int): + %2: Tensor = aten::NumToTensor(%1) + %3: int = aten::Int(%2) + return (%3))IR"; + std::string int_clean_pattern = R"IR( + graph(%1: int): + return (%1))IR"; + + std::string float_cast_pattern = R"IR( + graph(%1: float): + %2: Tensor = aten::NumToTensor(%1) + %3: float = aten::Float(%2) + return (%3))IR"; + std::string float_clean_pattern = R"IR( + graph(%1: float): + return (%1))IR"; + + std::string bool_cast_pattern = R"IR( + graph(%1: bool): + %2: Tensor = aten::NumToTensor(%1) + %3: bool = aten::Bool(%2) + return (%3))IR"; + std::string bool_clean_pattern = R"IR( + graph(%1: bool): + return (%1))IR"; + + torch::jit::SubgraphRewriter int_cast_rewriter; + int_cast_rewriter.RegisterRewritePattern(int_cast_pattern, int_clean_pattern); + int_cast_rewriter.runOnGraph(graph); + + torch::jit::SubgraphRewriter float_cast_rewriter; + float_cast_rewriter.RegisterRewritePattern(float_cast_pattern, float_clean_pattern); + float_cast_rewriter.runOnGraph(graph); + + torch::jit::SubgraphRewriter bool_cast_rewriter; + bool_cast_rewriter.RegisterRewritePattern(bool_cast_pattern, bool_clean_pattern); + bool_cast_rewriter.runOnGraph(graph); + + LOG_GRAPH("After RemoveUnnecessaryCasts: " << *graph); +} + +} // namespace passes +} // namespace lowering +} // namespace core +} // namespace torch_tensorrt diff --git a/tests/core/lowering/BUILD b/tests/core/lowering/BUILD index b97f9ba451..6eebb79d3f 100644 --- a/tests/core/lowering/BUILD +++ b/tests/core/lowering/BUILD @@ -50,6 +50,10 @@ lowering_test( name = "test_remove_detach_pass", ) +lowering_test( + name = "test_remove_unnecessary_casts", +) + lowering_test( name = "test_view_to_reshape_pass", ) @@ -81,6 +85,7 @@ test_suite( ":test_remove_detach_pass", ":test_view_to_reshape_pass", ":test_remove_dropout_pass", + ":test_remove_unnecessary_casts", ":test_reduce_to_pass", ":test_reduce_gelu", ":test_unpack_hardswish", diff --git a/tests/core/lowering/test_remove_unnecessary_casts.cpp b/tests/core/lowering/test_remove_unnecessary_casts.cpp new file mode 100644 index 0000000000..e6a9b8373c --- /dev/null +++ b/tests/core/lowering/test_remove_unnecessary_casts.cpp @@ -0,0 +1,79 @@ +#include +#include "core/compiler.h" +#include "core/lowering/passes/passes.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" +#include "torch/csrc/jit/ir/subgraph_matcher.h" + +TEST(LoweringPasses, RemoveUnnecessaryCastIntCorrectly) { + std::string source_graph = R"IR( + graph(%1: int): + %2: Tensor = aten::NumToTensor(%1) + %3: int = aten::Int(%2) + %4: int = aten::add(%3, %3, %3) + return (%4))IR"; + std::string target_graph = R"IR( + graph(%1: int): + %4: int = aten::add(%1, %1, %1) + return (%4))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, sg.get()); + torch_tensorrt::core::lowering::passes::RemoveContiguous(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, tg.get()); + + ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); +} + +TEST(LoweringPasses, RemoveUnnecessaryCastFloatCorrectly) { + std::string source_graph = R"IR( + graph(%1: float): + %2: Tensor = aten::NumToTensor(%1) + %3: float = aten::Float(%2) + %4: float = aten::add(%3, %3, %3) + return (%3))IR"; + std::string target_graph = R"IR( + graph(%1: float): + %4: float = aten::add(%1, %1, %1) + return (%4))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, sg.get()); + torch_tensorrt::core::lowering::passes::RemoveContiguous(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, tg.get()); + + ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); +} + +TEST(LoweringPasses, RemoveUnnecessaryCastBoolCorrectly) { + std::string source_graph = R"IR( + graph(%1: bool): + %2: Tensor = aten::NumToTensor(%1) + %3: bool = aten::Bool(%2) + %4: bool = aten::__and__(%3, %3) + return (%3))IR"; + std::string target_graph = R"IR( + graph(%1: bool): + %4: bool = aten::__and__(%1, %1) + return (%4))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, sg.get()); + torch_tensorrt::core::lowering::passes::RemoveContiguous(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, tg.get()); + + ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); +} \ No newline at end of file