diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 9d5bd64266f49..ef7261d872e27 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -788,6 +788,7 @@ void AtomicOpExpression::flatten(FlattenContext *ctx) { ctx->push_back(op_type, dest->stmt, src_val); } stmt = ctx->back_stmt(); + stmt->ret_type = stmt->as()->dest->ret_type; stmt->tb = tb; } diff --git a/tests/cpp/ir/ir_builder_test.cpp b/tests/cpp/ir/ir_builder_test.cpp index 4edf8a446b7e3..29a291189fd1c 100644 --- a/tests/cpp/ir/ir_builder_test.cpp +++ b/tests/cpp/ir/ir_builder_test.cpp @@ -153,4 +153,30 @@ TEST(IRBuilder, Ndarray) { EXPECT_EQ(array.read_int({1}), 3); EXPECT_EQ(array.read_int({2}), 42); } + +TEST(IRBuilder, AtomicOp) { + TestProgram test_prog; + test_prog.setup(); + + IRBuilder builder; + const int size = 10; + auto array = std::make_unique(size); + array[0] = 2; + array[2] = 40; + auto *arg = builder.create_arg_load(/*arg_id=*/0, get_data_type(), + /*is_ptr=*/true); + auto *zero = builder.get_int32(0); + auto *one = builder.get_int32(1); + auto *a0ptr = builder.create_external_ptr(arg, {zero}); + builder.create_atomic_add(a0ptr, one); // a[0] += 1 + auto block = builder.extract_ir(); + auto ker = std::make_unique(*test_prog.prog(), std::move(block)); + ker->insert_arr_arg(get_data_type(), /*total_dim=*/1, {1}); + auto launch_ctx = ker->make_launch_context(); + launch_ctx.set_arg_external_array_with_shape( + /*arg_id=*/0, (uint64)array.get(), size, {size}); + (*ker)(launch_ctx); + + EXPECT_EQ(array[0], 3); +} } // namespace taichi::lang