diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index c3ccdda92395e..175cba9dc2170 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -571,9 +571,11 @@ void DelayedIRModifier::replace_with(Stmt *stmt, } bool DelayedIRModifier::modify_ir() { + bool force_modified = modified_; + modified_ = false; if (to_insert_before.empty() && to_insert_after.empty() && to_erase.empty() && to_replace_with.empty()) - return false; + return force_modified; for (auto &i : to_insert_before) { i.first->parent->insert_before(i.first, std::move(i.second)); } @@ -593,6 +595,10 @@ bool DelayedIRModifier::modify_ir() { return true; } +void DelayedIRModifier::mark_as_modified() { + modified_ = true; +} + LocalAddress::LocalAddress(Stmt *var, int offset) : var(var), offset(offset) { TI_ASSERT(var->is()); } diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index 89499a121f67c..1b0ee54a3f927 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -729,6 +729,7 @@ class DelayedIRModifier { std::vector> to_insert_after; std::vector> to_replace_with; std::vector to_erase; + bool modified_{false}; public: ~DelayedIRModifier(); @@ -741,6 +742,9 @@ class DelayedIRModifier { VecStatement &&new_statements, bool replace_usages = true); bool modify_ir(); + + // Force the next call of modify_ir() to return true. + void mark_as_modified(); }; struct LocalAddress { diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index f0d95e9f619f0..6aa925de3817e 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -48,9 +48,19 @@ class AlgSimp : public BasicStmtVisitor { } void visit(UnaryOpStmt *stmt) override { - if (stmt->is_cast() && stmt->cast_type == stmt->operand->ret_type) { - stmt->replace_with(stmt->operand); - modifier.erase(stmt); + if (stmt->is_cast()) { + if (stmt->cast_type == stmt->operand->ret_type) { + stmt->replace_with(stmt->operand); + modifier.erase(stmt); + } else if (stmt->operand->is() && + stmt->operand->as()->is_cast()) { + auto prev_cast = stmt->operand->as(); + if (stmt->op_type == UnaryOpType::cast_bits && + prev_cast->op_type == UnaryOpType::cast_bits) { + stmt->operand = prev_cast->operand; + modifier.mark_as_modified(); + } + } } }