Skip to content

Commit

Permalink
[Lang] Merge irpass::half2_vectorize() with irpass::scalarize() (#8102)
Browse files Browse the repository at this point in the history
Issue: #

### Brief Summary

<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 44b862c</samp>

This pull request enhances the support and optimization for matrices and
vectors in the IR and the code generation, especially for f16 data
types. It adds `ndarray` methods to `MatrixType` and `VectorType`
classes, fixes code generation bugs and data flow analysis for matrix
and vector operations, simplifies and improves the scalarization and
vectorization of matrices in CUDA offloads, and adds a half2
vectorization optimization for f16 matrices and vectors. It also updates
the `IRBuilder` class, the `Ndarray` class, and the test cases
accordingly.

### Walkthrough

<!--
copilot:walkthrough
-->
### <samp>🤖 Generated by Copilot at 44b862c</samp>

* Add `ndarray` methods to `MatrixType` and `VectorType` classes to
create `Matrix` and `Vector` objects from IR types
([link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-5913c0a6b6a5e279414150955f30b96ea6b9676a1f5b1931ca4bcb39f19c81e9R1504-R1508),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-5913c0a6b6a5e279414150955f30b96ea6b9676a1f5b1931ca4bcb39f19c81e9R1606-R1608))
* Fix bug in `defined` function that used incorrect type to check
signedness of binary operands involving `MatrixType` or `VectorType`
([link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-3c663c78745adcd3f6a7ac81fe99e628decc3040f292ea1e20ecd4b85a7f4313L614-R614),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-3c663c78745adcd3f6a7ac81fe99e628decc3040f292ea1e20ecd4b85a7f4313L661-R661))
* Modify `reaching_definition_analysis`, `live_variable_analysis` and
`dead_store_elimination` functions to handle `MatrixPtrStmt` as local
variables in data dependency analysis
([link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-837b90142d1730f6a3ab20c91f1f35c95335ef82a021c74fd4dbdb05ff0e164fL387-R390),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-837b90142d1730f6a3ab20c91f1f35c95335ef82a021c74fd4dbdb05ff0e164fR558-R559),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-837b90142d1730f6a3ab20c91f1f35c95335ef82a021c74fd4dbdb05ff0e164fR584-R585),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-837b90142d1730f6a3ab20c91f1f35c95335ef82a021c74fd4dbdb05ff0e164fR717-R718),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-837b90142d1730f6a3ab20c91f1f35c95335ef82a021c74fd4dbdb05ff0e164fR818-R819),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-837b90142d1730f6a3ab20c91f1f35c95335ef82a021c74fd4dbdb05ff0e164fR846-R847))
* Add `taichi/ir/statements.h` header file to `taichi/ir/ir_builder.h`
and change return type of `get_constant` method to `Stmt *` to handle
`MatrixType` and `VectorType` constants
([link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-1894085b261e833e3e66924fc5b1cf63b9dd8b8aa0b3e78ec64366396131470dR5),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-1894085b261e833e3e66924fc5b1cf63b9dd8b8aa0b3e78ec64366396131470dL140-R141))
* Add `half2_optimization_enabled` parameter to `scalarize` function and
`Scalarize` class to control half2 vectorization optimization for
`MatrixType` and `VectorType` operands with two f16 elements
([link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-448ac6e85e192a27e5ec7c54cd8a91545dc7c83f62d030eafb9c190383cfe934L33-R33),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-97b0d9ab204b703802b3b5d04d036d30f66b34b726128216faf0d8a2a8564528L19-R23),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-97b0d9ab204b703802b3b5d04d036d30f66b34b726128216faf0d8a2a8564528L417-R503),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-97b0d9ab204b703802b3b5d04d036d30f66b34b726128216faf0d8a2a8564528L841-R928),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-97b0d9ab204b703802b3b5d04d036d30f66b34b726128216faf0d8a2a8564528L1195-R1285))
* Add `fp16.h` header file to `taichi/program/ndarray.cpp` and modify
`read` and `write` methods of `Ndarray` class to handle f16 data types
correctly
([link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-c88c6764ffa952681c8b0db12b376c473a8422cb7bb0243a10cc643cc245a5a1R5),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-c88c6764ffa952681c8b0db12b376c473a8422cb7bb0243a10cc643cc245a5a1L171-R185))
* Remove redundant call to `scalarize` function with
`config.real_matrix_scalarize` flag and modify call to `scalarize`
function with `half2_optimization_enabled` flag in
`offload_to_executable` function
([link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-8fde186587db97b3bbc8a856e59bc4467b30257335b0fad064b4eebd521a912bL234-L241),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-8fde186587db97b3bbc8a856e59bc4467b30257335b0fad064b4eebd521a912bL296-R297))
* Add `transform_pow_op_impl` method to `DemoteOperations` class to
transform power operation with scalar operands and modify `visit` method
to handle `floordiv`, `bit_sar` and `pow` operations with scalar and
tensor operands
([link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-d217f2b07d4578612dc805b0f01e5dc1883be9acb906b222a8762313cfd0596bR19-R129),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-d217f2b07d4578612dc805b0f01e5dc1883be9acb906b222a8762313cfd0596bR135-R146),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-d217f2b07d4578612dc805b0f01e5dc1883be9acb906b222a8762313cfd0596bR165),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-d217f2b07d4578612dc805b0f01e5dc1883be9acb906b222a8762313cfd0596bL67-R198),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-d217f2b07d4578612dc805b0f01e5dc1883be9acb906b222a8762313cfd0596bL99-R228),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-d217f2b07d4578612dc805b0f01e5dc1883be9acb906b222a8762313cfd0596bL109-L150),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-d217f2b07d4578612dc805b0f01e5dc1883be9acb906b222a8762313cfd0596bL159-R254),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-d217f2b07d4578612dc805b0f01e5dc1883be9acb906b222a8762313cfd0596bL176-R265),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-d217f2b07d4578612dc805b0f01e5dc1883be9acb906b222a8762313cfd0596bL191-R287))
* Modify `half2_vectorization_test.cpp` to use tensor type operands with
two f16 elements and call `scalarize` function with
`half2_optimization_enabled` flag instead of `vectorize_half2` function
([link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-5136d70f7a32456bee3938daca1066aa3d380aecd7d18257fc893b83dfd72a79L34-R68),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-5136d70f7a32456bee3938daca1066aa3d380aecd7d18257fc893b83dfd72a79L105-R88),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-5136d70f7a32456bee3938daca1066aa3d380aecd7d18257fc893b83dfd72a79L144-R115),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-5136d70f7a32456bee3938daca1066aa3d380aecd7d18257fc893b83dfd72a79L187-R146),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-5136d70f7a32456bee3938daca1066aa3d380aecd7d18257fc893b83dfd72a79L234-L258))
  • Loading branch information
jim19930609 authored Jun 1, 2023
1 parent d9373bd commit b3b7c64
Show file tree
Hide file tree
Showing 7 changed files with 203 additions and 827 deletions.
8 changes: 8 additions & 0 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1501,6 +1501,11 @@ def field(self, **kwargs):
kwargs.update({"ndim": self.ndim})
return Matrix.field(self.n, self.m, dtype=self.dtype, **kwargs)

def ndarray(self, **kwargs):
assert kwargs.get("ndim", self.ndim) == self.ndim
kwargs.update({"ndim": self.ndim})
return Matrix.ndarray(self.n, self.m, dtype=self.dtype, **kwargs)

def get_shape(self):
if self.ndim == 1:
return (self.n,)
Expand Down Expand Up @@ -1598,6 +1603,9 @@ def _instantiate(self, entries):
def field(self, **kwargs):
return Vector.field(self.n, dtype=self.dtype, **kwargs)

def ndarray(self, **kwargs):
return Vector.ndarray(self.n, dtype=self.dtype, **kwargs)

def to_string(self):
dtype_str = self.dtype.to_string() if self.dtype is not None else ""
return f"VectorType[{self.n}, {dtype_str}]"
Expand Down
3 changes: 1 addition & 2 deletions taichi/ir/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ namespace irpass {
void re_id(IRNode *root);
void flag_access(IRNode *root);
void eliminate_immutable_local_vars(IRNode *root);
bool scalarize(IRNode *root);
void vectorize_half2(IRNode *root);
bool scalarize(IRNode *root, bool half2_optimization_enabled = false);
void lower_matrix_ptr(IRNode *root);
bool die(IRNode *root);
bool simplify(IRNode *root, const CompileConfig &config);
Expand Down
11 changes: 11 additions & 0 deletions taichi/program/ndarray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "taichi/program/ndarray.h"
#include "taichi/program/program.h"
#include "fp16.h"

#ifdef TI_WITH_LLVM
#include "taichi/runtime/llvm/llvm_context.h"
Expand Down Expand Up @@ -168,10 +169,20 @@ TypedConstant Ndarray::read(const std::vector<int> &I) const {
TypedConstant data(get_element_data_type());
std::memcpy(&data.value_bits, device_arr_ptr, size);
staging_buf_->device->unmap(*staging_buf_);

if (get_element_data_type()->is_primitive(PrimitiveTypeID::f16)) {
float float32 = fp16_ieee_to_fp32_value(data.val_u16);
data.val_f32 = float32;
}
return data;
}

void Ndarray::write(const std::vector<int> &I, TypedConstant val) const {
if (get_element_data_type()->is_primitive(PrimitiveTypeID::f16)) {
uint16_t float16 = fp16_ieee_from_fp32_value(val.val_f32);
std::memcpy(&val.value_bits, &float16, 4);
}

size_t index = flatten_index(total_shape_, I);
size_t size_ = data_type_size(get_element_data_type());
taichi::lang::Device::AllocParams alloc_params;
Expand Down
36 changes: 4 additions & 32 deletions taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,27 +268,9 @@ void offload_to_executable(IRNode *ir,
irpass::analysis::verify(ir);
}

if (config.real_matrix_scalarize) {
if (irpass::scalarize(ir)) {
// Remove redundant MatrixInitStmt inserted during scalarization
irpass::full_simplify(ir, config,
{lower_global_access, /*autodiff_enabled*/ false});
print("Scalarized");
}
}

irpass::demote_operations(ir, config);
print("Operations demoted");

if (config.real_matrix_scalarize) {
if (irpass::scalarize(ir)) {
// Remove redundant MatrixInitStmt inserted during scalarization
irpass::full_simplify(ir, config,
{lower_global_access, /*autodiff_enabled*/ false});
print("Scalarized");
}
}

irpass::full_simplify(ir, config,
{lower_global_access, /*autodiff_enabled*/ false});
print("Simplified IV");
Expand All @@ -303,28 +285,18 @@ void offload_to_executable(IRNode *ir,
print("Bit struct stores optimized");
}

bool half2_optimization_enabled =
(config.arch == Arch::cuda && config.half2_vectorization &&
!get_custom_cuda_library_path().empty());
if (config.real_matrix_scalarize) {
if (irpass::scalarize(ir)) {
if (irpass::scalarize(ir, half2_optimization_enabled)) {
// Remove redundant MatrixInitStmt inserted during scalarization
irpass::full_simplify(ir, config,
{lower_global_access, /*autodiff_enabled*/ false});
print("Scalarized");
}
}

if (config.arch == Arch::cuda && config.half2_vectorization &&
!get_custom_cuda_library_path().empty()) {
irpass::vectorize_half2(ir);

irpass::type_check(ir, config);

irpass::full_simplify(ir, config,
{lower_global_access, /*autodiff_enabled*/ false});

irpass::flag_access(ir);
print("Half2 vectorized");
}

// Final field registration correctness & type checking
irpass::type_check(ir, config);
irpass::analysis::verify(ir);
Expand Down
98 changes: 92 additions & 6 deletions taichi/transforms/scalarize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@ class Scalarize : public BasicStmtVisitor {
public:
ImmediateIRModifier immediate_modifier_;
DelayedIRModifier delayed_modifier_;
bool half2_optimization_enabled_ = false;

explicit Scalarize(IRNode *node) : immediate_modifier_(node) {
explicit Scalarize(IRNode *node, bool half2_optimization)
: immediate_modifier_(node),
half2_optimization_enabled_(half2_optimization) {
}

/*
Expand Down Expand Up @@ -414,7 +417,90 @@ class Scalarize : public BasicStmtVisitor {
void visit(AtomicOpStmt *stmt) override {
auto dest_dtype = stmt->dest->ret_type.ptr_removed();
auto val_dtype = stmt->val->ret_type;
if (dest_dtype->is<TensorType>() || val_dtype->is<TensorType>()) {

bool half2_optimizable = half2_optimization_enabled_;
bool is_tensor_type = false;
if (dest_dtype->is<TensorType>()) {
is_tensor_type = true;
half2_optimizable &=
(dest_dtype.get_element_type() == PrimitiveType::f16);
half2_optimizable &=
(dest_dtype->as<TensorType>()->get_num_elements() == 2);
} else {
half2_optimizable = false;
}
is_tensor_type |= val_dtype->is<TensorType>();

if (half2_optimizable) {
/*
Before:
TensorType<2 x i32> old_val = AtomicStmt(TensorType<2 x i32>* dest,
TensorType<2 x i32> val)
After:
TensorType<2 x i32> old_val = AtomicStmt(TensorType<2 x i32>* dest,
TensorType<2 x i32> val)
TensorType(2, f16)* old_val_alloc = AllocaStmt(TensorType(2, f16))
LocalStoreStmt(old_val_alloc, old_val)
f16* old_val_ptr0 = MatrixPtrStmt(old_val_alloc, 0)
f16* old_val_ptr1 = MatrixPtrStmt(old_val_alloc, 0)
f16 old_val0 = LoadStmt(old_val_ptr0)
f16 old_val1 = LoadStmt(old_val_ptr1)
tmp = MatrixInitStmt(old_val0, old_val1)
stmt->replace_all_usages_with(tmp)
*/
auto atomic_stmt =
std::make_unique<AtomicOpStmt>(stmt->op_type, stmt->dest, stmt->val);
atomic_stmt->ret_type = stmt->ret_type;

auto alloca_stmt = std::make_unique<AllocaStmt>(dest_dtype);

auto local_store_stmt = std::make_unique<LocalStoreStmt>(
alloca_stmt.get(), atomic_stmt.get());

auto zero =
std::make_unique<ConstStmt>(TypedConstant(PrimitiveType::i32, 0));
auto one =
std::make_unique<ConstStmt>(TypedConstant(PrimitiveType::i32, 1));

auto matrix_ptr_0 =
std::make_unique<MatrixPtrStmt>(alloca_stmt.get(), zero.get());
auto matrix_ptr_1 =
std::make_unique<MatrixPtrStmt>(alloca_stmt.get(), one.get());
matrix_ptr_0->ret_type = PrimitiveType::f16;
matrix_ptr_0->ret_type.set_is_pointer(true);
matrix_ptr_1->ret_type = PrimitiveType::f16;
matrix_ptr_1->ret_type.set_is_pointer(true);

auto load_stmt_0 = std::make_unique<LocalLoadStmt>(matrix_ptr_0.get());
auto load_stmt_1 = std::make_unique<LocalLoadStmt>(matrix_ptr_1.get());
load_stmt_0->ret_type = PrimitiveType::f16;
load_stmt_1->ret_type = PrimitiveType::f16;

auto matrix_init_stmt = std::make_unique<MatrixInitStmt>(
std::vector<Stmt *>{load_stmt_0.get(), load_stmt_1.get()});
matrix_init_stmt->ret_type = stmt->ret_type;

immediate_modifier_.replace_usages_with(stmt, matrix_init_stmt.get());
delayed_modifier_.insert_before(stmt, std::move(atomic_stmt));
delayed_modifier_.insert_before(stmt, std::move(alloca_stmt));
delayed_modifier_.insert_before(stmt, std::move(local_store_stmt));
delayed_modifier_.insert_before(stmt, std::move(zero));
delayed_modifier_.insert_before(stmt, std::move(one));
delayed_modifier_.insert_before(stmt, std::move(matrix_ptr_0));
delayed_modifier_.insert_before(stmt, std::move(matrix_ptr_1));
delayed_modifier_.insert_before(stmt, std::move(load_stmt_0));
delayed_modifier_.insert_before(stmt, std::move(load_stmt_1));
delayed_modifier_.insert_before(stmt, std::move(matrix_init_stmt));

delayed_modifier_.erase(stmt);

} else if (is_tensor_type) {
// Make sure broadcasting has been correctly applied by
// AtomicOpExpression::type_check().
TI_ASSERT(dest_dtype->is<TensorType>() && val_dtype->is<TensorType>());
Expand Down Expand Up @@ -838,8 +924,8 @@ class Scalarize : public BasicStmtVisitor {
}
}

static bool run(IRNode *node) {
Scalarize pass(node);
static bool run(IRNode *node, bool half2_optimization_enabled) {
Scalarize pass(node, half2_optimization_enabled);
node->accept(&pass);
return pass.delayed_modifier_.modify_ir();
}
Expand Down Expand Up @@ -1192,11 +1278,11 @@ class MergeExternalAndMatrixPtr : public BasicStmtVisitor {

namespace irpass {

bool scalarize(IRNode *root) {
bool scalarize(IRNode *root, bool half2_optimization_enabled) {
TI_AUTO_PROF;
bool modified = false;

modified |= Scalarize::run(root);
modified |= Scalarize::run(root, half2_optimization_enabled);
auto scalarizable_allocas = GatherScalarizableLocalPointers::run(root);
modified |= ScalarizePointers::run(root, scalarizable_allocas);
modified |= ExtractLocalPointers::run(root);
Expand Down
Loading

0 comments on commit b3b7c64

Please sign in to comment.