Skip to content

Commit

Permalink
[ir] Add a function to test if two IRNodes are equivalent (#683)
Browse files Browse the repository at this point in the history
* [skip ci] add (dummy) same_statements.cpp

* Move Stmt::is<>() to IRNode, and implement void IRNodeComparator::visit(Block *)

* use macros to simplify the code; implement visit(AssertStmt *) and visit(SNodeOpStmt *)

* [skip ci] implement visitor functions for a bunch of statements

* [skip ci] use unordered_map

* [skip ci] more Stmts

* [skip ci] check SNode * rather than their id

* [skip ci] macro DEFINE_SNODE_CHECK

* OffloadedStmt

* [skip ci] add (dummy) same_statements.cpp

* Move Stmt::is<>() to IRNode, and implement void IRNodeComparator::visit(Block *)

* use macros to simplify the code; implement visit(AssertStmt *) and visit(SNodeOpStmt *)

* [skip ci] implement visitor functions for a bunch of statements

* [skip ci] use unordered_map

* [skip ci] more Stmts

* [skip ci] check SNode * rather than their id

* [skip ci] macro DEFINE_SNODE_CHECK

* OffloadedStmt

* add a test

* debug(passing) (temporarily disable python tests for fast "ti test")

* [skip ci] debug(failing)

* fixed undefined behavior

* [skip ci] remove debug outputs

* refactor some stmts

* StmtFieldSNode, and tests for SNodeLookupStmt

* [skip ci] more stmts

* more stmts

* LaneAttribute<TypedConstant>

* add void visit(ConstStmt *) back to test ci

* avoid calling operator() to test ci

* move void StmtFieldManager::operator() to the end to test ci

* revert debugging changes

* more stmts; move VectorElement to ir.h

* [skip ci] finish

* move is_specialization to meta.h

* add_operand->register_operand; foolproof double registration

* add a pass after `lower` to check if fields_registered is true for all statements

* [skip ci] enforce code format

* update

* initialize WhileStmt::mask

* true_mask, false_mask

Co-authored-by: Yuanming Hu <[email protected]>
Co-authored-by: Taichi Gardener <[email protected]>
  • Loading branch information
3 people authored Apr 4, 2020
1 parent 62e24dd commit b865805
Show file tree
Hide file tree
Showing 15 changed files with 660 additions and 132 deletions.
3 changes: 2 additions & 1 deletion examples/nbody_oscillator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def initialize():
def advance(dt: ti.f32):
for i in range(N):
for k in ti.static(range(2)):
if pos[i][k] < 0 and vel[i][k] < 0 or pos[i][k] > 1 and vel[i][k] > 0:
if pos[i][k] < 0 and vel[i][k] < 0 or pos[i][k] > 1 and vel[i][
k] > 0:
vel[i][k] = -bounce * vel[i][k]

p = pos[i] - ti.Vector([0.5, 0.5])
Expand Down
60 changes: 60 additions & 0 deletions taichi/analysis/check_fields_registered.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#include "taichi/ir/ir.h"

TLANG_NAMESPACE_BEGIN

class FieldsRegisteredChecker : public BasicStmtVisitor {
public:
using BasicStmtVisitor::visit;

FieldsRegisteredChecker() {
allow_undefined_visitor = true;
invoke_default_visitor = true;
}

void visit(Stmt *stmt) override {
TI_ASSERT(stmt->fields_registered);
}

void visit(IfStmt *if_stmt) override {
TI_ASSERT(if_stmt->fields_registered);
if (if_stmt->true_statements)
if_stmt->true_statements->accept(this);
if (if_stmt->false_statements) {
if_stmt->false_statements->accept(this);
}
}

void visit(WhileStmt *stmt) override {
TI_ASSERT(stmt->fields_registered);
stmt->body->accept(this);
}

void visit(RangeForStmt *for_stmt) override {
TI_ASSERT(for_stmt->fields_registered);
for_stmt->body->accept(this);
}

void visit(StructForStmt *for_stmt) override {
TI_ASSERT(for_stmt->fields_registered);
for_stmt->body->accept(this);
}

void visit(OffloadedStmt *stmt) override {
TI_ASSERT(stmt->fields_registered);
if (stmt->body)
stmt->body->accept(this);
}

static void run(IRNode *root) {
FieldsRegisteredChecker checker;
root->accept(&checker);
}
};

namespace irpass {
void check_fields_registered(IRNode *root) {
return FieldsRegisteredChecker::run(root);
}
} // namespace irpass

TLANG_NAMESPACE_END
198 changes: 198 additions & 0 deletions taichi/analysis/same_statements.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
#include "taichi/ir/ir.h"
#include <unordered_map>
#include <unordered_set>

TLANG_NAMESPACE_BEGIN

// Compare if two IRNodes are equivalent.
class IRNodeComparator : public IRVisitor {
private:
IRNode *other_node;
// map the id from this node to the other node
std::unordered_map<int, int> id_map;
// ids which don't belong to either node
std::unordered_set<int> captured_id;

public:
bool same;

explicit IRNodeComparator(IRNode *other_node) : other_node(other_node) {
allow_undefined_visitor = true;
invoke_default_visitor = true;
same = true;
}

void map_id(int this_id, int other_id) {
if (captured_id.find(this_id) != captured_id.end() ||
captured_id.find(other_id) != captured_id.end()) {
same = false;
return;
}
auto it = id_map.find(this_id);
if (it == id_map.end()) {
id_map[this_id] = other_id;
} else if (it->second != other_id) {
same = false;
}
}

int get_other_id(int this_id) {
// get the corresponding id in the other node
auto it = id_map.find(this_id);
if (it != id_map.end()) {
return it->second;
}
// if not found, should be captured
// (What if this_id belongs to the other node? Ignoring this case here.)
if (captured_id.find(this_id) == captured_id.end()) {
captured_id.insert(this_id);
}
return this_id;
}

void visit(Block *stmt_list) override {
if (!other_node->is<Block>()) {
same = false;
return;
}

auto other = other_node->as<Block>();
if (stmt_list->size() != other->size()) {
same = false;
return;
}
for (int i = 0; i < (int)stmt_list->size(); i++) {
other_node = other->statements[i].get();
stmt_list->statements[i]->accept(this);
if (!same)
break;
}
other_node = other;
}

void basic_check(Stmt *stmt) {
// type check
if (typeid(*other_node) != typeid(*stmt)) {
same = false;
return;
}

// operand check
auto other = other_node->as<Stmt>();
if (stmt->num_operands() != other->num_operands()) {
same = false;
return;
}
for (int i = 0; i < stmt->num_operands(); i++) {
if (get_other_id(stmt->operand(i)->id) != other->operand(i)->id) {
same = false;
return;
}
}

// field check
if (!stmt->field_manager.equal(other->field_manager)) {
same = false;
return;
}

map_id(stmt->id, other->id);
}

void visit(Stmt *stmt) override {
basic_check(stmt);
}

void visit(IfStmt *stmt) override {
basic_check(stmt);
if (!same)
return;
auto other = other_node->as<IfStmt>();
if (stmt->true_statements) {
if (!other->true_statements) {
same = false;
return;
}
other_node = other->true_statements.get();
stmt->true_statements->accept(this);
other_node = other;
}
if (stmt->false_statements && same) {
if (!other->false_statements) {
same = false;
return;
}
other_node = other->false_statements.get();
stmt->false_statements->accept(this);
other_node = other;
}
}

void visit(FuncBodyStmt *stmt) override {
basic_check(stmt);
if (!same)
return;
auto other = other_node->as<FuncBodyStmt>();
other_node = other->body.get();
stmt->body->accept(this);
other_node = other;
}

void visit(WhileStmt *stmt) override {
basic_check(stmt);
if (!same)
return;
auto other = other_node->as<WhileStmt>();
other_node = other->body.get();
stmt->body->accept(this);
other_node = other;
}

void visit(RangeForStmt *stmt) override {
basic_check(stmt);
if (!same)
return;
auto other = other_node->as<RangeForStmt>();
other_node = other->body.get();
stmt->body->accept(this);
other_node = other;
}

void visit(StructForStmt *stmt) override {
basic_check(stmt);
if (!same)
return;
auto other = other_node->as<StructForStmt>();
other_node = other->body.get();
stmt->body->accept(this);
other_node = other;
}

void visit(OffloadedStmt *stmt) override {
basic_check(stmt);
if (!same)
return;
auto other = other_node->as<OffloadedStmt>();
if (stmt->has_body()) {
TI_ASSERT(stmt->body);
TI_ASSERT(other->body);
other_node = other->body.get();
stmt->body->accept(this);
other_node = other;
}
}

static bool run(IRNode *root1, IRNode *root2) {
IRNodeComparator comparator(root2);
root1->accept(&comparator);
return comparator.same;
}
};

namespace irpass {
bool same_statements(IRNode *root1, IRNode *root2) {
return IRNodeComparator::run(root1, root2);
}
} // namespace irpass

TLANG_NAMESPACE_END
3 changes: 2 additions & 1 deletion taichi/backends/opengl/opengl_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,8 @@ GLSSBO *root_ssbo;
void create_glsl_root_buffer(size_t size) {
// if (root_ssbo) return;
initialize_opengl();
root_ssbo = new GLSSBO; // TODO(archibate): mem leaking, use std::optional instead
root_ssbo =
new GLSSBO; // TODO(archibate): mem leaking, use std::optional instead
size += 2 * sizeof(int);
void *buffer = std::calloc(size, 1);
root_ssbo->bind_data(buffer, size, GL_DYNAMIC_READ);
Expand Down
24 changes: 13 additions & 11 deletions taichi/codegen/codegen_opengl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,38 +195,38 @@ class KernelGen : public IRVisitor {
if (used.atomic_float && !opengl_has_GL_NV_shader_atomic_float) { // {{{
kernel_header += (
#include "taichi/backends/opengl/shaders/atomics_data_f32.glsl.h"
);
);
#ifdef _GLSL_INT64
kernel_header += (
#include "taichi/backends/opengl/shaders/atomics_data_f64.glsl.h"
);
);
#endif
if (used.global_temp) {
kernel_header += (
#include "taichi/backends/opengl/shaders/atomics_gtmp_f32.glsl.h"
);
);
#ifdef _GLSL_INT64
kernel_header += (
kernel_header += (
#include "taichi/backends/opengl/shaders/atomics_gtmp_f64.glsl.h"
);
);
#endif
}
if (used.external_ptr) {
kernel_header += (
#include "taichi/backends/opengl/shaders/atomics_extr_f32.glsl.h"
);
);
#ifdef _GLSL_INT64
kernel_header += (
kernel_header += (
#include "taichi/backends/opengl/shaders/atomics_extr_f64.glsl.h"
);
);
#endif
}
} // }}}
if (used.random) { // TODO(archibate): random in different offloads should
// share rand seed? {{{
kernel_header += (
#include "taichi/backends/opengl/shaders/random.glsl.h"
);
);
} // }}}

