-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ir] [refactor] Simplify the "re_id" pass (#1304)
- Loading branch information
1 parent
dac7724
commit 4c25993
Showing
1 changed file
with
8 additions
and
60 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,94 +1,42 @@ | ||
#include "taichi/ir/ir.h" | ||
#include "taichi/ir/transforms.h" | ||
#include "taichi/ir/visitors.h" | ||
#include "taichi/ir/frontend_ir.h" | ||
|
||
TLANG_NAMESPACE_BEGIN | ||
|
||
// This pass manipulates the id of statements so that they are successive values | ||
// starting from 0 | ||
class ReId : public IRVisitor { | ||
class ReId : public BasicStmtVisitor { | ||
public: | ||
int id_counter; | ||
|
||
ReId(IRNode *node) { | ||
ReId() : id_counter(0) { | ||
allow_undefined_visitor = true; | ||
invoke_default_visitor = true; | ||
id_counter = 0; | ||
node->accept(this); | ||
} | ||
|
||
void re_id(Stmt *stmt) { | ||
stmt->id = id_counter++; | ||
} | ||
|
||
void visit(Stmt *stmt) { | ||
void visit(Stmt *stmt) override { | ||
re_id(stmt); | ||
} | ||
|
||
void visit(Block *stmt_list) { // block itself has no id | ||
for (auto &stmt : stmt_list->statements) { | ||
stmt->accept(this); | ||
} | ||
} | ||
|
||
void visit(IfStmt *if_stmt) { | ||
re_id(if_stmt); | ||
if (if_stmt->true_statements) | ||
if_stmt->true_statements->accept(this); | ||
if (if_stmt->false_statements) { | ||
if_stmt->false_statements->accept(this); | ||
} | ||
} | ||
|
||
void visit(FrontendIfStmt *if_stmt) { | ||
re_id(if_stmt); | ||
if (if_stmt->true_statements) | ||
if (if_stmt->true_statements) | ||
if_stmt->true_statements->accept(this); | ||
if (if_stmt->false_statements) { | ||
if_stmt->false_statements->accept(this); | ||
} | ||
} | ||
|
||
void visit(WhileStmt *stmt) { | ||
re_id(stmt); | ||
stmt->body->accept(this); | ||
} | ||
|
||
void visit(FrontendWhileStmt *stmt) { | ||
void preprocess_container_stmt(Stmt *stmt) override { | ||
re_id(stmt); | ||
stmt->body->accept(this); | ||
} | ||
|
||
void visit(FrontendForStmt *for_stmt) { | ||
re_id(for_stmt); | ||
for_stmt->body->accept(this); | ||
} | ||
|
||
void visit(RangeForStmt *for_stmt) { | ||
re_id(for_stmt); | ||
for_stmt->body->accept(this); | ||
} | ||
|
||
void visit(StructForStmt *for_stmt) { | ||
re_id(for_stmt); | ||
for_stmt->body->accept(this); | ||
} | ||
|
||
void visit(OffloadedStmt *stmt) { | ||
re_id(stmt); | ||
if (stmt->body) | ||
stmt->body->accept(this); | ||
static void run(IRNode *node) { | ||
ReId instance; | ||
node->accept(&instance); | ||
} | ||
}; | ||
|
||
namespace irpass { | ||
|
||
void re_id(IRNode *root) { | ||
ReId instance(root); | ||
ReId::run(root); | ||
} | ||
|
||
} // namespace irpass | ||
|
||
TLANG_NAMESPACE_END |