Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] Migrate irpass::scalarize() after optimize_bit_struct_stores & determine_ad_stack_size #8097

Merged
merged 14 commits into from
May 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) {
if (is_real(stmt->ret_type.get_element_type())) {
llvm_val[stmt] =
builder->CreateFDiv(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
} else if (is_signed(stmt->ret_type)) {
} else if (is_signed(stmt->ret_type.get_element_type())) {
llvm_val[stmt] =
builder->CreateSDiv(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
} else {
Expand Down Expand Up @@ -658,7 +658,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) {
builder->CreateShl(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
#endif
} else if (op == BinaryOpType::bit_sar) {
if (is_signed(stmt->lhs->element_type())) {
if (is_signed(stmt->lhs->ret_type.get_element_type())) {
llvm_val[stmt] =
builder->CreateAShr(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
} else {
Expand Down
15 changes: 14 additions & 1 deletion taichi/ir/control_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,10 @@ void CFGNode::reaching_definition_analysis(bool after_lower_access) {
auto data_source_ptrs = irpass::analysis::get_store_destination(stmt);
for (auto data_source_ptr : data_source_ptrs) {
// stmt provides a data source
if (after_lower_access && !(data_source_ptr->is<AllocaStmt>())) {
if (after_lower_access &&
!((data_source_ptr->is<MatrixPtrStmt>() &&
data_source_ptr->as<MatrixPtrStmt>()->origin->is<AllocaStmt>()) ||
data_source_ptr->is<AllocaStmt>())) {
// After lower_access, we only analyze local variables.
continue;
}
Expand Down Expand Up @@ -552,6 +555,8 @@ void CFGNode::live_variable_analysis(bool after_lower_access) {
irpass::analysis::get_load_pointers(stmt, true /*get_alias*/);
for (auto &load_ptr : load_ptrs) {
if (!after_lower_access ||
(load_ptr->is<MatrixPtrStmt>() &&
load_ptr->as<MatrixPtrStmt>()->origin->is<AllocaStmt>()) ||
(load_ptr->is<AllocaStmt>() || load_ptr->is<AdStackAllocaStmt>())) {
// After lower_access, we only analyze local variables and stacks.
if (!contain_variable(live_kill, load_ptr)) {
Expand All @@ -576,6 +581,8 @@ void CFGNode::live_variable_analysis(bool after_lower_access) {
}
for (auto store_ptr : store_ptrs) {
if (!after_lower_access ||
(store_ptr->is<MatrixPtrStmt>() &&
store_ptr->as<MatrixPtrStmt>()->origin->is<AllocaStmt>()) ||
(store_ptr->is<AllocaStmt>() || store_ptr->is<AdStackAllocaStmt>())) {
// After lower_access, we only analyze local variables and stacks.
live_kill.insert(store_ptr);
Expand Down Expand Up @@ -707,6 +714,8 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) {
auto store_ptr = *store_ptrs.begin();

if (!after_lower_access ||
(store_ptr->is<MatrixPtrStmt>() &&
store_ptr->as<MatrixPtrStmt>()->origin->is<AllocaStmt>()) ||
(store_ptr->is<AllocaStmt>() || store_ptr->is<AdStackAllocaStmt>())) {
// !may_contain_variable(live_in_this_node, store_ptr): address is not
// loaded after this store
Expand Down Expand Up @@ -806,6 +815,8 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) {
auto load_ptr = load_ptrs.begin()[0];

if (!after_lower_access ||
(load_ptr->is<MatrixPtrStmt>() &&
load_ptr->as<MatrixPtrStmt>()->origin->is<AllocaStmt>()) ||
(load_ptr->is<AllocaStmt>() || load_ptr->is<AdStackAllocaStmt>())) {
// live_load_in_this_node[addr]: tracks the
// next load to the same address
Expand All @@ -832,6 +843,8 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) {
// Update live_in_this_node
for (auto &load_ptr : load_ptrs) {
if (!after_lower_access ||
(load_ptr->is<MatrixPtrStmt>() &&
load_ptr->as<MatrixPtrStmt>()->origin->is<AllocaStmt>()) ||
(load_ptr->is<AllocaStmt>() || load_ptr->is<AdStackAllocaStmt>())) {
// Addr is used in this node, so it's live in this node
update_container_with_alias(tensor_to_matrix_ptrs_map,
Expand Down
3 changes: 2 additions & 1 deletion taichi/ir/ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "taichi/ir/ir.h"
#include "taichi/ir/mesh.h"
#include "taichi/ir/statements.h"

namespace taichi::lang {

Expand Down Expand Up @@ -137,7 +138,7 @@ class IRBuilder {
ConstStmt *get_float64(float64 value);

template <typename T>
ConstStmt *get_constant(DataType dt, const T &value) {
Stmt *get_constant(DataType dt, const T &value) {
return insert(Stmt::make_typed<ConstStmt>(TypedConstant(dt, value)));
}

Expand Down
17 changes: 9 additions & 8 deletions taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,14 +231,6 @@ void offload_to_executable(IRNode *ir,
print("Make block local");
}

if (config.real_matrix_scalarize) {
if (irpass::scalarize(ir)) {
// Remove redundant MatrixInitStmt inserted during scalarization
irpass::full_simplify(ir, config, {false, /*autodiff_enabled*/ false});
print("Scalarized");
}
}

if (is_extension_supported(config.arch, Extension::mesh)) {
irpass::demote_mesh_statements(ir, config, {kernel->get_name()});
print("Demote mesh statements");
Expand Down Expand Up @@ -293,6 +285,15 @@ void offload_to_executable(IRNode *ir,
print("Bit struct stores optimized");
}

if (config.real_matrix_scalarize) {
if (irpass::scalarize(ir)) {
// Remove redundant MatrixInitStmt inserted during scalarization
irpass::full_simplify(ir, config,
{lower_global_access, /*autodiff_enabled*/ false});
print("Scalarized");
}
}

if (config.arch == Arch::cuda && config.half2_vectorization &&
!get_custom_cuda_library_path().empty()) {
irpass::vectorize_half2(ir);
Expand Down
Loading