Skip to content

Commit

Permalink
[ir] [refactor] Let the type of Alloca be pointer
Browse files Browse the repository at this point in the history
ghstack-source-id: cad6b7f30197dd661e6a59e490c24343cf084ba0
Pull Request resolved: #8007
  • Loading branch information
lin-hitonami committed Jun 2, 2023
1 parent f8c7fa0 commit 9b2a0ab
Show file tree
Hide file tree
Showing 19 changed files with 132 additions and 90 deletions.
4 changes: 2 additions & 2 deletions taichi/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {

void visit(AllocaStmt *stmt) override {
// Override shared memory codegen logic for large shared memory
if (stmt->ret_type->is<TensorType>() && stmt->is_shared) {
auto tensor_type = stmt->ret_type->cast<TensorType>();
auto tensor_type = stmt->ret_type.ptr_removed()->cast<TensorType>();
if (tensor_type && stmt->is_shared) {
size_t shared_array_bytes =
tensor_type->get_num_elements() *
data_type_size(tensor_type->get_element_type());
Expand Down
14 changes: 7 additions & 7 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,10 @@ void TaskCodeGenLLVM::visit(Block *stmt_list) {
}

void TaskCodeGenLLVM::visit(AllocaStmt *stmt) {
if (stmt->ret_type->is<TensorType>()) {
auto tensor_type = stmt->ret_type->cast<TensorType>();
TI_ASSERT(stmt->ret_type.is_pointer());
auto alloca_type = stmt->ret_type.ptr_removed();
if (alloca_type->is<TensorType>()) {
auto tensor_type = alloca_type->cast<TensorType>();
auto type = tlctx->get_data_type(tensor_type);
if (stmt->is_shared) {
auto base = new llvm::GlobalVariable(
Expand All @@ -141,12 +143,10 @@ void TaskCodeGenLLVM::visit(AllocaStmt *stmt) {
llvm_val[stmt] = create_entry_block_alloca(type);
}
} else {
llvm_val[stmt] =
create_entry_block_alloca(stmt->ret_type, stmt->ret_type.is_pointer());
llvm_val[stmt] = create_entry_block_alloca(alloca_type);
// initialize as zero if element is not a pointer
if (!stmt->ret_type.is_pointer())
builder->CreateStore(tlctx->get_constant(stmt->ret_type, 0),
llvm_val[stmt]);
if (!alloca_type->is<PointerType>())
builder->CreateStore(tlctx->get_constant(alloca_type, 0), llvm_val[stmt]);
}
}

Expand Down
2 changes: 0 additions & 2 deletions taichi/codegen/llvm/llvm_codegen_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,6 @@ class LLVMModuleBuilder {

llvm::Value *create_entry_block_alloca(DataType dt, bool is_pointer = false) {
auto type = tlctx->get_data_type(dt);
if (is_pointer)
type = llvm::PointerType::get(type, 0);
return create_entry_block_alloca(type);
}

Expand Down
6 changes: 3 additions & 3 deletions taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,8 @@ class TaskCodegen : public IRVisitor {

void visit(AllocaStmt *alloca) override {
spirv::Value ptr_val;
if (alloca->ret_type->is<TensorType>()) {
auto tensor_type = alloca->ret_type->cast<TensorType>();
auto alloca_type = alloca->ret_type.ptr_removed();
if (auto tensor_type = alloca_type->cast<TensorType>()) {
auto elem_num = tensor_type->get_num_elements();
spirv::SType elem_type =
ir_->get_primitive_type(tensor_type->get_element_type());
Expand All @@ -288,7 +288,7 @@ class TaskCodegen : public IRVisitor {
}
} else {
// Alloca for a single variable
spirv::SType src_type = ir_->get_primitive_type(alloca->element_type());
spirv::SType src_type = ir_->get_primitive_type(alloca_type);
ptr_val = ir_->alloca_variable(src_type);
ir_->store_variable(ptr_val, ir_->get_zero(src_type));
}
Expand Down
7 changes: 4 additions & 3 deletions taichi/ir/control_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const {
// result: the value to store
Stmt *result = irpass::analysis::get_store_data(
block->statements[last_def_position].get());
bool is_tensor_involved = var->ret_type->is<TensorType>();
bool is_tensor_involved = var->ret_type.ptr_removed()->is<TensorType>();
if (!(var->is<AllocaStmt>() && !is_tensor_involved)) {
// In between the store stmt and current stmt,
// if there's a third-stmt that "may" have stored a "different value" to
Expand Down Expand Up @@ -355,7 +355,7 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const {

// Check for aliased address
// There's a store to the same dest_addr before this stmt
bool is_tensor_involved = var->ret_type->is<TensorType>();
bool is_tensor_involved = var->ret_type.ptr_removed()->is<TensorType>();
if (!(var->is<AllocaStmt>() && !is_tensor_involved)) {
// In between the store stmt and current stmt,
// if there's a third-stmt that "may" have stored a "different value" to
Expand Down Expand Up @@ -443,7 +443,8 @@ bool CFGNode::store_to_load_forwarding(bool after_lower_access,
continue;

// special case of alloca (initialized to 0)
auto zero = Stmt::make<ConstStmt>(TypedConstant(result->ret_type, 0));
auto zero = Stmt::make<ConstStmt>(
TypedConstant(result->ret_type.ptr_removed(), 0));
replace_with(i, std::move(zero), true);
} else {
if (result->ret_type.ptr_removed()->is<TensorType>() &&
Expand Down
27 changes: 18 additions & 9 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ FrontendAssignStmt::FrontendAssignStmt(const Expr &lhs, const Expr &rhs)
: lhs(lhs), rhs(rhs) {
TI_ASSERT(lhs->is_lvalue());
if (lhs.is<IdExpression>() && lhs->ret_type == PrimitiveType::unknown) {
lhs.expr->ret_type = rhs.get_rvalue_type();
lhs.expr->ret_type =
TypeFactory::get_instance().get_pointer_type(rhs.get_rvalue_type());
}
}

Expand Down Expand Up @@ -127,7 +128,8 @@ void FrontendForStmt::init_loop_vars(const ExprGroup &loop_vars) {

void FrontendForStmt::add_loop_var(const Expr &loop_var) {
loop_var_ids.push_back(loop_var.cast<IdExpression>()->id);
loop_var.expr->ret_type = PrimitiveType::i32;
loop_var.expr->ret_type =
TypeFactory::get_instance().get_pointer_type(PrimitiveType::i32);
}

FrontendFuncDefStmt::FrontendFuncDefStmt(const FrontendFuncDefStmt &o)
Expand Down Expand Up @@ -507,6 +509,8 @@ void TernaryOpExpression::type_check(const CompileConfig *config) {
TI_ASSERT_TYPE_CHECKED(op1);
TI_ASSERT_TYPE_CHECKED(op2);
TI_ASSERT_TYPE_CHECKED(op3);
// TI_INFO("Ternary op {}",
// ExpressionHumanFriendlyPrinter::expr_to_string(this));

bool is_valid = true;
bool is_tensor = false;
Expand Down Expand Up @@ -883,7 +887,10 @@ void IndexExpression::type_check(const CompileConfig *) {
"Invalid IndexExpression: the source is not among field, ndarray or "
"local tensor");
}

ret_type = TypeFactory::get_instance().get_pointer_type(ret_type);
// TI_INFO("IndexExpression {} type checked : {}.",
// ExpressionHumanFriendlyPrinter::expr_to_string(this),
// ret_type->to_string());
for (auto &indices : indices_group) {
for (int i = 0; i < indices.exprs.size(); i++) {
auto &expr = indices.exprs[i];
Expand Down Expand Up @@ -960,7 +967,7 @@ void LoopUniqueExpression::flatten(FlattenContext *ctx) {

void IdExpression::flatten(FlattenContext *ctx) {
stmt = ctx->current_block->lookup_var(id);
if (!ret_type->is_primitive(PrimitiveTypeID::unknown)) {
if (stmt->ret_type->is_primitive(PrimitiveTypeID::unknown)) {
stmt->ret_type = ret_type;
}
}
Expand Down Expand Up @@ -1514,7 +1521,7 @@ Expr ASTBuilder::expr_subscript(const Expr &expr,
std::string tb) {
TI_ASSERT(expr.is<FieldExpression>() || expr.is<MatrixFieldExpression>() ||
expr.is<ExternalTensorExpression>() ||
is_tensor(expr.expr->ret_type));
is_tensor(expr.expr->ret_type.ptr_removed()));

// IndexExpression without ret_shape is used for matrix indexing,
// where each entry of ExprGroup is interpreted as indexing into a specific
Expand Down Expand Up @@ -1677,7 +1684,7 @@ std::vector<Expr> ASTBuilder::expand_exprs(const std::vector<Expr> &exprs) {
elem.expr->ret_type = struct_type->get_element_type(indices);
expanded_exprs.push_back(elem);
}
} else if (!expr->ret_type->is<TensorType>()) {
} else if (!expr->ret_type.ptr_removed()->is<TensorType>()) {
expanded_exprs.push_back(expr);
} else {
// Expand TensorType expr
Expand All @@ -1695,7 +1702,7 @@ std::vector<Expr> ASTBuilder::expand_exprs(const std::vector<Expr> &exprs) {
return {ind0, ind1, ind2, ind3}
*/
auto tensor_type = expr->ret_type->cast<TensorType>();
auto tensor_type = expr->ret_type.ptr_removed()->cast<TensorType>();

Expr id_expr;
if (expr.is<IdExpression>()) {
Expand All @@ -1708,7 +1715,8 @@ std::vector<Expr> ASTBuilder::expand_exprs(const std::vector<Expr> &exprs) {
for (int i = 0; i < shape[0]; i++) {
auto ind = Expr(std::make_shared<IndexExpression>(
id_expr, ExprGroup(Expr(i)), expr->tb));
ind.expr->ret_type = tensor_type->get_element_type();
ind->type_check(nullptr);
// ind.expr->ret_type = tensor_type->get_element_type();
expanded_exprs.push_back(ind);
}
} else {
Expand All @@ -1717,7 +1725,8 @@ std::vector<Expr> ASTBuilder::expand_exprs(const std::vector<Expr> &exprs) {
for (int j = 0; j < shape[1]; j++) {
auto ind = Expr(std::make_shared<IndexExpression>(
id_expr, ExprGroup(Expr(i), Expr(j)), expr->tb));
ind.expr->ret_type = tensor_type->get_element_type();
ind->type_check(nullptr);
// ind.expr->ret_type = tensor_type->get_element_type();
expanded_exprs.push_back(ind);
}
}
Expand Down
6 changes: 4 additions & 2 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ class FrontendAllocaStmt : public Stmt {
DataType element,
bool is_shared = false)
: ident(lhs), is_shared(is_shared) {
ret_type = DataType(TypeFactory::create_tensor_type(shape, element));
ret_type = TypeFactory::get_instance().get_pointer_type(
DataType(TypeFactory::create_tensor_type(shape, element)));
}

bool is_shared;
Expand Down Expand Up @@ -500,6 +501,7 @@ class ExternalTensorExpression : public Expression {

void type_check(const CompileConfig *config) override {
ret_type = dt;
ret_type.set_is_pointer(true);
config_ = config;
}

Expand Down Expand Up @@ -585,7 +587,7 @@ class MatrixExpression : public Expression {
std::vector<int> shape,
DataType element_type)
: elements(elements) {
this->dt = DataType(TypeFactory::create_tensor_type(shape, element_type));
dt = TypeFactory::create_tensor_type(shape, element_type);
}

void type_check(const CompileConfig *config) override;
Expand Down
1 change: 1 addition & 0 deletions taichi/ir/statements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ MatrixOfMatrixPtrStmt::MatrixOfMatrixPtrStmt(const std::vector<Stmt *> &stmts,
DataType dt)
: stmts(stmts) {
ret_type = dt;
ret_type.set_is_pointer(true);
TI_STMT_REG_FIELDS;
}

Expand Down
9 changes: 7 additions & 2 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,20 @@ class Function;
class AllocaStmt : public Stmt, public ir_traits::Store {
public:
explicit AllocaStmt(DataType type) : is_shared(false) {
ret_type = type;
if (type->is_primitive(PrimitiveTypeID::unknown)) {
ret_type = type;
} else {
ret_type = TypeFactory::get_instance().get_pointer_type(type);
}
TI_STMT_REG_FIELDS;
}

AllocaStmt(const std::vector<int> &shape,
DataType type,
bool is_shared = false)
: is_shared(is_shared) {
ret_type = TypeFactory::create_tensor_type(shape, type);
ret_type = TypeFactory::get_instance().get_pointer_type(
TypeFactory::create_tensor_type(shape, type));
TI_STMT_REG_FIELDS;
}

Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/type_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ static bool compare_types(DataType x, DataType y) {
static DataType to_primitive_type(DataType d) {
if (d->is<PointerType>()) {
d = d->as<PointerType>()->get_pointee_type();
TI_WARN("promoted_type got a pointer input.");
TI_ERROR("promoted_type got a pointer input.");
}

if (d->is<TensorType>()) {
Expand Down
Loading

0 comments on commit 9b2a0ab

Please sign in to comment.