-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[opt] Simplify replace_statements and improve demote_dense_struct_fors #2335
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -75,25 +75,29 @@ bool constant_fold(IRNode *root, | |
const CompileConfig &config, | ||
const ConstantFoldPass::Args &args); | ||
void offload(IRNode *root, const CompileConfig &config); | ||
bool transform_statements( | ||
IRNode *root, | ||
std::function<bool(Stmt *)> filter, | ||
std::function<void(Stmt *, DelayedIRModifier *)> transformer); | ||
/** | ||
* @param root The IR root to be traversed. | ||
* @param filter A function which tells if a statement need to be replaced. | ||
* @param generator If a statement s need to be replaced, generate a new | ||
* statement s1 with the argument s, insert s1 to s's place, and replace all | ||
* usages of s with s1. | ||
* @return Whether the IR is modified. | ||
*/ | ||
void replace_statements_with( | ||
bool replace_and_insert_statements( | ||
IRNode *root, | ||
std::function<bool(Stmt *)> filter, | ||
std::function<std::unique_ptr<Stmt>(Stmt *)> generator); | ||
/** | ||
* @param generator If a statement s need to be replaced, find the existing | ||
* @param finder If a statement s need to be replaced, find the existing | ||
* statement s1 with the argument s, and replace all usages of s with s1. | ||
* @return Whether the IR is modified. | ||
*/ | ||
bool replace_statements_with(IRNode *root, | ||
std::function<bool(Stmt *)> filter, | ||
std::function<Stmt *(Stmt *)> generator); | ||
bool replace_statements(IRNode *root, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Following the above comment, how is |s|'s defining stmt handled in this case? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |s|'s defining stmt is erased here. |
||
std::function<bool(Stmt *)> filter, | ||
std::function<Stmt *(Stmt *)> finder); | ||
void demote_dense_struct_fors(IRNode *root); | ||
bool demote_atomics(IRNode *root, const CompileConfig &config); | ||
void reverse_segments(IRNode *root); // for autograd | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
#include "taichi/ir/transforms.h" | ||
|
||
TLANG_NAMESPACE_BEGIN | ||
|
||
namespace irpass { | ||
|
||
bool replace_and_insert_statements( | ||
IRNode *root, | ||
std::function<bool(Stmt *)> filter, | ||
std::function<std::unique_ptr<Stmt>(Stmt *)> generator) { | ||
return transform_statements(root, std::move(filter), | ||
[&](Stmt *stmt, DelayedIRModifier *modifier) { | ||
modifier->replace_with(stmt, generator(stmt)); | ||
}); | ||
} | ||
|
||
bool replace_statements(IRNode *root, | ||
std::function<bool(Stmt *)> filter, | ||
std::function<Stmt *(Stmt *)> finder) { | ||
return transform_statements( | ||
root, std::move(filter), [&](Stmt *stmt, DelayedIRModifier *modifier) { | ||
auto existing_new_stmt = finder(stmt); | ||
irpass::replace_all_usages_with(root, stmt, existing_new_stmt); | ||
modifier->erase(stmt); | ||
}); | ||
} | ||
|
||
} // namespace irpass | ||
|
||
TLANG_NAMESPACE_END |
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
#include "taichi/ir/ir.h" | ||
#include "taichi/ir/statements.h" | ||
#include "taichi/ir/transforms.h" | ||
#include "taichi/ir/visitors.h" | ||
|
||
TLANG_NAMESPACE_BEGIN | ||
|
||
// Transform each filtered statement | ||
class StatementsTransformer : public BasicStmtVisitor { | ||
public: | ||
using BasicStmtVisitor::visit; | ||
|
||
StatementsTransformer( | ||
std::function<bool(Stmt *)> filter, | ||
std::function<void(Stmt *, DelayedIRModifier *)> transformer) | ||
: filter_(std::move(filter)), transformer_(std::move(transformer)) { | ||
allow_undefined_visitor = true; | ||
invoke_default_visitor = true; | ||
} | ||
|
||
void maybe_transform(Stmt *stmt) { | ||
if (filter_(stmt)) { | ||
transformer_(stmt, &modifier_); | ||
} | ||
} | ||
|
||
void preprocess_container_stmt(Stmt *stmt) override { | ||
maybe_transform(stmt); | ||
} | ||
|
||
void visit(Stmt *stmt) override { | ||
maybe_transform(stmt); | ||
} | ||
|
||
static bool run(IRNode *root, | ||
std::function<bool(Stmt *)> filter, | ||
std::function<void(Stmt *, DelayedIRModifier *)> replacer) { | ||
StatementsTransformer transformer(std::move(filter), std::move(replacer)); | ||
root->accept(&transformer); | ||
return transformer.modifier_.modify_ir(); | ||
} | ||
|
||
private: | ||
std::function<bool(Stmt *)> filter_; | ||
std::function<void(Stmt *, DelayedIRModifier *)> transformer_; | ||
DelayedIRModifier modifier_; | ||
}; | ||
|
||
namespace irpass { | ||
|
||
bool transform_statements( | ||
IRNode *root, | ||
std::function<bool(Stmt *)> filter, | ||
std::function<void(Stmt *, DelayedIRModifier *)> transformer) { | ||
return StatementsTransformer::run(root, std::move(filter), | ||
std::move(transformer)); | ||
} | ||
|
||
} // namespace irpass | ||
|
||
TLANG_NAMESPACE_END |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: by saying "insert s1 to s's place", do we mean "insert |s1| to where |s| is defined, removes |s|'s definition and replaces all usages of ..."?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, I'll rephrase this.