diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index 9c980bb06c2e8..d8d58b1557fcb 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -352,7 +352,7 @@ void Stmt::replace_operand_with(Stmt *old_stmt, Stmt *new_stmt) { Block *current_block = nullptr; -Expr Var(Expr x) { +Expr Var(const Expr &x) { auto var = Expr(std::make_shared()); current_ast_builder().insert(std::make_unique( std::static_pointer_cast(var.expr)->id, DataType::unknown)); diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index 922e5ce3a282e..b85d43cbb7654 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -419,15 +419,17 @@ struct LaneAttribute { LaneAttribute(const std::vector &data) : data(data) { } - LaneAttribute(const T &t) { - data.resize(1); - data[0] = t; + LaneAttribute(const T &t): data(1, t) { } void resize(int s) { data.resize(s); } + void reserve(int s) { + data.reserve(s); + } + void push_back(const T &t) { data.push_back(t); } @@ -1682,10 +1684,8 @@ struct LocalAddress { Stmt *var; int offset; - LocalAddress() : LocalAddress(nullptr, 0) { - } - LocalAddress(Stmt *var, int offset) : var(var), offset(offset) { + TI_ASSERT(var->is()); } }; @@ -1745,6 +1745,7 @@ class LocalStoreStmt : public Stmt { // LaneAttribute data; LocalStoreStmt(Stmt *ptr, Stmt *data) : ptr(ptr), data(data) { + TI_ASSERT(ptr->is()); TI_STMT_REG_FIELDS; } @@ -2076,7 +2077,7 @@ class FrontendWhileStmt : public Stmt { Expr cond; std::unique_ptr body; - FrontendWhileStmt(Expr cond) : cond(load_if_ptr(cond)) { + FrontendWhileStmt(const Expr &cond) : cond(load_if_ptr(cond)) { } bool is_container_statement() const override { @@ -2139,9 +2140,9 @@ extern Block *current_block; class IdExpression : public Expression { public: Identifier id; - IdExpression(std::string name = "") : id(name) { + IdExpression(const std::string &name = "") : id(name) { } - IdExpression(Identifier id) : id(id) { + IdExpression(const Identifier &id) : id(id) { } std::string serialize() override { @@ -2168,7 +2169,7 @@ class AtomicOpExpression : public Expression { AtomicOpType op_type; Expr dest, val; - AtomicOpExpression(AtomicOpType op_type, Expr dest, Expr val) + AtomicOpExpression(AtomicOpType op_type, const Expr &dest, const Expr &val) : op_type(op_type), dest(dest), val(val) { } @@ -2276,7 +2277,7 @@ class SNodeOpExpression : public Expression { class GlobalLoadExpression : public Expression { public: Expr ptr; - GlobalLoadExpression(Expr ptr) : ptr(ptr) { + GlobalLoadExpression(const Expr &ptr) : ptr(ptr) { } std::string serialize() override { @@ -2372,7 +2373,8 @@ inline void SLP(int v) { class For { public: - For(Expr i, Expr s, Expr e, const std::function &func) { + For(const Expr &i, const Expr &s, const Expr &e, + const std::function &func) { auto stmt_unique = std::make_unique(i, s, e); auto stmt = stmt_unique.get(); current_ast_builder().insert(std::move(stmt_unique)); @@ -2380,7 +2382,8 @@ class For { func(); } - For(ExprGroup i, Expr global, const std::function &func) { + For(const ExprGroup &i, const Expr &global, + const std::function &func) { auto stmt_unique = std::make_unique(i, global); auto stmt = stmt_unique.get(); current_ast_builder().insert(std::move(stmt_unique)); @@ -2393,7 +2396,7 @@ class For { class While { public: - While(Expr cond, const std::function &func) { + While(const Expr &cond, const std::function &func) { auto while_stmt = std::make_unique(cond); FrontendWhileStmt *ptr = while_stmt.get(); current_ast_builder().insert(std::move(while_stmt)); @@ -2402,7 +2405,7 @@ class While { } }; -Expr Var(Expr x); +Expr Var(const Expr &x); class VectorElement { public: diff --git a/taichi/transforms/vector_split.cpp b/taichi/transforms/vector_split.cpp index f1c38e52764d0..5caa0761b5635 100644 --- a/taichi/transforms/vector_split.cpp +++ b/taichi/transforms/vector_split.cpp @@ -172,14 +172,14 @@ class BasicBlockVectorSplit : public IRVisitor { for (int i = 0; i < current_split_factor; i++) { LaneAttribute ptr; int new_width = need_split ? max_width : stmt->width(); - ptr.resize(new_width); + ptr.reserve(new_width); for (int j = 0; j < new_width; j++) { LocalAddress addr(stmt->ptr[lane_start(i) + j]); if (origin2split.find(addr.var) == origin2split.end()) { - ptr[j] = addr; + ptr.push_back(addr); } else { - ptr[j].var = lookup(addr.var, addr.offset / max_width); - ptr[j].offset = addr.offset % max_width; + ptr.push_back(LocalAddress(lookup(addr.var, addr.offset / max_width), + addr.offset % max_width)); } } current_split[i] = Stmt::make(ptr);