Skip to content

Commit

Permalink
[dx12] Add aot for dx12. (taichi-dev#6099)
Browse files Browse the repository at this point in the history
Only make sure the pipeline generate something.
No real dxil generated yet.

Move DX12 build to gpu ci which will run the aot test.

Issue: taichi-dev#5276

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
python3kgae and pre-commit-ci[bot] authored Sep 27, 2022
1 parent be1b139 commit 73dac71
Show file tree
Hide file tree
Showing 14 changed files with 439 additions and 11 deletions.
1 change: 1 addition & 0 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ jobs:
TAICHI_CMAKE_ARGS: >-
-DTI_WITH_OPENGL:BOOL=ON
-DTI_WITH_DX11:BOOL=ON
-DTI_WITH_DX12:BOOL=ON
-DTI_WITH_CC:BOOL=OFF
-DTI_BUILD_TESTS:BOOL=ON
-DTI_WITH_C_API=ON
Expand Down
1 change: 1 addition & 0 deletions ci/windows/win_build_test.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ if (!$llvmVer.CompareTo("10")) {
$env:TAICHI_CMAKE_ARGS += " -DCLANG_EXECUTABLE=C:\\taichi_clang_15\\bin\\clang++.exe"
$env:TAICHI_CMAKE_ARGS += " -DLLVM_AS_EXECUTABLE=C:\\taichi_llvm_15\\bin\\llvm-as.exe"
$env:TAICHI_CMAKE_ARGS += " -DTI_LLVM_15:BOOL=ON"
$env:TAICHI_CMAKE_ARGS += " -DTI_WITH_DX12:BOOL=ON"
}

$env:TAICHI_CMAKE_ARGS += " -DTI_WITH_VULKAN:BOOL=OFF -DTI_WITH_CUDA:BOOL=OFF -DTI_WITH_OPENGL:BOOL=OFF"
Expand Down
10 changes: 10 additions & 0 deletions cmake/TaichiTests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ if(TI_WITH_OPENGL)
list(APPEND TAICHI_TESTS_SOURCE ${TAICHI_TESTS_OPENGL_SOURCE})
endif()

if(TI_WITH_DX12)
file(GLOB TAICHI_TESTS_DX12_SOURCE "tests/cpp/aot/dx12/*.cpp")
list(APPEND TAICHI_TESTS_SOURCE ${TAICHI_TESTS_DX12_SOURCE})
endif()

add_executable(${TESTS_NAME} ${TAICHI_TESTS_SOURCE})
if (WIN32)
# Output the executable to build/ instead of build/Debug/...
Expand All @@ -64,6 +69,11 @@ if (TI_WITH_OPENGL)
target_link_libraries(${TESTS_NAME} PRIVATE opengl_rhi)
endif()

if (TI_WITH_DX12)
target_link_libraries(${TESTS_NAME} PRIVATE dx12_runtime)
target_link_libraries(${TESTS_NAME} PRIVATE dx12_rhi)
endif()

target_include_directories(${TESTS_NAME}
PRIVATE
${PROJECT_SOURCE_DIR}
Expand Down
8 changes: 4 additions & 4 deletions cpp_examples/aot_save.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
#include "taichi/ir/statements.h"
#include "taichi/program/program.h"

void aot_save() {
void aot_save(taichi::Arch arch) {
using namespace taichi;
using namespace lang;
auto program = Program(Arch::vulkan);
auto program = Program(arch);

program.this_thread_config().advanced_optimization = false;

Expand All @@ -18,7 +18,7 @@ void aot_save() {
place->dt = PrimitiveType::i32;
program.add_snode_tree(std::unique_ptr<SNode>(root), /*compile_only=*/true);

auto aot_builder = program.make_aot_module_builder(Arch::vulkan);
auto aot_builder = program.make_aot_module_builder(arch);

std::unique_ptr<Kernel> kernel_init, kernel_ret;

Expand Down Expand Up @@ -73,6 +73,6 @@ void aot_save() {
aot_builder->add_field("place", place, true, place->dt, {n}, 1, 1);
aot_builder->add("init", kernel_init.get());
aot_builder->add("ret", kernel_ret.get());
aot_builder->dump(".", "aot.tcb");
aot_builder->dump(".", taichi::arch_name(arch) + "_aot.tcb");
std::cout << "done" << std::endl;
}
9 changes: 7 additions & 2 deletions cpp_examples/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@

void run_snode();
void autograd();
void aot_save();
void aot_save(taichi::Arch arch);

int main() {
run_snode();
autograd();
aot_save();
#ifdef TI_WITH_VULKAN
aot_save(taichi::Arch::vulkan);
#endif
#ifdef TI_WITH_DX12
aot_save(taichi::Arch::dx12);
#endif
return 0;
}
5 changes: 5 additions & 0 deletions taichi/aot/module_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "taichi/runtime/gfx/aot_module_loader_impl.h"
#include "taichi/runtime/metal/aot_module_loader_impl.h"
#include "taichi/runtime/dx12/aot_module_loader_impl.h"

namespace taichi::lang {
namespace aot {
Expand Down Expand Up @@ -40,6 +41,10 @@ std::unique_ptr<Module> Module::load(Arch arch, std::any mod_params) {
} else if (arch == Arch::dx11) {
#ifdef TI_WITH_DX11
return gfx::make_aot_module(mod_params, arch);
#endif
} else if (arch == Arch::dx12) {
#ifdef TI_WITH_DX12
return directx12::make_aot_module(mod_params, arch);
#endif
} else if (arch == Arch::metal) {
#ifdef TI_WITH_METAL
Expand Down
2 changes: 1 addition & 1 deletion taichi/codegen/dx12/dx12_global_optimize_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ std::vector<uint8_t> global_optimize_module(llvm::Module *module,
F.addFnAttr(llvm::Attribute::AlwaysInline);
}
// FIXME: choose shader model based on feature used.
llvm::StringRef triple = "dxil-pc-shadermodel6.3-compute";
llvm::StringRef triple = "dxil-pc-shadermodel6.0-compute";
module->setTargetTriple(triple);
module->setSourceFileName("");
std::string err_str;
Expand Down
7 changes: 7 additions & 0 deletions taichi/jit/jit_session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ std::unique_ptr<JITSession> JITSession::create(TaichiLLVMContext *tlctx,
return create_llvm_jit_session_cuda(tlctx, config, arch);
#else
TI_NOT_IMPLEMENTED
#endif
} else if (arch == Arch::dx12) {
#ifdef TI_WITH_DX12
// NOTE: there's no jit for dx12. Create cpu session to avoid crash.
return create_llvm_jit_session_cpu(tlctx, config, Arch::x64);
#else
TI_NOT_IMPLEMENTED
#endif
}
#else
Expand Down
1 change: 1 addition & 0 deletions taichi/runtime/dx12/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ add_library(dx12_runtime)
target_sources(dx12_runtime
PRIVATE
aot_module_builder_impl.cpp
aot_module_loader_impl.cpp
)

target_include_directories(dx12_runtime
Expand Down
17 changes: 17 additions & 0 deletions taichi/runtime/dx12/aot_graph_data.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#pragma once
#include "taichi/aot/graph_data.h"

namespace taichi {
namespace lang {
namespace directx12 {
class KernelImpl : public aot::Kernel {
public:
explicit KernelImpl() {
}

void launch(RuntimeContext *ctx) override {
}
};
} // namespace directx12
} // namespace lang
} // namespace taichi
86 changes: 82 additions & 4 deletions taichi/runtime/dx12/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,16 @@ AotModuleBuilderImpl::AotModuleBuilderImpl(LlvmProgramImpl *prog) : prog(prog) {

void AotModuleBuilderImpl::add_per_backend(const std::string &identifier,
Kernel *kernel) {
TI_NOT_IMPLEMENTED;
auto &dxil_codes = module_data.dxil_codes[identifier];
auto &compiled_kernel = module_data.kernels[identifier];

KernelCodeGenDX12 cgen(kernel, /*ir*/ nullptr);
auto compiled_data = cgen.compile();
for (auto &dxil : compiled_data.task_dxil_source_codes) {
dxil_codes.emplace_back(dxil);
}
// FIXME: set compiled kernel.
compiled_kernel.tasks = compiled_data.tasks;
}

void AotModuleBuilderImpl::add_compiled_kernel(aot::Kernel *kernel) {
Expand All @@ -31,18 +40,87 @@ void AotModuleBuilderImpl::add_field_per_backend(const std::string &identifier,
std::vector<int> shape,
int row_num,
int column_num) {
TI_NOT_IMPLEMENTED;
// FIXME: support sparse fields.
TI_ERROR_IF(!all_fields_are_dense_in_container(rep_snode->parent),
"AOT: D12 supports only dense fields for now");

const auto &field = prog->get_cached_field(rep_snode->get_snode_tree_id());

// const auto &dense_desc =
// compiled_structs_[0].snode_descriptors.at(rep_snode->parent->id);

aot::CompiledFieldData field_data;
field_data.field_name = identifier;
field_data.is_scalar = is_scalar;
field_data.dtype = static_cast<int>(dt->cast<PrimitiveType>()->type);
field_data.dtype_name = dt.to_string();
field_data.shape = shape;
// FIXME: calc mem_offset_in_parent for llvm path.
field_data.mem_offset_in_parent = field.snode_metas[0].chunk_size;
// dense_desc.mem_offset_in_parent_cell;
if (!is_scalar) {
field_data.element_shape = {row_num, column_num};
}

module_data.fields.emplace_back(field_data);
}

void AotModuleBuilderImpl::add_per_backend_tmpl(const std::string &identifier,
const std::string &key,
Kernel *kernel) {
TI_NOT_IMPLEMENTED;
// FIXME: share code with add_per_backend.
auto tmpl_identifier = identifier + "|" + key;

auto &dxil_codes = module_data.dxil_codes[tmpl_identifier];
auto &compiled_kernel = module_data.kernels[tmpl_identifier];

KernelCodeGenDX12 cgen(kernel, /*ir*/ nullptr);
auto compiled_data = cgen.compile();
for (auto &dxil : compiled_data.task_dxil_source_codes) {
dxil_codes.emplace_back(dxil);
}
// set compiled kernel.
}

std::string write_dxil_container(const std::string &output_dir,
const std::string &name,
const std::vector<uint8_t> &source_code) {
const std::string path = fmt::format("{}/{}.dxc", output_dir, name);
std::ofstream fs(path, std::ios_base::binary | std::ios::trunc);
fs.write((char *)source_code.data(), source_code.size() * sizeof(uint8_t));
fs.close();
return path;
}

void AotModuleBuilderImpl::dump(const std::string &output_dir,
const std::string &filename) const {
TI_NOT_IMPLEMENTED;
TI_WARN_IF(!filename.empty(),
"Filename prefix is ignored on Unified Device API backends.");
const std::string bin_path = fmt::format("{}/metadata_dx12.tcb", output_dir);
write_to_binary_file(module_data, bin_path);
// Copy module_data to update task.source_path.
auto tmp_module_data = module_data;
for (auto &[name, compiled_kernel] : tmp_module_data.kernels) {
auto it = tmp_module_data.dxil_codes.find(name);
TI_ASSERT(it != tmp_module_data.dxil_codes.end());
auto &dxil_codes = it->second;
auto &tasks = compiled_kernel.tasks;
TI_ASSERT(dxil_codes.size() == tasks.size());
for (int i = 0; i < tasks.size(); ++i) {
auto &dxil_code = dxil_codes[i];
auto &task = tasks[i];
std::string dxil_path =
write_dxil_container(output_dir, task.name, dxil_code);
task.source_path = dxil_path;
}
}

const std::string json_path =
fmt::format("{}/metadata_dx12.json", output_dir);
tmp_module_data.dump_json(json_path);

// FIXME: dump graph to different file.
// dump_graph(output_dir);
}

} // namespace directx12
Expand Down
132 changes: 132 additions & 0 deletions taichi/runtime/dx12/aot_module_loader_impl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#include "aot_module_loader_impl.h"
#include "aot_module_builder_impl.h"
#include "aot_graph_data.h"
#include <fstream>
#include <type_traits>

#include "taichi/aot/module_data.h"
#include "taichi/aot/graph_data.h"

namespace taichi {
namespace lang {
namespace directx12 {
namespace {
class FieldImpl : public aot::Field {
public:
explicit FieldImpl(const aot::CompiledFieldData &field) : field_(field) {
}

private:
aot::CompiledFieldData field_;
};

class AotModuleImpl : public aot::Module {
public:
explicit AotModuleImpl(const AotModuleParams &params, Arch device_api_backend)
: device_api_backend_(device_api_backend) {
const std::string bin_path =
fmt::format("{}/metadata_dx12.tcb", params.module_path);
read_from_binary_file(module_data, bin_path);

for (auto &[name, compiled_kernel] : module_data.kernels) {
auto &dxil_codes = module_data.dxil_codes[name];
auto &tasks = compiled_kernel.tasks;
for (int i = 0; i < tasks.size(); ++i) {
auto &task = tasks[i];
dxil_codes.emplace_back(
read_dxil_container(params.module_path, task.name));
}
}

// FIXME: enable once write graph to graphs_dx12.tcb.
// const std::string graph_path =
// fmt::format("{}/graphs_dx12.tcb", params.module_path);
// read_from_binary_file(graphs_, graph_path);
}

std::unique_ptr<aot::CompiledGraph> get_graph(
const std::string &name) override {
TI_ERROR_IF(graphs_.count(name) == 0, "Cannot find graph {}", name);
std::vector<aot::CompiledDispatch> dispatches;
for (auto &dispatch : graphs_[name].dispatches) {
dispatches.push_back({dispatch.kernel_name, dispatch.symbolic_args,
get_kernel(dispatch.kernel_name)});
}
aot::CompiledGraph graph{dispatches};
return std::make_unique<aot::CompiledGraph>(std::move(graph));
}

size_t get_root_size() const override {
return module_data.root_buffer_size;
}

// Module metadata
Arch arch() const override {
return device_api_backend_;
}
uint64_t version() const override {
TI_NOT_IMPLEMENTED;
}

private:
bool get_field_data_by_name(const std::string &name,
aot::CompiledFieldData &field) {
for (int i = 0; i < module_data.fields.size(); ++i) {
if (module_data.fields[i].field_name.rfind(name, 0) == 0) {
field = module_data.fields[i];
return true;
}
}
return false;
}

std::unique_ptr<aot::Kernel> make_new_kernel(
const std::string &name) override {
if (module_data.kernels.find(name) == module_data.kernels.end())
return nullptr;
return std::make_unique<KernelImpl>();
}

std::unique_ptr<aot::KernelTemplate> make_new_kernel_template(
const std::string &name) override {
TI_NOT_IMPLEMENTED;
return nullptr;
}

std::unique_ptr<aot::Field> make_new_field(const std::string &name) override {
aot::CompiledFieldData field;
if (!get_field_data_by_name(name, field)) {
TI_DEBUG("Failed to load field {}", name);
return nullptr;
}
return std::make_unique<FieldImpl>(field);
}

std::vector<uint8_t> read_dxil_container(const std::string &output_dir,
const std::string &name) {
const std::string path = fmt::format("{}/{}.dxc", output_dir, name);
std::vector<uint8_t> source_code;
std::ifstream fs(path, std::ios_base::binary | std::ios::ate);
size_t size = fs.tellg();
fs.seekg(0, std::ios::beg);
source_code.resize(size / sizeof(uint8_t));
fs.read((char *)source_code.data(), size);
fs.close();
return source_code;
}

ModuleDataDX12 module_data;
Arch device_api_backend_;
};

} // namespace

std::unique_ptr<aot::Module> make_aot_module(std::any mod_params,
Arch device_api_backend) {
AotModuleParams params = std::any_cast<AotModuleParams &>(mod_params);
return std::make_unique<AotModuleImpl>(params, device_api_backend);
}

} // namespace directx12
} // namespace lang
} // namespace taichi
Loading

0 comments on commit 73dac71

Please sign in to comment.