diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index 3196cc550e..311f4277b0 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -64,6 +64,7 @@ void LowerGraph(std::shared_ptr& g, LowerInfo lower_info) { passes::RemoveNOPs(g); passes::AliasOperators(g); passes::SiluToSigmoidMultipication(g); + passes::RemoveSingleUse0DTensors(g); passes::RemoveUnnecessaryCasts(g); LOG_GRAPH(*g); } diff --git a/core/lowering/passes/passes.h b/core/lowering/passes/passes.h index 00c18883cd..348b56997f 100644 --- a/core/lowering/passes/passes.h +++ b/core/lowering/passes/passes.h @@ -27,6 +27,7 @@ void RemoveContiguous(std::shared_ptr& graph); void ViewToReshape(std::shared_ptr& graph); void RemoveDropout(std::shared_ptr& graph); void RemoveNOPs(std::shared_ptr graph); +void RemoveSingleUse0DTensors(std::shared_ptr& g); void RemoveUnnecessaryCasts(std::shared_ptr& graph); void UnpackAddMM(std::shared_ptr& graph); void UnpackBatchNorm(std::shared_ptr& graph); diff --git a/core/lowering/passes/remove_set_attrs.cpp b/core/lowering/passes/remove_set_attrs.cpp new file mode 100644 index 0000000000..6645707f49 --- /dev/null +++ b/core/lowering/passes/remove_set_attrs.cpp @@ -0,0 +1,35 @@ +#include +#include + +#include "torch/csrc/jit/passes/subgraph_rewrite.h" + +#include "core/lowering/passes/passes.h" +#include "core/util/prelude.h" + +namespace torch_tensorrt { +namespace core { +namespace lowering { +namespace passes { + +void RemoveSetAttrs(const torch::jit::Module& mod, std::string method_name) { + auto g = mod.get_method(method_name).graph(); + + std::string set_attr_pattern = R"IR( + graph(%self, %0): + None = prim::SetAttr[name="_has_warned"](%self, %0) + return ())IR"; + std::string no_set_attr_pattern = R"IR( + graph(%self, %0): + return ())IR"; + + // remove contiguous + torch::jit::SubgraphRewriter remove_set_attr; + remove_set_attr.RegisterRewritePattern(set_attr_pattern, no_set_attr_pattern); + remove_set_attr.runOnGraph(g); + LOG_GRAPH("Post remove contiguous: " << *g); +} + +} // namespace passes +} // namespace lowering +} // namespace core +} // namespace torch_tensorrt diff --git a/core/lowering/passes/remove_unnecessary_casts.cpp b/core/lowering/passes/remove_unnecessary_casts.cpp index 78bc0d0a71..7f6cc85284 100644 --- a/core/lowering/passes/remove_unnecessary_casts.cpp +++ b/core/lowering/passes/remove_unnecessary_casts.cpp @@ -1,4 +1,5 @@ #include "torch/csrc/jit/passes/subgraph_rewrite.h" +#include "torch/csrc/jit/ir/constants.h" #include "core/util/prelude.h" @@ -55,6 +56,119 @@ void RemoveUnnecessaryCasts(std::shared_ptr& graph) { LOG_GRAPH("After RemoveUnnecessaryCasts: " << *graph); } +void RemoveSingleUse0DTensors(std::shared_ptr& g) { + for (auto it = g->block()->nodes().begin(), end = g->block()->nodes().end(); it != end; ++it) { + if (it->kind() == torch::jit::prim::Constant) { + // Going from a constant and is single use means we can fuse + if (it->output()->type()->isSubtypeOf(c10::TensorType::get())) { + // Get the tensor stored in constant + at::Tensor t = *torch::jit::constant_as(it->output()); + // If shape is 0D + if (t.sizes() == std::vector({})) { + LOG_GRAPH("Found a 0D Tensor: " << it->output()->debugName()); + LOG_GRAPH("Number of uses: " << it->output()->uses().size()); + // If the tensor is only used once + if (it->output()->uses().size() == 1) { + auto use = it->output()->uses()[0]; + auto user = use.user; + + // Is a NumToTensor / aten::[Int/Float] case + if (user->outputs().size() == 1 && user->outputs()[0]->type()->isSubtypeOf(c10::TensorType::get())) { + if (user->output()->uses().size() == 1) { + auto potential_cast = user->output()->uses()[0].user; + // The downstream user is aten::Int + if (potential_cast->kind() == c10::Symbol::fromQualString("aten::Int") + || potential_cast->kind() == c10::Symbol::fromQualString("aten::Float")) { + LOG_GRAPH("Downstream user is aten::Int/aten::Float"); + auto arg = use.offset; + + for (size_t k = 0; k < user->inputs().size(); ++k) { + if (k != arg) { + if (user->inputs()[k]->type()->isSubtypeOf(c10::TensorType::get())) { + LOG_GRAPH("Input " << k << " is a Tensor"); + if (user->inputs()[k]->node()->kind() == c10::Symbol::fromQualString("prim::NumToTensor")) { + auto num_to_tensor = user->inputs()[k]->node(); + + LOG_GRAPH("Found a prim::NumToTensor / aten::[Int/Float] pair with an intermediate operation:\n " + << *(*it) + << *num_to_tensor + << *user + << *potential_cast); + + // Replace the Tensor Constant with a scalar constant + LOG_GRAPH("Deleting 0-dim Tensor: " << **it); + torch::jit::WithInsertPoint gaurd(*it); + + auto new_const_val = g->insertConstant(t.item(), c10::nullopt, it->scope()); + new_const_val->copyMetadata(it->output()); + // How to determine the internal scalar type instead of assuming? + if (potential_cast->kind() == c10::aten::Int) { + new_const_val->setType(c10::IntType::get()); + } else if (potential_cast->kind() == c10::aten::Float) { + new_const_val->setType(c10::FloatType::get()); + } + it->output()->replaceAllUsesWith(new_const_val); + it.destroyCurrent(); + + LOG_GRAPH("New constant: " << *new_const_val->node()); + + // Delete NumToTensor + LOG_GRAPH("Deleting NumToTensor: " << *num_to_tensor); + num_to_tensor->output()->replaceAllUsesWith(num_to_tensor->inputs()[0]); + num_to_tensor->destroy(); + + // Change intermediate op output type + LOG_GRAPH(user->schema()); + + torch::jit::Node* new_node; + switch (user->kind()) { + // Use this to handle special cases where the scalar version of the intermediate operator + // has a different schema than the original + case c10::aten::add: + new_node = g->create( + user->kind(), + torch::jit::ArrayRef({user->inputs()[0], user->inputs()[1]}), + 1); + new_node->insertAfter(user); + new_node->outputs()[0]->setType(c10::IntType::get()); + user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); + user->destroy(); + break; + default: + new_node = g->create( + user->kind(), + user->inputs(), + 1); + new_node->insertAfter(user); + new_node->outputs()[0]->setType(c10::IntType::get()); + user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); + user->destroy(); + break; + } + + LOG_GRAPH("New intermediate operation: " << *new_node); + LOG_GRAPH(new_node->schema()); + + // Delete aten::Int + LOG_GRAPH("Deleting aten::[Int/Float]: " << *potential_cast); + potential_cast->output()->replaceAllUsesWith(potential_cast->inputs()[0]); + potential_cast->destroy(); + } + } + } + } + } + } + } + } + } + } + } + } + LOG_ERROR("Post removing single use 0-dim Tensor operations: " << *g); +} + + } // namespace passes } // namespace lowering } // namespace core diff --git a/tests/core/lowering/test_remove_unnecessary_casts.cpp b/tests/core/lowering/test_remove_unnecessary_casts.cpp index e6a9b8373c..ef370a81c2 100644 --- a/tests/core/lowering/test_remove_unnecessary_casts.cpp +++ b/tests/core/lowering/test_remove_unnecessary_casts.cpp @@ -22,7 +22,7 @@ TEST(LoweringPasses, RemoveUnnecessaryCastIntCorrectly) { torch_tensorrt::core::util::logging::LogLevel::kGRAPH); auto sg = std::make_shared(); torch::jit::parseIR(source_graph, sg.get()); - torch_tensorrt::core::lowering::passes::RemoveContiguous(sg); + torch_tensorrt::core::lowering::passes::RemoveUnnecessaryCasts(sg); auto tg = std::make_shared(); torch::jit::parseIR(target_graph, tg.get()); @@ -46,7 +46,7 @@ TEST(LoweringPasses, RemoveUnnecessaryCastFloatCorrectly) { torch_tensorrt::core::util::logging::LogLevel::kGRAPH); auto sg = std::make_shared(); torch::jit::parseIR(source_graph, sg.get()); - torch_tensorrt::core::lowering::passes::RemoveContiguous(sg); + torch_tensorrt::core::lowering::passes::RemoveUnnecessaryCasts(sg); auto tg = std::make_shared(); torch::jit::parseIR(target_graph, tg.get()); @@ -70,7 +70,85 @@ TEST(LoweringPasses, RemoveUnnecessaryCastBoolCorrectly) { torch_tensorrt::core::util::logging::LogLevel::kGRAPH); auto sg = std::make_shared(); torch::jit::parseIR(source_graph, sg.get()); - torch_tensorrt::core::lowering::passes::RemoveContiguous(sg); + torch_tensorrt::core::lowering::passes::RemoveUnnecessaryCasts(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, tg.get()); + + ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); +} + +TEST(LoweringPasses, RemoveSingleUse0DTensorsIntCorrectly) { + std::string source_graph = R"IR( + graph(%0: int): + %1: Tensor = prim::Constant[value=[8]]() + %2: int = prim::Constant[value=1]() + %3: Tensor = prim::NumToTensor(%0) + %4: Tensor = aten::add(%1, %3, %2) + %5: int = aten::Int(%4) + %6: int = aten::add(%5, %5) + return (%6))IR"; + std::string target_graph = R"IR( + graph(%0: int): + %1: int = prim::Constant[value=8]() + %4: int = aten::add(%1, %0) + %6: int = aten::add(%4, %4) + return (%6))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, sg.get()); + + auto first_op = *(sg->block()->nodes().begin()); + torch::jit::WithInsertPoint guard(first_op); + torch::jit::Value* r = sg->insertConstant( + c10::scalar_to_tensor(8), c10::nullopt, first_op->scope()); + r->copyMetadata(first_op->output()); + r->setType(c10::TensorType::get()); + first_op->output()->replaceAllUsesWith(r); + first_op->destroy(); + + torch_tensorrt::core::lowering::passes::RemoveSingleUse0DTensors(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, tg.get()); + + ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); +} + +TEST(LoweringPasses, RemoveSingleUse0DTensorsFloatCorrectly) { + std::string source_graph = R"IR( + graph(%0: float): + %1: Tensor = prim::Constant[value=[8.]]() + %2: float = prim::Constant[value=1.]() + %3: Tensor = prim::NumToTensor(%0) + %4: Tensor = aten::add(%1, %3, %2) + %5: float = aten::Float(%4) + %6: float = aten::add(%5, %5) + return (%6))IR"; + std::string target_graph = R"IR( + graph(%0: float): + %1: float = prim::Constant[value=8.]() + %4: float = aten::add(%1, %0) + %6: float = aten::add(%4, %4) + return (%6))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, sg.get()); + + auto first_op = *(sg->block()->nodes().begin()); + torch::jit::WithInsertPoint guard(first_op); + torch::jit::Value* r = sg->insertConstant( + c10::scalar_to_tensor(8.0), c10::nullopt, first_op->scope()); + r->copyMetadata(first_op->output()); + r->setType(c10::TensorType::get()); + first_op->output()->replaceAllUsesWith(r); + first_op->destroy(); + + torch_tensorrt::core::lowering::passes::RemoveSingleUse0DTensors(sg); auto tg = std::make_shared(); torch::jit::parseIR(target_graph, tg.get());