Skip to content

Commit

Permalink
[Opt] Make full_simplify iterative when advanced_optimization=True (#…
Browse files Browse the repository at this point in the history
…1225)

* [Opt] Make full_simplify iterative when advanced_optimization=True

* retrigger CI
  • Loading branch information
xumingkuan authored Jun 12, 2020
1 parent bff5755 commit 4d427a2
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 27 deletions.
4 changes: 2 additions & 2 deletions taichi/ir/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace irpass {
void re_id(IRNode *root);
void flag_access(IRNode *root);
void die(IRNode *root);
void simplify(IRNode *root, Kernel *kernel = nullptr);
bool simplify(IRNode *root, Kernel *kernel = nullptr);
void cfg_optimization(IRNode *root);
bool alg_simp(IRNode *root);
bool whole_kernel_cse(IRNode *root);
Expand All @@ -30,7 +30,7 @@ void replace_all_usages_with(IRNode *root, Stmt *old_stmt, Stmt *new_stmt);
void check_out_of_bound(IRNode *root);
void lower_access(IRNode *root, bool lower_atomic, Kernel *kernel = nullptr);
void make_adjoint(IRNode *root, bool use_stack = false);
void constant_fold(IRNode *root);
bool constant_fold(IRNode *root);
void offload(IRNode *root);
void fix_block_parents(IRNode *root);
void replace_statements_with(IRNode *root,
Expand Down
15 changes: 1 addition & 14 deletions taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,6 @@ void compile_to_offloads(IRNode *ir,
irpass::analysis::verify(ir);
}

irpass::extract_constant(ir);
print("Constant extracted");
irpass::analysis::verify(ir);

irpass::variable_optimization(ir, false);
print("Store forwarded");
irpass::analysis::verify(ir);
Expand Down Expand Up @@ -114,19 +110,10 @@ void compile_to_offloads(IRNode *ir,
irpass::analysis::verify(ir);
}

irpass::extract_constant(ir);
print("Constant extracted II");

irpass::demote_atomics(ir);
print("Atomics demoted");
irpass::analysis::verify(ir);

irpass::full_simplify(ir);
print("Simplified III");

irpass::extract_constant(ir);
print("Constant extracted III");

irpass::variable_optimization(ir, true);
print("Store forwarded II");

Expand All @@ -135,7 +122,7 @@ void compile_to_offloads(IRNode *ir,
irpass::analysis::verify(ir);

irpass::full_simplify(ir);
print("Simplified IV");
print("Simplified III");

// Final field registration correctness & type checking
irpass::typecheck(ir);
Expand Down
11 changes: 7 additions & 4 deletions taichi/transforms/constant_fold.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,24 +183,27 @@ class ConstantFold : public BasicStmtVisitor {
}
}

static void run(IRNode *node) {
static bool run(IRNode *node) {
ConstantFold folder;
bool result = false;
while (true) {
bool modified = false;
try {
node->accept(&folder);
} catch (IRModified) {
modified = true;
result = true;
}
if (!modified)
break;
}
return result;
}
};

namespace irpass {

void constant_fold(IRNode *root) {
bool constant_fold(IRNode *root) {
// @archibate found that `debug=True` will cause JIT kernels
// failed to evaluate correctly (always return 0), so we simply
// disable constant_fold when config.debug is turned on.
Expand All @@ -209,10 +212,10 @@ void constant_fold(IRNode *root) {
auto kernel = root->get_kernel();
if (kernel && kernel->program.config.debug) {
TI_TRACE("config.debug enabled, ignoring constant fold");
return;
return false;
}
if (!advanced_optimization)
return;
return false;
return ConstantFold::run(root);
}

Expand Down
33 changes: 26 additions & 7 deletions taichi/transforms/simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1172,21 +1172,40 @@ class Simplify : public IRVisitor {

namespace irpass {

void simplify(IRNode *root, Kernel *kernel) {
while (1) {
bool simplify(IRNode *root, Kernel *kernel) {
bool modified = false;
while (true) {
Simplify pass(root, kernel);
if (!pass.modified)
if (pass.modified)
modified = true;
else
break;
}
return modified;
}

void full_simplify(IRNode *root, Kernel *kernel) {
constant_fold(root);
if (advanced_optimization) {
alg_simp(root);
die(root);
whole_kernel_cse(root);
while (true) {
bool modified = false;
extract_constant(root);
if (constant_fold(root))
modified = true;
if (alg_simp(root))
modified = true;
die(root);
if (whole_kernel_cse(root))
modified = true;
die(root);
if (simplify(root, kernel))
modified = true;
die(root);
if (!modified)
break;
}
return;
}
constant_fold(root);
die(root);
simplify(root, kernel);
die(root);
Expand Down

0 comments on commit 4d427a2

Please sign in to comment.