Skip to content

Commit

Permalink
[ir] [refactor] Let the type of Alloca be pointer
Browse files Browse the repository at this point in the history
ghstack-source-id: ebf707d7fd59f67fce8d352d493c64d1f95f57bf
Pull Request resolved: #8007
  • Loading branch information
lin-hitonami committed Jun 2, 2023
1 parent b8c8ba9 commit d6cb6af
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 18 deletions.
20 changes: 11 additions & 9 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ FrontendAssignStmt::FrontendAssignStmt(const Expr &lhs, const Expr &rhs)
: lhs(lhs), rhs(rhs) {
TI_ASSERT(lhs->is_lvalue());
if (lhs.is<IdExpression>() && lhs->ret_type == PrimitiveType::unknown) {
lhs.expr->ret_type = rhs.get_rvalue_type();
lhs.expr->ret_type =
TypeFactory::get_instance().get_pointer_type(rhs.get_rvalue_type());
}
}

Expand Down Expand Up @@ -127,7 +128,8 @@ void FrontendForStmt::init_loop_vars(const ExprGroup &loop_vars) {

void FrontendForStmt::add_loop_var(const Expr &loop_var) {
loop_var_ids.push_back(loop_var.cast<IdExpression>()->id);
loop_var.expr->ret_type = PrimitiveType::i32;
loop_var.expr->ret_type =
TypeFactory::get_instance().get_pointer_type(PrimitiveType::i32);
}

FrontendFuncDefStmt::FrontendFuncDefStmt(const FrontendFuncDefStmt &o)
Expand Down Expand Up @@ -883,7 +885,7 @@ void IndexExpression::type_check(const CompileConfig *) {
"Invalid IndexExpression: the source is not among field, ndarray or "
"local tensor");
}

