Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Aug 3, 2022
1 parent 398484c commit 841a6d7
Show file tree
Hide file tree
Showing 11 changed files with 33 additions and 44 deletions.
16 changes: 8 additions & 8 deletions taichi/codegen/dx12/codegen_dx12.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,9 @@ class TaskCodeGenLLVMDX12 : public TaskCodeGenLLVM {

#ifdef TI_WITH_LLVM

static std::vector<uint8_t> generate_dxil_from_llvm(LLVMCompiledData &compiled_data,
taichi::lang::Kernel *kernel) {
static std::vector<uint8_t> generate_dxil_from_llvm(
LLVMCompiledData &compiled_data,
taichi::lang::Kernel *kernel) {
// generate dxil from llvm ir.
auto offloaded_local = compiled_data.tasks;
auto module = compiled_data.module.get();
Expand All @@ -162,7 +163,8 @@ static std::vector<uint8_t> generate_dxil_from_llvm(LLVMCompiledData &compiled_d
// FIXME: save task.block_dim like
// tlctx->mark_function_as_cuda_kernel(func, task.block_dim);
}
auto dx_container = directx12::global_optimize_module(module, kernel->program->config);
auto dx_container =
directx12::global_optimize_module(module, kernel->program->config);
// validate and sign dx container.
return directx12::validate_and_sign(dx_container);
}
Expand Down Expand Up @@ -197,8 +199,7 @@ KernelCodeGenDX12::CompileResult KernelCodeGenDX12::compile() {
aot::CompiledOffloadedTask task;
// FIXME: build all fields for task.
task.name = fmt::format("{}_{}_{}", kernel->get_name(),
offload_stmt->task_name(),
i);
offload_stmt->task_name(), i);
task.type = offload_stmt->task_name();
Result.tasks.emplace_back(task);
}
Expand All @@ -209,15 +210,14 @@ KernelCodeGenDX12::CompileResult KernelCodeGenDX12::compile() {

LLVMCompiledData KernelCodeGenDX12::modulegen(
std::unique_ptr<llvm::Module> &&module,
OffloadedStmt *stmt) {
OffloadedStmt *stmt) {
TaskCodeGenLLVMDX12 gen(kernel, stmt);
return gen.run_compilation();
}
#endif // TI_WITH_LLVM

FunctionType KernelCodeGenDX12::codegen() {
// FIXME: implement codegen.
return [](RuntimeContext &ctx) {
};
return [](RuntimeContext &ctx) {};
}
TLANG_NAMESPACE_END
9 changes: 3 additions & 6 deletions taichi/codegen/dx12/dx12_llvm_passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class Function;
class Module;
class Type;
class GlobalVariable;
}
} // namespace llvm

namespace taichi {
namespace lang {
Expand All @@ -26,14 +26,13 @@ enum class BufferSpaceId {
Args = 3,
Runtime = 4,
Result = 5,
UtilCBuffer = 6, // For things like Num work groups.
UtilCBuffer = 6, // For things like Num work groups.
};

enum ResourceAddressSpace {
CBuffer = 4,
};


void mark_function_as_cs_entry(llvm::Function *);
bool is_cs_entry(llvm::Function *);
void set_num_threads(llvm::Function *, unsigned x, unsigned y, unsigned z);
Expand All @@ -52,7 +51,6 @@ extern const char *NumWorkGroupsCBName;
} // namespace lang
} // namespace taichi


namespace llvm {
class ModulePass;
class PassRegistry;
Expand All @@ -64,11 +62,10 @@ void initializeTaichiRuntimeContextLowerPass(PassRegistry &);
/// Pass to convert modules into DXIL-compatable modules
ModulePass *createTaichiRuntimeContextLowerPass();


/// Initializer for taichi intrinsic lower.
void initializeTaichiIntrinsicLowerPass(PassRegistry &);

/// Pass to lower taichi intrinsic into DXIL intrinsic.
ModulePass *createTaichiIntrinsicLowerPass(taichi::lang::CompileConfig *config);

}
} // namespace llvm
10 changes: 4 additions & 6 deletions taichi/codegen/dx12/dx12_lower_runtime_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ using namespace taichi::lang::directx12;

#define DEBUG_TYPE "dxil-taichi-runtime-context-lower"


namespace {

class TaichiRuntimeContextLower : public ModulePass {
Expand All @@ -38,12 +37,11 @@ char TaichiRuntimeContextLower::ID = 0;

} // end anonymous namespace


INITIALIZE_PASS(TaichiRuntimeContextLower,
DEBUG_TYPE,
"Lower taichi RuntimeContext",
false,
false)
DEBUG_TYPE,
"Lower taichi RuntimeContext",
false,
false)

