diff --git a/taichi/codegen/dx12/CMakeLists.txt b/taichi/codegen/dx12/CMakeLists.txt index 8a1b70aaa6ea4..24dca8d27238b 100644 --- a/taichi/codegen/dx12/CMakeLists.txt +++ b/taichi/codegen/dx12/CMakeLists.txt @@ -4,6 +4,7 @@ add_library(dx12_codegen) target_sources(dx12_codegen PRIVATE codegen_dx12.cpp + dx12_global_optimize_module.cpp ) target_include_directories(dx12_codegen diff --git a/taichi/codegen/dx12/codegen_dx12.cpp b/taichi/codegen/dx12/codegen_dx12.cpp index 29b446a4cbe4a..4be95a53f7b1b 100644 --- a/taichi/codegen/dx12/codegen_dx12.cpp +++ b/taichi/codegen/dx12/codegen_dx12.cpp @@ -1,5 +1,5 @@ #include "taichi/codegen/dx12/codegen_dx12.h" - +#include "taichi/codegen/dx12/dx12_llvm_passes.h" #include "taichi/rhi/dx12/dx12_api.h" #include "taichi/runtime/program_impls/llvm/llvm_program.h" #include "taichi/common/core.h" @@ -228,7 +228,24 @@ class TaskCodeGenLLVMDX12 : public TaskCodeGenLLVM { static std::vector generate_dxil_from_llvm( LLVMCompiledData &compiled_data, - taichi::lang::Kernel *kernel){TI_NOT_IMPLEMENTED} + taichi::lang::Kernel *kernel) { + // generate dxil from llvm ir. + auto offloaded_local = compiled_data.tasks; + auto module = compiled_data.module.get(); + for (auto &task : offloaded_local) { + llvm::Function *func = module->getFunction(task.name); + TI_ASSERT(func); + directx12::mark_function_as_cs_entry(func); + directx12::set_num_threads( + func, kernel->program->config.default_gpu_block_dim, 1, 1); + // FIXME: save task.block_dim like + // tlctx->mark_function_as_cuda_kernel(func, task.block_dim); + } + auto dx_container = + directx12::global_optimize_module(module, kernel->program->config); + // validate and sign dx container. + return directx12::validate_and_sign(dx_container); +} KernelCodeGenDX12::CompileResult KernelCodeGenDX12::compile() { TI_AUTO_PROF; diff --git a/taichi/codegen/dx12/dx12_global_optimize_module.cpp b/taichi/codegen/dx12/dx12_global_optimize_module.cpp new file mode 100644 index 0000000000000..94a9d7c003328 --- /dev/null +++ b/taichi/codegen/dx12/dx12_global_optimize_module.cpp @@ -0,0 +1,152 @@ + +#include "taichi/common/core.h" +#include "taichi/util/io.h" +#include "taichi/program/program.h" +#include "taichi/ir/ir.h" +#include "taichi/ir/statements.h" +#include "taichi/util/statistics.h" +#include "taichi/util/file_sequence_writer.h" +#include "taichi/runtime/llvm/llvm_context.h" + +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "llvm/IR/Function.h" + +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Attributes.h" +#include "llvm/Support/Host.h" +#include "llvm/MC/TargetRegistry.h" +#include "llvm/IR/Verifier.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Transforms/IPO/PassManagerBuilder.h" +#include "llvm/Transforms/InstCombine/InstCombine.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/GVN.h" +#include "llvm/Transforms/IPO.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/IR/GlobalVariable.h" + +using namespace llvm; + +namespace taichi { +namespace lang { +namespace directx12 { + +const llvm::StringRef ShaderAttrKindStr = "hlsl.shader"; + +void mark_function_as_cs_entry(::llvm::Function *F) { + F->addFnAttr(ShaderAttrKindStr, "compute"); +} +bool is_cs_entry(::llvm::Function *F) { + return F->hasFnAttribute(ShaderAttrKindStr); +} + +void set_num_threads(llvm::Function *F, unsigned x, unsigned y, unsigned z) { + const llvm::StringRef NumThreadsAttrKindStr = "hlsl.numthreads"; + std::string Str = llvm::formatv("{0},{1},{2}", x, y, z); + F->addFnAttr(NumThreadsAttrKindStr, Str); +} + +std::vector global_optimize_module(llvm::Module *module, + CompileConfig &config) { + TI_AUTO_PROF + if (llvm::verifyModule(*module, &llvm::errs())) { + module->print(llvm::errs(), nullptr); + TI_ERROR("Module broken"); + } + + for (llvm::Function &F : module->functions()) { + if (directx12::is_cs_entry(&F)) + continue; + F.addFnAttr(llvm::Attribute::AlwaysInline); + } + // FIXME: choose shader model based on feature used. + llvm::StringRef triple = "dxil-pc-shadermodel6.3-compute"; + module->setTargetTriple(triple); + module->setSourceFileName(""); + std::string err_str; + const llvm::Target *target = + TargetRegistry::lookupTarget(triple.str(), err_str); + TI_ERROR_UNLESS(target, err_str); + + TargetOptions options; + if (config.fast_math) { + options.AllowFPOpFusion = FPOpFusion::Fast; + options.UnsafeFPMath = 1; + options.NoInfsFPMath = 1; + options.NoNaNsFPMath = 1; + } else { + options.AllowFPOpFusion = FPOpFusion::Strict; + options.UnsafeFPMath = 0; + options.NoInfsFPMath = 0; + options.NoNaNsFPMath = 0; + } + options.HonorSignDependentRoundingFPMathOption = false; + options.NoZerosInBSS = false; + options.GuaranteedTailCallOpt = false; + + legacy::FunctionPassManager function_pass_manager(module); + legacy::PassManager module_pass_manager; + + llvm::StringRef mcpu = ""; + std::unique_ptr target_machine(target->createTargetMachine( + triple.str(), mcpu.str(), "", options, llvm::Reloc::PIC_, + llvm::CodeModel::Small, + config.opt_level > 0 ? CodeGenOpt::Aggressive : CodeGenOpt::None)); + + TI_ERROR_UNLESS(target_machine.get(), "Could not allocate target machine!"); + + module->setDataLayout(target_machine->createDataLayout()); + + module_pass_manager.add(createTargetTransformInfoWrapperPass( + target_machine->getTargetIRAnalysis())); + function_pass_manager.add(createTargetTransformInfoWrapperPass( + target_machine->getTargetIRAnalysis())); + + PassManagerBuilder b; + b.OptLevel = 3; + b.Inliner = createFunctionInliningPass(b.OptLevel, 0, false); + b.LoopVectorize = true; + b.SLPVectorize = true; + + target_machine->adjustPassManager(b); + + b.populateFunctionPassManager(function_pass_manager); + b.populateModulePassManager(module_pass_manager); + llvm::SmallString<256> str; + llvm::raw_svector_ostream OS(str); + // Write DXIL container to OS. + target_machine->addPassesToEmitFile(module_pass_manager, OS, nullptr, + CGFT_ObjectFile); + + { + TI_PROFILER("llvm_function_pass"); + function_pass_manager.doInitialization(); + for (llvm::Module::iterator i = module->begin(); i != module->end(); i++) + function_pass_manager.run(*i); + + function_pass_manager.doFinalization(); + } + + { + TI_PROFILER("llvm_module_pass"); + module_pass_manager.run(*module); + } + if (config.print_kernel_llvm_ir_optimized) { + static FileSequenceWriter writer( + "taichi_kernel_dx12_llvm_ir_optimized_{:04d}.ll", + "optimized LLVM IR (DX12)"); + writer.write(module); + } + return std::vector(str.begin(), str.end()); +} + +} // namespace directx12 +} // namespace lang +} // namespace taichi diff --git a/taichi/codegen/dx12/dx12_llvm_passes.h b/taichi/codegen/dx12/dx12_llvm_passes.h new file mode 100644 index 0000000000000..c07896abba1a3 --- /dev/null +++ b/taichi/codegen/dx12/dx12_llvm_passes.h @@ -0,0 +1,29 @@ + +#pragma once + +#include +#include + +namespace llvm { +class Function; +class Module; +} // namespace llvm + +namespace taichi { +namespace lang { +struct CompileConfig; + +namespace directx12 { + +void mark_function_as_cs_entry(llvm::Function *); +bool is_cs_entry(llvm::Function *); +void set_num_threads(llvm::Function *, unsigned x, unsigned y, unsigned z); + +std::vector global_optimize_module(llvm::Module *module, + CompileConfig &config); + +extern const char *NumWorkGroupsCBName; + +} // namespace directx12 +} // namespace lang +} // namespace taichi