Skip to content

Commit

Permalink
[Type] [refactor] Remove LegacyVectorType (#1967)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanke98 authored Oct 17, 2020
1 parent 1c1edb7 commit f648538
Show file tree
Hide file tree
Showing 14 changed files with 84 additions and 58 deletions.
2 changes: 1 addition & 1 deletion taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1121,7 +1121,7 @@ class KernelCodegen : public IRVisitor {

std::string inject_load_global_tmp(int offset,
DataType dt = PrimitiveType::i32) {
const auto vt = LegacyVectorType(/*width=*/1, dt);
const auto vt = TypeFactory::create_vector_or_scalar_type(1, dt);
auto gtmp = Stmt::make<GlobalTemporaryStmt>(offset, vt);
gtmp->accept(this);
auto gload = Stmt::make<GlobalLoadStmt>(gtmp.get());
Expand Down
6 changes: 4 additions & 2 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1307,15 +1307,17 @@ std::tuple<llvm::Value *, llvm::Value *> CodeGenLLVM::get_range_for_bounds(
begin = tlctx->get_constant(stmt->begin_value);
} else {
auto begin_stmt = Stmt::make<GlobalTemporaryStmt>(
stmt->begin_offset, LegacyVectorType(1, PrimitiveType::i32));
stmt->begin_offset,
TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::i32));
begin_stmt->accept(this);
begin = builder->CreateLoad(llvm_val[begin_stmt.get()]);
}
if (stmt->const_end) {
end = tlctx->get_constant(stmt->end_value);
} else {
auto end_stmt = Stmt::make<GlobalTemporaryStmt>(
stmt->end_offset, LegacyVectorType(1, PrimitiveType::i32));
stmt->end_offset,
TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::i32));
end_stmt->accept(this);
end = builder->CreateLoad(llvm_val[end_stmt.get()]);
}
Expand Down
4 changes: 2 additions & 2 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class FrontendAllocaStmt : public Stmt {
Identifier ident;

FrontendAllocaStmt(const Identifier &lhs, DataType type) : ident(lhs) {
ret_type = LegacyVectorType(1, type);
ret_type = TypeFactory::create_vector_or_scalar_type(1, type);
}

TI_DEFINE_ACCEPT
Expand Down Expand Up @@ -203,7 +203,7 @@ class FrontendKernelReturnStmt : public Stmt {
Expr value;

FrontendKernelReturnStmt(const Expr &value, DataType dt) : value(value) {
ret_type = LegacyVectorType(1, dt);
ret_type = TypeFactory::create_vector_or_scalar_type(1, dt);
}

bool is_container_statement() const override {
Expand Down
11 changes: 6 additions & 5 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ TLANG_NAMESPACE_BEGIN
class AllocaStmt : public Stmt {
public:
AllocaStmt(DataType type) {
ret_type = LegacyVectorType(1, type);
ret_type = TypeFactory::create_vector_or_scalar_type(1, type);
TI_STMT_REG_FIELDS;
}

AllocaStmt(int width, DataType type) {
ret_type = LegacyVectorType(width, type);
ret_type = TypeFactory::create_vector_or_scalar_type(width, type);
TI_STMT_REG_FIELDS;
}

Expand Down Expand Up @@ -105,7 +105,7 @@ class ArgLoadStmt : public Stmt {
bool is_ptr;

ArgLoadStmt(int arg_id, DataType dt, bool is_ptr = false) : arg_id(arg_id) {
this->ret_type = LegacyVectorType(1, dt);
this->ret_type = TypeFactory::create_vector_or_scalar_type(1, dt);
this->is_ptr = is_ptr;
TI_STMT_REG_FIELDS;
}
Expand Down Expand Up @@ -615,7 +615,7 @@ class KernelReturnStmt : public Stmt {
Stmt *value;

KernelReturnStmt(Stmt *value, DataType dt) : value(value) {
this->ret_type = LegacyVectorType(1, dt);
this->ret_type = TypeFactory::create_vector_or_scalar_type(1, dt);
TI_STMT_REG_FIELDS;
}

Expand Down Expand Up @@ -999,7 +999,8 @@ class InternalFuncStmt : public Stmt {
std::string func_name;

InternalFuncStmt(const std::string &func_name) : func_name(func_name) {
this->ret_type = LegacyVectorType(1, PrimitiveType::i32);
this->ret_type =
TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::i32);
TI_STMT_REG_FIELDS;
}

Expand Down
9 changes: 0 additions & 9 deletions taichi/ir/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,4 @@ int Type::vector_width() const {
}
}

DataType LegacyVectorType(int width, DataType data_type, bool is_pointer) {
TI_ASSERT(width == 1);
if (is_pointer) {
return TypeFactory::get_instance().get_pointer_type(data_type.get_ptr());
} else {
return data_type;
}
}

TLANG_NAMESPACE_END
5 changes: 1 addition & 4 deletions taichi/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class VectorType : public Type {
public:
VectorType(int num_elements, Type *element)
: num_elements_(num_elements), element_(element) {
TI_ASSERT(num_elements_ != 1);
}

Type *get_element_type() const {
Expand All @@ -161,8 +162,4 @@ class VectorType : public Type {
Type *element_{nullptr};
};

DataType LegacyVectorType(int width,
DataType data_type,
bool is_pointer = false);

TLANG_NAMESPACE_END
11 changes: 11 additions & 0 deletions taichi/ir/type_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,15 @@ Type *TypeFactory::get_pointer_type(Type *element) {
TypeFactory::TypeFactory() {
}

DataType TypeFactory::create_vector_or_scalar_type(int width,
DataType element,
bool element_is_pointer) {
TI_ASSERT(width == 1);
if (element_is_pointer) {
return TypeFactory::get_instance().get_pointer_type(element.get_ptr());
} else {
return element;
}
}

TLANG_NAMESPACE_END
4 changes: 4 additions & 0 deletions taichi/ir/type_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ class TypeFactory {

Type *get_pointer_type(Type *element);

static DataType create_vector_or_scalar_type(int width,
DataType element,
bool element_is_pointer = false);

private:
TypeFactory();

Expand Down
1 change: 1 addition & 0 deletions taichi/lang_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "taichi/common/core.h"
#include "taichi/system/profiler.h"
#include "taichi/ir/type.h"
#include "taichi/ir/type_factory.h"

TLANG_NAMESPACE_BEGIN

Expand Down
11 changes: 7 additions & 4 deletions taichi/transforms/make_block_local.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ void make_block_local_offload(OffloadedStmt *offload) {
TypedConstant(data_type, 0));
}
auto bls_ptr = element_block->push_back<BlockLocalPtrStmt>(
bls_element_offset_bytes, LegacyVectorType(1, data_type, true));
bls_element_offset_bytes,
TypeFactory::create_vector_or_scalar_type(1, data_type, true));
element_block->push_back<GlobalStoreStmt>(bls_ptr, value);
});
}
Expand Down Expand Up @@ -268,8 +269,9 @@ void make_block_local_offload(OffloadedStmt *offload) {
BinaryOpType::add, bls_element_offset,
bls.push_back<ConstStmt>(TypedConstant((int32)bls_offset)));

bls.push_back<BlockLocalPtrStmt>(bls_element_offset,
LegacyVectorType(1, data_type, true));
bls.push_back<BlockLocalPtrStmt>(
bls_element_offset,
TypeFactory::create_vector_or_scalar_type(1, data_type, true));
global_ptr->replace_with(std::move(bls));
}
}
Expand All @@ -283,7 +285,8 @@ void make_block_local_offload(OffloadedStmt *offload) {
Stmt *bls_element_offset_bytes) {
// Store/accumulate from BLS to global
auto bls_ptr = element_block->push_back<BlockLocalPtrStmt>(
bls_element_offset_bytes, LegacyVectorType(1, data_type, true));
bls_element_offset_bytes,
TypeFactory::create_vector_or_scalar_type(1, data_type, true));
auto bls_val = element_block->push_back<GlobalLoadStmt>(bls_ptr);

auto global_pointer =
Expand Down
11 changes: 7 additions & 4 deletions taichi/transforms/make_thread_local.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ void make_thread_local_offload(OffloadedStmt *offload) {
tls_offset += (dtype_size - tls_offset % dtype_size) % dtype_size;

auto tls_ptr = offload->tls_prologue->push_back<ThreadLocalPtrStmt>(
tls_offset, LegacyVectorType(1, data_type, true));
tls_offset,
TypeFactory::create_vector_or_scalar_type(1, data_type, true));

auto zero = offload->tls_prologue->insert(
std::make_unique<ConstStmt>(TypedConstant(data_type, 0)), -1);
Expand All @@ -140,8 +141,9 @@ void make_thread_local_offload(OffloadedStmt *offload) {
// Make loop body accumulate to TLS ptr instead of global ptr
{
auto tls_ptr = offload->body->insert(
Stmt::make<ThreadLocalPtrStmt>(tls_offset,
LegacyVectorType(1, data_type, true)),
Stmt::make<ThreadLocalPtrStmt>(
tls_offset,
TypeFactory::create_vector_or_scalar_type(1, data_type, true)),
0);
dest->replace_with(tls_ptr);
}
Expand All @@ -154,7 +156,8 @@ void make_thread_local_offload(OffloadedStmt *offload) {
offload->tls_epilogue->parent_stmt = offload;
}
auto tls_ptr = offload->tls_epilogue->push_back<ThreadLocalPtrStmt>(
tls_offset, LegacyVectorType(1, data_type, true));
tls_offset,
TypeFactory::create_vector_or_scalar_type(1, data_type, true));
// TODO: do not use global load from TLS.
auto tls_load = offload->tls_epilogue->push_back<GlobalLoadStmt>(tls_ptr);
auto global_ptr = offload->tls_epilogue->insert(
Expand Down
43 changes: 28 additions & 15 deletions taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,13 @@ class TypeCheck : public IRVisitor {
}

void visit(SNodeOpStmt *stmt) {
stmt->ret_type = LegacyVectorType(1, PrimitiveType::i32);
stmt->ret_type =
TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::i32);
}

void visit(ExternalTensorShapeAlongAxisStmt *stmt) {
stmt->ret_type = LegacyVectorType(1, PrimitiveType::i32);
stmt->ret_type =
TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::i32);
}

void visit(GlobalPtrStmt *stmt) {
Expand Down Expand Up @@ -163,8 +165,10 @@ class TypeCheck : public IRVisitor {
}

void visit(RangeForStmt *stmt) {
mark_as_if_const(stmt->begin, LegacyVectorType(1, PrimitiveType::i32));
mark_as_if_const(stmt->end, LegacyVectorType(1, PrimitiveType::i32));
mark_as_if_const(stmt->begin, TypeFactory::create_vector_or_scalar_type(
1, PrimitiveType::i32));
mark_as_if_const(stmt->end, TypeFactory::create_vector_or_scalar_type(
1, PrimitiveType::i32));
stmt->body->accept(this);
}

Expand Down Expand Up @@ -285,7 +289,8 @@ class TypeCheck : public IRVisitor {
}
}
if (is_comparison(stmt->op_type)) {
stmt->ret_type = LegacyVectorType(stmt->lhs->width(), PrimitiveType::i32);
stmt->ret_type = TypeFactory::create_vector_or_scalar_type(
stmt->lhs->width(), PrimitiveType::i32);
} else {
stmt->ret_type = stmt->lhs->ret_type;
}
Expand All @@ -307,7 +312,8 @@ class TypeCheck : public IRVisitor {
auto cast_stmt = insert_type_cast_before(stmt, stmt->op3, ret_type);
stmt->op3 = cast_stmt;
}
stmt->ret_type = LegacyVectorType(stmt->op1->width(), ret_type);
stmt->ret_type = TypeFactory::create_vector_or_scalar_type(
stmt->op1->width(), ret_type);
} else {
TI_NOT_IMPLEMENTED
}
Expand Down Expand Up @@ -342,36 +348,43 @@ class TypeCheck : public IRVisitor {

void visit(ExternalPtrStmt *stmt) {
stmt->ret_type.set_is_pointer(true);
stmt->ret_type =
LegacyVectorType(stmt->base_ptrs.size(), stmt->base_ptrs[0]->ret_type);
stmt->ret_type = TypeFactory::create_vector_or_scalar_type(
stmt->base_ptrs.size(), stmt->base_ptrs[0]->ret_type);
}

void visit(LoopIndexStmt *stmt) {
stmt->ret_type = LegacyVectorType(1, PrimitiveType::i32);
stmt->ret_type =
TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::i32);
}

