From 572e9b0814f7deb7fcfe545d3641d6e3141bcf38 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Sun, 12 Apr 2020 17:19:55 -0400 Subject: [PATCH 1/7] Slim ir.h --- taichi/ir/ir.cpp | 69 ++++++++++++++++++++++++++++++++++++++---------- taichi/ir/ir.h | 39 +++------------------------ 2 files changed, 59 insertions(+), 49 deletions(-) diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index 3b47cbcbae4aa..7c95655c1030b 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -14,6 +14,50 @@ TLANG_NAMESPACE_BEGIN #define TI_EXPRESSION_IMPLEMENTATION #include "expression.h" +void DecoratorRecorder::reset() { + vectorize = -1; + parallelize = 0; + uniform = false; + scratch_opt.clear(); + block_dim = 0; + strictly_serialized = false; +} + +Block *IRBuilder::current_block() { + if (stack.empty()) + return nullptr; + else + return stack.back(); +} + +Stmt *IRBuilder::get_last_stmt() { + return stack.back()->back(); +} + +void IRBuilder::insert(std::unique_ptr &&stmt, int location) { + TI_ASSERT(!stack.empty()); + stack.back()->insert(std::move(stmt), location); +} + +void IRBuilder::stop_gradient(SNode *snode) { + TI_ASSERT(!stack.empty()); + stack.back()->stop_gradients.push_back(snode); +} + +int Identifier::id_counter = 0; +std::string Identifier::raw_name() const { + if (name_.empty()) + return fmt::format("tmp{}", id); + else + return name_; +} + +Stmt *VecStatement::push_back(pStmt &&stmt) { + auto ret = stmt.get(); + stmts.push_back(std::move(stmt)); + return ret; +} + class StatementTypeNameVisitor : public IRVisitor { public: std::string type_name; @@ -29,21 +73,23 @@ class StatementTypeNameVisitor : public IRVisitor { #undef PER_STATEMENT }; +inline Expr load_if_ptr(const Expr &ptr) { + if (ptr.is()) { + return load(ptr); + } else if (ptr.is()) { + TI_ASSERT(ptr.cast()->snode->num_active_indices == + 0); + return load(ptr[ExprGroup()]); + } else + return ptr; +} + std::string Stmt::type() { StatementTypeNameVisitor v; this->accept(&v); return v.type_name; } -void IRBuilder::insert(std::unique_ptr &&stmt, int location) { - TI_ASSERT(!stack.empty()); - stack.back()->insert(std::move(stmt), location); -} - -void IRBuilder::stop_gradient(SNode *snode) { - TI_ASSERT(!stack.empty()); - stack.back()->stop_gradients.push_back(snode); -} GetChStmt::GetChStmt(taichi::lang::Stmt *input_ptr, int chid) : input_ptr(input_ptr), chid(chid) { @@ -267,7 +313,6 @@ IRNode *FrontendContext::root() { return static_cast(root_node.get()); } -int Identifier::id_counter = 0; std::atomic Stmt::instance_id_counter(0); std::unique_ptr context; @@ -501,10 +546,6 @@ For::For(const Expr &s, const Expr &e, const std::function &func) { func(i); } -Stmt *IRBuilder::get_last_stmt() { - return stack.back()->back(); -} - OffloadedStmt::OffloadedStmt(OffloadedStmt::TaskType task_type) : OffloadedStmt(task_type, nullptr) { } diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index 4116cf465418f..ff4ba9c7f9a4f 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -199,14 +199,7 @@ class DecoratorRecorder { reset(); } - void reset() { - vectorize = -1; - parallelize = 0; - uniform = false; - scratch_opt.clear(); - block_dim = 0; - strictly_serialized = false; - } + void reset(); }; class FrontendContext { @@ -255,12 +248,7 @@ class IRBuilder { std::unique_ptr create_scope(std::unique_ptr &list); - Block *current_block() { - if (stack.empty()) - return nullptr; - else - return stack.back(); - } + Block *current_block(); Stmt *get_last_stmt(); @@ -282,12 +270,7 @@ class Identifier { id = id_counter++; } - std::string raw_name() const { - if (name_.empty()) - return fmt::format("tmp{}", id); - else - return name_; - } + std::string raw_name() const; std::string name() const { return "@" + raw_name(); @@ -323,11 +306,7 @@ class VecStatement { stmts = std::move(other_stmts); } - Stmt *push_back(pStmt &&stmt) { - auto ret = stmt.get(); - stmts.push_back(std::move(stmt)); - return ret; - } + Stmt *push_back(pStmt &&stmt); template T *push_back(Args &&... args) { @@ -2371,16 +2350,6 @@ inline Expr load(Expr ptr) { return Expr::make(ptr); } -inline Expr load_if_ptr(const Expr &ptr) { - if (ptr.is()) { - return load(ptr); - } else if (ptr.is()) { - TI_ASSERT(ptr.cast()->snode->num_active_indices == - 0); - return load(ptr[ExprGroup()]); - } else - return ptr; -} inline Expr ptr_if_global(const Expr &var) { if (var.is()) { From ffe98d9cb6ca07e0cb7091c8336ba9f5a6e45aa7 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Sun, 12 Apr 2020 22:31:57 -0400 Subject: [PATCH 2/7] move lots of stuffs to expr.h/expr.cpp --- taichi/ir/expr.cpp | 73 ++++++++ taichi/ir/expr.h | 23 +++ taichi/ir/ir.cpp | 415 ++++++++++++++++++++++++++++++--------------- taichi/ir/ir.h | 241 +++----------------------- 4 files changed, 397 insertions(+), 355 deletions(-) diff --git a/taichi/ir/expr.cpp b/taichi/ir/expr.cpp index ce596b239d715..a5fcc69c7e4d9 100644 --- a/taichi/ir/expr.cpp +++ b/taichi/ir/expr.cpp @@ -1,5 +1,6 @@ #include "expr.h" #include "ir.h" +#include "taichi/program/program.h" TLANG_NAMESPACE_BEGIN @@ -20,4 +21,76 @@ std::string Expr::get_attribute(const std::string &key) const { return expr->get_attribute(key); } +Expr select(const Expr &cond, const Expr &true_val, const Expr &false_val) { + return Expr::make(TernaryOpType::select, cond, true_val, + false_val); +} + +Expr operator-(const Expr &expr) { + return Expr::make(UnaryOpType::neg, expr); +} + +Expr operator~(const Expr &expr) { + return Expr::make(UnaryOpType::bit_not, expr); +} + +Expr cast(const Expr &input, DataType dt) { + auto ret = std::make_shared(UnaryOpType::cast, input); + ret->cast_type = dt; + ret->cast_by_value = true; + return Expr(ret); +} + +Expr bit_cast(const Expr &input, DataType dt) { + auto ret = std::make_shared(UnaryOpType::cast, input); + ret->cast_type = dt; + ret->cast_by_value = false; + return Expr(ret); +} + +Expr Expr::operator[](const ExprGroup &indices) const { + TI_ASSERT(is() || is()); + return Expr::make(*this, indices.loaded()); +} + +Expr &Expr::operator=(const Expr &o) { + if (get_current_program().current_kernel) { + if (expr == nullptr) { + set(o.eval()); + } else if (expr->is_lvalue()) { + current_ast_builder().insert(std::make_unique( + ptr_if_global(*this), load_if_ptr(o))); + } else { + // set(o.eval()); + TI_ERROR("Cannot assign to non-lvalue: {}", serialize()); + } + } else { + set(o); + } + return *this; +} + +Expr Expr::parent() const { + TI_ASSERT(is()); + return Expr::make( + cast()->snode->parent); +} + +SNode *Expr::snode() const { + TI_ASSERT(is()); + return cast()->snode; +} + +Expr Expr::operator!() { + return Expr::make(UnaryOpType::logic_not, expr); +} + +void Expr::declare(DataType dt) { + set(Expr::make(dt, Identifier())); +} + +void Expr::set_grad(const Expr &o) { + this->cast()->adjoint.set(o); +} + TLANG_NAMESPACE_END diff --git a/taichi/ir/expr.h b/taichi/ir/expr.h index 1f1ec0f2181e1..1deee99971f05 100644 --- a/taichi/ir/expr.h +++ b/taichi/ir/expr.h @@ -102,4 +102,27 @@ class Expr { std::string get_attribute(const std::string &key) const; }; +Expr select(const Expr &cond, const Expr &true_val, const Expr &false_val); + +Expr operator-(const Expr &expr); + +Expr operator~(const Expr &expr); + +// Value cast +Expr cast(const Expr &input, DataType dt); + +template +Expr cast(const Expr &input) { + return taichi::lang::cast(input, get_data_type()); +} + +Expr bit_cast(const Expr &input, DataType dt); + +template +Expr bit_cast(const Expr &input) { + return taichi::lang::bit_cast(input, get_data_type()); +} + + + TLANG_NAMESPACE_END diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index 7c95655c1030b..96ddb82d5b16d 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -14,6 +14,28 @@ TLANG_NAMESPACE_BEGIN #define TI_EXPRESSION_IMPLEMENTATION #include "expression.h" +IRBuilder ¤t_ast_builder() { + return context->builder(); +} + +std::string VectorType::pointer_suffix() const { + if (is_pointer()) { + return "*"; + } else { + return ""; + } +} + +std::string VectorType::element_type_name() const { + return fmt::format("{}{}", data_type_short_name(data_type), + pointer_suffix()); +} + +std::string VectorType::str() const { + auto ename = element_type_name(); + return fmt::format("{:4}x{}", ename, width); +} + void DecoratorRecorder::reset() { vectorize = -1; parallelize = 0; @@ -44,6 +66,16 @@ void IRBuilder::stop_gradient(SNode *snode) { stack.back()->stop_gradients.push_back(snode); } +std::unique_ptr IRBuilder::create_scope( + std::unique_ptr &list) { + TI_ASSERT(list == nullptr); + list = std::make_unique(); + if (!stack.empty()) { + list->parent = stack.back(); + } + return std::make_unique(this, list.get()); +} + int Identifier::id_counter = 0; std::string Identifier::raw_name() const { if (name_.empty()) @@ -84,51 +116,155 @@ inline Expr load_if_ptr(const Expr &ptr) { return ptr; } +inline Expr smart_load(const Expr &var) { + return load_if_ptr(ptr_if_global(var)); +} + +int StmtFieldSNode::get_snode_id(taichi::lang::SNode *snode) { + if (snode == nullptr) + return -1; + return snode->id; +} + +bool StmtFieldSNode::equal(const StmtField *other_generic) const { + if (auto other = dynamic_cast(other_generic)) { + return get_snode_id(snode) == get_snode_id(other->snode); + } else { + // Different types + return false; + } +} + +bool StmtFieldManager::equal(StmtFieldManager &other) const { + if (fields.size() != other.fields.size()) { + return false; + } + auto num_fields = fields.size(); + for (std::size_t i = 0; i < num_fields; i++) { + if (!fields[i]->equal(other.fields[i].get())) { + return false; + } + } + return true; +} + +std::atomic Stmt::instance_id_counter(0); + +Stmt::Stmt() : field_manager(this), fields_registered(false) { + parent = nullptr; + instance_id = instance_id_counter++; + id = instance_id; + operand_bitmap = 0; + erased = false; + is_ptr = false; +} + +Stmt *Stmt::insert_before_me(std::unique_ptr &&new_stmt) { + auto ret = new_stmt.get(); + TI_ASSERT(parent); + auto &stmts = parent->statements; + int loc = -1; + for (int i = 0; i < (int)stmts.size(); i++) { + if (stmts[i].get() == this) { + loc = i; + break; + } + } + TI_ASSERT(loc != -1); + new_stmt->parent = parent; + stmts.insert(stmts.begin() + loc, std::move(new_stmt)); + return ret; +} + +Stmt *Stmt::insert_after_me(std::unique_ptr &&new_stmt) { + auto ret = new_stmt.get(); + TI_ASSERT(parent); + auto &stmts = parent->statements; + int loc = -1; + for (int i = 0; i < (int)stmts.size(); i++) { + if (stmts[i].get() == this) { + loc = i; + break; + } + } + TI_ASSERT(loc != -1); + new_stmt->parent = parent; + stmts.insert(stmts.begin() + loc + 1, std::move(new_stmt)); + return ret; +} + +void Stmt::replace_with(Stmt *new_stmt) { + auto root = get_ir_root(); + irpass::replace_all_usages_with(root, this, new_stmt); + // Note: the current structure should have been destroyed now.. +} + +void Stmt::replace_with(VecStatement &&new_statements, bool replace_usages) { + parent->replace_with(this, std::move(new_statements), replace_usages); +} + +void Stmt::replace_operand_with(Stmt *old_stmt, Stmt *new_stmt) { + operand_bitmap = 0; + int n_op = num_operands(); + for (int i = 0; i < n_op; i++) { + if (operand(i) == old_stmt) { + *operands[i] = new_stmt; + } + operand_bitmap |= operand_hash(operand(i)); + } + rebuild_operand_bitmap(); +} + +std::string Stmt::type_hint() const { + if (ret_type.data_type == DataType::unknown) + return ""; + else + return fmt::format("<{}>{}", ret_type.str(), is_ptr ? "ptr " : " "); +} + std::string Stmt::type() { StatementTypeNameVisitor v; this->accept(&v); return v.type_name; } - -GetChStmt::GetChStmt(taichi::lang::Stmt *input_ptr, int chid) - : input_ptr(input_ptr), chid(chid) { - TI_ASSERT(input_ptr->is()); - input_snode = input_ptr->as()->snode; - output_snode = input_snode->ch[chid].get(); - TI_STMT_REG_FIELDS; -} - -Expr select(const Expr &cond, const Expr &true_val, const Expr &false_val) { - return Expr::make(TernaryOpType::select, cond, true_val, - false_val); +IRNode *Stmt::get_ir_root() { + auto block = parent; + while (block->parent) + block = block->parent; + return dynamic_cast(block); } -Expr operator-(const Expr &expr) { - return Expr::make(UnaryOpType::neg, expr); +std::vector Stmt::get_operands() const { + std::vector ret; + for (int i = 0; i < num_operands(); i++) { + ret.push_back(*operands[i]); + } + return ret; } -Expr operator~(const Expr &expr) { - return Expr::make(UnaryOpType::bit_not, expr); +void Stmt::set_operand(int i, Stmt *stmt) { + *operands[i] = stmt; + rebuild_operand_bitmap(); } -Expr cast(const Expr &input, DataType dt) { - auto ret = std::make_shared(UnaryOpType::cast, input); - ret->cast_type = dt; - ret->cast_by_value = true; - return Expr(ret); +void Stmt::register_operand(Stmt *&stmt) { + operands.push_back(&stmt); + rebuild_operand_bitmap(); } -Expr bit_cast(const Expr &input, DataType dt) { - auto ret = std::make_shared(UnaryOpType::cast, input); - ret->cast_type = dt; - ret->cast_by_value = false; - return Expr(ret); +void Stmt::mark_fields_registered() { + TI_ASSERT(!fields_registered); + fields_registered = true; } -Expr Expr::operator[](const ExprGroup &indices) const { - TI_ASSERT(is() || is()); - return Expr::make(*this, indices.loaded()); +std::string Expression::get_attribute( + const std::string &key) const { + if (auto it = attributes.find(key); it == attributes.end()) { + TI_ERROR("Attribute {} not found.", key); + } else { + return it->second; + } } ExprGroup ExprGroup::loaded() const { @@ -138,39 +274,127 @@ ExprGroup ExprGroup::loaded() const { return indices_loaded; } -DecoratorRecorder dec; - -IRBuilder ¤t_ast_builder() { - return context->builder(); +std::string ExprGroup::serialize() const { + std::string ret; + for (int i = 0; i < (int)exprs.size(); i++) { + ret += exprs[i].serialize(); + if (i + 1 < (int)exprs.size()) { + ret += ", "; + } + } + return ret; } -std::unique_ptr IRBuilder::create_scope( - std::unique_ptr &list) { - TI_ASSERT(list == nullptr); - list = std::make_unique(); - if (!stack.empty()) { - list->parent = stack.back(); - } - return std::make_unique(this, list.get()); +UnaryOpStmt::UnaryOpStmt(taichi::lang::UnaryOpType op_type, + taichi::lang::Stmt *operand) + : op_type(op_type), operand(operand) { + TI_ASSERT(!operand->is()); + cast_type = DataType::unknown; + cast_by_value = true; + TI_STMT_REG_FIELDS; } -Expr &Expr::operator=(const Expr &o) { - if (get_current_program().current_kernel) { - if (expr == nullptr) { - set(o.eval()); - } else if (expr->is_lvalue()) { - current_ast_builder().insert(std::make_unique( - ptr_if_global(*this), load_if_ptr(o))); +bool UnaryOpStmt::same_operation(UnaryOpStmt *o) const { + if (op_type == o->op_type) { + if (op_type == UnaryOpType::cast) { + return cast_type == o->cast_type; } else { - // set(o.eval()); - TI_ERROR("Cannot assign to non-lvalue: {}", serialize()); + return true; } + } + return false; +} + +std::string UnaryOpExpression::serialize() { + if (type == UnaryOpType::cast) { + std::string reint = cast_by_value ? "" : "reinterpret_"; + return fmt::format("({}{}<{}> {})", reint, unary_op_type_name(type), + data_type_name(cast_type), operand->serialize()); + } else { + return fmt::format("({} {})", unary_op_type_name(type), + operand->serialize()); + } +} + +void UnaryOpExpression::flatten(VecStatement &ret) { + operand->flatten(ret); + auto unary = std::make_unique(type, operand->stmt); + if (type == UnaryOpType::cast) { + unary->cast_type = cast_type; + unary->cast_by_value = cast_by_value; + } + stmt = unary.get(); + stmt->tb = tb; + ret.push_back(std::move(unary)); +} + +ExternalPtrStmt::ExternalPtrStmt( + const taichi::lang::LaneAttribute &base_ptrs, + const std::vector &indices) + : base_ptrs(base_ptrs), indices(indices) { + DataType dt = DataType::f32; + for (int i = 0; i < (int)base_ptrs.size(); i++) { + TI_ASSERT(base_ptrs[i] != nullptr); + TI_ASSERT(base_ptrs[i]->is()); + } + width() = base_ptrs.size(); + element_type() = dt; + TI_STMT_REG_FIELDS; +} + +GlobalPtrStmt::GlobalPtrStmt( + const taichi::lang::LaneAttribute &snodes, + const std::vector &indices, + bool activate) + : snodes(snodes), indices(indices), activate(activate) { + for (int i = 0; i < (int)snodes.size(); i++) { + TI_ASSERT(snodes[i] != nullptr); + TI_ASSERT(snodes[0]->dt == snodes[i]->dt); + } + width() = snodes.size(); + element_type() = snodes[0]->dt; + TI_STMT_REG_FIELDS; +} + +std::string GlobalPtrExpression::serialize() { + std::string s = fmt::format("{}[", var.serialize()); + for (int i = 0; i < (int)indices.size(); i++) { + s += indices.exprs[i]->serialize(); + if (i + 1 < (int)indices.size()) + s += ", "; + } + s += "]"; + return s; +} + +void GlobalPtrExpression::flatten(VecStatement &ret) { + std::vector index_stmts; + for (int i = 0; i < (int)indices.size(); i++) { + indices.exprs[i]->flatten(ret); + index_stmts.push_back(indices.exprs[i]->stmt); + } + if (var.is()) { + ret.push_back(std::make_unique( + var.cast()->snode, index_stmts)); } else { - set(o); + TI_ASSERT(var.is()); + var->flatten(ret); + ret.push_back(std::make_unique( + var.cast()->stmt, index_stmts)); } - return *this; + stmt = ret.back().get(); } +GetChStmt::GetChStmt(taichi::lang::Stmt *input_ptr, int chid) + : input_ptr(input_ptr), chid(chid) { + TI_ASSERT(input_ptr->is()); + input_snode = input_ptr->as()->snode; + output_snode = input_snode->ch[chid].get(); + TI_STMT_REG_FIELDS; +} + +DecoratorRecorder dec; + FrontendContext::FrontendContext() { root_node = std::make_unique(); current_builder = std::make_unique(root_node.get()); @@ -291,12 +515,6 @@ FrontendForStmt::FrontendForStmt(const ExprGroup &loop_var, } } -IRNode *Stmt::get_ir_root() { - auto block = parent; - while (block->parent) - block = block->parent; - return dynamic_cast(block); -} FrontendAssignStmt::FrontendAssignStmt(const Expr &lhs, const Expr &rhs) : lhs(lhs), rhs(rhs) { @@ -313,89 +531,8 @@ IRNode *FrontendContext::root() { return static_cast(root_node.get()); } -std::atomic Stmt::instance_id_counter(0); - std::unique_ptr context; -Expr Expr::parent() const { - TI_ASSERT(is()); - return Expr::make( - cast()->snode->parent); -} - -SNode *Expr::snode() const { - TI_ASSERT(is()); - return cast()->snode; -} - -Expr Expr::operator!() { - return Expr::make(UnaryOpType::logic_not, expr); -} - -void Expr::declare(DataType dt) { - set(Expr::make(dt, Identifier())); -} - -void Expr::set_grad(const Expr &o) { - this->cast()->adjoint.set(o); -} - -Stmt *Stmt::insert_before_me(std::unique_ptr &&new_stmt) { - auto ret = new_stmt.get(); - TI_ASSERT(parent); - auto &stmts = parent->statements; - int loc = -1; - for (int i = 0; i < (int)stmts.size(); i++) { - if (stmts[i].get() == this) { - loc = i; - break; - } - } - TI_ASSERT(loc != -1); - new_stmt->parent = parent; - stmts.insert(stmts.begin() + loc, std::move(new_stmt)); - return ret; -} - -Stmt *Stmt::insert_after_me(std::unique_ptr &&new_stmt) { - auto ret = new_stmt.get(); - TI_ASSERT(parent); - auto &stmts = parent->statements; - int loc = -1; - for (int i = 0; i < (int)stmts.size(); i++) { - if (stmts[i].get() == this) { - loc = i; - break; - } - } - TI_ASSERT(loc != -1); - new_stmt->parent = parent; - stmts.insert(stmts.begin() + loc + 1, std::move(new_stmt)); - return ret; -} - -void Stmt::replace_with(Stmt *new_stmt) { - auto root = get_ir_root(); - irpass::replace_all_usages_with(root, this, new_stmt); - // Note: the current structure should have been destroyed now.. -} - -void Stmt::replace_with(VecStatement &&new_statements, bool replace_usages) { - parent->replace_with(this, std::move(new_statements), replace_usages); -} - -void Stmt::replace_operand_with(Stmt *old_stmt, Stmt *new_stmt) { - operand_bitmap = 0; - int n_op = num_operands(); - for (int i = 0; i < n_op; i++) { - if (operand(i) == old_stmt) { - *operands[i] = new_stmt; - } - operand_bitmap |= operand_hash(operand(i)); - } - rebuild_operand_bitmap(); -} - Block *current_block = nullptr; Expr Var(const Expr &x) { diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index ff4ba9c7f9a4f..e7db7b8855abd 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -159,23 +159,9 @@ struct VectorType { return !(*this == o); } - std::string pointer_suffix() const { - if (is_pointer()) { - return "*"; - } else { - return ""; - } - } - - std::string element_type_name() const { - return fmt::format("{}{}", data_type_short_name(data_type), - pointer_suffix()); - } - - std::string str() const { - auto ename = element_type_name(); - return fmt::format("{:4}x{}", ename, width); - } + std::string pointer_suffix() const; + std::string element_type_name() const; + std::string str() const; bool is_pointer() const { return _is_pointer; @@ -247,11 +233,8 @@ class IRBuilder { }; std::unique_ptr create_scope(std::unique_ptr &list); - Block *current_block(); - Stmt *get_last_stmt(); - void stop_gradient(SNode *); }; @@ -548,20 +531,9 @@ class StmtFieldSNode final : public StmtField { explicit StmtFieldSNode(SNode *const &snode) : snode(snode) { } - static int get_snode_id(SNode *snode) { - if (snode == nullptr) - return -1; - return snode->id; - } + static int get_snode_id(SNode *snode); - bool equal(const StmtField *other_generic) const override { - if (auto other = dynamic_cast(other_generic)) { - return get_snode_id(snode) == get_snode_id(other->snode); - } else { - // Different types - return false; - } - } + bool equal(const StmtField *other_generic) const override; }; class StmtFieldManager { @@ -588,18 +560,7 @@ class StmtFieldManager { this->operator()(rest_names.c_str(), std::forward(rest)...); } - bool equal(StmtFieldManager &other) const { - if (fields.size() != other.fields.size()) { - return false; - } - auto num_fields = fields.size(); - for (std::size_t i = 0; i < num_fields; i++) { - if (!fields[i]->equal(other.fields[i].get())) { - return false; - } - } - return true; - } + bool equal(StmtFieldManager &other) const; }; #define TI_STMT_DEF_FIELDS(...) TI_IO_DEF(__VA_ARGS__) @@ -623,17 +584,10 @@ class Stmt : public IRNode { bool fields_registered; std::string tb; bool is_ptr; + VectorType ret_type; Stmt(const Stmt &stmt) = delete; - - Stmt() : field_manager(this), fields_registered(false) { - parent = nullptr; - instance_id = instance_id_counter++; - id = instance_id; - operand_bitmap = 0; - erased = false; - is_ptr = false; - } + Stmt(); static uint64 operand_hash(Stmt *stmt) { return uint64(1) << ((uint64(stmt) >> 4) % 64); @@ -655,18 +609,11 @@ class Stmt : public IRNode { return ret_type.data_type; } - VectorType ret_type; - std::string ret_data_type_name() const { return ret_type.str(); } - std::string type_hint() const { - if (ret_type.data_type == DataType::unknown) - return ""; - else - return fmt::format("<{}>{}", ret_type.str(), is_ptr ? "ptr " : " "); - } + std::string type_hint() const; std::string name() const { return fmt::format("${}", id); @@ -689,13 +636,7 @@ class Stmt : public IRNode { return *operands[i]; } - std::vector get_operands() const { - std::vector ret; - for (int i = 0; i < num_operands(); i++) { - ret.push_back(*operands[i]); - } - return ret; - } + std::vector get_operands() const; void rebuild_operand_bitmap() { return; // disable bitmap maintenance since the fact that the user can @@ -708,20 +649,9 @@ class Stmt : public IRNode { } } - void set_operand(int i, Stmt *stmt) { - *operands[i] = stmt; - rebuild_operand_bitmap(); - } - - void register_operand(Stmt *&stmt) { - operands.push_back(&stmt); - rebuild_operand_bitmap(); - } - - void mark_fields_registered() { - TI_ASSERT(!fields_registered); - fields_registered = true; - } + void set_operand(int i, Stmt *stmt); + void register_operand(Stmt *&stmt); + void mark_fields_registered(); virtual void rebuild_operands() { TI_NOT_IMPLEMENTED; @@ -732,9 +662,7 @@ class Stmt : public IRNode { } void replace_with(Stmt *new_stmt); - void replace_with(VecStatement &&new_statements, bool replace_usages = true); - virtual void replace_operand_with(Stmt *old_stmt, Stmt *new_stmt); IRNode *get_ir_root(); @@ -808,13 +736,7 @@ class Expression { attributes[key] = value; } - std::string get_attribute(const std::string &key) const { - if (auto it = attributes.find(key); it == attributes.end()) { - TI_ERROR("Attribute {} not found.", key); - } else { - return it->second; - } - } + std::string get_attribute(const std::string &key) const; }; class ExprGroup { @@ -859,17 +781,7 @@ class ExprGroup { return exprs[i]; } - std::string serialize() { - std::string ret; - for (int i = 0; i < (int)exprs.size(); i++) { - ret += exprs[i].serialize(); - if (i + 1 < (int)exprs.size()) { - ret += ", "; - } - } - return ret; - } - + std::string serialize() const; ExprGroup loaded() const; }; @@ -969,24 +881,9 @@ class UnaryOpStmt : public Stmt { DataType cast_type; bool cast_by_value = true; - UnaryOpStmt(UnaryOpType op_type, Stmt *operand) - : op_type(op_type), operand(operand) { - TI_ASSERT(!operand->is()); - cast_type = DataType::unknown; - cast_by_value = true; - TI_STMT_REG_FIELDS; - } + UnaryOpStmt(UnaryOpType op_type, Stmt *operand); - bool same_operation(UnaryOpStmt *o) const { - if (op_type == o->op_type) { - if (op_type == UnaryOpType::cast) { - return cast_type == o->cast_type; - } else { - return true; - } - } - return false; - } + bool same_operation(UnaryOpStmt *o) const; virtual bool has_global_side_effect() const override { return false; @@ -1114,28 +1011,9 @@ class UnaryOpExpression : public Expression { cast_by_value = true; } - std::string serialize() override { - if (type == UnaryOpType::cast) { - std::string reint = cast_by_value ? "" : "reinterpret_"; - return fmt::format("({}{}<{}> {})", reint, unary_op_type_name(type), - data_type_name(cast_type), operand->serialize()); - } else { - return fmt::format("({} {})", unary_op_type_name(type), - operand->serialize()); - } - } + std::string serialize() override; - void flatten(VecStatement &ret) override { - operand->flatten(ret); - auto unary = std::make_unique(type, operand->stmt); - if (type == UnaryOpType::cast) { - unary->cast_type = cast_type; - unary->cast_by_value = cast_by_value; - } - stmt = unary.get(); - stmt->tb = tb; - ret.push_back(std::move(unary)); - } + void flatten(VecStatement &ret) override; }; class BinaryOpStmt : public Stmt { @@ -1220,12 +1098,12 @@ class BinaryOpExpression : public Expression { } }; -class TrinaryOpExpression : public Expression { +class TernaryOpExpression : public Expression { public: TernaryOpType type; Expr op1, op2, op3; - TrinaryOpExpression(TernaryOpType type, + TernaryOpExpression(TernaryOpType type, const Expr &op1, const Expr &op2, const Expr &op3) @@ -1259,17 +1137,7 @@ class ExternalPtrStmt : public Stmt { bool activate; ExternalPtrStmt(const LaneAttribute &base_ptrs, - const std::vector &indices) - : base_ptrs(base_ptrs), indices(indices) { - DataType dt = DataType::f32; - for (int i = 0; i < (int)base_ptrs.size(); i++) { - TI_ASSERT(base_ptrs[i] != nullptr); - TI_ASSERT(base_ptrs[i]->is()); - } - width() = base_ptrs.size(); - element_type() = dt; - TI_STMT_REG_FIELDS; - } + const std::vector &indices); virtual bool has_global_side_effect() const override { return false; @@ -1287,16 +1155,7 @@ class GlobalPtrStmt : public Stmt { GlobalPtrStmt(const LaneAttribute &snodes, const std::vector &indices, - bool activate = true) - : snodes(snodes), indices(indices), activate(activate) { - for (int i = 0; i < (int)snodes.size(); i++) { - TI_ASSERT(snodes[i] != nullptr); - TI_ASSERT(snodes[0]->dt == snodes[i]->dt); - } - width() = snodes.size(); - element_type() = snodes[0]->dt; - TI_STMT_REG_FIELDS; - } + bool activate = true); virtual bool has_global_side_effect() const override { return activate; @@ -1377,34 +1236,9 @@ class GlobalPtrExpression : public Expression { : var(var), indices(indices) { } - std::string serialize() override { - std::string s = fmt::format("{}[", var.serialize()); - for (int i = 0; i < (int)indices.size(); i++) { - s += indices.exprs[i]->serialize(); - if (i + 1 < (int)indices.size()) - s += ", "; - } - s += "]"; - return s; - } + std::string serialize() override; - void flatten(VecStatement &ret) override { - std::vector index_stmts; - for (int i = 0; i < (int)indices.size(); i++) { - indices.exprs[i]->flatten(ret); - index_stmts.push_back(indices.exprs[i]->stmt); - } - if (var.is()) { - ret.push_back(std::make_unique( - var.cast()->snode, index_stmts)); - } else { - TI_ASSERT(var.is()); - var->flatten(ret); - ret.push_back(std::make_unique( - var.cast()->stmt, index_stmts)); - } - stmt = ret.back().get(); - } + void flatten(VecStatement &ret) override; bool is_lvalue() const override { return true; @@ -1413,27 +1247,6 @@ class GlobalPtrExpression : public Expression { #include "expression.h" -Expr select(const Expr &cond, const Expr &true_val, const Expr &false_val); - -Expr operator-(const Expr &expr); - -Expr operator~(const Expr &expr); - -// Value cast -Expr cast(const Expr &input, DataType dt); - -template -Expr cast(const Expr &input) { - return taichi::lang::cast(input, get_data_type()); -} - -Expr bit_cast(const Expr &input, DataType dt); - -template -Expr bit_cast(const Expr &input) { - return taichi::lang::bit_cast(input, get_data_type()); -} - class Block : public IRNode { public: Block *parent; @@ -2362,10 +2175,6 @@ inline Expr ptr_if_global(const Expr &var) { } } -inline Expr smart_load(const Expr &var) { - return load_if_ptr(ptr_if_global(var)); -} - extern DecoratorRecorder dec; inline void Vectorize(int v) { From d901a41fed900bb94f66294645e7b1daed502f03 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Sun, 12 Apr 2020 22:45:27 -0400 Subject: [PATCH 3/7] finish --- taichi/ir/expr.cpp | 62 ++++++++++ taichi/ir/ir.cpp | 302 +++++++++++++++++++++++++++++++++++---------- taichi/ir/ir.h | 229 +++------------------------------- 3 files changed, 315 insertions(+), 278 deletions(-) diff --git a/taichi/ir/expr.cpp b/taichi/ir/expr.cpp index a5fcc69c7e4d9..03a8253bdb261 100644 --- a/taichi/ir/expr.cpp +++ b/taichi/ir/expr.cpp @@ -93,4 +93,66 @@ void Expr::set_grad(const Expr &o) { this->cast()->adjoint.set(o); } +Expr::Expr(int32 x) : Expr() { + expr = std::make_shared(x); +} + +Expr::Expr(int64 x) : Expr() { + expr = std::make_shared(x); +} + +Expr::Expr(float32 x) : Expr() { + expr = std::make_shared(x); +} + +Expr::Expr(float64 x) : Expr() { + expr = std::make_shared(x); +} + +Expr::Expr(const Identifier &id) : Expr() { + expr = std::make_shared(id); +} + +Expr Expr::eval() const { + TI_ASSERT(expr != nullptr); + if (is()) { + return *this; + } + auto eval_stmt = Stmt::make(*this); + auto eval_expr = Expr::make(eval_stmt.get()); + eval_stmt->as()->eval_expr.set(eval_expr); + // needed in lower_ast to replace the statement itself with the + // lowered statement + current_ast_builder().insert(std::move(eval_stmt)); + return eval_expr; +} + +void Expr::operator+=(const Expr &o) { + if (this->atomic) { + current_ast_builder().insert(Stmt::make( + AtomicOpType::add, ptr_if_global(*this), load_if_ptr(o))); + } else { + (*this) = (*this) + o; + } +} + +void Expr::operator-=(const Expr &o) { + if (this->atomic) { + current_ast_builder().insert(Stmt::make( + AtomicOpType::add, *this, -load_if_ptr(o))); + } else { + (*this) = (*this) - o; + } +} + +void Expr::operator*=(const Expr &o) { + TI_ASSERT(!this->atomic); + (*this) = (*this) * load_if_ptr(o); +} + +void Expr::operator/=(const Expr &o) { + TI_ASSERT(!this->atomic); + (*this) = (*this) / load_if_ptr(o); +} + TLANG_NAMESPACE_END diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index 96ddb82d5b16d..41904b672367e 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -120,7 +120,7 @@ inline Expr smart_load(const Expr &var) { return load_if_ptr(ptr_if_global(var)); } -int StmtFieldSNode::get_snode_id(taichi::lang::SNode *snode) { +int StmtFieldSNode::get_snode_id(SNode *snode) { if (snode == nullptr) return -1; return snode->id; @@ -285,8 +285,8 @@ std::string ExprGroup::serialize() const { return ret; } -UnaryOpStmt::UnaryOpStmt(taichi::lang::UnaryOpType op_type, - taichi::lang::Stmt *operand) +UnaryOpStmt::UnaryOpStmt(UnaryOpType op_type, + Stmt *operand) : op_type(op_type), operand(operand) { TI_ASSERT(!operand->is()); cast_type = DataType::unknown; @@ -329,7 +329,7 @@ void UnaryOpExpression::flatten(VecStatement &ret) { } ExternalPtrStmt::ExternalPtrStmt( - const taichi::lang::LaneAttribute &base_ptrs, + const LaneAttribute &base_ptrs, const std::vector &indices) : base_ptrs(base_ptrs), indices(indices) { DataType dt = DataType::f32; @@ -343,7 +343,7 @@ ExternalPtrStmt::ExternalPtrStmt( } GlobalPtrStmt::GlobalPtrStmt( - const taichi::lang::LaneAttribute &snodes, + const LaneAttribute &snodes, const std::vector &indices, bool activate) : snodes(snodes), indices(indices), activate(activate) { @@ -385,7 +385,7 @@ void GlobalPtrExpression::flatten(VecStatement &ret) { stmt = ret.back().get(); } -GetChStmt::GetChStmt(taichi::lang::Stmt *input_ptr, int chid) +GetChStmt::GetChStmt(Stmt *input_ptr, int chid) : input_ptr(input_ptr), chid(chid) { TI_ASSERT(input_ptr->is()); input_snode = input_ptr->as()->snode; @@ -400,65 +400,6 @@ FrontendContext::FrontendContext() { current_builder = std::make_unique(root_node.get()); } -Expr::Expr(int32 x) : Expr() { - expr = std::make_shared(x); -} - -Expr::Expr(int64 x) : Expr() { - expr = std::make_shared(x); -} - -Expr::Expr(float32 x) : Expr() { - expr = std::make_shared(x); -} - -Expr::Expr(float64 x) : Expr() { - expr = std::make_shared(x); -} - -Expr::Expr(const Identifier &id) : Expr() { - expr = std::make_shared(id); -} - -Expr Expr::eval() const { - TI_ASSERT(expr != nullptr); - if (is()) { - return *this; - } - auto eval_stmt = Stmt::make(*this); - auto eval_expr = Expr::make(eval_stmt.get()); - eval_stmt->as()->eval_expr.set(eval_expr); - // needed in lower_ast to replace the statement itself with the - // lowered statement - current_ast_builder().insert(std::move(eval_stmt)); - return eval_expr; -} - -void Expr::operator+=(const Expr &o) { - if (this->atomic) { - current_ast_builder().insert(Stmt::make( - AtomicOpType::add, ptr_if_global(*this), load_if_ptr(o))); - } else { - (*this) = (*this) + o; - } -} -void Expr::operator-=(const Expr &o) { - if (this->atomic) { - current_ast_builder().insert(Stmt::make( - AtomicOpType::add, *this, -load_if_ptr(o))); - } else { - (*this) = (*this) - o; - } -} -void Expr::operator*=(const Expr &o) { - TI_ASSERT(!this->atomic); - (*this) = (*this) * load_if_ptr(o); -} -void Expr::operator/=(const Expr &o) { - TI_ASSERT(!this->atomic); - (*this) = (*this) / load_if_ptr(o); -} - FrontendForStmt::FrontendForStmt(const Expr &loop_var, const Expr &begin, const Expr &end) @@ -515,7 +456,6 @@ FrontendForStmt::FrontendForStmt(const ExprGroup &loop_var, } } - FrontendAssignStmt::FrontendAssignStmt(const Expr &lhs, const Expr &rhs) : lhs(lhs), rhs(rhs) { TI_ASSERT(lhs->is_lvalue()); @@ -547,6 +487,22 @@ void Print_(const Expr &a, const std::string &str) { current_ast_builder().insert(std::make_unique(a, str)); } +Expr load(const Expr &ptr) { + TI_ASSERT(ptr.is()); + return Expr::make(ptr); +} + +Expr ptr_if_global(const Expr &var) { + if (var.is()) { + // singleton global variable + TI_ASSERT(var.snode()->num_active_indices == 0); + return var[ExprGroup()]; + } else { + // may be any local or global expr + return var; + } +} + template <> std::string to_string(const LaneAttribute &ptr) { std::string ret = " ["; @@ -582,6 +538,29 @@ Stmt *LocalLoadStmt::previous_store_or_alloca_in_block() { return nullptr; } +void LocalLoadStmt::rebuild_operands() { + operands.clear(); + for (int i = 0; i < (int)ptr.size(); i++) { + register_operand(this->ptr[i].var); + } +} + +bool LocalLoadStmt::same_source() const { + for (int i = 1; i < (int)ptr.size(); i++) { + if (ptr[i].var != ptr[0].var) + return false; + } + return true; +} + +bool LocalLoadStmt::has_source(Stmt *_alloca) const { + for (int i = 0; i < width(); i++) { + if (ptr[i].var == alloca) + return true; + } + return false; +} + void Block::erase(int location) { statements[location]->erased = true; trash_bin.push_back(std::move(statements[location])); // do not delete the @@ -651,7 +630,7 @@ void Block::replace_with(Stmt *old_statement, replace_with(old_statement, std::move(vec)); } -Stmt *Block::lookup_var(const taichi::lang::Ident &ident) const { +Stmt *Block::lookup_var(const Ident &ident) const { auto ptr = local_var_alloca.find(ident); if (ptr != local_var_alloca.end()) { return ptr->second; @@ -674,6 +653,174 @@ Stmt *Block::mask() { } } +void Block::set_statements(VecStatement &&stmts) { + statements.clear(); + for (int i = 0; i < (int)stmts.size(); i++) { + insert(std::move(stmts[i]), i); + } +} + +void Block::insert_before(Stmt *old_statement, VecStatement &&new_statements) { + int location = -1; + for (int i = 0; i < (int)statements.size(); i++) { + if (old_statement == statements[i].get()) { + location = i; + break; + } + } + TI_ASSERT(location != -1); + for (int i = (int)new_statements.size() - 1; i >= 0; i--) { + insert(std::move(new_statements[i]), location); + } +} + +void Block::replace_with(Stmt *old_statement, + VecStatement &&new_statements, + bool replace_usages) { + int location = -1; + for (int i = 0; i < (int)statements.size(); i++) { + if (old_statement == statements[i].get()) { + location = i; + break; + } + } + TI_ASSERT(location != -1); + if (replace_usages) + old_statement->replace_with(new_statements.back().get()); + trash_bin.push_back(std::move(statements[location])); + statements.erase(statements.begin() + location); + for (int i = (int)new_statements.size() - 1; i >= 0; i--) { + insert(std::move(new_statements[i]), location); + } +} + +bool Block::has_container_statements() { + for (auto &s : statements) { + if (s->is_container_statement()) + return true; + } + return false; +} + +int Block::locate(Stmt *stmt) { + for (int i = 0; i < (int)statements.size(); i++) { + if (statements[i].get() == stmt) { + return i; + } + } + return -1; +} + +FrontendSNodeOpStmt::FrontendSNodeOpStmt( + SNodeOpType op_type, + SNode *snode, + const ExprGroup &indices, + const Expr &val) + : op_type(op_type), snode(snode), indices(indices.loaded()), val(val) { + if (val.expr != nullptr) { + TI_ASSERT(op_type == SNodeOpType::append); + this->val.set(load_if_ptr(val)); + } else { + TI_ASSERT(op_type != SNodeOpType::append); + } +} + +SNodeOpStmt::SNodeOpStmt(SNodeOpType op_type, + SNode *snode, + Stmt *ptr, + Stmt *val) + : op_type(op_type), snode(snode), ptr(ptr), val(val) { + width() = 1; + element_type() = DataType::i32; + TI_STMT_REG_FIELDS; +} + +SNodeOpStmt::SNodeOpStmt(SNodeOpType op_type, + SNode *snode, + const std::vector &indices) + : op_type(op_type), snode(snode), indices(indices) { + ptr = nullptr; + val = nullptr; + TI_ASSERT(op_type == SNodeOpType::is_active || + op_type == SNodeOpType::deactivate); + width() = 1; + element_type() = DataType::i32; + TI_STMT_REG_FIELDS; +} + +std::string AtomicOpExpression::serialize() { + if (op_type == AtomicOpType::add) { + return fmt::format("atomic_add({}, {})", dest.serialize(), + val.serialize()); + } else if (op_type == AtomicOpType::sub) { + return fmt::format("atomic_sub({}, {})", dest.serialize(), + val.serialize()); + } else if (op_type == AtomicOpType::min) { + return fmt::format("atomic_min({}, {})", dest.serialize(), + val.serialize()); + } else if (op_type == AtomicOpType::max) { + return fmt::format("atomic_max({}, {})", dest.serialize(), + val.serialize()); + } else if (op_type == AtomicOpType::bit_and) { + return fmt::format("atomic_bit_and({}, {})", dest.serialize(), + val.serialize()); + } else if (op_type == AtomicOpType::bit_or) { + return fmt::format("atomic_bit_or({}, {})", dest.serialize(), + val.serialize()); + } else if (op_type == AtomicOpType::bit_xor) { + return fmt::format("atomic_bit_xor({}, {})", dest.serialize(), + val.serialize()); + } else { + // min/max not supported in the LLVM backend yet. + TI_NOT_IMPLEMENTED; + } +} + +std::string SNodeOpExpression::serialize() { + if (value.expr) { + return fmt::format("{}({}, [{}], {})", snode_op_type_name(op_type), + snode->get_node_type_name_hinted(), + indices.serialize(), value.serialize()); + } else { + return fmt::format("{}({}, [{}])", snode_op_type_name(op_type), + snode->get_node_type_name_hinted(), + indices.serialize()); + } +} + +void SNodeOpExpression::flatten(VecStatement &ret) { + std::vector indices_stmt; + for (int i = 0; i < (int)indices.size(); i++) { + indices[i]->flatten(ret); + indices_stmt.push_back(indices[i]->stmt); + } + if (op_type == SNodeOpType::is_active) { + // is_active cannot be lowered all the way to a global pointer. + // It should be lowered into a pointer to parent and an index. + TI_ERROR_IF( + snode->type != SNodeType::pointer && snode->type != SNodeType::hash && + snode->type != SNodeType::bitmasked, + "ti.is_active only works on pointer, hash or bitmasked nodes."); + ret.push_back(SNodeOpType::is_active, snode, indices_stmt); + } else { + auto ptr = ret.push_back(snode, indices_stmt); + if (op_type == SNodeOpType::append) { + value->flatten(ret); + ret.push_back(SNodeOpType::append, snode, ptr, + ret.back().get()); + TI_ERROR_IF(snode->type != SNodeType::dynamic, + "ti.append only works on dynamic nodes."); + TI_ERROR_IF(snode->ch.size() != 1, + "ti.append only works on single-child dynamic nodes."); + TI_ERROR_IF(data_type_size(snode->ch[0]->dt) != 4, + "ti.append only works on i32/f32 nodes."); + } else if (op_type == SNodeOpType::length) { + ret.push_back(SNodeOpType::length, snode, ptr, nullptr); + } + } + stmt = ret.back().get(); +} + For::For(const Expr &s, const Expr &e, const std::function &func) { auto i = Expr(std::make_shared()); auto stmt_unique = std::make_unique(i, s, e); @@ -683,6 +830,27 @@ For::For(const Expr &s, const Expr &e, const std::function &func) { func(i); } +For::For(const Expr &i, + const Expr &s, + const Expr &e, + const std::function &func) { + auto stmt_unique = std::make_unique(i, s, e); + auto stmt = stmt_unique.get(); + current_ast_builder().insert(std::move(stmt_unique)); + auto _ = current_ast_builder().create_scope(stmt->body); + func(); +} + +For::For(const ExprGroup &i, + const Expr &global, + const std::function &func) { + auto stmt_unique = std::make_unique(i, global); + auto stmt = stmt_unique.get(); + current_ast_builder().insert(std::move(stmt_unique)); + auto _ = current_ast_builder().create_scope(stmt->body); + func(); +} + OffloadedStmt::OffloadedStmt(OffloadedStmt::TaskType task_type) : OffloadedStmt(task_type, nullptr) { } diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index e7db7b8855abd..d8b1708b23c0a 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -1260,82 +1260,22 @@ class Block : public IRNode { parent = nullptr; } - bool has_container_statements() { - for (auto &s : statements) { - if (s->is_container_statement()) - return true; - } - return false; - } - - int locate(Stmt *stmt) { - for (int i = 0; i < (int)statements.size(); i++) { - if (statements[i].get() == stmt) { - return i; - } - } - return -1; - } - + bool has_container_statements(); + int locate(Stmt *stmt); void erase(int location); - void erase(Stmt *stmt); - std::unique_ptr extract(int location); - std::unique_ptr extract(Stmt *stmt); - void insert(std::unique_ptr &&stmt, int location = -1); - void insert(VecStatement &&stmt, int location = -1); - void replace_statements_in_range(int start, int end, VecStatement &&stmts); - - void set_statements(VecStatement &&stmts) { - statements.clear(); - for (int i = 0; i < (int)stmts.size(); i++) { - insert(std::move(stmts[i]), i); - } - } - + void set_statements(VecStatement &&stmts); void replace_with(Stmt *old_statement, std::unique_ptr &&new_statement); - - void insert_before(Stmt *old_statement, VecStatement &&new_statements) { - int location = -1; - for (int i = 0; i < (int)statements.size(); i++) { - if (old_statement == statements[i].get()) { - location = i; - break; - } - } - TI_ASSERT(location != -1); - for (int i = (int)new_statements.size() - 1; i >= 0; i--) { - insert(std::move(new_statements[i]), location); - } - } - + void insert_before(Stmt *old_statement, VecStatement &&new_statements); void replace_with(Stmt *old_statement, VecStatement &&new_statements, - bool replace_usages = true) { - int location = -1; - for (int i = 0; i < (int)statements.size(); i++) { - if (old_statement == statements[i].get()) { - location = i; - break; - } - } - TI_ASSERT(location != -1); - if (replace_usages) - old_statement->replace_with(new_statements.back().get()); - trash_bin.push_back(std::move(statements[location])); - statements.erase(statements.begin() + location); - for (int i = (int)new_statements.size() - 1; i >= 0; i--) { - insert(std::move(new_statements[i]), location); - } - } - + bool replace_usages = true); Stmt *lookup_var(const Ident &ident) const; - Stmt *mask(); Stmt *back() const { @@ -1381,15 +1321,7 @@ class FrontendSNodeOpStmt : public Stmt { FrontendSNodeOpStmt(SNodeOpType op_type, SNode *snode, const ExprGroup &indices, - const Expr &val = Expr(nullptr)) - : op_type(op_type), snode(snode), indices(indices.loaded()), val(val) { - if (val.expr != nullptr) { - TI_ASSERT(op_type == SNodeOpType::append); - this->val.set(load_if_ptr(val)); - } else { - TI_ASSERT(op_type != SNodeOpType::append); - } - } + const Expr &val = Expr(nullptr)); DEFINE_ACCEPT }; @@ -1402,25 +1334,11 @@ class SNodeOpStmt : public Stmt { Stmt *val; std::vector indices; - SNodeOpStmt(SNodeOpType op_type, SNode *snode, Stmt *ptr, Stmt *val = nullptr) - : op_type(op_type), snode(snode), ptr(ptr), val(val) { - width() = 1; - element_type() = DataType::i32; - TI_STMT_REG_FIELDS; - } + SNodeOpStmt(SNodeOpType op_type, SNode *snode, Stmt *ptr, Stmt *val = nullptr); SNodeOpStmt(SNodeOpType op_type, SNode *snode, - const std::vector &indices) - : op_type(op_type), snode(snode), indices(indices) { - ptr = nullptr; - val = nullptr; - TI_ASSERT(op_type == SNodeOpType::is_active || - op_type == SNodeOpType::deactivate); - width() = 1; - element_type() = DataType::i32; - TI_STMT_REG_FIELDS; - } + const std::vector &indices); static bool activation_related(SNodeOpType op) { return op == SNodeOpType::activate || op == SNodeOpType::deactivate || @@ -1538,28 +1456,9 @@ class LocalLoadStmt : public Stmt { TI_STMT_REG_FIELDS; } - void rebuild_operands() override { - operands.clear(); - for (int i = 0; i < (int)ptr.size(); i++) { - register_operand(this->ptr[i].var); - } - } - - bool same_source() const { - for (int i = 1; i < (int)ptr.size(); i++) { - if (ptr[i].var != ptr[0].var) - return false; - } - return true; - } - - bool has_source(Stmt *alloca) const { - for (int i = 0; i < width(); i++) { - if (ptr[i].var == alloca) - return true; - } - return false; - } + void rebuild_operands() override; + bool same_source() const; + bool has_source(Stmt *alloca) const; bool integral_operands() const override { return false; @@ -2022,33 +1921,7 @@ class AtomicOpExpression : public Expression { : op_type(op_type), dest(dest), val(val) { } - std::string serialize() override { - if (op_type == AtomicOpType::add) { - return fmt::format("atomic_add({}, {})", dest.serialize(), - val.serialize()); - } else if (op_type == AtomicOpType::sub) { - return fmt::format("atomic_sub({}, {})", dest.serialize(), - val.serialize()); - } else if (op_type == AtomicOpType::min) { - return fmt::format("atomic_min({}, {})", dest.serialize(), - val.serialize()); - } else if (op_type == AtomicOpType::max) { - return fmt::format("atomic_max({}, {})", dest.serialize(), - val.serialize()); - } else if (op_type == AtomicOpType::bit_and) { - return fmt::format("atomic_bit_and({}, {})", dest.serialize(), - val.serialize()); - } else if (op_type == AtomicOpType::bit_or) { - return fmt::format("atomic_bit_or({}, {})", dest.serialize(), - val.serialize()); - } else if (op_type == AtomicOpType::bit_xor) { - return fmt::format("atomic_bit_xor({}, {})", dest.serialize(), - val.serialize()); - } else { - // min/max not supported in the LLVM backend yet. - TI_NOT_IMPLEMENTED; - } - } + std::string serialize() override; void flatten(VecStatement &ret) override { // FrontendAtomicStmt is the correct place to flatten sub-exprs like |dest| @@ -2077,50 +1950,9 @@ class SNodeOpExpression : public Expression { : snode(snode), op_type(op_type), indices(indices), value(value) { } - std::string serialize() override { - if (value.expr) { - return fmt::format("{}({}, [{}], {})", snode_op_type_name(op_type), - snode->get_node_type_name_hinted(), - indices.serialize(), value.serialize()); - } else { - return fmt::format("{}({}, [{}])", snode_op_type_name(op_type), - snode->get_node_type_name_hinted(), - indices.serialize()); - } - } + std::string serialize() override; - void flatten(VecStatement &ret) override { - std::vector indices_stmt; - for (int i = 0; i < (int)indices.size(); i++) { - indices[i]->flatten(ret); - indices_stmt.push_back(indices[i]->stmt); - } - if (op_type == SNodeOpType::is_active) { - // is_active cannot be lowered all the way to a global pointer. - // It should be lowered into a pointer to parent and an index. - TI_ERROR_IF( - snode->type != SNodeType::pointer && snode->type != SNodeType::hash && - snode->type != SNodeType::bitmasked, - "ti.is_active only works on pointer, hash or bitmasked nodes."); - ret.push_back(SNodeOpType::is_active, snode, indices_stmt); - } else { - auto ptr = ret.push_back(snode, indices_stmt); - if (op_type == SNodeOpType::append) { - value->flatten(ret); - ret.push_back(SNodeOpType::append, snode, ptr, - ret.back().get()); - TI_ERROR_IF(snode->type != SNodeType::dynamic, - "ti.append only works on dynamic nodes."); - TI_ERROR_IF(snode->ch.size() != 1, - "ti.append only works on single-child dynamic nodes."); - TI_ERROR_IF(data_type_size(snode->ch[0]->dt) != 4, - "ti.append only works on i32/f32 nodes."); - } else if (op_type == SNodeOpType::length) { - ret.push_back(SNodeOpType::length, snode, ptr, nullptr); - } - } - stmt = ret.back().get(); - } + void flatten(VecStatement &ret) override; }; class GlobalLoadExpression : public Expression { @@ -2158,22 +1990,9 @@ class ConstExpression : public Expression { } }; -inline Expr load(Expr ptr) { - TI_ASSERT(ptr.is()); - return Expr::make(ptr); -} - +inline Expr load(const Expr &ptr); -inline Expr ptr_if_global(const Expr &var) { - if (var.is()) { - // singleton global variable - TI_ASSERT(var.snode()->num_active_indices == 0); - return var[ExprGroup()]; - } else { - // may be any local or global expr - return var; - } -} +inline Expr ptr_if_global(const Expr &var); extern DecoratorRecorder dec; @@ -2211,23 +2030,11 @@ class For { For(const Expr &i, const Expr &s, const Expr &e, - const std::function &func) { - auto stmt_unique = std::make_unique(i, s, e); - auto stmt = stmt_unique.get(); - current_ast_builder().insert(std::move(stmt_unique)); - auto _ = current_ast_builder().create_scope(stmt->body); - func(); - } + const std::function &func); For(const ExprGroup &i, const Expr &global, - const std::function &func) { - auto stmt_unique = std::make_unique(i, global); - auto stmt = stmt_unique.get(); - current_ast_builder().insert(std::move(stmt_unique)); - auto _ = current_ast_builder().create_scope(stmt->body); - func(); - } + const std::function &func); For(const Expr &s, const Expr &e, const std::function &func); }; From 16e6403ff783b70ecf60714ca0c7e8bb0641d4ca Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Sun, 12 Apr 2020 23:25:53 -0400 Subject: [PATCH 4/7] [skip ci] enforce code format --- taichi/ir/expr.h | 2 -- taichi/ir/ir.cpp | 61 ++++++++++++++++++++---------------------------- taichi/ir/ir.h | 5 +++- 3 files changed, 29 insertions(+), 39 deletions(-) diff --git a/taichi/ir/expr.h b/taichi/ir/expr.h index 1deee99971f05..cfab4599bf4d2 100644 --- a/taichi/ir/expr.h +++ b/taichi/ir/expr.h @@ -123,6 +123,4 @@ Expr bit_cast(const Expr &input) { return taichi::lang::bit_cast(input, get_data_type()); } - - TLANG_NAMESPACE_END diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index 41904b672367e..e6faa02523434 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -27,8 +27,7 @@ std::string VectorType::pointer_suffix() const { } std::string VectorType::element_type_name() const { - return fmt::format("{}{}", data_type_short_name(data_type), - pointer_suffix()); + return fmt::format("{}{}", data_type_short_name(data_type), pointer_suffix()); } std::string VectorType::str() const { @@ -110,7 +109,7 @@ inline Expr load_if_ptr(const Expr &ptr) { return load(ptr); } else if (ptr.is()) { TI_ASSERT(ptr.cast()->snode->num_active_indices == - 0); + 0); return load(ptr[ExprGroup()]); } else return ptr; @@ -258,8 +257,7 @@ void Stmt::mark_fields_registered() { fields_registered = true; } -std::string Expression::get_attribute( - const std::string &key) const { +std::string Expression::get_attribute(const std::string &key) const { if (auto it = attributes.find(key); it == attributes.end()) { TI_ERROR("Attribute {} not found.", key); } else { @@ -285,8 +283,7 @@ std::string ExprGroup::serialize() const { return ret; } -UnaryOpStmt::UnaryOpStmt(UnaryOpType op_type, - Stmt *operand) +UnaryOpStmt::UnaryOpStmt(UnaryOpType op_type, Stmt *operand) : op_type(op_type), operand(operand) { TI_ASSERT(!operand->is()); cast_type = DataType::unknown; @@ -328,9 +325,8 @@ void UnaryOpExpression::flatten(VecStatement &ret) { ret.push_back(std::move(unary)); } -ExternalPtrStmt::ExternalPtrStmt( - const LaneAttribute &base_ptrs, - const std::vector &indices) +ExternalPtrStmt::ExternalPtrStmt(const LaneAttribute &base_ptrs, + const std::vector &indices) : base_ptrs(base_ptrs), indices(indices) { DataType dt = DataType::f32; for (int i = 0; i < (int)base_ptrs.size(); i++) { @@ -342,10 +338,9 @@ ExternalPtrStmt::ExternalPtrStmt( TI_STMT_REG_FIELDS; } -GlobalPtrStmt::GlobalPtrStmt( - const LaneAttribute &snodes, - const std::vector &indices, - bool activate) +GlobalPtrStmt::GlobalPtrStmt(const LaneAttribute &snodes, + const std::vector &indices, + bool activate) : snodes(snodes), indices(indices), activate(activate) { for (int i = 0; i < (int)snodes.size(); i++) { TI_ASSERT(snodes[i] != nullptr); @@ -711,11 +706,10 @@ int Block::locate(Stmt *stmt) { return -1; } -FrontendSNodeOpStmt::FrontendSNodeOpStmt( - SNodeOpType op_type, - SNode *snode, - const ExprGroup &indices, - const Expr &val) +FrontendSNodeOpStmt::FrontendSNodeOpStmt(SNodeOpType op_type, + SNode *snode, + const ExprGroup &indices, + const Expr &val) : op_type(op_type), snode(snode), indices(indices.loaded()), val(val) { if (val.expr != nullptr) { TI_ASSERT(op_type == SNodeOpType::append); @@ -742,7 +736,7 @@ SNodeOpStmt::SNodeOpStmt(SNodeOpType op_type, ptr = nullptr; val = nullptr; TI_ASSERT(op_type == SNodeOpType::is_active || - op_type == SNodeOpType::deactivate); + op_type == SNodeOpType::deactivate); width() = 1; element_type() = DataType::i32; TI_STMT_REG_FIELDS; @@ -750,17 +744,13 @@ SNodeOpStmt::SNodeOpStmt(SNodeOpType op_type, std::string AtomicOpExpression::serialize() { if (op_type == AtomicOpType::add) { - return fmt::format("atomic_add({}, {})", dest.serialize(), - val.serialize()); + return fmt::format("atomic_add({}, {})", dest.serialize(), val.serialize()); } else if (op_type == AtomicOpType::sub) { - return fmt::format("atomic_sub({}, {})", dest.serialize(), - val.serialize()); + return fmt::format("atomic_sub({}, {})", dest.serialize(), val.serialize()); } else if (op_type == AtomicOpType::min) { - return fmt::format("atomic_min({}, {})", dest.serialize(), - val.serialize()); + return fmt::format("atomic_min({}, {})", dest.serialize(), val.serialize()); } else if (op_type == AtomicOpType::max) { - return fmt::format("atomic_max({}, {})", dest.serialize(), - val.serialize()); + return fmt::format("atomic_max({}, {})", dest.serialize(), val.serialize()); } else if (op_type == AtomicOpType::bit_and) { return fmt::format("atomic_bit_and({}, {})", dest.serialize(), val.serialize()); @@ -779,12 +769,11 @@ std::string AtomicOpExpression::serialize() { std::string SNodeOpExpression::serialize() { if (value.expr) { return fmt::format("{}({}, [{}], {})", snode_op_type_name(op_type), - snode->get_node_type_name_hinted(), - indices.serialize(), value.serialize()); + snode->get_node_type_name_hinted(), indices.serialize(), + value.serialize()); } else { return fmt::format("{}({}, [{}])", snode_op_type_name(op_type), - snode->get_node_type_name_hinted(), - indices.serialize()); + snode->get_node_type_name_hinted(), indices.serialize()); } } @@ -797,10 +786,10 @@ void SNodeOpExpression::flatten(VecStatement &ret) { if (op_type == SNodeOpType::is_active) { // is_active cannot be lowered all the way to a global pointer. // It should be lowered into a pointer to parent and an index. - TI_ERROR_IF( - snode->type != SNodeType::pointer && snode->type != SNodeType::hash && - snode->type != SNodeType::bitmasked, - "ti.is_active only works on pointer, hash or bitmasked nodes."); + TI_ERROR_IF(snode->type != SNodeType::pointer && + snode->type != SNodeType::hash && + snode->type != SNodeType::bitmasked, + "ti.is_active only works on pointer, hash or bitmasked nodes."); ret.push_back(SNodeOpType::is_active, snode, indices_stmt); } else { auto ptr = ret.push_back(snode, indices_stmt); diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index d8b1708b23c0a..dd8d4d5bf7a6f 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -1334,7 +1334,10 @@ class SNodeOpStmt : public Stmt { Stmt *val; std::vector indices; - SNodeOpStmt(SNodeOpType op_type, SNode *snode, Stmt *ptr, Stmt *val = nullptr); + SNodeOpStmt(SNodeOpType op_type, + SNode *snode, + Stmt *ptr, + Stmt *val = nullptr); SNodeOpStmt(SNodeOpType op_type, SNode *snode, From b43238dec5ce4892720626c68f76941c15225086 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Sun, 12 Apr 2020 23:59:29 -0400 Subject: [PATCH 5/7] remove `inline` to test CI --- taichi/ir/ir.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index dd8d4d5bf7a6f..22d261a01627f 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -238,8 +238,8 @@ class IRBuilder { void stop_gradient(SNode *); }; -inline Expr load_if_ptr(const Expr &ptr); -inline Expr smart_load(const Expr &var); +Expr load_if_ptr(const Expr &ptr); +Expr smart_load(const Expr &var); class Identifier { public: @@ -1993,9 +1993,9 @@ class ConstExpression : public Expression { } }; -inline Expr load(const Expr &ptr); +Expr load(const Expr &ptr); -inline Expr ptr_if_global(const Expr &var); +Expr ptr_if_global(const Expr &var); extern DecoratorRecorder dec; From b095a9a73478e721954ded85895e48828a1c4f29 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Mon, 13 Apr 2020 13:24:20 -0400 Subject: [PATCH 6/7] fix CI -- why did CLion's "split function into declaration and definition" give me an underscore? --- taichi/ir/ir.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index e6faa02523434..405dd6b760d71 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -548,7 +548,7 @@ bool LocalLoadStmt::same_source() const { return true; } -bool LocalLoadStmt::has_source(Stmt *_alloca) const { +bool LocalLoadStmt::has_source(Stmt *alloca) const { for (int i = 0; i < width(); i++) { if (ptr[i].var == alloca) return true; From 6c902eb3410eced8f1f9df732fb8731391e4b160 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Mon, 13 Apr 2020 14:16:53 -0400 Subject: [PATCH 7/7] fix inline --- taichi/ir/ir.cpp | 34 +++++++++++++++------------------- taichi/ir/ir.h | 11 ++++++----- 2 files changed, 21 insertions(+), 24 deletions(-) diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index 405dd6b760d71..3aa2ce3ee3cf2 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -104,7 +104,7 @@ class StatementTypeNameVisitor : public IRVisitor { #undef PER_STATEMENT }; -inline Expr load_if_ptr(const Expr &ptr) { +Expr load_if_ptr(const Expr &ptr) { if (ptr.is()) { return load(ptr); } else if (ptr.is()) { @@ -115,8 +115,20 @@ inline Expr load_if_ptr(const Expr &ptr) { return ptr; } -inline Expr smart_load(const Expr &var) { - return load_if_ptr(ptr_if_global(var)); +Expr load(const Expr &ptr) { + TI_ASSERT(ptr.is()); + return Expr::make(ptr); +} + +Expr ptr_if_global(const Expr &var) { + if (var.is()) { + // singleton global variable + TI_ASSERT(var.snode()->num_active_indices == 0); + return var[ExprGroup()]; + } else { + // may be any local or global expr + return var; + } } int StmtFieldSNode::get_snode_id(SNode *snode) { @@ -482,22 +494,6 @@ void Print_(const Expr &a, const std::string &str) { current_ast_builder().insert(std::make_unique(a, str)); } -Expr load(const Expr &ptr) { - TI_ASSERT(ptr.is()); - return Expr::make(ptr); -} - -Expr ptr_if_global(const Expr &var) { - if (var.is()) { - // singleton global variable - TI_ASSERT(var.snode()->num_active_indices == 0); - return var[ExprGroup()]; - } else { - // may be any local or global expr - return var; - } -} - template <> std::string to_string(const LaneAttribute &ptr) { std::string ret = " ["; diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index 22d261a01627f..5ce9796fb3644 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -239,7 +239,12 @@ class IRBuilder { }; Expr load_if_ptr(const Expr &ptr); -Expr smart_load(const Expr &var); +Expr load(const Expr &ptr); +Expr ptr_if_global(const Expr &var); + +inline Expr smart_load(const Expr &var) { + return load_if_ptr(ptr_if_global(var)); +} class Identifier { public: @@ -1993,10 +1998,6 @@ class ConstExpression : public Expression { } }; -Expr load(const Expr &ptr); - -Expr ptr_if_global(const Expr &var); - extern DecoratorRecorder dec; inline void Vectorize(int v) {