Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Metal] Add codegen/runtime support for print() #1310

Merged
merged 5 commits into from
Jun 26, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 54 additions & 5 deletions taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@ namespace shaders {
#define TI_INSIDE_METAL_CODEGEN
#include "taichi/backends/metal/shaders/ad_stack.metal.h"
#include "taichi/backends/metal/shaders/helpers.metal.h"
#include "taichi/backends/metal/shaders/print.metal.h"
#include "taichi/backends/metal/shaders/runtime_kernels.metal.h"
#undef TI_INSIDE_METAL_CODEGEN

#include "taichi/backends/metal/shaders/print.metal.h"
#include "taichi/backends/metal/shaders/runtime_structs.metal.h"

} // namespace shaders

using BuffersEnum = KernelAttributes::Buffers;
Expand All @@ -33,6 +36,8 @@ constexpr char kContextBufferName[] = "ctx_addr";
constexpr char kContextVarName[] = "kernel_ctx_";
constexpr char kRuntimeBufferName[] = "runtime_addr";
constexpr char kRuntimeVarName[] = "runtime_";
constexpr char kPrintBufferName[] = "print_addr";
constexpr char kPrintAllocVarName[] = "print_alloc_";
constexpr char kLinearLoopIndexName[] = "linear_loop_idx_";
constexpr char kListgenElemVarName[] = "listgen_elem_";
constexpr char kRandStateVarName[] = "rand_state_";
Expand All @@ -49,6 +54,8 @@ std::string buffer_to_name(BuffersEnum b) {
return kContextBufferName;
case BuffersEnum::Runtime:
return kRuntimeBufferName;
case BuffersEnum::Print:
return kPrintBufferName;
default:
TI_NOT_IMPLEMENTED;
break;
Expand Down Expand Up @@ -77,16 +84,19 @@ class KernelCodegen : public IRVisitor {
};

public:
// TODO(k-ye): Create a Params to hold these ctor params.
KernelCodegen(const std::string &mtl_kernel_prefix,
const std::string &root_snode_type_name,
Kernel *kernel,
const CompiledStructs *compiled_structs)
const CompiledStructs *compiled_structs,
PrintStringTable *print_strtab)
: mtl_kernel_prefix_(mtl_kernel_prefix),
root_snode_type_name_(root_snode_type_name),
kernel_(kernel),
compiled_structs_(compiled_structs),
needs_root_buffer_(compiled_structs_->root_size > 0),
ctx_attribs_(*kernel_) {
ctx_attribs_(*kernel_),
print_strtab_(print_strtab) {
// allow_undefined_visitor = true;
for (const auto s : kAllSections) {
section_appenders_[s] = LineAppender();
Expand Down Expand Up @@ -548,8 +558,34 @@ class KernelCodegen : public IRVisitor {
}

void visit(PrintStmt *stmt) override {
// TODO: Add a flag to control whether ignoring print() stmt is allowed.
TI_WARN("Cannot print inside Metal kernel, ignored");
mark_print_used();
const auto &contents = stmt->contents;
const int num_entries = contents.size();
const std::string msgbuf_var_name = stmt->raw_name() + "_msgbuf_";
emit("device auto* {} = mtl_print_alloc_buf({}, {});", msgbuf_var_name,
kPrintAllocVarName, num_entries);
// Check for buffer overflow
emit("if ({}) {{", msgbuf_var_name);
{
ScopedIndent s(current_appender());
const std::string msg_var_name = stmt->raw_name() + "_msg_";
emit("PrintMsg {}({}, {});", msg_var_name, msgbuf_var_name, num_entries);
for (int i = 0; i < num_entries; ++i) {
const auto &entry = contents[i];
if (std::holds_alternative<Stmt *>(entry)) {
auto *arg_stmt = std::get<Stmt *>(entry);
const auto dt = arg_stmt->element_type();
TI_ASSERT_INFO(dt == DataType::i32 || dt == DataType::f32,
"print() only supports i32 or f32 scalars for now.");
emit("{}.pm_set_{}({}, {});", msg_var_name, data_type_short_name(dt),
i, arg_stmt->raw_name());
} else {
const int str_id = print_strtab_->put(std::get<std::string>(entry));
emit("{}.pm_set_str({}, {});", msg_var_name, i, str_id);
}
}
}
emit("}}");
}

void visit(StackAllocaStmt *stmt) override {
Expand Down Expand Up @@ -630,6 +666,8 @@ class KernelCodegen : public IRVisitor {
emit("");
current_appender().append_raw(shaders::kMetalAdStackSourceCode);
emit("");
current_appender().append_raw(shaders::kMetalPrintSourceCode);
emit("");
emit_kernel_args_struct();
}

Expand Down Expand Up @@ -697,6 +735,8 @@ class KernelCodegen : public IRVisitor {
result.push_back(BuffersEnum::Context);
}
result.push_back(BuffersEnum::Runtime);
// TODO(k-ye): Bind this buffer only when print() is used.
result.push_back(BuffersEnum::Print);
return result;
}

Expand Down Expand Up @@ -955,6 +995,9 @@ class KernelCodegen : public IRVisitor {
fmt::arg("rtm", kRuntimeVarName),
fmt::arg("lidx", kLinearLoopIndexName),
fmt::arg("nums", kNumRandSeeds));
// Init PrintMsgAllocator
emit("device auto* {} = reinterpret_cast<device PrintMsgAllocator*>({});",
kPrintAllocVarName, kPrintBufferName);
}
// We do not need additional indentation, because |func_ir| itself is a
// block, which will be indented automatically.
Expand Down Expand Up @@ -1043,6 +1086,11 @@ class KernelCodegen : public IRVisitor {
return kernel_name + "_func";
}

void mark_print_used() {
TI_ASSERT(current_kernel_attribs_ != nullptr);
current_kernel_attribs_->uses_print = true;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Use a UsedFeatures structure like OpenGL does.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point.. This is a sloppy design in Metal. I don't have a data structure that records the attributes of a taichi kernel. All i have is a vector of Metal kernels', see

const std::vector<KernelAttributes> &kernels_attribs,
.

I'd like to take this opportunity to clean up this flaw in another PR. Once I have a TaichiKernelAttributes, i can also add a UsedFeatures there. Then we don't even need to iterate over all the metal kernels to figure out if we need to flush the print buffers or not. (I tried adding something, but feel like it's too noisy for this one.) WDYT?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, ff2 do it iapr, you may also implement a operator|=(UsedFeatures, UsedFeatures) for recording all used feature among kernels :)

}

class SectionGuard {
public:
SectionGuard(KernelCodegen *kg, Section new_sec)
Expand Down Expand Up @@ -1079,6 +1127,7 @@ class KernelCodegen : public IRVisitor {
const CompiledStructs *const compiled_structs_;
const bool needs_root_buffer_;
const KernelContextAttributes ctx_attribs_;
PrintStringTable *const print_strtab_;

bool is_top_level_{true};
int mtl_kernel_count_{0};
Expand Down Expand Up @@ -1111,7 +1160,7 @@ FunctionType CodeGen::compile() {

KernelCodegen codegen(taichi_kernel_name_,
kernel_->program.snode_root->node_type_name, kernel_,
compiled_structs_);
compiled_structs_, kernel_mgr_->print_strtable());
const auto source_code = codegen.run();
kernel_mgr_->register_taichi_kernel(taichi_kernel_name_, source_code,
codegen.kernels_attribs(),
Expand Down
82 changes: 79 additions & 3 deletions taichi/backends/metal/kernel_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ namespace metal {

namespace {
namespace shaders {
#include "taichi/backends/metal/shaders/print.metal.h"
#include "taichi/backends/metal/shaders/runtime_utils.metal.h"
}
} // namespace shaders

using KernelTaskType = OffloadedStmt::TaskType;
using BufferEnum = KernelAttributes::Buffers;
Expand Down Expand Up @@ -433,7 +434,15 @@ class KernelManager::Impl {
runtime_buffer_ != nullptr,
"Failed to allocate Metal runtime buffer, requested {} bytes",
runtime_mem_->size());
print_mem_ = std::make_unique<BufferMemoryView>(
sizeof(shaders::PrintMsgAllocator) + shaders::kMetalPrintBufferSize,
mem_pool_);
print_buffer_ = new_mtl_buffer_no_copy(device_.get(), print_mem_->ptr(),
print_mem_->size());
TI_ASSERT(print_buffer_ != nullptr);

init_runtime(params.root_id);
init_print_buffer();
}

void register_taichi_kernel(
Expand Down Expand Up @@ -477,20 +486,30 @@ class KernelManager::Impl {
{BufferEnum::Root, root_buffer_.get()},
{BufferEnum::GlobalTmps, global_tmps_buffer_.get()},
{BufferEnum::Runtime, runtime_buffer_.get()},
{BufferEnum::Print, print_buffer_.get()},
};
if (ctx_blitter) {
ctx_blitter->host_to_metal();
input_buffers[BufferEnum::Context] = ctk.ctx_buffer.get();
}

bool uses_print = false;
for (const auto &mk : ctk.compiled_mtl_kernels) {
mk->launch(input_buffers, cur_command_buffer_.get());
uses_print = (uses_print || mk->kernel_attribs()->uses_print);
}
if (ctx_blitter) {

if (ctx_blitter || uses_print) {
// TODO(k-ye): One optimization is to synchronize only when we absolutely
// need to transfer the data back to host. This includes the cases where
// an arg is 1) an array, or 2) used as return value.
synchronize();
ctx_blitter->metal_to_host();
if (ctx_blitter) {
ctx_blitter->metal_to_host();
}
if (uses_print) {
flush_print_buffers();
}
}
}

Expand All @@ -502,6 +521,10 @@ class KernelManager::Impl {
profiler_->stop();
}

PrintStringTable *print_strtable() {
return &print_strtable_;
}

private:
void init_runtime(int root_id) {
using namespace shaders;
Expand Down Expand Up @@ -625,6 +648,47 @@ class KernelManager::Impl {
}
}

void init_print_buffer() {
// This includes setting PrintMsgAllocator::next to zero.
std::memset(print_mem_->ptr(), 0, print_mem_->size());
}

void flush_print_buffers() {
auto *pa =
reinterpret_cast<shaders::PrintMsgAllocator *>(print_mem_->ptr());
const int used_sz = std::min(pa->next, shaders::kMetalPrintBufferSize);
using MsgType = shaders::PrintMsg::Type;
char *buf = reinterpret_cast<char *>(pa + 1);
const char *buf_end = buf + used_sz;

while (buf < buf_end) {
int32_t *msg_ptr = reinterpret_cast<int32_t *>(buf);
const int num_entries = *msg_ptr;
++msg_ptr;
shaders::PrintMsg msg(msg_ptr, num_entries);
for (int i = 0; i < num_entries; ++i) {
const auto dt = msg.pm_get_type(i);
const int32_t x = msg.pm_get_data(i);
if (dt == MsgType::I32) {
std::cout << x;
} else if (dt == MsgType::F32) {
std::cout << *reinterpret_cast<const float *>(&x);
} else if (dt == MsgType::Str) {
std::cout << print_strtable_.get(x);
} else {
TI_ERROR("Unexecpted data type={}", dt);
}
}
buf += shaders::mtl_compute_print_msg_bytes(num_entries);
}

if (pa->next >= shaders::kMetalPrintBufferSize) {
std::cout << "...(maximum print buffer reached)\n";
}

pa->next = 0;
}

static int compute_num_elems_per_chunk(int n) {
const int lb =
(n + shaders::kTaichiNumChunks - 1) / shaders::kTaichiNumChunks;
Expand Down Expand Up @@ -662,8 +726,11 @@ class KernelManager::Impl {
nsobj_unique_ptr<MTLBuffer> global_tmps_buffer_;
std::unique_ptr<BufferMemoryView> runtime_mem_;
nsobj_unique_ptr<MTLBuffer> runtime_buffer_;
std::unique_ptr<BufferMemoryView> print_mem_;
nsobj_unique_ptr<MTLBuffer> print_buffer_;
std::unordered_map<std::string, std::unique_ptr<CompiledTaichiKernel>>
compiled_taichi_kernels_;
PrintStringTable print_strtable_;
};

#else
Expand All @@ -690,6 +757,11 @@ class KernelManager::Impl {
void synchronize() {
TI_ERROR("Metal not supported on the current OS");
}

PrintStringTable *print_strtable() {
TI_ERROR("Metal not supported on the current OS");
return nullptr;
}
};

#endif // TI_PLATFORM_OSX
Expand Down Expand Up @@ -719,5 +791,9 @@ void KernelManager::synchronize() {
impl_->synchronize();
}

PrintStringTable *KernelManager::print_strtable() {
return impl_->print_strtable();
}

} // namespace metal
TLANG_NAMESPACE_END
2 changes: 2 additions & 0 deletions taichi/backends/metal/kernel_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class KernelManager {
// Synchronize the memory content from Metal to host (x86_64).
void synchronize();

PrintStringTable *print_strtable();

private:
// Use Pimpl so that we can expose this interface without conditionally
// compiling on TI_PLATFORM_OSX
Expand Down
21 changes: 17 additions & 4 deletions taichi/backends/metal/kernel_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,28 @@ TLANG_NAMESPACE_BEGIN

namespace metal {

int PrintStringTable::put(const std::string &str) {
int i = 0;
for (; i < strs_.size(); ++i) {
if (str == strs_[i]) {
return i;
}
}
strs_.push_back(str);
return i;
}

const std::string &PrintStringTable::get(int i) {
return strs_[i];
}

// static
std::string KernelAttributes::buffers_name(Buffers b) {
#define REGISTER_NAME(x) \
{ Buffers::x, #x }
const static std::unordered_map<Buffers, std::string> m = {
REGISTER_NAME(Root),
REGISTER_NAME(GlobalTmps),
REGISTER_NAME(Context),
REGISTER_NAME(Runtime),
REGISTER_NAME(Root), REGISTER_NAME(GlobalTmps), REGISTER_NAME(Context),
REGISTER_NAME(Runtime), REGISTER_NAME(Print),
};
#undef REGISTER_NAME
return m.find(b)->second;
Expand Down
15 changes: 15 additions & 0 deletions taichi/backends/metal/kernel_util.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <optional>
#include <string>
#include <vector>

Expand All @@ -21,13 +22,24 @@ class SNode;

namespace metal {

// TODO(k-ye): Share this between OpenGL and Metal?
class PrintStringTable {
public:
int put(const std::string &str);
const std::string &get(int i);

private:
std::vector<std::string> strs_;
};

// This struct holds the necessary information to launch a Metal kernel.
struct KernelAttributes {
enum class Buffers {
Root,
GlobalTmps,
Context,
Runtime,
Print,
};
std::string name;
int num_threads;
Expand Down Expand Up @@ -59,6 +71,9 @@ struct KernelAttributes {
// clear_list + listgen
RuntimeListOpAttributes runtime_list_op_attribs;

// Whether print() is called inside this kernel.
bool uses_print = false;

static std::string buffers_name(Buffers b);
std::string debug_string() const;
};
Expand Down
4 changes: 1 addition & 3 deletions taichi/backends/metal/shaders/print.metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ STR(
}

[[maybe_unused]] inline int mtl_compute_print_msg_bytes(int num_entries) {
// First int32: stores |num_entries|
// Then follows the number of int32's to store the type masks.
// Finally the print data. 4 byte for each entry.
// See PrintMsg's layout for how this is computed.
const int sz =
sizeof(int32_t) *
(1 + mtl_compute_num_print_msg_typemasks(num_entries) + num_entries);
Expand Down
Loading