From 98e27878891cb338188fa6fa9927c0097fea955a Mon Sep 17 00:00:00 2001 From: Yuanming Hu Date: Tue, 31 Mar 2020 19:55:19 -0400 Subject: [PATCH 1/3] base --- python/taichi/main.py | 2 +- taichi/common/serialization.h | 2 +- taichi/ir/ir.h | 82 ++++++++++++++++++++++++++- taichi/ir/statements.cpp | 6 ++ tests/cpp/test_stmt_field_manager.cpp | 27 +++++++++ 5 files changed, 116 insertions(+), 3 deletions(-) create mode 100644 taichi/ir/statements.cpp create mode 100644 tests/cpp/test_stmt_field_manager.cpp diff --git a/python/taichi/main.py b/python/taichi/main.py index e8eb313a2f814..176c7c6fde7dd 100644 --- a/python/taichi/main.py +++ b/python/taichi/main.py @@ -150,10 +150,10 @@ def main(debug=False): script = script.read() exec(script, {'__name__': '__main__'}) elif mode == "test": + ret = test_cpp(args) ret = test_python(args) if ret: return -1 - ret = test_cpp(args) return ret elif mode == "build": ti.core.build() diff --git a/taichi/common/serialization.h b/taichi/common/serialization.h index c4522b3d945d7..b2360a295bdda 100644 --- a/taichi/common/serialization.h +++ b/taichi/common/serialization.h @@ -450,7 +450,7 @@ class BinarySerializer : public Serializer { "the raw pointer."); } } else { - std::size_t val_ptr; + std::size_t val_ptr = 0; this->operator()("", val_ptr); if (val_ptr != 0) { TI_ASSERT(assets.find(val_ptr) != assets.end()); diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index 9744d40869a80..ad50e90df2f52 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -502,12 +502,82 @@ struct LaneAttribute { } }; +class StmtField { + public: + StmtField() = default; + + virtual bool equal(const StmtField *other) = 0; + + virtual ~StmtField() = default; +}; + +template +class StmtFieldNumeric final : public StmtField { + public: + T value; + + explicit StmtFieldNumeric(T value) : value(value) { + } + + bool equal(const StmtField *other_generic) override { + if (auto other = dynamic_cast(other_generic)) { + return other->value == value; + } else { + // Different types + return false; + } + } +}; + +class StmtFieldManager { + private: + Stmt *stmt; + std::vector> fields; + + public: + StmtFieldManager(Stmt *stmt) : stmt(stmt) { + } + + void operator()(const char *_, Stmt *&value); + + template + void operator()(const char *key, T value); + + template + void operator()(const char *key_, T &t, Args &&... rest) { + std::string key(key_); + size_t pos = key.find(","); + std::string first_name = key.substr(0, pos); + std::string rest_names = + key.substr(pos + 2, int(key.size()) - (int)pos - 2); + this->operator()(first_name.c_str(), t); + this->operator()(rest_names.c_str(), std::forward(rest)...); + } + + bool equal(StmtFieldManager *other) { + if (fields.size() != other->fields.size()) { + return false; + } + auto num_fields = fields.size(); + for (std::size_t i = 0; i < num_fields; i++) { + if (!fields[i]->equal(other->fields[i].get())) { + return false; + } + } + return true; + } +}; + +#define TI_STMT_DEF_FIELDS(...) TI_IO_DEF(__VA_ARGS__) +#define TI_STMT_REG_FIELDS io(field_manager) + class Stmt : public IRNode { protected: // NOTE: operands should not be directly modified, for the // correctness of operand_bitmap std::vector operands; public: + StmtFieldManager field_manager; static std::atomic instance_id_counter; int instance_id; int id; @@ -519,7 +589,7 @@ class Stmt : public IRNode { Stmt(const Stmt &stmt) = delete; - Stmt() { + Stmt() : field_manager(this) { parent = nullptr; instance_id = instance_id_counter++; id = instance_id; @@ -680,6 +750,16 @@ class Stmt : public IRNode { virtual ~Stmt() override = default; }; +inline void StmtFieldManager::operator()(const char *_, Stmt *&value) { + stmt->add_operand(value); +} + +template +inline void StmtFieldManager::operator()(const char *key, T value) { + stmt->field_manager.fields.emplace_back( + std::make_unique>(value)); +} + // always a tree - used as rvalues class Expression { public: diff --git a/taichi/ir/statements.cpp b/taichi/ir/statements.cpp new file mode 100644 index 0000000000000..96fad1c712d44 --- /dev/null +++ b/taichi/ir/statements.cpp @@ -0,0 +1,6 @@ +// TODO: gradually cppize statements.h +#include "statements.h" + +TLANG_NAMESPACE_BEGIN + +TLANG_NAMESPACE_END diff --git a/tests/cpp/test_stmt_field_manager.cpp b/tests/cpp/test_stmt_field_manager.cpp new file mode 100644 index 0000000000000..7b46c6fccebbd --- /dev/null +++ b/tests/cpp/test_stmt_field_manager.cpp @@ -0,0 +1,27 @@ +#include "taichi/ir/ir.h" +#include "taichi/common/testing.h" + +TLANG_NAMESPACE_BEGIN + +class TestStmt : public Stmt { +private: + Stmt *input; + int a; + float b; + +public: + TestStmt(Stmt *input, int a, int b): input(input), a(a), b(b) { + TI_STMT_REG_FIELDS; + } + + TI_STMT_DEF_FIELDS(input, a, b); +}; + + +TI_TEST("test_stmt_field_manager") { + auto a = Stmt::make(nullptr, 1, 2.0f); + + TI_CHECK(a->num_operands() == 1); +} + +TLANG_NAMESPACE_END From 6b9d00cbb55baf899ca3278b1d9ef47372c775d5 Mon Sep 17 00:00:00 2001 From: Yuanming Hu Date: Tue, 31 Mar 2020 20:31:35 -0400 Subject: [PATCH 2/3] update --- taichi/ir/ir.h | 34 +++++++++++++-------------- tests/cpp/test_stmt_field_manager.cpp | 17 ++++++++++---- 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index ad50e90df2f52..2c6b769166da7 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -506,7 +506,7 @@ class StmtField { public: StmtField() = default; - virtual bool equal(const StmtField *other) = 0; + virtual bool equal(const StmtField *other) const = 0; virtual ~StmtField() = default; }; @@ -519,7 +519,7 @@ class StmtFieldNumeric final : public StmtField { explicit StmtFieldNumeric(T value) : value(value) { } - bool equal(const StmtField *other_generic) override { + bool equal(const StmtField *other_generic) const override { if (auto other = dynamic_cast(other_generic)) { return other->value == value; } else { @@ -530,7 +530,7 @@ class StmtFieldNumeric final : public StmtField { }; class StmtFieldManager { - private: + public: Stmt *stmt; std::vector> fields; @@ -538,29 +538,27 @@ class StmtFieldManager { StmtFieldManager(Stmt *stmt) : stmt(stmt) { } - void operator()(const char *_, Stmt *&value); - template - void operator()(const char *key, T value); + void operator()(const char *key, T &value); template - void operator()(const char *key_, T &t, Args &&... rest) { + void operator()(const char *key_, T &t, Args &... rest) { std::string key(key_); size_t pos = key.find(","); std::string first_name = key.substr(0, pos); std::string rest_names = key.substr(pos + 2, int(key.size()) - (int)pos - 2); this->operator()(first_name.c_str(), t); - this->operator()(rest_names.c_str(), std::forward(rest)...); + this->operator()(rest_names.c_str(), rest...); } - bool equal(StmtFieldManager *other) { - if (fields.size() != other->fields.size()) { + bool equal(StmtFieldManager &other) const { + if (fields.size() != other.fields.size()) { return false; } auto num_fields = fields.size(); for (std::size_t i = 0; i < num_fields; i++) { - if (!fields[i]->equal(other->fields[i].get())) { + if (!fields[i]->equal(other.fields[i].get())) { return false; } } @@ -750,14 +748,14 @@ class Stmt : public IRNode { virtual ~Stmt() override = default; }; -inline void StmtFieldManager::operator()(const char *_, Stmt *&value) { - stmt->add_operand(value); -} - template -inline void StmtFieldManager::operator()(const char *key, T value) { - stmt->field_manager.fields.emplace_back( - std::make_unique>(value)); +inline void StmtFieldManager::operator()(const char *key, T &value) { + if constexpr (std::is_same::type, Stmt *>::value) { + stmt->add_operand(const_cast(value)); + } else { + stmt->field_manager.fields.emplace_back( + std::make_unique>(value)); + } } // always a tree - used as rvalues diff --git a/tests/cpp/test_stmt_field_manager.cpp b/tests/cpp/test_stmt_field_manager.cpp index 7b46c6fccebbd..5eba9e6a0f361 100644 --- a/tests/cpp/test_stmt_field_manager.cpp +++ b/tests/cpp/test_stmt_field_manager.cpp @@ -4,24 +4,33 @@ TLANG_NAMESPACE_BEGIN class TestStmt : public Stmt { -private: + private: Stmt *input; int a; float b; -public: - TestStmt(Stmt *input, int a, int b): input(input), a(a), b(b) { + public: + TestStmt(Stmt *input, int a, int b) : input(input), a(a), b(b) { TI_STMT_REG_FIELDS; } TI_STMT_DEF_FIELDS(input, a, b); }; - TI_TEST("test_stmt_field_manager") { auto a = Stmt::make(nullptr, 1, 2.0f); TI_CHECK(a->num_operands() == 1); + TI_CHECK(a->field_manager.fields.size() == 2); + + auto b = Stmt::make(nullptr, 1, 2.0f); + + TI_CHECK(a->field_manager.equal(b->field_manager) == true); + + auto c = Stmt::make(nullptr, 2, 2.1f); + + TI_CHECK(a->field_manager.equal(c->field_manager) == false); + // To test two statements are equal: 1) same Stmt type 2) same operands 3) same field_manager } TLANG_NAMESPACE_END From 214b4b352b385b92e4418c94a916c9ee769a5cae Mon Sep 17 00:00:00 2001 From: Yuanming Hu Date: Tue, 31 Mar 2020 20:42:27 -0400 Subject: [PATCH 3/3] finalized --- python/taichi/main.py | 2 +- taichi/ir/ir.h | 18 ++++++++++-------- tests/cpp/test_stmt_field_manager.cpp | 3 ++- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/python/taichi/main.py b/python/taichi/main.py index 176c7c6fde7dd..e8eb313a2f814 100644 --- a/python/taichi/main.py +++ b/python/taichi/main.py @@ -150,10 +150,10 @@ def main(debug=False): script = script.read() exec(script, {'__name__': '__main__'}) elif mode == "test": - ret = test_cpp(args) ret = test_python(args) if ret: return -1 + ret = test_cpp(args) return ret elif mode == "build": ti.core.build() diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index 2c6b769166da7..788c744a02e9d 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -513,9 +513,10 @@ class StmtField { template class StmtFieldNumeric final : public StmtField { - public: + private: T value; + public: explicit StmtFieldNumeric(T value) : value(value) { } @@ -530,26 +531,27 @@ class StmtFieldNumeric final : public StmtField { }; class StmtFieldManager { - public: + private: Stmt *stmt; - std::vector> fields; public: + std::vector> fields; + StmtFieldManager(Stmt *stmt) : stmt(stmt) { } template - void operator()(const char *key, T &value); + void operator()(const char *key, T &&value); template - void operator()(const char *key_, T &t, Args &... rest) { + void operator()(const char *key_, T &&t, Args &&... rest) { std::string key(key_); size_t pos = key.find(","); std::string first_name = key.substr(0, pos); std::string rest_names = key.substr(pos + 2, int(key.size()) - (int)pos - 2); - this->operator()(first_name.c_str(), t); - this->operator()(rest_names.c_str(), rest...); + this->operator()(first_name.c_str(), std::forward(t)); + this->operator()(rest_names.c_str(), std::forward(rest)...); } bool equal(StmtFieldManager &other) const { @@ -749,7 +751,7 @@ class Stmt : public IRNode { }; template -inline void StmtFieldManager::operator()(const char *key, T &value) { +inline void StmtFieldManager::operator()(const char *key, T &&value) { if constexpr (std::is_same::type, Stmt *>::value) { stmt->add_operand(const_cast(value)); } else { diff --git a/tests/cpp/test_stmt_field_manager.cpp b/tests/cpp/test_stmt_field_manager.cpp index 5eba9e6a0f361..07431788eccbe 100644 --- a/tests/cpp/test_stmt_field_manager.cpp +++ b/tests/cpp/test_stmt_field_manager.cpp @@ -30,7 +30,8 @@ TI_TEST("test_stmt_field_manager") { auto c = Stmt::make(nullptr, 2, 2.1f); TI_CHECK(a->field_manager.equal(c->field_manager) == false); - // To test two statements are equal: 1) same Stmt type 2) same operands 3) same field_manager + // To test two statements are equal: 1) same Stmt type 2) same operands 3) + // same field_manager } TLANG_NAMESPACE_END