Skip to content

Commit

Permalink
[Lang] Migrate irpass::scalarize() after irpass::detect_read_only() (t…
Browse files Browse the repository at this point in the history
…aichi-dev#7939)

Issue: #

### Brief Summary

<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at e7b3921</samp>

This pull request refactors the IR scalarization pass and the local
pointer extraction pass, and moves the scalarization pass to a later
stage in the compilation pipeline. These changes aim to separate the IR
transformation and code generation stages, and to enable more
optimizations for scalarized matrices.

### Walkthrough

<!--
copilot:walkthrough
-->
### <samp>🤖 Generated by Copilot at e7b3921</samp>

* Refactor the IR transformation and code generation stages to separate
the scalarization pass from the `compile_to_offloads` function and move
it to the `offload_to_executable` function
([link](https://github.com/taichi-dev/taichi/pull/7939/files?diff=unified&w=0#diff-8fde186587db97b3bbc8a856e59bc4467b30257335b0fad064b4eebd521a912bL130-L137),
[link](https://github.com/taichi-dev/taichi/pull/7939/files?diff=unified&w=0#diff-8fde186587db97b3bbc8a856e59bc4467b30257335b0fad064b4eebd521a912bR182-R189),
[link](https://github.com/taichi-dev/taichi/pull/7939/files?diff=unified&w=0#diff-97b0d9ab204b703802b3b5d04d036d30f66b34b726128216faf0d8a2a8564528L1121-R1157))
* Add a full simplification pass after scalarization to optimize the
scalarized IR and eliminate redundant statements
([link](https://github.com/taichi-dev/taichi/pull/7939/files?diff=unified&w=0#diff-8fde186587db97b3bbc8a856e59bc4467b30257335b0fad064b4eebd521a912bR182-R189))
* Rename the class `ScalarizeLocalPointers` to `ScalarizePointers` and
update the constructor to reflect its ability to handle both local and
global matrix pointers, as well as matrix pointers from external arrays
([link](https://github.com/taichi-dev/taichi/pull/7939/files?diff=unified&w=0#diff-97b0d9ab204b703802b3b5d04d036d30f66b34b726128216faf0d8a2a8564528L884-R884),
[link](https://github.com/taichi-dev/taichi/pull/7939/files?diff=unified&w=0#diff-97b0d9ab204b703802b3b5d04d036d30f66b34b726128216faf0d8a2a8564528L893-R893),
[link](https://github.com/taichi-dev/taichi/pull/7939/files?diff=unified&w=0#diff-97b0d9ab204b703802b3b5d04d036d30f66b34b726128216faf0d8a2a8564528L1121-R1157))
* Simplify the logic for initializing a local tensor with zero values by
using a `MatrixInitStmt` instead of multiple `GlobalStoreStmt`s in the
`ExtractLocalPointers` class
([link](https://github.com/taichi-dev/taichi/pull/7939/files?diff=unified&w=0#diff-d47b571f975c1002b8cb93634ac2a3d5f090f3fa9676ec3e0004c2ec4116ee21L536-R548))
* Move the comment explaining the logic of the `visit` function for
`MatrixPtrStmt` in the `ScalarizePointers` class to improve the
readability of the code
([link](https://github.com/taichi-dev/taichi/pull/7939/files?diff=unified&w=0#diff-97b0d9ab204b703802b3b5d04d036d30f66b34b726128216faf0d8a2a8564528L951-R960))
* Add a new logic branch to the `visit` function for `MatrixPtrStmt` in
the `ScalarizePointers` class to handle the case where the matrix
pointer originates from a global temporary statement, and simplify the
global temporary statement by adding the matrix offset to the original
offset
([link](https://github.com/taichi-dev/taichi/pull/7939/files?diff=unified&w=0#diff-97b0d9ab204b703802b3b5d04d036d30f66b34b726128216faf0d8a2a8564528L982-R1010))
* Add a new `visit` function for `OffloadedStmt` in the
`ExtractLocalPointers` class to ensure that the extraction process is
applied to each offloaded task, which may contain new local pointers
introduced by scalarization
([link](https://github.com/taichi-dev/taichi/pull/7939/files?diff=unified&w=0#diff-97b0d9ab204b703802b3b5d04d036d30f66b34b726128216faf0d8a2a8564528R1058-R1065))
  • Loading branch information
jim19930609 authored and quadpixels committed May 13, 2023
1 parent 9f4c2b5 commit 47397a3
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
16 changes: 8 additions & 8 deletions taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,6 @@ void compile_to_offloads(IRNode *ir,
print("Offloaded");
irpass::analysis::verify(ir);

if (config.real_matrix_scalarize) {
irpass::scalarize(ir);

// Remove redundant MatrixInitStmt inserted during scalarization
irpass::die(ir);
print("Scalarized");
}

// TODO: This pass may be redundant as cfg_optimization() is already called
// in full_simplify().
if (config.opt_level > 0 && config.cfg_optimization) {
Expand Down Expand Up @@ -187,6 +179,14 @@ void offload_to_executable(IRNode *ir,
print("Detect read-only accesses");
}

if (config.real_matrix_scalarize) {
irpass::scalarize(ir);

// Remove redundant MatrixInitStmt inserted during scalarization
irpass::full_simplify(ir, config, {false, /*autodiff_enabled*/ false});
print("Scalarized");
}

irpass::demote_atomics(ir, config);
print("Atomics demoted I");
irpass::analysis::verify(ir);
Expand Down
8 changes: 6 additions & 2 deletions taichi/transforms/scalarize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1049,8 +1049,12 @@ class ExtractLocalPointers : public BasicStmtVisitor {
Block *top_level_;

explicit ExtractLocalPointers(IRNode *root) : immediate_modifier_(root) {
TI_ASSERT(root->is<Block>());
top_level_ = root->as<Block>();
if (root->is<OffloadedStmt>()) {
top_level_ = root->as<OffloadedStmt>()->body.get();
} else {
TI_ASSERT(root->is<Block>());
top_level_ = root->as<Block>();
}
root->accept(this);
delayed_modifier_.modify_ir();
}
Expand Down

0 comments on commit 47397a3

Please sign in to comment.