-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[aot] [llvm] LLVM AOT Field #3: Added AOT tests for Fields - CPU back…
…end (#5121) * [aot] [llvm] Implemented FieldCacheData and refactored initialize_llvm_runtime_snodes() * Addressed compilation erros * [aot] [llvm] LLVM AOT Field #1: Adjust serialization/deserialization logics for FieldCacheData * [llvm] [aot] Added Field support for LLVM AOT * [aot] [llvm] LLVM AOT Field #2: Updated LLVM AOTModuleLoader & AOTModuleBuilder to support Fields * [aot] [llvm] LLVM AOT Field #3: Added AOT tests for Fields running CPU backend * Added tests for activate/deactivate operations * Fixed merge issue
- Loading branch information
1 parent
d7e7491
commit 692ed11
Showing
6 changed files
with
232 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import os | ||
|
||
import taichi as ti | ||
|
||
|
||
def compile_aot(): | ||
# Make sure "debug" mode is on | ||
# in both python & C++ tests | ||
ti.init(arch=ti.x64, debug=True) | ||
|
||
x = ti.field(ti.i32) | ||
y = ti.field(ti.i32) | ||
|
||
common = ti.root.dense(ti.i, 4) | ||
common.dense(ti.i, 8).place(x) | ||
|
||
p = common.pointer(ti.i, 2) | ||
p.dense(ti.i, 8).place(y) | ||
|
||
@ti.kernel | ||
def init_fields(base: int): | ||
# Dense SNode | ||
for i in range(4 * 8): | ||
x[i] = base + i | ||
|
||
# Pointer SNode | ||
y[32] = 4 | ||
y[33] = 5 | ||
y[9] = 10 | ||
|
||
@ti.kernel | ||
def check_init_x(base: int): | ||
# Check numerical accuracy for Dense SNodes | ||
for i in range(4 * 8): | ||
assert (x[i] == base + i) | ||
|
||
@ti.kernel | ||
def check_init_y(): | ||
# Check sparsity for Pointer SNodes | ||
for i in range(8): | ||
if i == 1 or i == 4: | ||
assert (ti.is_active(p, [i])) | ||
else: | ||
assert (not ti.is_active(p, [i])) | ||
|
||
# Check numerical accuracy for Pointer SNodes | ||
for i in range(8, 8 + 8): | ||
if i == 9: | ||
assert (y[i] == 10) | ||
else: | ||
assert (y[i] == 0) | ||
|
||
for i in range(32, 32 + 8): | ||
if i == 32: | ||
assert (y[i] == 4) | ||
elif i == 33: | ||
assert (y[i] == 5) | ||
else: | ||
assert (y[i] == 0) | ||
|
||
@ti.kernel | ||
def deactivate_pointer_fields(): | ||
ti.deactivate(p, [1]) | ||
ti.deactivate(p, [4]) | ||
|
||
@ti.kernel | ||
def activate_pointer_fields(): | ||
ti.activate(p, [7]) | ||
ti.activate(p, [3]) | ||
|
||
@ti.kernel | ||
def check_deactivate_pointer_fields(): | ||
assert (not ti.is_active(p, [1])) | ||
assert (not ti.is_active(p, [4])) | ||
|
||
@ti.kernel | ||
def check_activate_pointer_fields(): | ||
assert (ti.is_active(p, [7])) | ||
assert (ti.is_active(p, [3])) | ||
|
||
for i in range(7 * 8, 7 * 8 + 8): | ||
assert (y[i] == 0) | ||
|
||
for i in range(3 * 8, 3 * 8 + 8): | ||
assert (y[i] == 0) | ||
|
||
assert "TAICHI_AOT_FOLDER_PATH" in os.environ.keys() | ||
dir_name = str(os.environ["TAICHI_AOT_FOLDER_PATH"]) | ||
|
||
m = ti.aot.Module(ti.x64) | ||
|
||
m.add_kernel(init_fields, template_args={}) | ||
m.add_kernel(check_init_x, template_args={}) | ||
m.add_kernel(check_init_y, template_args={}) | ||
|
||
m.add_kernel(deactivate_pointer_fields, template_args={}) | ||
m.add_kernel(activate_pointer_fields, template_args={}) | ||
|
||
m.add_kernel(check_deactivate_pointer_fields, template_args={}) | ||
m.add_kernel(check_activate_pointer_fields, template_args={}) | ||
|
||
m.add_field("x", x) | ||
m.add_field("y", y) | ||
|
||
m.save(dir_name, 'x64-aot') | ||
|
||
|
||
compile_aot() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
#include "gtest/gtest.h" | ||
|
||
#include "taichi/program/kernel_profiler.h" | ||
#include "taichi/llvm/llvm_program.h" | ||
#include "taichi/system/memory_pool.h" | ||
#include "taichi/backends/cpu/aot_module_loader_impl.h" | ||
#include "taichi/backends/cuda/aot_module_loader_impl.h" | ||
#include "taichi/llvm/llvm_aot_module_loader.h" | ||
#include "taichi/backends/cuda/cuda_driver.h" | ||
#include "taichi/platform/cuda/detect_cuda.h" | ||
|
||
#define TI_RUNTIME_HOST | ||
#include "taichi/program/context.h" | ||
#undef TI_RUNTIME_HOST | ||
|
||
namespace taichi { | ||
namespace lang { | ||
|
||
TEST(LlvmAOTTest, Field) { | ||
CompileConfig cfg; | ||
cfg.arch = Arch::x64; | ||
cfg.kernel_profiler = false; | ||
constexpr KernelProfilerBase *kNoProfiler = nullptr; | ||
LlvmProgramImpl prog{cfg, kNoProfiler}; | ||
auto *compute_device = prog.get_compute_device(); | ||
|
||
// Must have handled all the arch fallback logic by this point. | ||
auto memory_pool = std::make_unique<MemoryPool>(cfg.arch, compute_device); | ||
prog.initialize_host(); | ||
uint64 *result_buffer{nullptr}; | ||
prog.materialize_runtime(memory_pool.get(), kNoProfiler, &result_buffer); | ||
|
||
cpu::AotModuleParams aot_params; | ||
const auto folder_dir = getenv("TAICHI_AOT_FOLDER_PATH"); | ||
|
||
std::stringstream aot_mod_ss; | ||
aot_mod_ss << folder_dir; | ||
aot_params.module_path = aot_mod_ss.str(); | ||
aot_params.program = &prog; | ||
std::unique_ptr<aot::Module> mod = cpu::make_aot_module(aot_params); | ||
|
||
aot::Kernel *k_init_fields = mod->get_kernel("init_fields"); | ||
aot::Kernel *k_check_init_x = mod->get_kernel("check_init_x"); | ||
aot::Kernel *k_check_init_y = mod->get_kernel("check_init_y"); | ||
|
||
aot::Kernel *k_deactivate_pointer_fields = | ||
mod->get_kernel("deactivate_pointer_fields"); | ||
aot::Kernel *k_activate_pointer_fields = | ||
mod->get_kernel("activate_pointer_fields"); | ||
|
||
aot::Kernel *k_check_deactivate_pointer_fields = | ||
mod->get_kernel("check_deactivate_pointer_fields"); | ||
aot::Kernel *k_check_activate_pointer_fields = | ||
mod->get_kernel("check_activate_pointer_fields"); | ||
|
||
// Initialize Fields | ||
aot::Field *field_x = mod->get_field("0" /*snode_tree_id*/); | ||
aot::Field *field_y = mod->get_field("0" /*snode_tree_id*/); | ||
|
||
finalize_aot_field(mod.get(), field_x, result_buffer); | ||
finalize_aot_field(mod.get(), field_y, result_buffer); | ||
|
||
int base_value = 10; | ||
/* -------- Test Case 1 ------ */ | ||
// Kernel: init_fields(int) | ||
{ | ||
RuntimeContext ctx; | ||
ctx.runtime = prog.get_llvm_runtime(); | ||
ctx.set_arg(0, base_value); | ||
k_init_fields->launch(&ctx); | ||
} | ||
|
||
// Kernel: check_init_x(int) | ||
{ | ||
RuntimeContext ctx; | ||
ctx.runtime = prog.get_llvm_runtime(); | ||
ctx.set_arg(0, base_value); | ||
k_check_init_x->launch(&ctx); | ||
} | ||
// Kernel: check_init_y() | ||
{ | ||
RuntimeContext ctx; | ||
ctx.runtime = prog.get_llvm_runtime(); | ||
k_check_init_y->launch(&ctx); | ||
} | ||
|
||
/* -------- Test Case 2 ------ */ | ||
// Kernel: deactivate_pointer_fields() | ||
{ | ||
RuntimeContext ctx; | ||
ctx.runtime = prog.get_llvm_runtime(); | ||
k_deactivate_pointer_fields->launch(&ctx); | ||
} | ||
// Kernel: check_deactivate_pointer_fields() | ||
{ | ||
RuntimeContext ctx; | ||
ctx.runtime = prog.get_llvm_runtime(); | ||
k_check_deactivate_pointer_fields->launch(&ctx); | ||
} | ||
|
||
/* -------- Test Case 3 ------ */ | ||
// Kernel: activate_pointer_fields() | ||
{ | ||
RuntimeContext ctx; | ||
ctx.runtime = prog.get_llvm_runtime(); | ||
k_activate_pointer_fields->launch(&ctx); | ||
} | ||
// Kernel: check_activate_pointer_fields() | ||
{ | ||
RuntimeContext ctx; | ||
ctx.runtime = prog.get_llvm_runtime(); | ||
k_check_activate_pointer_fields->launch(&ctx); | ||
} | ||
|
||
// Check assertion error from ti.kernel | ||
prog.check_runtime_error(result_buffer); | ||
} | ||
|
||
} // namespace lang | ||
} // namespace taichi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters