Skip to content

Commit

Permalink
Make atomic_add() return value
Browse files Browse the repository at this point in the history
  • Loading branch information
k-ye committed Jan 16, 2020
1 parent 0726d49 commit 39c2bf1
Show file tree
Hide file tree
Showing 10 changed files with 327 additions and 31 deletions.
12 changes: 8 additions & 4 deletions python/taichi/lang/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __add__(self, other):
__radd__ = __add__

def __iadd__(self, other):
taichi_lang_core.expr_atomic_add(self.ptr, other.ptr)
self.atomic_add(other)

def __neg__(self):
return Expr(taichi_lang_core.expr_neg(self.ptr), tb=self.stack_info())
Expand All @@ -65,7 +65,9 @@ def __sub__(self, other):
taichi_lang_core.expr_sub(self.ptr, other.ptr), tb=self.stack_info())

def __isub__(self, other):
taichi_lang_core.expr_atomic_sub(self.ptr, other.ptr)
# TODO: add atomic_sub()
import taichi as ti
ti.expr_init(taichi_lang_core.expr_atomic_sub(self.ptr, other.ptr))

def __imul__(self, other):
self.assign(Expr(taichi_lang_core.expr_mul(self.ptr, other.ptr)))
Expand Down Expand Up @@ -97,7 +99,7 @@ def __truediv__(self, other):

def __rtruediv__(self, other):
return Expr(taichi_lang_core.expr_truediv(Expr(other).ptr, self.ptr))

def __floordiv__(self, other):
return Expr(taichi_lang_core.expr_floordiv(self.ptr, Expr(other).ptr))

Expand Down Expand Up @@ -237,7 +239,9 @@ def fill(self, val):
fill_tensor(self, val)

def atomic_add(self, other):
taichi_lang_core.expr_atomic_add(self.ptr, other.ptr)
import taichi as ti
other_ptr = ti.wrap_scalar(other).ptr
return ti.expr_init(taichi_lang_core.expr_atomic_add(self.ptr, other_ptr))

def __pow__(self, power, modulo=None):
assert isinstance(power, int) and power >= 0
Expand Down
3 changes: 1 addition & 2 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ def wrap_scalar(x):


def atomic_add(a, b):
a.atomic_add(wrap_scalar(b))

return a.atomic_add(b)

def subscript(value, *indices):
import numpy as np
Expand Down
17 changes: 10 additions & 7 deletions taichi/backends/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -858,21 +858,24 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder {
void visit(AtomicOpStmt *stmt) override {
// auto mask = stmt->parent->mask();
// TODO: deal with mask when vectorized
TC_ASSERT(stmt->width() == 1);
for (int l = 0; l < stmt->width(); l++) {
TC_ASSERT(stmt->op_type == AtomicOpType::add);
llvm::Value *old_value;
if (stmt->val->ret_type.data_type == DataType::i32)
builder->CreateAtomicRMW(llvm::AtomicRMWInst::BinOp::Add,
stmt->dest->value, stmt->val->value,
llvm::AtomicOrdering::SequentiallyConsistent);
old_value = builder->CreateAtomicRMW(
llvm::AtomicRMWInst::BinOp::Add, stmt->dest->value,
stmt->val->value, llvm::AtomicOrdering::SequentiallyConsistent);
else if (stmt->val->ret_type.data_type == DataType::f32) {
builder->CreateCall(get_runtime_function("atomic_add_f32"),
{stmt->dest->value, stmt->val->value});
old_value = builder->CreateCall(get_runtime_function("atomic_add_f32"),
{stmt->dest->value, stmt->val->value});
} else if (stmt->val->ret_type.data_type == DataType::f64) {
builder->CreateCall(get_runtime_function("atomic_add_f64"),
{stmt->dest->value, stmt->val->value});
old_value = builder->CreateCall(get_runtime_function("atomic_add_f64"),
{stmt->dest->value, stmt->val->value});
} else {
TC_NOT_IMPLEMENTED
}
stmt->value = old_value;
}
}

Expand Down
18 changes: 11 additions & 7 deletions taichi/backends/codegen_llvm_ptx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,23 +217,27 @@ class CodeGenLLVMGPU : public CodeGenLLVM {
} else {
for (int l = 0; l < stmt->width(); l++) {
TC_ASSERT(stmt->op_type == AtomicOpType::add);
llvm::Value *old_value;
if (is_integral(stmt->val->ret_type.data_type)) {
builder->CreateAtomicRMW(
old_value = builder->CreateAtomicRMW(
llvm::AtomicRMWInst::BinOp::Add, stmt->dest->value,
stmt->val->value, llvm::AtomicOrdering::SequentiallyConsistent);
} else if (stmt->val->ret_type.data_type == DataType::f32) {
auto dt = tlctx->get_data_type(DataType::f32);
builder->CreateIntrinsic(Intrinsic::nvvm_atomic_load_add_f32,
{llvm::PointerType::get(dt, 0)},
{stmt->dest->value, stmt->val->value});
old_value =
builder->CreateIntrinsic(Intrinsic::nvvm_atomic_load_add_f32,
{llvm::PointerType::get(dt, 0)},
{stmt->dest->value, stmt->val->value});
} else if (stmt->val->ret_type.data_type == DataType::f64) {
auto dt = tlctx->get_data_type(DataType::f64);
builder->CreateIntrinsic(Intrinsic::nvvm_atomic_load_add_f64,
{llvm::PointerType::get(dt, 0)},
{stmt->dest->value, stmt->val->value});
old_value =
builder->CreateIntrinsic(Intrinsic::nvvm_atomic_load_add_f64,
{llvm::PointerType::get(dt, 0)},
{stmt->dest->value, stmt->val->value});
} else {
TC_NOT_IMPLEMENTED
}
stmt->value = old_value;
}
}
}
Expand Down
35 changes: 35 additions & 0 deletions taichi/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -1898,6 +1898,41 @@ class IdExpression : public Expression {
}
};

