-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ir] Add a function to test if two IRNodes are equivalent (#683)
* [skip ci] add (dummy) same_statements.cpp * Move Stmt::is<>() to IRNode, and implement void IRNodeComparator::visit(Block *) * use macros to simplify the code; implement visit(AssertStmt *) and visit(SNodeOpStmt *) * [skip ci] implement visitor functions for a bunch of statements * [skip ci] use unordered_map * [skip ci] more Stmts * [skip ci] check SNode * rather than their id * [skip ci] macro DEFINE_SNODE_CHECK * OffloadedStmt * [skip ci] add (dummy) same_statements.cpp * Move Stmt::is<>() to IRNode, and implement void IRNodeComparator::visit(Block *) * use macros to simplify the code; implement visit(AssertStmt *) and visit(SNodeOpStmt *) * [skip ci] implement visitor functions for a bunch of statements * [skip ci] use unordered_map * [skip ci] more Stmts * [skip ci] check SNode * rather than their id * [skip ci] macro DEFINE_SNODE_CHECK * OffloadedStmt * add a test * debug(passing) (temporarily disable python tests for fast "ti test") * [skip ci] debug(failing) * fixed undefined behavior * [skip ci] remove debug outputs * refactor some stmts * StmtFieldSNode, and tests for SNodeLookupStmt * [skip ci] more stmts * more stmts * LaneAttribute<TypedConstant> * add void visit(ConstStmt *) back to test ci * avoid calling operator() to test ci * move void StmtFieldManager::operator() to the end to test ci * revert debugging changes * more stmts; move VectorElement to ir.h * [skip ci] finish * move is_specialization to meta.h * add_operand->register_operand; foolproof double registration * add a pass after `lower` to check if fields_registered is true for all statements * [skip ci] enforce code format * update * initialize WhileStmt::mask * true_mask, false_mask Co-authored-by: Yuanming Hu <[email protected]> Co-authored-by: Taichi Gardener <[email protected]>
- Loading branch information
1 parent
62e24dd
commit b865805
Showing
15 changed files
with
660 additions
and
132 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
#include "taichi/ir/ir.h" | ||
|
||
TLANG_NAMESPACE_BEGIN | ||
|
||
class FieldsRegisteredChecker : public BasicStmtVisitor { | ||
public: | ||
using BasicStmtVisitor::visit; | ||
|
||
FieldsRegisteredChecker() { | ||
allow_undefined_visitor = true; | ||
invoke_default_visitor = true; | ||
} | ||
|
||
void visit(Stmt *stmt) override { | ||
TI_ASSERT(stmt->fields_registered); | ||
} | ||
|
||
void visit(IfStmt *if_stmt) override { | ||
TI_ASSERT(if_stmt->fields_registered); | ||
if (if_stmt->true_statements) | ||
if_stmt->true_statements->accept(this); | ||
if (if_stmt->false_statements) { | ||
if_stmt->false_statements->accept(this); | ||
} | ||
} | ||
|
||
void visit(WhileStmt *stmt) override { | ||
TI_ASSERT(stmt->fields_registered); | ||
stmt->body->accept(this); | ||
} | ||
|
||
void visit(RangeForStmt *for_stmt) override { | ||
TI_ASSERT(for_stmt->fields_registered); | ||
for_stmt->body->accept(this); | ||
} | ||
|
||
void visit(StructForStmt *for_stmt) override { | ||
TI_ASSERT(for_stmt->fields_registered); | ||
for_stmt->body->accept(this); | ||
} | ||
|
||
void visit(OffloadedStmt *stmt) override { | ||
TI_ASSERT(stmt->fields_registered); | ||
if (stmt->body) | ||
stmt->body->accept(this); | ||
} | ||
|
||
static void run(IRNode *root) { | ||
FieldsRegisteredChecker checker; | ||
root->accept(&checker); | ||
} | ||
}; | ||
|
||
namespace irpass { | ||
void check_fields_registered(IRNode *root) { | ||
return FieldsRegisteredChecker::run(root); | ||
} | ||
} // namespace irpass | ||
|
||
TLANG_NAMESPACE_END |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
#include "taichi/ir/ir.h" | ||
#include <unordered_map> | ||
#include <unordered_set> | ||
|
||
TLANG_NAMESPACE_BEGIN | ||
|
||
// Compare if two IRNodes are equivalent. | ||
class IRNodeComparator : public IRVisitor { | ||
private: | ||
IRNode *other_node; | ||
// map the id from this node to the other node | ||
std::unordered_map<int, int> id_map; | ||
// ids which don't belong to either node | ||
std::unordered_set<int> captured_id; | ||
|
||
public: | ||
bool same; | ||
|
||
explicit IRNodeComparator(IRNode *other_node) : other_node(other_node) { | ||
allow_undefined_visitor = true; | ||
invoke_default_visitor = true; | ||
same = true; | ||
} | ||
|
||
void map_id(int this_id, int other_id) { | ||
if (captured_id.find(this_id) != captured_id.end() || | ||
captured_id.find(other_id) != captured_id.end()) { | ||
same = false; | ||
return; | ||
} | ||
auto it = id_map.find(this_id); | ||
if (it == id_map.end()) { | ||
id_map[this_id] = other_id; | ||
} else if (it->second != other_id) { | ||
same = false; | ||
} | ||
} | ||
|
||
int get_other_id(int this_id) { | ||
// get the corresponding id in the other node | ||
auto it = id_map.find(this_id); | ||
if (it != id_map.end()) { | ||
return it->second; | ||
} | ||
// if not found, should be captured | ||
// (What if this_id belongs to the other node? Ignoring this case here.) | ||
if (captured_id.find(this_id) == captured_id.end()) { | ||
captured_id.insert(this_id); | ||
} | ||
return this_id; | ||
} | ||
|
||
void visit(Block *stmt_list) override { | ||
if (!other_node->is<Block>()) { | ||
same = false; | ||
return; | ||
} | ||
|
||
auto other = other_node->as<Block>(); | ||
if (stmt_list->size() != other->size()) { | ||
same = false; | ||
return; | ||
} | ||
for (int i = 0; i < (int)stmt_list->size(); i++) { | ||
other_node = other->statements[i].get(); | ||
stmt_list->statements[i]->accept(this); | ||
if (!same) | ||
break; | ||
} | ||
other_node = other; | ||
} | ||
|
||
void basic_check(Stmt *stmt) { | ||
// type check | ||
if (typeid(*other_node) != typeid(*stmt)) { | ||
same = false; | ||
return; | ||
} | ||
|
||
// operand check | ||
auto other = other_node->as<Stmt>(); | ||
if (stmt->num_operands() != other->num_operands()) { | ||
same = false; | ||
return; | ||
} | ||
for (int i = 0; i < stmt->num_operands(); i++) { | ||
if (get_other_id(stmt->operand(i)->id) != other->operand(i)->id) { | ||
same = false; | ||
return; | ||
} | ||
} | ||
|
||
// field check | ||
if (!stmt->field_manager.equal(other->field_manager)) { | ||
same = false; | ||
return; | ||
} | ||
|
||
map_id(stmt->id, other->id); | ||
} | ||
|
||
void visit(Stmt *stmt) override { | ||
basic_check(stmt); | ||
} | ||
|
||
void visit(IfStmt *stmt) override { | ||
basic_check(stmt); | ||
if (!same) | ||
return; | ||
auto other = other_node->as<IfStmt>(); | ||
if (stmt->true_statements) { | ||
if (!other->true_statements) { | ||
same = false; | ||
return; | ||
} | ||
other_node = other->true_statements.get(); | ||
stmt->true_statements->accept(this); | ||
other_node = other; | ||
} | ||
if (stmt->false_statements && same) { | ||
if (!other->false_statements) { | ||
same = false; | ||
return; | ||
} | ||
other_node = other->false_statements.get(); | ||
stmt->false_statements->accept(this); | ||
other_node = other; | ||
} | ||
} | ||
|
||
void visit(FuncBodyStmt *stmt) override { | ||
basic_check(stmt); | ||
if (!same) | ||
return; | ||
auto other = other_node->as<FuncBodyStmt>(); | ||
other_node = other->body.get(); | ||
stmt->body->accept(this); | ||
other_node = other; | ||
} | ||
|
||
void visit(WhileStmt *stmt) override { | ||
basic_check(stmt); | ||
if (!same) | ||
return; | ||
auto other = other_node->as<WhileStmt>(); | ||
other_node = other->body.get(); | ||
stmt->body->accept(this); | ||
other_node = other; | ||
} | ||
|
||
void visit(RangeForStmt *stmt) override { | ||
basic_check(stmt); | ||
if (!same) | ||
return; | ||
auto other = other_node->as<RangeForStmt>(); | ||
other_node = other->body.get(); | ||
stmt->body->accept(this); | ||
other_node = other; | ||
} | ||
|
||
void visit(StructForStmt *stmt) override { | ||
basic_check(stmt); | ||
if (!same) | ||
return; | ||
auto other = other_node->as<StructForStmt>(); | ||
other_node = other->body.get(); | ||
stmt->body->accept(this); | ||
other_node = other; | ||
} | ||
|
||
void visit(OffloadedStmt *stmt) override { | ||
basic_check(stmt); | ||
if (!same) | ||
return; | ||
auto other = other_node->as<OffloadedStmt>(); | ||
if (stmt->has_body()) { | ||
TI_ASSERT(stmt->body); | ||
TI_ASSERT(other->body); | ||
other_node = other->body.get(); | ||
stmt->body->accept(this); | ||
other_node = other; | ||
} | ||
} | ||
|
||
static bool run(IRNode *root1, IRNode *root2) { | ||
IRNodeComparator comparator(root2); | ||
root1->accept(&comparator); | ||
return comparator.same; | ||
} | ||
}; | ||
|
||
namespace irpass { | ||
bool same_statements(IRNode *root1, IRNode *root2) { | ||
return IRNodeComparator::run(root1, root2); | ||
} | ||
} // namespace irpass | ||
|
||
TLANG_NAMESPACE_END |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.