Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[llvm] [refactor] Merge AtomicOpStmt codegen in CPU and CUDA backends #5086

Merged
merged 2 commits into from
Jun 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 2 additions & 115 deletions taichi/backends/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {

// Not all reduction statements can be optimized.
// If the operation cannot be optimized, this function returns nullptr.
llvm::Value *optimized_reduction(AtomicOpStmt *stmt) {
llvm::Value *optimized_reduction(AtomicOpStmt *stmt) override {
if (!stmt->is_reduction) {
return nullptr;
}
Expand Down Expand Up @@ -227,39 +227,6 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
{llvm_val[stmt->dest], llvm_val[stmt->val]});
}

llvm::Value *custom_type_atomic(AtomicOpStmt *stmt) {
if (stmt->op_type != AtomicOpType::add) {
return nullptr;
}

auto dst_type = stmt->dest->ret_type->as<PointerType>()->get_pointee_type();
if (auto cit = dst_type->cast<CustomIntType>()) {
return atomic_add_custom_int(stmt, cit);
} else if (auto cft = dst_type->cast<CustomFloatType>()) {
return atomic_add_custom_float(stmt, cft);
} else {
return nullptr;
}
}

llvm::Value *integral_type_atomic(AtomicOpStmt *stmt) {
if (!is_integral(stmt->val->ret_type)) {
return nullptr;
}
std::unordered_map<AtomicOpType, llvm::AtomicRMWInst::BinOp> bin_op;
bin_op[AtomicOpType::add] = llvm::AtomicRMWInst::BinOp::Add;
bin_op[AtomicOpType::min] = llvm::AtomicRMWInst::BinOp::Min;
bin_op[AtomicOpType::max] = llvm::AtomicRMWInst::BinOp::Max;

bin_op[AtomicOpType::bit_and] = llvm::AtomicRMWInst::BinOp::And;
bin_op[AtomicOpType::bit_or] = llvm::AtomicRMWInst::BinOp::Or;
bin_op[AtomicOpType::bit_xor] = llvm::AtomicRMWInst::BinOp::Xor;
TI_ASSERT(bin_op.find(stmt->op_type) != bin_op.end());
return builder->CreateAtomicRMW(
bin_op.at(stmt->op_type), llvm_val[stmt->dest], llvm_val[stmt->val],
llvm::AtomicOrdering::SequentiallyConsistent);
}

// A huge hack for supporting f16 atomic add/max/min! Borrowed from
// https://github.com/tensorflow/tensorflow/blob/470d58a83470f8ede3beaa584e6992bc71b7baa6/tensorflow/compiler/xla/service/gpu/ir_emitter.cc#L378-L490
// The reason is that LLVM10 does not support generating atomicCAS for f16 on
Expand Down Expand Up @@ -311,7 +278,7 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
llvm::Value *atomic_op_using_cas(
llvm::Value *output_address,
llvm::Value *val,
std::function<llvm::Value *(llvm::Value *, llvm::Value *)> op) {
std::function<llvm::Value *(llvm::Value *, llvm::Value *)> op) override {
llvm::PointerType *output_address_type =
llvm::dyn_cast<llvm::PointerType>(output_address->getType());
TI_ASSERT(output_address_type != nullptr);
Expand Down Expand Up @@ -406,86 +373,6 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
return output_address;
}

llvm::Value *real_or_unsigned_type_atomic(AtomicOpStmt *stmt) {
if (!stmt->val->ret_type->is<PrimitiveType>()) {
return nullptr;
}
AtomicOpType op = stmt->op_type;
if (stmt->val->ret_type->is_primitive(PrimitiveTypeID::f16)) {
switch (op) {
case AtomicOpType::add:
return atomic_op_using_cas(
llvm_val[stmt->dest], llvm_val[stmt->val],
[&](auto v1, auto v2) { return builder->CreateFAdd(v1, v2); });
case AtomicOpType::max:
return atomic_op_using_cas(
llvm_val[stmt->dest], llvm_val[stmt->val],
[&](auto v1, auto v2) { return builder->CreateMaxNum(v1, v2); });
case AtomicOpType::min:
return atomic_op_using_cas(
llvm_val[stmt->dest], llvm_val[stmt->val],
[&](auto v1, auto v2) { return builder->CreateMinNum(v1, v2); });
default:
break;
}
}

PrimitiveTypeID prim_type =
stmt->val->ret_type->cast<PrimitiveType>()->type;

std::unordered_map<PrimitiveTypeID,
std::unordered_map<AtomicOpType, std::string>>
atomics;

atomics[PrimitiveTypeID::f32][AtomicOpType::add] = "atomic_add_f32";
atomics[PrimitiveTypeID::f64][AtomicOpType::add] = "atomic_add_f64";
atomics[PrimitiveTypeID::f32][AtomicOpType::min] = "atomic_min_f32";
atomics[PrimitiveTypeID::f64][AtomicOpType::min] = "atomic_min_f64";
atomics[PrimitiveTypeID::f32][AtomicOpType::max] = "atomic_max_f32";
atomics[PrimitiveTypeID::f64][AtomicOpType::max] = "atomic_max_f64";
atomics[PrimitiveTypeID::u32][AtomicOpType::min] = "atomic_min_u32";
atomics[PrimitiveTypeID::u64][AtomicOpType::min] = "atomic_min_u64";
atomics[PrimitiveTypeID::u32][AtomicOpType::max] = "atomic_max_u32";
atomics[PrimitiveTypeID::u64][AtomicOpType::max] = "atomic_max_u64";

if (atomics.find(prim_type) == atomics.end()) {
return nullptr;
}
if (is_integral(stmt->val->ret_type) &&
atomics.at(prim_type).find(op) == atomics.at(prim_type).end()) {
return nullptr;
}
TI_ASSERT(atomics.at(prim_type).find(op) != atomics.at(prim_type).end());

return create_call(atomics.at(prim_type).at(op),
{llvm_val[stmt->dest], llvm_val[stmt->val]});
}

void visit(AtomicOpStmt *stmt) override {
// https://llvm.org/docs/NVPTXUsage.html#address-spaces
bool is_local = stmt->dest->is<AllocaStmt>();
if (is_local) {
TI_ERROR("Local atomics should have been demoted.");
}
TI_ASSERT(stmt->width() == 1);
for (int l = 0; l < stmt->width(); l++) {
llvm::Value *old_value;

if (llvm::Value *result = optimized_reduction(stmt)) {
old_value = result;
} else if (llvm::Value *result = custom_type_atomic(stmt)) {
old_value = result;
} else if (llvm::Value *result = real_or_unsigned_type_atomic(stmt)) {
old_value = result;
} else if (llvm::Value *result = integral_type_atomic(stmt)) {
old_value = result;
} else {
TI_NOT_IMPLEMENTED
}
llvm_val[stmt] = old_value;
}
}

void visit(RangeForStmt *for_stmt) override {
create_naive_range_for(for_stmt);
}
Expand Down
199 changes: 105 additions & 94 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1207,6 +1207,44 @@ void CodeGenLLVM::visit(SNodeOpStmt *stmt) {
}
}

