Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dx12] Add aot for dx12. #6099

Merged
merged 13 commits into from
Sep 27, 2022
Merged
13 changes: 7 additions & 6 deletions .github/workflows/scripts/win_build.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ if (-not (Test-Path $libsDir)) {

$RepoURL = 'https://github.com/taichi-dev/taichi'

SetupCCacheLocal "$libsDir/ccache"
# SetupCCacheLocal "$libsDir/ccache"
python3kgae marked this conversation as resolved.
Show resolved Hide resolved

if ($clone) {
Info("Clone the repository")
Expand Down Expand Up @@ -58,14 +58,15 @@ if ($llvmVer -eq "10") {
$env:TAICHI_CMAKE_ARGS += " -DCLANG_EXECUTABLE=$($libsDir -replace "\\", "\\")\\taichi_clang\\bin\\clang++.exe"
$env:TAICHI_CMAKE_ARGS += " -DLLVM_AS_EXECUTABLE=$($libsDir -replace "\\", "\\")\\taichi_llvm\\bin\\llvm-as.exe"
} elseif ($llvmVer -eq "15") {
DownloadDep LLVM-15 llvm-15.zip taichi_llvm_15 `
https://github.com/python3kgae/taichi_assets/releases/download/llvm15_vs2019_clang/taichi-llvm-15.0.0-msvc2019.zip
DownloadDep LLVM-15 llvm-15.zip taichi_llvm_15_test `
https://github.com/python3kgae/taichi_assets/releases/download/llvm15_vs2019_clang/taichi-llvm-15.0.0-msvc2019-patched.zip
DownloadDep Clang-15 clang-15.zip taichi_clang_15 `
https://github.com/python3kgae/taichi_assets/releases/download/llvm15_vs2022_clang/clang-15.0.0-win.zip
$env:LLVM_DIR = "$libsDir\taichi_llvm_15"
$env:LLVM_DIR = "$libsDir\taichi_llvm_15_test"
$env:TAICHI_CMAKE_ARGS += " -DCLANG_EXECUTABLE=$($libsDir -replace "\\", "\\")\\taichi_clang_15\\bin\\clang++.exe"
$env:TAICHI_CMAKE_ARGS += " -DLLVM_AS_EXECUTABLE=$($libsDir -replace "\\", "\\")\\taichi_llvm_15\\bin\\llvm-as.exe"
$env:TAICHI_CMAKE_ARGS += " -DLLVM_AS_EXECUTABLE=$($libsDir -replace "\\", "\\")\\taichi_llvm_15_test\\bin\\llvm-as.exe"
$env:TAICHI_CMAKE_ARGS += " -DTI_LLVM_15:BOOL=ON"
$env:TAICHI_CMAKE_ARGS += " -DTI_WITH_DX12:BOOL=ON"
} else {
throw "Unsupported LLVM version"
}
Expand Down Expand Up @@ -112,4 +113,4 @@ if ($install) {
Info("Build finished")
}

ccache -s -v
#ccache -s -v
python3kgae marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ jobs:
TAICHI_CMAKE_ARGS: >-
-DTI_WITH_OPENGL:BOOL=ON
-DTI_WITH_DX11:BOOL=ON
-DTI_WITH_DX12:BOOL=ON
-DTI_WITH_CC:BOOL=OFF
-DTI_BUILD_TESTS:BOOL=ON
-DTI_WITH_C_API=ON
Expand Down
1 change: 1 addition & 0 deletions ci/windows/win_build_test.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ if (!$llvmVer.CompareTo("10")) {
$env:TAICHI_CMAKE_ARGS += " -DCLANG_EXECUTABLE=C:\\taichi_clang_15\\bin\\clang++.exe"
$env:TAICHI_CMAKE_ARGS += " -DLLVM_AS_EXECUTABLE=C:\\taichi_llvm_15\\bin\\llvm-as.exe"
$env:TAICHI_CMAKE_ARGS += " -DTI_LLVM_15:BOOL=ON"
$env:TAICHI_CMAKE_ARGS += " -DTI_WITH_DX12:BOOL=ON"
python3kgae marked this conversation as resolved.
Show resolved Hide resolved
}

$env:TAICHI_CMAKE_ARGS += " -DTI_WITH_VULKAN:BOOL=OFF -DTI_WITH_CUDA:BOOL=OFF -DTI_WITH_OPENGL:BOOL=OFF"
Expand Down
2 changes: 2 additions & 0 deletions cmake/TaichiCore.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ endif()

if (TI_WITH_DX12)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DTI_WITH_DX12")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -gcodeview")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -g -gcodeview")
endif()

## TODO: Remove CC backend
Expand Down
10 changes: 10 additions & 0 deletions cmake/TaichiTests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ if(TI_WITH_OPENGL)
list(APPEND TAICHI_TESTS_SOURCE ${TAICHI_TESTS_OPENGL_SOURCE})
endif()

if(TI_WITH_DX12)
file(GLOB TAICHI_TESTS_DX12_SOURCE "tests/cpp/aot/dx12/*.cpp")
list(APPEND TAICHI_TESTS_SOURCE ${TAICHI_TESTS_DX12_SOURCE})
endif()

add_executable(${TESTS_NAME} ${TAICHI_TESTS_SOURCE})
if (WIN32)
# Output the executable to build/ instead of build/Debug/...
Expand All @@ -64,6 +69,11 @@ if (TI_WITH_OPENGL)
target_link_libraries(${TESTS_NAME} PRIVATE opengl_rhi)
endif()

if (TI_WITH_DX12)
target_link_libraries(${TESTS_NAME} PRIVATE dx12_runtime)
target_link_libraries(${TESTS_NAME} PRIVATE dx12_rhi)
endif()

target_include_directories(${TESTS_NAME}
PRIVATE
${PROJECT_SOURCE_DIR}
Expand Down
8 changes: 4 additions & 4 deletions cpp_examples/aot_save.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
#include "taichi/ir/statements.h"
#include "taichi/program/program.h"

void aot_save() {
void aot_save(taichi::Arch arch) {
using namespace taichi;
using namespace lang;
auto program = Program(Arch::vulkan);
auto program = Program(arch);

program.this_thread_config().advanced_optimization = false;

Expand All @@ -18,7 +18,7 @@ void aot_save() {
place->dt = PrimitiveType::i32;
program.add_snode_tree(std::unique_ptr<SNode>(root), /*compile_only=*/true);

auto aot_builder = program.make_aot_module_builder(Arch::vulkan);
auto aot_builder = program.make_aot_module_builder(arch);

std::unique_ptr<Kernel> kernel_init, kernel_ret;

Expand Down Expand Up @@ -73,6 +73,6 @@ void aot_save() {
aot_builder->add_field("place", place, true, place->dt, {n}, 1, 1);
aot_builder->add("init", kernel_init.get());
aot_builder->add("ret", kernel_ret.get());
aot_builder->dump(".", "aot.tcb");
aot_builder->dump(".", taichi::arch_name(arch) + "_aot.tcb");
std::cout << "done" << std::endl;
}
9 changes: 7 additions & 2 deletions cpp_examples/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@

void run_snode();
void autograd();
void aot_save();
void aot_save(taichi::Arch arch);

int main() {
run_snode();
autograd();
aot_save();
#ifdef TI_WITH_VULKAN
aot_save(taichi::Arch::vulkan);
#endif
#ifdef TI_WITH_DX12
aot_save(taichi::Arch::dx12);
#endif
return 0;
}
5 changes: 5 additions & 0 deletions taichi/aot/module_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "taichi/runtime/gfx/aot_module_loader_impl.h"
#include "taichi/runtime/metal/aot_module_loader_impl.h"
#include "taichi/runtime/dx12/aot_module_loader_impl.h"

namespace taichi {
namespace lang {
Expand Down Expand Up @@ -41,6 +42,10 @@ std::unique_ptr<Module> Module::load(Arch arch, std::any mod_params) {
} else if (arch == Arch::dx11) {
#ifdef TI_WITH_DX11
return gfx::make_aot_module(mod_params, arch);
#endif
} else if (arch == Arch::dx12) {
#ifdef TI_WITH_DX12
return directx12::make_aot_module(mod_params, arch);
#endif
} else if (arch == Arch::metal) {
#ifdef TI_WITH_METAL
Expand Down
2 changes: 1 addition & 1 deletion taichi/codegen/dx12/dx12_global_optimize_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ std::vector<uint8_t> global_optimize_module(llvm::Module *module,
F.addFnAttr(llvm::Attribute::AlwaysInline);
}
// FIXME: choose shader model based on feature used.
llvm::StringRef triple = "dxil-pc-shadermodel6.3-compute";
llvm::StringRef triple = "dxil-pc-shadermodel6.0-compute";
module->setTargetTriple(triple);
module->setSourceFileName("");
std::string err_str;
Expand Down
7 changes: 7 additions & 0 deletions taichi/jit/jit_session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ std::unique_ptr<JITSession> JITSession::create(TaichiLLVMContext *tlctx,
return create_llvm_jit_session_cuda(tlctx, config, arch);
#else
TI_NOT_IMPLEMENTED
#endif
} else if (arch == Arch::dx12) {
#ifdef TI_WITH_DX12
python3kgae marked this conversation as resolved.
Show resolved Hide resolved
// NOTE: there's no jit for dx12. Create cpu session to avoid crash.
return create_llvm_jit_session_cpu(tlctx, config, Arch::x64);
#else
TI_NOT_IMPLEMENTED
#endif
}
#else
Expand Down
1 change: 1 addition & 0 deletions taichi/runtime/dx12/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ add_library(dx12_runtime)
target_sources(dx12_runtime
PRIVATE
aot_module_builder_impl.cpp
aot_module_loader_impl.cpp
)

target_include_directories(dx12_runtime
Expand Down
17 changes: 17 additions & 0 deletions taichi/runtime/dx12/aot_graph_data.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#pragma once
#include "taichi/aot/graph_data.h"

namespace taichi {
namespace lang {
namespace directx12 {
class KernelImpl : public aot::Kernel {
public:
explicit KernelImpl() {
}

void launch(RuntimeContext *ctx) override {
}
};
} // namespace directx12
} // namespace lang
} // namespace taichi
86 changes: 82 additions & 4 deletions taichi/runtime/dx12/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,16 @@ AotModuleBuilderImpl::AotModuleBuilderImpl(LlvmProgramImpl *prog) : prog(prog) {

void AotModuleBuilderImpl::add_per_backend(const std::string &identifier,
Kernel *kernel) {
TI_NOT_IMPLEMENTED;
auto &dxil_codes = module_data.dxil_codes[identifier];
auto &compiled_kernel = module_data.kernels[identifier];

KernelCodeGenDX12 cgen(kernel, /*ir*/ nullptr);
auto compiled_data = cgen.compile();
for (auto &dxil : compiled_data.task_dxil_source_codes) {
dxil_codes.emplace_back(dxil);
}
// FIXME: set compiled kernel.
compiled_kernel.tasks = compiled_data.tasks;
}

void AotModuleBuilderImpl::add_compiled_kernel(aot::Kernel *kernel) {
Expand All @@ -32,18 +41,87 @@ void AotModuleBuilderImpl::add_field_per_backend(const std::string &identifier,
std::vector<int> shape,
int row_num,
int column_num) {
TI_NOT_IMPLEMENTED;
// FIXME: support sparse fields.
TI_ERROR_IF(!all_fields_are_dense_in_container(rep_snode->parent),
"AOT: D12 supports only dense fields for now");

const auto &field = prog->get_cached_field(rep_snode->get_snode_tree_id());

// const auto &dense_desc =
// compiled_structs_[0].snode_descriptors.at(rep_snode->parent->id);

aot::CompiledFieldData field_data;
field_data.field_name = identifier;
field_data.is_scalar = is_scalar;
field_data.dtype = static_cast<int>(dt->cast<PrimitiveType>()->type);
field_data.dtype_name = dt.to_string();
field_data.shape = shape;
// FIXME: calc mem_offset_in_parent for llvm path.
field_data.mem_offset_in_parent = field.snode_metas[0].chunk_size;
// dense_desc.mem_offset_in_parent_cell;
if (!is_scalar) {
field_data.element_shape = {row_num, column_num};
}

module_data.fields.emplace_back(field_data);
}

void AotModuleBuilderImpl::add_per_backend_tmpl(const std::string &identifier,
const std::string &key,
Kernel *kernel) {
TI_NOT_IMPLEMENTED;
// FIXME: share code with add_per_backend.
auto tmpl_identifier = identifier + "|" + key;

auto &dxil_codes = module_data.dxil_codes[tmpl_identifier];
auto &compiled_kernel = module_data.kernels[tmpl_identifier];

KernelCodeGenDX12 cgen(kernel, /*ir*/ nullptr);
auto compiled_data = cgen.compile();
for (auto &dxil : compiled_data.task_dxil_source_codes) {
dxil_codes.emplace_back(dxil);
}
// set compiled kernel.
}

std::string write_dxil_container(const std::string &output_dir,
const std::string &name,
const std::vector<uint8_t> &source_code) {
const std::string path = fmt::format("{}/{}.dxc", output_dir, name);
std::ofstream fs(path, std::ios_base::binary | std::ios::trunc);
fs.write((char *)source_code.data(), source_code.size() * sizeof(uint8_t));
fs.close();
return path;
}

void AotModuleBuilderImpl::dump(const std::string &output_dir,
const std::string &filename) const {
TI_NOT_IMPLEMENTED;
TI_WARN_IF(!filename.empty(),
"Filename prefix is ignored on Unified Device API backends.");
const std::string bin_path = fmt::format("{}/metadata_dx12.tcb", output_dir);
write_to_binary_file(module_data, bin_path);
// Copy module_data to update task.source_path.
auto tmp_module_data = module_data;
for (auto &[name, compiled_kernel] : tmp_module_data.kernels) {
auto it = tmp_module_data.dxil_codes.find(name);
TI_ASSERT(it != tmp_module_data.dxil_codes.end());
auto &dxil_codes = it->second;
auto &tasks = compiled_kernel.tasks;
TI_ASSERT(dxil_codes.size() == tasks.size());
for (int i = 0; i < tasks.size(); ++i) {
auto &dxil_code = dxil_codes[i];
auto &task = tasks[i];
std::string dxil_path =
write_dxil_container(output_dir, task.name, dxil_code);
task.source_path = dxil_path;
}
}

const std::string json_path =
fmt::format("{}/metadata_dx12.json", output_dir);
tmp_module_data.dump_json(json_path);

// FIXME: dump graph to different file.
// dump_graph(output_dir);
}

} // namespace directx12
Expand Down
Loading