Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Statement Field Manager #690

Merged
merged 3 commits into from
Apr 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion taichi/common/serialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ class BinarySerializer : public Serializer {
"the raw pointer.");
}
} else {
std::size_t val_ptr;
std::size_t val_ptr = 0;
this->operator()("", val_ptr);
if (val_ptr != 0) {
TI_ASSERT(assets.find(val_ptr) != assets.end());
Expand Down
82 changes: 81 additions & 1 deletion taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -502,12 +502,82 @@ struct LaneAttribute {
}
};

class StmtField {
public:
StmtField() = default;

virtual bool equal(const StmtField *other) const = 0;

virtual ~StmtField() = default;
};

template <typename T>
class StmtFieldNumeric final : public StmtField {
private:
T value;

public:
explicit StmtFieldNumeric(T value) : value(value) {
}

bool equal(const StmtField *other_generic) const override {
if (auto other = dynamic_cast<const StmtFieldNumeric *>(other_generic)) {
return other->value == value;
} else {
// Different types
return false;
}
}
};

class StmtFieldManager {
private:
Stmt *stmt;

public:
std::vector<std::unique_ptr<StmtField>> fields;

StmtFieldManager(Stmt *stmt) : stmt(stmt) {
}

template <typename T>
void operator()(const char *key, T &&value);

template <typename T, typename... Args>
void operator()(const char *key_, T &&t, Args &&... rest) {
std::string key(key_);
size_t pos = key.find(",");
std::string first_name = key.substr(0, pos);
std::string rest_names =
key.substr(pos + 2, int(key.size()) - (int)pos - 2);
this->operator()(first_name.c_str(), std::forward<T>(t));
this->operator()(rest_names.c_str(), std::forward<Args>(rest)...);
}

bool equal(StmtFieldManager &other) const {
if (fields.size() != other.fields.size()) {
return false;
}
auto num_fields = fields.size();
for (std::size_t i = 0; i < num_fields; i++) {
if (!fields[i]->equal(other.fields[i].get())) {
return false;
}
}
return true;
}
};

#define TI_STMT_DEF_FIELDS(...) TI_IO_DEF(__VA_ARGS__)
#define TI_STMT_REG_FIELDS io(field_manager)

class Stmt : public IRNode {
protected: // NOTE: operands should not be directly modified, for the
// correctness of operand_bitmap
std::vector<Stmt **> operands;

public:
StmtFieldManager field_manager;
static std::atomic<int> instance_id_counter;
int instance_id;
int id;
Expand All @@ -519,7 +589,7 @@ class Stmt : public IRNode {

Stmt(const Stmt &stmt) = delete;

Stmt() {
Stmt() : field_manager(this) {
parent = nullptr;
instance_id = instance_id_counter++;
id = instance_id;
Expand Down Expand Up @@ -680,6 +750,16 @@ class Stmt : public IRNode {
virtual ~Stmt() override = default;
};

template <typename T>
inline void StmtFieldManager::operator()(const char *key, T &&value) {
if constexpr (std::is_same<typename std::decay<T>::type, Stmt *>::value) {
stmt->add_operand(const_cast<Stmt *&>(value));
} else {
stmt->field_manager.fields.emplace_back(
std::make_unique<StmtFieldNumeric<T>>(value));
}
}

// always a tree - used as rvalues
class Expression {
public:
Expand Down
6 changes: 6 additions & 0 deletions taichi/ir/statements.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// TODO: gradually cppize statements.h
#include "statements.h"

TLANG_NAMESPACE_BEGIN

TLANG_NAMESPACE_END
37 changes: 37 additions & 0 deletions tests/cpp/test_stmt_field_manager.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#include "taichi/ir/ir.h"
#include "taichi/common/testing.h"

TLANG_NAMESPACE_BEGIN

class TestStmt : public Stmt {
private:
Stmt *input;
int a;
float b;

public:
TestStmt(Stmt *input, int a, int b) : input(input), a(a), b(b) {
TI_STMT_REG_FIELDS;
}

TI_STMT_DEF_FIELDS(input, a, b);
};

TI_TEST("test_stmt_field_manager") {
auto a = Stmt::make<TestStmt>(nullptr, 1, 2.0f);

TI_CHECK(a->num_operands() == 1);
TI_CHECK(a->field_manager.fields.size() == 2);

auto b = Stmt::make<TestStmt>(nullptr, 1, 2.0f);

TI_CHECK(a->field_manager.equal(b->field_manager) == true);

auto c = Stmt::make<TestStmt>(nullptr, 2, 2.1f);

TI_CHECK(a->field_manager.equal(c->field_manager) == false);
// To test two statements are equal: 1) same Stmt type 2) same operands 3)
// same field_manager
}

TLANG_NAMESPACE_END