From ab8bcf3a54ce0260cfa72260b928260bf5cd16d3 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Thu, 13 May 2021 13:54:27 +0800 Subject: [PATCH 1/2] [opt] Simplify replace_statements and improve demote_dense_struct_fors --- taichi/ir/transforms.h | 16 +- .../transforms/demote_dense_struct_fors.cpp | 29 ++-- taichi/transforms/inlining.cpp | 8 +- taichi/transforms/replace_statements.cpp | 30 ++++ taichi/transforms/statement_replace.cpp | 155 ------------------ taichi/transforms/transform_statements.cpp | 61 +++++++ 6 files changed, 119 insertions(+), 180 deletions(-) create mode 100644 taichi/transforms/replace_statements.cpp delete mode 100644 taichi/transforms/statement_replace.cpp create mode 100644 taichi/transforms/transform_statements.cpp diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h index e498bff4f3e20..d3899c5a07cb0 100644 --- a/taichi/ir/transforms.h +++ b/taichi/ir/transforms.h @@ -75,25 +75,29 @@ bool constant_fold(IRNode *root, const CompileConfig &config, const ConstantFoldPass::Args &args); void offload(IRNode *root, const CompileConfig &config); +bool transform_statements( + IRNode *root, + std::function filter, + std::function transformer); /** * @param root The IR root to be traversed. * @param filter A function which tells if a statement need to be replaced. * @param generator If a statement s need to be replaced, generate a new * statement s1 with the argument s, insert s1 to s's place, and replace all * usages of s with s1. + * @return Whether the IR is modified. */ -void replace_statements_with( +bool replace_and_insert_statements( IRNode *root, std::function filter, std::function(Stmt *)> generator); /** - * @param generator If a statement s need to be replaced, find the existing + * @param finder If a statement s need to be replaced, find the existing * statement s1 with the argument s, and replace all usages of s with s1. - * @return Whether the IR is modified. */ -bool replace_statements_with(IRNode *root, - std::function filter, - std::function generator); +bool replace_statements(IRNode *root, + std::function filter, + std::function finder); void demote_dense_struct_fors(IRNode *root); bool demote_atomics(IRNode *root, const CompileConfig &config); void reverse_segments(IRNode *root); // for autograd diff --git a/taichi/transforms/demote_dense_struct_fors.cpp b/taichi/transforms/demote_dense_struct_fors.cpp index edac3bcb7b29f..687cf5933b264 100644 --- a/taichi/transforms/demote_dense_struct_fors.cpp +++ b/taichi/transforms/demote_dense_struct_fors.cpp @@ -86,21 +86,22 @@ void convert_to_range_for(OffloadedStmt *offloaded) { } } - for (int i = 0; i < num_loop_vars; i++) { - // TODO: Use only one (instead num_loop_vars) invocation(s) of - // replace_statements_with - irpass::replace_statements_with( - body.get(), - [&](Stmt *s) { - if (auto loop_index = s->cast()) { - return loop_index->loop == offloaded && - loop_index->index == - snodes.back()->physical_index_position[i]; - } + irpass::replace_statements( + body.get(), /*filter=*/ + [&](Stmt *s) { + if (auto loop_index = s->cast()) { + return loop_index->loop == offloaded; + } else { return false; - }, - [&](Stmt *) { return new_loop_vars[i]; }); - } + } + }, + /*finder=*/ + [&](Stmt *s) { + auto index = std::find(physical_indices.begin(), physical_indices.end(), + s->as()->index); + TI_ASSERT(index != physical_indices.end()); + return new_loop_vars[index - physical_indices.begin()]; + }); if (has_test) { // Create an If statement diff --git a/taichi/transforms/inlining.cpp b/taichi/transforms/inlining.cpp index 7c36f602420d2..a89676444186e 100644 --- a/taichi/transforms/inlining.cpp +++ b/taichi/transforms/inlining.cpp @@ -24,12 +24,10 @@ class Inliner : public BasicStmtVisitor { TI_ASSERT(func->rets.size() <= 1); auto inlined_ir = irpass::analysis::clone(func->ir.get()); if (!func->args.empty()) { - // TODO: Make sure that if stmt->args is an ArgLoadStmt, - // it will not be replaced again here - irpass::replace_statements_with( + irpass::replace_statements( inlined_ir.get(), /*filter=*/[&](Stmt *s) { return s->is(); }, - /*generator=*/ + /*finder=*/ [&](Stmt *s) { return stmt->args[s->as()->arg_id]; }); } if (func->rets.empty()) { @@ -46,7 +44,7 @@ class Inliner : public BasicStmtVisitor { // Use a local variable to store the return value auto *return_address = inlined_ir->as()->insert( Stmt::make(func->rets[0].dt), /*location=*/0); - irpass::replace_statements_with( + irpass::replace_and_insert_statements( inlined_ir.get(), /*filter=*/[&](Stmt *s) { return s->is(); }, /*generator=*/ diff --git a/taichi/transforms/replace_statements.cpp b/taichi/transforms/replace_statements.cpp new file mode 100644 index 0000000000000..08923dd5628b9 --- /dev/null +++ b/taichi/transforms/replace_statements.cpp @@ -0,0 +1,30 @@ +#include "taichi/ir/transforms.h" + +TLANG_NAMESPACE_BEGIN + +namespace irpass { + +bool replace_and_insert_statements( + IRNode *root, + std::function filter, + std::function(Stmt *)> generator) { + return transform_statements(root, std::move(filter), + [&](Stmt *stmt, DelayedIRModifier *modifier) { + modifier->replace_with(stmt, generator(stmt)); + }); +} + +bool replace_statements(IRNode *root, + std::function filter, + std::function finder) { + return transform_statements( + root, std::move(filter), [&](Stmt *stmt, DelayedIRModifier *modifier) { + auto existing_new_stmt = finder(stmt); + irpass::replace_all_usages_with(root, stmt, existing_new_stmt); + modifier->erase(stmt); + }); +} + +} // namespace irpass + +TLANG_NAMESPACE_END diff --git a/taichi/transforms/statement_replace.cpp b/taichi/transforms/statement_replace.cpp deleted file mode 100644 index 8c363f4c336d8..0000000000000 --- a/taichi/transforms/statement_replace.cpp +++ /dev/null @@ -1,155 +0,0 @@ -#include "taichi/ir/ir.h" -#include "taichi/ir/statements.h" -#include "taichi/ir/transforms.h" -#include "taichi/ir/visitors.h" - -TLANG_NAMESPACE_BEGIN - -// TODO: rewrite and simplify these classes -// Replace all usages and remove the statements themselves -class StatementReplaceAndRemove : public IRVisitor { - public: - IRNode *node; - std::function filter; - std::function generator; - DelayedIRModifier modifier; - - StatementReplaceAndRemove(IRNode *node, - std::function filter, - std::function generator) - : node(node), filter(filter), generator(generator) { - allow_undefined_visitor = true; - invoke_default_visitor = true; - } - - void replace_if_necessary(Stmt *stmt) { - if (filter(stmt)) { - auto new_stmt = generator(stmt); - irpass::replace_all_usages_with(node, stmt, new_stmt); - modifier.erase(stmt); - } - } - - void visit(Block *stmt_list) override { - for (auto &stmt : stmt_list->statements) { - stmt->accept(this); - } - } - - void visit(IfStmt *if_stmt) override { - replace_if_necessary(if_stmt); - if (if_stmt->true_statements) - if_stmt->true_statements->accept(this); - if (if_stmt->false_statements) { - if_stmt->false_statements->accept(this); - } - } - - void visit(WhileStmt *stmt) override { - replace_if_necessary(stmt); - stmt->body->accept(this); - } - - void visit(RangeForStmt *for_stmt) override { - replace_if_necessary(for_stmt); - for_stmt->body->accept(this); - } - - void visit(StructForStmt *for_stmt) override { - replace_if_necessary(for_stmt); - for_stmt->body->accept(this); - } - - void visit(Stmt *stmt) override { - replace_if_necessary(stmt); - } - - bool run() { - node->accept(this); - return modifier.modify_ir(); - } -}; - -// Replace both usages and the statements themselves -class StatementReplace : public IRVisitor { - public: - IRNode *node; - std::function filter; - std::function(Stmt *)> generator; - - StatementReplace(IRNode *node, - std::function filter, - std::function(Stmt *)> generator) - : node(node), filter(filter), generator(generator) { - allow_undefined_visitor = true; - invoke_default_visitor = true; - } - - void replace_if_necessary(Stmt *stmt) { - if (filter(stmt)) { - auto block = stmt->parent; - auto new_stmt = generator(stmt); - irpass::replace_all_usages_with(node, stmt, new_stmt.get()); - block->replace_with(stmt, std::move(new_stmt), false); - } - } - - void visit(Block *stmt_list) override { - for (auto &stmt : stmt_list->statements) { - stmt->accept(this); - } - } - - void visit(IfStmt *if_stmt) override { - replace_if_necessary(if_stmt); - if (if_stmt->true_statements) - if_stmt->true_statements->accept(this); - if (if_stmt->false_statements) { - if_stmt->false_statements->accept(this); - } - } - - void visit(WhileStmt *stmt) override { - replace_if_necessary(stmt); - stmt->body->accept(this); - } - - void visit(RangeForStmt *for_stmt) override { - replace_if_necessary(for_stmt); - for_stmt->body->accept(this); - } - - void visit(StructForStmt *for_stmt) override { - replace_if_necessary(for_stmt); - for_stmt->body->accept(this); - } - - void visit(Stmt *stmt) override { - replace_if_necessary(stmt); - } - - void run() { - node->accept(this); - } -}; - -namespace irpass { - -void replace_statements_with( - IRNode *root, - std::function filter, - std::function(Stmt *)> generator) { - StatementReplace transformer(root, filter, generator); - transformer.run(); -} - -bool replace_statements_with(IRNode *root, - std::function filter, - std::function generator) { - StatementReplaceAndRemove transformer(root, filter, generator); - return transformer.run(); -} - -} // namespace irpass - -TLANG_NAMESPACE_END diff --git a/taichi/transforms/transform_statements.cpp b/taichi/transforms/transform_statements.cpp new file mode 100644 index 0000000000000..eaae474fafdd9 --- /dev/null +++ b/taichi/transforms/transform_statements.cpp @@ -0,0 +1,61 @@ +#include "taichi/ir/ir.h" +#include "taichi/ir/statements.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/visitors.h" + +TLANG_NAMESPACE_BEGIN + +// Transform each filtered statement +class StatementsTransformer : public BasicStmtVisitor { + public: + using BasicStmtVisitor::visit; + + StatementsTransformer( + std::function filter, + std::function transformer) + : filter_(std::move(filter)), transformer_(std::move(transformer)) { + allow_undefined_visitor = true; + invoke_default_visitor = true; + } + + void maybe_transform(Stmt *stmt) { + if (filter_(stmt)) { + transformer_(stmt, &modifier_); + } + } + + void preprocess_container_stmt(Stmt *stmt) override { + maybe_transform(stmt); + } + + void visit(Stmt *stmt) override { + maybe_transform(stmt); + } + + static bool run(IRNode *root, + std::function filter, + std::function replacer) { + StatementsTransformer transformer(std::move(filter), std::move(replacer)); + root->accept(&transformer); + return transformer.modifier_.modify_ir(); + } + + private: + std::function filter_; + std::function transformer_; + DelayedIRModifier modifier_; +}; + +namespace irpass { + +bool transform_statements( + IRNode *root, + std::function filter, + std::function transformer) { + return StatementsTransformer::run(root, std::move(filter), + std::move(transformer)); +} + +} // namespace irpass + +TLANG_NAMESPACE_END From 9ba605902621c2ae3d195960bc9c3caac71031a3 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Thu, 13 May 2021 15:01:10 +0800 Subject: [PATCH 2/2] Apply review --- taichi/ir/transforms.h | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h index d3899c5a07cb0..6b55a79dc4c67 100644 --- a/taichi/ir/transforms.h +++ b/taichi/ir/transforms.h @@ -82,9 +82,9 @@ bool transform_statements( /** * @param root The IR root to be traversed. * @param filter A function which tells if a statement need to be replaced. - * @param generator If a statement s need to be replaced, generate a new - * statement s1 with the argument s, insert s1 to s's place, and replace all - * usages of s with s1. + * @param generator If a statement |s| need to be replaced, generate a new + * statement |s1| with the argument |s|, insert |s1| to where |s| is defined, + * remove |s|'s definition, and replace all usages of |s| with |s1|. * @return Whether the IR is modified. */ bool replace_and_insert_statements( @@ -92,8 +92,9 @@ bool replace_and_insert_statements( std::function filter, std::function(Stmt *)> generator); /** - * @param finder If a statement s need to be replaced, find the existing - * statement s1 with the argument s, and replace all usages of s with s1. + * @param finder If a statement |s| need to be replaced, find the existing + * statement |s1| with the argument |s|, remove |s|'s definition, and replace + * all usages of |s| with |s1|. */ bool replace_statements(IRNode *root, std::function filter,