line_appender_header_.append_raw(kernel_header);
Expand All @@ -236,8 +236,10 @@ class KernelGen : public IRVisitor {
threads_per_group = std::max(1, num_threads_);
else
num_groups_ = (num_threads_ + threads_per_group - 1) / threads_per_group;
emit("layout(local_size_x = {} /* {}, {} */, local_size_y = 1, local_size_z = 1) in;",
threads_per_group, num_groups_, num_threads_);
emit(
"layout(local_size_x = {} /* {}, {} */, local_size_y = 1, local_size_z "
"= 1) in;",
threads_per_group, num_groups_, num_threads_);
std::string extensions = "";
if (opengl_has_GL_NV_shader_atomic_float) {
extensions += "#extension GL_NV_shader_atomic_float: enable\n";
Expand Down
7 changes: 7 additions & 0 deletions taichi/common/meta.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ struct copy_refcv {
template <typename T, typename G>
using copy_refcv_t = typename copy_refcv<T, G>::type;

template <class T, template <class...> class Template>
struct is_specialization : std::false_type {};

template <template <class...> class Template, class... Args>
struct is_specialization<Template<Args...>, Template> : std::true_type {};

TI_STATIC_ASSERT((std::is_same<const volatile int, volatile const int>::value));
TI_STATIC_ASSERT(
(std::is_same<int,
Expand All @@ -96,5 +102,6 @@ TI_STATIC_ASSERT(
(std::is_same<copy_refcv_t<const int &, real>, const real &>::value));
TI_STATIC_ASSERT((std::is_same<copy_refcv_t<const volatile int &, real>,
const volatile real &>::value));
TI_STATIC_ASSERT((is_specialization<std::vector<int>, std::vector>::value));

} // namespace taichi
5 changes: 2 additions & 3 deletions taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ void IRBuilder::stop_gradient(SNode *snode) {

GetChStmt::GetChStmt(taichi::lang::Stmt *input_ptr, int chid)
: input_ptr(input_ptr), chid(chid) {
add_operand(this->input_ptr);
TI_ASSERT(input_ptr->is<SNodeLookupStmt>());
input_snode = input_ptr->as<SNodeLookupStmt>()->snode;
output_snode = input_snode->ch[chid].get();
TI_STMT_REG_FIELDS;
}

Expr select(const Expr &cond, const Expr &true_val, const Expr &false_val) {
Expand Down Expand Up @@ -495,8 +495,6 @@ OffloadedStmt::OffloadedStmt(OffloadedStmt::TaskType task_type)

OffloadedStmt::OffloadedStmt(OffloadedStmt::TaskType task_type, SNode *snode)
: task_type(task_type), snode(snode) {
add_operand(begin_stmt);
add_operand(end_stmt);
num_cpu_threads = 1;
const_begin = false;
const_end = false;
Expand All @@ -511,6 +509,7 @@ OffloadedStmt::OffloadedStmt(OffloadedStmt::TaskType task_type, SNode *snode)
if (task_type != TaskType::listgen) {
body = std::make_unique<Block>();
}
TI_STMT_REG_FIELDS;
}

std::string OffloadedStmt::task_name() const {
Expand Down
Loading

0 comments on commit b865805

Please sign in to comment.