Skip to content

Commit

Permalink
[Lang] Migrate irpass::scalarize() after irpass::offload() (#7919)
Browse files Browse the repository at this point in the history
Issue: #

### Brief Summary

<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at c9e5993</samp>

This pull request refactors scalarization for matrix operations to
optimize different backends. It moves scalarization from
`compile_to_offloads` to `compile_to_executable` for LLVM, and keeps it
in `compile_to_offloads` for Metal.

### Walkthrough

<!--
copilot:walkthrough
-->
### <samp>🤖 Generated by Copilot at c9e5993</samp>

* Remove scalarization before offloading for non-Metal backends
([link](https://github.com/taichi-dev/taichi/pull/7919/files?diff=unified&w=0#diff-8fde186587db97b3bbc8a856e59bc4467b30257335b0fad064b4eebd521a912bL126-L133))
* Add scalarization before offloading for Metal backend only
([link](https://github.com/taichi-dev/taichi/pull/7919/files?diff=unified&w=0#diff-8fde186587db97b3bbc8a856e59bc4467b30257335b0fad064b4eebd521a912bR138-R145))
  • Loading branch information
jim19930609 authored May 7, 2023
1 parent b248442 commit 1ee025b
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 36 deletions.
16 changes: 8 additions & 8 deletions taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,6 @@ void compile_to_offloads(IRNode *ir,
print("Access flagged I");
irpass::analysis::verify(ir);

if (config.real_matrix_scalarize) {
irpass::scalarize(ir);

// Remove redundant MatrixInitStmt inserted during scalarization
irpass::die(ir);
print("Scalarized");
}

irpass::full_simplify(ir, config, {false, /*autodiff_enabled*/ false});
print("Simplified II");
irpass::analysis::verify(ir);
Expand All @@ -143,6 +135,14 @@ void compile_to_offloads(IRNode *ir,
print("Offloaded");
irpass::analysis::verify(ir);

if (config.real_matrix_scalarize) {
irpass::scalarize(ir);

// Remove redundant MatrixInitStmt inserted during scalarization
irpass::die(ir);
print("Scalarized");
}

// TODO: This pass may be redundant as cfg_optimization() is already called
// in full_simplify().
if (config.opt_level > 0 && config.cfg_optimization) {
Expand Down
27 changes: 11 additions & 16 deletions taichi/transforms/offload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -533,24 +533,19 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
local_to_global_offset_.at(stmt), ret_type);
auto offloaded = stmt_to_offloaded_[stmt];
stmt_to_offloaded_[ptr] = offloaded;

TypedConstant zero(stmt->ret_type.get_element_type());
auto const_zero_stmt = replacement.push_back<ConstStmt>(zero);
if (auto tensor_type = stmt->ret_type->cast<TensorType>()) {
TypedConstant zero(tensor_type->get_element_type());
auto const_zero_stmt = replacement.push_back<ConstStmt>(zero);
stmt_to_offloaded_[const_zero_stmt] = offloaded;
for (int i = 0; i < tensor_type->get_num_elements(); ++i) {
auto const_offset_stmt =
replacement.push_back<ConstStmt>(TypedConstant(i));
auto ptr_offset_stmt =
replacement.push_back<MatrixPtrStmt>(ptr, const_offset_stmt);
auto global_store_stmt = replacement.push_back<GlobalStoreStmt>(
ptr_offset_stmt, const_zero_stmt);
stmt_to_offloaded_[const_offset_stmt] = offloaded;
stmt_to_offloaded_[ptr_offset_stmt] = offloaded;
stmt_to_offloaded_[global_store_stmt] = offloaded;
}
std::vector<Stmt *> zero_values(tensor_type->get_num_elements(),
const_zero_stmt);
auto zero_matrix_init_stmt =
replacement.push_back<MatrixInitStmt>(zero_values);
zero_matrix_init_stmt->ret_type = stmt->ret_type.ptr_removed();
auto global_store_stmt =
replacement.push_back<GlobalStoreStmt>(ptr, zero_matrix_init_stmt);
stmt_to_offloaded_[global_store_stmt] = offloaded;
} else {
TypedConstant zero(stmt->ret_type);
auto const_zero_stmt = replacement.push_back<ConstStmt>(zero);
auto global_store_stmt =
replacement.push_back<GlobalStoreStmt>(ptr, const_zero_stmt);
stmt_to_offloaded_[global_store_stmt] = offloaded;
Expand Down
60 changes: 48 additions & 12 deletions taichi/transforms/scalarize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,7 @@ class GatherScalarizableLocalPointers : public BasicStmtVisitor {
}
};

class ScalarizeLocalPointers : public BasicStmtVisitor {
class ScalarizePointers : public BasicStmtVisitor {
public:
ImmediateIRModifier immediate_modifier_;
DelayedIRModifier delayed_modifier_;
Expand All @@ -890,7 +890,7 @@ class ScalarizeLocalPointers : public BasicStmtVisitor {
// { original_alloca_stmt : [scalarized_alloca_stmt0, ...] }
std::unordered_map<Stmt *, std::vector<Stmt *>> scalarized_local_tensor_map_;

explicit ScalarizeLocalPointers(
explicit ScalarizePointers(
IRNode *node,
const std::unordered_set<Stmt *> &scalarizable_allocas)
: immediate_modifier_(node), scalarizable_allocas_(scalarizable_allocas) {
Expand Down Expand Up @@ -948,16 +948,16 @@ class ScalarizeLocalPointers : public BasicStmtVisitor {
}
}

/*
Before:
MatrixPtrStmt(TensorType<4 x i32>* alloca_stmt, int offset)
After:
scalarized_alloca_stmt =
scalarized_local_tensor_map_[alloca_stmt][offset]
stmt->replace_all_usages_with(scalarized_alloca_stmt)
*/
void visit(MatrixPtrStmt *stmt) override {
/*
Before:
MatrixPtrStmt(TensorType<4 x i32>* alloca_stmt, int offset)
After:
scalarized_alloca_stmt =
scalarized_local_tensor_map_[alloca_stmt][offset]
stmt->replace_all_usages_with(scalarized_alloca_stmt)
*/
if (stmt->origin->is<AllocaStmt>() &&
scalarizable_allocas_.count(stmt->origin) == 1) {
auto alloca_stmt = stmt->origin->cast<AllocaStmt>();
Expand All @@ -979,6 +979,34 @@ class ScalarizeLocalPointers : public BasicStmtVisitor {

immediate_modifier_.replace_usages_with(stmt, new_stmt);
delayed_modifier_.erase(stmt);
return;
}

/*
Before:
TensorType<4 x i32>* ptr = GlobalTempStmt(offset_0)
i32* ptr_1 = MatrixPtrStmt(ptr, offset_1)
After:
i32* $1 = GlobalTempStmt(offset_0 + offset_1 * sizeof(i32))
replace_all_usages_with(ptr_1, $1)
*/
if (stmt->origin->is<GlobalTemporaryStmt>() &&
stmt->offset->is<ConstStmt>()) {
auto global_temp_stmt = stmt->origin->as<GlobalTemporaryStmt>();
auto offset_0 = global_temp_stmt->offset;
auto offset_1 = stmt->offset->as<ConstStmt>()->val.val_int32();
auto new_offset =
offset_0 + offset_1 * data_type_size(stmt->ret_type.ptr_removed());

auto new_global_temp_stmt = std::make_unique<GlobalTemporaryStmt>(
new_offset, stmt->ret_type.ptr_removed().get_element_type());
new_global_temp_stmt->ret_type.set_is_pointer(true);

stmt->replace_usages_with(new_global_temp_stmt.get());
delayed_modifier_.insert_before(stmt, std::move(new_global_temp_stmt));
delayed_modifier_.erase(stmt);
return;
}
}

Expand Down Expand Up @@ -1027,6 +1055,14 @@ class ExtractLocalPointers : public BasicStmtVisitor {
delayed_modifier_.modify_ir();
}

void visit(OffloadedStmt *stmt) override {
// Extract to OffloadStmt
Block *orig_top_level = top_level_;
top_level_ = stmt->body.get();
stmt->all_blocks_accept(this);
top_level_ = orig_top_level;
}

void visit(MatrixPtrStmt *stmt) override {
if (stmt->origin->is<AllocaStmt>()) {
auto alloca_stmt = stmt->origin->cast<AllocaStmt>();
Expand Down Expand Up @@ -1118,7 +1154,7 @@ void scalarize(IRNode *root) {
TI_AUTO_PROF;
Scalarize scalarize_pass(root);
auto scalarizable_allocas = GatherScalarizableLocalPointers::run(root);
ScalarizeLocalPointers scalarize_pointers_pass(root, scalarizable_allocas);
ScalarizePointers scalarize_pointers_pass(root, scalarizable_allocas);
ExtractLocalPointers extract_pointers_pass(root);
MergeExternalAndMatrixPtr::run(root);
}
Expand Down

0 comments on commit 1ee025b

Please sign in to comment.