void visit(LoopLinearIndexStmt *stmt) {
stmt->ret_type = LegacyVectorType(1, PrimitiveType::i32);
stmt->ret_type =
TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::i32);
}

void visit(BlockCornerIndexStmt *stmt) {
stmt->ret_type = LegacyVectorType(1, PrimitiveType::i32);
stmt->ret_type =
TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::i32);
}

void visit(BlockDimStmt *stmt) {
stmt->ret_type = LegacyVectorType(1, PrimitiveType::i32);
stmt->ret_type =
TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::i32);
}

void visit(GetRootStmt *stmt) {
stmt->ret_type = LegacyVectorType(1, PrimitiveType::gen, true);
stmt->ret_type =
TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::gen, true);
}

void visit(SNodeLookupStmt *stmt) {
stmt->ret_type = LegacyVectorType(1, PrimitiveType::gen, true);
stmt->ret_type =
TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::gen, true);
}

void visit(GetChStmt *stmt) {
stmt->ret_type = LegacyVectorType(1, stmt->output_snode->dt, true);
stmt->ret_type = TypeFactory::create_vector_or_scalar_type(
1, stmt->output_snode->dt, true);
}

void visit(OffloadedStmt *stmt) {
Expand Down
Loading

0 comments on commit f648538

Please sign in to comment.