-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[opt] Simplify replace_statements and improve demote_dense_struct_fors (
#2335) * [opt] Simplify replace_statements and improve demote_dense_struct_fors * Apply review
- Loading branch information
1 parent
553522e
commit e5f439b
Showing
6 changed files
with
124 additions
and
184 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
#include "taichi/ir/transforms.h" | ||
|
||
TLANG_NAMESPACE_BEGIN | ||
|
||
namespace irpass { | ||
|
||
bool replace_and_insert_statements( | ||
IRNode *root, | ||
std::function<bool(Stmt *)> filter, | ||
std::function<std::unique_ptr<Stmt>(Stmt *)> generator) { | ||
return transform_statements(root, std::move(filter), | ||
[&](Stmt *stmt, DelayedIRModifier *modifier) { | ||
modifier->replace_with(stmt, generator(stmt)); | ||
}); | ||
} | ||
|
||
bool replace_statements(IRNode *root, | ||
std::function<bool(Stmt *)> filter, | ||
std::function<Stmt *(Stmt *)> finder) { | ||
return transform_statements( | ||
root, std::move(filter), [&](Stmt *stmt, DelayedIRModifier *modifier) { | ||
auto existing_new_stmt = finder(stmt); | ||
irpass::replace_all_usages_with(root, stmt, existing_new_stmt); | ||
modifier->erase(stmt); | ||
}); | ||
} | ||
|
||
} // namespace irpass | ||
|
||
TLANG_NAMESPACE_END |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
#include "taichi/ir/ir.h" | ||
#include "taichi/ir/statements.h" | ||
#include "taichi/ir/transforms.h" | ||
#include "taichi/ir/visitors.h" | ||
|
||
TLANG_NAMESPACE_BEGIN | ||
|
||
// Transform each filtered statement | ||
class StatementsTransformer : public BasicStmtVisitor { | ||
public: | ||
using BasicStmtVisitor::visit; | ||
|
||
StatementsTransformer( | ||
std::function<bool(Stmt *)> filter, | ||
std::function<void(Stmt *, DelayedIRModifier *)> transformer) | ||
: filter_(std::move(filter)), transformer_(std::move(transformer)) { | ||
allow_undefined_visitor = true; | ||
invoke_default_visitor = true; | ||
} | ||
|
||
void maybe_transform(Stmt *stmt) { | ||
if (filter_(stmt)) { | ||
transformer_(stmt, &modifier_); | ||
} | ||
} | ||
|
||
void preprocess_container_stmt(Stmt *stmt) override { | ||
maybe_transform(stmt); | ||
} | ||
|
||
void visit(Stmt *stmt) override { | ||
maybe_transform(stmt); | ||
} | ||
|
||
static bool run(IRNode *root, | ||
std::function<bool(Stmt *)> filter, | ||
std::function<void(Stmt *, DelayedIRModifier *)> replacer) { | ||
StatementsTransformer transformer(std::move(filter), std::move(replacer)); | ||
root->accept(&transformer); | ||
return transformer.modifier_.modify_ir(); | ||
} | ||
|
||
private: | ||
std::function<bool(Stmt *)> filter_; | ||
std::function<void(Stmt *, DelayedIRModifier *)> transformer_; | ||
DelayedIRModifier modifier_; | ||
}; | ||
|
||
namespace irpass { | ||
|
||
bool transform_statements( | ||
IRNode *root, | ||
std::function<bool(Stmt *)> filter, | ||
std::function<void(Stmt *, DelayedIRModifier *)> transformer) { | ||
return StatementsTransformer::run(root, std::move(filter), | ||
std::move(transformer)); | ||
} | ||
|
||
} // namespace irpass | ||
|
||
TLANG_NAMESPACE_END |