Skip to content

Commit

Permalink
[ir] [refactor] Update passes for the refactor of Alloca
Browse files Browse the repository at this point in the history
ghstack-source-id: 5b1414e3ca9fa712f8c3594608c2bdb898664114
Pull Request resolved: #8125
  • Loading branch information
lin-hitonami committed Jun 2, 2023
1 parent e2da958 commit b8c8ba9
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 29 deletions.
7 changes: 4 additions & 3 deletions taichi/ir/control_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const {
// result: the value to store
Stmt *result = irpass::analysis::get_store_data(
block->statements[last_def_position].get());
bool is_tensor_involved = var->ret_type->is<TensorType>();
bool is_tensor_involved = var->ret_type.ptr_removed()->is<TensorType>();
if (!(var->is<AllocaStmt>() && !is_tensor_involved)) {
// In between the store stmt and current stmt,
// if there's a third-stmt that "may" have stored a "different value" to
Expand Down Expand Up @@ -355,7 +355,7 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const {

// Check for aliased address
// There's a store to the same dest_addr before this stmt
bool is_tensor_involved = var->ret_type->is<TensorType>();
bool is_tensor_involved = var->ret_type.ptr_removed()->is<TensorType>();
if (!(var->is<AllocaStmt>() && !is_tensor_involved)) {
// In between the store stmt and current stmt,
// if there's a third-stmt that "may" have stored a "different value" to
Expand Down Expand Up @@ -443,7 +443,8 @@ bool CFGNode::store_to_load_forwarding(bool after_lower_access,
continue;

// special case of alloca (initialized to 0)
auto zero = Stmt::make<ConstStmt>(TypedConstant(result->ret_type, 0));
auto zero = Stmt::make<ConstStmt>(
TypedConstant(result->ret_type.ptr_removed(), 0));
replace_with(i, std::move(zero), true);
} else {
if (result->ret_type.ptr_removed()->is<TensorType>() &&
Expand Down
11 changes: 6 additions & 5 deletions taichi/transforms/frontend_type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ namespace taichi::lang {

class FrontendTypeCheck : public IRVisitor {
void check_cond_type(const Expr &cond, std::string stmt_name) {
if (!cond->ret_type->is<PrimitiveType>() || !is_integral(cond->ret_type))
auto cond_type = cond.get_rvalue_type();
if (!cond_type->is<PrimitiveType>() || !is_integral(cond_type))
throw TaichiTypeError(fmt::format(
"`{0}` conditions must be an integer; found {1}. Consider using "
"`{0} x != 0` instead of `{0} x` for float values.",
stmt_name, cond->ret_type->to_string()));
stmt_name, cond_type->to_string()));
}

public:
Expand Down Expand Up @@ -49,8 +50,8 @@ class FrontendTypeCheck : public IRVisitor {
}

void visit(FrontendAssignStmt *stmt) override {
auto lhs_type = stmt->lhs->ret_type;
auto rhs_type = stmt->rhs->ret_type;
auto lhs_type = stmt->lhs->ret_type.ptr_removed();
auto rhs_type = stmt->rhs->ret_type.ptr_removed();

auto error = [&]() {
throw TaichiTypeError(fmt::format("{}cannot assign '{}' to '{}'",
Expand Down Expand Up @@ -85,7 +86,7 @@ class FrontendTypeCheck : public IRVisitor {

Expr const &expr = std::get<Expr>(content);
TI_ASSERT(expr.expr != nullptr);
DataType data_type = expr->ret_type;
DataType data_type = expr.get_rvalue_type();
if (data_type->is<TensorType>()) {
data_type = DataType(data_type->as<TensorType>()->get_element_type());
}
Expand Down
6 changes: 3 additions & 3 deletions taichi/transforms/lower_ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ class LowerAST : public IRVisitor {
auto ident = stmt->ident;
TI_ASSERT(block->local_var_to_stmt.find(ident) ==
block->local_var_to_stmt.end());
if (stmt->ret_type->is<TensorType>()) {
auto tensor_type = stmt->ret_type->cast<TensorType>();
auto alloca_type = stmt->ret_type.ptr_removed();
if (auto tensor_type = alloca_type->cast<TensorType>()) {
auto lowered = std::make_unique<AllocaStmt>(
tensor_type->get_shape(), tensor_type->get_element_type(),
stmt->is_shared);
block->local_var_to_stmt.insert(std::make_pair(ident, lowered.get()));
stmt->parent->replace_with(stmt, std::move(lowered));
} else {
auto lowered = std::make_unique<AllocaStmt>(stmt->ret_type);
auto lowered = std::make_unique<AllocaStmt>(alloca_type);
block->local_var_to_stmt.insert(std::make_pair(ident, lowered.get()));
stmt->parent->replace_with(stmt, std::move(lowered));
}
Expand Down
10 changes: 5 additions & 5 deletions taichi/transforms/offload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,16 +530,16 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
if (local_to_global_offset_.find(stmt) == local_to_global_offset_.end())
return;
VecStatement replacement;
auto ret_type = stmt->ret_type;
local_to_global_vector_type_[stmt] = ret_type;
auto alloca_type = stmt->ret_type.ptr_removed();
local_to_global_vector_type_[stmt] = alloca_type;
auto ptr = replacement.push_back<GlobalTemporaryStmt>(
local_to_global_offset_.at(stmt), ret_type);
local_to_global_offset_.at(stmt), alloca_type);
auto offloaded = stmt_to_offloaded_[stmt];
stmt_to_offloaded_[ptr] = offloaded;

TypedConstant zero(stmt->ret_type.get_element_type());
TypedConstant zero(alloca_type.get_element_type());
auto const_zero_stmt = replacement.push_back<ConstStmt>(zero);
if (auto tensor_type = stmt->ret_type->cast<TensorType>()) {
if (auto tensor_type = alloca_type->cast<TensorType>()) {
std::vector<Stmt *> zero_values(tensor_type->get_num_elements(),
const_zero_stmt);
auto zero_matrix_init_stmt =
Expand Down
20 changes: 11 additions & 9 deletions taichi/transforms/scalarize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class Scalarize : public BasicStmtVisitor {
template <typename T>
void scalarize_store_stmt(T *stmt) {
auto dest_dtype = stmt->dest->ret_type.ptr_removed();
auto val_dtype = stmt->val->ret_type;
auto val_dtype = stmt->val->ret_type.ptr_removed();
if (dest_dtype->template is<TensorType>() &&
val_dtype->template is<TensorType>()) {
// Needs scalarize
Expand Down Expand Up @@ -185,7 +185,8 @@ class Scalarize : public BasicStmtVisitor {
stmt->replace_all_usages_with(tmp)
*/
void visit(UnaryOpStmt *stmt) override {
auto operand_dtype = stmt->operand->ret_type;
auto operand_dtype = stmt->operand->ret_type.ptr_removed();
auto stmt_dtype = stmt->ret_type.ptr_removed();
if (operand_dtype->is<TensorType>()) {
// Needs scalarize
auto operand_tensor_type = operand_dtype->as<TensorType>();
Expand All @@ -198,7 +199,7 @@ class Scalarize : public BasicStmtVisitor {

std::vector<Stmt *> matrix_init_values;
int num_elements = operand_tensor_type->get_num_elements();
auto primitive_type = stmt->ret_type.get_element_type();
auto primitive_type = stmt_dtype.get_element_type();
for (size_t i = 0; i < num_elements; i++) {
auto unary_stmt = std::make_unique<UnaryOpStmt>(
stmt->op_type, operand_matrix_init_stmt->values[i]);
Expand Down Expand Up @@ -246,8 +247,9 @@ class Scalarize : public BasicStmtVisitor {
stmt->replace_all_usages_with(tmp)
*/
void visit(BinaryOpStmt *stmt) override {
auto lhs_dtype = stmt->lhs->ret_type;
auto rhs_dtype = stmt->rhs->ret_type;
auto lhs_dtype = stmt->lhs->ret_type.ptr_removed();
auto rhs_dtype = stmt->rhs->ret_type.ptr_removed();
auto stmt_dtype = stmt->ret_type.ptr_removed();
if (lhs_dtype->is<TensorType>() || rhs_dtype->is<TensorType>()) {
// Make sure broadcasting has been correctly applied by
// BinaryOpExpression::type_check().
Expand All @@ -270,7 +272,7 @@ class Scalarize : public BasicStmtVisitor {
TI_ASSERT(rhs_vals.size() == lhs_vals.size());

size_t num_elements = lhs_vals.size();
auto primitive_type = stmt->ret_type.get_element_type();
auto primitive_type = stmt_dtype.get_element_type();
std::vector<Stmt *> matrix_init_values;
for (size_t i = 0; i < num_elements; i++) {
auto binary_stmt = std::make_unique<BinaryOpStmt>(
Expand Down Expand Up @@ -581,9 +583,9 @@ class Scalarize : public BasicStmtVisitor {
stmt->replace_all_usages_with(tmp)
*/
void visit(TernaryOpStmt *stmt) override {
auto cond_dtype = stmt->op1->ret_type;
auto op2_dtype = stmt->op2->ret_type;
auto op3_dtype = stmt->op3->ret_type;
auto cond_dtype = stmt->op1->ret_type.ptr_removed();
auto op2_dtype = stmt->op2->ret_type.ptr_removed();
auto op3_dtype = stmt->op3->ret_type.ptr_removed();
if (cond_dtype->is<TensorType>()) {
// Make sure broadcasting has been correctly applied by
// TernaryOpExpression::type_check().
Expand Down
10 changes: 6 additions & 4 deletions taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ class TypeCheck : public IRVisitor {
Stmt *&val,
const std::string &stmt_name) {
auto dst_type = dst->ret_type.ptr_removed();
auto val_type = val->ret_type.ptr_removed();
if (is_quant(dst_type)) {
// We force the value type to be the compute_type of the bit pointer.
// Casting from compute_type to physical_type is handled in codegen.
dst_type = dst_type->get_compute_type();
}
if (dst_type != val->ret_type) {
auto promoted = promoted_type(dst_type, val->ret_type);
if (dst_type != val_type) {
auto promoted = promoted_type(dst_type, val_type);
if (dst_type != promoted) {
TI_WARN("[{}] {} may lose precision: {} <- {}\n{}", stmt->name(),
stmt_name, dst_type->to_string(), val->ret_data_type_name(),
Expand Down Expand Up @@ -88,13 +89,14 @@ class TypeCheck : public IRVisitor {
TI_ASSERT(stmt->src->is<AllocaStmt>() || stmt->src->is<MatrixPtrStmt>() ||
stmt->src->is<MatrixOfMatrixPtrStmt>());
if (auto ptr_offset_stmt = stmt->src->cast<MatrixPtrStmt>()) {
auto lookup = DataType(ptr_offset_stmt->origin->ret_type->as<TensorType>()
auto lookup = DataType(ptr_offset_stmt->origin->ret_type.ptr_removed()
->as<TensorType>()
->get_element_type())
.ptr_removed();
stmt->ret_type = lookup;
} else {
auto lookup = stmt->src->ret_type;
stmt->ret_type = lookup;
stmt->ret_type = lookup.ptr_removed();
}
}

Expand Down

0 comments on commit b8c8ba9

Please sign in to comment.