ret_type = TypeFactory::get_instance().get_pointer_type(ret_type);
for (auto &indices : indices_group) {
for (int i = 0; i < indices.exprs.size(); i++) {
auto &expr = indices.exprs[i];
Expand Down Expand Up @@ -960,7 +962,7 @@ void LoopUniqueExpression::flatten(FlattenContext *ctx) {

void IdExpression::flatten(FlattenContext *ctx) {
stmt = ctx->current_block->lookup_var(id);
if (!ret_type->is_primitive(PrimitiveTypeID::unknown)) {
if (stmt->ret_type->is_primitive(PrimitiveTypeID::unknown)) {
stmt->ret_type = ret_type;
}
}
Expand Down Expand Up @@ -1514,7 +1516,7 @@ Expr ASTBuilder::expr_subscript(const Expr &expr,
std::string tb) {
TI_ASSERT(expr.is<FieldExpression>() || expr.is<MatrixFieldExpression>() ||
expr.is<ExternalTensorExpression>() ||
is_tensor(expr.expr->ret_type));
is_tensor(expr.expr->ret_type.ptr_removed()));

// IndexExpression without ret_shape is used for matrix indexing,
// where each entry of ExprGroup is interpreted as indexing into a specific
Expand Down Expand Up @@ -1677,7 +1679,7 @@ std::vector<Expr> ASTBuilder::expand_exprs(const std::vector<Expr> &exprs) {
elem.expr->ret_type = struct_type->get_element_type(indices);
expanded_exprs.push_back(elem);
}
} else if (!expr->ret_type->is<TensorType>()) {
} else if (!expr->ret_type.ptr_removed()->is<TensorType>()) {
expanded_exprs.push_back(expr);
} else {
// Expand TensorType expr
Expand All @@ -1695,7 +1697,7 @@ std::vector<Expr> ASTBuilder::expand_exprs(const std::vector<Expr> &exprs) {
return {ind0, ind1, ind2, ind3}
*/
auto tensor_type = expr->ret_type->cast<TensorType>();
auto tensor_type = expr->ret_type.ptr_removed()->cast<TensorType>();

Expr id_expr;
if (expr.is<IdExpression>()) {
Expand All @@ -1708,7 +1710,7 @@ std::vector<Expr> ASTBuilder::expand_exprs(const std::vector<Expr> &exprs) {
for (int i = 0; i < shape[0]; i++) {
auto ind = Expr(std::make_shared<IndexExpression>(
id_expr, ExprGroup(Expr(i)), expr->tb));
ind.expr->ret_type = tensor_type->get_element_type();
ind->type_check(nullptr);
expanded_exprs.push_back(ind);
}
} else {
Expand All @@ -1717,7 +1719,7 @@ std::vector<Expr> ASTBuilder::expand_exprs(const std::vector<Expr> &exprs) {
for (int j = 0; j < shape[1]; j++) {
auto ind = Expr(std::make_shared<IndexExpression>(
id_expr, ExprGroup(Expr(i), Expr(j)), expr->tb));
ind.expr->ret_type = tensor_type->get_element_type();
ind->type_check(nullptr);
expanded_exprs.push_back(ind);
}
}
Expand Down
6 changes: 4 additions & 2 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ class FrontendAllocaStmt : public Stmt {
DataType element,
bool is_shared = false)
: ident(lhs), is_shared(is_shared) {
ret_type = DataType(TypeFactory::create_tensor_type(shape, element));
ret_type = TypeFactory::get_instance().get_pointer_type(
DataType(TypeFactory::create_tensor_type(shape, element)));
}

bool is_shared;
Expand Down Expand Up @@ -500,6 +501,7 @@ class ExternalTensorExpression : public Expression {

void type_check(const CompileConfig *config) override {
ret_type = dt;
ret_type.set_is_pointer(true);
config_ = config;
}

Expand Down Expand Up @@ -585,7 +587,7 @@ class MatrixExpression : public Expression {
std::vector<int> shape,
DataType element_type)
: elements(elements) {
this->dt = DataType(TypeFactory::create_tensor_type(shape, element_type));
dt = TypeFactory::create_tensor_type(shape, element_type);
}

void type_check(const CompileConfig *config) override;
Expand Down
1 change: 1 addition & 0 deletions taichi/ir/statements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ MatrixOfMatrixPtrStmt::MatrixOfMatrixPtrStmt(const std::vector<Stmt *> &stmts,
DataType dt)
: stmts(stmts) {
ret_type = dt;
ret_type.set_is_pointer(true);
TI_STMT_REG_FIELDS;
}

Expand Down
9 changes: 7 additions & 2 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,20 @@ class Function;
class AllocaStmt : public Stmt, public ir_traits::Store {
public:
explicit AllocaStmt(DataType type) : is_shared(false) {
ret_type = type;
if (type->is_primitive(PrimitiveTypeID::unknown)) {
ret_type = type;
} else {
ret_type = TypeFactory::get_instance().get_pointer_type(type);
}
TI_STMT_REG_FIELDS;
}

AllocaStmt(const std::vector<int> &shape,
DataType type,
bool is_shared = false)
: is_shared(is_shared) {
ret_type = TypeFactory::create_tensor_type(shape, type);
ret_type = TypeFactory::get_instance().get_pointer_type(
TypeFactory::create_tensor_type(shape, type));
TI_STMT_REG_FIELDS;
}

Expand Down
3 changes: 2 additions & 1 deletion taichi/transforms/scalarize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1028,7 +1028,8 @@ class ScalarizePointers : public BasicStmtVisitor {
for (size_t i = 0; i < tensor_type->get_num_elements(); i++) {
auto scalarized_alloca_stmt =
std::make_unique<AllocaStmt>(primitive_type);
scalarized_alloca_stmt->ret_type = primitive_type;
scalarized_alloca_stmt->ret_type =
TypeFactory::get_instance().get_pointer_type(primitive_type);

scalarized_local_tensor_map_[stmt].push_back(
scalarized_alloca_stmt.get());
Expand Down
12 changes: 9 additions & 3 deletions tests/cpp/ir/frontend_type_inference_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ TEST(FrontendTypeInference, Id) {
auto const_i32 = value<int32>(-(1 << 20));
const_i32->type_check(nullptr);
auto id_i32 = kernel->context->builder().make_var(const_i32, const_i32->tb);
EXPECT_EQ(id_i32->ret_type, PrimitiveType::i32);
EXPECT_EQ(id_i32->ret_type,
DataType(TypeFactory::get_instance().get_pointer_type(
PrimitiveType::i32)));
}

TEST(FrontendTypeInference, BinaryOp) {
Expand Down Expand Up @@ -139,7 +141,9 @@ TEST(FrontendTypeInference, GlobalPtr_Field) {
index->type_check(nullptr);
auto global_ptr = ast_builder->expr_subscript(global_var, ExprGroup(index));
global_ptr->type_check(nullptr);
EXPECT_EQ(global_ptr->ret_type, PrimitiveType::u8);
EXPECT_EQ(global_ptr->ret_type,
DataType(TypeFactory::get_instance().get_pointer_type(
PrimitiveType::u8)));
}

TEST(FrontendTypeInference, GlobalPtr_ExternalTensor) {
Expand Down Expand Up @@ -172,7 +176,9 @@ TEST(FrontendTypeInference, TensorElement) {
index->type_check(nullptr);
auto tensor_element = Expr::make<IndexExpression>(var, ExprGroup(index));
tensor_element->type_check(nullptr);
EXPECT_EQ(tensor_element->ret_type, PrimitiveType::u32);
EXPECT_EQ(tensor_element->ret_type,
DataType(TypeFactory::get_instance().get_pointer_type(
PrimitiveType::u32)));
}

TEST(FrontendTypeInference, AtomicOp) {
Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_unary_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test(x: ti.f32) -> ti.i32:


@test_utils.test(arch=[ti.cuda, ti.vulkan, ti.opengl, ti.metal])
def _test_frexp(): # Fails in this PR, but will be fixed in the last PR of this series
def test_frexp():
@ti.kernel
def get_frac(x: ti.f32) -> ti.f32:
a, b = ti.frexp(x)
Expand Down

0 comments on commit d6cb6af

Please sign in to comment.