diff --git a/c_api/tests/c_api_aot_test.cpp b/c_api/tests/c_api_aot_test.cpp index 536a05c9540fb..47f712bbb3526 100644 --- a/c_api/tests/c_api_aot_test.cpp +++ b/c_api/tests/c_api_aot_test.cpp @@ -122,6 +122,41 @@ void texture_aot_kernel_test(TiArch arch) { } } +static void shared_array_aot_test(TiArch arch) { + uint32_t kArrLen = 8192; + + const auto folder_dir = getenv("TAICHI_AOT_FOLDER_PATH"); + + std::stringstream aot_mod_ss; + aot_mod_ss << folder_dir; + + ti::Runtime runtime(arch); + + ti::NdArray v_array = + runtime.allocate_ndarray({kArrLen}, {}, true); + ti::NdArray d_array = + runtime.allocate_ndarray({kArrLen}, {}, true); + ti::NdArray a_array = + runtime.allocate_ndarray({kArrLen}, {}, true); + ti::AotModule aot_mod = runtime.load_aot_module(aot_mod_ss.str().c_str()); + ti::Kernel k_run = aot_mod.get_kernel("run"); + + k_run.push_arg(v_array); + k_run.push_arg(d_array); + k_run.push_arg(a_array); + k_run.launch(); + runtime.wait(); + + // Check Results + float *data = reinterpret_cast(a_array.map()); + + for (int i = 0; i < kArrLen; ++i) { + EXPECT_EQ(data[i], kArrLen); + } + + a_array.unmap(); +} + TEST_F(CapiTest, AotTestCpuField) { TiArch arch = TiArch::TI_ARCH_X64; field_aot_test(arch); @@ -166,3 +201,17 @@ TEST_F(CapiTest, GraphTestVulkanTextureKernel) { texture_aot_kernel_test(arch); } } + +TEST_F(CapiTest, AotTestCudaSharedArray) { + if (ti::is_arch_available(TI_ARCH_CUDA)) { + TiArch arch = TiArch::TI_ARCH_CUDA; + shared_array_aot_test(arch); + } +} + +TEST_F(CapiTest, AotTestVulkanSharedArray) { + if (ti::is_arch_available(TI_ARCH_VULKAN)) { + TiArch arch = TiArch::TI_ARCH_VULKAN; + shared_array_aot_test(arch); + } +} diff --git a/tests/cpp/aot/python_scripts/shared_array_aot_test.py b/tests/cpp/aot/python_scripts/shared_array_aot_test.py new file mode 100644 index 0000000000000..209760e70a113 --- /dev/null +++ b/tests/cpp/aot/python_scripts/shared_array_aot_test.py @@ -0,0 +1,60 @@ +import argparse +import os + +import numpy as np + +import taichi as ti + + +def shared_array_aot_test(arch): + ti.init(arch=arch) + + if ti.lang.impl.current_cfg().arch != arch: + return + block_dim = 128 + nBlocks = 64 + N = nBlocks * block_dim + v_arr = np.zeros(N).astype(np.float32) + d_arr = np.zeros(N).astype(np.float32) + a_arr = np.zeros(N).astype(np.float32) + + @ti.kernel + def run(v: ti.types.ndarray(ndim=1), d: ti.types.ndarray(ndim=1), + a: ti.types.ndarray(ndim=1)): + for i in range(nBlocks * block_dim): + v[i] = 1.0 + d[i] = 1.0 + + ti.loop_config(block_dim=block_dim) + for i in range(nBlocks * block_dim): + tid = i % block_dim + pad = ti.simt.block.SharedArray((block_dim, ), ti.f32) + acc = 0.0 + v_val = v[i] + for k in range(nBlocks): + pad[tid] = d[k * block_dim + tid] + ti.simt.block.sync() + for j in range(block_dim): + acc += v_val * pad[j] + ti.simt.block.sync() + a[i] = acc + + assert "TAICHI_AOT_FOLDER_PATH" in os.environ.keys() + dir_name = str(os.environ["TAICHI_AOT_FOLDER_PATH"]) + + m = ti.aot.Module() + m.add_kernel(run, template_args={'v': v_arr, 'd': d_arr, 'a': a_arr}) + m.save(dir_name) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--arch", type=str) + args = parser.parse_args() + + if args.arch == "cuda": + shared_array_aot_test(arch=ti.cuda) + elif args.arch == "vulkan": + shared_array_aot_test(arch=ti.vulkan) + else: + assert False diff --git a/tests/test_config.json b/tests/test_config.json index beda500f90e91..4005aa8f79237 100644 --- a/tests/test_config.json +++ b/tests/test_config.json @@ -183,10 +183,18 @@ ["cpp", "aot", "python_scripts", "kernel_aot_test1.py"], "--arch=cuda" ], + "CapiTest.AotTestCudaSharedArray": [ + ["cpp", "aot", "python_scripts", "shared_array_aot_test.py"], + "--arch=cuda" + ], "CapiTest.AotTestVulkanKernel": [ ["cpp", "aot", "python_scripts", "kernel_aot_test1.py"], "--arch=vulkan" ], + "CapiTest.AotTestVulkanSharedArray": [ + ["cpp", "aot", "python_scripts", "shared_array_aot_test.py"], + "--arch=vulkan" + ], "CapiTest.AotTestOpenglKernel": [ ["cpp", "aot", "python_scripts", "kernel_aot_test1.py"], "--arch=opengl"