Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[type] Atomic demotion for bit struct stores #2174

Merged
merged 2 commits into from
Jan 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions taichi/analysis/gather_uniquely_accessed_pointers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,71 @@ class UniquelyAccessedSNodeSearcher : public BasicStmtVisitor {
}
};

class UniquelyAccessedBitStructGatherer : public BasicStmtVisitor {
private:
std::unordered_map<OffloadedStmt *,
std::unordered_map<const SNode *, GlobalPtrStmt *>>
result_;

public:
using BasicStmtVisitor::visit;

UniquelyAccessedBitStructGatherer() {
allow_undefined_visitor = true;
invoke_default_visitor = false;
}

void visit(OffloadedStmt *stmt) override {
if (stmt->task_type == OffloadedTaskType::range_for ||
stmt->task_type == OffloadedTaskType::struct_for) {
auto &loop_unique_bit_struct = result_[stmt];
auto loop_unique_ptr =
irpass::analysis::gather_uniquely_accessed_pointers(stmt);
for (auto &it : loop_unique_ptr) {
auto *snode = it.first;
auto *ptr = it.second;
if (snode->is_bit_level) {
// Find the nearest non-bit-level ancestor
while (snode->is_bit_level) {
snode = snode->parent;
}
// Check whether uniquely accessed
auto accessed_ptr = loop_unique_bit_struct.find(snode);
if (accessed_ptr == loop_unique_bit_struct.end()) {
loop_unique_bit_struct[snode] = ptr;
} else {
if (!irpass::analysis::definitely_same_address(accessed_ptr->second,
ptr)) {
accessed_ptr->second = nullptr; // not uniquely accessed
}
}
}
}
}
// Do not dive into OffloadedStmt
}

static std::unordered_map<OffloadedStmt *,
std::unordered_map<const SNode *, GlobalPtrStmt *>>
run(IRNode *root) {
UniquelyAccessedBitStructGatherer gatherer;
root->accept(&gatherer);
return gatherer.result_;
}
};

namespace irpass::analysis {
std::unordered_map<const SNode *, GlobalPtrStmt *>
gather_uniquely_accessed_pointers(IRNode *root) {
// TODO: What about SNodeOpStmts?
return UniquelyAccessedSNodeSearcher::run(root);
}

std::unordered_map<OffloadedStmt *,
std::unordered_map<const SNode *, GlobalPtrStmt *>>
gather_uniquely_accessed_bit_structs(IRNode *root) {
return UniquelyAccessedBitStructGatherer::run(root);
}
} // namespace irpass::analysis

TLANG_NAMESPACE_END
3 changes: 2 additions & 1 deletion taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
void store_masked(llvm::Value *byte_ptr,
uint64 mask,
Type *physical_type,
llvm::Value *value);
llvm::Value *value,
bool atomic);

void visit(GlobalStoreStmt *stmt) override;

Expand Down
13 changes: 8 additions & 5 deletions taichi/codegen/codegen_llvm_quant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,15 @@ void CodeGenLLVM::store_custom_int(llvm::Value *byte_ptr,
void CodeGenLLVM::store_masked(llvm::Value *byte_ptr,
uint64 mask,
Type *physical_type,
llvm::Value *value) {
llvm::Value *value,
bool atomic) {
uint64 full_mask = (~(uint64)0) >> (64 - data_type_bits(physical_type));
if ((mask & full_mask) == full_mask) {
builder->CreateStore(value, byte_ptr);
return;
}
create_call(fmt::format("set_mask_b{}", data_type_bits(physical_type)),
create_call(fmt::format("{}set_mask_b{}", atomic ? "atomic_" : "",
data_type_bits(physical_type)),
{builder->CreateBitCast(byte_ptr, llvm_ptr_type(physical_type)),
tlctx->get_constant(mask),
builder->CreateIntCast(value, llvm_type(physical_type), false)});
Expand Down Expand Up @@ -172,7 +174,7 @@ void CodeGenLLVM::visit(BitStructStoreStmt *stmt) {
// Store all the components
builder->CreateStore(bit_struct_val, llvm_val[stmt->ptr]);
} else {
// Create a mask and use a single atomicCAS
// Create a mask and use a single (atomic)CAS
uint64 mask = 0;
for (auto &ch_id : stmt->ch_ids) {
auto &ch = bit_struct_snode->ch[ch_id];
Expand All @@ -188,7 +190,7 @@ void CodeGenLLVM::visit(BitStructStoreStmt *stmt) {
bit_struct_snode->ch[ch_id]->bit_offset);
}
store_masked(llvm_val[stmt->ptr], mask, bit_struct_physical_type,
bit_struct_val);
bit_struct_val, stmt->is_atomic);
}
}

Expand Down Expand Up @@ -296,7 +298,8 @@ void CodeGenLLVM::store_floats_with_shared_exponents(BitStructStoreStmt *stmt) {
update_mask(mask, num_digit_bits, digits_bit_offset);
}
}
store_masked(llvm_val[stmt->ptr], mask, bit_struct_physical_type, masked_val);
store_masked(llvm_val[stmt->ptr], mask, bit_struct_physical_type, masked_val,
stmt->is_atomic);
}

