diff --git a/taichi/backends/codegen_metal.cpp b/taichi/backends/codegen_metal.cpp index 4c5f42afe31d8..a394c8a93ebcd 100644 --- a/taichi/backends/codegen_metal.cpp +++ b/taichi/backends/codegen_metal.cpp @@ -8,15 +8,20 @@ namespace metal { namespace { constexpr char kKernelThreadIdName[] = "utid_"; // 'u' for unsigned +constexpr char kRootBufferName[] = "root_addr"; constexpr char kGlobalTmpsBufferName[] = "global_tmps_addr"; +constexpr char kArgsBufferName[] = "args_addr"; constexpr char kArgsContextName[] = "args_ctx_"; class MetalKernelCodegen : public IRVisitor { public: MetalKernelCodegen(const std::string &mtl_kernel_prefix, - const std::string &root_snode_type_name) + const std::string &root_snode_type_name, + const StructCompiledResult *compiled_snode_structs) : mtl_kernel_prefix_(mtl_kernel_prefix), - root_snode_type_name_(root_snode_type_name) { + root_snode_type_name_(root_snode_type_name), + compiled_snode_structs_(compiled_snode_structs), + needs_root_buffer_(compiled_snode_structs_->root_size > 0) { // allow_undefined_visitor = true; } @@ -26,8 +31,8 @@ class MetalKernelCodegen : public IRVisitor { return mtl_kernels_attribs_; } - void run(const std::string &snode_structs_source_code, Kernel *kernel) { - generate_mtl_header(snode_structs_source_code); + void run(Kernel *kernel) { + generate_mtl_header(compiled_snode_structs_->source_code); generate_kernel_args_struct(kernel); kernel->ir->accept(this); } @@ -79,8 +84,10 @@ class MetalKernelCodegen : public IRVisitor { void visit(GetRootStmt *stmt) override { // Should we assert |root_stmt_| is assigned only once? + TI_ASSERT(needs_root_buffer_); root_stmt_ = stmt; - emit(R"({} {}(addr);)", root_snode_type_name_, stmt->raw_name()); + emit(R"({} {}({});)", root_snode_type_name_, stmt->raw_name(), + kRootBufferName); } void visit(GetChStmt *stmt) override { @@ -512,15 +519,22 @@ class MetalKernelCodegen : public IRVisitor { void emit_mtl_kernel_func_sig(const std::string &kernel_name) { emit("kernel void {}(", kernel_name); - emit(" device byte* addr [[buffer(0)]],"); - emit(" device byte* {} [[buffer(1)]],", kGlobalTmpsBufferName); + int buffer_idx = 0; + if (needs_root_buffer_) { + emit(" device byte* {} [[buffer({})]],", kRootBufferName, + buffer_idx++); + } + emit(" device byte* {} [[buffer({})]],", kGlobalTmpsBufferName, + buffer_idx++); if (args_attribs_.has_args()) { - emit(" device byte* args_addr [[buffer(2)]],"); + emit(" device byte* {} [[buffer({})]],", kArgsBufferName, + buffer_idx++); } emit(" const uint {} [[thread_position_in_grid]]) {{", kKernelThreadIdName); if (args_attribs_.has_args()) { - emit(" {} {}(args_addr);", kernel_args_classname(), kArgsContextName); + emit(" {} {}({});", kernel_args_classname(), kArgsContextName, + kArgsBufferName); } } @@ -539,6 +553,8 @@ class MetalKernelCodegen : public IRVisitor { const std::string mtl_kernel_prefix_; const std::string root_snode_type_name_; + const StructCompiledResult *const compiled_snode_structs_; + const bool needs_root_buffer_; bool is_top_level_{true}; int mtl_kernel_count_{0}; @@ -684,8 +700,9 @@ void MetalCodeGen::lower() { FunctionType MetalCodeGen::gen(const SNode &root_snode, MetalRuntime *runtime) { // Make a copy of the name! const std::string taichi_kernel_name = taichi_kernel_name_; - MetalKernelCodegen codegen(taichi_kernel_name, root_snode.node_type_name); - codegen.run(struct_compiled_->source_code, kernel_); + MetalKernelCodegen codegen(taichi_kernel_name, root_snode.node_type_name, + struct_compiled_); + codegen.run(kernel_); metal::MetalKernelArgsAttributes mtl_args_attribs; for (const auto &arg : kernel_->args) { mtl_args_attribs.insert_arg(arg.dt, arg.is_nparray, arg.size, diff --git a/taichi/platform/metal/metal_runtime.cpp b/taichi/platform/metal/metal_runtime.cpp index abdf7d184e811..81e5180097e66 100644 --- a/taichi/platform/metal/metal_runtime.cpp +++ b/taichi/platform/metal/metal_runtime.cpp @@ -98,8 +98,10 @@ class CompiledMtlKernel { set_compute_pipeline_state(encoder.get(), pipeline_state_.get()); int buffer_index = 0; - set_mtl_buffer(encoder.get(), data_buffers.root, /*offset=*/0, - buffer_index++); + if (data_buffers.root) { + set_mtl_buffer(encoder.get(), data_buffers.root, /*offset=*/0, + buffer_index++); + } set_mtl_buffer(encoder.get(), data_buffers.global_tmps, /*offset=*/0, buffer_index++); if (data_buffers.args) { @@ -284,7 +286,8 @@ class MetalRuntime::Impl { explicit Impl(Params params) : config_(params.config), mem_pool_(params.mem_pool), - profiler_(params.profiler) { + profiler_(params.profiler), + needs_root_buffer_(params.root_size > 0) { if (config_->debug) { TI_ASSERT(is_metal_api_available()); } @@ -300,6 +303,7 @@ class MetalRuntime::Impl { const size_t rtm_root_mem_size = llvm_ctx->lookup_function( "Runtime_get_root_mem_size")(llvm_rtm); if (rtm_root_mem_size > 0) { + TI_ASSERT(needs_root_buffer_); // Make sure the runtime's root memory is large enough. TI_ASSERT(iroundup(params.root_size, taichi_page_size) <= rtm_root_mem_size); @@ -308,19 +312,12 @@ class MetalRuntime::Impl { TI_ASSERT(rtm_root_mem != nullptr); root_buffer_ = new_mtl_buffer_no_copy(device_.get(), rtm_root_mem, rtm_root_mem_size); + TI_ASSERT(root_buffer_ != nullptr); + TI_DEBUG("Metal root buffer size: {} bytes", rtm_root_mem_size); } else { - // TODO(k-ye) In case no SNodes is defined, we should not allocate - // |root_buffer_| at all. However, that requires us to change the Metal - // kernels so that |root_buffer_| isn't a required argument. - TI_TRACE( - "LLVM root buffer size is 0, allocating directly from the memory " - "pool"); - root_buffer_mem_ = - std::make_unique(taichi_page_size, mem_pool_); - root_buffer_ = new_mtl_buffer_no_copy( - device_.get(), root_buffer_mem_->ptr(), root_buffer_mem_->size()); + TI_ASSERT(!needs_root_buffer_); + TI_DEBUG("Metal root buffer is empty"); } - TI_ASSERT(root_buffer_ != nullptr); // Make sure we don't have to round up global temporaries' buffer size. TI_ASSERT(iroundup(taichi_global_tmp_buffer_size, taichi_page_size) == @@ -427,11 +424,10 @@ class MetalRuntime::Impl { CompileConfig *const config_; MemoryPool *const mem_pool_; ProfilerBase *const profiler_; + const bool needs_root_buffer_; nsobj_unique_ptr device_{nullptr}; nsobj_unique_ptr command_queue_{nullptr}; nsobj_unique_ptr cur_command_buffer_{nullptr}; - // |root_buffer_mem_| is used only when the LLVM's root size is 0. - std::unique_ptr root_buffer_mem_; nsobj_unique_ptr root_buffer_{nullptr}; uint8_t *global_tmps_mem_begin_; nsobj_unique_ptr global_tmps_buffer_{nullptr}; diff --git a/taichi/program.cpp b/taichi/program.cpp index dfde94c654096..89def61830131 100644 --- a/taichi/program.cpp +++ b/taichi/program.cpp @@ -204,7 +204,6 @@ void Program::materialize_layout() { params.profiler = profiler.get(); metal_runtime_ = std::make_unique(std::move(params)); } - TI_INFO("Metal root buffer size: {} B", metal_struct_compiled_->root_size); } }