Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dx12] Add aot for dx12. #6099

Merged
merged 13 commits into from
Sep 27, 2022
Merged
1 change: 1 addition & 0 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ jobs:
TAICHI_CMAKE_ARGS: >-
-DTI_WITH_OPENGL:BOOL=ON
-DTI_WITH_DX11:BOOL=ON
-DTI_WITH_DX12:BOOL=ON
-DTI_WITH_CC:BOOL=OFF
-DTI_BUILD_TESTS:BOOL=ON
-DTI_WITH_C_API=ON
Expand Down
1 change: 1 addition & 0 deletions ci/windows/win_build_test.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ if (!$llvmVer.CompareTo("10")) {
$env:TAICHI_CMAKE_ARGS += " -DCLANG_EXECUTABLE=C:\\taichi_clang_15\\bin\\clang++.exe"
$env:TAICHI_CMAKE_ARGS += " -DLLVM_AS_EXECUTABLE=C:\\taichi_llvm_15\\bin\\llvm-as.exe"
$env:TAICHI_CMAKE_ARGS += " -DTI_LLVM_15:BOOL=ON"
$env:TAICHI_CMAKE_ARGS += " -DTI_WITH_DX12:BOOL=ON"
python3kgae marked this conversation as resolved.
Show resolved Hide resolved
}

$env:TAICHI_CMAKE_ARGS += " -DTI_WITH_VULKAN:BOOL=OFF -DTI_WITH_CUDA:BOOL=OFF -DTI_WITH_OPENGL:BOOL=OFF"
Expand Down
10 changes: 10 additions & 0 deletions cmake/TaichiTests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ if(TI_WITH_OPENGL)
list(APPEND TAICHI_TESTS_SOURCE ${TAICHI_TESTS_OPENGL_SOURCE})
endif()

if(TI_WITH_DX12)
file(GLOB TAICHI_TESTS_DX12_SOURCE "tests/cpp/aot/dx12/*.cpp")
list(APPEND TAICHI_TESTS_SOURCE ${TAICHI_TESTS_DX12_SOURCE})
endif()

add_executable(${TESTS_NAME} ${TAICHI_TESTS_SOURCE})
if (WIN32)
# Output the executable to build/ instead of build/Debug/...
Expand All @@ -64,6 +69,11 @@ if (TI_WITH_OPENGL)
target_link_libraries(${TESTS_NAME} PRIVATE opengl_rhi)
endif()

if (TI_WITH_DX12)
target_link_libraries(${TESTS_NAME} PRIVATE dx12_runtime)
target_link_libraries(${TESTS_NAME} PRIVATE dx12_rhi)
endif()

