Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Type] [refactor] Remove LegacyVectorType #1967

Merged
merged 4 commits into from
Oct 17, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1121,7 +1121,7 @@ class KernelCodegen : public IRVisitor {

std::string inject_load_global_tmp(int offset,
DataType dt = PrimitiveType::i32) {
const auto vt = LegacyVectorType(/*width=*/1, dt);
const auto vt = TypeFactory::create_vector_or_scalar_type(1, dt);
auto gtmp = Stmt::make<GlobalTemporaryStmt>(offset, vt);
gtmp->accept(this);
auto gload = Stmt::make<GlobalLoadStmt>(gtmp.get());
Expand Down
4 changes: 2 additions & 2 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1307,15 +1307,15 @@ std::tuple<llvm::Value *, llvm::Value *> CodeGenLLVM::get_range_for_bounds(
begin = tlctx->get_constant(stmt->begin_value);
} else {
auto begin_stmt = Stmt::make<GlobalTemporaryStmt>(
stmt->begin_offset, LegacyVectorType(1, PrimitiveType::i32));
stmt->begin_offset, TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::i32));
begin_stmt->accept(this);
begin = builder->CreateLoad(llvm_val[begin_stmt.get()]);
}
if (stmt->const_end) {
end = tlctx->get_constant(stmt->end_value);
} else {
auto end_stmt = Stmt::make<GlobalTemporaryStmt>(
stmt->end_offset, LegacyVectorType(1, PrimitiveType::i32));
stmt->end_offset, TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::i32));
end_stmt->accept(this);
end = builder->CreateLoad(llvm_val[end_stmt.get()]);
}
Expand Down
4 changes: 2 additions & 2 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class FrontendAllocaStmt : public Stmt {
Identifier ident;

FrontendAllocaStmt(const Identifier &lhs, DataType type) : ident(lhs) {
ret_type = LegacyVectorType(1, type);
ret_type = TypeFactory::create_vector_or_scalar_type(1, type);
}

TI_DEFINE_ACCEPT
Expand Down Expand Up @@ -203,7 +203,7 @@ class FrontendKernelReturnStmt : public Stmt {
Expr value;

FrontendKernelReturnStmt(const Expr &value, DataType dt) : value(value) {
ret_type = LegacyVectorType(1, dt);
ret_type = TypeFactory::create_vector_or_scalar_type(1, dt);
}

bool is_container_statement() const override {
Expand Down
10 changes: 5 additions & 5 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ TLANG_NAMESPACE_BEGIN
class AllocaStmt : public Stmt {
public:
AllocaStmt(DataType type) {
ret_type = LegacyVectorType(1, type);
ret_type = TypeFactory::create_vector_or_scalar_type(1, type);
TI_STMT_REG_FIELDS;
}

AllocaStmt(int width, DataType type) {
ret_type = LegacyVectorType(width, type);
ret_type = TypeFactory::create_vector_or_scalar_type(width, type);
TI_STMT_REG_FIELDS;
}

Expand Down Expand Up @@ -105,7 +105,7 @@ class ArgLoadStmt : public Stmt {
bool is_ptr;

ArgLoadStmt(int arg_id, DataType dt, bool is_ptr = false) : arg_id(arg_id) {
this->ret_type = LegacyVectorType(1, dt);
this->ret_type = TypeFactory::create_vector_or_scalar_type(1, dt);
this->is_ptr = is_ptr;
TI_STMT_REG_FIELDS;
}
Expand Down Expand Up @@ -615,7 +615,7 @@ class KernelReturnStmt : public Stmt {
Stmt *value;

KernelReturnStmt(Stmt *value, DataType dt) : value(value) {
this->ret_type = LegacyVectorType(1, dt);
this->ret_type = TypeFactory::create_vector_or_scalar_type(1, dt);
TI_STMT_REG_FIELDS;
}

Expand Down Expand Up @@ -999,7 +999,7 @@ class InternalFuncStmt : public Stmt {
std::string func_name;

InternalFuncStmt(const std::string &func_name) : func_name(func_name) {
this->ret_type = LegacyVectorType(1, PrimitiveType::i32);
this->ret_type = TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::i32);
TI_STMT_REG_FIELDS;
}

Expand Down
9 changes: 0 additions & 9 deletions taichi/ir/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,4 @@ int Type::vector_width() const {
}
}

DataType LegacyVectorType(int width, DataType data_type, bool is_pointer) {
TI_ASSERT(width == 1);
if (is_pointer) {
return TypeFactory::get_instance().get_pointer_type(data_type.get_ptr());
} else {
return data_type;
}
}

TLANG_NAMESPACE_END
5 changes: 1 addition & 4 deletions taichi/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class VectorType : public Type {
public:
VectorType(int num_elements, Type *element)
: num_elements_(num_elements), element_(element) {
TI_ASSERT(num_elements_ != 1);
}

Type *get_element_type() const {
Expand All @@ -161,8 +162,4 @@ class VectorType : public Type {
Type *element_{nullptr};
};

DataType LegacyVectorType(int width,
DataType data_type,
bool is_pointer = false);

TLANG_NAMESPACE_END
12 changes: 12 additions & 0 deletions taichi/ir/type_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,16 @@ Type *TypeFactory::get_pointer_type(Type *element) {
TypeFactory::TypeFactory() {
}

DataType TypeFactory::create_vector_or_scalar_type(int width,
DataType element,
bool element_is_pointer) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the core change, we use this function to replace the LegacyVectorType

TI_ASSERT(width == 1);
if (element_is_pointer) {
return TypeFactory::get_instance().get_pointer_type(element.get_ptr());
}
else {
return element;
}
}

TLANG_NAMESPACE_END
2 changes: 2 additions & 0 deletions taichi/ir/type_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class TypeFactory {

Type *get_pointer_type(Type *element);

static DataType create_vector_or_scalar_type(int width, DataType element, bool element_is_pointer = false);

private:
TypeFactory();

Expand Down
1 change: 1 addition & 0 deletions taichi/lang_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "taichi/common/core.h"
#include "taichi/system/profiler.h"
#include "taichi/ir/type.h"
#include "taichi/ir/type_factory.h"

TLANG_NAMESPACE_BEGIN

Expand Down
6 changes: 3 additions & 3 deletions taichi/transforms/make_block_local.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ void make_block_local_offload(OffloadedStmt *offload) {
TypedConstant(data_type, 0));
}
auto bls_ptr = element_block->push_back<BlockLocalPtrStmt>(
bls_element_offset_bytes, LegacyVectorType(1, data_type, true));
bls_element_offset_bytes, TypeFactory::create_vector_or_scalar_type(1, data_type, true));
element_block->push_back<GlobalStoreStmt>(bls_ptr, value);
});
}
Expand Down Expand Up @@ -269,7 +269,7 @@ void make_block_local_offload(OffloadedStmt *offload) {
bls.push_back<ConstStmt>(TypedConstant((int32)bls_offset)));

bls.push_back<BlockLocalPtrStmt>(bls_element_offset,
LegacyVectorType(1, data_type, true));
TypeFactory::create_vector_or_scalar_type(1, data_type, true));
global_ptr->replace_with(std::move(bls));
}
}
Expand All @@ -283,7 +283,7 @@ void make_block_local_offload(OffloadedStmt *offload) {
Stmt *bls_element_offset_bytes) {
// Store/accumulate from BLS to global
auto bls_ptr = element_block->push_back<BlockLocalPtrStmt>(
bls_element_offset_bytes, LegacyVectorType(1, data_type, true));
bls_element_offset_bytes, TypeFactory::create_vector_or_scalar_type(1, data_type, true));
auto bls_val = element_block->push_back<GlobalLoadStmt>(bls_ptr);

