From f18af28b8210c3d2f2d1565107f46750255d47f2 Mon Sep 17 00:00:00 2001 From: lin-hitonami Date: Fri, 2 Jun 2023 14:40:13 +0800 Subject: [PATCH] [ir] [refactor] Let the type of Alloca be pointer ghstack-source-id: a8ec16888a32815714beabcc567ccc3ff54f4d20 Pull Request resolved: https://github.com/taichi-dev/taichi/pull/8007 --- taichi/ir/frontend_ir.cpp | 8 +++++--- taichi/ir/frontend_ir.h | 6 ++++-- taichi/ir/statements.cpp | 1 + taichi/ir/statements.h | 9 +++++++-- taichi/transforms/scalarize.cpp | 3 ++- tests/cpp/ir/frontend_type_inference_test.cpp | 12 +++++++++--- tests/python/test_unary_ops.py | 2 +- 7 files changed, 29 insertions(+), 12 deletions(-) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index a8c6f473a39fb..9ba46903f1a9f 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -37,7 +37,8 @@ FrontendAssignStmt::FrontendAssignStmt(const Expr &lhs, const Expr &rhs) : lhs(lhs), rhs(rhs) { TI_ASSERT(lhs->is_lvalue()); if (lhs.is() && lhs->ret_type == PrimitiveType::unknown) { - lhs.expr->ret_type = rhs.get_rvalue_type(); + lhs.expr->ret_type = + TypeFactory::get_instance().get_pointer_type(rhs.get_rvalue_type()); } } @@ -127,7 +128,8 @@ void FrontendForStmt::init_loop_vars(const ExprGroup &loop_vars) { void FrontendForStmt::add_loop_var(const Expr &loop_var) { loop_var_ids.push_back(loop_var.cast()->id); - loop_var.expr->ret_type = PrimitiveType::i32; + loop_var.expr->ret_type = + TypeFactory::get_instance().get_pointer_type(PrimitiveType::i32); } FrontendFuncDefStmt::FrontendFuncDefStmt(const FrontendFuncDefStmt &o) @@ -883,7 +885,7 @@ void IndexExpression::type_check(const CompileConfig *) { "Invalid IndexExpression: the source is not among field, ndarray or " "local tensor"); } - + ret_type = TypeFactory::get_instance().get_pointer_type(ret_type); for (auto &indices : indices_group) { for (int i = 0; i < indices.exprs.size(); i++) { auto &expr = indices.exprs[i]; diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index dc9f6fad7d7da..35ecf42241e6d 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -86,7 +86,8 @@ class FrontendAllocaStmt : public Stmt { DataType element, bool is_shared = false) : ident(lhs), is_shared(is_shared) { - ret_type = DataType(TypeFactory::create_tensor_type(shape, element)); + ret_type = TypeFactory::get_instance().get_pointer_type( + DataType(TypeFactory::create_tensor_type(shape, element))); } bool is_shared; @@ -500,6 +501,7 @@ class ExternalTensorExpression : public Expression { void type_check(const CompileConfig *config) override { ret_type = dt; + ret_type.set_is_pointer(true); config_ = config; } @@ -585,7 +587,7 @@ class MatrixExpression : public Expression { std::vector shape, DataType element_type) : elements(elements) { - this->dt = DataType(TypeFactory::create_tensor_type(shape, element_type)); + dt = TypeFactory::create_tensor_type(shape, element_type); } void type_check(const CompileConfig *config) override; diff --git a/taichi/ir/statements.cpp b/taichi/ir/statements.cpp index 88ed18f0cea3b..c01691abf4434 100644 --- a/taichi/ir/statements.cpp +++ b/taichi/ir/statements.cpp @@ -86,6 +86,7 @@ MatrixOfMatrixPtrStmt::MatrixOfMatrixPtrStmt(const std::vector &stmts, DataType dt) : stmts(stmts) { ret_type = dt; + ret_type.set_is_pointer(true); TI_STMT_REG_FIELDS; } diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index b84122ae3c347..935124d6ee015 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -19,7 +19,11 @@ class Function; class AllocaStmt : public Stmt, public ir_traits::Store { public: explicit AllocaStmt(DataType type) : is_shared(false) { - ret_type = type; + if (type->is_primitive(PrimitiveTypeID::unknown)) { + ret_type = type; + } else { + ret_type = TypeFactory::get_instance().get_pointer_type(type); + } TI_STMT_REG_FIELDS; } @@ -27,7 +31,8 @@ class AllocaStmt : public Stmt, public ir_traits::Store { DataType type, bool is_shared = false) : is_shared(is_shared) { - ret_type = TypeFactory::create_tensor_type(shape, type); + ret_type = TypeFactory::get_instance().get_pointer_type( + TypeFactory::create_tensor_type(shape, type)); TI_STMT_REG_FIELDS; } diff --git a/taichi/transforms/scalarize.cpp b/taichi/transforms/scalarize.cpp index f14859af99ede..1687b7a48bba3 100644 --- a/taichi/transforms/scalarize.cpp +++ b/taichi/transforms/scalarize.cpp @@ -1028,7 +1028,8 @@ class ScalarizePointers : public BasicStmtVisitor { for (size_t i = 0; i < tensor_type->get_num_elements(); i++) { auto scalarized_alloca_stmt = std::make_unique(primitive_type); - scalarized_alloca_stmt->ret_type = primitive_type; + scalarized_alloca_stmt->ret_type = + TypeFactory::get_instance().get_pointer_type(primitive_type); scalarized_local_tensor_map_[stmt].push_back( scalarized_alloca_stmt.get()); diff --git a/tests/cpp/ir/frontend_type_inference_test.cpp b/tests/cpp/ir/frontend_type_inference_test.cpp index bce26256c0ac5..dd21a5ec7b92c 100644 --- a/tests/cpp/ir/frontend_type_inference_test.cpp +++ b/tests/cpp/ir/frontend_type_inference_test.cpp @@ -32,7 +32,9 @@ TEST(FrontendTypeInference, Id) { auto const_i32 = value(-(1 << 20)); const_i32->type_check(nullptr); auto id_i32 = kernel->context->builder().make_var(const_i32, const_i32->tb); - EXPECT_EQ(id_i32->ret_type, PrimitiveType::i32); + EXPECT_EQ(id_i32->ret_type, + DataType(TypeFactory::get_instance().get_pointer_type( + PrimitiveType::i32))); } TEST(FrontendTypeInference, BinaryOp) { @@ -139,7 +141,9 @@ TEST(FrontendTypeInference, GlobalPtr_Field) { index->type_check(nullptr); auto global_ptr = ast_builder->expr_subscript(global_var, ExprGroup(index)); global_ptr->type_check(nullptr); - EXPECT_EQ(global_ptr->ret_type, PrimitiveType::u8); + EXPECT_EQ(global_ptr->ret_type, + DataType(TypeFactory::get_instance().get_pointer_type( + PrimitiveType::u8))); } TEST(FrontendTypeInference, GlobalPtr_ExternalTensor) { @@ -172,7 +176,9 @@ TEST(FrontendTypeInference, TensorElement) { index->type_check(nullptr); auto tensor_element = Expr::make(var, ExprGroup(index)); tensor_element->type_check(nullptr); - EXPECT_EQ(tensor_element->ret_type, PrimitiveType::u32); + EXPECT_EQ(tensor_element->ret_type, + DataType(TypeFactory::get_instance().get_pointer_type( + PrimitiveType::u32))); } TEST(FrontendTypeInference, AtomicOp) { diff --git a/tests/python/test_unary_ops.py b/tests/python/test_unary_ops.py index 2dc7f4c68310e..6fe014bee872b 100644 --- a/tests/python/test_unary_ops.py +++ b/tests/python/test_unary_ops.py @@ -78,7 +78,7 @@ def test(x: ti.f32) -> ti.i32: @test_utils.test(arch=[ti.cuda, ti.vulkan, ti.opengl, ti.metal]) -def _test_frexp(): # Fails in this PR, but will be fixed in the last PR of this series +def test_frexp(): @ti.kernel def get_frac(x: ti.f32) -> ti.f32: a, b = ti.frexp(x)