diff --git a/taichi/platform/metal/metal_runtime.cpp b/taichi/platform/metal/metal_runtime.cpp index 1db530e12aa97..41d348642edb5 100644 --- a/taichi/platform/metal/metal_runtime.cpp +++ b/taichi/platform/metal/metal_runtime.cpp @@ -57,7 +57,11 @@ class BufferMemoryView { void *ptr_; }; -using MTLBuffersInput = std::unordered_map; +// MetalRuntime maintains a series of MTLBuffers that are shared across all the +// Metal kernels mapped by a single Taichi kernel. This map stores those buffers +// from their enum. Each CompiledMtlKernelBase can then decide which specific +// buffers they will need to use in a launch. +using InputBuffersMap = std::unordered_map; // Info for launching a compiled Metal kernel class CompiledMtlKernelBase { @@ -82,7 +86,7 @@ class CompiledMtlKernelBase { return &kernel_attribs_; } - virtual void launch(MTLBuffersInput &input_buffers, + virtual void launch(InputBuffersMap &input_buffers, MTLCommandBuffer *command_buffer) = 0; protected: @@ -118,10 +122,11 @@ class CompiledMtlKernelBase { nsobj_unique_ptr pipeline_state_{nullptr}; }; +// Metal kernel derived from a user Taichi kernel class UserMtlKernel : public CompiledMtlKernelBase { public: using CompiledMtlKernelBase::CompiledMtlKernelBase; - void launch(MTLBuffersInput &input_buffers, + void launch(InputBuffersMap &input_buffers, MTLCommandBuffer *command_buffer) override { // 0 is valid for |num_threads|! TI_ASSERT(kernel_attribs_.num_threads >= 0); @@ -139,6 +144,7 @@ class UserMtlKernel : public CompiledMtlKernelBase { } }; +// Internal Metal kernel used to maintain the kernel runtime data class RuntimeListOpsMtlKernel : public CompiledMtlKernelBase { public: struct Params : public CompiledMtlKernelBase::Params { @@ -162,7 +168,7 @@ class RuntimeListOpsMtlKernel : public CompiledMtlKernelBase { mem[1] = child_snode_id_; } - void launch(MTLBuffersInput &input_buffers, + void launch(InputBuffersMap &input_buffers, MTLCommandBuffer *command_buffer) override { BindBuffers buffers; for (const auto b : kernel_attribs_.buffers) { @@ -178,6 +184,12 @@ class RuntimeListOpsMtlKernel : public CompiledMtlKernelBase { private: const int parent_snode_id_; const int child_snode_id_; + // For such Metal kernels, it always takes in an args buffer of two int32's: + // args[0] = parent_snode_id + // args[1] = child_snode_id + // Note that this args buffer has nothing to do with the one passed to Taichi + // kernel. + // See taichi/platform/metal/shaders/runtime_kernels.metal.h std::unique_ptr args_mem_; nsobj_unique_ptr args_buffer_; }; @@ -433,7 +445,7 @@ class MetalRuntime::Impl { TI_INFO("Lauching Taichi kernel <{}>", taichi_kernel_name); } - MTLBuffersInput input_buffers = { + InputBuffersMap input_buffers = { {BufferEnum::Root, root_buffer_.get()}, {BufferEnum::GlobalTmps, global_tmps_buffer_.get()}, {BufferEnum::Runtime, runtime_buffer_.get()},