Skip to content

Commit

Permalink
Do not instantiate Metal root buffer if it's not used (#532)
Browse files Browse the repository at this point in the history
  • Loading branch information
k-ye authored Feb 27, 2020
1 parent 0a97f4d commit 245d760
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 28 deletions.
39 changes: 28 additions & 11 deletions taichi/backends/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,20 @@ namespace metal {
namespace {

constexpr char kKernelThreadIdName[] = "utid_"; // 'u' for unsigned
constexpr char kRootBufferName[] = "root_addr";
constexpr char kGlobalTmpsBufferName[] = "global_tmps_addr";
constexpr char kArgsBufferName[] = "args_addr";
constexpr char kArgsContextName[] = "args_ctx_";

class MetalKernelCodegen : public IRVisitor {
public:
MetalKernelCodegen(const std::string &mtl_kernel_prefix,
const std::string &root_snode_type_name)
const std::string &root_snode_type_name,
const StructCompiledResult *compiled_snode_structs)
: mtl_kernel_prefix_(mtl_kernel_prefix),
root_snode_type_name_(root_snode_type_name) {
root_snode_type_name_(root_snode_type_name),
compiled_snode_structs_(compiled_snode_structs),
needs_root_buffer_(compiled_snode_structs_->root_size > 0) {
// allow_undefined_visitor = true;
}

Expand All @@ -26,8 +31,8 @@ class MetalKernelCodegen : public IRVisitor {
return mtl_kernels_attribs_;
}

void run(const std::string &snode_structs_source_code, Kernel *kernel) {
generate_mtl_header(snode_structs_source_code);
void run(Kernel *kernel) {
generate_mtl_header(compiled_snode_structs_->source_code);
generate_kernel_args_struct(kernel);
kernel->ir->accept(this);
}
Expand Down Expand Up @@ -79,8 +84,10 @@ class MetalKernelCodegen : public IRVisitor {

void visit(GetRootStmt *stmt) override {
// Should we assert |root_stmt_| is assigned only once?
TI_ASSERT(needs_root_buffer_);
root_stmt_ = stmt;
emit(R"({} {}(addr);)", root_snode_type_name_, stmt->raw_name());
emit(R"({} {}({});)", root_snode_type_name_, stmt->raw_name(),
kRootBufferName);
}

void visit(GetChStmt *stmt) override {
Expand Down Expand Up @@ -512,15 +519,22 @@ class MetalKernelCodegen : public IRVisitor {

void emit_mtl_kernel_func_sig(const std::string &kernel_name) {
emit("kernel void {}(", kernel_name);
emit(" device byte* addr [[buffer(0)]],");
emit(" device byte* {} [[buffer(1)]],", kGlobalTmpsBufferName);
int buffer_idx = 0;
if (needs_root_buffer_) {
emit(" device byte* {} [[buffer({})]],", kRootBufferName,
buffer_idx++);
}
emit(" device byte* {} [[buffer({})]],", kGlobalTmpsBufferName,
buffer_idx++);
if (args_attribs_.has_args()) {
emit(" device byte* args_addr [[buffer(2)]],");
emit(" device byte* {} [[buffer({})]],", kArgsBufferName,
buffer_idx++);
}
emit(" const uint {} [[thread_position_in_grid]]) {{",
kKernelThreadIdName);
if (args_attribs_.has_args()) {
emit(" {} {}(args_addr);", kernel_args_classname(), kArgsContextName);
emit(" {} {}({});", kernel_args_classname(), kArgsContextName,
kArgsBufferName);
}
}

Expand All @@ -539,6 +553,8 @@ class MetalKernelCodegen : public IRVisitor {

const std::string mtl_kernel_prefix_;
const std::string root_snode_type_name_;
const StructCompiledResult *const compiled_snode_structs_;
const bool needs_root_buffer_;

bool is_top_level_{true};
int mtl_kernel_count_{0};
Expand Down Expand Up @@ -684,8 +700,9 @@ void MetalCodeGen::lower() {
FunctionType MetalCodeGen::gen(const SNode &root_snode, MetalRuntime *runtime) {
// Make a copy of the name!
const std::string taichi_kernel_name = taichi_kernel_name_;
MetalKernelCodegen codegen(taichi_kernel_name, root_snode.node_type_name);
codegen.run(struct_compiled_->source_code, kernel_);
MetalKernelCodegen codegen(taichi_kernel_name, root_snode.node_type_name,
struct_compiled_);
codegen.run(kernel_);
metal::MetalKernelArgsAttributes mtl_args_attribs;
for (const auto &arg : kernel_->args) {
mtl_args_attribs.insert_arg(arg.dt, arg.is_nparray, arg.size,
Expand Down
28 changes: 12 additions & 16 deletions taichi/platform/metal/metal_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,10 @@ class CompiledMtlKernel {

set_compute_pipeline_state(encoder.get(), pipeline_state_.get());
int buffer_index = 0;
set_mtl_buffer(encoder.get(), data_buffers.root, /*offset=*/0,
buffer_index++);
if (data_buffers.root) {
set_mtl_buffer(encoder.get(), data_buffers.root, /*offset=*/0,
buffer_index++);
}
set_mtl_buffer(encoder.get(), data_buffers.global_tmps, /*offset=*/0,
buffer_index++);
if (data_buffers.args) {
Expand Down Expand Up @@ -284,7 +286,8 @@ class MetalRuntime::Impl {
explicit Impl(Params params)
: config_(params.config),
mem_pool_(params.mem_pool),
profiler_(params.profiler) {
profiler_(params.profiler),
needs_root_buffer_(params.root_size > 0) {
if (config_->debug) {
TI_ASSERT(is_metal_api_available());
}
Expand All @@ -300,6 +303,7 @@ class MetalRuntime::Impl {
const size_t rtm_root_mem_size = llvm_ctx->lookup_function<size_t(void *)>(
"Runtime_get_root_mem_size")(llvm_rtm);
if (rtm_root_mem_size > 0) {
TI_ASSERT(needs_root_buffer_);
// Make sure the runtime's root memory is large enough.
TI_ASSERT(iroundup(params.root_size, taichi_page_size) <=
rtm_root_mem_size);
Expand All @@ -308,19 +312,12 @@ class MetalRuntime::Impl {
TI_ASSERT(rtm_root_mem != nullptr);
root_buffer_ = new_mtl_buffer_no_copy(device_.get(), rtm_root_mem,
rtm_root_mem_size);
TI_ASSERT(root_buffer_ != nullptr);
TI_DEBUG("Metal root buffer size: {} bytes", rtm_root_mem_size);
} else {
// TODO(k-ye) In case no SNodes is defined, we should not allocate
// |root_buffer_| at all. However, that requires us to change the Metal
// kernels so that |root_buffer_| isn't a required argument.
TI_TRACE(
"LLVM root buffer size is 0, allocating directly from the memory "
"pool");
root_buffer_mem_ =
std::make_unique<BufferMemoryView>(taichi_page_size, mem_pool_);
root_buffer_ = new_mtl_buffer_no_copy(
device_.get(), root_buffer_mem_->ptr(), root_buffer_mem_->size());
TI_ASSERT(!needs_root_buffer_);
TI_DEBUG("Metal root buffer is empty");
}
TI_ASSERT(root_buffer_ != nullptr);

// Make sure we don't have to round up global temporaries' buffer size.
TI_ASSERT(iroundup(taichi_global_tmp_buffer_size, taichi_page_size) ==
Expand Down Expand Up @@ -427,11 +424,10 @@ class MetalRuntime::Impl {
CompileConfig *const config_;
MemoryPool *const mem_pool_;
ProfilerBase *const profiler_;
const bool needs_root_buffer_;
nsobj_unique_ptr<MTLDevice> device_{nullptr};
nsobj_unique_ptr<MTLCommandQueue> command_queue_{nullptr};
nsobj_unique_ptr<MTLCommandBuffer> cur_command_buffer_{nullptr};
// |root_buffer_mem_| is used only when the LLVM's root size is 0.
std::unique_ptr<BufferMemoryView> root_buffer_mem_;
nsobj_unique_ptr<MTLBuffer> root_buffer_{nullptr};
uint8_t *global_tmps_mem_begin_;
nsobj_unique_ptr<MTLBuffer> global_tmps_buffer_{nullptr};
Expand Down
1 change: 0 additions & 1 deletion taichi/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ void Program::materialize_layout() {
params.profiler = profiler.get();
metal_runtime_ = std::make_unique<metal::MetalRuntime>(std::move(params));
}
TI_INFO("Metal root buffer size: {} B", metal_struct_compiled_->root_size);
}
}

Expand Down

0 comments on commit 245d760

Please sign in to comment.