Skip to content

Commit

Permalink
[opt] Add ExtractPointers pass for dynamic index (#7051)
Browse files Browse the repository at this point in the history
Issue: #2590

### Brief Summary

Under pure `dynamic_index` setting, `MatrixPtrStmt`s are not scalarized.
It actually produces `2n` more instructions (`n` `ConstStmt`s and n
`MatrixPtrStmt`s) than the scalarized setting, where `n` is the number
of usages of `MatrixPtrStmt`s. This PR adds `ExtractPointers` pass to
eliminate all the redundant instructions. See comments in the code for
details.

After this PR, the number of instructions after the `scalarize()` pass
of the script in #6933 under dynamic index reduces from 49589 to 26581,
and the compilation time reduces from 20.02s to 7.82s.

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
strongoier and pre-commit-ci[bot] authored Jan 5, 2023
1 parent 523dd47 commit ae0882c
Showing 1 changed file with 73 additions and 0 deletions.
73 changes: 73 additions & 0 deletions taichi/transforms/scalarize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -652,13 +652,86 @@ class ScalarizePointers : public BasicStmtVisitor {
using BasicStmtVisitor::visit;
};

// The ExtractPointers pass aims at removing redundant ConstStmts and
// MatrixPtrStmts generated for any (AllocaStmt, integer) pair by extracting
// a unique copy for any future usage.
//
// Example for redundant stmts:
// <i32> $0 = const 0
// <i32> $1 = const 1
// ...
// <[Tensor (3, 3) f32]> $47738 = alloca
// <i32> $47739 = const 0 [REDUNDANT]
// <*f32> $47740 = shift ptr [$47738 + $47739]
// $47741 : local store [$47740 <- $47713]
// <i32> $47742 = const 1 [REDUNDANT]
// <*f32> $47743 = shift ptr [$47738 + $47742]
// $47744 : local store [$47743 <- $47716]
// ...
// <i32> $47812 = const 1 [REDUNDANT]
// <*f32> $47813 = shift ptr [$47738 + $47812] [REDUNDANT]
// <f32> $47814 = local load [$47813]
class ExtractPointers : public BasicStmtVisitor {
public:
ImmediateIRModifier immediate_modifier_;
DelayedIRModifier delayed_modifier_;

std::unordered_map<std::pair<Stmt *, int>,
Stmt *,
hashing::Hasher<std::pair<Stmt *, int>>>
first_matrix_ptr_; // mapping an (AllocaStmt, integer) pair to the first
// MatrixPtrStmt representing it
std::unordered_map<int, Stmt *>
first_const_; // mapping an integer to the first ConstStmt representing
// it
Block *top_level_;

explicit ExtractPointers(IRNode *root) : immediate_modifier_(root) {
TI_ASSERT(root->is<Block>());
top_level_ = root->as<Block>();
root->accept(this);
delayed_modifier_.modify_ir();
}

void visit(MatrixPtrStmt *stmt) override {
if (stmt->origin->is<AllocaStmt>()) {
auto alloca_stmt = stmt->origin->cast<AllocaStmt>();
auto tensor_type =
alloca_stmt->ret_type.ptr_removed()->cast<TensorType>();
TI_ASSERT(tensor_type != nullptr);
if (stmt->offset->is<ConstStmt>()) {
int offset = stmt->offset->cast<ConstStmt>()->val.val_int32();
if (first_const_.count(offset) == 0) {
first_const_[offset] = stmt->offset;
delayed_modifier_.extract_to_block_front(stmt->offset, top_level_);
}
auto key = std::make_pair(alloca_stmt, offset);
if (first_matrix_ptr_.count(key) == 0) {
auto extracted = std::make_unique<MatrixPtrStmt>(
alloca_stmt, first_const_[offset]);
first_matrix_ptr_[key] = extracted.get();
delayed_modifier_.insert_after(alloca_stmt, std::move(extracted));
}
auto new_stmt = first_matrix_ptr_[key];
immediate_modifier_.replace_usages_with(stmt, new_stmt);
delayed_modifier_.erase(stmt);
}
}
}

private:
using BasicStmtVisitor::visit;
};

namespace irpass {

void scalarize(IRNode *root, const CompileConfig &config) {
TI_AUTO_PROF;
Scalarize scalarize_pass(root);
if (!config.dynamic_index) {
ScalarizePointers scalarize_pointers_pass(root);
} else {
ExtractPointers extract_pointers_pass(root);
}
}

Expand Down

0 comments on commit ae0882c

Please sign in to comment.