Skip to content

Commit

Permalink
[ir] [refactor] Let the type of Alloca be pointer
Browse files Browse the repository at this point in the history
ghstack-source-id: ebf707d7fd59f67fce8d352d493c64d1f95f57bf
Pull Request resolved: #8007
  • Loading branch information
lin-hitonami committed Jun 2, 2023
1 parent e2da958 commit 22e6418
Show file tree
Hide file tree
Showing 12 changed files with 70 additions and 47 deletions.
7 changes: 4 additions & 3 deletions taichi/ir/control_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const {
// result: the value to store
Stmt *result = irpass::analysis::get_store_data(
block->statements[last_def_position].get());
bool is_tensor_involved = var->ret_type->is<TensorType>();
bool is_tensor_involved = var->ret_type.ptr_removed()->is<TensorType>();
if (!(var->is<AllocaStmt>() && !is_tensor_involved)) {
// In between the store stmt and current stmt,
// if there's a third-stmt that "may" have stored a "different value" to
Expand Down Expand Up @@ -355,7 +355,7 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const {

// Check for aliased address
// There's a store to the same dest_addr before this stmt
bool is_tensor_involved = var->ret_type->is<TensorType>();
bool is_tensor_involved = var->ret_type.ptr_removed()->is<TensorType>();
if (!(var->is<AllocaStmt>() && !is_tensor_involved)) {
// In between the store stmt and current stmt,
// if there's a third-stmt that "may" have stored a "different value" to
Expand Down Expand Up @@ -443,7 +443,8 @@ bool CFGNode::store_to_load_forwarding(bool after_lower_access,
continue;

// special case of alloca (initialized to 0)
auto zero = Stmt::make<ConstStmt>(TypedConstant(result->ret_type, 0));
auto zero = Stmt::make<ConstStmt>(
TypedConstant(result->ret_type.ptr_removed(), 0));
replace_with(i, std::move(zero), true);
} else {
if (result->ret_type.ptr_removed()->is<TensorType>() &&
Expand Down
20 changes: 11 additions & 9 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ FrontendAssignStmt::FrontendAssignStmt(const Expr &lhs, const Expr &rhs)
: lhs(lhs), rhs(rhs) {
TI_ASSERT(lhs->is_lvalue());
if (lhs.is<IdExpression>() && lhs->ret_type == PrimitiveType::unknown) {
lhs.expr->ret_type = rhs.get_rvalue_type();
lhs.expr->ret_type =
TypeFactory::get_instance().get_pointer_type(rhs.get_rvalue_type());
}
}

Expand Down Expand Up @@ -127,7 +128,8 @@ void FrontendForStmt::init_loop_vars(const ExprGroup &loop_vars) {

void FrontendForStmt::add_loop_var(const Expr &loop_var) {
loop_var_ids.push_back(loop_var.cast<IdExpression>()->id);
loop_var.expr->ret_type = PrimitiveType::i32;
loop_var.expr->ret_type =
TypeFactory::get_instance().get_pointer_type(PrimitiveType::i32);
}

FrontendFuncDefStmt::FrontendFuncDefStmt(const FrontendFuncDefStmt &o)
Expand Down Expand Up @@ -883,7 +885,7 @@ void IndexExpression::type_check(const CompileConfig *) {
"Invalid IndexExpression: the source is not among field, ndarray or "
"local tensor");
}

ret_type = TypeFactory::get_instance().get_pointer_type(ret_type);
for (auto &indices : indices_group) {
for (int i = 0; i < indices.exprs.size(); i++) {
auto &expr = indices.exprs[i];
Expand Down Expand Up @@ -960,7 +962,7 @@ void LoopUniqueExpression::flatten(FlattenContext *ctx) {

void IdExpression::flatten(FlattenContext *ctx) {
stmt = ctx->current_block->lookup_var(id);
if (!ret_type->is_primitive(PrimitiveTypeID::unknown)) {
if (stmt->ret_type->is_primitive(PrimitiveTypeID::unknown)) {
stmt->ret_type = ret_type;
}
}
Expand Down Expand Up @@ -1514,7 +1516,7 @@ Expr ASTBuilder::expr_subscript(const Expr &expr,
std::string tb) {
TI_ASSERT(expr.is<FieldExpression>() || expr.is<MatrixFieldExpression>() ||
expr.is<ExternalTensorExpression>() ||
is_tensor(expr.expr->ret_type));
is_tensor(expr.expr->ret_type.ptr_removed()));

// IndexExpression without ret_shape is used for matrix indexing,
// where each entry of ExprGroup is interpreted as indexing into a specific
Expand Down Expand Up @@ -1677,7 +1679,7 @@ std::vector<Expr> ASTBuilder::expand_exprs(const std::vector<Expr> &exprs) {
elem.expr->ret_type = struct_type->get_element_type(indices);
expanded_exprs.push_back(elem);
}
} else if (!expr->ret_type->is<TensorType>()) {
} else if (!expr->ret_type.ptr_removed()->is<TensorType>()) {
expanded_exprs.push_back(expr);
} else {
// Expand TensorType expr
Expand All @@ -1695,7 +1697,7 @@ std::vector<Expr> ASTBuilder::expand_exprs(const std::vector<Expr> &exprs) {
return {ind0, ind1, ind2, ind3}
*/
auto tensor_type = expr->ret_type->cast<TensorType>();
auto tensor_type = expr->ret_type.ptr_removed()->cast<TensorType>();

Expr id_expr;
if (expr.is<IdExpression>()) {
Expand All @@ -1708,7 +1710,7 @@ std::vector<Expr> ASTBuilder::expand_exprs(const std::vector<Expr> &exprs) {
for (int i = 0; i < shape[0]; i++) {
auto ind = Expr(std::make_shared<IndexExpression>(
id_expr, ExprGroup(Expr(i)), expr->tb));
ind.expr->ret_type = tensor_type->get_element_type();
ind->type_check(nullptr);
expanded_exprs.push_back(ind);
}
} else {
Expand All @@ -1717,7 +1719,7 @@ std::vector<Expr> ASTBuilder::expand_exprs(const std::vector<Expr> &exprs) {
for (int j = 0; j < shape[1]; j++) {
auto ind = Expr(std::make_shared<IndexExpression>(
id_expr, ExprGroup(Expr(i), Expr(j)), expr->tb));
ind.expr->ret_type = tensor_type->get_element_type();
ind->type_check(nullptr);
expanded_exprs.push_back(ind);
}
}
Expand Down
6 changes: 4 additions & 2 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ class FrontendAllocaStmt : public Stmt {
DataType element,
bool is_shared = false)
: ident(lhs), is_shared(is_shared) {
ret_type = DataType(TypeFactory::create_tensor_type(shape, element));
ret_type = TypeFactory::get_instance().get_pointer_type(
DataType(TypeFactory::create_tensor_type(shape, element)));
}

bool is_shared;
Expand Down Expand Up @@ -500,6 +501,7 @@ class ExternalTensorExpression : public Expression {

void type_check(const CompileConfig *config) override {
ret_type = dt;
ret_type.set_is_pointer(true);
config_ = config;
}

Expand Down Expand Up @@ -585,7 +587,7 @@ class MatrixExpression : public Expression {
std::vector<int> shape,
DataType element_type)
: elements(elements) {
this->dt = DataType(TypeFactory::create_tensor_type(shape, element_type));
dt = TypeFactory::create_tensor_type(shape, element_type);
}

void type_check(const CompileConfig *config) override;
Expand Down
1 change: 1 addition & 0 deletions taichi/ir/statements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ MatrixOfMatrixPtrStmt::MatrixOfMatrixPtrStmt(const std::vector<Stmt *> &stmts,
DataType dt)
: stmts(stmts) {
ret_type = dt;
ret_type.set_is_pointer(true);
TI_STMT_REG_FIELDS;
}

Expand Down
9 changes: 7 additions & 2 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,20 @@ class Function;
class AllocaStmt : public Stmt, public ir_traits::Store {
public:
explicit AllocaStmt(DataType type) : is_shared(false) {
ret_type = type;
if (type->is_primitive(PrimitiveTypeID::unknown)) {
ret_type = type;
} else {
ret_type = TypeFactory::get_instance().get_pointer_type(type);
}
TI_STMT_REG_FIELDS;
}

AllocaStmt(const std::vector<int> &shape,
DataType type,
bool is_shared = false)
: is_shared(is_shared) {
ret_type = TypeFactory::create_tensor_type(shape, type);
ret_type = TypeFactory::get_instance().get_pointer_type(
TypeFactory::create_tensor_type(shape, type));
TI_STMT_REG_FIELDS;
}

Expand Down
11 changes: 6 additions & 5 deletions taichi/transforms/frontend_type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ namespace taichi::lang {

class FrontendTypeCheck : public IRVisitor {
void check_cond_type(const Expr &cond, std::string stmt_name) {
if (!cond->ret_type->is<PrimitiveType>() || !is_integral(cond->ret_type))
auto cond_type = cond.get_rvalue_type();
if (!cond_type->is<PrimitiveType>() || !is_integral(cond_type))
throw TaichiTypeError(fmt::format(
"`{0}` conditions must be an integer; found {1}. Consider using "
"`{0} x != 0` instead of `{0} x` for float values.",
stmt_name, cond->ret_type->to_string()));
stmt_name, cond_type->to_string()));
}

public:
Expand Down Expand Up @@ -49,8 +50,8 @@ class FrontendTypeCheck : public IRVisitor {
}

void visit(FrontendAssignStmt *stmt) override {
auto lhs_type = stmt->lhs->ret_type;
auto rhs_type = stmt->rhs->ret_type;
auto lhs_type = stmt->lhs->ret_type.ptr_removed();
auto rhs_type = stmt->rhs->ret_type.ptr_removed();

auto error = [&]() {
throw TaichiTypeError(fmt::format("{}cannot assign '{}' to '{}'",
Expand Down Expand Up @@ -85,7 +86,7 @@ class FrontendTypeCheck : public IRVisitor {

Expr const &expr = std::get<Expr>(content);
TI_ASSERT(expr.expr != nullptr);
DataType data_type = expr->ret_type;
DataType data_type = expr.get_rvalue_type();
if (data_type->is<TensorType>()) {
data_type = DataType(data_type->as<TensorType>()->get_element_type());
}
Expand Down
6 changes: 3 additions & 3 deletions taichi/transforms/lower_ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ class LowerAST : public IRVisitor {
auto ident = stmt->ident;
TI_ASSERT(block->local_var_to_stmt.find(ident) ==
block->local_var_to_stmt.end());
if (stmt->ret_type->is<TensorType>()) {
auto tensor_type = stmt->ret_type->cast<TensorType>();
auto alloca_type = stmt->ret_type.ptr_removed();
if (auto tensor_type = alloca_type->cast<TensorType>()) {
auto lowered = std::make_unique<AllocaStmt>(
tensor_type->get_shape(), tensor_type->get_element_type(),
stmt->is_shared);
block->local_var_to_stmt.insert(std::make_pair(ident, lowered.get()));
stmt->parent->replace_with(stmt, std::move(lowered));
} else {
auto lowered = std::make_unique<AllocaStmt>(stmt->ret_type);
auto lowered = std::make_unique<AllocaStmt>(alloca_type);
block->local_var_to_stmt.insert(std::make_pair(ident, lowered.get()));
stmt->parent->replace_with(stmt, std::move(lowered));
}
Expand Down
10 changes: 5 additions & 5 deletions taichi/transforms/offload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,16 +530,16 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
if (local_to_global_offset_.find(stmt) == local_to_global_offset_.end())
return;
VecStatement replacement;
auto ret_type = stmt->ret_type;
local_to_global_vector_type_[stmt] = ret_type;
auto alloca_type = stmt->ret_type.ptr_removed();
local_to_global_vector_type_[stmt] = alloca_type;
auto ptr = replacement.push_back<GlobalTemporaryStmt>(
local_to_global_offset_.at(stmt), ret_type);
local_to_global_offset_.at(stmt), alloca_type);
auto offloaded = stmt_to_offloaded_[stmt];
stmt_to_offloaded_[ptr] = offloaded;

TypedConstant zero(stmt->ret_type.get_element_type());
TypedConstant zero(alloca_type.get_element_type());
auto const_zero_stmt = replacement.push_back<ConstStmt>(zero);
if (auto tensor_type = stmt->ret_type->cast<TensorType>()) {
if (auto tensor_type = alloca_type->cast<TensorType>()) {
std::vector<Stmt *> zero_values(tensor_type->get_num_elements(),
const_zero_stmt);
auto zero_matrix_init_stmt =
Expand Down
23 changes: 13 additions & 10 deletions taichi/transforms/scalarize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class Scalarize : public BasicStmtVisitor {
template <typename T>
void scalarize_store_stmt(T *stmt) {
auto dest_dtype = stmt->dest->ret_type.ptr_removed();
auto val_dtype = stmt->val->ret_type;
auto val_dtype = stmt->val->ret_type.ptr_removed();
if (dest_dtype->template is<TensorType>() &&
val_dtype->template is<TensorType>()) {
// Needs scalarize
Expand Down Expand Up @@ -185,7 +185,8 @@ class Scalarize : public BasicStmtVisitor {
stmt->replace_all_usages_with(tmp)
*/
void visit(UnaryOpStmt *stmt) override {
auto operand_dtype = stmt->operand->ret_type;
auto operand_dtype = stmt->operand->ret_type.ptr_removed();
auto stmt_dtype = stmt->ret_type.ptr_removed();
if (operand_dtype->is<TensorType>()) {
// Needs scalarize
auto operand_tensor_type = operand_dtype->as<TensorType>();
Expand All @@ -198,7 +199,7 @@ class Scalarize : public BasicStmtVisitor {

std::vector<Stmt *> matrix_init_values;
int num_elements = operand_tensor_type->get_num_elements();
auto primitive_type = stmt->ret_type.get_element_type();
auto primitive_type = stmt_dtype.get_element_type();
for (size_t i = 0; i < num_elements; i++) {
auto unary_stmt = std::make_unique<UnaryOpStmt>(
stmt->op_type, operand_matrix_init_stmt->values[i]);
Expand Down Expand Up @@ -246,8 +247,9 @@ class Scalarize : public BasicStmtVisitor {
stmt->replace_all_usages_with(tmp)
*/
void visit(BinaryOpStmt *stmt) override {
auto lhs_dtype = stmt->lhs->ret_type;
auto rhs_dtype = stmt->rhs->ret_type;
auto lhs_dtype = stmt->lhs->ret_type.ptr_removed();
auto rhs_dtype = stmt->rhs->ret_type.ptr_removed();
auto stmt_dtype = stmt->ret_type.ptr_removed();
if (lhs_dtype->is<TensorType>() || rhs_dtype->is<TensorType>()) {
// Make sure broadcasting has been correctly applied by
// BinaryOpExpression::type_check().
Expand All @@ -270,7 +272,7 @@ class Scalarize : public BasicStmtVisitor {
TI_ASSERT(rhs_vals.size() == lhs_vals.size());

size_t num_elements = lhs_vals.size();
auto primitive_type = stmt->ret_type.get_element_type();
auto primitive_type = stmt_dtype.get_element_type();
std::vector<Stmt *> matrix_init_values;
for (size_t i = 0; i < num_elements; i++) {
auto binary_stmt = std::make_unique<BinaryOpStmt>(
Expand Down Expand Up @@ -581,9 +583,9 @@ class Scalarize : public BasicStmtVisitor {
stmt->replace_all_usages_with(tmp)
*/
void visit(TernaryOpStmt *stmt) override {
auto cond_dtype = stmt->op1->ret_type;
auto op2_dtype = stmt->op2->ret_type;
auto op3_dtype = stmt->op3->ret_type;
auto cond_dtype = stmt->op1->ret_type.ptr_removed();
auto op2_dtype = stmt->op2->ret_type.ptr_removed();
auto op3_dtype = stmt->op3->ret_type.ptr_removed();
if (cond_dtype->is<TensorType>()) {
// Make sure broadcasting has been correctly applied by
// TernaryOpExpression::type_check().
Expand Down Expand Up @@ -1026,7 +1028,8 @@ class ScalarizePointers : public BasicStmtVisitor {
for (size_t i = 0; i < tensor_type->get_num_elements(); i++) {
auto scalarized_alloca_stmt =
std::make_unique<AllocaStmt>(primitive_type);
scalarized_alloca_stmt->ret_type = primitive_type;
scalarized_alloca_stmt->ret_type =
TypeFactory::get_instance().get_pointer_type(primitive_type);

scalarized_local_tensor_map_[stmt].push_back(
scalarized_alloca_stmt.get());
Expand Down
10 changes: 6 additions & 4 deletions taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ class TypeCheck : public IRVisitor {
Stmt *&val,
const std::string &stmt_name) {
auto dst_type = dst->ret_type.ptr_removed();
auto val_type = val->ret_type.ptr_removed();
if (is_quant(dst_type)) {
// We force the value type to be the compute_type of the bit pointer.
// Casting from compute_type to physical_type is handled in codegen.
dst_type = dst_type->get_compute_type();
}
if (dst_type != val->ret_type) {
auto promoted = promoted_type(dst_type, val->ret_type);
if (dst_type != val_type) {
auto promoted = promoted_type(dst_type, val_type);
if (dst_type != promoted) {
TI_WARN("[{}] {} may lose precision: {} <- {}\n{}", stmt->name(),
stmt_name, dst_type->to_string(), val->ret_data_type_name(),
Expand Down Expand Up @@ -88,13 +89,14 @@ class TypeCheck : public IRVisitor {
TI_ASSERT(stmt->src->is<AllocaStmt>() || stmt->src->is<MatrixPtrStmt>() ||
stmt->src->is<MatrixOfMatrixPtrStmt>());
if (auto ptr_offset_stmt = stmt->src->cast<MatrixPtrStmt>()) {
auto lookup = DataType(ptr_offset_stmt->origin->ret_type->as<TensorType>()
auto lookup = DataType(ptr_offset_stmt->origin->ret_type.ptr_removed()
->as<TensorType>()
->get_element_type())
.ptr_removed();
stmt->ret_type = lookup;
} else {
auto lookup = stmt->src->ret_type;
stmt->ret_type = lookup;
stmt->ret_type = lookup.ptr_removed();
}
}

Expand Down
12 changes: 9 additions & 3 deletions tests/cpp/ir/frontend_type_inference_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ TEST(FrontendTypeInference, Id) {
auto const_i32 = value<int32>(-(1 << 20));
const_i32->type_check(nullptr);
auto id_i32 = kernel->context->builder().make_var(const_i32, const_i32->tb);
EXPECT_EQ(id_i32->ret_type, PrimitiveType::i32);
EXPECT_EQ(id_i32->ret_type,
DataType(TypeFactory::get_instance().get_pointer_type(
PrimitiveType::i32)));
}

TEST(FrontendTypeInference, BinaryOp) {
Expand Down Expand Up @@ -139,7 +141,9 @@ TEST(FrontendTypeInference, GlobalPtr_Field) {
index->type_check(nullptr);
auto global_ptr = ast_builder->expr_subscript(global_var, ExprGroup(index));
global_ptr->type_check(nullptr);
EXPECT_EQ(global_ptr->ret_type, PrimitiveType::u8);
EXPECT_EQ(global_ptr->ret_type,
DataType(TypeFactory::get_instance().get_pointer_type(
PrimitiveType::u8)));
}

TEST(FrontendTypeInference, GlobalPtr_ExternalTensor) {
Expand Down Expand Up @@ -172,7 +176,9 @@ TEST(FrontendTypeInference, TensorElement) {
index->type_check(nullptr);
auto tensor_element = Expr::make<IndexExpression>(var, ExprGroup(index));
tensor_element->type_check(nullptr);
EXPECT_EQ(tensor_element->ret_type, PrimitiveType::u32);
EXPECT_EQ(tensor_element->ret_type,
DataType(TypeFactory::get_instance().get_pointer_type(
PrimitiveType::u32)));
}

TEST(FrontendTypeInference, AtomicOp) {
Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_unary_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test(x: ti.f32) -> ti.i32:


@test_utils.test(arch=[ti.cuda, ti.vulkan, ti.opengl, ti.metal])
def _test_frexp(): # Fails in this PR, but will be fixed in the last PR of this series
def test_frexp():
@ti.kernel
def get_frac(x: ti.f32) -> ti.f32:
a, b = ti.frexp(x)
Expand Down

0 comments on commit 22e6418

Please sign in to comment.