Skip to content

Commit

Permalink
[aot] Add test for shared array (#7387)
Browse files Browse the repository at this point in the history
Fixes #7274
  • Loading branch information
ailzhang authored Feb 20, 2023
1 parent f0d628b commit fadedc6
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 0 deletions.
49 changes: 49 additions & 0 deletions c_api/tests/c_api_aot_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,41 @@ void texture_aot_kernel_test(TiArch arch) {
}
}

static void shared_array_aot_test(TiArch arch) {
uint32_t kArrLen = 8192;

const auto folder_dir = getenv("TAICHI_AOT_FOLDER_PATH");

std::stringstream aot_mod_ss;
aot_mod_ss << folder_dir;

ti::Runtime runtime(arch);

ti::NdArray<float> v_array =
runtime.allocate_ndarray<float>({kArrLen}, {}, true);
ti::NdArray<float> d_array =
runtime.allocate_ndarray<float>({kArrLen}, {}, true);
ti::NdArray<float> a_array =
runtime.allocate_ndarray<float>({kArrLen}, {}, true);
ti::AotModule aot_mod = runtime.load_aot_module(aot_mod_ss.str().c_str());
ti::Kernel k_run = aot_mod.get_kernel("run");

k_run.push_arg(v_array);
k_run.push_arg(d_array);
k_run.push_arg(a_array);
k_run.launch();
runtime.wait();

// Check Results
float *data = reinterpret_cast<float *>(a_array.map());

for (int i = 0; i < kArrLen; ++i) {
EXPECT_EQ(data[i], kArrLen);
}

a_array.unmap();
}

TEST_F(CapiTest, AotTestCpuField) {
TiArch arch = TiArch::TI_ARCH_X64;
field_aot_test(arch);
Expand Down Expand Up @@ -166,3 +201,17 @@ TEST_F(CapiTest, GraphTestVulkanTextureKernel) {
texture_aot_kernel_test(arch);
}
}

TEST_F(CapiTest, AotTestCudaSharedArray) {
if (ti::is_arch_available(TI_ARCH_CUDA)) {
TiArch arch = TiArch::TI_ARCH_CUDA;
shared_array_aot_test(arch);
}
}

TEST_F(CapiTest, AotTestVulkanSharedArray) {
if (ti::is_arch_available(TI_ARCH_VULKAN)) {
TiArch arch = TiArch::TI_ARCH_VULKAN;
shared_array_aot_test(arch);
}
}
60 changes: 60 additions & 0 deletions tests/cpp/aot/python_scripts/shared_array_aot_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import argparse
import os

import numpy as np

import taichi as ti


def shared_array_aot_test(arch):
ti.init(arch=arch)

if ti.lang.impl.current_cfg().arch != arch:
return
block_dim = 128
nBlocks = 64
N = nBlocks * block_dim
v_arr = np.zeros(N).astype(np.float32)
d_arr = np.zeros(N).astype(np.float32)
a_arr = np.zeros(N).astype(np.float32)

@ti.kernel
def run(v: ti.types.ndarray(ndim=1), d: ti.types.ndarray(ndim=1),
a: ti.types.ndarray(ndim=1)):
for i in range(nBlocks * block_dim):
v[i] = 1.0
d[i] = 1.0

ti.loop_config(block_dim=block_dim)
for i in range(nBlocks * block_dim):
tid = i % block_dim
pad = ti.simt.block.SharedArray((block_dim, ), ti.f32)
acc = 0.0
v_val = v[i]
for k in range(nBlocks):
pad[tid] = d[k * block_dim + tid]
ti.simt.block.sync()
for j in range(block_dim):
acc += v_val * pad[j]
ti.simt.block.sync()
a[i] = acc

assert "TAICHI_AOT_FOLDER_PATH" in os.environ.keys()
dir_name = str(os.environ["TAICHI_AOT_FOLDER_PATH"])

m = ti.aot.Module()
m.add_kernel(run, template_args={'v': v_arr, 'd': d_arr, 'a': a_arr})
m.save(dir_name)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--arch", type=str)
args = parser.parse_args()

if args.arch == "cuda":
shared_array_aot_test(arch=ti.cuda)
elif args.arch == "vulkan":
shared_array_aot_test(arch=ti.vulkan)
else:
assert False
8 changes: 8 additions & 0 deletions tests/test_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,18 @@
["cpp", "aot", "python_scripts", "kernel_aot_test1.py"],
"--arch=cuda"
],
"CapiTest.AotTestCudaSharedArray": [
["cpp", "aot", "python_scripts", "shared_array_aot_test.py"],
"--arch=cuda"
],
"CapiTest.AotTestVulkanKernel": [
["cpp", "aot", "python_scripts", "kernel_aot_test1.py"],
"--arch=vulkan"
],
"CapiTest.AotTestVulkanSharedArray": [
["cpp", "aot", "python_scripts", "shared_array_aot_test.py"],
"--arch=vulkan"
],
"CapiTest.AotTestOpenglKernel": [
["cpp", "aot", "python_scripts", "kernel_aot_test1.py"],
"--arch=opengl"
Expand Down

0 comments on commit fadedc6

Please sign in to comment.