Skip to content

Commit

Permalink
[ir] Add assertions of Allocas in the constructors of LocalAddress …
Browse files Browse the repository at this point in the history
…and LocalStoreStmt
  • Loading branch information
xumingkuan authored Apr 7, 2020
1 parent 0c1bb39 commit 213533d
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 20 deletions.
2 changes: 1 addition & 1 deletion taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ void Stmt::replace_operand_with(Stmt *old_stmt, Stmt *new_stmt) {

Block *current_block = nullptr;

Expr Var(Expr x) {
Expr Var(const Expr &x) {
auto var = Expr(std::make_shared<IdExpression>());
current_ast_builder().insert(std::make_unique<FrontendAllocaStmt>(
std::static_pointer_cast<IdExpression>(var.expr)->id, DataType::unknown));
Expand Down
33 changes: 18 additions & 15 deletions taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -419,15 +419,17 @@ struct LaneAttribute {
LaneAttribute(const std::vector<T> &data) : data(data) {
}

LaneAttribute(const T &t) {
data.resize(1);
data[0] = t;
LaneAttribute(const T &t): data(1, t) {
}

void resize(int s) {
data.resize(s);
}

void reserve(int s) {
data.reserve(s);
}

void push_back(const T &t) {
data.push_back(t);
}
Expand Down Expand Up @@ -1682,10 +1684,8 @@ struct LocalAddress {
Stmt *var;
int offset;

LocalAddress() : LocalAddress(nullptr, 0) {
}

LocalAddress(Stmt *var, int offset) : var(var), offset(offset) {
TI_ASSERT(var->is<AllocaStmt>());
}
};

Expand Down Expand Up @@ -1745,6 +1745,7 @@ class LocalStoreStmt : public Stmt {
// LaneAttribute<Stmt *> data;

LocalStoreStmt(Stmt *ptr, Stmt *data) : ptr(ptr), data(data) {
TI_ASSERT(ptr->is<AllocaStmt>());
TI_STMT_REG_FIELDS;
}

Expand Down Expand Up @@ -2076,7 +2077,7 @@ class FrontendWhileStmt : public Stmt {
Expr cond;
std::unique_ptr<Block> body;

FrontendWhileStmt(Expr cond) : cond(load_if_ptr(cond)) {
FrontendWhileStmt(const Expr &cond) : cond(load_if_ptr(cond)) {
}

bool is_container_statement() const override {
Expand Down Expand Up @@ -2139,9 +2140,9 @@ extern Block *current_block;
class IdExpression : public Expression {
public:
Identifier id;
IdExpression(std::string name = "") : id(name) {
IdExpression(const std::string &name = "") : id(name) {
}
IdExpression(Identifier id) : id(id) {
IdExpression(const Identifier &id) : id(id) {
}

std::string serialize() override {
Expand All @@ -2168,7 +2169,7 @@ class AtomicOpExpression : public Expression {
AtomicOpType op_type;
Expr dest, val;

AtomicOpExpression(AtomicOpType op_type, Expr dest, Expr val)
AtomicOpExpression(AtomicOpType op_type, const Expr &dest, const Expr &val)
: op_type(op_type), dest(dest), val(val) {
}

Expand Down Expand Up @@ -2276,7 +2277,7 @@ class SNodeOpExpression : public Expression {
class GlobalLoadExpression : public Expression {
public:
Expr ptr;
GlobalLoadExpression(Expr ptr) : ptr(ptr) {
GlobalLoadExpression(const Expr &ptr) : ptr(ptr) {
}

std::string serialize() override {
Expand Down Expand Up @@ -2372,15 +2373,17 @@ inline void SLP(int v) {

class For {
public:
For(Expr i, Expr s, Expr e, const std::function<void()> &func) {
For(const Expr &i, const Expr &s, const Expr &e,
const std::function<void()> &func) {
auto stmt_unique = std::make_unique<FrontendForStmt>(i, s, e);
auto stmt = stmt_unique.get();
current_ast_builder().insert(std::move(stmt_unique));
auto _ = current_ast_builder().create_scope(stmt->body);
func();
}

For(ExprGroup i, Expr global, const std::function<void()> &func) {
For(const ExprGroup &i, const Expr &global,
const std::function<void()> &func) {
auto stmt_unique = std::make_unique<FrontendForStmt>(i, global);
auto stmt = stmt_unique.get();
current_ast_builder().insert(std::move(stmt_unique));
Expand All @@ -2393,7 +2396,7 @@ class For {

class While {
public:
While(Expr cond, const std::function<void()> &func) {
While(const Expr &cond, const std::function<void()> &func) {
auto while_stmt = std::make_unique<FrontendWhileStmt>(cond);
FrontendWhileStmt *ptr = while_stmt.get();
current_ast_builder().insert(std::move(while_stmt));
Expand All @@ -2402,7 +2405,7 @@ class While {
}
};

Expr Var(Expr x);
Expr Var(const Expr &x);

class VectorElement {
public:
Expand Down
8 changes: 4 additions & 4 deletions taichi/transforms/vector_split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,14 +172,14 @@ class BasicBlockVectorSplit : public IRVisitor {
for (int i = 0; i < current_split_factor; i++) {
LaneAttribute<LocalAddress> ptr;
int new_width = need_split ? max_width : stmt->width();
ptr.resize(new_width);
ptr.reserve(new_width);
for (int j = 0; j < new_width; j++) {
LocalAddress addr(stmt->ptr[lane_start(i) + j]);
if (origin2split.find(addr.var) == origin2split.end()) {
ptr[j] = addr;
ptr.push_back(addr);
} else {
ptr[j].var = lookup(addr.var, addr.offset / max_width);
ptr[j].offset = addr.offset % max_width;
ptr.push_back(LocalAddress(lookup(addr.var, addr.offset / max_width),
addr.offset % max_width));
}
}
current_split[i] = Stmt::make<LocalLoadStmt>(ptr);
Expand Down

0 comments on commit 213533d

Please sign in to comment.