Skip to content

Commit

Permalink
[llvm] Unify the llvm context of host and device (taichi-dev#7249)
Browse files Browse the repository at this point in the history
Issue: fixes taichi-dev#7251
The` llvm_context_host_` and `llvm_context_device_` are never both used
in the same backend, so we can unify them and only use a single
`llvm_context_`.
### Brief Summary
  • Loading branch information
lin-hitonami authored and quadpixels committed May 13, 2023
1 parent 278ac44 commit 192e78e
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 150 deletions.
3 changes: 1 addition & 2 deletions c_api/src/taichi_llvm_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ taichi::lang::Device &LlvmRuntime::get() {
TiMemory LlvmRuntime::allocate_memory(
const taichi::lang::Device::AllocParams &params) {
const taichi::lang::CompileConfig &config = executor_->get_config();
taichi::lang::TaichiLLVMContext *tlctx =
executor_->get_llvm_context(config.arch);
taichi::lang::TaichiLLVMContext *tlctx = executor_->get_llvm_context();
taichi::lang::LLVMRuntime *llvm_runtime = executor_->get_llvm_runtime();
taichi::lang::LlvmDevice *llvm_device = executor_->llvm_device();

Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/cpu/aot_module_loader_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class AotModuleImpl : public LlvmAotModule {
LlvmOfflineCache::KernelCacheData &&loaded) override {
Arch arch = executor_->get_config().arch;
TI_ASSERT(arch == Arch::x64 || arch == Arch::arm64);
auto *tlctx = executor_->get_llvm_context(arch);
auto *tlctx = executor_->get_llvm_context();

CPUModuleToFunctionConverter converter{tlctx, executor_};
return converter.convert(name, loaded.args,
Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/cuda/aot_module_loader_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class AotModuleImpl : public LlvmAotModule {
LlvmOfflineCache::KernelCacheData &&loaded) override {
Arch arch = executor_->get_config().arch;
TI_ASSERT(arch == Arch::cuda);
auto *tlctx = executor_->get_llvm_context(arch);
auto *tlctx = executor_->get_llvm_context();

CUDAModuleToFunctionConverter converter{tlctx, executor_};
return converter.convert(name, loaded.args,
Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/llvm/llvm_aot_module_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace taichi::lang {
LlvmOfflineCache::KernelCacheData LlvmAotModule::load_kernel_from_cache(
const std::string &name) {
TI_ASSERT(cache_reader_ != nullptr);
auto *tlctx = executor_->get_llvm_context(executor_->get_config().arch);
auto *tlctx = executor_->get_llvm_context();
LlvmOfflineCache::KernelCacheData loaded;
auto ok = cache_reader_->get_kernel_cache(loaded, name,
*tlctx->get_this_thread_context());
Expand Down
6 changes: 6 additions & 0 deletions taichi/runtime/llvm/llvm_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ TaichiLLVMContext::TaichiLLVMContext(const CompileConfig &config, Arch arch)
llvm::InitializeNativeTargetAsmParser();
#endif
} else if (arch == Arch::dx12) {
// FIXME: Must initialize these before initializing Arch::dx12
// because it uses the jit of CPU right now.
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
llvm::InitializeNativeTargetAsmParser();
// The dx target is used elsewhere, so we need to initialize it too.
#if defined(TI_WITH_DX12)
LLVMInitializeDirectXTarget();
LLVMInitializeDirectXTargetMC();
Expand Down
118 changes: 25 additions & 93 deletions taichi/runtime/llvm/llvm_runtime_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,22 @@ LlvmRuntimeExecutor::LlvmRuntimeExecutor(CompileConfig &config,
}
}

if (config.kernel_profiler) {
profiler_ = profiler;
}

snode_tree_buffer_manager_ = std::make_unique<SNodeTreeBufferManager>(this);
thread_pool_ = std::make_unique<ThreadPool>(config.cpu_max_num_threads);
preallocated_device_buffer_ = nullptr;

llvm_runtime_ = nullptr;
llvm_context_host_ = std::make_unique<TaichiLLVMContext>(config, host_arch());

if (config.arch == Arch::cuda) {
if (arch_is_cpu(config.arch)) {
config.max_block_dim = 1024;
device_ = std::make_shared<cpu::CpuDevice>();
}
#if defined(TI_WITH_CUDA)
else if (config.arch == Arch::cuda) {
int num_SMs{1};
CUDADriver::get_instance().device_get_attribute(
&num_SMs, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, nullptr);
Expand All @@ -89,89 +96,41 @@ LlvmRuntimeExecutor::LlvmRuntimeExecutor(CompileConfig &config,
}
config.saturating_grid_dim = num_SMs * query_max_block_per_sm * 2;
}
#endif
} else if (config.arch == Arch::dx12) {
#if defined(TI_WITH_DX12)
// FIXME: set value based on DX12.
config.max_block_dim = 1024;
#endif
}

if (arch_is_cpu(config.arch)) {
config.max_block_dim = 1024;
device_ = std::make_shared<cpu::CpuDevice>();
}

if (config.kernel_profiler) {
profiler_ = profiler;
}

#if defined(TI_WITH_CUDA)
if (config.arch == Arch::cuda) {
if (config.kernel_profiler) {
CUDAContext::get_instance().set_profiler(profiler);
} else {
CUDAContext::get_instance().set_profiler(nullptr);
}
CUDAContext::get_instance().set_debug(config.debug);
device_ = std::make_shared<cuda::CudaDevice>();

this->maybe_initialize_cuda_llvm_context();
}
#endif

#if defined(TI_WITH_AMDGPU)
if (config.arch == Arch::amdgpu) {
else if (config.arch == Arch::amdgpu) {
AMDGPUContext::get_instance().set_debug(config.debug);
device_ = std::make_shared<amdgpu::AmdgpuDevice>();

this->maybe_initialize_amdgpu_llvm_context();
}
#endif

#ifdef TI_WITH_DX12
if (config.arch == Arch::dx12) {
else if (config.arch == Arch::dx12) {
// FIXME: add dx12 device.
// FIXME: set value based on DX12.
config.max_block_dim = 1024;
device_ = std::make_shared<cpu::CpuDevice>();

llvm_context_device_ =
std::make_unique<TaichiLLVMContext>(config, Arch::dx12);
// FIXME: add dx12 JIT.
llvm_context_device_->init_runtime_jit_module();
}
#endif

this->initialize_host();
}

TaichiLLVMContext *LlvmRuntimeExecutor::get_llvm_context(Arch arch) {
if (arch_is_cpu(arch)) {
return llvm_context_host_.get();
} else {
return llvm_context_device_.get();
else {
TI_NOT_IMPLEMENTED
}
llvm_context_ = std::make_unique<TaichiLLVMContext>(
config_, arch_is_cpu(config.arch) ? host_arch() : config.arch);
llvm_context_->init_runtime_jit_module();
}

void LlvmRuntimeExecutor::initialize_host() {
// Note this cannot be placed inside LlvmProgramImpl constructor, see doc
// string for init_runtime_jit_module() for more details.
llvm_context_host_->init_runtime_jit_module();
}

void LlvmRuntimeExecutor::maybe_initialize_cuda_llvm_context() {
if (config_.arch == Arch::cuda && llvm_context_device_ == nullptr) {
llvm_context_device_ =
std::make_unique<TaichiLLVMContext>(config_, Arch::cuda);
llvm_context_device_->init_runtime_jit_module();
}
}

void LlvmRuntimeExecutor::maybe_initialize_amdgpu_llvm_context() {
if (config_.arch == Arch::amdgpu && llvm_context_device_ == nullptr) {
llvm_context_device_ =
std::make_unique<TaichiLLVMContext>(config_, Arch::amdgpu);
llvm_context_device_->init_runtime_jit_module();
}
TaichiLLVMContext *LlvmRuntimeExecutor::get_llvm_context() {
return llvm_context_.get();
}

void LlvmRuntimeExecutor::print_list_manager_info(void *list_manager,
Expand Down Expand Up @@ -243,13 +202,7 @@ std::size_t LlvmRuntimeExecutor::get_snode_num_dynamically_allocated(

void LlvmRuntimeExecutor::check_runtime_error(uint64 *result_buffer) {
synchronize();
auto tlctx = llvm_context_host_.get();
if (llvm_context_device_) {
// In case there is a standalone device context (e.g. CUDA without unified
// memory), use the device context instead.
tlctx = llvm_context_device_.get();
}
auto *runtime_jit_module = tlctx->runtime_jit_module;
auto *runtime_jit_module = llvm_context_->runtime_jit_module;
runtime_jit_module->call<void *>("runtime_retrieve_and_reset_error_code",
llvm_runtime_);
auto error_code =
Expand Down Expand Up @@ -369,18 +322,7 @@ DevicePtr LlvmRuntimeExecutor::get_snode_tree_device_ptr(int tree_id) {
void LlvmRuntimeExecutor::initialize_llvm_runtime_snodes(
const LlvmOfflineCache::FieldCacheData &field_cache_data,
uint64 *result_buffer) {
TaichiLLVMContext *tlctx = nullptr;
if (config_.arch == Arch::cuda) {
#if defined(TI_WITH_CUDA)
tlctx = llvm_context_device_.get();
#else
TI_NOT_IMPLEMENTED
#endif
} else {
tlctx = llvm_context_host_.get();
}

auto *const runtime_jit = tlctx->runtime_jit_module;
auto *const runtime_jit = llvm_context_->runtime_jit_module;
// By the time this creator is called, "this" is already destroyed.
// Therefore it is necessary to capture members by values.
size_t root_size = field_cache_data.root_size;
Expand Down Expand Up @@ -477,18 +419,11 @@ LlvmDevice *LlvmRuntimeExecutor::llvm_device() {
DeviceAllocation LlvmRuntimeExecutor::allocate_memory_ndarray(
std::size_t alloc_size,
uint64 *result_buffer) {
TaichiLLVMContext *tlctx = nullptr;
if (llvm_context_device_) {
tlctx = llvm_context_device_.get();
} else {
tlctx = llvm_context_host_.get();
}

return llvm_device()->allocate_memory_runtime(
{{alloc_size, /*host_write=*/false, /*host_read=*/false,
/*export_sharing=*/false, AllocUsage::Storage},
config_.ndarray_use_cached_allocator,
tlctx->runtime_jit_module,
llvm_context_->runtime_jit_module,
get_llvm_runtime(),
result_buffer});
}
Expand Down Expand Up @@ -538,7 +473,6 @@ void LlvmRuntimeExecutor::materialize_runtime(MemoryPool *memory_pool,
KernelProfilerBase *profiler,
uint64 **result_buffer_ptr) {
std::size_t prealloc_size = 0;
TaichiLLVMContext *tlctx = nullptr;
if (config_.arch == Arch::cuda) {
#if defined(TI_WITH_CUDA)
CUDADriver::get_instance().malloc(
Expand Down Expand Up @@ -566,16 +500,14 @@ void LlvmRuntimeExecutor::materialize_runtime(MemoryPool *memory_pool,

CUDADriver::get_instance().memset(preallocated_device_buffer_, 0,
prealloc_size);
tlctx = llvm_context_device_.get();
#else
TI_NOT_IMPLEMENTED
#endif
} else {
*result_buffer_ptr = (uint64 *)memory_pool->allocate(
sizeof(uint64) * taichi_result_buffer_entries, 8);
tlctx = llvm_context_host_.get();
}
auto *const runtime_jit = tlctx->runtime_jit_module;
auto *const runtime_jit = llvm_context_->runtime_jit_module;

// Starting random state for the program calculated using the random seed.
// The seed is multiplied by 1048391 so that two programs with different seeds
Expand Down Expand Up @@ -656,7 +588,7 @@ void LlvmRuntimeExecutor::materialize_runtime(MemoryPool *memory_pool,
}

void LlvmRuntimeExecutor::destroy_snode_tree(SNodeTree *snode_tree) {
get_llvm_context(config_.arch)->delete_snode_tree(snode_tree->id());
get_llvm_context()->delete_snode_tree(snode_tree->id());
snode_tree_buffer_manager_->destroy(snode_tree);
}

Expand Down
25 changes: 3 additions & 22 deletions taichi/runtime/llvm/llvm_runtime_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class LlvmRuntimeExecutor {
return config_;
}

TaichiLLVMContext *get_llvm_context(Arch arch);
TaichiLLVMContext *get_llvm_context();

LLVMRuntime *get_llvm_runtime();

Expand Down Expand Up @@ -101,14 +101,7 @@ class LlvmRuntimeExecutor {
Args &&...args) {
TI_ASSERT(arch_uses_llvm(config_.arch));

TaichiLLVMContext *tlctx = nullptr;
if (llvm_context_device_) {
tlctx = llvm_context_device_.get();
} else {
tlctx = llvm_context_host_.get();
}

auto runtime = tlctx->runtime_jit_module;
auto runtime = llvm_context_->runtime_jit_module;
runtime->call<void *>("runtime_" + key, llvm_runtime_,
std::forward<Args>(args)...);
return taichi_union_cast_with_different_sizes<T>(fetch_result_uint64(
Expand All @@ -121,17 +114,6 @@ class LlvmRuntimeExecutor {
cuda::CudaDevice *cuda_device();
cpu::CpuDevice *cpu_device();

void initialize_host();

/**
* Initializes Program#llvm_context_device, if this has not been done.
*
* Not thread safe.
*/
void maybe_initialize_cuda_llvm_context();

void maybe_initialize_amdgpu_llvm_context();

void finalize();

uint64 fetch_result_uint64(int i, uint64 *result_buffer);
Expand All @@ -146,8 +128,7 @@ class LlvmRuntimeExecutor {
//
// TaichiLLVMContext is a thread-safe class with llvm::Module for compilation
// and JITSession/JITModule for runtime loading & execution
std::unique_ptr<TaichiLLVMContext> llvm_context_host_{nullptr};
std::unique_ptr<TaichiLLVMContext> llvm_context_device_{nullptr};
std::unique_ptr<TaichiLLVMContext> llvm_context_{nullptr};
void *llvm_runtime_{nullptr};

std::unique_ptr<ThreadPool> thread_pool_{nullptr};
Expand Down
21 changes: 4 additions & 17 deletions taichi/runtime/program_impls/llvm/llvm_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,10 @@ std::unique_ptr<StructCompiler> LlvmProgramImpl::compile_snode_tree_types_impl(
SNodeTree *tree) {
auto *const root = tree->root();
std::unique_ptr<StructCompiler> struct_compiler{nullptr};
if (arch_is_cpu(config->arch)) {
auto host_module =
runtime_exec_->llvm_context_host_.get()->new_module("struct");
struct_compiler = std::make_unique<StructCompilerLLVM>(
host_arch(), this, std::move(host_module), tree->id());
} else if (config->arch == Arch::dx12) {
auto device_module =
runtime_exec_->llvm_context_device_.get()->new_module("struct");
struct_compiler = std::make_unique<StructCompilerLLVM>(
Arch::dx12, this, std::move(device_module), tree->id());
} else {
TI_ASSERT(config->arch == Arch::cuda);
auto device_module =
runtime_exec_->llvm_context_device_.get()->new_module("struct");
struct_compiler = std::make_unique<StructCompilerLLVM>(
Arch::cuda, this, std::move(device_module), tree->id());
}
auto module = runtime_exec_->llvm_context_.get()->new_module("struct");
struct_compiler = std::make_unique<StructCompilerLLVM>(
arch_is_cpu(config->arch) ? host_arch() : config->arch, this,
std::move(module), tree->id());
struct_compiler->run(*root);
++num_snode_trees_processed_;
return struct_compiler;
Expand Down
14 changes: 1 addition & 13 deletions taichi/runtime/program_impls/llvm/llvm_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,18 +150,6 @@ class LlvmProgramImpl : public ProgramImpl {
result_buffer);
}

void initialize_host() {
runtime_exec_->initialize_host();
}

void maybe_initialize_cuda_llvm_context() {
runtime_exec_->maybe_initialize_cuda_llvm_context();
}

void maybe_initialize_amdgpu_llvm_context() {
runtime_exec_->maybe_initialize_amdgpu_llvm_context();
}

uint64 fetch_result_uint64(int i, uint64 *result_buffer) override {
return runtime_exec_->fetch_result_uint64(i, result_buffer);
}
Expand All @@ -185,7 +173,7 @@ class LlvmProgramImpl : public ProgramImpl {
}

TaichiLLVMContext *get_llvm_context(Arch arch) {
return runtime_exec_->get_llvm_context(arch);
return runtime_exec_->get_llvm_context();
}

void synchronize() override {
Expand Down

0 comments on commit 192e78e

Please sign in to comment.