llvm::Value *CodeGenLLVM::optimized_reduction(AtomicOpStmt *stmt) {
return nullptr;
}

llvm::Value *CodeGenLLVM::custom_type_atomic(AtomicOpStmt *stmt) {
// TODO(type): support all AtomicOpTypes on custom types
if (stmt->op_type != AtomicOpType::add) {
return nullptr;
}

auto dst_type = stmt->dest->ret_type->as<PointerType>()->get_pointee_type();
if (auto cit = dst_type->cast<CustomIntType>()) {
return atomic_add_custom_int(stmt, cit);
} else if (auto cft = dst_type->cast<CustomFloatType>()) {
return atomic_add_custom_float(stmt, cft);
} else {
return nullptr;
}
}

llvm::Value *CodeGenLLVM::integral_type_atomic(AtomicOpStmt *stmt) {
if (!is_integral(stmt->val->ret_type)) {
return nullptr;
}
std::unordered_map<AtomicOpType, llvm::AtomicRMWInst::BinOp> bin_op;
bin_op[AtomicOpType::add] = llvm::AtomicRMWInst::BinOp::Add;
bin_op[AtomicOpType::min] = llvm::AtomicRMWInst::BinOp::Min;
bin_op[AtomicOpType::max] = llvm::AtomicRMWInst::BinOp::Max;

bin_op[AtomicOpType::bit_and] = llvm::AtomicRMWInst::BinOp::And;
bin_op[AtomicOpType::bit_or] = llvm::AtomicRMWInst::BinOp::Or;
bin_op[AtomicOpType::bit_xor] = llvm::AtomicRMWInst::BinOp::Xor;
TI_ASSERT(bin_op.find(stmt->op_type) != bin_op.end());
return builder->CreateAtomicRMW(bin_op.at(stmt->op_type),
llvm_val[stmt->dest], llvm_val[stmt->val],
llvm::AtomicOrdering::SequentiallyConsistent);
}

