Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] VectorType -> LegacyVectorType #1943

Merged
merged 2 commits into from
Oct 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1039,7 +1039,7 @@ class KernelCodegen : public IRVisitor {

std::string inject_load_global_tmp(int offset,
DataType dt = PrimitiveType::i32) {
const auto vt = VectorType(/*width=*/1, dt);
const auto vt = LegacyVectorType(/*width=*/1, dt);
auto gtmp = Stmt::make<GlobalTemporaryStmt>(offset, vt);
gtmp->accept(this);
auto gload = Stmt::make<GlobalLoadStmt>(gtmp.get());
Expand Down
4 changes: 2 additions & 2 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1297,15 +1297,15 @@ std::tuple<llvm::Value *, llvm::Value *> CodeGenLLVM::get_range_for_bounds(
begin = tlctx->get_constant(stmt->begin_value);
} else {
auto begin_stmt = Stmt::make<GlobalTemporaryStmt>(
stmt->begin_offset, VectorType(1, PrimitiveType::i32));
stmt->begin_offset, LegacyVectorType(1, PrimitiveType::i32));
begin_stmt->accept(this);
begin = builder->CreateLoad(llvm_val[begin_stmt.get()]);
}
if (stmt->const_end) {
end = tlctx->get_constant(stmt->end_value);
} else {
auto end_stmt = Stmt::make<GlobalTemporaryStmt>(
stmt->end_offset, VectorType(1, PrimitiveType::i32));
stmt->end_offset, LegacyVectorType(1, PrimitiveType::i32));
end_stmt->accept(this);
end = builder->CreateLoad(llvm_val[end_stmt.get()]);
}
Expand Down
4 changes: 2 additions & 2 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class FrontendAllocaStmt : public Stmt {
Identifier ident;

FrontendAllocaStmt(const Identifier &lhs, DataType type) : ident(lhs) {
ret_type = VectorType(1, type);
ret_type = LegacyVectorType(1, type);
}

TI_DEFINE_ACCEPT
Expand Down Expand Up @@ -203,7 +203,7 @@ class FrontendKernelReturnStmt : public Stmt {
Expr value;

FrontendKernelReturnStmt(const Expr &value, DataType dt) : value(value) {
ret_type = VectorType(1, dt);
ret_type = LegacyVectorType(1, dt);
}

bool is_container_statement() const override {
Expand Down
6 changes: 3 additions & 3 deletions taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,19 @@ IRBuilder &current_ast_builder() {
return context->builder();
}

std::string VectorType::pointer_suffix() const {
std::string LegacyVectorType::pointer_suffix() const {
if (is_pointer()) {
return "*";
} else {
return "";
}
}

std::string VectorType::element_type_name() const {
std::string LegacyVectorType::element_type_name() const {
return fmt::format("{}{}", data_type_short_name(data_type), pointer_suffix());
}

std::string VectorType::str() const {
std::string LegacyVectorType::str() const {
auto ename = element_type_name();
return fmt::format("{:4}x{}", ename, width);
}
Expand Down
12 changes: 6 additions & 6 deletions taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,27 +33,27 @@ using ScratchPadOptions = std::vector<std::pair<int, SNode *>>;

IRBuilder &current_ast_builder();

struct VectorType {
struct LegacyVectorType {
private:
bool _is_pointer;

public:
int width;
DataType data_type;

VectorType(int width, DataType data_type, bool is_pointer = false)
LegacyVectorType(int width, DataType data_type, bool is_pointer = false)
: _is_pointer(is_pointer), width(width), data_type(data_type) {
}

VectorType()
LegacyVectorType()
: _is_pointer(false), width(1), data_type(PrimitiveType::unknown) {
}

bool operator==(const VectorType &o) const {
bool operator==(const LegacyVectorType &o) const {
return width == o.width && data_type == o.data_type;
}

bool operator!=(const VectorType &o) const {
bool operator!=(const LegacyVectorType &o) const {
return !(*this == o);
}

Expand Down Expand Up @@ -531,7 +531,7 @@ class Stmt : public IRNode {
bool fields_registered;
std::string tb;
bool is_ptr;
VectorType ret_type;
LegacyVectorType ret_type;

Stmt();
Stmt(const Stmt &stmt);
Expand Down
17 changes: 9 additions & 8 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ TLANG_NAMESPACE_BEGIN
class AllocaStmt : public Stmt {
public:
AllocaStmt(DataType type) {
ret_type = VectorType(1, type);
ret_type = LegacyVectorType(1, type);
TI_STMT_REG_FIELDS;
}

AllocaStmt(int width, DataType type) {
ret_type = VectorType(width, type);
ret_type = LegacyVectorType(width, type);
TI_STMT_REG_FIELDS;
}

Expand Down Expand Up @@ -103,7 +103,7 @@ class ArgLoadStmt : public Stmt {
int arg_id;

ArgLoadStmt(int arg_id, DataType dt, bool is_ptr = false) : arg_id(arg_id) {
this->ret_type = VectorType(1, dt);
this->ret_type = LegacyVectorType(1, dt);
this->is_ptr = is_ptr;
TI_STMT_REG_FIELDS;
}
Expand Down Expand Up @@ -613,7 +613,7 @@ class KernelReturnStmt : public Stmt {
Stmt *value;

KernelReturnStmt(Stmt *value, DataType dt) : value(value) {
this->ret_type = VectorType(1, dt);
this->ret_type = LegacyVectorType(1, dt);
TI_STMT_REG_FIELDS;
}

Expand Down Expand Up @@ -938,7 +938,7 @@ class GlobalTemporaryStmt : public Stmt {
public:
std::size_t offset;

GlobalTemporaryStmt(std::size_t offset, VectorType ret_type)
GlobalTemporaryStmt(std::size_t offset, LegacyVectorType ret_type)
: offset(offset) {
this->ret_type = ret_type;
TI_STMT_REG_FIELDS;
Expand All @@ -956,7 +956,8 @@ class ThreadLocalPtrStmt : public Stmt {
public:
std::size_t offset;

ThreadLocalPtrStmt(std::size_t offset, VectorType ret_type) : offset(offset) {
ThreadLocalPtrStmt(std::size_t offset, LegacyVectorType ret_type)
: offset(offset) {
this->ret_type = ret_type;
TI_STMT_REG_FIELDS;
}
Expand All @@ -973,7 +974,7 @@ class BlockLocalPtrStmt : public Stmt {
public:
Stmt *offset;

BlockLocalPtrStmt(Stmt *offset, VectorType ret_type) : offset(offset) {
BlockLocalPtrStmt(Stmt *offset, LegacyVectorType ret_type) : offset(offset) {
this->ret_type = ret_type;
TI_STMT_REG_FIELDS;
}
Expand Down Expand Up @@ -1004,7 +1005,7 @@ class InternalFuncStmt : public Stmt {
std::string func_name;

InternalFuncStmt(const std::string &func_name) : func_name(func_name) {
this->ret_type = VectorType(1, PrimitiveType::i32);
this->ret_type = LegacyVectorType(1, PrimitiveType::i32);
TI_STMT_REG_FIELDS;
}

Expand Down
6 changes: 3 additions & 3 deletions taichi/transforms/make_block_local.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ void make_block_local_offload(OffloadedStmt *offload) {
TypedConstant(data_type, 0));
}
auto bls_ptr = element_block->push_back<BlockLocalPtrStmt>(
bls_element_offset_bytes, VectorType(1, data_type));
bls_element_offset_bytes, LegacyVectorType(1, data_type));
element_block->push_back<GlobalStoreStmt>(bls_ptr, value);
});
}
Expand Down Expand Up @@ -268,7 +268,7 @@ void make_block_local_offload(OffloadedStmt *offload) {
bls.push_back<ConstStmt>(TypedConstant((int32)bls_offset)));

bls.push_back<BlockLocalPtrStmt>(bls_element_offset,
VectorType(1, data_type));
LegacyVectorType(1, data_type));
global_ptr->replace_with(std::move(bls));
}
}
Expand All @@ -282,7 +282,7 @@ void make_block_local_offload(OffloadedStmt *offload) {
Stmt *bls_element_offset_bytes) {
// Store/accumulate from BLS to global
auto bls_ptr = element_block->push_back<BlockLocalPtrStmt>(
bls_element_offset_bytes, VectorType(1, data_type));
bls_element_offset_bytes, LegacyVectorType(1, data_type));
auto bls_val = element_block->push_back<GlobalLoadStmt>(bls_ptr);

auto global_pointer =
Expand Down
11 changes: 6 additions & 5 deletions taichi/transforms/make_thread_local.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ void make_thread_local_offload(OffloadedStmt *offload) {
tls_offset += (dtype_size - tls_offset % dtype_size) % dtype_size;

auto tls_ptr = offload->tls_prologue->push_back<ThreadLocalPtrStmt>(
tls_offset, VectorType(1, data_type));
tls_offset, LegacyVectorType(1, data_type));

auto zero = offload->tls_prologue->insert(
std::make_unique<ConstStmt>(TypedConstant(data_type, 0)), -1);
Expand All @@ -139,9 +139,10 @@ void make_thread_local_offload(OffloadedStmt *offload) {
// Step 2:
// Make loop body accumulate to TLS ptr instead of global ptr
{
auto tls_ptr = offload->body->insert(
Stmt::make<ThreadLocalPtrStmt>(tls_offset, VectorType(1, data_type)),
0);
auto tls_ptr =
offload->body->insert(Stmt::make<ThreadLocalPtrStmt>(
tls_offset, LegacyVectorType(1, data_type)),
0);
dest->replace_with(tls_ptr);
}

Expand All @@ -153,7 +154,7 @@ void make_thread_local_offload(OffloadedStmt *offload) {
offload->tls_epilogue->parent_stmt = offload;
}
auto tls_ptr = offload->tls_epilogue->push_back<ThreadLocalPtrStmt>(
tls_offset, VectorType(1, data_type));
tls_offset, LegacyVectorType(1, data_type));
// TODO: do not use global load from TLS.
auto tls_load = offload->tls_epilogue->push_back<GlobalLoadStmt>(tls_ptr);
auto global_ptr = offload->tls_epilogue->insert(
Expand Down
4 changes: 2 additions & 2 deletions taichi/transforms/offload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ class IdentifyValuesUsedInOtherOffloads : public BasicStmtVisitor {
global_offset = 0;
}

std::size_t allocate_global(VectorType type) {
std::size_t allocate_global(LegacyVectorType type) {
TI_ASSERT(type.width == 1);
auto ret = global_offset;
global_offset += data_type_size(type.data_type);
Expand Down Expand Up @@ -563,7 +563,7 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
StmtToOffsetMap local_to_global_offset;
std::unordered_map<Stmt *, Stmt *> stmt_to_offloaded;
OffloadedRanges *const offloaded_ranges_;
std::unordered_map<Stmt *, VectorType> local_to_global_vector_type;
std::unordered_map<Stmt *, LegacyVectorType> local_to_global_vector_type;
};

void insert_gc(IRNode *root) {
Expand Down
32 changes: 16 additions & 16 deletions taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class TypeCheck : public IRVisitor {
allow_undefined_visitor = true;
}

static void mark_as_if_const(Stmt *stmt, VectorType t) {
static void mark_as_if_const(Stmt *stmt, LegacyVectorType t) {
if (stmt->is<ConstStmt>()) {
stmt->ret_type = t;
}
Expand Down Expand Up @@ -110,11 +110,11 @@ class TypeCheck : public IRVisitor {
}

void visit(SNodeOpStmt *stmt) {
stmt->ret_type = VectorType(1, PrimitiveType::i32);
stmt->ret_type = LegacyVectorType(1, PrimitiveType::i32);
}

void visit(ExternalTensorShapeAlongAxisStmt *stmt) {
stmt->ret_type = VectorType(1, PrimitiveType::i32);
stmt->ret_type = LegacyVectorType(1, PrimitiveType::i32);
}

void visit(GlobalPtrStmt *stmt) {
Expand Down Expand Up @@ -161,8 +161,8 @@ class TypeCheck : public IRVisitor {
}

void visit(RangeForStmt *stmt) {
mark_as_if_const(stmt->begin, VectorType(1, PrimitiveType::i32));
mark_as_if_const(stmt->end, VectorType(1, PrimitiveType::i32));
mark_as_if_const(stmt->begin, LegacyVectorType(1, PrimitiveType::i32));
mark_as_if_const(stmt->end, LegacyVectorType(1, PrimitiveType::i32));
stmt->body->accept(this);
}

Expand Down Expand Up @@ -288,7 +288,7 @@ class TypeCheck : public IRVisitor {
}
if (is_comparison(stmt->op_type)) {
stmt->ret_type =
VectorType(stmt->lhs->ret_type.width, PrimitiveType::i32);
LegacyVectorType(stmt->lhs->ret_type.width, PrimitiveType::i32);
} else {
stmt->ret_type = stmt->lhs->ret_type;
}
Expand All @@ -309,7 +309,7 @@ class TypeCheck : public IRVisitor {
auto cast_stmt = insert_type_cast_before(stmt, stmt->op3, ret_type);
stmt->op3 = cast_stmt;
}
stmt->ret_type = VectorType(stmt->op1->width(), ret_type);
stmt->ret_type = LegacyVectorType(stmt->op1->width(), ret_type);
} else {
TI_NOT_IMPLEMENTED
}
Expand Down Expand Up @@ -343,36 +343,36 @@ class TypeCheck : public IRVisitor {

void visit(ExternalPtrStmt *stmt) {
stmt->ret_type.set_is_pointer(true);
stmt->ret_type = VectorType(stmt->base_ptrs.size(),
stmt->base_ptrs[0]->ret_type.data_type);
stmt->ret_type = LegacyVectorType(stmt->base_ptrs.size(),
stmt->base_ptrs[0]->ret_type.data_type);
}

void visit(LoopIndexStmt *stmt) {
stmt->ret_type = VectorType(1, PrimitiveType::i32);
stmt->ret_type = LegacyVectorType(1, PrimitiveType::i32);
}

void visit(LoopLinearIndexStmt *stmt) {
stmt->ret_type = VectorType(1, PrimitiveType::i32);
stmt->ret_type = LegacyVectorType(1, PrimitiveType::i32);
}

void visit(BlockCornerIndexStmt *stmt) {
stmt->ret_type = VectorType(1, PrimitiveType::i32);
stmt->ret_type = LegacyVectorType(1, PrimitiveType::i32);
}

void visit(BlockDimStmt *stmt) {
stmt->ret_type = VectorType(1, PrimitiveType::i32);
stmt->ret_type = LegacyVectorType(1, PrimitiveType::i32);
}

void visit(GetRootStmt *stmt) {
stmt->ret_type = VectorType(1, PrimitiveType::gen, true);
stmt->ret_type = LegacyVectorType(1, PrimitiveType::gen, true);
}

void visit(SNodeLookupStmt *stmt) {
stmt->ret_type = VectorType(1, PrimitiveType::gen, true);
stmt->ret_type = LegacyVectorType(1, PrimitiveType::gen, true);
}

void visit(GetChStmt *stmt) {
stmt->ret_type = VectorType(1, stmt->output_snode->dt);
stmt->ret_type = LegacyVectorType(1, stmt->output_snode->dt);
stmt->ret_type.set_is_pointer(true);
}

Expand Down
Loading