Skip to content

Commit

Permalink
[opt] Simplify replace_statements and improve demote_dense_struct_fors (
Browse files Browse the repository at this point in the history
#2335)

* [opt] Simplify replace_statements and improve demote_dense_struct_fors

* Apply review
  • Loading branch information
xumingkuan authored May 13, 2021
1 parent 553522e commit e5f439b
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 184 deletions.
25 changes: 15 additions & 10 deletions taichi/ir/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,25 +75,30 @@ bool constant_fold(IRNode *root,
const CompileConfig &config,
const ConstantFoldPass::Args &args);
void offload(IRNode *root, const CompileConfig &config);
bool transform_statements(
IRNode *root,
std::function<bool(Stmt *)> filter,
std::function<void(Stmt *, DelayedIRModifier *)> transformer);
/**
* @param root The IR root to be traversed.
* @param filter A function which tells if a statement need to be replaced.
* @param generator If a statement s need to be replaced, generate a new
* statement s1 with the argument s, insert s1 to s's place, and replace all
* usages of s with s1.
* @param generator If a statement |s| need to be replaced, generate a new
* statement |s1| with the argument |s|, insert |s1| to where |s| is defined,
* remove |s|'s definition, and replace all usages of |s| with |s1|.
* @return Whether the IR is modified.
*/
void replace_statements_with(
bool replace_and_insert_statements(
IRNode *root,
std::function<bool(Stmt *)> filter,
std::function<std::unique_ptr<Stmt>(Stmt *)> generator);
/**
* @param generator If a statement s need to be replaced, find the existing
* statement s1 with the argument s, and replace all usages of s with s1.
* @return Whether the IR is modified.
* @param finder If a statement |s| need to be replaced, find the existing
* statement |s1| with the argument |s|, remove |s|'s definition, and replace
* all usages of |s| with |s1|.
*/
bool replace_statements_with(IRNode *root,
std::function<bool(Stmt *)> filter,
std::function<Stmt *(Stmt *)> generator);
bool replace_statements(IRNode *root,
std::function<bool(Stmt *)> filter,
std::function<Stmt *(Stmt *)> finder);
void demote_dense_struct_fors(IRNode *root);
bool demote_atomics(IRNode *root, const CompileConfig &config);
void reverse_segments(IRNode *root); // for autograd
Expand Down
29 changes: 15 additions & 14 deletions taichi/transforms/demote_dense_struct_fors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,21 +86,22 @@ void convert_to_range_for(OffloadedStmt *offloaded) {
}
}

for (int i = 0; i < num_loop_vars; i++) {
// TODO: Use only one (instead num_loop_vars) invocation(s) of
// replace_statements_with
irpass::replace_statements_with(
body.get(),
[&](Stmt *s) {
if (auto loop_index = s->cast<LoopIndexStmt>()) {
return loop_index->loop == offloaded &&
loop_index->index ==
snodes.back()->physical_index_position[i];
}
irpass::replace_statements(
body.get(), /*filter=*/
[&](Stmt *s) {
if (auto loop_index = s->cast<LoopIndexStmt>()) {
return loop_index->loop == offloaded;
} else {
return false;
},
[&](Stmt *) { return new_loop_vars[i]; });
}
}
},
/*finder=*/
[&](Stmt *s) {
auto index = std::find(physical_indices.begin(), physical_indices.end(),
s->as<LoopIndexStmt>()->index);
TI_ASSERT(index != physical_indices.end());
return new_loop_vars[index - physical_indices.begin()];
});

if (has_test) {
// Create an If statement
Expand Down
8 changes: 3 additions & 5 deletions taichi/transforms/inlining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,10 @@ class Inliner : public BasicStmtVisitor {
TI_ASSERT(func->rets.size() <= 1);
auto inlined_ir = irpass::analysis::clone(func->ir.get());
if (!func->args.empty()) {
// TODO: Make sure that if stmt->args is an ArgLoadStmt,
// it will not be replaced again here
irpass::replace_statements_with(
irpass::replace_statements(
inlined_ir.get(),
/*filter=*/[&](Stmt *s) { return s->is<ArgLoadStmt>(); },
/*generator=*/
/*finder=*/
[&](Stmt *s) { return stmt->args[s->as<ArgLoadStmt>()->arg_id]; });
}
if (func->rets.empty()) {
Expand All @@ -46,7 +44,7 @@ class Inliner : public BasicStmtVisitor {
// Use a local variable to store the return value
auto *return_address = inlined_ir->as<Block>()->insert(
Stmt::make<AllocaStmt>(func->rets[0].dt), /*location=*/0);
irpass::replace_statements_with(
irpass::replace_and_insert_statements(
inlined_ir.get(),
/*filter=*/[&](Stmt *s) { return s->is<KernelReturnStmt>(); },
/*generator=*/
Expand Down
30 changes: 30 additions & 0 deletions taichi/transforms/replace_statements.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#include "taichi/ir/transforms.h"

TLANG_NAMESPACE_BEGIN

namespace irpass {

bool replace_and_insert_statements(
IRNode *root,
std::function<bool(Stmt *)> filter,
std::function<std::unique_ptr<Stmt>(Stmt *)> generator) {
return transform_statements(root, std::move(filter),
[&](Stmt *stmt, DelayedIRModifier *modifier) {
modifier->replace_with(stmt, generator(stmt));
});
}

bool replace_statements(IRNode *root,
std::function<bool(Stmt *)> filter,
std::function<Stmt *(Stmt *)> finder) {
return transform_statements(
root, std::move(filter), [&](Stmt *stmt, DelayedIRModifier *modifier) {
auto existing_new_stmt = finder(stmt);
irpass::replace_all_usages_with(root, stmt, existing_new_stmt);
modifier->erase(stmt);
});
}

} // namespace irpass

TLANG_NAMESPACE_END
155 changes: 0 additions & 155 deletions taichi/transforms/statement_replace.cpp

This file was deleted.

61 changes: 61 additions & 0 deletions taichi/transforms/transform_statements.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#include "taichi/ir/ir.h"
#include "taichi/ir/statements.h"
#include "taichi/ir/transforms.h"
#include "taichi/ir/visitors.h"

TLANG_NAMESPACE_BEGIN

// Transform each filtered statement
class StatementsTransformer : public BasicStmtVisitor {
public:
using BasicStmtVisitor::visit;

StatementsTransformer(
std::function<bool(Stmt *)> filter,
std::function<void(Stmt *, DelayedIRModifier *)> transformer)
: filter_(std::move(filter)), transformer_(std::move(transformer)) {
allow_undefined_visitor = true;
invoke_default_visitor = true;
}

void maybe_transform(Stmt *stmt) {
if (filter_(stmt)) {
transformer_(stmt, &modifier_);
}
}

void preprocess_container_stmt(Stmt *stmt) override {
maybe_transform(stmt);
}

void visit(Stmt *stmt) override {
maybe_transform(stmt);
}

static bool run(IRNode *root,
std::function<bool(Stmt *)> filter,
std::function<void(Stmt *, DelayedIRModifier *)> replacer) {
StatementsTransformer transformer(std::move(filter), std::move(replacer));
root->accept(&transformer);
return transformer.modifier_.modify_ir();
}

private:
std::function<bool(Stmt *)> filter_;
std::function<void(Stmt *, DelayedIRModifier *)> transformer_;
DelayedIRModifier modifier_;
};

namespace irpass {

bool transform_statements(
IRNode *root,
std::function<bool(Stmt *)> filter,
std::function<void(Stmt *, DelayedIRModifier *)> transformer) {
return StatementsTransformer::run(root, std::move(filter),
std::move(transformer));
}

} // namespace irpass

TLANG_NAMESPACE_END

0 comments on commit e5f439b

Please sign in to comment.