auto global_pointer =
Expand Down
6 changes: 3 additions & 3 deletions taichi/transforms/make_thread_local.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ void make_thread_local_offload(OffloadedStmt *offload) {
tls_offset += (dtype_size - tls_offset % dtype_size) % dtype_size;

auto tls_ptr = offload->tls_prologue->push_back<ThreadLocalPtrStmt>(
tls_offset, LegacyVectorType(1, data_type, true));
tls_offset, TypeFactory::create_vector_or_scalar_type(1, data_type, true));

auto zero = offload->tls_prologue->insert(
std::make_unique<ConstStmt>(TypedConstant(data_type, 0)), -1);
Expand All @@ -141,7 +141,7 @@ void make_thread_local_offload(OffloadedStmt *offload) {
{
auto tls_ptr = offload->body->insert(
Stmt::make<ThreadLocalPtrStmt>(tls_offset,
LegacyVectorType(1, data_type, true)),
TypeFactory::create_vector_or_scalar_type(1, data_type, true)),
0);
dest->replace_with(tls_ptr);
}
Expand All @@ -154,7 +154,7 @@ void make_thread_local_offload(OffloadedStmt *offload) {
offload->tls_epilogue->parent_stmt = offload;
}
auto tls_ptr = offload->tls_epilogue->push_back<ThreadLocalPtrStmt>(
tls_offset, LegacyVectorType(1, data_type, true));
tls_offset, TypeFactory::create_vector_or_scalar_type(1, data_type, true));
// TODO: do not use global load from TLS.
auto tls_load = offload->tls_epilogue->push_back<GlobalLoadStmt>(tls_ptr);
auto global_ptr = offload->tls_epilogue->insert(
Expand Down
28 changes: 14 additions & 14 deletions taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,11 @@ class TypeCheck : public IRVisitor {
}

void visit(SNodeOpStmt *stmt) {
stmt->ret_type = LegacyVectorType(1, PrimitiveType::i32);
stmt->ret_type = TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::i32);
}

