diff --git a/taichi/codegen/dx12/codegen_dx12.cpp b/taichi/codegen/dx12/codegen_dx12.cpp index 9b9515ba321b33..8530c9bea3ad24 100644 --- a/taichi/codegen/dx12/codegen_dx12.cpp +++ b/taichi/codegen/dx12/codegen_dx12.cpp @@ -148,8 +148,9 @@ class TaskCodeGenLLVMDX12 : public TaskCodeGenLLVM { #ifdef TI_WITH_LLVM -static std::vector generate_dxil_from_llvm(LLVMCompiledData &compiled_data, - taichi::lang::Kernel *kernel) { +static std::vector generate_dxil_from_llvm( + LLVMCompiledData &compiled_data, + taichi::lang::Kernel *kernel) { // generate dxil from llvm ir. auto offloaded_local = compiled_data.tasks; auto module = compiled_data.module.get(); @@ -162,7 +163,8 @@ static std::vector generate_dxil_from_llvm(LLVMCompiledData &compiled_d // FIXME: save task.block_dim like // tlctx->mark_function_as_cuda_kernel(func, task.block_dim); } - auto dx_container = directx12::global_optimize_module(module, kernel->program->config); + auto dx_container = + directx12::global_optimize_module(module, kernel->program->config); // validate and sign dx container. return directx12::validate_and_sign(dx_container); } @@ -197,8 +199,7 @@ KernelCodeGenDX12::CompileResult KernelCodeGenDX12::compile() { aot::CompiledOffloadedTask task; // FIXME: build all fields for task. task.name = fmt::format("{}_{}_{}", kernel->get_name(), - offload_stmt->task_name(), - i); + offload_stmt->task_name(), i); task.type = offload_stmt->task_name(); Result.tasks.emplace_back(task); } @@ -209,7 +210,7 @@ KernelCodeGenDX12::CompileResult KernelCodeGenDX12::compile() { LLVMCompiledData KernelCodeGenDX12::modulegen( std::unique_ptr &&module, - OffloadedStmt *stmt) { + OffloadedStmt *stmt) { TaskCodeGenLLVMDX12 gen(kernel, stmt); return gen.run_compilation(); } @@ -217,7 +218,6 @@ LLVMCompiledData KernelCodeGenDX12::modulegen( FunctionType KernelCodeGenDX12::codegen() { // FIXME: implement codegen. - return [](RuntimeContext &ctx) { - }; + return [](RuntimeContext &ctx) {}; } TLANG_NAMESPACE_END diff --git a/taichi/codegen/dx12/dx12_llvm_passes.h b/taichi/codegen/dx12/dx12_llvm_passes.h index 6a83a1e94ba7a0..59f788f4d78c20 100644 --- a/taichi/codegen/dx12/dx12_llvm_passes.h +++ b/taichi/codegen/dx12/dx12_llvm_passes.h @@ -9,7 +9,7 @@ class Function; class Module; class Type; class GlobalVariable; -} +} // namespace llvm namespace taichi { namespace lang { @@ -26,14 +26,13 @@ enum class BufferSpaceId { Args = 3, Runtime = 4, Result = 5, - UtilCBuffer = 6, // For things like Num work groups. + UtilCBuffer = 6, // For things like Num work groups. }; enum ResourceAddressSpace { CBuffer = 4, }; - void mark_function_as_cs_entry(llvm::Function *); bool is_cs_entry(llvm::Function *); void set_num_threads(llvm::Function *, unsigned x, unsigned y, unsigned z); @@ -52,7 +51,6 @@ extern const char *NumWorkGroupsCBName; } // namespace lang } // namespace taichi - namespace llvm { class ModulePass; class PassRegistry; @@ -64,11 +62,10 @@ void initializeTaichiRuntimeContextLowerPass(PassRegistry &); /// Pass to convert modules into DXIL-compatable modules ModulePass *createTaichiRuntimeContextLowerPass(); - /// Initializer for taichi intrinsic lower. void initializeTaichiIntrinsicLowerPass(PassRegistry &); /// Pass to lower taichi intrinsic into DXIL intrinsic. ModulePass *createTaichiIntrinsicLowerPass(taichi::lang::CompileConfig *config); -} +} // namespace llvm diff --git a/taichi/codegen/dx12/dx12_lower_runtime_context.cpp b/taichi/codegen/dx12/dx12_lower_runtime_context.cpp index b6f9d1cf80f3ef..7f81f99862dbe7 100644 --- a/taichi/codegen/dx12/dx12_lower_runtime_context.cpp +++ b/taichi/codegen/dx12/dx12_lower_runtime_context.cpp @@ -18,7 +18,6 @@ using namespace taichi::lang::directx12; #define DEBUG_TYPE "dxil-taichi-runtime-context-lower" - namespace { class TaichiRuntimeContextLower : public ModulePass { @@ -38,12 +37,11 @@ char TaichiRuntimeContextLower::ID = 0; } // end anonymous namespace - INITIALIZE_PASS(TaichiRuntimeContextLower, - DEBUG_TYPE, - "Lower taichi RuntimeContext", - false, - false) + DEBUG_TYPE, + "Lower taichi RuntimeContext", + false, + false) llvm::ModulePass *llvm::createTaichiRuntimeContextLowerPass() { return new TaichiRuntimeContextLower(); diff --git a/taichi/rhi/arch.cpp b/taichi/rhi/arch.cpp index fbf4ab11f3cdb9..3685d1948ed011 100644 --- a/taichi/rhi/arch.cpp +++ b/taichi/rhi/arch.cpp @@ -45,8 +45,7 @@ bool arch_is_cpu(Arch arch) { bool arch_uses_llvm(Arch arch) { return (arch == Arch::x64 || arch == Arch::arm64 || arch == Arch::cuda || - arch == Arch::dx12 || - arch == Arch::wasm); + arch == Arch::dx12 || arch == Arch::wasm); } bool arch_is_gpu(Arch arch) { diff --git a/taichi/rhi/dx12/dx12_api.h b/taichi/rhi/dx12/dx12_api.h index 9a7614227601a7..485ddacb087f71 100644 --- a/taichi/rhi/dx12/dx12_api.h +++ b/taichi/rhi/dx12/dx12_api.h @@ -14,8 +14,9 @@ bool is_dx12_api_available(); std::shared_ptr make_dx12_device(); -std::vector validate_and_sign(std::vector &input_dxil_container); +std::vector validate_and_sign( + std::vector &input_dxil_container); -} // namespace directx11 +} // namespace directx12 } // namespace lang } // namespace taichi diff --git a/taichi/runtime/dx12/aot_graph_data.h b/taichi/runtime/dx12/aot_graph_data.h index 9420d42792c6b1..d8ee16572f0edc 100644 --- a/taichi/runtime/dx12/aot_graph_data.h +++ b/taichi/runtime/dx12/aot_graph_data.h @@ -11,7 +11,6 @@ class KernelImpl : public aot::Kernel { void launch(RuntimeContext *ctx) override { } - }; } // namespace directx12 } // namespace lang diff --git a/taichi/runtime/dx12/aot_module_builder_impl.cpp b/taichi/runtime/dx12/aot_module_builder_impl.cpp index 5f26750d224056..1a9f4ef5e61aee 100644 --- a/taichi/runtime/dx12/aot_module_builder_impl.cpp +++ b/taichi/runtime/dx12/aot_module_builder_impl.cpp @@ -17,8 +17,6 @@ AotModuleBuilderImpl::AotModuleBuilderImpl(LlvmProgramImpl *prog) : prog(prog) { void AotModuleBuilderImpl::add_per_backend(const std::string &identifier, Kernel *kernel) { - - auto &dxil_codes = module_data.dxil_codes[identifier]; auto &compiled_kernel = module_data.kernels[identifier]; @@ -50,11 +48,11 @@ void AotModuleBuilderImpl::add_field_per_backend(const std::string &identifier, // matter too much for now. TI_ERROR_IF(!all_fields_are_dense_in_container(rep_snode->parent), "AOT: only supports dense field"); - + const auto &field = prog->get_cached_field(rep_snode->get_snode_tree_id()); - //const auto &dense_desc = - // compiled_structs_[0].snode_descriptors.at(rep_snode->parent->id); + // const auto &dense_desc = + // compiled_structs_[0].snode_descriptors.at(rep_snode->parent->id); aot::CompiledFieldData field_data; field_data.field_name = identifier; @@ -64,7 +62,7 @@ void AotModuleBuilderImpl::add_field_per_backend(const std::string &identifier, field_data.shape = shape; // FIXME: calc mem_offset_in_parent for llvm path. field_data.mem_offset_in_parent = field.snode_metas[0].chunk_size; - //dense_desc.mem_offset_in_parent_cell; + // dense_desc.mem_offset_in_parent_cell; if (!is_scalar) { field_data.element_shape = {row_num, column_num}; } @@ -122,14 +120,14 @@ void AotModuleBuilderImpl::dump(const std::string &output_dir, } } - const std::string json_path = fmt::format("{}/metadata_dx12.json", output_dir); + const std::string json_path = + fmt::format("{}/metadata_dx12.json", output_dir); tmp_module_data.dump_json(json_path); // FIXME: dump graph to different file. - //dump_graph(output_dir); + // dump_graph(output_dir); } - } // namespace directx12 } // namespace lang } // namespace taichi diff --git a/taichi/runtime/dx12/aot_module_loader_impl.cpp b/taichi/runtime/dx12/aot_module_loader_impl.cpp index b178a623502192..25e1690f60fc1f 100644 --- a/taichi/runtime/dx12/aot_module_loader_impl.cpp +++ b/taichi/runtime/dx12/aot_module_loader_impl.cpp @@ -13,8 +13,7 @@ namespace directx12 { namespace { class FieldImpl : public aot::Field { public: - explicit FieldImpl(const aot::CompiledFieldData &field) - : field_(field) { + explicit FieldImpl(const aot::CompiledFieldData &field) : field_(field) { } private: @@ -40,9 +39,9 @@ class AotModuleImpl : public aot::Module { } // FIXME: enable once write graph to graphs_dx12.tcb. - //const std::string graph_path = + // const std::string graph_path = // fmt::format("{}/graphs_dx12.tcb", params.module_path); - //read_from_binary_file(graphs_, graph_path); + // read_from_binary_file(graphs_, graph_path); } std::unique_ptr get_graph(std::string name) override { @@ -103,7 +102,7 @@ class AotModuleImpl : public aot::Module { } std::vector read_dxil_container(const std::string &output_dir, - const std::string &name) { + const std::string &name) { const std::string path = fmt::format("{}/{}.dxc", output_dir, name); std::vector source_code; std::ifstream fs(path, std::ios_base::binary | std::ios::ate); @@ -127,6 +126,6 @@ std::unique_ptr make_aot_module(std::any mod_params, return std::make_unique(params, device_api_backend); } -} // namespace gfx +} // namespace directx12 } // namespace lang } // namespace taichi diff --git a/taichi/runtime/llvm/CMakeLists.txt b/taichi/runtime/llvm/CMakeLists.txt index da7bed2826c0e9..bd8448d63f86a2 100644 --- a/taichi/runtime/llvm/CMakeLists.txt +++ b/taichi/runtime/llvm/CMakeLists.txt @@ -36,4 +36,4 @@ endif() if (TI_WITH_DX12) target_link_libraries(llvm_runtime PRIVATE ${llvm_directx_libs}) target_link_libraries(llvm_runtime PRIVATE dx12_rhi) -endif() \ No newline at end of file +endif() diff --git a/taichi/runtime/llvm/llvm_runtime_executor.cpp b/taichi/runtime/llvm/llvm_runtime_executor.cpp index cd6216a0cf366b..eced0fc29dd5b2 100644 --- a/taichi/runtime/llvm/llvm_runtime_executor.cpp +++ b/taichi/runtime/llvm/llvm_runtime_executor.cpp @@ -121,7 +121,7 @@ LlvmRuntimeExecutor::LlvmRuntimeExecutor(CompileConfig &config, llvm_context_device_ = std::make_unique(config_, Arch::dx12); // FIXME: add dx12 JIT. - //llvm_context_device_->init_runtime_jit_module(); + // llvm_context_device_->init_runtime_jit_module(); } #endif diff --git a/tests/cpp/aot/dx12/aot_save_load_test.cpp b/tests/cpp/aot/dx12/aot_save_load_test.cpp index 3039884366a2dc..7f9905dcdb32c1 100644 --- a/tests/cpp/aot/dx12/aot_save_load_test.cpp +++ b/tests/cpp/aot/dx12/aot_save_load_test.cpp @@ -102,7 +102,6 @@ using namespace lang; aot_builder->dump(".", ""); } - #ifdef TI_WITH_DX12 TEST(AotSaveLoad, DX12) { @@ -134,7 +133,6 @@ TEST(AotSaveLoad, DX12) { EXPECT_FALSE(ret2_kernel); // FIXME: test run kernels and check the result. - } #endif