llvm::ModulePass *llvm::createTaichiRuntimeContextLowerPass() {
return new TaichiRuntimeContextLower();
Expand Down
3 changes: 1 addition & 2 deletions taichi/rhi/arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ bool arch_is_cpu(Arch arch) {

bool arch_uses_llvm(Arch arch) {
return (arch == Arch::x64 || arch == Arch::arm64 || arch == Arch::cuda ||
arch == Arch::dx12 ||
arch == Arch::wasm);
arch == Arch::dx12 || arch == Arch::wasm);
}

bool arch_is_gpu(Arch arch) {
Expand Down
5 changes: 3 additions & 2 deletions taichi/rhi/dx12/dx12_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ bool is_dx12_api_available();

std::shared_ptr<Device> make_dx12_device();

std::vector<uint8_t> validate_and_sign(std::vector<uint8_t> &input_dxil_container);
std::vector<uint8_t> validate_and_sign(
std::vector<uint8_t> &input_dxil_container);

} // namespace directx11
} // namespace directx12
} // namespace lang
} // namespace taichi
1 change: 0 additions & 1 deletion taichi/runtime/dx12/aot_graph_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ class KernelImpl : public aot::Kernel {

void launch(RuntimeContext *ctx) override {
}

};
} // namespace directx12
} // namespace lang
Expand Down
16 changes: 7 additions & 9 deletions taichi/runtime/dx12/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ AotModuleBuilderImpl::AotModuleBuilderImpl(LlvmProgramImpl *prog) : prog(prog) {

void AotModuleBuilderImpl::add_per_backend(const std::string &identifier,
Kernel *kernel) {


auto &dxil_codes = module_data.dxil_codes[identifier];
auto &compiled_kernel = module_data.kernels[identifier];

Expand Down Expand Up @@ -50,11 +48,11 @@ void AotModuleBuilderImpl::add_field_per_backend(const std::string &identifier,
// matter too much for now.
TI_ERROR_IF(!all_fields_are_dense_in_container(rep_snode->parent),
"AOT: only supports dense field");

const auto &field = prog->get_cached_field(rep_snode->get_snode_tree_id());

//const auto &dense_desc =
// compiled_structs_[0].snode_descriptors.at(rep_snode->parent->id);
// const auto &dense_desc =
// compiled_structs_[0].snode_descriptors.at(rep_snode->parent->id);

aot::CompiledFieldData field_data;
field_data.field_name = identifier;
Expand All @@ -64,7 +62,7 @@ void AotModuleBuilderImpl::add_field_per_backend(const std::string &identifier,
field_data.shape = shape;
// FIXME: calc mem_offset_in_parent for llvm path.
field_data.mem_offset_in_parent = field.snode_metas[0].chunk_size;
//dense_desc.mem_offset_in_parent_cell;
// dense_desc.mem_offset_in_parent_cell;
if (!is_scalar) {
field_data.element_shape = {row_num, column_num};
}
Expand Down Expand Up @@ -122,14 +120,14 @@ void AotModuleBuilderImpl::dump(const std::string &output_dir,
}
}

const std::string json_path = fmt::format("{}/metadata_dx12.json", output_dir);
const std::string json_path =
fmt::format("{}/metadata_dx12.json", output_dir);
tmp_module_data.dump_json(json_path);

// FIXME: dump graph to different file.
//dump_graph(output_dir);
// dump_graph(output_dir);
}


} // namespace directx12
} // namespace lang
} // namespace taichi
11 changes: 5 additions & 6 deletions taichi/runtime/dx12/aot_module_loader_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ namespace directx12 {
namespace {
class FieldImpl : public aot::Field {
public:
explicit FieldImpl(const aot::CompiledFieldData &field)
: field_(field) {
explicit FieldImpl(const aot::CompiledFieldData &field) : field_(field) {
}

private:
Expand All @@ -40,9 +39,9 @@ class AotModuleImpl : public aot::Module {
}

// FIXME: enable once write graph to graphs_dx12.tcb.
//const std::string graph_path =
// const std::string graph_path =
// fmt::format("{}/graphs_dx12.tcb", params.module_path);
//read_from_binary_file(graphs_, graph_path);
// read_from_binary_file(graphs_, graph_path);
}

std::unique_ptr<aot::CompiledGraph> get_graph(std::string name) override {
Expand Down Expand Up @@ -103,7 +102,7 @@ class AotModuleImpl : public aot::Module {
}

std::vector<uint8_t> read_dxil_container(const std::string &output_dir,
const std::string &name) {
const std::string &name) {
const std::string path = fmt::format("{}/{}.dxc", output_dir, name);
std::vector<uint8_t> source_code;
std::ifstream fs(path, std::ios_base::binary | std::ios::ate);
Expand All @@ -127,6 +126,6 @@ std::unique_ptr<aot::Module> make_aot_module(std::any mod_params,
return std::make_unique<AotModuleImpl>(params, device_api_backend);
}

} // namespace gfx
} // namespace directx12
} // namespace lang
} // namespace taichi
2 changes: 1 addition & 1 deletion taichi/runtime/llvm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ endif()
if (TI_WITH_DX12)
target_link_libraries(llvm_runtime PRIVATE ${llvm_directx_libs})
target_link_libraries(llvm_runtime PRIVATE dx12_rhi)
endif()
endif()
2 changes: 1 addition & 1 deletion taichi/runtime/llvm/llvm_runtime_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ LlvmRuntimeExecutor::LlvmRuntimeExecutor(CompileConfig &config,
llvm_context_device_ =
std::make_unique<TaichiLLVMContext>(config_, Arch::dx12);
// FIXME: add dx12 JIT.
//llvm_context_device_->init_runtime_jit_module();
// llvm_context_device_->init_runtime_jit_module();
}
#endif

Expand Down
2 changes: 0 additions & 2 deletions tests/cpp/aot/dx12/aot_save_load_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ using namespace lang;
aot_builder->dump(".", "");
}


#ifdef TI_WITH_DX12

TEST(AotSaveLoad, DX12) {
Expand Down Expand Up @@ -134,7 +133,6 @@ TEST(AotSaveLoad, DX12) {
EXPECT_FALSE(ret2_kernel);

// FIXME: test run kernels and check the result.

}

#endif

0 comments on commit 841a6d7

Please sign in to comment.