target_include_directories(${TESTS_NAME}
PRIVATE
${PROJECT_SOURCE_DIR}
Expand Down
8 changes: 4 additions & 4 deletions cpp_examples/aot_save.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
#include "taichi/ir/statements.h"
#include "taichi/program/program.h"

void aot_save() {
void aot_save(taichi::Arch arch) {
using namespace taichi;
using namespace lang;
auto program = Program(Arch::vulkan);
auto program = Program(arch);

program.this_thread_config().advanced_optimization = false;

Expand All @@ -18,7 +18,7 @@ void aot_save() {
place->dt = PrimitiveType::i32;
program.add_snode_tree(std::unique_ptr<SNode>(root), /*compile_only=*/true);

auto aot_builder = program.make_aot_module_builder(Arch::vulkan);
auto aot_builder = program.make_aot_module_builder(arch);

std::unique_ptr<Kernel> kernel_init, kernel_ret;

Expand Down Expand Up @@ -73,6 +73,6 @@ void aot_save() {
aot_builder->add_field("place", place, true, place->dt, {n}, 1, 1);
aot_builder->add("init", kernel_init.get());
aot_builder->add("ret", kernel_ret.get());
aot_builder->dump(".", "aot.tcb");
aot_builder->dump(".", taichi::arch_name(arch) + "_aot.tcb");
std::cout << "done" << std::endl;
}
9 changes: 7 additions & 2 deletions cpp_examples/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@

void run_snode();
void autograd();
void aot_save();
void aot_save(taichi::Arch arch);

int main() {
run_snode();
autograd();
aot_save();
#ifdef TI_WITH_VULKAN
aot_save(taichi::Arch::vulkan);
#endif
#ifdef TI_WITH_DX12
aot_save(taichi::Arch::dx12);
#endif
return 0;
}
5 changes: 5 additions & 0 deletions taichi/aot/module_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "taichi/runtime/gfx/aot_module_loader_impl.h"
#include "taichi/runtime/metal/aot_module_loader_impl.h"
#include "taichi/runtime/dx12/aot_module_loader_impl.h"

namespace taichi {
namespace lang {
Expand Down Expand Up @@ -41,6 +42,10 @@ std::unique_ptr<Module> Module::load(Arch arch, std::any mod_params) {
} else if (arch == Arch::dx11) {
#ifdef TI_WITH_DX11
return gfx::make_aot_module(mod_params, arch);
#endif
} else if (arch == Arch::dx12) {
#ifdef TI_WITH_DX12
return directx12::make_aot_module(mod_params, arch);
#endif
} else if (arch == Arch::metal) {
#ifdef TI_WITH_METAL
Expand Down
2 changes: 1 addition & 1 deletion taichi/codegen/dx12/dx12_global_optimize_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ std::vector<uint8_t> global_optimize_module(llvm::Module *module,
F.addFnAttr(llvm::Attribute::AlwaysInline);
}
// FIXME: choose shader model based on feature used.
llvm::StringRef triple = "dxil-pc-shadermodel6.3-compute";
llvm::StringRef triple = "dxil-pc-shadermodel6.0-compute";
module->setTargetTriple(triple);
module->setSourceFileName("");
std::string err_str;
Expand Down
7 changes: 7 additions & 0 deletions taichi/jit/jit_session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ std::unique_ptr<JITSession> JITSession::create(TaichiLLVMContext *tlctx,
return create_llvm_jit_session_cuda(tlctx, config, arch);
#else
TI_NOT_IMPLEMENTED
#endif
} else if (arch == Arch::dx12) {
#ifdef TI_WITH_DX12
python3kgae marked this conversation as resolved.
Show resolved Hide resolved
// NOTE: there's no jit for dx12. Create cpu session to avoid crash.
return create_llvm_jit_session_cpu(tlctx, config, Arch::x64);
#else
TI_NOT_IMPLEMENTED
#endif
}
#else
Expand Down
1 change: 1 addition & 0 deletions taichi/runtime/dx12/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ add_library(dx12_runtime)
target_sources(dx12_runtime
PRIVATE
aot_module_builder_impl.cpp
aot_module_loader_impl.cpp
)

target_include_directories(dx12_runtime
Expand Down
17 changes: 17 additions & 0 deletions taichi/runtime/dx12/aot_graph_data.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#pragma once
#include "taichi/aot/graph_data.h"

namespace taichi {
namespace lang {
namespace directx12 {
class KernelImpl : public aot::Kernel {
public:
explicit KernelImpl() {
}

void launch(RuntimeContext *ctx) override {
}
};
} // namespace directx12
} // namespace lang
} // namespace taichi
86 changes: 82 additions & 4 deletions taichi/runtime/dx12/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,16 @@ AotModuleBuilderImpl::AotModuleBuilderImpl(LlvmProgramImpl *prog) : prog(prog) {

void AotModuleBuilderImpl::add_per_backend(const std::string &identifier,
Kernel *kernel) {
TI_NOT_IMPLEMENTED;
auto &dxil_codes = module_data.dxil_codes[identifier];
auto &compiled_kernel = module_data.kernels[identifier];

KernelCodeGenDX12 cgen(kernel, /*ir*/ nullptr);
auto compiled_data = cgen.compile();
for (auto &dxil : compiled_data.task_dxil_source_codes) {
dxil_codes.emplace_back(dxil);
}
// FIXME: set compiled kernel.
compiled_kernel.tasks = compiled_data.tasks;
}

void AotModuleBuilderImpl::add_compiled_kernel(aot::Kernel *kernel) {
Expand All @@ -32,18 +41,87 @@ void AotModuleBuilderImpl::add_field_per_backend(const std::string &identifier,
std::vector<int> shape,
int row_num,
int column_num) {
TI_NOT_IMPLEMENTED;
// FIXME: support sparse fields.
TI_ERROR_IF(!all_fields_are_dense_in_container(rep_snode->parent),
"AOT: D12 supports only dense fields for now");

const auto &field = prog->get_cached_field(rep_snode->get_snode_tree_id());

// const auto &dense_desc =
// compiled_structs_[0].snode_descriptors.at(rep_snode->parent->id);

aot::CompiledFieldData field_data;
field_data.field_name = identifier;
field_data.is_scalar = is_scalar;
field_data.dtype = static_cast<int>(dt->cast<PrimitiveType>()->type);
field_data.dtype_name = dt.to_string();
field_data.shape = shape;
// FIXME: calc mem_offset_in_parent for llvm path.
field_data.mem_offset_in_parent = field.snode_metas[0].chunk_size;
// dense_desc.mem_offset_in_parent_cell;
if (!is_scalar) {
field_data.element_shape = {row_num, column_num};
}

module_data.fields.emplace_back(field_data);
}

void AotModuleBuilderImpl::add_per_backend_tmpl(const std::string &identifier,
const std::string &key,
Kernel *kernel) {
TI_NOT_IMPLEMENTED;
// FIXME: share code with add_per_backend.
auto tmpl_identifier = identifier + "|" + key;

auto &dxil_codes = module_data.dxil_codes[tmpl_identifier];
auto &compiled_kernel = module_data.kernels[tmpl_identifier];

KernelCodeGenDX12 cgen(kernel, /*ir*/ nullptr);
auto compiled_data = cgen.compile();
for (auto &dxil : compiled_data.task_dxil_source_codes) {
dxil_codes.emplace_back(dxil);
}
// set compiled kernel.
}

std::string write_dxil_container(const std::string &output_dir,
const std::string &name,
const std::vector<uint8_t> &source_code) {
const std::string path = fmt::format("{}/{}.dxc", output_dir, name);
std::ofstream fs(path, std::ios_base::binary | std::ios::trunc);
fs.write((char *)source_code.data(), source_code.size() * sizeof(uint8_t));
fs.close();
return path;
}

void AotModuleBuilderImpl::dump(const std::string &output_dir,
const std::string &filename) const {
TI_NOT_IMPLEMENTED;
TI_WARN_IF(!filename.empty(),
"Filename prefix is ignored on Unified Device API backends.");
const std::string bin_path = fmt::format("{}/metadata_dx12.tcb", output_dir);
write_to_binary_file(module_data, bin_path);
// Copy module_data to update task.source_path.
auto tmp_module_data = module_data;
for (auto &[name, compiled_kernel] : tmp_module_data.kernels) {
auto it = tmp_module_data.dxil_codes.find(name);
TI_ASSERT(it != tmp_module_data.dxil_codes.end());
auto &dxil_codes = it->second;
auto &tasks = compiled_kernel.tasks;
TI_ASSERT(dxil_codes.size() == tasks.size());
for (int i = 0; i < tasks.size(); ++i) {
auto &dxil_code = dxil_codes[i];
auto &task = tasks[i];
std::string dxil_path =
write_dxil_container(output_dir, task.name, dxil_code);
task.source_path = dxil_path;
}
}

const std::string json_path =
fmt::format("{}/metadata_dx12.json", output_dir);
tmp_module_data.dump_json(json_path);

// FIXME: dump graph to different file.
// dump_graph(output_dir);
}

} // namespace directx12
Expand Down
132 changes: 132 additions & 0 deletions taichi/runtime/dx12/aot_module_loader_impl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#include "aot_module_loader_impl.h"
#include "aot_module_builder_impl.h"
#include "aot_graph_data.h"
#include <fstream>
#include <type_traits>

#include "taichi/aot/module_data.h"
#include "taichi/aot/graph_data.h"

namespace taichi {
namespace lang {
namespace directx12 {
namespace {
class FieldImpl : public aot::Field {
public:
explicit FieldImpl(const aot::CompiledFieldData &field) : field_(field) {
}

private:
aot::CompiledFieldData field_;
};

class AotModuleImpl : public aot::Module {
public:
explicit AotModuleImpl(const AotModuleParams &params, Arch device_api_backend)
: device_api_backend_(device_api_backend) {
const std::string bin_path =
fmt::format("{}/metadata_dx12.tcb", params.module_path);
read_from_binary_file(module_data, bin_path);

for (auto &[name, compiled_kernel] : module_data.kernels) {
auto &dxil_codes = module_data.dxil_codes[name];
auto &tasks = compiled_kernel.tasks;
for (int i = 0; i < tasks.size(); ++i) {
auto &task = tasks[i];
dxil_codes.emplace_back(
read_dxil_container(params.module_path, task.name));
}
}

// FIXME: enable once write graph to graphs_dx12.tcb.
// const std::string graph_path =
// fmt::format("{}/graphs_dx12.tcb", params.module_path);
// read_from_binary_file(graphs_, graph_path);
}

std::unique_ptr<aot::CompiledGraph> get_graph(
const std::string &name) override {
TI_ERROR_IF(graphs_.count(name) == 0, "Cannot find graph {}", name);
std::vector<aot::CompiledDispatch> dispatches;
for (auto &dispatch : graphs_[name].dispatches) {
dispatches.push_back({dispatch.kernel_name, dispatch.symbolic_args,
get_kernel(dispatch.kernel_name)});
}
aot::CompiledGraph graph{dispatches};
return std::make_unique<aot::CompiledGraph>(std::move(graph));
}

size_t get_root_size() const override {
return module_data.root_buffer_size;
}

// Module metadata
Arch arch() const override {
return device_api_backend_;
}
uint64_t version() const override {
TI_NOT_IMPLEMENTED;
}

private:
bool get_field_data_by_name(const std::string &name,
aot::CompiledFieldData &field) {
for (int i = 0; i < module_data.fields.size(); ++i) {
if (module_data.fields[i].field_name.rfind(name, 0) == 0) {
field = module_data.fields[i];
return true;
}
}
return false;
}

std::unique_ptr<aot::Kernel> make_new_kernel(
const std::string &name) override {
if (module_data.kernels.find(name) == module_data.kernels.end())
return nullptr;
return std::make_unique<KernelImpl>();
}

std::unique_ptr<aot::KernelTemplate> make_new_kernel_template(
const std::string &name) override {
TI_NOT_IMPLEMENTED;
return nullptr;
}

std::unique_ptr<aot::Field> make_new_field(const std::string &name) override {
aot::CompiledFieldData field;
if (!get_field_data_by_name(name, field)) {
TI_DEBUG("Failed to load field {}", name);
return nullptr;
}
return std::make_unique<FieldImpl>(field);
}

std::vector<uint8_t> read_dxil_container(const std::string &output_dir,
const std::string &name) {
const std::string path = fmt::format("{}/{}.dxc", output_dir, name);
std::vector<uint8_t> source_code;
std::ifstream fs(path, std::ios_base::binary | std::ios::ate);
size_t size = fs.tellg();
fs.seekg(0, std::ios::beg);
source_code.resize(size / sizeof(uint8_t));
fs.read((char *)source_code.data(), size);
fs.close();
return source_code;
}

ModuleDataDX12 module_data;
Arch device_api_backend_;
};

} // namespace

std::unique_ptr<aot::Module> make_aot_module(std::any mod_params,
Arch device_api_backend) {
AotModuleParams params = std::any_cast<AotModuleParams &>(mod_params);
return std::make_unique<AotModuleImpl>(params, device_api_backend);
}

} // namespace directx12
} // namespace lang
} // namespace taichi
Loading