diff --git a/python/taichi/lang/expr.py b/python/taichi/lang/expr.py index c963ca43215ffd..896be7b0a053a2 100644 --- a/python/taichi/lang/expr.py +++ b/python/taichi/lang/expr.py @@ -54,7 +54,7 @@ def __add__(self, other): __radd__ = __add__ def __iadd__(self, other): - taichi_lang_core.expr_atomic_add(self.ptr, other.ptr) + self.atomic_add(other) def __neg__(self): return Expr(taichi_lang_core.expr_neg(self.ptr), tb=self.stack_info()) @@ -65,7 +65,9 @@ def __sub__(self, other): taichi_lang_core.expr_sub(self.ptr, other.ptr), tb=self.stack_info()) def __isub__(self, other): - taichi_lang_core.expr_atomic_sub(self.ptr, other.ptr) + # TODO: add atomic_sub() + import taichi as ti + ti.expr_init(taichi_lang_core.expr_atomic_sub(self.ptr, other.ptr)) def __imul__(self, other): self.assign(Expr(taichi_lang_core.expr_mul(self.ptr, other.ptr))) @@ -97,7 +99,7 @@ def __truediv__(self, other): def __rtruediv__(self, other): return Expr(taichi_lang_core.expr_truediv(Expr(other).ptr, self.ptr)) - + def __floordiv__(self, other): return Expr(taichi_lang_core.expr_floordiv(self.ptr, Expr(other).ptr)) @@ -237,7 +239,9 @@ def fill(self, val): fill_tensor(self, val) def atomic_add(self, other): - taichi_lang_core.expr_atomic_add(self.ptr, other.ptr) + import taichi as ti + other_ptr = ti.wrap_scalar(other).ptr + return ti.expr_init(taichi_lang_core.expr_atomic_add(self.ptr, other_ptr)) def __pow__(self, power, modulo=None): assert isinstance(power, int) and power >= 0 diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 5d4f7212069739..c57c3c320c7da7 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -35,8 +35,7 @@ def wrap_scalar(x): def atomic_add(a, b): - a.atomic_add(wrap_scalar(b)) - + return a.atomic_add(b) def subscript(value, *indices): import numpy as np diff --git a/taichi/backends/codegen_llvm.h b/taichi/backends/codegen_llvm.h index 7ba69ca638da1e..48f08817d2b17a 100644 --- a/taichi/backends/codegen_llvm.h +++ b/taichi/backends/codegen_llvm.h @@ -858,21 +858,24 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { void visit(AtomicOpStmt *stmt) override { // auto mask = stmt->parent->mask(); // TODO: deal with mask when vectorized + TC_ASSERT(stmt->width() == 1); for (int l = 0; l < stmt->width(); l++) { TC_ASSERT(stmt->op_type == AtomicOpType::add); + llvm::Value *old_value; if (stmt->val->ret_type.data_type == DataType::i32) - builder->CreateAtomicRMW(llvm::AtomicRMWInst::BinOp::Add, - stmt->dest->value, stmt->val->value, - llvm::AtomicOrdering::SequentiallyConsistent); + old_value = builder->CreateAtomicRMW( + llvm::AtomicRMWInst::BinOp::Add, stmt->dest->value, + stmt->val->value, llvm::AtomicOrdering::SequentiallyConsistent); else if (stmt->val->ret_type.data_type == DataType::f32) { - builder->CreateCall(get_runtime_function("atomic_add_f32"), - {stmt->dest->value, stmt->val->value}); + old_value = builder->CreateCall(get_runtime_function("atomic_add_f32"), + {stmt->dest->value, stmt->val->value}); } else if (stmt->val->ret_type.data_type == DataType::f64) { - builder->CreateCall(get_runtime_function("atomic_add_f64"), - {stmt->dest->value, stmt->val->value}); + old_value = builder->CreateCall(get_runtime_function("atomic_add_f64"), + {stmt->dest->value, stmt->val->value}); } else { TC_NOT_IMPLEMENTED } + stmt->value = old_value; } } diff --git a/taichi/backends/codegen_llvm_ptx.cpp b/taichi/backends/codegen_llvm_ptx.cpp index b645ccf4848374..c98d640bbe1501 100644 --- a/taichi/backends/codegen_llvm_ptx.cpp +++ b/taichi/backends/codegen_llvm_ptx.cpp @@ -217,23 +217,27 @@ class CodeGenLLVMGPU : public CodeGenLLVM { } else { for (int l = 0; l < stmt->width(); l++) { TC_ASSERT(stmt->op_type == AtomicOpType::add); + llvm::Value *old_value; if (is_integral(stmt->val->ret_type.data_type)) { - builder->CreateAtomicRMW( + old_value = builder->CreateAtomicRMW( llvm::AtomicRMWInst::BinOp::Add, stmt->dest->value, stmt->val->value, llvm::AtomicOrdering::SequentiallyConsistent); } else if (stmt->val->ret_type.data_type == DataType::f32) { auto dt = tlctx->get_data_type(DataType::f32); - builder->CreateIntrinsic(Intrinsic::nvvm_atomic_load_add_f32, - {llvm::PointerType::get(dt, 0)}, - {stmt->dest->value, stmt->val->value}); + old_value = + builder->CreateIntrinsic(Intrinsic::nvvm_atomic_load_add_f32, + {llvm::PointerType::get(dt, 0)}, + {stmt->dest->value, stmt->val->value}); } else if (stmt->val->ret_type.data_type == DataType::f64) { auto dt = tlctx->get_data_type(DataType::f64); - builder->CreateIntrinsic(Intrinsic::nvvm_atomic_load_add_f64, - {llvm::PointerType::get(dt, 0)}, - {stmt->dest->value, stmt->val->value}); + old_value = + builder->CreateIntrinsic(Intrinsic::nvvm_atomic_load_add_f64, + {llvm::PointerType::get(dt, 0)}, + {stmt->dest->value, stmt->val->value}); } else { TC_NOT_IMPLEMENTED } + stmt->value = old_value; } } } diff --git a/taichi/ir.h b/taichi/ir.h index 483c9df978e547..f3f735735c946e 100644 --- a/taichi/ir.h +++ b/taichi/ir.h @@ -1898,6 +1898,41 @@ class IdExpression : public Expression { } }; +// This is just a wrapper class of FrontendAtomicStmt, so that we can turn +// ti.atomic_op() into an expression (with side effect). +class AtomicOpExpression : public Expression { + // TODO(issue#332): Flatten this into AtomicOpStmt directly, then we can + // deprecate FrontendAtomicStmt. + public: + AtomicOpType op_type; + Expr dest, val; + + AtomicOpExpression(AtomicOpType op_type, Expr dest, Expr val) + : op_type(op_type), dest(dest), val(val) { + } + + std::string serialize() override { + if (op_type == AtomicOpType::add) { + return fmt::format("atomic_add({}, {})", dest.serialize(), + val.serialize()); + } else if (op_type == AtomicOpType::sub) { + return fmt::format("atomic_sub({}, {})", dest.serialize(), + val.serialize()); + } else { + // min/max not supported in the LLVM backend yet. + TC_NOT_IMPLEMENTED; + } + } + + void flatten(VecStatement &ret) override { + // FrontendAtomicStmt is the correct place to flatten sub-exprs like |dest| + // and |val| (See LowerAST). This class only wraps the frontend atomic_op() + // stmt as an expression. + ret.push_back(op_type, dest, val); + stmt = ret.back().get(); + } +}; + class SNodeOpExpression : public Expression { public: SNode *snode; diff --git a/taichi/python_bindings.cpp b/taichi/python_bindings.cpp index da838534be26d4..4a40fe7b9f5c0b 100644 --- a/taichi/python_bindings.cpp +++ b/taichi/python_bindings.cpp @@ -238,13 +238,13 @@ void export_lang(py::module &m) { m.def("value_cast", static_cast(cast)); m.def("expr_atomic_add", [&](const Expr &a, const Expr &b) { - current_ast_builder().insert(Stmt::make( - AtomicOpType::add, ptr_if_global(a), load_if_ptr(b))); + return Expr::make(AtomicOpType::add, ptr_if_global(a), + load_if_ptr(b)); }); m.def("expr_atomic_sub", [&](const Expr &a, const Expr &b) { - current_ast_builder().insert(Stmt::make( - AtomicOpType::sub, ptr_if_global(a), load_if_ptr(b))); + return Expr::make(AtomicOpType::sub, ptr_if_global(a), + load_if_ptr(b)); }); m.def("expr_add", expr_add); diff --git a/taichi/transforms/demote_atomics.cpp b/taichi/transforms/demote_atomics.cpp index ba6c0fa22681fa..b1486d1b30cb4d 100644 --- a/taichi/transforms/demote_atomics.cpp +++ b/taichi/transforms/demote_atomics.cpp @@ -32,19 +32,41 @@ class DemoteAtomics : public BasicStmtVisitor { auto val = stmt->val; auto new_stmts = VecStatement(); + Stmt *load; if (is_local) { TC_ASSERT(stmt->width() == 1); - auto load = new_stmts.push_back(LocalAddress(ptr, 0)); + load = new_stmts.push_back(LocalAddress(ptr, 0)); auto add = new_stmts.push_back(BinaryOpType::add, load, val); new_stmts.push_back(ptr, add); } else { - auto load = new_stmts.push_back(ptr); + load = new_stmts.push_back(ptr); auto add = new_stmts.push_back(BinaryOpType::add, load, val); new_stmts.push_back(ptr, add); } - stmt->parent->replace_with(stmt, new_stmts); + // For a taichi program like `c = ti.atomic_add(a, b)`, the IR looks + // like the following + // + // $c = # lhs memory + // $d = atomic add($a, $b) + // $e : store [$c <- $d] + // + // If this gets demoted, the IR is translated into: + // + // $c = # lhs memory + // $d' = load $a <-- added by demote_atomic + // $e' = add $d' $b + // $f : store [$a <- $e'] <-- added by demote_atomic + // $g : store [$c <- ???] <-- store the old value into lhs $c + // + // Naively relying on Block::replace_with() would incorrectly fill $f + // into ???, because $f is a store stmt that doesn't have a return + // value. The correct thing is to replace |stmt| $d with the loaded + // old value $d'. + // See also: https://github.com/taichi-dev/taichi/issues/332 + stmt->replace_with(load); + stmt->parent->replace_with(stmt, new_stmts, /*replace_usages=*/false); throw IRModified(); } } diff --git a/taichi/transforms/simplify.cpp b/taichi/transforms/simplify.cpp index 911c6fdd0f38fe..1f6c6ad8294a75 100644 --- a/taichi/transforms/simplify.cpp +++ b/taichi/transforms/simplify.cpp @@ -1,4 +1,5 @@ #include +#include #include "../ir.h" TLANG_NAMESPACE_BEGIN @@ -342,6 +343,15 @@ class BasicBlockSimplify : public IRVisitor { stmt->ptr)) { has_load = true; } + if (block->statements[j]->is() && + (block->statements[j]->as()->dest == + stmt->ptr)) { + // $a = alloca + // $b : local store [$a <- v1] <-- prev lstore |bstmt_| + // $c = atomic add($a, v2) <-- cannot eliminate $b + // $d : local store [$a <- v3] + has_load = true; + } } if (!has_load) { stmt->parent->erase(bstmt_); @@ -353,7 +363,6 @@ class BasicBlockSimplify : public IRVisitor { } // has following load? - if (stmt->parent->locate(stmt->ptr) != -1) { // optimize local variables only bool has_related = false; @@ -371,6 +380,16 @@ class BasicBlockSimplify : public IRVisitor { break; } } + if (bstmt->is()) { + // $a = alloca + // $b : local store [$a <- v1] + // $c = atomic add($a, v2) <-- cannot eliminate $b + auto bstmt_ = bstmt->as(); + if (bstmt_->dest == stmt->ptr) { + has_related = true; + break; + } + } } if (!has_related) { stmt->parent->erase(stmt); @@ -783,6 +802,21 @@ class BasicBlockSimplify : public IRVisitor { return stmt->is() || stmt->is(); } + static bool is_atomic_value_used(const std::vector &clause, + int atomic_stmt_i) { + // Cast type to check precondition + const auto *stmt = clause[atomic_stmt_i]->as(); + for (size_t i = atomic_stmt_i + 1; i < clause.size(); ++i) { + for (const auto &op : clause[i]->get_operands()) { + // Simpler to do pointer comparison? + if (op && (op->instance_id == stmt->instance_id)) { + return true; + } + } + } + return false; + } + void visit(IfStmt *if_stmt) override { auto flatten = [&](std::vector &clause, bool true_branch) { bool plain_clause = true; // no global store, no container @@ -792,7 +826,7 @@ class BasicBlockSimplify : public IRVisitor { // global side effects. LocalStore is kept and specially treated later. bool global_state_changed = false; - for (int i = 0; i < (int)clause.size(); i++) { + for (int i = 0; i < (int)clause.size() && plain_clause; i++) { bool has_side_effects = clause[i]->is_container_statement() || clause[i]->has_global_side_effect(); @@ -802,9 +836,11 @@ class BasicBlockSimplify : public IRVisitor { plain_clause = false; } - if (is_global_write(clause[i].get()) || + if (clause[i]->is() || clause[i]->is() || !has_side_effects) { // This stmt can be kept. + } else if (clause[i]->is()) { + plain_clause = !is_atomic_value_used(clause, i); } else { plain_clause = false; } @@ -816,7 +852,9 @@ class BasicBlockSimplify : public IRVisitor { for (int i = 0; i < (int)clause.size(); i++) { if (is_global_write(clause[i].get())) { // do nothing. Keep the statement. - } else if (clause[i]->is()) { + continue; + } + if (clause[i]->is()) { auto store = clause[i]->as(); auto lanes = LaneAttribute(); for (int l = 0; l < store->width(); l++) { diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index a9d9fe049326a5..0e29ab0143e3ab 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -52,6 +52,9 @@ class TypeCheck : public IRVisitor { stmt->val = insert_type_cast_before(stmt, stmt->val, stmt->dest->ret_type.data_type); } + if (stmt->element_type() == DataType::unknown) { + stmt->ret_type = stmt->dest->ret_type; + } } void visit(LocalLoadStmt *stmt) { diff --git a/tests/python/test_atomic.py b/tests/python/test_atomic.py new file mode 100644 index 00000000000000..5d0fc206cd2d97 --- /dev/null +++ b/tests/python/test_atomic.py @@ -0,0 +1,188 @@ +import taichi as ti +from pytest import approx + +ti.cfg.print_ir = True +n = 128 + + +def run_atomic_add_global_case(vartype, step, valproc=lambda x: x): + x = ti.var(vartype) + y = ti.var(vartype) + c = ti.var(vartype) + + @ti.layout + def place(): + ti.root.dense(ti.i, n).place(x, y) + ti.root.place(c) + + @ti.kernel + def func(): + ck = ti.to_numpy_type(vartype)(0) + for i in range(n): + x[i] = ti.atomic_add(c[None], step) + y[i] = ti.atomic_add(ck, step) + + func() + + assert valproc(c[None]) == n * step + x_actual = sorted(x.to_numpy()) + y_actual = sorted(y.to_numpy()) + expect = [i * step for i in range(n)] + for (xa, ya, e) in zip(x_actual, y_actual, expect): + assert valproc(xa) == e + assert valproc(ya) == e + + +@ti.all_archs +def test_atomic_add_global_i32(): + run_atomic_add_global_case(ti.i32, 42) + + +@ti.all_archs +def test_atomic_add_global_f32(): + run_atomic_add_global_case( + ti.f32, 4.2, valproc=lambda x: approx(x, rel=1e-5)) + + +@ti.all_archs +def test_atomic_add_expr_evaled(): + c = ti.var(ti.i32) + step = 42 + + @ti.layout + def place(): + ti.root.place(c) + + @ti.kernel + def func(): + for i in range(n): + # this is an expr with side effect, make sure it's not optimized out. + ti.atomic_add(c[None], step) + + func() + + assert c[None] == n * step + + +@ti.all_archs +def test_atomic_add_demoted(): + # Ensure demoted atomics do not crash the program. + x = ti.var(ti.i32) + y = ti.var(ti.i32) + step = 42 + + @ti.layout + def place(): + ti.root.dense(ti.i, n).place(x, y) + + @ti.kernel + def func(): + for i in range(n): + s = i + # Both adds should get demoted. + x[i] = ti.atomic_add(s, step) + y[i] = s.atomic_add(step) + + func() + + for i in range(n): + assert x[i] == i + assert y[i] == i + step + + +@ti.all_archs +def test_atomic_add_with_local_store_simplify1(): + # Test for the following LocalStoreStmt simplification case: + # + # local store [$a <- ...] + # atomic add ($a, ...) + # local store [$a <- ...] + # + # Specifically, the second store should not suppress the first one, because + # atomic_add can return value. + x = ti.var(ti.i32) + y = ti.var(ti.i32) + step = 42 + + @ti.layout + def place(): + ti.root.dense(ti.i, n).place(x, y) + + @ti.kernel + def func(): + for i in range(n): + # do a local store + j = i + x[i] = ti.atomic_add(j, step) + # do another local store, make sure the previous one is not optimized out + j = x[i] + y[i] = j + + func() + + for i in range(n): + assert x[i] == i + assert y[i] == i + + +@ti.all_archs +def test_atomic_add_with_local_store_simplify2(): + # Test for the following LocalStoreStmt simplification case: + # + # local store [$a <- ...] + # atomic add ($a, ...) + # + # Specifically, the local store should not be removed, because + # atomic_add can return its value. + x = ti.var(ti.i32) + step = 42 + + @ti.layout + def place(): + ti.root.dense(ti.i, n).place(x) + + @ti.kernel + def func(): + for i in range(n): + j = i + x[i] = ti.atomic_add(j, step) + + func() + + for i in range(n): + assert x[i] == i + + +@ti.all_archs +def test_atomic_add_with_if_simplify(): + # Make sure IfStmt simplification doesn't move stmts depending on the result + # of atomic_add() + x = ti.var(ti.i32) + step = 42 + + @ti.layout + def place(): + ti.root.dense(ti.i, n).place(x) + + boundary = n / 2 + @ti.kernel + def func(): + for i in range(n): + if i > boundary: + # A sequence of commands designed such that atomic_add() is the only + # thing to decide whether the if branch can be simplified. + s = i + j = s.atomic_add(s) + k = j + s + x[i] = k + else: + # If we look at the IR, this branch should be simplified, since nobody + # is using atomic_add's result. + x[i].atomic_add(i) + x[i] += step + + func() + + for i in range(n): + expect = i * 3 if i > boundary else (i + step) + assert x[i] == expect