Skip to content

Commit

Permalink
[amdgpu] Part5 enable the api of amdgpu (taichi-dev#7202)
Browse files Browse the repository at this point in the history
Issue: taichi-dev#6434

### Brief Summary
1. enable amdgpu api in taichi(except struct for)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and quadpixels committed May 13, 2023
1 parent 7e41de2 commit d7287d1
Show file tree
Hide file tree
Showing 31 changed files with 294 additions and 55 deletions.
2 changes: 2 additions & 0 deletions c_api/include/taichi/taichi_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,8 @@ typedef enum TiArch {
TI_ARCH_OPENGL = 6,
// OpenGL ES GPU backend.
TI_ARCH_GLES = 7,
// AMDGPU backend
TI_ARCH_AMDGPU = 8,
TI_ARCH_MAX_ENUM = 0xffffffff,
} TiArch;

Expand Down
5 changes: 3 additions & 2 deletions cmake/TaichiCore.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ file(GLOB TAICHI_CORE_SOURCE
"taichi/system/*"
"taichi/transforms/*"
"taichi/aot/*.cpp" "taichi/aot/*.h"
"taichi/platform/cuda/*" "taichi/platform/mac/*" "taichi/platform/windows/*"
"taichi/platform/cuda/*" "taichi/platform/amdgpu/*"
"taichi/platform/mac/*" "taichi/platform/windows/*"
"taichi/codegen/*.cpp" "taichi/codegen/*.h"
"taichi/runtime/*.h" "taichi/runtime/*.cpp"
"taichi/rhi/*.h" "taichi/rhi/*.cpp"
Expand All @@ -116,7 +117,7 @@ endif()

if (TI_WITH_AMDGPU)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DTI_WITH_AMDGPU")
# file(GLOB TAICHI_AMDGPU_RUNTIME_SOURCE "taichi/runtime/amdgpu/runtime.cpp")
file(GLOB TAICHI_AMDGPU_RUNTIME_SOURCE "taichi/runtime/amdgpu/runtime.cpp")
list(APPEND TAIHI_CORE_SOURCE ${TAICHI_AMDGPU_RUNTIME_SOURCE})
endif()

Expand Down
18 changes: 12 additions & 6 deletions python/taichi/lang/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@
"""
# ----------------------

amdgpu = _ti_core.amdgpu
"""The AMDGPU backend.
"""
# ----------------------

metal = _ti_core.metal
"""The Apple Metal backend.
"""
Expand Down Expand Up @@ -159,9 +164,9 @@
"""
# ----------------------

gpu = [cuda, metal, vulkan, opengl, dx11, dx12, gles]
gpu = [cuda, metal, vulkan, opengl, dx11, dx12, gles, amdgpu]
"""A list of GPU backends supported on the current system.
Currently contains 'cuda', 'metal', 'opengl', 'vulkan', 'dx11', 'dx12', 'gles'.
Currently contains 'cuda', 'metal', 'opengl', 'vulkan', 'dx11', 'dx12', 'gles', 'amdgpu'.
When this is used, Taichi automatically picks the matching GPU backend. If no
GPU is detected, Taichi falls back to the CPU backend.
Expand Down Expand Up @@ -726,6 +731,7 @@ def is_arch_supported(arch):

arch_table = {
cuda: _ti_core.with_cuda,
amdgpu: _ti_core.with_amdgpu,
metal: _ti_core.with_metal,
opengl: functools.partial(_ti_core.with_opengl, False),
gles: functools.partial(_ti_core.with_opengl, True),
Expand Down Expand Up @@ -773,8 +779,8 @@ def get_compute_stream_device_time_elapsed_us() -> float:
__all__ = [
'i', 'ij', 'ijk', 'ijkl', 'ijl', 'ik', 'ikl', 'il', 'j', 'jk', 'jkl', 'jl',
'k', 'kl', 'l', 'x86_64', 'x64', 'dx11', 'dx12', 'wasm', 'arm64', 'cc',
'cpu', 'cuda', 'gles', 'gpu', 'metal', 'opengl', 'vulkan', 'extension',
'loop_config', 'global_thread_idx', 'assume_in_range', 'block_local',
'cache_read_only', 'init', 'mesh_local', 'no_activate', 'reset',
'mesh_patch_idx', 'get_compute_stream_device_time_elapsed_us'
'cpu', 'cuda', 'amdgpu', 'gles', 'gpu', 'metal', 'opengl', 'vulkan',
'extension', 'loop_config', 'global_thread_idx', 'assume_in_range',
'block_local', 'cache_read_only', 'init', 'mesh_local', 'no_activate',
'reset', 'mesh_patch_idx', 'get_compute_stream_device_time_elapsed_us'
]
30 changes: 25 additions & 5 deletions taichi/codegen/amdgpu/codegen_amdgpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,12 +265,32 @@ class TaskCodeGenAMDGPU : public TaskCodeGenLLVM {
}

void visit(GlobalLoadStmt *stmt) override {
if (auto get_ch = stmt->src->cast<GetChStmt>()) {
bool should_cache_as_read_only = current_offload->mem_access_opt.has_flag(
get_ch->output_snode, SNodeAccessFlag::read_only);
create_global_load(stmt, should_cache_as_read_only);
auto ptr = llvm_val[stmt->src];
auto ptr_type = stmt->src->ret_type->as<PointerType>();
if (ptr_type->is_bit_pointer()) {
auto val_type = ptr_type->get_pointee_type();
auto get_ch = stmt->src->as<GetChStmt>();
auto physical_type =
tlctx->get_data_type(get_ch->input_snode->physical_type);
auto [byte_ptr, bit_offset] = load_bit_ptr(ptr);
auto physical_value = builder->CreateLoad(physical_type, byte_ptr);
if (auto qit = val_type->cast<QuantIntType>()) {
llvm_val[stmt] = extract_quant_int(physical_value, bit_offset, qit);
} else if (auto qfxt = val_type->cast<QuantFixedType>()) {
qit = qfxt->get_digits_type()->as<QuantIntType>();
auto digits = extract_quant_int(physical_value, bit_offset, qit);
llvm_val[stmt] = reconstruct_quant_fixed(digits, qfxt);
} else {
TI_ASSERT(val_type->is<QuantFloatType>());
TI_ASSERT(get_ch->input_snode->dt->is<BitStructType>());
llvm_val[stmt] = extract_quant_float(
physical_value, get_ch->input_snode->dt->as<BitStructType>(),
get_ch->output_snode->id_in_bit_struct);
}
} else {
create_global_load(stmt, false);
// Byte pointer case.
llvm_val[stmt] =
builder->CreateLoad(tlctx->get_data_type(stmt->ret_type), ptr);
}
}

Expand Down
9 changes: 9 additions & 0 deletions taichi/codegen/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
#if defined(TI_WITH_DX12)
#include "taichi/codegen/dx12/codegen_dx12.h"
#endif
#if defined(TI_WITH_AMDGPU)
#include "taichi/codegen/amdgpu/codegen_amdgpu.h"
#endif
#include "taichi/system/timer.h"
#include "taichi/ir/analysis.h"
#include "taichi/ir/transforms.h"
Expand Down Expand Up @@ -47,6 +50,12 @@ std::unique_ptr<KernelCodeGen> KernelCodeGen::create(
return std::make_unique<KernelCodeGenDX12>(compile_config, kernel);
#else
TI_NOT_IMPLEMENTED
#endif
} else if (arch == Arch::amdgpu) {
#if defined(TI_WITH_AMDGPU)
return std::make_unique<KernelCodeGenAMDGPU>(compile_config, kernel);
#else
TI_NOT_IMPLEMENTED
#endif
} else {
TI_NOT_IMPLEMENTED
Expand Down
6 changes: 6 additions & 0 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2634,6 +2634,12 @@ LLVMCompiledTask TaskCodeGenLLVM::run_compilation() {
TI_ASSERT(func);
tlctx->mark_function_as_cuda_kernel(func, task.block_dim);
}
} else if (compile_config.arch == Arch::amdgpu) {
for (const auto &task : offloaded_tasks) {
llvm::Function *func = module->getFunction(task.name);
TI_ASSERT(func);
tlctx->mark_function_as_amdgpu_kernel(func);
}
}

return {std::move(offloaded_tasks), std::move(module),
Expand Down
2 changes: 1 addition & 1 deletion taichi/inc/archs.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ PER_ARCH(opengl) // OpenGL Compute Shaders
PER_ARCH(dx11) // Microsoft DirectX 11, WIP
PER_ARCH(dx12) // Microsoft DirectX 12, WIP
PER_ARCH(opencl) // OpenCL, N/A
PER_ARCH(amdgpu) // AMD GPU, WIP
PER_ARCH(amdgpu) // AMD GPU
PER_ARCH(vulkan) // Vulkan
PER_ARCH(gles) // OpenGL ES
7 changes: 4 additions & 3 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ void FrontendForStmt::init_config(Arch arch, const ForLoopConfig &config) {
strictly_serialized = config.strictly_serialized;
mem_access_opt = config.mem_access_opt;
block_dim = config.block_dim;
if (arch == Arch::cuda) {
if (arch == Arch::cuda || arch == Arch::amdgpu) {
num_cpu_threads = 1;
TI_ASSERT(block_dim <= taichi_max_gpu_block_dim);
} else { // cpu
Expand Down Expand Up @@ -1284,8 +1284,9 @@ void ASTBuilder::insert_for(const Expr &s,

Expr ASTBuilder::insert_thread_idx_expr() {
auto loop = stack_.size() ? stack_.back()->parent_stmt : nullptr;
TI_ERROR_IF(arch_ != Arch::cuda && !arch_is_cpu(arch_),
"ti.thread_idx() is only available in cuda or cpu context.");
TI_ERROR_IF(
arch_ != Arch::cuda && !arch_is_cpu(arch_) && arch_ != Arch::amdgpu,
"ti.thread_idx() is only available in cuda or cpu or amdgpu context.");
if (loop != nullptr) {
auto i = stack_.size() - 1;
while (!(loop->is<FrontendForStmt>())) {
Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,7 @@ class ASTBuilder {
}

void block_dim(int v) {
if (arch_ == Arch::cuda || arch_ == Arch::vulkan) {
if (arch_ == Arch::cuda || arch_ == Arch::vulkan || arch_ == Arch::amdgpu) {
TI_ASSERT((v % 32 == 0) || bit::is_power_of_two(v));
} else {
TI_ASSERT(bit::is_power_of_two(v));
Expand Down
11 changes: 11 additions & 0 deletions taichi/jit/jit_session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ std::unique_ptr<JITSession> create_llvm_jit_session_cuda(
TaichiLLVMContext *tlctx,
const CompileConfig &config,
Arch arch);

std::unique_ptr<JITSession> create_llvm_jit_session_amdgpu(
TaichiLLVMContext *tlctx,
const CompileConfig &config,
Arch arch);
#endif

JITSession::JITSession(TaichiLLVMContext *tlctx, const CompileConfig &config)
Expand All @@ -40,6 +45,12 @@ std::unique_ptr<JITSession> JITSession::create(TaichiLLVMContext *tlctx,
return create_llvm_jit_session_cpu(tlctx, config, Arch::x64);
#else
TI_NOT_IMPLEMENTED
#endif
} else if (arch == Arch::amdgpu) {
#ifdef TI_WITH_AMDGPU
return create_llvm_jit_session_amdgpu(tlctx, config, arch);
#else
TI_NOT_IMPLEMENTED
#endif
}
#else
Expand Down
17 changes: 17 additions & 0 deletions taichi/platform/amdgpu/detect_amdgpu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include "taichi/platform/amdgpu/detect_amdgpu.h"

#if defined(TI_WITH_AMDGPU)
#include "taichi/rhi/amdgpu/amdgpu_driver.h"
#endif

namespace taichi {

bool is_rocm_api_available() {
#if defined(TI_WITH_AMDGPU)
return lang::AMDGPUDriver::get_instance_without_context().detected();
#else
return false;
#endif
}

} // namespace taichi
5 changes: 5 additions & 0 deletions taichi/platform/amdgpu/detect_amdgpu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#pragma once

namespace taichi {
bool is_rocm_api_available();
} // namespace taichi
2 changes: 1 addition & 1 deletion taichi/program/compile_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ CompileConfig::CompileConfig() {
print_kernel_nvptx = false;
print_kernel_llvm_ir_optimized = false;

// CUDA backend options:
// CUDA/AMDGPU backend options:
device_memory_GB = 1; // by default, preallocate 1 GB GPU memory
device_memory_fraction = 0.0;

Expand Down
2 changes: 1 addition & 1 deletion taichi/program/compile_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ struct CompileConfig {
bool print_kernel_llvm_ir_optimized;
bool print_kernel_nvptx;

// CUDA backend options:
// CUDA/AMDGPU backend options:
float64 device_memory_GB;
float64 device_memory_fraction;

Expand Down
3 changes: 2 additions & 1 deletion taichi/program/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ void Kernel::operator()(const CompileConfig &compile_config,
compiled_(ctx_builder.get_context());

const auto arch = compile_config.arch;
if (compile_config.debug && (arch_is_cpu(arch) || arch == Arch::cuda)) {
if (compile_config.debug &&
(arch_is_cpu(arch) || arch == Arch::cuda || arch == Arch::amdgpu)) {
program->check_runtime_error();
}
}
Expand Down
6 changes: 4 additions & 2 deletions taichi/program/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "taichi/ir/frontend_ir.h"
#include "taichi/program/snode_expr_utils.h"
#include "taichi/math/arithmetic.h"

#ifdef TI_WITH_LLVM
#include "taichi/runtime/program_impls/llvm/llvm_program.h"
#include "taichi/codegen/llvm/struct_llvm.h"
Expand Down Expand Up @@ -349,7 +350,7 @@ Ndarray *Program::create_ndarray(const DataType type,
auto arr = std::make_unique<Ndarray>(this, type, shape, layout);
if (zero_fill) {
Arch arch = compile_config().arch;
if (arch_is_cpu(arch) || arch == Arch::cuda) {
if (arch_is_cpu(arch) || arch == Arch::cuda || arch == Arch::amdgpu) {
fill_ndarray_fast_u32(arr.get(), /*data=*/0);
} else if (arch != Arch::dx12) {
// Device api support for dx12 backend are not complete yet
Expand Down Expand Up @@ -408,7 +409,8 @@ Texture *Program::create_texture(const DataType type,
intptr_t Program::get_ndarray_data_ptr_as_int(const Ndarray *ndarray) {
uint64_t *data_ptr{nullptr};
if (arch_is_cpu(compile_config().arch) ||
compile_config().arch == Arch::cuda) {
compile_config().arch == Arch::cuda ||
compile_config().arch == Arch::amdgpu) {
// For the LLVM backends, device allocation is a physical pointer.
data_ptr =
program_impl_->get_ndarray_alloc_info_ptr(ndarray->ndarray_alloc_);
Expand Down
6 changes: 6 additions & 0 deletions taichi/python/export_misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
#include "taichi/rhi/cuda/cuda_driver.h"
#endif

#include "taichi/platform/amdgpu/detect_amdgpu.h"
#if defined(TI_WITH_AMDGPU)
#include "taichi/rhi/amdgpu/amdgpu_driver.h"
#endif

#ifdef TI_WITH_VULKAN
#include "taichi/rhi/vulkan/vulkan_loader.h"
#endif
Expand Down Expand Up @@ -144,6 +149,7 @@ void export_misc(py::module &m) {
m.def("pop_python_print_buffer", []() { return py_cout.pop_content(); });
m.def("toggle_python_print_buffer", [](bool opt) { py_cout.enabled = opt; });
m.def("with_cuda", is_cuda_api_available);
m.def("with_amdgpu", is_rocm_api_available);
#ifdef TI_WITH_METAL
m.def("with_metal", taichi::lang::metal::is_metal_api_available);
#else
Expand Down
2 changes: 1 addition & 1 deletion taichi/rhi/arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ bool arch_is_cuda(Arch arch) {

bool arch_uses_llvm(Arch arch) {
return (arch == Arch::x64 || arch == Arch::arm64 || arch == Arch::cuda ||
arch == Arch::dx12 || arch == Arch::wasm);
arch == Arch::dx12 || arch == Arch::wasm || arch == Arch::amdgpu);
}

bool arch_is_gpu(Arch arch) {
Expand Down
4 changes: 4 additions & 0 deletions taichi/rhi/interop/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ if (TI_WITH_CUDA)
target_compile_definitions(${INTEROP_RHI} PRIVATE -DTI_WITH_CUDA)
endif()

if (TI_WITH_AMDGPU)
target_compile_definitions(${INTEROP_RHI} PRIVATE -DTI_WITH_AMDGPU)
endif()

if (TI_WITH_VULKAN)
target_compile_definitions(${INTEROP_RHI} PRIVATE -DTI_WITH_VULKAN)
endif()
Expand Down
14 changes: 14 additions & 0 deletions taichi/runtime/llvm/llvm_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,16 @@ TaichiLLVMContext::TaichiLLVMContext(const CompileConfig &config, Arch arch)
LLVMInitializeDirectXTargetMC();
LLVMInitializeDirectXTargetInfo();
LLVMInitializeDirectXAsmPrinter();
#endif
} else if (arch == Arch::amdgpu) {
#if defined(TI_WITH_AMDGPU)
LLVMInitializeAMDGPUTarget();
LLVMInitializeAMDGPUTargetMC();
LLVMInitializeAMDGPUTargetInfo();
LLVMInitializeAMDGPUAsmPrinter();
LLVMInitializeAMDGPUAsmParser();
#else
TI_NOT_IMPLEMENTED
#endif
} else {
#if defined(TI_WITH_CUDA)
Expand Down Expand Up @@ -803,6 +813,10 @@ void TaichiLLVMContext::mark_function_as_cuda_kernel(llvm::Function *func,
}
}

void TaichiLLVMContext::mark_function_as_amdgpu_kernel(llvm::Function *func) {
func->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
}

void TaichiLLVMContext::eliminate_unused_functions(
llvm::Module *module,
std::function<bool(const std::string &)> export_indicator) {
Expand Down
2 changes: 2 additions & 0 deletions taichi/runtime/llvm/llvm_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ class TaichiLLVMContext {

void mark_function_as_cuda_kernel(llvm::Function *func, int block_dim = 0);

void mark_function_as_amdgpu_kernel(llvm::Function *func);

void fetch_this_thread_struct_module();
llvm::Module *get_this_thread_runtime_module();
llvm::Function *get_runtime_function(const std::string &name);
Expand Down
Loading

0 comments on commit d7287d1

Please sign in to comment.