Skip to content

Commit

Permalink
[lang] Set ret_type for AtomicOpStmt (#6213)
Browse files Browse the repository at this point in the history
  • Loading branch information
ailzhang authored Oct 1, 2022
1 parent 92875a3 commit ed12813
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
1 change: 1 addition & 0 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,7 @@ void AtomicOpExpression::flatten(FlattenContext *ctx) {
ctx->push_back<AtomicOpStmt>(op_type, dest->stmt, src_val);
}
stmt = ctx->back_stmt();
stmt->ret_type = stmt->as<AtomicOpStmt>()->dest->ret_type;
stmt->tb = tb;
}

Expand Down
26 changes: 26 additions & 0 deletions tests/cpp/ir/ir_builder_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,30 @@ TEST(IRBuilder, Ndarray) {
EXPECT_EQ(array.read_int({1}), 3);
EXPECT_EQ(array.read_int({2}), 42);
}

TEST(IRBuilder, AtomicOp) {
TestProgram test_prog;
test_prog.setup();

IRBuilder builder;
const int size = 10;
auto array = std::make_unique<int[]>(size);
array[0] = 2;
array[2] = 40;
auto *arg = builder.create_arg_load(/*arg_id=*/0, get_data_type<int>(),
/*is_ptr=*/true);
auto *zero = builder.get_int32(0);
auto *one = builder.get_int32(1);
auto *a0ptr = builder.create_external_ptr(arg, {zero});
builder.create_atomic_add(a0ptr, one); // a[0] += 1
auto block = builder.extract_ir();
auto ker = std::make_unique<Kernel>(*test_prog.prog(), std::move(block));
ker->insert_arr_arg(get_data_type<int>(), /*total_dim=*/1, {1});
auto launch_ctx = ker->make_launch_context();
launch_ctx.set_arg_external_array_with_shape(
/*arg_id=*/0, (uint64)array.get(), size, {size});
(*ker)(launch_ctx);

EXPECT_EQ(array[0], 3);
}
} // namespace taichi::lang

0 comments on commit ed12813

Please sign in to comment.