// This is just a wrapper class of FrontendAtomicStmt, so that we can turn
// ti.atomic_op() into an expression (with side effect).
class AtomicOpExpression : public Expression {
// TODO(issue#332): Flatten this into AtomicOpStmt directly, then we can
// deprecate FrontendAtomicStmt.
public:
AtomicOpType op_type;
Expr dest, val;

AtomicOpExpression(AtomicOpType op_type, Expr dest, Expr val)
: op_type(op_type), dest(dest), val(val) {
}

std::string serialize() override {
if (op_type == AtomicOpType::add) {
return fmt::format("atomic_add({}, {})", dest.serialize(),
val.serialize());
} else if (op_type == AtomicOpType::sub) {
return fmt::format("atomic_sub({}, {})", dest.serialize(),
val.serialize());
} else {
// min/max not supported in the LLVM backend yet.
TC_NOT_IMPLEMENTED;
}
}

void flatten(VecStatement &ret) override {
// FrontendAtomicStmt is the correct place to flatten sub-exprs like |dest|
// and |val| (See LowerAST). This class only wraps the frontend atomic_op()
// stmt as an expression.
ret.push_back<FrontendAtomicStmt>(op_type, dest, val);
stmt = ret.back().get();
}
};

class SNodeOpExpression : public Expression {
public:
SNode *snode;
Expand Down
8 changes: 4 additions & 4 deletions taichi/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,13 +238,13 @@ void export_lang(py::module &m) {
m.def("value_cast", static_cast<Expr (*)(const Expr &expr, DataType)>(cast));

m.def("expr_atomic_add", [&](const Expr &a, const Expr &b) {
current_ast_builder().insert(Stmt::make<FrontendAtomicStmt>(
AtomicOpType::add, ptr_if_global(a), load_if_ptr(b)));
return Expr::make<AtomicOpExpression>(AtomicOpType::add, ptr_if_global(a),
load_if_ptr(b));
});

m.def("expr_atomic_sub", [&](const Expr &a, const Expr &b) {
current_ast_builder().insert(Stmt::make<FrontendAtomicStmt>(
AtomicOpType::sub, ptr_if_global(a), load_if_ptr(b)));
return Expr::make<AtomicOpExpression>(AtomicOpType::sub, ptr_if_global(a),
load_if_ptr(b));
});

m.def("expr_add", expr_add);
Expand Down
28 changes: 25 additions & 3 deletions taichi/transforms/demote_atomics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,41 @@ class DemoteAtomics : public BasicStmtVisitor {
auto val = stmt->val;

auto new_stmts = VecStatement();
Stmt *load;
if (is_local) {
TC_ASSERT(stmt->width() == 1);
auto load = new_stmts.push_back<LocalLoadStmt>(LocalAddress(ptr, 0));
load = new_stmts.push_back<LocalLoadStmt>(LocalAddress(ptr, 0));
auto add =
new_stmts.push_back<BinaryOpStmt>(BinaryOpType::add, load, val);
new_stmts.push_back<LocalStoreStmt>(ptr, add);
} else {
auto load = new_stmts.push_back<GlobalLoadStmt>(ptr);
load = new_stmts.push_back<GlobalLoadStmt>(ptr);
auto add =
new_stmts.push_back<BinaryOpStmt>(BinaryOpType::add, load, val);
new_stmts.push_back<GlobalStoreStmt>(ptr, add);
}
stmt->parent->replace_with(stmt, new_stmts);
// For a taichi program like `c = ti.atomic_add(a, b)`, the IR looks
// like the following
//
// $c = # lhs memory
// $d = atomic add($a, $b)
// $e : store [$c <- $d]
//
// If this gets demoted, the IR is translated into:
//
// $c = # lhs memory
// $d' = load $a <-- added by demote_atomic
// $e' = add $d' $b
// $f : store [$a <- $e'] <-- added by demote_atomic
// $g : store [$c <- ???] <-- store the old value into lhs $c
//
// Naively relying on Block::replace_with() would incorrectly fill $f
// into ???, because $f is a store stmt that doesn't have a return
// value. The correct thing is to replace |stmt| $d with the loaded
// old value $d'.
// See also: https://github.com/taichi-dev/taichi/issues/332
stmt->replace_with(load);
stmt->parent->replace_with(stmt, new_stmts, /*replace_usages=*/false);
throw IRModified();
}
}
Expand Down
46 changes: 42 additions & 4 deletions taichi/transforms/simplify.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <set>
#include <unordered_set>
#include "../ir.h"

TLANG_NAMESPACE_BEGIN
Expand Down Expand Up @@ -342,6 +343,15 @@ class BasicBlockSimplify : public IRVisitor {
stmt->ptr)) {
has_load = true;
}
if (block->statements[j]->is<AtomicOpStmt>() &&
(block->statements[j]->as<AtomicOpStmt>()->dest ==
stmt->ptr)) {
// $a = alloca
// $b : local store [$a <- v1] <-- prev lstore |bstmt_|
// $c = atomic add($a, v2) <-- cannot eliminate $b
// $d : local store [$a <- v3]
has_load = true;
}
}
if (!has_load) {
stmt->parent->erase(bstmt_);
Expand All @@ -353,7 +363,6 @@ class BasicBlockSimplify : public IRVisitor {
}

// has following load?

if (stmt->parent->locate(stmt->ptr) != -1) {
// optimize local variables only
bool has_related = false;
Expand All @@ -371,6 +380,16 @@ class BasicBlockSimplify : public IRVisitor {
break;
}
}
if (bstmt->is<AtomicOpStmt>()) {
// $a = alloca
// $b : local store [$a <- v1]
// $c = atomic add($a, v2) <-- cannot eliminate $b
auto bstmt_ = bstmt->as<AtomicOpStmt>();
if (bstmt_->dest == stmt->ptr) {
has_related = true;
break;
}
}
}
if (!has_related) {
stmt->parent->erase(stmt);
Expand Down Expand Up @@ -783,6 +802,21 @@ class BasicBlockSimplify : public IRVisitor {
return stmt->is<GlobalStoreStmt>() || stmt->is<AtomicOpStmt>();
}

static bool is_atomic_value_used(const std::vector<pStmt> &clause,
int atomic_stmt_i) {
// Cast type to check precondition
const auto *stmt = clause[atomic_stmt_i]->as<AtomicOpStmt>();
for (size_t i = atomic_stmt_i + 1; i < clause.size(); ++i) {
for (const auto &op : clause[i]->get_operands()) {
// Simpler to do pointer comparison?
if (op && (op->instance_id == stmt->instance_id)) {
return true;
}
}
}
return false;
}

void visit(IfStmt *if_stmt) override {
auto flatten = [&](std::vector<pStmt> &clause, bool true_branch) {
bool plain_clause = true; // no global store, no container
Expand All @@ -792,7 +826,7 @@ class BasicBlockSimplify : public IRVisitor {
// global side effects. LocalStore is kept and specially treated later.

bool global_state_changed = false;
for (int i = 0; i < (int)clause.size(); i++) {
for (int i = 0; i < (int)clause.size() && plain_clause; i++) {
bool has_side_effects = clause[i]->is_container_statement() ||
clause[i]->has_global_side_effect();

Expand All @@ -802,9 +836,11 @@ class BasicBlockSimplify : public IRVisitor {
plain_clause = false;
}

if (is_global_write(clause[i].get()) ||
if (clause[i]->is<GlobalStoreStmt>() ||
clause[i]->is<LocalStoreStmt>() || !has_side_effects) {
// This stmt can be kept.
} else if (clause[i]->is<AtomicOpStmt>()) {
plain_clause = plain_clause && !is_atomic_value_used(clause, i);
} else {
plain_clause = false;
}
Expand All @@ -816,7 +852,9 @@ class BasicBlockSimplify : public IRVisitor {
for (int i = 0; i < (int)clause.size(); i++) {
if (is_global_write(clause[i].get())) {
// do nothing. Keep the statement.
} else if (clause[i]->is<LocalStoreStmt>()) {
continue;
}
if (clause[i]->is<LocalStoreStmt>()) {
auto store = clause[i]->as<LocalStoreStmt>();
auto lanes = LaneAttribute<LocalAddress>();
for (int l = 0; l < store->width(); l++) {
Expand Down
3 changes: 3 additions & 0 deletions taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ class TypeCheck : public IRVisitor {
stmt->val = insert_type_cast_before(stmt, stmt->val,
stmt->dest->ret_type.data_type);
}
if (stmt->element_type() == DataType::unknown) {
stmt->ret_type = stmt->dest->ret_type;
}
}

void visit(LocalLoadStmt *stmt) {
Expand Down
Loading

0 comments on commit 39c2bf1

Please sign in to comment.