diff --git a/python/taichi/lang/expr.py b/python/taichi/lang/expr.py index c963ca43215ffd..bcee0133fe774c 100644 --- a/python/taichi/lang/expr.py +++ b/python/taichi/lang/expr.py @@ -237,7 +237,8 @@ def fill(self, val): fill_tensor(self, val) def atomic_add(self, other): - taichi_lang_core.expr_atomic_add(self.ptr, other.ptr) + import taichi as ti + return ti.expr_init(taichi_lang_core.expr_atomic_add(self.ptr, other.ptr)) def __pow__(self, power, modulo=None): assert isinstance(power, int) and power >= 0 diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 5d4f7212069739..4b7427b39535fc 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -35,7 +35,7 @@ def wrap_scalar(x): def atomic_add(a, b): - a.atomic_add(wrap_scalar(b)) + return a.atomic_add(wrap_scalar(b)) def subscript(value, *indices): diff --git a/taichi/backends/codegen_llvm.h b/taichi/backends/codegen_llvm.h index be8d76b262e86f..121e482ab84905 100644 --- a/taichi/backends/codegen_llvm.h +++ b/taichi/backends/codegen_llvm.h @@ -860,19 +860,22 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { // TODO: deal with mask when vectorized for (int l = 0; l < stmt->width(); l++) { TC_ASSERT(stmt->op_type == AtomicOpType::add); + // Is this broken if stmt->width() is greater than 1? + llvm::Value *value; if (stmt->val->ret_type.data_type == DataType::i32) - builder->CreateAtomicRMW(llvm::AtomicRMWInst::BinOp::Add, - stmt->dest->value, stmt->val->value, - llvm::AtomicOrdering::SequentiallyConsistent); + value = builder->CreateAtomicRMW( + llvm::AtomicRMWInst::BinOp::Add, stmt->dest->value, + stmt->val->value, llvm::AtomicOrdering::SequentiallyConsistent); else if (stmt->val->ret_type.data_type == DataType::f32) { - builder->CreateCall(get_runtime_function("atomic_add_f32"), - {stmt->dest->value, stmt->val->value}); + value = builder->CreateCall(get_runtime_function("atomic_add_f32"), + {stmt->dest->value, stmt->val->value}); } else if (stmt->val->ret_type.data_type == DataType::f64) { - builder->CreateCall(get_runtime_function("atomic_add_f64"), - {stmt->dest->value, stmt->val->value}); + value = builder->CreateCall(get_runtime_function("atomic_add_f64"), + {stmt->dest->value, stmt->val->value}); } else { TC_NOT_IMPLEMENTED } + stmt->value = value; } } diff --git a/taichi/backends/codegen_llvm_ptx.cpp b/taichi/backends/codegen_llvm_ptx.cpp index 44bb143e6572a9..90a84b43a85c10 100644 --- a/taichi/backends/codegen_llvm_ptx.cpp +++ b/taichi/backends/codegen_llvm_ptx.cpp @@ -218,23 +218,27 @@ class CodeGenLLVMGPU : public CodeGenLLVM { } else { for (int l = 0; l < stmt->width(); l++) { TC_ASSERT(stmt->op_type == AtomicOpType::add); + llvm::Value *value; if (is_integral(stmt->val->ret_type.data_type)) { - builder->CreateAtomicRMW( + value = builder->CreateAtomicRMW( llvm::AtomicRMWInst::BinOp::Add, stmt->dest->value, stmt->val->value, llvm::AtomicOrdering::SequentiallyConsistent); } else if (stmt->val->ret_type.data_type == DataType::f32) { auto dt = tlctx->get_data_type(DataType::f32); - builder->CreateIntrinsic(Intrinsic::nvvm_atomic_load_add_f32, - {llvm::PointerType::get(dt, 0)}, - {stmt->dest->value, stmt->val->value}); + value = + builder->CreateIntrinsic(Intrinsic::nvvm_atomic_load_add_f32, + {llvm::PointerType::get(dt, 0)}, + {stmt->dest->value, stmt->val->value}); } else if (stmt->val->ret_type.data_type == DataType::f64) { auto dt = tlctx->get_data_type(DataType::f64); - builder->CreateIntrinsic(Intrinsic::nvvm_atomic_load_add_f64, - {llvm::PointerType::get(dt, 0)}, - {stmt->dest->value, stmt->val->value}); + value = + builder->CreateIntrinsic(Intrinsic::nvvm_atomic_load_add_f64, + {llvm::PointerType::get(dt, 0)}, + {stmt->dest->value, stmt->val->value}); } else { TC_NOT_IMPLEMENTED } + stmt->value = value; } } } diff --git a/taichi/ir.h b/taichi/ir.h index 483c9df978e547..3be6485fa456cd 100644 --- a/taichi/ir.h +++ b/taichi/ir.h @@ -1898,6 +1898,40 @@ class IdExpression : public Expression { } }; +// This is just a wrapper class of FrontendAtomicStmt, so that we can turn +// ti.atomic_op() into an expression with return value. +class FrontendAtomicExpression : public Expression { + public: + AtomicOpType op_type; + Expr dest, val; + + FrontendAtomicExpression(AtomicOpType op_type, Expr dest, Expr val) + : op_type(op_type), dest(dest), val(val) { + } + + std::string serialize() override { + if (op_type == AtomicOpType::add) { + return fmt::format("atomic_add({}, {})", dest.serialize(), + val.serialize()); + } else if (op_type == AtomicOpType::sub) { + // Note that sub is not supported yet, which is handled at a lower level. + return fmt::format("atomic_sub({}, {})", dest.serialize(), + val.serialize()); + } else { + // min/max not supported in the LLVM backend yet. + TC_NOT_IMPLEMENTED; + } + } + + void flatten(VecStatement &ret) override { + // FrontendAtomicStmt is the correct place to flatten sub-exprs like |dest| + // and |val| (See LowerAST). This class only wraps the frontend atomic_op() + // stmt as an expression. + ret.push_back(op_type, dest, val); + stmt = ret.back().get(); + } +}; + class SNodeOpExpression : public Expression { public: SNode *snode; diff --git a/taichi/python_bindings.cpp b/taichi/python_bindings.cpp index df6b8798a964d9..1eaef93842016d 100644 --- a/taichi/python_bindings.cpp +++ b/taichi/python_bindings.cpp @@ -237,13 +237,13 @@ void export_lang(py::module &m) { m.def("value_cast", static_cast(cast)); m.def("expr_atomic_add", [&](const Expr &a, const Expr &b) { - current_ast_builder().insert(Stmt::make( - AtomicOpType::add, ptr_if_global(a), load_if_ptr(b))); + return Expr::make( + AtomicOpType::add, ptr_if_global(a), load_if_ptr(b)); }); m.def("expr_atomic_sub", [&](const Expr &a, const Expr &b) { - current_ast_builder().insert(Stmt::make( - AtomicOpType::sub, ptr_if_global(a), load_if_ptr(b))); + return Expr::make( + AtomicOpType::sub, ptr_if_global(a), load_if_ptr(b)); }); m.def("expr_add", expr_add); diff --git a/taichi/transforms/ir_printer.cpp b/taichi/transforms/ir_printer.cpp index 14279f0d3c860a..183c678c56bf82 100644 --- a/taichi/transforms/ir_printer.cpp +++ b/taichi/transforms/ir_printer.cpp @@ -122,15 +122,17 @@ class IRPrinter : public IRVisitor { } void visit(FrontendAtomicStmt *stmt) override { + // FrontendAtomicStmt gets lowered before the type checking, therefore we + // don't print its return type (which is none anyway). print("{}{} = atomic {}({}, {})", stmt->type_hint(), stmt->name(), atomic_op_type_name(stmt->op_type), stmt->dest->serialize(), stmt->val->serialize()); } void visit(AtomicOpStmt *stmt) override { - print("{}{} = atomic {}({}, {})", stmt->type_hint(), stmt->name(), + print("{}{} = atomic {}({}, {}) -> {}", stmt->type_hint(), stmt->name(), atomic_op_type_name(stmt->op_type), stmt->dest->name(), - stmt->val->name()); + stmt->val->name(), stmt->ret_data_type_name()); } void visit(IfStmt *if_stmt) override { diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index a9d9fe049326a5..0e29ab0143e3ab 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -52,6 +52,9 @@ class TypeCheck : public IRVisitor { stmt->val = insert_type_cast_before(stmt, stmt->val, stmt->dest->ret_type.data_type); } + if (stmt->element_type() == DataType::unknown) { + stmt->ret_type = stmt->dest->ret_type; + } } void visit(LocalLoadStmt *stmt) {