Skip to content

Commit

Permalink
[refactor] Remove ir parameter of KernelCodeGen::KernelCodeGen(Kernel…
Browse files Browse the repository at this point in the history
… *kernel, IRNode *ir) (#7046)

Issue: #7002
  • Loading branch information
PGZXB authored Jan 5, 2023
1 parent edb8afa commit 1a7a5ef
Show file tree
Hide file tree
Showing 8 changed files with 17 additions and 26 deletions.
18 changes: 8 additions & 10 deletions taichi/codegen/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,27 @@

namespace taichi::lang {

KernelCodeGen::KernelCodeGen(Kernel *kernel, IRNode *ir)
: prog(kernel->program), kernel(kernel), ir(ir) {
if (ir == nullptr)
this->ir = kernel->ir.get();
KernelCodeGen::KernelCodeGen(Kernel *kernel)
: prog(kernel->program), kernel(kernel) {
this->ir = kernel->ir.get();
}

std::unique_ptr<KernelCodeGen> KernelCodeGen::create(Arch arch,
Kernel *kernel,
Stmt *stmt) {
Kernel *kernel) {
#ifdef TI_WITH_LLVM
if (arch_is_cpu(arch) && arch != Arch::wasm) {
return std::make_unique<KernelCodeGenCPU>(kernel, stmt);
return std::make_unique<KernelCodeGenCPU>(kernel);
} else if (arch == Arch::wasm) {
return std::make_unique<KernelCodeGenWASM>(kernel, stmt);
return std::make_unique<KernelCodeGenWASM>(kernel);
} else if (arch == Arch::cuda) {
#if defined(TI_WITH_CUDA)
return std::make_unique<KernelCodeGenCUDA>(kernel, stmt);
return std::make_unique<KernelCodeGenCUDA>(kernel);
#else
TI_NOT_IMPLEMENTED
#endif
} else if (arch == Arch::dx12) {
#if defined(TI_WITH_DX12)
return std::make_unique<KernelCodeGenDX12>(kernel, stmt);
return std::make_unique<KernelCodeGenDX12>(kernel);
#else
TI_NOT_IMPLEMENTED
#endif
Expand Down
6 changes: 2 additions & 4 deletions taichi/codegen/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,11 @@ class KernelCodeGen {
IRNode *ir;

public:
KernelCodeGen(Kernel *kernel, IRNode *ir);
explicit KernelCodeGen(Kernel *kernel);

virtual ~KernelCodeGen() = default;

static std::unique_ptr<KernelCodeGen> create(Arch arch,
Kernel *kernel,
Stmt *stmt = nullptr);
static std::unique_ptr<KernelCodeGen> create(Arch arch, Kernel *kernel);

virtual FunctionType compile_to_function() = 0;
virtual bool supports_offline_cache() const {
Expand Down
3 changes: 1 addition & 2 deletions taichi/codegen/cpu/codegen_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ namespace taichi::lang {

class KernelCodeGenCPU : public KernelCodeGen {
public:
explicit KernelCodeGenCPU(Kernel *kernel, IRNode *ir = nullptr)
: KernelCodeGen(kernel, ir) {
explicit KernelCodeGenCPU(Kernel *kernel) : KernelCodeGen(kernel) {
}

// TODO: Stop defining this macro guards in the headers
Expand Down
3 changes: 1 addition & 2 deletions taichi/codegen/cuda/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ namespace taichi::lang {

class KernelCodeGenCUDA : public KernelCodeGen {
public:
explicit KernelCodeGenCUDA(Kernel *kernel, IRNode *ir = nullptr)
: KernelCodeGen(kernel, ir) {
explicit KernelCodeGenCUDA(Kernel *kernel) : KernelCodeGen(kernel) {
}

// TODO: Stop defining this macro guards in the headers
Expand Down
3 changes: 1 addition & 2 deletions taichi/codegen/dx12/codegen_dx12.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ namespace taichi::lang {

class KernelCodeGenDX12 : public KernelCodeGen {
public:
KernelCodeGenDX12(Kernel *kernel, IRNode *ir = nullptr)
: KernelCodeGen(kernel, ir) {
explicit KernelCodeGenDX12(Kernel *kernel) : KernelCodeGen(kernel) {
}
struct CompileResult {
std::vector<std::vector<uint8_t>> task_dxil_source_codes;
Expand Down
3 changes: 1 addition & 2 deletions taichi/codegen/wasm/codegen_wasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ namespace taichi::lang {

class KernelCodeGenWASM : public KernelCodeGen {
public:
explicit KernelCodeGenWASM(Kernel *kernel, IRNode *ir = nullptr)
: KernelCodeGen(kernel, ir) {
explicit KernelCodeGenWASM(Kernel *kernel) : KernelCodeGen(kernel) {
}

FunctionType compile_to_function() override;
Expand Down
4 changes: 2 additions & 2 deletions taichi/runtime/dx12/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ void AotModuleBuilderImpl::add_per_backend(const std::string &identifier,
auto &dxil_codes = module_data.dxil_codes[identifier];
auto &compiled_kernel = module_data.kernels[identifier];

KernelCodeGenDX12 cgen(kernel, /*ir*/ nullptr);
KernelCodeGenDX12 cgen(kernel);
auto compiled_data = cgen.compile();
for (auto &dxil : compiled_data.task_dxil_source_codes) {
dxil_codes.emplace_back(dxil);
Expand Down Expand Up @@ -69,7 +69,7 @@ void AotModuleBuilderImpl::add_per_backend_tmpl(const std::string &identifier,
auto &dxil_codes = module_data.dxil_codes[tmpl_identifier];
auto &compiled_kernel = module_data.kernels[tmpl_identifier];

KernelCodeGenDX12 cgen(kernel, /*ir*/ nullptr);
KernelCodeGenDX12 cgen(kernel);
auto compiled_data = cgen.compile();
for (auto &dxil : compiled_data.task_dxil_source_codes) {
dxil_codes.emplace_back(dxil);
Expand Down
3 changes: 1 addition & 2 deletions taichi/runtime/wasm/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ void AotModuleBuilderImpl::dump(const std::string &output_dir,

void AotModuleBuilderImpl::add_per_backend(const std::string &identifier,
Kernel *kernel) {
auto module_info =
KernelCodeGenWASM(kernel, nullptr).compile_kernel_to_module();
auto module_info = KernelCodeGenWASM(kernel).compile_kernel_to_module();
if (module_) {
llvm::Linker::linkModules(*module_, std::move(module_info.module),
llvm::Linker::OverrideFromSrc);
Expand Down

0 comments on commit 1a7a5ef

Please sign in to comment.