From b865805452e542f61fa937933738a346f64d679e Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Fri, 3 Apr 2020 20:59:09 -0400 Subject: [PATCH] [ir] Add a function to test if two IRNodes are equivalent (#683) * [skip ci] add (dummy) same_statements.cpp * Move Stmt::is<>() to IRNode, and implement void IRNodeComparator::visit(Block *) * use macros to simplify the code; implement visit(AssertStmt *) and visit(SNodeOpStmt *) * [skip ci] implement visitor functions for a bunch of statements * [skip ci] use unordered_map * [skip ci] more Stmts * [skip ci] check SNode * rather than their id * [skip ci] macro DEFINE_SNODE_CHECK * OffloadedStmt * [skip ci] add (dummy) same_statements.cpp * Move Stmt::is<>() to IRNode, and implement void IRNodeComparator::visit(Block *) * use macros to simplify the code; implement visit(AssertStmt *) and visit(SNodeOpStmt *) * [skip ci] implement visitor functions for a bunch of statements * [skip ci] use unordered_map * [skip ci] more Stmts * [skip ci] check SNode * rather than their id * [skip ci] macro DEFINE_SNODE_CHECK * OffloadedStmt * add a test * debug(passing) (temporarily disable python tests for fast "ti test") * [skip ci] debug(failing) * fixed undefined behavior * [skip ci] remove debug outputs * refactor some stmts * StmtFieldSNode, and tests for SNodeLookupStmt * [skip ci] more stmts * more stmts * LaneAttribute * add void visit(ConstStmt *) back to test ci * avoid calling operator() to test ci * move void StmtFieldManager::operator() to the end to test ci * revert debugging changes * more stmts; move VectorElement to ir.h * [skip ci] finish * move is_specialization to meta.h * add_operand->register_operand; foolproof double registration * add a pass after `lower` to check if fields_registered is true for all statements * [skip ci] enforce code format * update * initialize WhileStmt::mask * true_mask, false_mask Co-authored-by: Yuanming Hu Co-authored-by: Taichi Gardener --- examples/nbody_oscillator.py | 3 +- taichi/analysis/check_fields_registered.cpp | 60 +++++ taichi/analysis/same_statements.cpp | 198 ++++++++++++++++ taichi/backends/opengl/opengl_api.cpp | 3 +- taichi/codegen/codegen_opengl.cpp | 24 +- taichi/common/meta.h | 7 + taichi/ir/ir.cpp | 5 +- taichi/ir/ir.h | 250 +++++++++++++------- taichi/ir/statements.h | 92 ++++--- taichi/lang_util.h | 6 +- taichi/transforms/compile_to_offloads.cpp | 3 + taichi/transforms/type_check.cpp | 1 + tests/cpp/test_same_statements.cpp | 104 ++++++++ tests/cpp/test_stmt_field_manager.cpp | 34 +++ tests/python/test_random.py | 2 + 15 files changed, 660 insertions(+), 132 deletions(-) create mode 100644 taichi/analysis/check_fields_registered.cpp create mode 100644 taichi/analysis/same_statements.cpp create mode 100644 tests/cpp/test_same_statements.cpp diff --git a/examples/nbody_oscillator.py b/examples/nbody_oscillator.py index 92e867dc53e41..24ba33310ae10 100644 --- a/examples/nbody_oscillator.py +++ b/examples/nbody_oscillator.py @@ -21,7 +21,8 @@ def initialize(): def advance(dt: ti.f32): for i in range(N): for k in ti.static(range(2)): - if pos[i][k] < 0 and vel[i][k] < 0 or pos[i][k] > 1 and vel[i][k] > 0: + if pos[i][k] < 0 and vel[i][k] < 0 or pos[i][k] > 1 and vel[i][ + k] > 0: vel[i][k] = -bounce * vel[i][k] p = pos[i] - ti.Vector([0.5, 0.5]) diff --git a/taichi/analysis/check_fields_registered.cpp b/taichi/analysis/check_fields_registered.cpp new file mode 100644 index 0000000000000..dcbd99d4a42cc --- /dev/null +++ b/taichi/analysis/check_fields_registered.cpp @@ -0,0 +1,60 @@ +#include "taichi/ir/ir.h" + +TLANG_NAMESPACE_BEGIN + +class FieldsRegisteredChecker : public BasicStmtVisitor { + public: + using BasicStmtVisitor::visit; + + FieldsRegisteredChecker() { + allow_undefined_visitor = true; + invoke_default_visitor = true; + } + + void visit(Stmt *stmt) override { + TI_ASSERT(stmt->fields_registered); + } + + void visit(IfStmt *if_stmt) override { + TI_ASSERT(if_stmt->fields_registered); + if (if_stmt->true_statements) + if_stmt->true_statements->accept(this); + if (if_stmt->false_statements) { + if_stmt->false_statements->accept(this); + } + } + + void visit(WhileStmt *stmt) override { + TI_ASSERT(stmt->fields_registered); + stmt->body->accept(this); + } + + void visit(RangeForStmt *for_stmt) override { + TI_ASSERT(for_stmt->fields_registered); + for_stmt->body->accept(this); + } + + void visit(StructForStmt *for_stmt) override { + TI_ASSERT(for_stmt->fields_registered); + for_stmt->body->accept(this); + } + + void visit(OffloadedStmt *stmt) override { + TI_ASSERT(stmt->fields_registered); + if (stmt->body) + stmt->body->accept(this); + } + + static void run(IRNode *root) { + FieldsRegisteredChecker checker; + root->accept(&checker); + } +}; + +namespace irpass { +void check_fields_registered(IRNode *root) { + return FieldsRegisteredChecker::run(root); +} +} // namespace irpass + +TLANG_NAMESPACE_END diff --git a/taichi/analysis/same_statements.cpp b/taichi/analysis/same_statements.cpp new file mode 100644 index 0000000000000..99d7370021a9e --- /dev/null +++ b/taichi/analysis/same_statements.cpp @@ -0,0 +1,198 @@ +#include "taichi/ir/ir.h" +#include +#include + +TLANG_NAMESPACE_BEGIN + +// Compare if two IRNodes are equivalent. +class IRNodeComparator : public IRVisitor { + private: + IRNode *other_node; + // map the id from this node to the other node + std::unordered_map id_map; + // ids which don't belong to either node + std::unordered_set captured_id; + + public: + bool same; + + explicit IRNodeComparator(IRNode *other_node) : other_node(other_node) { + allow_undefined_visitor = true; + invoke_default_visitor = true; + same = true; + } + + void map_id(int this_id, int other_id) { + if (captured_id.find(this_id) != captured_id.end() || + captured_id.find(other_id) != captured_id.end()) { + same = false; + return; + } + auto it = id_map.find(this_id); + if (it == id_map.end()) { + id_map[this_id] = other_id; + } else if (it->second != other_id) { + same = false; + } + } + + int get_other_id(int this_id) { + // get the corresponding id in the other node + auto it = id_map.find(this_id); + if (it != id_map.end()) { + return it->second; + } + // if not found, should be captured + // (What if this_id belongs to the other node? Ignoring this case here.) + if (captured_id.find(this_id) == captured_id.end()) { + captured_id.insert(this_id); + } + return this_id; + } + + void visit(Block *stmt_list) override { + if (!other_node->is()) { + same = false; + return; + } + + auto other = other_node->as(); + if (stmt_list->size() != other->size()) { + same = false; + return; + } + for (int i = 0; i < (int)stmt_list->size(); i++) { + other_node = other->statements[i].get(); + stmt_list->statements[i]->accept(this); + if (!same) + break; + } + other_node = other; + } + + void basic_check(Stmt *stmt) { + // type check + if (typeid(*other_node) != typeid(*stmt)) { + same = false; + return; + } + + // operand check + auto other = other_node->as(); + if (stmt->num_operands() != other->num_operands()) { + same = false; + return; + } + for (int i = 0; i < stmt->num_operands(); i++) { + if (get_other_id(stmt->operand(i)->id) != other->operand(i)->id) { + same = false; + return; + } + } + + // field check + if (!stmt->field_manager.equal(other->field_manager)) { + same = false; + return; + } + + map_id(stmt->id, other->id); + } + + void visit(Stmt *stmt) override { + basic_check(stmt); + } + + void visit(IfStmt *stmt) override { + basic_check(stmt); + if (!same) + return; + auto other = other_node->as(); + if (stmt->true_statements) { + if (!other->true_statements) { + same = false; + return; + } + other_node = other->true_statements.get(); + stmt->true_statements->accept(this); + other_node = other; + } + if (stmt->false_statements && same) { + if (!other->false_statements) { + same = false; + return; + } + other_node = other->false_statements.get(); + stmt->false_statements->accept(this); + other_node = other; + } + } + + void visit(FuncBodyStmt *stmt) override { + basic_check(stmt); + if (!same) + return; + auto other = other_node->as(); + other_node = other->body.get(); + stmt->body->accept(this); + other_node = other; + } + + void visit(WhileStmt *stmt) override { + basic_check(stmt); + if (!same) + return; + auto other = other_node->as(); + other_node = other->body.get(); + stmt->body->accept(this); + other_node = other; + } + + void visit(RangeForStmt *stmt) override { + basic_check(stmt); + if (!same) + return; + auto other = other_node->as(); + other_node = other->body.get(); + stmt->body->accept(this); + other_node = other; + } + + void visit(StructForStmt *stmt) override { + basic_check(stmt); + if (!same) + return; + auto other = other_node->as(); + other_node = other->body.get(); + stmt->body->accept(this); + other_node = other; + } + + void visit(OffloadedStmt *stmt) override { + basic_check(stmt); + if (!same) + return; + auto other = other_node->as(); + if (stmt->has_body()) { + TI_ASSERT(stmt->body); + TI_ASSERT(other->body); + other_node = other->body.get(); + stmt->body->accept(this); + other_node = other; + } + } + + static bool run(IRNode *root1, IRNode *root2) { + IRNodeComparator comparator(root2); + root1->accept(&comparator); + return comparator.same; + } +}; + +namespace irpass { +bool same_statements(IRNode *root1, IRNode *root2) { + return IRNodeComparator::run(root1, root2); +} +} // namespace irpass + +TLANG_NAMESPACE_END diff --git a/taichi/backends/opengl/opengl_api.cpp b/taichi/backends/opengl/opengl_api.cpp index 5395b351adabd..2ffe6d17432a0 100644 --- a/taichi/backends/opengl/opengl_api.cpp +++ b/taichi/backends/opengl/opengl_api.cpp @@ -244,7 +244,8 @@ GLSSBO *root_ssbo; void create_glsl_root_buffer(size_t size) { // if (root_ssbo) return; initialize_opengl(); - root_ssbo = new GLSSBO; // TODO(archibate): mem leaking, use std::optional instead + root_ssbo = + new GLSSBO; // TODO(archibate): mem leaking, use std::optional instead size += 2 * sizeof(int); void *buffer = std::calloc(size, 1); root_ssbo->bind_data(buffer, size, GL_DYNAMIC_READ); diff --git a/taichi/codegen/codegen_opengl.cpp b/taichi/codegen/codegen_opengl.cpp index 0def84b6361fd..d37eb2bc4738f 100644 --- a/taichi/codegen/codegen_opengl.cpp +++ b/taichi/codegen/codegen_opengl.cpp @@ -195,30 +195,30 @@ class KernelGen : public IRVisitor { if (used.atomic_float && !opengl_has_GL_NV_shader_atomic_float) { // {{{ kernel_header += ( #include "taichi/backends/opengl/shaders/atomics_data_f32.glsl.h" - ); + ); #ifdef _GLSL_INT64 kernel_header += ( #include "taichi/backends/opengl/shaders/atomics_data_f64.glsl.h" - ); + ); #endif if (used.global_temp) { kernel_header += ( #include "taichi/backends/opengl/shaders/atomics_gtmp_f32.glsl.h" - ); + ); #ifdef _GLSL_INT64 - kernel_header += ( + kernel_header += ( #include "taichi/backends/opengl/shaders/atomics_gtmp_f64.glsl.h" - ); + ); #endif } if (used.external_ptr) { kernel_header += ( #include "taichi/backends/opengl/shaders/atomics_extr_f32.glsl.h" - ); + ); #ifdef _GLSL_INT64 - kernel_header += ( + kernel_header += ( #include "taichi/backends/opengl/shaders/atomics_extr_f64.glsl.h" - ); + ); #endif } } // }}} @@ -226,7 +226,7 @@ class KernelGen : public IRVisitor { // share rand seed? {{{ kernel_header += ( #include "taichi/backends/opengl/shaders/random.glsl.h" - ); + ); } // }}} line_appender_header_.append_raw(kernel_header); @@ -236,8 +236,10 @@ class KernelGen : public IRVisitor { threads_per_group = std::max(1, num_threads_); else num_groups_ = (num_threads_ + threads_per_group - 1) / threads_per_group; - emit("layout(local_size_x = {} /* {}, {} */, local_size_y = 1, local_size_z = 1) in;", - threads_per_group, num_groups_, num_threads_); + emit( + "layout(local_size_x = {} /* {}, {} */, local_size_y = 1, local_size_z " + "= 1) in;", + threads_per_group, num_groups_, num_threads_); std::string extensions = ""; if (opengl_has_GL_NV_shader_atomic_float) { extensions += "#extension GL_NV_shader_atomic_float: enable\n"; diff --git a/taichi/common/meta.h b/taichi/common/meta.h index eed0d840dc609..306c373764db8 100644 --- a/taichi/common/meta.h +++ b/taichi/common/meta.h @@ -77,6 +77,12 @@ struct copy_refcv { template using copy_refcv_t = typename copy_refcv::type; +template class Template> +struct is_specialization : std::false_type {}; + +template