llvm::Value *CodeGenLLVM::atomic_op_using_cas(
llvm::Value *dest,
llvm::Value *val,
Expand Down Expand Up @@ -1242,104 +1280,77 @@ llvm::Value *CodeGenLLVM::atomic_op_using_cas(
return old_val;
}

llvm::Value *CodeGenLLVM::real_or_unsigned_type_atomic(AtomicOpStmt *stmt) {
if (!stmt->val->ret_type->is<PrimitiveType>()) {
return nullptr;
}
AtomicOpType op = stmt->op_type;
if (stmt->val->ret_type->is_primitive(PrimitiveTypeID::f16)) {
switch (op) {
case AtomicOpType::add:
return atomic_op_using_cas(
llvm_val[stmt->dest], llvm_val[stmt->val],
[&](auto v1, auto v2) { return builder->CreateFAdd(v1, v2); });
case AtomicOpType::max:
return atomic_op_using_cas(
llvm_val[stmt->dest], llvm_val[stmt->val],
[&](auto v1, auto v2) { return builder->CreateMaxNum(v1, v2); });
case AtomicOpType::min:
return atomic_op_using_cas(
llvm_val[stmt->dest], llvm_val[stmt->val],
[&](auto v1, auto v2) { return builder->CreateMinNum(v1, v2); });
default:
break;
}
}

PrimitiveTypeID prim_type = stmt->val->ret_type->cast<PrimitiveType>()->type;

std::unordered_map<PrimitiveTypeID,
std::unordered_map<AtomicOpType, std::string>>
atomics;

atomics[PrimitiveTypeID::f32][AtomicOpType::add] = "atomic_add_f32";
atomics[PrimitiveTypeID::f64][AtomicOpType::add] = "atomic_add_f64";
atomics[PrimitiveTypeID::f32][AtomicOpType::min] = "atomic_min_f32";
atomics[PrimitiveTypeID::f64][AtomicOpType::min] = "atomic_min_f64";
atomics[PrimitiveTypeID::f32][AtomicOpType::max] = "atomic_max_f32";
atomics[PrimitiveTypeID::f64][AtomicOpType::max] = "atomic_max_f64";
atomics[PrimitiveTypeID::u32][AtomicOpType::min] = "atomic_min_u32";
atomics[PrimitiveTypeID::u64][AtomicOpType::min] = "atomic_min_u64";
atomics[PrimitiveTypeID::u32][AtomicOpType::max] = "atomic_max_u32";
atomics[PrimitiveTypeID::u64][AtomicOpType::max] = "atomic_max_u64";

if (atomics.find(prim_type) == atomics.end()) {
return nullptr;
}
if (is_integral(stmt->val->ret_type) &&
atomics.at(prim_type).find(op) == atomics.at(prim_type).end()) {
return nullptr;
}
TI_ASSERT(atomics.at(prim_type).find(op) != atomics.at(prim_type).end());

return create_call(atomics.at(prim_type).at(op),
{llvm_val[stmt->dest], llvm_val[stmt->val]});
}

void CodeGenLLVM::visit(AtomicOpStmt *stmt) {
// auto mask = stmt->parent->mask();
// TODO: deal with mask when vectorized
// TODO(type): support all AtomicOpTypes on custom types
bool is_local = stmt->dest->is<AllocaStmt>();
if (is_local) {
TI_ERROR("Local atomics should have been demoted.");
}
TI_ASSERT(stmt->width() == 1);
for (int l = 0; l < stmt->width(); l++) {
llvm::Value *old_value;
if (stmt->op_type == AtomicOpType::add) {
auto dst_type =
stmt->dest->ret_type->as<PointerType>()->get_pointee_type();
if (dst_type->is<PrimitiveType>() && is_integral(stmt->val->ret_type)) {
old_value = builder->CreateAtomicRMW(
llvm::AtomicRMWInst::BinOp::Add, llvm_val[stmt->dest],
llvm_val[stmt->val], llvm::AtomicOrdering::SequentiallyConsistent);
} else if (!dst_type->is<CustomFloatType>() &&
is_real(stmt->val->ret_type)) {
old_value = builder->CreateAtomicRMW(
llvm::AtomicRMWInst::BinOp::FAdd, llvm_val[stmt->dest],
llvm_val[stmt->val], llvm::AtomicOrdering::SequentiallyConsistent);
} else if (auto cit = dst_type->cast<CustomIntType>()) {
old_value = atomic_add_custom_int(stmt, cit);
} else if (auto cft = dst_type->cast<CustomFloatType>()) {
old_value = atomic_add_custom_float(stmt, cft);
} else {
TI_NOT_IMPLEMENTED
}
} else if (stmt->op_type == AtomicOpType::min) {
if (stmt->val->ret_type->is_primitive(PrimitiveTypeID::u32)) {
old_value = create_call("atomic_min_u32",
{llvm_val[stmt->dest], llvm_val[stmt->val]});
} else if (stmt->val->ret_type->is_primitive(PrimitiveTypeID::u64)) {
old_value = create_call("atomic_min_u64",
{llvm_val[stmt->dest], llvm_val[stmt->val]});
} else if (is_integral(stmt->val->ret_type)) {
old_value = builder->CreateAtomicRMW(
llvm::AtomicRMWInst::BinOp::Min, llvm_val[stmt->dest],
llvm_val[stmt->val], llvm::AtomicOrdering::SequentiallyConsistent);
} else if (stmt->val->ret_type->is_primitive(PrimitiveTypeID::f16)) {
old_value = atomic_op_using_cas(
llvm_val[stmt->dest], llvm_val[stmt->val],
[&](auto v1, auto v2) { return builder->CreateMinNum(v1, v2); });
} else if (stmt->val->ret_type->is_primitive(PrimitiveTypeID::f32)) {
old_value = create_call("atomic_min_f32",
{llvm_val[stmt->dest], llvm_val[stmt->val]});
} else if (stmt->val->ret_type->is_primitive(PrimitiveTypeID::f64)) {
old_value = create_call("atomic_min_f64",
{llvm_val[stmt->dest], llvm_val[stmt->val]});
} else {
TI_NOT_IMPLEMENTED
}
} else if (stmt->op_type == AtomicOpType::max) {
if (stmt->val->ret_type->is_primitive(PrimitiveTypeID::u32)) {
old_value = create_call("atomic_max_u32",
{llvm_val[stmt->dest], llvm_val[stmt->val]});
} else if (stmt->val->ret_type->is_primitive(PrimitiveTypeID::u64)) {
old_value = create_call("atomic_max_u64",
{llvm_val[stmt->dest], llvm_val[stmt->val]});
} else if (is_integral(stmt->val->ret_type)) {
old_value = builder->CreateAtomicRMW(
llvm::AtomicRMWInst::BinOp::Max, llvm_val[stmt->dest],
llvm_val[stmt->val], llvm::AtomicOrdering::SequentiallyConsistent);
} else if (stmt->val->ret_type->is_primitive(PrimitiveTypeID::f16)) {
old_value = atomic_op_using_cas(
llvm_val[stmt->dest], llvm_val[stmt->val],
[&](auto v1, auto v2) { return builder->CreateMaxNum(v1, v2); });
} else if (stmt->val->ret_type->is_primitive(PrimitiveTypeID::f32)) {
old_value = create_call("atomic_max_f32",
{llvm_val[stmt->dest], llvm_val[stmt->val]});
} else if (stmt->val->ret_type->is_primitive(PrimitiveTypeID::f64)) {
old_value = create_call("atomic_max_f64",
{llvm_val[stmt->dest], llvm_val[stmt->val]});
} else {
TI_NOT_IMPLEMENTED
}
} else if (stmt->op_type == AtomicOpType::bit_and) {
if (is_integral(stmt->val->ret_type)) {
old_value = builder->CreateAtomicRMW(
llvm::AtomicRMWInst::BinOp::And, llvm_val[stmt->dest],
llvm_val[stmt->val], llvm::AtomicOrdering::SequentiallyConsistent);
} else {
TI_NOT_IMPLEMENTED
}
} else if (stmt->op_type == AtomicOpType::bit_or) {
if (is_integral(stmt->val->ret_type)) {
old_value = builder->CreateAtomicRMW(
llvm::AtomicRMWInst::BinOp::Or, llvm_val[stmt->dest],
llvm_val[stmt->val], llvm::AtomicOrdering::SequentiallyConsistent);
} else {
TI_NOT_IMPLEMENTED
}
} else if (stmt->op_type == AtomicOpType::bit_xor) {
if (is_integral(stmt->val->ret_type)) {
old_value = builder->CreateAtomicRMW(
llvm::AtomicRMWInst::BinOp::Xor, llvm_val[stmt->dest],
llvm_val[stmt->val], llvm::AtomicOrdering::SequentiallyConsistent);
} else {
TI_NOT_IMPLEMENTED
}

if (llvm::Value *result = optimized_reduction(stmt)) {
old_value = result;
} else if (llvm::Value *result = custom_type_atomic(stmt)) {
old_value = result;
} else if (llvm::Value *result = real_or_unsigned_type_atomic(stmt)) {
old_value = result;
} else if (llvm::Value *result = integral_type_atomic(stmt)) {
old_value = result;
} else {
TI_NOT_IMPLEMENTED
}
Expand Down
18 changes: 13 additions & 5 deletions taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,19 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
CustomIntType *cit,
llvm::Value *real);

virtual llvm::Value *optimized_reduction(AtomicOpStmt *stmt);

virtual llvm::Value *custom_type_atomic(AtomicOpStmt *stmt);

virtual llvm::Value *integral_type_atomic(AtomicOpStmt *stmt);

virtual llvm::Value *atomic_op_using_cas(
llvm::Value *output_address,
llvm::Value *val,
std::function<llvm::Value *(llvm::Value *, llvm::Value *)> op);

virtual llvm::Value *real_or_unsigned_type_atomic(AtomicOpStmt *stmt);

void visit(AtomicOpStmt *stmt) override;

void visit(GlobalPtrStmt *stmt) override;
Expand Down Expand Up @@ -392,11 +405,6 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

llvm::Value *get_exponent_offset(llvm::Value *exponent, CustomFloatType *cft);

llvm::Value *atomic_op_using_cas(
llvm::Value *dest,
llvm::Value *val,
std::function<llvm::Value *(llvm::Value *, llvm::Value *)> op);

void visit(FuncCallStmt *stmt) override;

llvm::Value *bitcast_from_u64(llvm::Value *val, DataType type);
Expand Down