Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ir] [refactor] Update passes for the refactor of Alloca #8125

Merged
merged 2 commits into from
Jun 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions taichi/ir/control_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const {
// result: the value to store
Stmt *result = irpass::analysis::get_store_data(
block->statements[last_def_position].get());
bool is_tensor_involved = var->ret_type->is<TensorType>();
bool is_tensor_involved = var->ret_type.ptr_removed()->is<TensorType>();
if (!(var->is<AllocaStmt>() && !is_tensor_involved)) {
// In between the store stmt and current stmt,
// if there's a third-stmt that "may" have stored a "different value" to
Expand Down Expand Up @@ -355,7 +355,7 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const {

// Check for aliased address
// There's a store to the same dest_addr before this stmt
bool is_tensor_involved = var->ret_type->is<TensorType>();
bool is_tensor_involved = var->ret_type.ptr_removed()->is<TensorType>();
if (!(var->is<AllocaStmt>() && !is_tensor_involved)) {
// In between the store stmt and current stmt,
// if there's a third-stmt that "may" have stored a "different value" to
Expand Down Expand Up @@ -443,7 +443,8 @@ bool CFGNode::store_to_load_forwarding(bool after_lower_access,
continue;

// special case of alloca (initialized to 0)
auto zero = Stmt::make<ConstStmt>(TypedConstant(result->ret_type, 0));
auto zero = Stmt::make<ConstStmt>(
TypedConstant(result->ret_type.ptr_removed(), 0));
replace_with(i, std::move(zero), true);
} else {
if (result->ret_type.ptr_removed()->is<TensorType>() &&
Expand Down
11 changes: 6 additions & 5 deletions taichi/transforms/frontend_type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ namespace taichi::lang {

class FrontendTypeCheck : public IRVisitor {
void check_cond_type(const Expr &cond, std::string stmt_name) {
if (!cond->ret_type->is<PrimitiveType>() || !is_integral(cond->ret_type))
auto cond_type = cond.get_rvalue_type();
if (!cond_type->is<PrimitiveType>() || !is_integral(cond_type))
throw TaichiTypeError(fmt::format(
"`{0}` conditions must be an integer; found {1}. Consider using "
"`{0} x != 0` instead of `{0} x` for float values.",
stmt_name, cond->ret_type->to_string()));
stmt_name, cond_type->to_string()));
}

public:
Expand Down Expand Up @@ -49,8 +50,8 @@ class FrontendTypeCheck : public IRVisitor {
}

void visit(FrontendAssignStmt *stmt) override {
auto lhs_type = stmt->lhs->ret_type;
auto rhs_type = stmt->rhs->ret_type;
auto lhs_type = stmt->lhs->ret_type.ptr_removed();
auto rhs_type = stmt->rhs->ret_type.ptr_removed();

auto error = [&]() {
throw TaichiTypeError(fmt::format("{}cannot assign '{}' to '{}'",
Expand Down Expand Up @@ -85,7 +86,7 @@ class FrontendTypeCheck : public IRVisitor {

Expr const &expr = std::get<Expr>(content);
TI_ASSERT(expr.expr != nullptr);
DataType data_type = expr->ret_type;
DataType data_type = expr.get_rvalue_type();
if (data_type->is<TensorType>()) {
data_type = DataType(data_type->as<TensorType>()->get_element_type());
}
Expand Down
6 changes: 3 additions & 3 deletions taichi/transforms/lower_ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ class LowerAST : public IRVisitor {
auto ident = stmt->ident;
TI_ASSERT(block->local_var_to_stmt.find(ident) ==
block->local_var_to_stmt.end());
if (stmt->ret_type->is<TensorType>()) {
auto tensor_type = stmt->ret_type->cast<TensorType>();
auto alloca_type = stmt->ret_type.ptr_removed();
if (auto tensor_type = alloca_type->cast<TensorType>()) {
auto lowered = std::make_unique<AllocaStmt>(
tensor_type->get_shape(), tensor_type->get_element_type(),
stmt->is_shared);
block->local_var_to_stmt.insert(std::make_pair(ident, lowered.get()));
stmt->parent->replace_with(stmt, std::move(lowered));
} else {
auto lowered = std::make_unique<AllocaStmt>(stmt->ret_type);
auto lowered = std::make_unique<AllocaStmt>(alloca_type);
block->local_var_to_stmt.insert(std::make_pair(ident, lowered.get()));
stmt->parent->replace_with(stmt, std::move(lowered));
}
Expand Down
10 changes: 5 additions & 5 deletions taichi/transforms/offload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,16 +530,16 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
if (local_to_global_offset_.find(stmt) == local_to_global_offset_.end())
return;
VecStatement replacement;
auto ret_type = stmt->ret_type;
local_to_global_vector_type_[stmt] = ret_type;
auto alloca_type = stmt->ret_type.ptr_removed();
local_to_global_vector_type_[stmt] = alloca_type;
auto ptr = replacement.push_back<GlobalTemporaryStmt>(
local_to_global_offset_.at(stmt), ret_type);
local_to_global_offset_.at(stmt), alloca_type);
auto offloaded = stmt_to_offloaded_[stmt];
stmt_to_offloaded_[ptr] = offloaded;

TypedConstant zero(stmt->ret_type.get_element_type());
TypedConstant zero(alloca_type.get_element_type());
auto const_zero_stmt = replacement.push_back<ConstStmt>(zero);
if (auto tensor_type = stmt->ret_type->cast<TensorType>()) {
if (auto tensor_type = alloca_type->cast<TensorType>()) {
std::vector<Stmt *> zero_values(tensor_type->get_num_elements(),
const_zero_stmt);
auto zero_matrix_init_stmt =
Expand Down
20 changes: 11 additions & 9 deletions taichi/transforms/scalarize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class Scalarize : public BasicStmtVisitor {
template <typename T>
void scalarize_store_stmt(T *stmt) {
auto dest_dtype = stmt->dest->ret_type.ptr_removed();
auto val_dtype = stmt->val->ret_type;
auto val_dtype = stmt->val->ret_type.ptr_removed();
if (dest_dtype->template is<TensorType>() &&
val_dtype->template is<TensorType>()) {
// Needs scalarize
Expand Down Expand Up @@ -185,7 +185,8 @@ class Scalarize : public BasicStmtVisitor {
stmt->replace_all_usages_with(tmp)
*/
void visit(UnaryOpStmt *stmt) override {
auto operand_dtype = stmt->operand->ret_type;
auto operand_dtype = stmt->operand->ret_type.ptr_removed();
auto stmt_dtype = stmt->ret_type.ptr_removed();
if (operand_dtype->is<TensorType>()) {
// Needs scalarize
auto operand_tensor_type = operand_dtype->as<TensorType>();
Expand All @@ -198,7 +199,7 @@ class Scalarize : public BasicStmtVisitor {

std::vector<Stmt *> matrix_init_values;
int num_elements = operand_tensor_type->get_num_elements();
auto primitive_type = stmt->ret_type.get_element_type();
auto primitive_type = stmt_dtype.get_element_type();
for (size_t i = 0; i < num_elements; i++) {
auto unary_stmt = std::make_unique<UnaryOpStmt>(
stmt->op_type, operand_matrix_init_stmt->values[i]);
Expand Down Expand Up @@ -246,8 +247,9 @@ class Scalarize : public BasicStmtVisitor {
stmt->replace_all_usages_with(tmp)
*/
void visit(BinaryOpStmt *stmt) override {
auto lhs_dtype = stmt->lhs->ret_type;
auto rhs_dtype = stmt->rhs->ret_type;
auto lhs_dtype = stmt->lhs->ret_type.ptr_removed();
auto rhs_dtype = stmt->rhs->ret_type.ptr_removed();
auto stmt_dtype = stmt->ret_type.ptr_removed();
if (lhs_dtype->is<TensorType>() || rhs_dtype->is<TensorType>()) {
// Make sure broadcasting has been correctly applied by
// BinaryOpExpression::type_check().
Expand All @@ -270,7 +272,7 @@ class Scalarize : public BasicStmtVisitor {
TI_ASSERT(rhs_vals.size() == lhs_vals.size());

size_t num_elements = lhs_vals.size();
auto primitive_type = stmt->ret_type.get_element_type();
auto primitive_type = stmt_dtype.get_element_type();
std::vector<Stmt *> matrix_init_values;
for (size_t i = 0; i < num_elements; i++) {
auto binary_stmt = std::make_unique<BinaryOpStmt>(
Expand Down Expand Up @@ -581,9 +583,9 @@ class Scalarize : public BasicStmtVisitor {
stmt->replace_all_usages_with(tmp)
*/
void visit(TernaryOpStmt *stmt) override {
auto cond_dtype = stmt->op1->ret_type;
auto op2_dtype = stmt->op2->ret_type;
auto op3_dtype = stmt->op3->ret_type;
auto cond_dtype = stmt->op1->ret_type.ptr_removed();
auto op2_dtype = stmt->op2->ret_type.ptr_removed();
auto op3_dtype = stmt->op3->ret_type.ptr_removed();
if (cond_dtype->is<TensorType>()) {
// Make sure broadcasting has been correctly applied by
// TernaryOpExpression::type_check().
Expand Down
10 changes: 6 additions & 4 deletions taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ class TypeCheck : public IRVisitor {
Stmt *&val,
const std::string &stmt_name) {
auto dst_type = dst->ret_type.ptr_removed();
auto val_type = val->ret_type.ptr_removed();
if (is_quant(dst_type)) {
// We force the value type to be the compute_type of the bit pointer.
// Casting from compute_type to physical_type is handled in codegen.
dst_type = dst_type->get_compute_type();
}
if (dst_type != val->ret_type) {
auto promoted = promoted_type(dst_type, val->ret_type);
if (dst_type != val_type) {
auto promoted = promoted_type(dst_type, val_type);
if (dst_type != promoted) {
TI_WARN("[{}] {} may lose precision: {} <- {}\n{}", stmt->name(),
stmt_name, dst_type->to_string(), val->ret_data_type_name(),
Expand Down Expand Up @@ -88,13 +89,14 @@ class TypeCheck : public IRVisitor {
TI_ASSERT(stmt->src->is<AllocaStmt>() || stmt->src->is<MatrixPtrStmt>() ||
stmt->src->is<MatrixOfMatrixPtrStmt>());
if (auto ptr_offset_stmt = stmt->src->cast<MatrixPtrStmt>()) {
auto lookup = DataType(ptr_offset_stmt->origin->ret_type->as<TensorType>()
auto lookup = DataType(ptr_offset_stmt->origin->ret_type.ptr_removed()
->as<TensorType>()
->get_element_type())
.ptr_removed();
stmt->ret_type = lookup;
} else {
auto lookup = stmt->src->ret_type;
stmt->ret_type = lookup;
stmt->ret_type = lookup.ptr_removed();
}
}

Expand Down