Skip to content

Commit

Permalink
[bug] Exclude quant type when doing store-to-load forwarding and skip…
Browse files Browse the repository at this point in the history
… bit struct store fusion when unfeasible (#8023)

Issue: #7981

### Brief Summary

<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at cef23dc</samp>

Exclude `qi4` store pointers from forwarding optimization in live
variable analysis. This prevents precision loss when quantizing integers
in the Taichi compiler.

### Walkthrough

<!--
copilot:walkthrough
-->
### <samp>🤖 Generated by Copilot at cef23dc</samp>

* Exclude quantized integer store pointers from forwarding optimization
to avoid precision loss
([link](https://github.com/taichi-dev/taichi/pull/8023/files?diff=unified&w=0#diff-837b90142d1730f6a3ab20c91f1f35c95335ef82a021c74fd4dbdb05ff0e164fL128-R139))
  • Loading branch information
dream189free authored May 24, 2023
1 parent 258fa5a commit 3660ac8
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 16 deletions.
17 changes: 16 additions & 1 deletion taichi/ir/control_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,22 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const {
// store_ptr: prev-store dest_addr
for (auto store_ptr :
irpass::analysis::get_store_destination(block->statements[i].get())) {
if (irpass::analysis::definitely_same_address(var, store_ptr)) {
// Exclude `store_ptr` as a potential store destination due to mixed
// semantics of store statements for quant types. The store operation
// involves implicit casting before storing, which may result in a loss of
// precision. For example:
// <i32> $3 = const 233333
// <*qi4> $4 = global ptr [S3place<qi4><bit>], index [$1] activate=false
// $5 : global store [$4 <- $3]
// <i32> $6 = global load $4
// The store cannot be forwarded because $3 is first casted to a qi4 and
// then stored into $4. Since 233333 won't fit into a qi4, the store leads
// to truncation, resulting in a different value stored in $4 compared to
// $3.
// TODO: Still forward the store if the value can be statically proven to
// fit into the quant type.
if (!is_quant(store_ptr->ret_type.ptr_removed()) &&
irpass::analysis::definitely_same_address(var, store_ptr)) {
last_def_position = i;
break;
}
Expand Down
63 changes: 48 additions & 15 deletions taichi/transforms/optimize_bit_struct_stores.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,28 +62,33 @@ class MergeBitStructStores : public BasicStmtVisitor {
auto &statements = block->statements;
std::unordered_map<Stmt *, std::vector<BitStructStoreStmt *>>
ptr_to_bit_struct_stores;
std::unordered_set<Stmt *> loaded_after_store;
std::vector<Stmt *> statements_to_delete;
for (int i = 0; i <= (int)statements.size(); i++) {
// TODO: in some cases BitStructStoreStmts across container statements can
// still be merged, similar to basic block v.s. CFG optimizations.
if (i == statements.size() || statements[i]->is_container_statement()) {
for (const auto &item : ptr_to_bit_struct_stores) {
auto ptr = item.first;
auto stores = item.second;
if (stores.size() == 1)
for (const auto &[ptr, stores] : ptr_to_bit_struct_stores) {
if (stores.size() == 1) {
continue;
std::map<int, Stmt *> values;
for (auto s : stores) {
for (int j = 0; j < (int)s->ch_ids.size(); j++) {
values[s->ch_ids[j]] = s->values[j];
}
}

std::vector<int> ch_ids;
std::vector<Stmt *> store_values;
for (auto &ch_id_and_value : values) {
ch_ids.push_back(ch_id_and_value.first);
store_values.push_back(ch_id_and_value.second);
for (auto s : stores) {
for (int j = 0; j < (int)s->ch_ids.size(); j++) {
auto const &ch_id = s->ch_ids[j];
auto const &store_value = s->values[j];
ch_ids.push_back(ch_id);
store_values.push_back(store_value);
}
}

auto ch_ids_dup = [&ch_ids = std::as_const(ch_ids)]() {
std::unordered_set<int> ch_ids_set(ch_ids.begin(), ch_ids.end());
return ch_ids_set.size() != ch_ids.size();
};
TI_ASSERT(!ch_ids_dup());
// Now erase all (except the last) related BitSturctStoreStmts.
// Replace the last one with a merged version.
for (int j = 0; j < (int)stores.size() - 1; j++) {
Expand All @@ -94,10 +99,38 @@ class MergeBitStructStores : public BasicStmtVisitor {
modified_ = true;
}
ptr_to_bit_struct_stores.clear();
continue;
loaded_after_store.clear();
}
if (auto stmt = statements[i]->cast<BitStructStoreStmt>()) {
ptr_to_bit_struct_stores[stmt->ptr].push_back(stmt);
// Skip bit store fusion when there's a load between multiple stores.
// Example:
// <^qi16> $18 = get child [...] $17
// $178 : atomic bit_struct_store $17, ch_ids=[0], values=[$11]
// <i32> $20 = global load $18
// print "x[i]=", $20, "\n"
// <i32> $22 = add $11 $2
// <^qi16> $23 = get child [...] $17
// $179 : atomic bit_struct_store $17, ch_ids=[1], values=[$22]
// <i32> $25 = global load $23
// print "y[i]=", $25, "\n"
// In this case, $178 and $179 cannot be merged into a single store
// because the stored value $11 is loaded as $20 and then printed.
else if (auto stmt = statements[i]->cast<BitStructStoreStmt>()) {
// Phase 2: Find bit store after a marked load
if (loaded_after_store.find(stmt->ptr) != loaded_after_store.end()) {
// Disable store fusion for this bit struct
ptr_to_bit_struct_stores.erase(stmt->ptr);
loaded_after_store.erase(stmt->ptr);
} else {
ptr_to_bit_struct_stores[stmt->ptr].push_back(stmt);
}
} else if (auto load_stmt = statements[i]->cast<GlobalLoadStmt>()) {
// Phase 1: Find and mark any global loads after bit_struct_store
auto const &load_ops = load_stmt->src->get_operands();
auto load_src = load_ops.empty() ? nullptr : load_ops.front();
if (ptr_to_bit_struct_stores.find(load_src) !=
ptr_to_bit_struct_stores.end()) {
loaded_after_store.insert(load_src);
}
}
}

Expand Down
96 changes: 96 additions & 0 deletions tests/python/test_quant_int.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import taichi as ti
from tests import test_utils

# TODO: validation layer support on macos vulkan backend is not working.
vk_on_mac = (ti.vulkan, "Darwin")

# TODO: capfd doesn't function well on CUDA backend on Windows
cuda_on_windows = (ti.cuda, "Windows")


@test_utils.test(require=ti.extension.quant_basic)
def test_quant_int_implicit_cast():
Expand All @@ -17,3 +23,93 @@ def foo():

foo()
assert x[None] == 10


@test_utils.test(
require=ti.extension.quant_basic,
arch=[ti.cpu, ti.cuda, ti.vulkan],
exclude=[vk_on_mac, cuda_on_windows],
debug=True,
)
def test_quant_store_fusion(capfd):
x = ti.field(dtype=ti.types.quant.int(16, True))
y = ti.field(dtype=ti.types.quant.int(16, True))
v = ti.BitpackedFields(max_num_bits=32)
v.place(x, y)
ti.root.dense(ti.i, 10).place(v)

# should fuse store
@ti.kernel
def store():
ti.loop_config(serialize=True)
for i in range(10):
x[i] = i
y[i] = i + 1
print(x[i], y[i])

store()
ti.sync()

out, err = capfd.readouterr()
expected_out = """0 1
1 2
2 3
3 4
4 5
5 6
6 7
7 8
8 9
9 10
"""
assert out == expected_out and err == ""


@test_utils.test(
require=ti.extension.quant_basic,
arch=[ti.cpu, ti.cuda, ti.vulkan],
exclude=[vk_on_mac, cuda_on_windows],
debug=True,
)
def test_quant_store_no_fusion(capfd):
x = ti.field(dtype=ti.types.quant.int(16, True))
y = ti.field(dtype=ti.types.quant.int(16, True))
v = ti.BitpackedFields(max_num_bits=32)
v.place(x, y)
ti.root.dense(ti.i, 10).place(v)

@ti.kernel
def store():
ti.loop_config(serialize=True)
for i in range(10):
x[i] = i
print(x[i])
y[i] = i + 1
print(y[i])

store()
ti.sync()

out, err = capfd.readouterr()
expected_out = """0
1
1
2
2
3
3
4
4
5
5
6
6
7
7
8
8
9
9
10
"""
assert out == expected_out and err == ""

0 comments on commit 3660ac8

Please sign in to comment.