llvm::Value *CodeGenLLVM::extract_exponent_from_float(llvm::Value *f) {
Expand Down
3 changes: 3 additions & 0 deletions taichi/ir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ std::pair<std::unordered_set<SNode *>, std::unordered_set<SNode *>>
gather_snode_read_writes(IRNode *root);
std::vector<Stmt *> gather_statements(IRNode *root,
const std::function<bool(Stmt *)> &test);
std::unordered_map<OffloadedStmt *,
std::unordered_map<const SNode *, GlobalPtrStmt *>>
gather_uniquely_accessed_bit_structs(IRNode *root);
std::unordered_map<const SNode *, GlobalPtrStmt *>
gather_uniquely_accessed_pointers(IRNode *root);
std::unique_ptr<std::unordered_set<AtomicOpStmt *>> gather_used_atomics(
Expand Down
5 changes: 3 additions & 2 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -1182,11 +1182,12 @@ class BitStructStoreStmt : public Stmt {
Stmt *ptr;
std::vector<int> ch_ids;
std::vector<Stmt *> values;
bool is_atomic;

BitStructStoreStmt(Stmt *ptr,
const std::vector<int> &ch_ids,
const std::vector<Stmt *> &values)
: ptr(ptr), ch_ids(ch_ids), values(values) {
: ptr(ptr), ch_ids(ch_ids), values(values), is_atomic(true) {
TI_ASSERT(ch_ids.size() == values.size());
TI_STMT_REG_FIELDS;
}
Expand All @@ -1197,7 +1198,7 @@ class BitStructStoreStmt : public Stmt {
return false;
}

TI_STMT_DEF_FIELDS(ret_type, ptr, ch_ids, values);
TI_STMT_DEF_FIELDS(ret_type, ptr, ch_ids, values, is_atomic);
TI_DEFINE_ACCEPT_AND_CLONE;
};

Expand Down
6 changes: 5 additions & 1 deletion taichi/ir/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@ void demote_dense_struct_fors(IRNode *root);
bool demote_atomics(IRNode *root);
void reverse_segments(IRNode *root); // for autograd
void detect_read_only(IRNode *root);
void optimize_bit_struct_stores(IRNode *root);
void optimize_bit_struct_stores(
IRNode *root,
const std::unordered_map<OffloadedStmt *,
std::unordered_map<const SNode *, GlobalPtrStmt *>>
&uniquely_accessed_bit_structs);

// compile_to_offloads does the basic compilation to create all the offloaded
// tasks of a Taichi kernel. It's worth pointing out that this doesn't demote
Expand Down
2 changes: 2 additions & 0 deletions taichi/program/compile_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ struct CompileConfig {
// Setting 0 effectively means unlimited
int async_max_fuse_per_task{1};

bool quant_opt_atomic_demotion{true};

CompileConfig();
};

Expand Down
4 changes: 3 additions & 1 deletion taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,9 @@ void export_lang(py::module &m) {
&CompileConfig::async_opt_intermediate_file)
.def_readwrite("async_flush_every", &CompileConfig::async_flush_every)
.def_readwrite("async_max_fuse_per_task",
&CompileConfig::async_max_fuse_per_task);
&CompileConfig::async_max_fuse_per_task)
.def_readwrite("quant_opt_atomic_demotion",
&CompileConfig::quant_opt_atomic_demotion);

m.def("reset_default_compile_config",
[&]() { default_compile_config = CompileConfig(); });
Expand Down
7 changes: 6 additions & 1 deletion taichi/runtime/llvm/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1580,6 +1580,11 @@ void stack_push(Ptr stack, size_t max_num_elements, std::size_t element_size) {
// see #2096 for more details
#define DEFINE_SET_PARTIAL_BITS(N) \
void set_mask_b##N(u##N *ptr, u64 mask, u##N value) { \
u##N mask_N = (u##N)mask; \
*ptr = (*ptr & (~mask_N)) | value; \
} \
\
void atomic_set_mask_b##N(u##N *ptr, u64 mask, u##N value) { \
u##N mask_N = (u##N)mask; \
u##N new_value = 0; \
u##N old_value = *ptr; \
Expand All @@ -1594,7 +1599,7 @@ void stack_push(Ptr stack, size_t max_num_elements, std::size_t element_size) {
\
void set_partial_bits_b##N(u##N *ptr, u32 offset, u32 bits, u##N value) { \
u##N mask = ((~(u##N)0) << (N - bits)) >> (N - offset - bits); \
set_mask_b##N(ptr, mask, value << offset); \
atomic_set_mask_b##N(ptr, mask, value << offset); \
} \
\
u##N atomic_add_partial_bits_b##N(u##N *ptr, u32 offset, u32 bits, \
Expand Down
11 changes: 10 additions & 1 deletion taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,15 @@ void offload_to_executable(IRNode *ir,
print("Atomics demoted II");
irpass::analysis::verify(ir);

std::unordered_map<OffloadedStmt *,
std::unordered_map<const SNode *, GlobalPtrStmt *>>
uniquely_accessed_bit_structs;
if (is_extension_supported(config.arch, Extension::quant) &&
ir->get_config().quant_opt_atomic_demotion) {
uniquely_accessed_bit_structs =
irpass::analysis::gather_uniquely_accessed_bit_structs(ir);
}

irpass::remove_range_assumption(ir);
print("Remove range assumption");

Expand Down Expand Up @@ -205,7 +214,7 @@ void offload_to_executable(IRNode *ir,
print("Simplified IV");

if (is_extension_supported(config.arch, Extension::quant)) {
irpass::optimize_bit_struct_stores(ir);
irpass::optimize_bit_struct_stores(ir, uniquely_accessed_bit_structs);
print("Bit struct stores optimized");
}

Expand Down
4 changes: 2 additions & 2 deletions taichi/transforms/ir_printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -630,8 +630,8 @@ class IRPrinter : public IRVisitor {
values += ", ";
}
}
print("{} : bit_struct_store {}, ch_ids=[{}], values=[{}]", stmt->name(),
stmt->ptr->name(), ch_ids, values);
print("{} : {}bit_struct_store {}, ch_ids=[{}], values=[{}]", stmt->name(),
stmt->is_atomic ? "atomic " : "", stmt->ptr->name(), ch_ids, values);
}
};

Expand Down
88 changes: 85 additions & 3 deletions taichi/transforms/optimize_bit_struct_stores.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "taichi/ir/analysis.h"
#include "taichi/ir/ir.h"
#include "taichi/ir/statements.h"
#include "taichi/ir/transforms.h"
Expand All @@ -22,7 +23,7 @@ class CreateBitStructStores : public BasicStmtVisitor {
root->accept(&pass);
}

void visit(GlobalStoreStmt *stmt) {
void visit(GlobalStoreStmt *stmt) override {
auto get_ch = stmt->ptr->cast<GetChStmt>();
if (!get_ch || get_ch->input_snode->type != SNodeType::bit_struct)
return;
Expand Down Expand Up @@ -69,7 +70,7 @@ class MergeBitStructStores : public BasicStmtVisitor {
ptr_to_bit_struct_stores;
std::vector<Stmt *> statements_to_delete;
for (int i = 0; i <= (int)statements.size(); i++) {
// TODO: in some cases BitSturctStoreStmts across container statements can
// TODO: in some cases BitStructStoreStmts across container statements can
// still be merged, similar to basic block v.s. CFG optimizations.
if (i == statements.size() || statements[i]->is_container_statement()) {
for (const auto &item : ptr_to_bit_struct_stores) {
Expand Down Expand Up @@ -119,16 +120,97 @@ class MergeBitStructStores : public BasicStmtVisitor {
bool modified_{false};
};

class DemoteAtomicBitStructStores : public BasicStmtVisitor {
private:
const std::unordered_map<OffloadedStmt *,
std::unordered_map<const SNode *, GlobalPtrStmt *>>
&uniquely_accessed_bit_structs_;
std::unordered_map<OffloadedStmt *,
std::unordered_map<const SNode *, GlobalPtrStmt *>>::
const_iterator current_iterator_;
bool modified_{false};

public:
using BasicStmtVisitor::visit;
OffloadedStmt *current_offloaded;

explicit DemoteAtomicBitStructStores(
const std::unordered_map<
OffloadedStmt *,
std::unordered_map<const SNode *, GlobalPtrStmt *>>
&uniquely_accessed_bit_structs)
: uniquely_accessed_bit_structs_(uniquely_accessed_bit_structs),
current_offloaded(nullptr) {
allow_undefined_visitor = true;
invoke_default_visitor = false;
}

void visit(BitStructStoreStmt *stmt) override {
bool demote = false;
TI_ASSERT(current_offloaded);
if (current_offloaded->task_type == OffloadedTaskType::serial) {
demote = true;
} else if (current_offloaded->task_type == OffloadedTaskType::range_for ||
current_offloaded->task_type == OffloadedTaskType::struct_for) {
auto *snode = stmt->get_bit_struct_snode();
// Find the nearest non-bit-level ancestor
while (snode->is_bit_level) {
snode = snode->parent;
}
auto accessed_ptr_iterator = current_iterator_->second.find(snode);
if (accessed_ptr_iterator != current_iterator_->second.end() &&
accessed_ptr_iterator->second != nullptr) {
demote = true;
}
}
if (demote) {
stmt->is_atomic = false;
modified_ = true;
}
}

void visit(OffloadedStmt *stmt) override {
current_offloaded = stmt;
if (stmt->task_type == OffloadedTaskType::range_for ||
stmt->task_type == OffloadedTaskType::struct_for) {
current_iterator_ =
uniquely_accessed_bit_structs_.find(current_offloaded);
}
// We don't need to visit TLS/BLS prologues/epilogues.
if (stmt->body) {
stmt->body->accept(this);
}
current_offloaded = nullptr;
}

static bool run(IRNode *node,
const std::unordered_map<
OffloadedStmt *,
std::unordered_map<const SNode *, GlobalPtrStmt *>>
&uniquely_accessed_bit_structs) {
DemoteAtomicBitStructStores demoter(uniquely_accessed_bit_structs);
node->accept(&demoter);
return demoter.modified_;
}
};

} // namespace

TLANG_NAMESPACE_BEGIN

namespace irpass {
void optimize_bit_struct_stores(IRNode *root) {
void optimize_bit_struct_stores(
IRNode *root,
const std::unordered_map<OffloadedStmt *,
std::unordered_map<const SNode *, GlobalPtrStmt *>>
&uniquely_accessed_bit_structs) {
TI_AUTO_PROF;
CreateBitStructStores::run(root);
die(root); // remove unused GetCh
MergeBitStructStores::run(root);
if (root->get_config().quant_opt_atomic_demotion) {
DemoteAtomicBitStructStores::run(root, uniquely_accessed_bit_structs);
}
}

} // namespace irpass
Expand Down