Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
k-ye committed Mar 22, 2020
1 parent 9e31ef7 commit 0d728d1
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions taichi/platform/metal/metal_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@ class BufferMemoryView {
void *ptr_;
};

using MTLBuffersInput = std::unordered_map<BufferEnum, MTLBuffer *>;
// MetalRuntime maintains a series of MTLBuffers that are shared across all the
// Metal kernels mapped by a single Taichi kernel. This map stores those buffers
// from their enum. Each CompiledMtlKernelBase can then decide which specific
// buffers they will need to use in a launch.
using InputBuffersMap = std::unordered_map<BufferEnum, MTLBuffer *>;

// Info for launching a compiled Metal kernel
class CompiledMtlKernelBase {
Expand All @@ -82,7 +86,7 @@ class CompiledMtlKernelBase {
return &kernel_attribs_;
}

virtual void launch(MTLBuffersInput &input_buffers,
virtual void launch(InputBuffersMap &input_buffers,
MTLCommandBuffer *command_buffer) = 0;

protected:
Expand Down Expand Up @@ -118,10 +122,11 @@ class CompiledMtlKernelBase {
nsobj_unique_ptr<MTLComputePipelineState> pipeline_state_{nullptr};
};

// Metal kernel derived from a user Taichi kernel
class UserMtlKernel : public CompiledMtlKernelBase {
public:
using CompiledMtlKernelBase::CompiledMtlKernelBase;
void launch(MTLBuffersInput &input_buffers,
void launch(InputBuffersMap &input_buffers,
MTLCommandBuffer *command_buffer) override {
// 0 is valid for |num_threads|!
TI_ASSERT(kernel_attribs_.num_threads >= 0);
Expand All @@ -139,6 +144,7 @@ class UserMtlKernel : public CompiledMtlKernelBase {
}
};

// Internal Metal kernel used to maintain the kernel runtime data
class RuntimeListOpsMtlKernel : public CompiledMtlKernelBase {
public:
struct Params : public CompiledMtlKernelBase::Params {
Expand All @@ -162,7 +168,7 @@ class RuntimeListOpsMtlKernel : public CompiledMtlKernelBase {
mem[1] = child_snode_id_;
}

void launch(MTLBuffersInput &input_buffers,
void launch(InputBuffersMap &input_buffers,
MTLCommandBuffer *command_buffer) override {
BindBuffers buffers;
for (const auto b : kernel_attribs_.buffers) {
Expand All @@ -178,6 +184,12 @@ class RuntimeListOpsMtlKernel : public CompiledMtlKernelBase {
private:
const int parent_snode_id_;
const int child_snode_id_;
// For such Metal kernels, it always takes in an args buffer of two int32's:
// args[0] = parent_snode_id
// args[1] = child_snode_id
// Note that this args buffer has nothing to do with the one passed to Taichi
// kernel.
// See taichi/platform/metal/shaders/runtime_kernels.metal.h
std::unique_ptr<BufferMemoryView> args_mem_;
nsobj_unique_ptr<MTLBuffer> args_buffer_;
};
Expand Down Expand Up @@ -433,7 +445,7 @@ class MetalRuntime::Impl {
TI_INFO("Lauching Taichi kernel <{}>", taichi_kernel_name);
}

MTLBuffersInput input_buffers = {
InputBuffersMap input_buffers = {
{BufferEnum::Root, root_buffer_.get()},
{BufferEnum::GlobalTmps, global_tmps_buffer_.get()},
{BufferEnum::Runtime, runtime_buffer_.get()},
Expand Down

0 comments on commit 0d728d1

Please sign in to comment.