Skip to content

Commit

Permalink
[opt] Simplify bit_cast of bit_cast (#2152)
Browse files Browse the repository at this point in the history
  • Loading branch information
xumingkuan authored Jan 12, 2021
1 parent b21a2b1 commit b533d92
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 4 deletions.
8 changes: 7 additions & 1 deletion taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -571,9 +571,11 @@ void DelayedIRModifier::replace_with(Stmt *stmt,
}

bool DelayedIRModifier::modify_ir() {
bool force_modified = modified_;
modified_ = false;
if (to_insert_before.empty() && to_insert_after.empty() && to_erase.empty() &&
to_replace_with.empty())
return false;
return force_modified;
for (auto &i : to_insert_before) {
i.first->parent->insert_before(i.first, std::move(i.second));
}
Expand All @@ -593,6 +595,10 @@ bool DelayedIRModifier::modify_ir() {
return true;
}

void DelayedIRModifier::mark_as_modified() {
modified_ = true;
}

LocalAddress::LocalAddress(Stmt *var, int offset) : var(var), offset(offset) {
TI_ASSERT(var->is<AllocaStmt>());
}
Expand Down
4 changes: 4 additions & 0 deletions taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,7 @@ class DelayedIRModifier {
std::vector<std::pair<Stmt *, VecStatement>> to_insert_after;
std::vector<std::tuple<Stmt *, VecStatement, bool>> to_replace_with;
std::vector<Stmt *> to_erase;
bool modified_{false};

public:
~DelayedIRModifier();
Expand All @@ -741,6 +742,9 @@ class DelayedIRModifier {
VecStatement &&new_statements,
bool replace_usages = true);
bool modify_ir();

// Force the next call of modify_ir() to return true.
void mark_as_modified();
};

struct LocalAddress {
Expand Down
16 changes: 13 additions & 3 deletions taichi/transforms/alg_simp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,19 @@ class AlgSimp : public BasicStmtVisitor {
}

void visit(UnaryOpStmt *stmt) override {
if (stmt->is_cast() && stmt->cast_type == stmt->operand->ret_type) {
stmt->replace_with(stmt->operand);
modifier.erase(stmt);
if (stmt->is_cast()) {
if (stmt->cast_type == stmt->operand->ret_type) {
stmt->replace_with(stmt->operand);
modifier.erase(stmt);
} else if (stmt->operand->is<UnaryOpStmt>() &&
stmt->operand->as<UnaryOpStmt>()->is_cast()) {
auto prev_cast = stmt->operand->as<UnaryOpStmt>();
if (stmt->op_type == UnaryOpType::cast_bits &&
prev_cast->op_type == UnaryOpType::cast_bits) {
stmt->operand = prev_cast->operand;
modifier.mark_as_modified();
}
}
}
}

Expand Down

0 comments on commit b533d92

Please sign in to comment.