diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 03daa799044f6..cf967ed4dac43 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -1501,6 +1501,11 @@ def field(self, **kwargs): kwargs.update({"ndim": self.ndim}) return Matrix.field(self.n, self.m, dtype=self.dtype, **kwargs) + def ndarray(self, **kwargs): + assert kwargs.get("ndim", self.ndim) == self.ndim + kwargs.update({"ndim": self.ndim}) + return Matrix.ndarray(self.n, self.m, dtype=self.dtype, **kwargs) + def get_shape(self): if self.ndim == 1: return (self.n,) @@ -1598,6 +1603,9 @@ def _instantiate(self, entries): def field(self, **kwargs): return Vector.field(self.n, dtype=self.dtype, **kwargs) + def ndarray(self, **kwargs): + return Vector.ndarray(self.n, dtype=self.dtype, **kwargs) + def to_string(self): dtype_str = self.dtype.to_string() if self.dtype is not None else "" return f"VectorType[{self.n}, {dtype_str}]" diff --git a/python/taichi/linalg/matrixfree_cg.py b/python/taichi/linalg/matrixfree_cg.py index 0d333f4bbcea0..22bb51a4fd70c 100644 --- a/python/taichi/linalg/matrixfree_cg.py +++ b/python/taichi/linalg/matrixfree_cg.py @@ -47,7 +47,15 @@ def MatrixFreeCG(A, b, x, tol=1e-6, maxiter=5000, quiet=True): p = ti.field(dtype=solver_dtype) r = ti.field(dtype=solver_dtype) Ap = ti.field(dtype=solver_dtype) - vector_fields_builder.dense(ti.ij, size).place(p, r, Ap) + if len(size) == 1: + axes = ti.i + elif len(size) == 2: + axes = ti.ij + elif len(size) == 3: + axes = ti.ijk + else: + raise TaichiRuntimeError(f"MatrixFreeCG only support 1D, 2D, 3D inputs; your inputs is {len(size)}-D.") + vector_fields_builder.dense(axes, size).place(p, r, Ap) vector_fields_snode_tree = vector_fields_builder.finalize() scalar_builder = ti.FieldsBuilder() diff --git a/python/taichi/linalg/sparse_matrix.py b/python/taichi/linalg/sparse_matrix.py index 443afd933aae3..691f76b1cb912 100644 --- a/python/taichi/linalg/sparse_matrix.py +++ b/python/taichi/linalg/sparse_matrix.py @@ -284,12 +284,12 @@ def build(self, dtype=f32, _format="CSR"): taichi_arch = get_runtime().prog.config().arch if taichi_arch in [_ti_core.Arch.x64, _ti_core.Arch.arm64]: sm = self.ptr.build() - return SparseMatrix(sm=sm, dtype=dtype) + return SparseMatrix(sm=sm, dtype=self.dtype) if taichi_arch == _ti_core.Arch.cuda: - if dtype != f32: + if self.dtype != f32: raise TaichiRuntimeError("CUDA sparse matrix only supports f32.") sm = self.ptr.build_cuda() - return SparseMatrix(sm=sm, dtype=dtype) + return SparseMatrix(sm=sm, dtype=self.dtype) raise TaichiRuntimeError("Sparse matrix only supports CPU and CUDA backends.") diff --git a/python/taichi/linalg/sparse_solver.py b/python/taichi/linalg/sparse_solver.py index 05775c0dda496..3850cbee227f4 100644 --- a/python/taichi/linalg/sparse_solver.py +++ b/python/taichi/linalg/sparse_solver.py @@ -21,6 +21,7 @@ class SparseSolver: def __init__(self, dtype=f32, solver_type="LLT", ordering="AMD"): self.matrix = None + self.dtype = dtype solver_type_list = ["LLT", "LDLT", "LU"] solver_ordering = ["AMD", "COLAMD"] if solver_type in solver_type_list and ordering in solver_ordering: @@ -70,6 +71,10 @@ def analyze_pattern(self, sparse_matrix): """ if isinstance(sparse_matrix, SparseMatrix): self.matrix = sparse_matrix + if self.matrix.dtype != self.dtype: + raise TaichiRuntimeError( + f"The SparseSolver's dtype {self.dtype} is not consistent with the SparseMatrix's dtype {self.matrix.dtype}." + ) self.solver.analyze_pattern(sparse_matrix.matrix) else: self._type_assert(sparse_matrix) diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index cc8e9472d4f70..d546900f83724 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -384,7 +384,10 @@ void CFGNode::reaching_definition_analysis(bool after_lower_access) { auto data_source_ptrs = irpass::analysis::get_store_destination(stmt); for (auto data_source_ptr : data_source_ptrs) { // stmt provides a data source - if (after_lower_access && !(data_source_ptr->is())) { + if (after_lower_access && + !((data_source_ptr->is() && + data_source_ptr->as()->origin->is()) || + data_source_ptr->is())) { // After lower_access, we only analyze local variables. continue; } @@ -552,6 +555,8 @@ void CFGNode::live_variable_analysis(bool after_lower_access) { irpass::analysis::get_load_pointers(stmt, true /*get_alias*/); for (auto &load_ptr : load_ptrs) { if (!after_lower_access || + (load_ptr->is() && + load_ptr->as()->origin->is()) || (load_ptr->is() || load_ptr->is())) { // After lower_access, we only analyze local variables and stacks. if (!contain_variable(live_kill, load_ptr)) { @@ -576,6 +581,8 @@ void CFGNode::live_variable_analysis(bool after_lower_access) { } for (auto store_ptr : store_ptrs) { if (!after_lower_access || + (store_ptr->is() && + store_ptr->as()->origin->is()) || (store_ptr->is() || store_ptr->is())) { // After lower_access, we only analyze local variables and stacks. live_kill.insert(store_ptr); @@ -707,6 +714,8 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { auto store_ptr = *store_ptrs.begin(); if (!after_lower_access || + (store_ptr->is() && + store_ptr->as()->origin->is()) || (store_ptr->is() || store_ptr->is())) { // !may_contain_variable(live_in_this_node, store_ptr): address is not // loaded after this store @@ -806,6 +815,8 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { auto load_ptr = load_ptrs.begin()[0]; if (!after_lower_access || + (load_ptr->is() && + load_ptr->as()->origin->is()) || (load_ptr->is() || load_ptr->is())) { // live_load_in_this_node[addr]: tracks the // next load to the same address @@ -832,6 +843,8 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { // Update live_in_this_node for (auto &load_ptr : load_ptrs) { if (!after_lower_access || + (load_ptr->is() && + load_ptr->as()->origin->is()) || (load_ptr->is() || load_ptr->is())) { // Addr is used in this node, so it's live in this node update_container_with_alias(tensor_to_matrix_ptrs_map, diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index 7b48b949df0c7..14bfc11c2611a 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -511,7 +511,7 @@ class Block : public IRNode { // variables, and AllocaStmt for other variables. std::map local_var_to_stmt; - Block(Kernel *kernel = nullptr) { + explicit Block(Kernel *kernel = nullptr) { parent_ = kernel; } diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h index 95b4d29cc5764..37cdd1ba69207 100644 --- a/taichi/ir/transforms.h +++ b/taichi/ir/transforms.h @@ -30,8 +30,7 @@ namespace irpass { void re_id(IRNode *root); void flag_access(IRNode *root); void eliminate_immutable_local_vars(IRNode *root); -bool scalarize(IRNode *root); -void vectorize_half2(IRNode *root); +bool scalarize(IRNode *root, bool half2_optimization_enabled = false); void lower_matrix_ptr(IRNode *root); bool die(IRNode *root); bool simplify(IRNode *root, const CompileConfig &config); diff --git a/taichi/program/ndarray.cpp b/taichi/program/ndarray.cpp index 6aa326fbb2584..73d977c8a6590 100644 --- a/taichi/program/ndarray.cpp +++ b/taichi/program/ndarray.cpp @@ -2,6 +2,7 @@ #include "taichi/program/ndarray.h" #include "taichi/program/program.h" +#include "fp16.h" #ifdef TI_WITH_LLVM #include "taichi/runtime/llvm/llvm_context.h" @@ -168,10 +169,20 @@ TypedConstant Ndarray::read(const std::vector &I) const { TypedConstant data(get_element_data_type()); std::memcpy(&data.value_bits, device_arr_ptr, size); staging_buf_->device->unmap(*staging_buf_); + + if (get_element_data_type()->is_primitive(PrimitiveTypeID::f16)) { + float float32 = fp16_ieee_to_fp32_value(data.val_u16); + data.val_f32 = float32; + } return data; } void Ndarray::write(const std::vector &I, TypedConstant val) const { + if (get_element_data_type()->is_primitive(PrimitiveTypeID::f16)) { + uint16_t float16 = fp16_ieee_from_fp32_value(val.val_f32); + std::memcpy(&val.value_bits, &float16, 4); + } + size_t index = flatten_index(total_shape_, I); size_t size_ = data_type_size(get_element_data_type()); taichi::lang::Device::AllocParams alloc_params; diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index f1f8742bde5aa..43413f11aaf6a 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -271,15 +271,6 @@ void offload_to_executable(IRNode *ir, irpass::demote_operations(ir, config); print("Operations demoted"); - if (config.real_matrix_scalarize) { - if (irpass::scalarize(ir)) { - // Remove redundant MatrixInitStmt inserted during scalarization - irpass::full_simplify(ir, config, - {lower_global_access, /*autodiff_enabled*/ false}); - print("Scalarized"); - } - } - irpass::full_simplify(ir, config, {lower_global_access, /*autodiff_enabled*/ false}); print("Simplified IV"); @@ -294,17 +285,16 @@ void offload_to_executable(IRNode *ir, print("Bit struct stores optimized"); } - if (config.arch == Arch::cuda && config.half2_vectorization && - !get_custom_cuda_library_path().empty()) { - irpass::vectorize_half2(ir); - - irpass::type_check(ir, config); - - irpass::full_simplify(ir, config, - {lower_global_access, /*autodiff_enabled*/ false}); - - irpass::flag_access(ir); - print("Half2 vectorized"); + bool half2_optimization_enabled = + (config.arch == Arch::cuda && config.half2_vectorization && + !get_custom_cuda_library_path().empty()); + if (config.real_matrix_scalarize) { + if (irpass::scalarize(ir, half2_optimization_enabled)) { + // Remove redundant MatrixInitStmt inserted during scalarization + irpass::full_simplify(ir, config, + {lower_global_access, /*autodiff_enabled*/ false}); + print("Scalarized"); + } } // Final field registration correctness & type checking diff --git a/taichi/transforms/scalarize.cpp b/taichi/transforms/scalarize.cpp index 7e84c6bd40121..796a9939e491e 100644 --- a/taichi/transforms/scalarize.cpp +++ b/taichi/transforms/scalarize.cpp @@ -16,8 +16,11 @@ class Scalarize : public BasicStmtVisitor { public: ImmediateIRModifier immediate_modifier_; DelayedIRModifier delayed_modifier_; + bool half2_optimization_enabled_ = false; - explicit Scalarize(IRNode *node) : immediate_modifier_(node) { + explicit Scalarize(IRNode *node, bool half2_optimization) + : immediate_modifier_(node), + half2_optimization_enabled_(half2_optimization) { } /* @@ -414,7 +417,90 @@ class Scalarize : public BasicStmtVisitor { void visit(AtomicOpStmt *stmt) override { auto dest_dtype = stmt->dest->ret_type.ptr_removed(); auto val_dtype = stmt->val->ret_type; - if (dest_dtype->is() || val_dtype->is()) { + + bool half2_optimizable = half2_optimization_enabled_; + bool is_tensor_type = false; + if (dest_dtype->is()) { + is_tensor_type = true; + half2_optimizable &= + (dest_dtype.get_element_type() == PrimitiveType::f16); + half2_optimizable &= + (dest_dtype->as()->get_num_elements() == 2); + } else { + half2_optimizable = false; + } + is_tensor_type |= val_dtype->is(); + + if (half2_optimizable) { + /* + Before: + TensorType<2 x i32> old_val = AtomicStmt(TensorType<2 x i32>* dest, + TensorType<2 x i32> val) + + After: + TensorType<2 x i32> old_val = AtomicStmt(TensorType<2 x i32>* dest, + TensorType<2 x i32> val) + + TensorType(2, f16)* old_val_alloc = AllocaStmt(TensorType(2, f16)) + LocalStoreStmt(old_val_alloc, old_val) + + f16* old_val_ptr0 = MatrixPtrStmt(old_val_alloc, 0) + f16* old_val_ptr1 = MatrixPtrStmt(old_val_alloc, 0) + + f16 old_val0 = LoadStmt(old_val_ptr0) + f16 old_val1 = LoadStmt(old_val_ptr1) + + tmp = MatrixInitStmt(old_val0, old_val1) + + stmt->replace_all_usages_with(tmp) + */ + auto atomic_stmt = + std::make_unique(stmt->op_type, stmt->dest, stmt->val); + atomic_stmt->ret_type = stmt->ret_type; + + auto alloca_stmt = std::make_unique(dest_dtype); + + auto local_store_stmt = std::make_unique( + alloca_stmt.get(), atomic_stmt.get()); + + auto zero = + std::make_unique(TypedConstant(PrimitiveType::i32, 0)); + auto one = + std::make_unique(TypedConstant(PrimitiveType::i32, 1)); + + auto matrix_ptr_0 = + std::make_unique(alloca_stmt.get(), zero.get()); + auto matrix_ptr_1 = + std::make_unique(alloca_stmt.get(), one.get()); + matrix_ptr_0->ret_type = PrimitiveType::f16; + matrix_ptr_0->ret_type.set_is_pointer(true); + matrix_ptr_1->ret_type = PrimitiveType::f16; + matrix_ptr_1->ret_type.set_is_pointer(true); + + auto load_stmt_0 = std::make_unique(matrix_ptr_0.get()); + auto load_stmt_1 = std::make_unique(matrix_ptr_1.get()); + load_stmt_0->ret_type = PrimitiveType::f16; + load_stmt_1->ret_type = PrimitiveType::f16; + + auto matrix_init_stmt = std::make_unique( + std::vector{load_stmt_0.get(), load_stmt_1.get()}); + matrix_init_stmt->ret_type = stmt->ret_type; + + immediate_modifier_.replace_usages_with(stmt, matrix_init_stmt.get()); + delayed_modifier_.insert_before(stmt, std::move(atomic_stmt)); + delayed_modifier_.insert_before(stmt, std::move(alloca_stmt)); + delayed_modifier_.insert_before(stmt, std::move(local_store_stmt)); + delayed_modifier_.insert_before(stmt, std::move(zero)); + delayed_modifier_.insert_before(stmt, std::move(one)); + delayed_modifier_.insert_before(stmt, std::move(matrix_ptr_0)); + delayed_modifier_.insert_before(stmt, std::move(matrix_ptr_1)); + delayed_modifier_.insert_before(stmt, std::move(load_stmt_0)); + delayed_modifier_.insert_before(stmt, std::move(load_stmt_1)); + delayed_modifier_.insert_before(stmt, std::move(matrix_init_stmt)); + + delayed_modifier_.erase(stmt); + + } else if (is_tensor_type) { // Make sure broadcasting has been correctly applied by // AtomicOpExpression::type_check(). TI_ASSERT(dest_dtype->is() && val_dtype->is()); @@ -838,8 +924,8 @@ class Scalarize : public BasicStmtVisitor { } } - static bool run(IRNode *node) { - Scalarize pass(node); + static bool run(IRNode *node, bool half2_optimization_enabled) { + Scalarize pass(node, half2_optimization_enabled); node->accept(&pass); return pass.delayed_modifier_.modify_ir(); } @@ -1192,11 +1278,11 @@ class MergeExternalAndMatrixPtr : public BasicStmtVisitor { namespace irpass { -bool scalarize(IRNode *root) { +bool scalarize(IRNode *root, bool half2_optimization_enabled) { TI_AUTO_PROF; bool modified = false; - modified |= Scalarize::run(root); + modified |= Scalarize::run(root, half2_optimization_enabled); auto scalarizable_allocas = GatherScalarizableLocalPointers::run(root); modified |= ScalarizePointers::run(root, scalarizable_allocas); modified |= ExtractLocalPointers::run(root); diff --git a/taichi/transforms/vectorize_half2.cpp b/taichi/transforms/vectorize_half2.cpp deleted file mode 100644 index d159092357b66..0000000000000 --- a/taichi/transforms/vectorize_half2.cpp +++ /dev/null @@ -1,595 +0,0 @@ -#include - -#include "taichi/ir/ir.h" -#include "taichi/ir/statements.h" -#include "taichi/ir/transforms.h" -#include "taichi/ir/visitors.h" -#include "taichi/program/program.h" -#include "taichi/system/profiler.h" - -/* Auto vectorization for "half2" - TensorType(fp16, 2) - -In general, this Pass detects two AtomicOpStmt-Add with adjacent "dest_ptr" of -float16*, then merge them into AtomicOpStmt(TensorType(2 x fp16)*). This -vectorization is a pre-requisite to enable CUDA half2 optimization for float16 -atomic add. - -TODO: Remove this pass once irpass::scalarize() is turned off by default in CHI -IR. - -The "dest_ptr" may point to various data structures and we have to handle them -separately as follow: - -[Ndarray] Implemented -Condition: Two ExternalPtrStmts having the same base_ptr & inner most -indices are 0 and 1 - -Before: - i32 const_0 = ConstStmt(0) - i32 const_1 = ConstStmt(1) - - f16* ptr_0 = ExternalPtrStmt(arg, [$1, const_0]) - f16* ptr_1 = ExternalPtrStmt(arg, [$1, const_1]) - - f16 old_val0 = AtomicStmt(ptr_0, $7) - f16 old_val1 = AtomicStmt(ptr_1, $8) - -After: - TensorType(2, f16) val = MatrixInitStmt([$7, $8]) - - TensorType(2, f16)* ptr = ExternalPtrStmt(arg, [$1]) - TensorType(2, f16) old_val = AtomicStmt(ptr, val) - - TensorType(2, f16)* old_val_alloc = AllocaStmt(TensorType(2, f16)) - StoreStmt(old_val, old_val_alloc) - - f16 old_val0 = MatrixPtrStmt(old_val_alloc, 0) - f16 old_val1 = MatrixPtrStmt(old_val_alloc, 1) - - alloca_stmt0->replace_all_usages_with(old_val0); - alloca_stmt1->replace_all_usages_with(old_val1); - -[GlobalTemp] Implemented -Condition: Two GlobalTempStmts' offsets diff is 2 -Before: - f16* ptr_0 = GlobalTempStmt(offset0) - f16* ptr_1 = GlobalTempStmt(offset0 + 2) - - f16 old_val0 = AtomicStmt(ptr_0, $7) - f16 old_val1 = AtomicStmt(ptr_1, $8) - -After: - TensorType(2, f16) val = MatrixInitStmt([$7, $8]) - - TensorType(2, f16)* ptr = GlobalTempStmt(offset0) - TensorType(2, f16) old_val = AtomicStmt(ptr, val) - - TensorType(2, f16)* old_val_alloc = AllocaStmt(TensorType(2, f16)) - StoreStmt(old_val, old_val_alloc) - - f16 old_val0 = MatrixPtrStmt(old_val_alloc, 0) - f16 old_val1 = MatrixPtrStmt(old_val_alloc, 1) - - alloca_stmt0->replace_all_usages_with(old_val0); - alloca_stmt1->replace_all_usages_with(old_val1); - -[Field] Implemented -Condition: Two GetChStmts having the same container & chids are 0 and 1 -Before: - gen* container = SNodeLookupStmt(...) - - fp16* ptr_0 = GetChStmt(container, 0) - fp16* ptr_1 = GetChStmt(container, 1) - - f16 old_val0 = AtomicStmt(ptr_0, $7) - f16 old_val1 = AtomicStmt(ptr_1, $8) - -After: - TensorType(2, f16) val = MatrixInitStmt([$7, $8]) - - TensorType(2, f16)* ptr = GetChStmt(container, 0) - TensorType(2, f16) old_val = AtomicStmt(ptr, val) - - TensorType(2, f16)* old_val_alloc = AllocaStmt(TensorType(2, f16)) - StoreStmt(old_val, old_val_alloc) - - f16 old_val0 = MatrixPtrStmt(old_val_alloc, 0) - f16 old_val1 = MatrixPtrStmt(old_val_alloc, 1) - - alloca_stmt0->replace_all_usages_with(old_val0); - alloca_stmt1->replace_all_usages_with(old_val1); - -[Alloca] To be implemented -Before: - f16* ptr_0 = AllocaStmt(fp16) - f16* ptr_1 = AllocaStmt(fp16) - - f16 old_val0 = AtomicStmt(ptr_0, $7) - f16 old_val1 = AtomicStmt(ptr_1, $8) - -After: - TensorType(2, f16)* ptr = AllocaStmt(TensorType(2, f16)) - f16* ptr_0 = MatrixPtrStmt(ptr, 0) - f16* ptr_1 = MatrixPtrStmt(ptr, 1) - alloca_stmt0.replace_all_usages_with(ptr_0) - alloca_stmt1.replace_all_usages_with(ptr_1) - - TensorType(2, f16) val = MatrixInitStmt([$7, $8]) - - TensorType(2, f16) old_val = AtomicStmt(ptr, val) - - TensorType(2, f16)* old_val_alloc = AllocaStmt(TensorType(2, f16)) - StoreStmt(old_val, old_val_alloc) - - f16 old_val0 = MatrixPtrStmt(old_val_alloc, 0) - f16 old_val1 = MatrixPtrStmt(old_val_alloc, 1) - - alloca_stmt0->replace_all_usages_with(old_val0); - alloca_stmt1->replace_all_usages_with(old_val1); -*/ - -namespace taichi::lang { - -class Half2VectorizationAnalyzer : public BasicStmtVisitor { - public: - explicit Half2VectorizationAnalyzer(IRNode *node) { - node->accept(this); - } - - std::unordered_set should_remove; - std::map - should_replace_extern; // self: other - std::map - should_replace_global_temp; // self: other - std::map - should_replace_get_ch; // self: other - - void visit(AtomicOpStmt *stmt) override { - // opt-out - if (stmt->ret_type != PrimitiveType::f16) { - return; - } - - if (stmt->op_type != AtomicOpType::add) { - return; - } - - int stmt_type; // 0: ExternalPtrStmt, 1: GlobalTemporaryStmt, 2: GetChStmt - if (stmt->dest->is() && - stmt->dest->cast()->indices.back()->is()) { - stmt_type = 0; - } else if (stmt->dest->is()) { - stmt_type = 1; - } else if (stmt->dest->is()) { - stmt_type = 2; - } else { - return; - } - - bool found_pair = false; - std::vector atomic_ops_to_remove; - for (auto iter : recorded_atomic_ops_) { - auto *atomic_op = iter; - if (stmt_type == 0) { - // [AtomicOpStmt with ExternalPtrStmt] - // vectorization patterns: - // 1. Same ExternalPtrStmt->base_ptr - // 2. Absolute diff of ExternalPtrStmt->offset is 1 (element) - auto *self_extern_stmt = stmt->dest->cast(); - auto *other_extern_stmt = atomic_op->dest->cast(); - - if (self_extern_stmt->base_ptr != other_extern_stmt->base_ptr) { - continue; - } - - std::vector self_extern_indices = self_extern_stmt->indices; - std::vector other_extern_indices = other_extern_stmt->indices; - if (self_extern_indices.size() != other_extern_indices.size()) { - continue; - } - - if (self_extern_indices.back()->cast()->val.val_int32() + - other_extern_indices.back() - ->cast() - ->val.val_int32() == - 1) { - // Found pair - atomic_ops_to_remove.push_back(iter); - - if (self_extern_indices.back()->cast()->val.val_int32() == - 0) { - should_remove.insert(atomic_op); - should_replace_extern[stmt] = atomic_op; - } else { - should_remove.insert(stmt); - should_replace_extern[atomic_op] = stmt; - } - found_pair = true; - break; - } - } - - if (stmt_type == 1) { - // [AtomicOpStmt with GlobalTemporaryStmt] - // vectorization patterns: - // 1. Absolute diff of GlobalTemporaryStmt->offset is 2 (bytes) - auto *self_global_temp_stmt = stmt->dest->cast(); - auto *other_global_temp_stmt = - atomic_op->dest->cast(); - - int offset_diff = - std::abs(static_cast(self_global_temp_stmt->offset) - - static_cast(other_global_temp_stmt->offset)); - if (offset_diff == 2) { - // Found pair - atomic_ops_to_remove.push_back(iter); - - if (self_global_temp_stmt->offset < other_global_temp_stmt->offset) { - should_remove.insert(atomic_op); - should_replace_global_temp[stmt] = atomic_op; - } else { - should_remove.insert(stmt); - should_replace_global_temp[atomic_op] = stmt; - } - found_pair = true; - break; - } - } - - if (stmt_type == 2) { - // [AtomicOpStmt with GetChStmt] - // vectorization patterns: - // 1. Same GetChStmt->input_ptr - // 2. Absolute diff of GetChStmt->chid is 1 (element) - auto *self_get_ch_stmt = stmt->dest->cast(); - auto *other_get_ch_stmt = atomic_op->dest->cast(); - - if (self_get_ch_stmt->input_ptr != other_get_ch_stmt->input_ptr) { - continue; - } - - if ((self_get_ch_stmt->chid == 0 && other_get_ch_stmt->chid == 1) || - (self_get_ch_stmt->chid == 1 && other_get_ch_stmt->chid == 0)) { - // Found pair - atomic_ops_to_remove.push_back(iter); - - if (self_get_ch_stmt->chid == 0) { - should_remove.insert(atomic_op); - should_replace_get_ch[stmt] = atomic_op; - } else { - should_remove.insert(stmt); - should_replace_get_ch[atomic_op] = stmt; - } - found_pair = true; - break; - } - } - } - - for (auto stmt : atomic_ops_to_remove) { - recorded_atomic_ops_.erase(stmt); - } - - if (!found_pair) - recorded_atomic_ops_.insert(stmt); - } - - void visit(OffloadedStmt *stmt) override { - if (stmt->tls_prologue) - stmt->tls_prologue->accept(this); - - if (stmt->body) { - stmt->body->accept(this); - } - - if (stmt->tls_epilogue) - stmt->tls_epilogue->accept(this); - } - - private: - std::unordered_set recorded_atomic_ops_; - using BasicStmtVisitor::visit; -}; - -class Half2Vectorize : public BasicStmtVisitor { - public: - DelayedIRModifier delayed_modifier_; - - explicit Half2Vectorize( - IRNode *node, - const std::unordered_set &should_remove, - const std::map &should_replace_extern, - const std::map - &should_replace_global_temp, - const std::map &should_replace_get_ch) { - this->should_remove = should_remove; - this->should_replace_extern = should_replace_extern; - this->should_replace_global_temp = should_replace_global_temp; - this->should_replace_get_ch = should_replace_get_ch; - - node->accept(this); - - delayed_modifier_.modify_ir(); - } - - std::unordered_set should_remove; - std::map - should_replace_extern; // self: other - std::map - should_replace_global_temp; // self: other - std::map - should_replace_get_ch; // self: other - - void visit(AtomicOpStmt *stmt) override { - if (should_remove.find(stmt) != should_remove.end()) { - delayed_modifier_.erase(stmt); - return; - } - - if (should_replace_extern.find(stmt) != should_replace_extern.end()) { - auto *self_extern_stmt = stmt->dest->cast(); - auto *self_ptr = self_extern_stmt->base_ptr; - std::vector self_indices = self_extern_stmt->indices; - auto *self_val = stmt->val; - - AtomicOpStmt *other_stmt = should_replace_extern[stmt]; - auto *other_extern_stmt = other_stmt->dest->cast(); - std::vector other_indices = other_extern_stmt->indices; - auto *other_val = other_stmt->val; - - // Create MatrixInitStmt - std::vector matrix_init_values; - matrix_init_values.push_back(self_val); - matrix_init_values.push_back(other_val); - - auto matrix_init_stmt = - std::make_unique(matrix_init_values); - auto tensor_type = - TypeFactory::get_instance().get_tensor_type({2}, PrimitiveType::f16); - matrix_init_stmt->ret_type = tensor_type; - - // Create ExternalPtrStmt - std::vector new_indices = self_indices; - new_indices.pop_back(); // Remove last index - - std::vector element_shape = {2}; - int element_dim = -1; - auto new_extern_stmt = std::make_unique( - self_ptr, new_indices, self_extern_stmt->ndim, element_shape, - element_dim); - new_extern_stmt->overrided_dtype = true; - new_extern_stmt->ret_type = tensor_type; - new_extern_stmt->ret_type.set_is_pointer(true); - - // Create AtomicStmt - auto new_atomic_stmt = std::make_unique( - AtomicOpType::add, new_extern_stmt.get(), matrix_init_stmt.get()); - new_atomic_stmt->ret_type = tensor_type; - - // Create AllocaStmt - auto new_alloc_stmt = - std::make_unique(matrix_init_stmt->ret_type); - new_alloc_stmt->ret_type = tensor_type; - new_alloc_stmt->ret_type.set_is_pointer(true); - - // Create StoreStmt - auto new_store_stmt = std::make_unique( - new_alloc_stmt.get(), new_atomic_stmt.get()); - - // Create MatrixPtrStmt - auto const_0 = std::make_unique(TypedConstant(0)); - auto const_1 = std::make_unique(TypedConstant(1)); - const_0->ret_type = PrimitiveType::i32; - const_1->ret_type = PrimitiveType::i32; - - auto new_matrix_ptr_stmt0 = - std::make_unique(new_alloc_stmt.get(), const_0.get()); - auto new_matrix_ptr_stmt1 = - std::make_unique(new_alloc_stmt.get(), const_1.get()); - new_matrix_ptr_stmt0->ret_type = PrimitiveType::f16; - new_matrix_ptr_stmt1->ret_type = PrimitiveType::f16; - - if (other_indices.back()->cast()->val.val_int32() == 1) { - other_stmt->replace_usages_with(new_matrix_ptr_stmt1.get()); - stmt->replace_usages_with(new_matrix_ptr_stmt0.get()); - } else { - other_stmt->replace_usages_with(new_matrix_ptr_stmt0.get()); - stmt->replace_usages_with(new_matrix_ptr_stmt1.get()); - } - - delayed_modifier_.insert_before(stmt, std::move(const_0)); - delayed_modifier_.insert_before(stmt, std::move(const_1)); - delayed_modifier_.insert_before(stmt, std::move(matrix_init_stmt)); - delayed_modifier_.insert_before(stmt, std::move(new_extern_stmt)); - delayed_modifier_.insert_before(stmt, std::move(new_atomic_stmt)); - delayed_modifier_.insert_before(stmt, std::move(new_alloc_stmt)); - delayed_modifier_.insert_before(stmt, std::move(new_store_stmt)); - delayed_modifier_.insert_before(stmt, std::move(new_matrix_ptr_stmt0)); - delayed_modifier_.insert_before(stmt, std::move(new_matrix_ptr_stmt1)); - - delayed_modifier_.erase(stmt); - } - - if (should_replace_global_temp.find(stmt) != - should_replace_global_temp.end()) { - auto *self_global_temp_stmt = stmt->dest->cast(); - size_t self_offset = self_global_temp_stmt->offset; - auto *self_val = stmt->val; - - AtomicOpStmt *other_stmt = should_replace_global_temp[stmt]; - size_t other_offset = - other_stmt->dest->cast()->offset; - auto *other_val = other_stmt->val; - - // Create MatrixInitStmt - std::vector matrix_init_values; - matrix_init_values.push_back(self_val); - matrix_init_values.push_back(other_val); - - auto matrix_init_stmt = - std::make_unique(matrix_init_values); - auto tensor_type = - TypeFactory::get_instance().get_tensor_type({2}, PrimitiveType::f16); - matrix_init_stmt->ret_type = tensor_type; - - // Create GlobalTempStmt - auto new_global_temp_stmt = std::make_unique( - std::min(self_offset, other_offset), tensor_type); - new_global_temp_stmt->ret_type = tensor_type; - new_global_temp_stmt->ret_type.set_is_pointer(true); - - // Create AtomicStmt - auto new_atomic_stmt = std::make_unique( - AtomicOpType::add, new_global_temp_stmt.get(), - matrix_init_stmt.get()); - new_atomic_stmt->ret_type = tensor_type; - - // Create AllocaStmt - auto new_alloc_stmt = - std::make_unique(matrix_init_stmt->ret_type); - new_alloc_stmt->ret_type = tensor_type; - new_alloc_stmt->ret_type.set_is_pointer(true); - - // Create StoreStmt - auto new_store_stmt = std::make_unique( - new_alloc_stmt.get(), new_atomic_stmt.get()); - - // Create MatrixPtrStmt - auto const_0 = std::make_unique(TypedConstant(0)); - auto const_1 = std::make_unique(TypedConstant(1)); - const_0->ret_type = PrimitiveType::i32; - const_1->ret_type = PrimitiveType::i32; - - auto new_matrix_ptr_stmt0 = - std::make_unique(new_alloc_stmt.get(), const_0.get()); - auto new_matrix_ptr_stmt1 = - std::make_unique(new_alloc_stmt.get(), const_1.get()); - new_matrix_ptr_stmt0->ret_type = PrimitiveType::f16; - new_matrix_ptr_stmt1->ret_type = PrimitiveType::f16; - - if (other_offset > self_offset) { - other_stmt->replace_usages_with(new_matrix_ptr_stmt1.get()); - stmt->replace_usages_with(new_matrix_ptr_stmt0.get()); - } else { - other_stmt->replace_usages_with(new_matrix_ptr_stmt0.get()); - stmt->replace_usages_with(new_matrix_ptr_stmt1.get()); - } - - delayed_modifier_.insert_before(stmt, std::move(const_0)); - delayed_modifier_.insert_before(stmt, std::move(const_1)); - delayed_modifier_.insert_before(stmt, std::move(matrix_init_stmt)); - delayed_modifier_.insert_before(stmt, std::move(new_global_temp_stmt)); - delayed_modifier_.insert_before(stmt, std::move(new_atomic_stmt)); - delayed_modifier_.insert_before(stmt, std::move(new_alloc_stmt)); - delayed_modifier_.insert_before(stmt, std::move(new_store_stmt)); - delayed_modifier_.insert_before(stmt, std::move(new_matrix_ptr_stmt0)); - delayed_modifier_.insert_before(stmt, std::move(new_matrix_ptr_stmt1)); - - delayed_modifier_.erase(stmt); - } - - if (should_replace_get_ch.find(stmt) != should_replace_get_ch.end()) { - auto *self_get_ch_stmt = stmt->dest->cast(); - Stmt *self_input_ptr = self_get_ch_stmt->input_ptr; - size_t self_chid = self_get_ch_stmt->chid; - auto *self_val = stmt->val; - - AtomicOpStmt *other_stmt = should_replace_get_ch[stmt]; - auto *other_val = other_stmt->val; - - // Create MatrixInitStmt - std::vector matrix_init_values; - matrix_init_values.push_back(self_val); - matrix_init_values.push_back(other_val); - - auto matrix_init_stmt = - std::make_unique(matrix_init_values); - auto tensor_type = - TypeFactory::get_instance().get_tensor_type({2}, PrimitiveType::f16); - matrix_init_stmt->ret_type = tensor_type; - - // Create GlobalTempStmt - auto new_get_ch_stmt = std::make_unique( - self_input_ptr, self_chid, self_get_ch_stmt->is_bit_vectorized); - new_get_ch_stmt->input_snode = self_get_ch_stmt->input_snode; - new_get_ch_stmt->output_snode = self_get_ch_stmt->output_snode; - new_get_ch_stmt->ret_type = tensor_type; - new_get_ch_stmt->ret_type.set_is_pointer(true); - new_get_ch_stmt->overrided_dtype = true; - - // Create AtomicStmt - auto new_atomic_stmt = std::make_unique( - AtomicOpType::add, new_get_ch_stmt.get(), matrix_init_stmt.get()); - new_atomic_stmt->ret_type = tensor_type; - - // Create AllocaStmt - auto new_alloc_stmt = - std::make_unique(matrix_init_stmt->ret_type); - new_alloc_stmt->ret_type = tensor_type; - new_alloc_stmt->ret_type.set_is_pointer(true); - - // Create StoreStmt - auto new_store_stmt = std::make_unique( - new_alloc_stmt.get(), new_atomic_stmt.get()); - - // Create MatrixPtrStmt - auto const_0 = std::make_unique(TypedConstant(0)); - auto const_1 = std::make_unique(TypedConstant(1)); - const_0->ret_type = PrimitiveType::i32; - const_1->ret_type = PrimitiveType::i32; - - auto new_matrix_ptr_stmt0 = - std::make_unique(new_alloc_stmt.get(), const_0.get()); - auto new_matrix_ptr_stmt1 = - std::make_unique(new_alloc_stmt.get(), const_1.get()); - new_matrix_ptr_stmt0->ret_type = PrimitiveType::f16; - new_matrix_ptr_stmt1->ret_type = PrimitiveType::f16; - - stmt->replace_usages_with(new_matrix_ptr_stmt0.get()); - other_stmt->replace_usages_with(new_matrix_ptr_stmt1.get()); - - delayed_modifier_.insert_before(stmt, std::move(const_0)); - delayed_modifier_.insert_before(stmt, std::move(const_1)); - delayed_modifier_.insert_before(stmt, std::move(matrix_init_stmt)); - delayed_modifier_.insert_before(stmt, std::move(new_get_ch_stmt)); - delayed_modifier_.insert_before(stmt, std::move(new_atomic_stmt)); - delayed_modifier_.insert_before(stmt, std::move(new_alloc_stmt)); - delayed_modifier_.insert_before(stmt, std::move(new_store_stmt)); - delayed_modifier_.insert_before(stmt, std::move(new_matrix_ptr_stmt0)); - delayed_modifier_.insert_before(stmt, std::move(new_matrix_ptr_stmt1)); - - delayed_modifier_.erase(stmt); - } - } - - void visit(OffloadedStmt *stmt) override { - if (stmt->tls_prologue) - stmt->tls_prologue->accept(this); - - if (stmt->body) { - stmt->body->accept(this); - } - - if (stmt->tls_epilogue) - stmt->tls_epilogue->accept(this); - } - - private: - using BasicStmtVisitor::visit; -}; - -namespace irpass { - -void vectorize_half2(IRNode *root) { - TI_AUTO_PROF; - - Half2VectorizationAnalyzer analyzer(root); - - Half2Vectorize vectorize_pass( - root, analyzer.should_remove, analyzer.should_replace_extern, - analyzer.should_replace_global_temp, analyzer.should_replace_get_ch); -} - -} // namespace irpass - -} // namespace taichi::lang diff --git a/tests/cpp/transforms/half2_vectorization_test.cpp b/tests/cpp/transforms/half2_vectorization_test.cpp index 078136d2540f1..31641b66ab5d4 100644 --- a/tests/cpp/transforms/half2_vectorization_test.cpp +++ b/tests/cpp/transforms/half2_vectorization_test.cpp @@ -31,104 +31,63 @@ TEST(Half2Vectorization, Ndarray) { auto kernel = std::make_unique(*test_prog.prog(), func, "fake_kernel"); - /* - Before: - i32 const_0 = ConstStmt(0) - i32 const_1 = ConstStmt(1) - - f16* ptr_0 = ExternalPtrStmt(arg, [$1, const_0]) - f16* ptr_1 = ExternalPtrStmt(arg, [$1, const_1]) - - f16 old_val0 = AtomicStmt(ptr_0, $7) - f16 old_val1 = AtomicStmt(ptr_1, $8) - - After: - TensorType(2, f16) val = MatrixInitStmt([$7, $8]) - - TensorType(2, f16)* ptr = ExternalPtrStmt(arg, [$1]) - TensorType(2, f16) old_val = AtomicStmt(ptr, val) - - TensorType(2, f16)* old_val_alloc = AllocaStmt(TensorType(2, f16)) - StoreStmt(old_val, old_val_alloc) + auto half2_type = + TypeFactory::get_instance().create_tensor_type({2}, PrimitiveType::f16); - f16 old_val0 = MatrixPtrStmt(old_val_alloc, 0) - f16 old_val1 = MatrixPtrStmt(old_val_alloc, 1) - - alloca_stmt0->replace_all_usages_with(old_val0); - alloca_stmt1->replace_all_usages_with(old_val1); - */ - - auto type = TypeFactory::get_instance().get_ndarray_struct_type( - PrimitiveType::f16, 1); - - auto argload_stmt = - block->push_back(0 /*arg_id*/, type, /*is_ptr*/ true, - /*create_load*/ false); + auto argload_stmt = block->push_back( + 0 /*arg_id*/, PrimitiveType::f16, /*is_ptr*/ true, + /*create_load*/ false); + argload_stmt->ret_type = half2_type; auto const_0_stmt = block->push_back(TypedConstant(0)); - auto const_1_stmt = block->push_back(TypedConstant(1)); - - auto val_0_stmt = block->push_back(TypedConstant(10)); - auto val_1_stmt = block->push_back(TypedConstant(20)); std::vector external_ptr_indices0 = {const_0_stmt}; - std::vector external_ptr_indices1 = {const_1_stmt}; auto external_ptr_stmt_0 = block->push_back(argload_stmt, external_ptr_indices0); - auto external_ptr_stmt_1 = - block->push_back(argload_stmt, external_ptr_indices1); + external_ptr_stmt_0->ret_type = half2_type; + external_ptr_stmt_0->ret_type.set_is_pointer(true); + + auto val_0_stmt = block->push_back(TypedConstant(10)); + auto val_1_stmt = block->push_back(TypedConstant(20)); - block->push_back(AtomicOpType::add, external_ptr_stmt_0, - val_0_stmt); - block->push_back(AtomicOpType::add, external_ptr_stmt_1, - val_1_stmt); + std::vector values = {val_0_stmt, val_1_stmt}; + auto matrix_stmt = block->push_back(values); + matrix_stmt->ret_type = half2_type; - irpass::type_check(block.get(), CompileConfig()); + auto atomic_stmt = block->push_back( + AtomicOpType::add, external_ptr_stmt_0, matrix_stmt); + atomic_stmt->ret_type = half2_type; /* Before: - $0 = arg[0] + <[Tensor (2) f16]> $0 = argaddr[0] $1 = const 0 - $2 = const 1 + <*[Tensor (2) f16]> $2 = external_ptr $0, [$1] layout=AOS is_grad=false $3 = const 10 $4 = const 20 - <*f16> $5 = external_ptr $0, [$1] element_dim=14 layout=SOA is_grad=false - <*f16> $6 = external_ptr $0, [$2] element_dim=64 layout=SOA is_grad=false - $7 = cast_value $3 - $8 = atomic add($5, $7) - $9 = cast_value $4 - $10 = atomic add($6, $9) + <[Tensor (2) f16]> $5 = [$3, $4] + <[Tensor (2) f16]> $6 = atomic add($2, $5) */ - irpass::vectorize_half2(block.get()); + irpass::scalarize(block.get(), true /*half2_optimization_enabled*/); + CompileConfig config; + irpass::full_simplify(block.get(), config, {false, false}); /* - After: - $0 = argload[0] + After: + <[Tensor (2) f16]> $0 = argaddr[0] $1 = const 0 - $2 = const 1 + <*[Tensor (2) f16]> $2 = external_ptr $0, [$1] layout=AOS is_grad=false $3 = const 10 $4 = const 20 - <*f16> $5 = external_ptr $0, [$1] element_dim=14 layout=SOA is_grad=false - <*f16> $6 = external_ptr $0, [$2] element_dim=0 layout=AOS is_grad=false - $7 = cast_value $3 - $8 = const 0 - $9 = const 1 - <[Tensor (2) f16]> $10 = [$7, $17] - <*[Tensor (2) f16]> $11 = external_ptr $0, [], (2) element_dim=-1 - layout=AOS is_grad=false - <[Tensor (2) f16]> $12 = atomic add($11, $10) - <*[Tensor (2) f16]> $13 = alloca - $14 : local store [$13 <- $12] - $15 = shift ptr [$13 + $8] - $16 = shift ptr [$13 + $9] - $17 = cast_value $4 + <[Tensor (2) f16]> $5 = [$3, $4] + <[Tensor (2) f16]> $6 = atomic add($2, $5) */ - EXPECT_EQ(block->size(), 18); + EXPECT_EQ(block->size(), 7); // Check for scalarized statements - EXPECT_EQ(block->statements[10]->is(), true); - EXPECT_EQ(block->statements[11]->is(), true); - EXPECT_EQ(block->statements[12]->is(), true); + EXPECT_EQ(block->statements[5]->is(), true); + EXPECT_EQ(block->statements[2]->is(), true); + EXPECT_EQ(block->statements[6]->is(), true); } TEST(Half2Vectorization, GlobalTemporary) { @@ -141,83 +100,52 @@ TEST(Half2Vectorization, GlobalTemporary) { auto func = []() {}; auto kernel = std::make_unique(*test_prog.prog(), func, "fake_kernel"); - - /* - Before: - f16* ptr_0 = GlobalTempStmt(offset0) - f16* ptr_1 = GlobalTempStmt(offset0 + 2) - - f16 old_val0 = AtomicStmt(ptr_0, $7) - f16 old_val1 = AtomicStmt(ptr_1, $8) - - After: - TensorType(2, f16) val = MatrixInitStmt([$7, $8]) - - TensorType(2, f16)* ptr = GlobalTempStmt(offset0) - TensorType(2, f16) old_val = AtomicStmt(ptr, val) - - TensorType(2, f16)* old_val_alloc = AllocaStmt(TensorType(2, f16)) - StoreStmt(old_val, old_val_alloc) - - f16 old_val0 = MatrixPtrStmt(old_val_alloc, 0) - f16 old_val1 = MatrixPtrStmt(old_val_alloc, 1) - - alloca_stmt0->replace_all_usages_with(old_val0); - alloca_stmt1->replace_all_usages_with(old_val1); - */ + auto half2_type = + TypeFactory::get_instance().create_tensor_type({2}, PrimitiveType::f16); auto val_0_stmt = block->push_back(TypedConstant(10)); auto val_1_stmt = block->push_back(TypedConstant(20)); + std::vector values = {val_0_stmt, val_1_stmt}; + auto matrix_stmt = block->push_back(values); + matrix_stmt->ret_type = half2_type; + auto global_temp_stmt_0 = - block->push_back(0, PrimitiveType::f16); - auto global_temp_stmt_1 = - block->push_back(2, PrimitiveType::f16); + block->push_back(0, half2_type); block->push_back(AtomicOpType::add, global_temp_stmt_0, - val_0_stmt); - block->push_back(AtomicOpType::add, global_temp_stmt_1, - val_1_stmt); + matrix_stmt); irpass::type_check(block.get(), CompileConfig()); /* Before: $0 = const 10 - $1 = const 20 - <*f16> $2 = global tmp var (offset = 0 B) - <*f16> $3 = global tmp var (offset = 2 B) - $4 = cast_value $0 - $5 = atomic add($2, $4) - $6 = cast_value $1 - $7 = atomic add($3, $6) + $1 = cast_value $0 + $2 = const 20 + $3 = cast_value $2 + <[Tensor (2) f16]> $4 = [$1, $3] + <*[Tensor (2) f16]> $5 = global tmp var (offset = 0 B) + <[Tensor (2) f16]> $6 = atomic add($5, $4) */ - irpass::vectorize_half2(block.get()); + irpass::scalarize(block.get(), true /*half2_optimization_enabled*/); + CompileConfig config; + irpass::full_simplify(block.get(), config, {false, false}); /* After: - $0 = const 10 - $1 = const 20 - <*f16> $2 = global tmp var (offset = 0 B) - <*f16> $3 = global tmp var (offset = 2 B) - $4 = cast_value $0 - $5 = const 0 - $6 = const 1 - <[Tensor (2) f16]> $7 = [$4, $14] - <*[Tensor (2) f16]> $8 = global tmp var (offset = 0 B) - <[Tensor (2) f16]> $9 = atomic add($8, $7) - <*[Tensor (2) f16]> $10 = alloca - $11 : local store [$10 <- $9] - $12 = shift ptr [$10 + $5] - $13 = shift ptr [$10 + $6] - $14 = cast_value $1 + $0 = const 10.0 + $1 = const 20.0 + <[Tensor (2) f16]> $2 = [$0, $1] + <*[Tensor (2) f16]> $3 = global tmp var (offset = 0 B) + <[Tensor (2) f16]> $4 = atomic add($3, $2) */ - EXPECT_EQ(block->size(), 15); + EXPECT_EQ(block->size(), 5); // Check for scalarized statements - EXPECT_EQ(block->statements[7]->is(), true); - EXPECT_EQ(block->statements[8]->is(), true); - EXPECT_EQ(block->statements[9]->is(), true); + EXPECT_EQ(block->statements[2]->is(), true); + EXPECT_EQ(block->statements[3]->is(), true); + EXPECT_EQ(block->statements[4]->is(), true); } TEST(Half2Vectorization, Field) { @@ -231,31 +159,6 @@ TEST(Half2Vectorization, Field) { auto kernel = std::make_unique(*test_prog.prog(), func, "fake_kernel"); - /* - Before: - gen* container = SNodeLookupStmt(...) - - fp16* ptr_0 = GetChStmt(container, 0) - fp16* ptr_1 = GetChStmt(container, 1) - - f16 old_val0 = AtomicStmt(ptr_0, $7) - f16 old_val1 = AtomicStmt(ptr_1, $8) - - After: - TensorType(2, f16) val = MatrixInitStmt([$7, $8]) - - TensorType(2, f16)* ptr = GetChStmt(container, 0) - TensorType(2, f16) old_val = AtomicStmt(ptr, val) - - TensorType(2, f16)* old_val_alloc = AllocaStmt(TensorType(2, f16)) - StoreStmt(old_val, old_val_alloc) - - f16 old_val0 = MatrixPtrStmt(old_val_alloc, 0) - f16 old_val1 = MatrixPtrStmt(old_val_alloc, 1) - - alloca_stmt0->replace_all_usages_with(old_val0); - alloca_stmt1->replace_all_usages_with(old_val1); - */ auto get_root = block->push_back(); auto linearized_empty = block->push_back(std::vector(), std::vector()); @@ -267,66 +170,58 @@ TEST(Half2Vectorization, Field) { linearized_empty, false); auto get_ch_stmt_0 = block->push_back(lookup, 0); - auto get_ch_stmt_1 = block->push_back(lookup, 1); - get_ch_stmt_0->ret_type = PrimitiveType::f16; - get_ch_stmt_1->ret_type = PrimitiveType::f16; + auto half2_type = + TypeFactory::get_instance().create_tensor_type({2}, PrimitiveType::f16); + get_ch_stmt_0->ret_type = half2_type; get_ch_stmt_0->ret_type.set_is_pointer(true); - get_ch_stmt_1->ret_type.set_is_pointer(true); get_ch_stmt_0->as()->overrided_dtype = true; - get_ch_stmt_1->as()->overrided_dtype = true; auto val_0_stmt = block->push_back(TypedConstant(10)); auto val_1_stmt = block->push_back(TypedConstant(20)); - block->push_back(AtomicOpType::add, get_ch_stmt_0, val_0_stmt); - block->push_back(AtomicOpType::add, get_ch_stmt_1, val_1_stmt); + std::vector values = {val_0_stmt, val_1_stmt}; + auto matrix_stmt = block->push_back(values); + matrix_stmt->ret_type = half2_type; - irpass::type_check(block.get(), CompileConfig()); + block->push_back(AtomicOpType::add, get_ch_stmt_0, matrix_stmt); + irpass::type_check(block.get(), CompileConfig()); /* Before: <*gen> $0 = get root nullptr $1 = linearized(ind {}, stride {}) <*gen> $2 = [S1root][root]::lookup($0, $1) activate = false - <*f16> $3 = get child [S1root->S2place] $2 - <*f16> $4 = get child [S1root->S3place] $2 - $5 = const 10 + <*[Tensor (2) f16]> $3 = get child [S1root->S2place] $2 + $4 = const 10 + $5 = cast_value $4 $6 = const 20 - $7 = cast_value $5 - $8 = atomic add($3, $7) - $9 = cast_value $6 - $10 = atomic add($4, $9) + $7 = cast_value $6 + <[Tensor (2) f16]> $8 = [$5, $7] + <[Tensor (2) f16]> $9 = atomic add($3, $8) */ - irpass::vectorize_half2(block.get()); + irpass::scalarize(block.get(), true /*half2_optimization_enabled*/); + + CompileConfig config; + irpass::full_simplify(block.get(), config, {false, false}); /* After: <*gen> $0 = get root nullptr - $1 = linearized(ind {}, stride {}) + $1 = const 0 <*gen> $2 = [S1root][root]::lookup($0, $1) activate = false - <*f16> $3 = get child [S1root->S2place] $2 - <*f16> $4 = get child [S1root->S3place] $2 - $5 = const 10 - $6 = const 20 - $7 = cast_value $5 - $8 = const 0 - $9 = const 1 - <[Tensor (2) f16]> $10 = [$7, $17] - <*[Tensor (2) f16]> $11 = get child [S1root->S2place] $2 - <[Tensor (2) f16]> $12 = atomic add($11, $10) - <*[Tensor (2) f16]> $13 = alloca - $14 : local store [$13 <- $12] - $15 = shift ptr [$13 + $8] - $16 = shift ptr [$13 + $9] - $17 = cast_value $6 + <*[Tensor (2) f16]> $3 = get child [S1root->S2place] $2 + $4 = const 10.0 + $5 = const 20.0 + <[Tensor (2) f16]> $6 = [$4, $5] + <[Tensor (2) f16]> $7 = atomic add($3, $6) */ - EXPECT_EQ(block->size(), 18); + EXPECT_EQ(block->size(), 8); // Check for scalarized statements - EXPECT_EQ(block->statements[10]->is(), true); - EXPECT_EQ(block->statements[11]->is(), true); - EXPECT_EQ(block->statements[12]->is(), true); + EXPECT_EQ(block->statements[6]->is(), true); + EXPECT_EQ(block->statements[3]->is(), true); + EXPECT_EQ(block->statements[7]->is(), true); } } // namespace taichi::lang