diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index 732da3eb94051..e211395477727 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -130,7 +130,22 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const { // store_ptr: prev-store dest_addr for (auto store_ptr : irpass::analysis::get_store_destination(block->statements[i].get())) { - if (irpass::analysis::definitely_same_address(var, store_ptr)) { + // Exclude `store_ptr` as a potential store destination due to mixed + // semantics of store statements for quant types. The store operation + // involves implicit casting before storing, which may result in a loss of + // precision. For example: + // $3 = const 233333 + // <*qi4> $4 = global ptr [S3place], index [$1] activate=false + // $5 : global store [$4 <- $3] + // $6 = global load $4 + // The store cannot be forwarded because $3 is first casted to a qi4 and + // then stored into $4. Since 233333 won't fit into a qi4, the store leads + // to truncation, resulting in a different value stored in $4 compared to + // $3. + // TODO: Still forward the store if the value can be statically proven to + // fit into the quant type. + if (!is_quant(store_ptr->ret_type.ptr_removed()) && + irpass::analysis::definitely_same_address(var, store_ptr)) { last_def_position = i; break; } diff --git a/taichi/transforms/optimize_bit_struct_stores.cpp b/taichi/transforms/optimize_bit_struct_stores.cpp index 187f644846d7e..d09d946930bd9 100644 --- a/taichi/transforms/optimize_bit_struct_stores.cpp +++ b/taichi/transforms/optimize_bit_struct_stores.cpp @@ -62,28 +62,33 @@ class MergeBitStructStores : public BasicStmtVisitor { auto &statements = block->statements; std::unordered_map> ptr_to_bit_struct_stores; + std::unordered_set loaded_after_store; std::vector statements_to_delete; for (int i = 0; i <= (int)statements.size(); i++) { // TODO: in some cases BitStructStoreStmts across container statements can // still be merged, similar to basic block v.s. CFG optimizations. if (i == statements.size() || statements[i]->is_container_statement()) { - for (const auto &item : ptr_to_bit_struct_stores) { - auto ptr = item.first; - auto stores = item.second; - if (stores.size() == 1) + for (const auto &[ptr, stores] : ptr_to_bit_struct_stores) { + if (stores.size() == 1) { continue; - std::map values; - for (auto s : stores) { - for (int j = 0; j < (int)s->ch_ids.size(); j++) { - values[s->ch_ids[j]] = s->values[j]; - } } + std::vector ch_ids; std::vector store_values; - for (auto &ch_id_and_value : values) { - ch_ids.push_back(ch_id_and_value.first); - store_values.push_back(ch_id_and_value.second); + for (auto s : stores) { + for (int j = 0; j < (int)s->ch_ids.size(); j++) { + auto const &ch_id = s->ch_ids[j]; + auto const &store_value = s->values[j]; + ch_ids.push_back(ch_id); + store_values.push_back(store_value); + } } + + auto ch_ids_dup = [&ch_ids = std::as_const(ch_ids)]() { + std::unordered_set ch_ids_set(ch_ids.begin(), ch_ids.end()); + return ch_ids_set.size() != ch_ids.size(); + }; + TI_ASSERT(!ch_ids_dup()); // Now erase all (except the last) related BitSturctStoreStmts. // Replace the last one with a merged version. for (int j = 0; j < (int)stores.size() - 1; j++) { @@ -94,10 +99,38 @@ class MergeBitStructStores : public BasicStmtVisitor { modified_ = true; } ptr_to_bit_struct_stores.clear(); - continue; + loaded_after_store.clear(); } - if (auto stmt = statements[i]->cast()) { - ptr_to_bit_struct_stores[stmt->ptr].push_back(stmt); + // Skip bit store fusion when there's a load between multiple stores. + // Example: + // <^qi16> $18 = get child [...] $17 + // $178 : atomic bit_struct_store $17, ch_ids=[0], values=[$11] + // $20 = global load $18 + // print "x[i]=", $20, "\n" + // $22 = add $11 $2 + // <^qi16> $23 = get child [...] $17 + // $179 : atomic bit_struct_store $17, ch_ids=[1], values=[$22] + // $25 = global load $23 + // print "y[i]=", $25, "\n" + // In this case, $178 and $179 cannot be merged into a single store + // because the stored value $11 is loaded as $20 and then printed. + else if (auto stmt = statements[i]->cast()) { + // Phase 2: Find bit store after a marked load + if (loaded_after_store.find(stmt->ptr) != loaded_after_store.end()) { + // Disable store fusion for this bit struct + ptr_to_bit_struct_stores.erase(stmt->ptr); + loaded_after_store.erase(stmt->ptr); + } else { + ptr_to_bit_struct_stores[stmt->ptr].push_back(stmt); + } + } else if (auto load_stmt = statements[i]->cast()) { + // Phase 1: Find and mark any global loads after bit_struct_store + auto const &load_ops = load_stmt->src->get_operands(); + auto load_src = load_ops.empty() ? nullptr : load_ops.front(); + if (ptr_to_bit_struct_stores.find(load_src) != + ptr_to_bit_struct_stores.end()) { + loaded_after_store.insert(load_src); + } } } diff --git a/tests/python/test_quant_int.py b/tests/python/test_quant_int.py index fe40968703e71..c106c8051d7f6 100644 --- a/tests/python/test_quant_int.py +++ b/tests/python/test_quant_int.py @@ -1,6 +1,12 @@ import taichi as ti from tests import test_utils +# TODO: validation layer support on macos vulkan backend is not working. +vk_on_mac = (ti.vulkan, "Darwin") + +# TODO: capfd doesn't function well on CUDA backend on Windows +cuda_on_windows = (ti.cuda, "Windows") + @test_utils.test(require=ti.extension.quant_basic) def test_quant_int_implicit_cast(): @@ -17,3 +23,93 @@ def foo(): foo() assert x[None] == 10 + + +@test_utils.test( + require=ti.extension.quant_basic, + arch=[ti.cpu, ti.cuda, ti.vulkan], + exclude=[vk_on_mac, cuda_on_windows], + debug=True, +) +def test_quant_store_fusion(capfd): + x = ti.field(dtype=ti.types.quant.int(16, True)) + y = ti.field(dtype=ti.types.quant.int(16, True)) + v = ti.BitpackedFields(max_num_bits=32) + v.place(x, y) + ti.root.dense(ti.i, 10).place(v) + + # should fuse store + @ti.kernel + def store(): + ti.loop_config(serialize=True) + for i in range(10): + x[i] = i + y[i] = i + 1 + print(x[i], y[i]) + + store() + ti.sync() + + out, err = capfd.readouterr() + expected_out = """0 1 +1 2 +2 3 +3 4 +4 5 +5 6 +6 7 +7 8 +8 9 +9 10 +""" + assert out == expected_out and err == "" + + +@test_utils.test( + require=ti.extension.quant_basic, + arch=[ti.cpu, ti.cuda, ti.vulkan], + exclude=[vk_on_mac, cuda_on_windows], + debug=True, +) +def test_quant_store_no_fusion(capfd): + x = ti.field(dtype=ti.types.quant.int(16, True)) + y = ti.field(dtype=ti.types.quant.int(16, True)) + v = ti.BitpackedFields(max_num_bits=32) + v.place(x, y) + ti.root.dense(ti.i, 10).place(v) + + @ti.kernel + def store(): + ti.loop_config(serialize=True) + for i in range(10): + x[i] = i + print(x[i]) + y[i] = i + 1 + print(y[i]) + + store() + ti.sync() + + out, err = capfd.readouterr() + expected_out = """0 +1 +1 +2 +2 +3 +3 +4 +4 +5 +5 +6 +6 +7 +7 +8 +8 +9 +9 +10 +""" + assert out == expected_out and err == ""