Skip to content

Commit

Permalink
[refactor] Remove unused snode_trees in ProgramImpl interface (#4942)
Browse files Browse the repository at this point in the history
* [refactor] Remove unused snode_trees in ProgramImpl interface

* Update taichi/codegen/codegen_llvm.h
  • Loading branch information
k-ye authored May 10, 2022
1 parent 7657185 commit 2586a9f
Show file tree
Hide file tree
Showing 16 changed files with 49 additions and 91 deletions.
6 changes: 2 additions & 4 deletions taichi/backends/cc/cc_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,8 @@ void CCProgramImpl::materialize_runtime(MemoryPool *memory_pool,
result_buffer_ = *result_buffer_ptr;
}

void CCProgramImpl::materialize_snode_tree(
SNodeTree *tree,
std::vector<std::unique_ptr<SNodeTree>> &,
uint64 *result_buffer) {
void CCProgramImpl::materialize_snode_tree(SNodeTree *tree,
uint64 *result_buffer) {
auto *const root = tree->root();
CCLayoutGen gen(this, root);
layout_ = gen.compile();
Expand Down
4 changes: 1 addition & 3 deletions taichi/backends/cc/cc_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ class CCProgramImpl : public ProgramImpl {
KernelProfilerBase *,
uint64 **result_buffer_ptr) override;

void materialize_snode_tree(SNodeTree *tree,
std::vector<std::unique_ptr<SNodeTree>> &,
uint64 *result_buffer) override;
void materialize_snode_tree(SNodeTree *tree, uint64 *result_buffer) override;

void synchronize() override {
// Not implemented yet.
Expand Down
6 changes: 2 additions & 4 deletions taichi/backends/dx/dx_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,8 @@ void Dx11ProgramImpl::synchronize() {
TI_NOT_IMPLEMENTED;
}

void Dx11ProgramImpl::materialize_snode_tree(
SNodeTree *tree,
std::vector<std::unique_ptr<SNodeTree>> &snode_trees_,
uint64 *result_buffer_ptr) {
void Dx11ProgramImpl::materialize_snode_tree(SNodeTree *tree,
uint64 *result_buffer_ptr) {
snode_tree_mgr_->materialize_snode_tree(tree);
}

Expand Down
6 changes: 2 additions & 4 deletions taichi/backends/dx/dx_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@ class Dx11ProgramImpl : public ProgramImpl {
void materialize_runtime(MemoryPool *memory_pool,
KernelProfilerBase *profiler,
uint64 **result_buffer_ptr) override;
virtual void materialize_snode_tree(
SNodeTree *tree,
std::vector<std::unique_ptr<SNodeTree>> &snode_trees_,
uint64 *result_buffer_ptr) override;
virtual void materialize_snode_tree(SNodeTree *tree,
uint64 *result_buffer_ptr) override;
virtual void destroy_snode_tree(SNodeTree *snode_tree) override;
void synchronize() override;

Expand Down
10 changes: 3 additions & 7 deletions taichi/backends/metal/metal_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,12 @@ void MetalProgramImpl::materialize_runtime(MemoryPool *memory_pool,
metal_kernel_mgr_ = std::make_unique<metal::KernelManager>(std::move(params));
}

void MetalProgramImpl::compile_snode_tree_types(
SNodeTree *tree,
std::vector<std::unique_ptr<SNodeTree>> &snode_trees) {
void MetalProgramImpl::compile_snode_tree_types(SNodeTree *tree) {
(void)compile_snode_tree_types_impl(tree);
}

void MetalProgramImpl::materialize_snode_tree(
SNodeTree *tree,
std::vector<std::unique_ptr<SNodeTree>> &,
uint64 *result_buffer) {
void MetalProgramImpl::materialize_snode_tree(SNodeTree *tree,
uint64 *result_buffer) {
const auto &csnode_tree = compile_snode_tree_types_impl(tree);
metal_kernel_mgr_->add_compiled_snode_tree(csnode_tree);
}
Expand Down
9 changes: 2 additions & 7 deletions taichi/backends/metal/metal_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,9 @@ class MetalProgramImpl : public ProgramImpl {
KernelProfilerBase *profiler,
uint64 **result_buffer_ptr) override;

void compile_snode_tree_types(
SNodeTree *tree,
std::vector<std::unique_ptr<SNodeTree>> &snode_trees) override;
void compile_snode_tree_types(SNodeTree *tree) override;

void materialize_snode_tree(
SNodeTree *tree,
std::vector<std::unique_ptr<SNodeTree>> &snode_trees_,
uint64 *result_buffer) override;
void materialize_snode_tree(SNodeTree *tree, uint64 *result_buffer) override;

void synchronize() override {
metal_kernel_mgr_->synchronize();
Expand Down
12 changes: 4 additions & 8 deletions taichi/backends/opengl/opengl_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,17 @@ std::shared_ptr<Device> OpenglProgramImpl::get_device_shared() {
return opengl_runtime_->device;
}

void OpenglProgramImpl::compile_snode_tree_types(
SNodeTree *tree,
std::vector<std::unique_ptr<SNodeTree>> &snode_trees) {
void OpenglProgramImpl::compile_snode_tree_types(SNodeTree *tree) {
// TODO: support materializing multiple snode trees
opengl::OpenglStructCompiler scomp;
opengl_struct_compiled_ = scomp.run(*(tree->root()));
TI_TRACE("OpenGL root buffer size: {} B", opengl_struct_compiled_->root_size);
}

void OpenglProgramImpl::materialize_snode_tree(
SNodeTree *tree,
std::vector<std::unique_ptr<SNodeTree>> &snode_trees_,
uint64 *result_buffer) {
void OpenglProgramImpl::materialize_snode_tree(SNodeTree *tree,
uint64 *result_buffer) {
#ifdef TI_WITH_OPENGL
compile_snode_tree_types(tree, snode_trees_);
compile_snode_tree_types(tree);
opengl_runtime_->add_snode_tree(opengl_struct_compiled_->root_size);
#else
TI_NOT_IMPLEMENTED;
Expand Down
11 changes: 3 additions & 8 deletions taichi/backends/opengl/opengl_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,9 @@ class OpenglProgramImpl : public ProgramImpl {
KernelProfilerBase *profiler,
uint64 **result_buffer_ptr) override;

void compile_snode_tree_types(
SNodeTree *tree,
std::vector<std::unique_ptr<SNodeTree>> &snode_trees) override;

void materialize_snode_tree(
SNodeTree *tree,
std::vector<std::unique_ptr<SNodeTree>> &snode_trees_,
uint64 *result_buffer) override;
void compile_snode_tree_types(SNodeTree *tree) override;

void materialize_snode_tree(SNodeTree *tree, uint64 *result_buffer) override;

void synchronize() override {
}
Expand Down
10 changes: 3 additions & 7 deletions taichi/backends/vulkan/vulkan_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,7 @@ void VulkanProgramImpl::materialize_runtime(MemoryPool *memory_pool,
std::make_unique<vulkan::SNodeTreeManager>(vulkan_runtime_.get());
}

void VulkanProgramImpl::compile_snode_tree_types(
SNodeTree *tree,
std::vector<std::unique_ptr<SNodeTree>> &snode_trees) {
void VulkanProgramImpl::compile_snode_tree_types(SNodeTree *tree) {
if (vulkan_runtime_) {
snode_tree_mgr_->materialize_snode_tree(tree);
} else {
Expand All @@ -163,10 +161,8 @@ void VulkanProgramImpl::compile_snode_tree_types(
}
}

void VulkanProgramImpl::materialize_snode_tree(
SNodeTree *tree,
std::vector<std::unique_ptr<SNodeTree>> &,
uint64 *result_buffer) {
void VulkanProgramImpl::materialize_snode_tree(SNodeTree *tree,
uint64 *result_buffer) {
snode_tree_mgr_->materialize_snode_tree(tree);
}

Expand Down
8 changes: 2 additions & 6 deletions taichi/backends/vulkan/vulkan_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,13 @@ class VulkanProgramImpl : public ProgramImpl {
return 0; // TODO: support sparse in vulkan
}

void compile_snode_tree_types(
SNodeTree *tree,
std::vector<std::unique_ptr<SNodeTree>> &snode_trees) override;
void compile_snode_tree_types(SNodeTree *tree) override;

void materialize_runtime(MemoryPool *memory_pool,
KernelProfilerBase *profiler,
uint64 **result_buffer_ptr) override;

void materialize_snode_tree(SNodeTree *tree,
std::vector<std::unique_ptr<SNodeTree>> &,
uint64 *result_buffer) override;
void materialize_snode_tree(SNodeTree *tree, uint64 *result_buffer) override;

void synchronize() override {
vulkan_runtime_->synchronize();
Expand Down
1 change: 1 addition & 0 deletions taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// The LLVM backend for CPUs/NVPTX/AMDGPU
#pragma once

#ifdef TI_WITH_LLVM

#include <set>
Expand Down
23 changes: 11 additions & 12 deletions taichi/llvm/llvm_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,11 @@ void LlvmProgramImpl::synchronize() {

std::unique_ptr<llvm::Module>
LlvmProgramImpl::clone_struct_compiler_initial_context(
const std::vector<std::unique_ptr<SNodeTree>> &snode_trees_,
bool has_multiple_snode_trees,
TaichiLLVMContext *tlctx) {
if (!snode_trees_.empty())
if (has_multiple_snode_trees) {
return tlctx->clone_struct_module();
}
return tlctx->clone_runtime_module();
}

Expand Down Expand Up @@ -244,31 +245,29 @@ void LlvmProgramImpl::initialize_llvm_runtime_snodes(const SNodeTree *tree,
}
}

void LlvmProgramImpl::compile_snode_tree_types(
SNodeTree *tree,
std::vector<std::unique_ptr<SNodeTree>> &snode_trees) {
void LlvmProgramImpl::compile_snode_tree_types(SNodeTree *tree) {
auto *const root = tree->root();
const bool has_multiple_snode_trees = (num_snode_trees_processed_ > 0);
if (arch_is_cpu(config->arch)) {
auto host_module = clone_struct_compiler_initial_context(
snode_trees, llvm_context_host_.get());
has_multiple_snode_trees, llvm_context_host_.get());
struct_compiler_ = std::make_unique<StructCompilerLLVM>(
host_arch(), this, std::move(host_module), tree->id());

} else {
TI_ASSERT(config->arch == Arch::cuda);
auto device_module = clone_struct_compiler_initial_context(
snode_trees, llvm_context_device_.get());
has_multiple_snode_trees, llvm_context_device_.get());
struct_compiler_ = std::make_unique<StructCompilerLLVM>(
Arch::cuda, this, std::move(device_module), tree->id());
}
struct_compiler_->run(*root);
++num_snode_trees_processed_;
}

void LlvmProgramImpl::materialize_snode_tree(
SNodeTree *tree,
std::vector<std::unique_ptr<SNodeTree>> &snode_trees_,
uint64 *result_buffer) {
compile_snode_tree_types(tree, snode_trees_);
void LlvmProgramImpl::materialize_snode_tree(SNodeTree *tree,
uint64 *result_buffer) {
compile_snode_tree_types(tree);
initialize_llvm_runtime_snodes(tree, struct_compiler_.get(), result_buffer);
}

Expand Down
15 changes: 7 additions & 8 deletions taichi/llvm/llvm_program.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
#pragma once

#include <cstddef>

#include "taichi/llvm/llvm_device.h"
#include "taichi/llvm/llvm_offline_cache.h"
#include "taichi/system/snode_tree_buffer_manager.h"
Expand Down Expand Up @@ -62,14 +65,9 @@ class LlvmProgramImpl : public ProgramImpl {

FunctionType compile(Kernel *kernel, OffloadedStmt *offloaded) override;

void compile_snode_tree_types(
SNodeTree *tree,
std::vector<std::unique_ptr<SNodeTree>> &snode_trees) override;
void compile_snode_tree_types(SNodeTree *tree) override;

void materialize_snode_tree(
SNodeTree *tree,
std::vector<std::unique_ptr<SNodeTree>> &snode_trees_,
uint64 *result_buffer) override;
void materialize_snode_tree(SNodeTree *tree, uint64 *result_buffer) override;

template <typename T>
T fetch_result(int i, uint64 *result_buffer) {
Expand Down Expand Up @@ -122,7 +120,7 @@ class LlvmProgramImpl : public ProgramImpl {

private:
std::unique_ptr<llvm::Module> clone_struct_compiler_initial_context(
const std::vector<std::unique_ptr<SNodeTree>> &snode_trees_,
bool has_multiple_snode_trees,
TaichiLLVMContext *tlctx);

/**
Expand Down Expand Up @@ -173,6 +171,7 @@ class LlvmProgramImpl : public ProgramImpl {
std::unique_ptr<Runtime> runtime_mem_info_{nullptr};
std::unique_ptr<SNodeTreeBufferManager> snode_tree_buffer_manager_{nullptr};
std::unique_ptr<StructCompiler> struct_compiler_{nullptr};
std::size_t num_snode_trees_processed_{0};
void *llvm_runtime_{nullptr};
void *preallocated_device_buffer_{nullptr}; // TODO: move to memory allocator

Expand Down
5 changes: 2 additions & 3 deletions taichi/program/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,9 @@ SNodeTree *Program::add_snode_tree(std::unique_ptr<SNode> root,
auto tree = std::make_unique<SNodeTree>(id, std::move(root));
tree->root()->set_snode_tree_id(id);
if (compile_only) {
program_impl_->compile_snode_tree_types(tree.get(), snode_trees_);
program_impl_->compile_snode_tree_types(tree.get());
} else {
program_impl_->materialize_snode_tree(tree.get(), snode_trees_,
result_buffer);
program_impl_->materialize_snode_tree(tree.get(), result_buffer);
}
if (id < snode_trees_.size()) {
snode_trees_[id] = std::move(tree);
Expand Down
4 changes: 1 addition & 3 deletions taichi/program/program_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@ namespace lang {
ProgramImpl::ProgramImpl(CompileConfig &config_) : config(&config_) {
}

void ProgramImpl::compile_snode_tree_types(
SNodeTree *tree,
std::vector<std::unique_ptr<SNodeTree>> &snode_trees) {
void ProgramImpl::compile_snode_tree_types(SNodeTree *tree) {
// FIXME: Eventually all the backends should implement this
TI_NOT_IMPLEMENTED;
}
Expand Down
10 changes: 3 additions & 7 deletions taichi/program/program_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,13 @@ class ProgramImpl {
/**
* JIT compiles @param tree to backend-specific data types.
*/
virtual void compile_snode_tree_types(
SNodeTree *tree,
std::vector<std::unique_ptr<SNodeTree>> &snode_trees);
virtual void compile_snode_tree_types(SNodeTree *tree);

/**
* Compiles the @param tree types and allocates runtime buffer for it.
*/
virtual void materialize_snode_tree(
SNodeTree *tree,
std::vector<std::unique_ptr<SNodeTree>> &snode_trees_,
uint64 *result_buffer_ptr) = 0;
virtual void materialize_snode_tree(SNodeTree *tree,
uint64 *result_buffer_ptr) = 0;

virtual void destroy_snode_tree(SNodeTree *snode_tree) = 0;

Expand Down

0 comments on commit 2586a9f

Please sign in to comment.