void visit(ExternalTensorShapeAlongAxisStmt *stmt) {
stmt->ret_type = LegacyVectorType(1, PrimitiveType::i32);
stmt->ret_type = TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::i32);
}

void visit(GlobalPtrStmt *stmt) {
Expand Down Expand Up @@ -163,8 +163,8 @@ class TypeCheck : public IRVisitor {
}

void visit(RangeForStmt *stmt) {
mark_as_if_const(stmt->begin, LegacyVectorType(1, PrimitiveType::i32));
mark_as_if_const(stmt->end, LegacyVectorType(1, PrimitiveType::i32));
mark_as_if_const(stmt->begin, TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::i32));
mark_as_if_const(stmt->end, TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::i32));
stmt->body->accept(this);
}

Expand Down Expand Up @@ -285,7 +285,7 @@ class TypeCheck : public IRVisitor {
}
}
if (is_comparison(stmt->op_type)) {
stmt->ret_type = LegacyVectorType(stmt->lhs->width(), PrimitiveType::i32);
stmt->ret_type = TypeFactory::create_vector_or_scalar_type(stmt->lhs->width(), PrimitiveType::i32);
} else {
stmt->ret_type = stmt->lhs->ret_type;
}
Expand All @@ -307,7 +307,7 @@ class TypeCheck : public IRVisitor {
auto cast_stmt = insert_type_cast_before(stmt, stmt->op3, ret_type);
stmt->op3 = cast_stmt;
}
stmt->ret_type = LegacyVectorType(stmt->op1->width(), ret_type);
stmt->ret_type = TypeFactory::create_vector_or_scalar_type(stmt->op1->width(), ret_type);
} else {
TI_NOT_IMPLEMENTED
}
Expand Down Expand Up @@ -343,35 +343,35 @@ class TypeCheck : public IRVisitor {
void visit(ExternalPtrStmt *stmt) {
stmt->ret_type.set_is_pointer(true);
stmt->ret_type =
LegacyVectorType(stmt->base_ptrs.size(), stmt->base_ptrs[0]->ret_type);
TypeFactory::create_vector_or_scalar_type(stmt->base_ptrs.size(), stmt->base_ptrs[0]->ret_type);
}

void visit(LoopIndexStmt *stmt) {
stmt->ret_type = LegacyVectorType(1, PrimitiveType::i32);
stmt->ret_type = TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::i32);
}

void visit(LoopLinearIndexStmt *stmt) {
stmt->ret_type = LegacyVectorType(1, PrimitiveType::i32);
stmt->ret_type = TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::i32);
}

void visit(BlockCornerIndexStmt *stmt) {
stmt->ret_type = LegacyVectorType(1, PrimitiveType::i32);
stmt->ret_type = TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::i32);
}

void visit(BlockDimStmt *stmt) {
stmt->ret_type = LegacyVectorType(1, PrimitiveType::i32);
stmt->ret_type = TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::i32);
}

void visit(GetRootStmt *stmt) {
stmt->ret_type = LegacyVectorType(1, PrimitiveType::gen, true);
stmt->ret_type = TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::gen, true);
}

void visit(SNodeLookupStmt *stmt) {
stmt->ret_type = LegacyVectorType(1, PrimitiveType::gen, true);
stmt->ret_type = TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::gen, true);
}

void visit(GetChStmt *stmt) {
stmt->ret_type = LegacyVectorType(1, stmt->output_snode->dt, true);
stmt->ret_type = TypeFactory::create_vector_or_scalar_type(1, stmt->output_snode->dt, true);
}

void visit(OffloadedStmt *stmt) {
Expand Down
20 changes: 10 additions & 10 deletions tests/cpp/test_alg_simp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ TI_TEST("alg_simp") {
block->kernel = kernel.get();

auto global_load_addr = block->push_back<GlobalTemporaryStmt>(
0, LegacyVectorType(1, PrimitiveType::i32));
0, TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::i32));
auto global_load = block->push_back<GlobalLoadStmt>(global_load_addr);
auto zero = block->push_back<ConstStmt>(TypedConstant(0));
auto add =
block->push_back<BinaryOpStmt>(BinaryOpType::add, global_load, zero);
auto global_store_addr = block->push_back<GlobalTemporaryStmt>(
4, LegacyVectorType(1, PrimitiveType::i32));
4, TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::i32));
auto global_store =
block->push_back<GlobalStoreStmt>(global_store_addr, add);

Expand Down Expand Up @@ -53,7 +53,7 @@ TI_TEST("alg_simp") {
block->kernel = kernel.get();

auto global_load_addr = block->push_back<GlobalTemporaryStmt>(
0, LegacyVectorType(1, PrimitiveType::f32));
0, TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::f32));
auto global_load = block->push_back<GlobalLoadStmt>(global_load_addr);
auto one = block->push_back<ConstStmt>(TypedConstant(1.0f));
auto mul1 =
Expand All @@ -63,7 +63,7 @@ TI_TEST("alg_simp") {
auto div = block->push_back<BinaryOpStmt>(BinaryOpType::div, zero, one);
auto sub = block->push_back<BinaryOpStmt>(BinaryOpType::sub, mul2, div);
auto global_store_addr = block->push_back<GlobalTemporaryStmt>(
4, LegacyVectorType(1, PrimitiveType::f32));
4, TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::f32));
auto global_store =
block->push_back<GlobalStoreStmt>(global_store_addr, sub);

Expand Down Expand Up @@ -91,15 +91,15 @@ TI_TEST("alg_simp") {
block->kernel = kernel.get();

auto global_load_addr = block->push_back<GlobalTemporaryStmt>(
0, LegacyVectorType(1, PrimitiveType::i32));
0, TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::i32));
auto global_load = block->push_back<GlobalLoadStmt>(global_load_addr);
auto zero = block->push_back<ConstStmt>(TypedConstant(0));
auto mul =
block->push_back<BinaryOpStmt>(BinaryOpType::mul, global_load, zero);
auto one = block->push_back<ConstStmt>(TypedConstant(1));
auto add = block->push_back<BinaryOpStmt>(BinaryOpType::add, mul, one);
auto global_store_addr = block->push_back<GlobalTemporaryStmt>(
4, LegacyVectorType(1, PrimitiveType::i32));
4, TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::i32));
auto global_store =
block->push_back<GlobalStoreStmt>(global_store_addr, add);

Expand All @@ -119,14 +119,14 @@ TI_TEST("alg_simp") {
block->kernel = kernel.get();

global_load_addr = block->push_back<GlobalTemporaryStmt>(
8, LegacyVectorType(1, PrimitiveType::f32));
8, TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::f32));
global_load = block->push_back<GlobalLoadStmt>(global_load_addr);
zero = block->push_back<ConstStmt>(TypedConstant(0));
mul = block->push_back<BinaryOpStmt>(BinaryOpType::mul, global_load, zero);
one = block->push_back<ConstStmt>(TypedConstant(1));
add = block->push_back<BinaryOpStmt>(BinaryOpType::add, mul, one);
global_store_addr = block->push_back<GlobalTemporaryStmt>(
12, LegacyVectorType(1, PrimitiveType::f32));
12, TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::f32));
global_store = block->push_back<GlobalStoreStmt>(global_store_addr, add);

irpass::type_check(block.get()); // insert 2 casts
Expand All @@ -153,13 +153,13 @@ TI_TEST("alg_simp") {
auto block = std::make_unique<Block>();

auto global_load_addr = block->push_back<GlobalTemporaryStmt>(
0, LegacyVectorType(1, PrimitiveType::i32));
0, TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::i32));
auto global_load = block->push_back<GlobalLoadStmt>(global_load_addr);
auto minus_one = block->push_back<ConstStmt>(TypedConstant(-1));
auto and_result = block->push_back<BinaryOpStmt>(BinaryOpType::bit_and,
minus_one, global_load);
auto global_store_addr = block->push_back<GlobalTemporaryStmt>(
4, LegacyVectorType(1, PrimitiveType::i32));
4, TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::i32));
auto global_store =
block->push_back<GlobalStoreStmt>(global_store_addr, and_result);

Expand Down
4 changes: 2 additions & 2 deletions tests/cpp/test_same_statements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ TI_TEST("same_statements") {
auto block = std::make_unique<Block>();

auto global_load_addr = block->push_back<GlobalTemporaryStmt>(
0, LegacyVectorType(1, PrimitiveType::i32));
0, TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::i32));
auto global_load = block->push_back<GlobalLoadStmt>(global_load_addr);
auto global_store_addr = block->push_back<GlobalTemporaryStmt>(
4, LegacyVectorType(1, PrimitiveType::i32));
4, TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::i32));
auto one = block->push_back<ConstStmt>(TypedConstant(1));
auto if_stmt = block->push_back<IfStmt>(one)->as<IfStmt>();

Expand Down