diff --git a/taichi/backends/metal/codegen_metal.cpp b/taichi/backends/metal/codegen_metal.cpp index db4bc3e48425b..be3709e2226ad 100644 --- a/taichi/backends/metal/codegen_metal.cpp +++ b/taichi/backends/metal/codegen_metal.cpp @@ -17,10 +17,13 @@ namespace shaders { #define TI_INSIDE_METAL_CODEGEN #include "taichi/backends/metal/shaders/ad_stack.metal.h" #include "taichi/backends/metal/shaders/helpers.metal.h" +#include "taichi/backends/metal/shaders/print.metal.h" #include "taichi/backends/metal/shaders/runtime_kernels.metal.h" #undef TI_INSIDE_METAL_CODEGEN +#include "taichi/backends/metal/shaders/print.metal.h" #include "taichi/backends/metal/shaders/runtime_structs.metal.h" + } // namespace shaders using BuffersEnum = KernelAttributes::Buffers; @@ -33,6 +36,8 @@ constexpr char kContextBufferName[] = "ctx_addr"; constexpr char kContextVarName[] = "kernel_ctx_"; constexpr char kRuntimeBufferName[] = "runtime_addr"; constexpr char kRuntimeVarName[] = "runtime_"; +constexpr char kPrintBufferName[] = "print_addr"; +constexpr char kPrintAllocVarName[] = "print_alloc_"; constexpr char kLinearLoopIndexName[] = "linear_loop_idx_"; constexpr char kListgenElemVarName[] = "listgen_elem_"; constexpr char kRandStateVarName[] = "rand_state_"; @@ -49,6 +54,8 @@ std::string buffer_to_name(BuffersEnum b) { return kContextBufferName; case BuffersEnum::Runtime: return kRuntimeBufferName; + case BuffersEnum::Print: + return kPrintBufferName; default: TI_NOT_IMPLEMENTED; break; @@ -77,16 +84,19 @@ class KernelCodegen : public IRVisitor { }; public: + // TODO(k-ye): Create a Params to hold these ctor params. KernelCodegen(const std::string &mtl_kernel_prefix, const std::string &root_snode_type_name, Kernel *kernel, - const CompiledStructs *compiled_structs) + const CompiledStructs *compiled_structs, + PrintStringTable *print_strtab) : mtl_kernel_prefix_(mtl_kernel_prefix), root_snode_type_name_(root_snode_type_name), kernel_(kernel), compiled_structs_(compiled_structs), needs_root_buffer_(compiled_structs_->root_size > 0), - ctx_attribs_(*kernel_) { + ctx_attribs_(*kernel_), + print_strtab_(print_strtab) { // allow_undefined_visitor = true; for (const auto s : kAllSections) { section_appenders_[s] = LineAppender(); @@ -548,8 +558,34 @@ class KernelCodegen : public IRVisitor { } void visit(PrintStmt *stmt) override { - // TODO: Add a flag to control whether ignoring print() stmt is allowed. - TI_WARN("Cannot print inside Metal kernel, ignored"); + mark_print_used(); + const auto &contents = stmt->contents; + const int num_entries = contents.size(); + const std::string msgbuf_var_name = stmt->raw_name() + "_msgbuf_"; + emit("device auto* {} = mtl_print_alloc_buf({}, {});", msgbuf_var_name, + kPrintAllocVarName, num_entries); + // Check for buffer overflow + emit("if ({}) {{", msgbuf_var_name); + { + ScopedIndent s(current_appender()); + const std::string msg_var_name = stmt->raw_name() + "_msg_"; + emit("PrintMsg {}({}, {});", msg_var_name, msgbuf_var_name, num_entries); + for (int i = 0; i < num_entries; ++i) { + const auto &entry = contents[i]; + if (std::holds_alternative(entry)) { + auto *arg_stmt = std::get(entry); + const auto dt = arg_stmt->element_type(); + TI_ASSERT_INFO(dt == DataType::i32 || dt == DataType::f32, + "print() only supports i32 or f32 scalars for now."); + emit("{}.pm_set_{}({}, {});", msg_var_name, data_type_short_name(dt), + i, arg_stmt->raw_name()); + } else { + const int str_id = print_strtab_->put(std::get(entry)); + emit("{}.pm_set_str({}, {});", msg_var_name, i, str_id); + } + } + } + emit("}}"); } void visit(StackAllocaStmt *stmt) override { @@ -630,6 +666,8 @@ class KernelCodegen : public IRVisitor { emit(""); current_appender().append_raw(shaders::kMetalAdStackSourceCode); emit(""); + current_appender().append_raw(shaders::kMetalPrintSourceCode); + emit(""); emit_kernel_args_struct(); } @@ -697,6 +735,8 @@ class KernelCodegen : public IRVisitor { result.push_back(BuffersEnum::Context); } result.push_back(BuffersEnum::Runtime); + // TODO(k-ye): Bind this buffer only when print() is used. + result.push_back(BuffersEnum::Print); return result; } @@ -955,6 +995,9 @@ class KernelCodegen : public IRVisitor { fmt::arg("rtm", kRuntimeVarName), fmt::arg("lidx", kLinearLoopIndexName), fmt::arg("nums", kNumRandSeeds)); + // Init PrintMsgAllocator + emit("device auto* {} = reinterpret_cast({});", + kPrintAllocVarName, kPrintBufferName); } // We do not need additional indentation, because |func_ir| itself is a // block, which will be indented automatically. @@ -1043,6 +1086,11 @@ class KernelCodegen : public IRVisitor { return kernel_name + "_func"; } + void mark_print_used() { + TI_ASSERT(current_kernel_attribs_ != nullptr); + current_kernel_attribs_->uses_print = true; + } + class SectionGuard { public: SectionGuard(KernelCodegen *kg, Section new_sec) @@ -1079,6 +1127,7 @@ class KernelCodegen : public IRVisitor { const CompiledStructs *const compiled_structs_; const bool needs_root_buffer_; const KernelContextAttributes ctx_attribs_; + PrintStringTable *const print_strtab_; bool is_top_level_{true}; int mtl_kernel_count_{0}; @@ -1111,7 +1160,7 @@ FunctionType CodeGen::compile() { KernelCodegen codegen(taichi_kernel_name_, kernel_->program.snode_root->node_type_name, kernel_, - compiled_structs_); + compiled_structs_, kernel_mgr_->print_strtable()); const auto source_code = codegen.run(); kernel_mgr_->register_taichi_kernel(taichi_kernel_name_, source_code, codegen.kernels_attribs(), diff --git a/taichi/backends/metal/kernel_manager.cpp b/taichi/backends/metal/kernel_manager.cpp index f26dd001c1086..ccd63af75f48d 100644 --- a/taichi/backends/metal/kernel_manager.cpp +++ b/taichi/backends/metal/kernel_manager.cpp @@ -26,8 +26,9 @@ namespace metal { namespace { namespace shaders { +#include "taichi/backends/metal/shaders/print.metal.h" #include "taichi/backends/metal/shaders/runtime_utils.metal.h" -} +} // namespace shaders using KernelTaskType = OffloadedStmt::TaskType; using BufferEnum = KernelAttributes::Buffers; @@ -433,7 +434,15 @@ class KernelManager::Impl { runtime_buffer_ != nullptr, "Failed to allocate Metal runtime buffer, requested {} bytes", runtime_mem_->size()); + print_mem_ = std::make_unique( + sizeof(shaders::PrintMsgAllocator) + shaders::kMetalPrintBufferSize, + mem_pool_); + print_buffer_ = new_mtl_buffer_no_copy(device_.get(), print_mem_->ptr(), + print_mem_->size()); + TI_ASSERT(print_buffer_ != nullptr); + init_runtime(params.root_id); + init_print_buffer(); } void register_taichi_kernel( @@ -477,20 +486,30 @@ class KernelManager::Impl { {BufferEnum::Root, root_buffer_.get()}, {BufferEnum::GlobalTmps, global_tmps_buffer_.get()}, {BufferEnum::Runtime, runtime_buffer_.get()}, + {BufferEnum::Print, print_buffer_.get()}, }; if (ctx_blitter) { ctx_blitter->host_to_metal(); input_buffers[BufferEnum::Context] = ctk.ctx_buffer.get(); } + + bool uses_print = false; for (const auto &mk : ctk.compiled_mtl_kernels) { mk->launch(input_buffers, cur_command_buffer_.get()); + uses_print = (uses_print || mk->kernel_attribs()->uses_print); } - if (ctx_blitter) { + + if (ctx_blitter || uses_print) { // TODO(k-ye): One optimization is to synchronize only when we absolutely // need to transfer the data back to host. This includes the cases where // an arg is 1) an array, or 2) used as return value. synchronize(); - ctx_blitter->metal_to_host(); + if (ctx_blitter) { + ctx_blitter->metal_to_host(); + } + if (uses_print) { + flush_print_buffers(); + } } } @@ -502,6 +521,10 @@ class KernelManager::Impl { profiler_->stop(); } + PrintStringTable *print_strtable() { + return &print_strtable_; + } + private: void init_runtime(int root_id) { using namespace shaders; @@ -625,6 +648,47 @@ class KernelManager::Impl { } } + void init_print_buffer() { + // This includes setting PrintMsgAllocator::next to zero. + std::memset(print_mem_->ptr(), 0, print_mem_->size()); + } + + void flush_print_buffers() { + auto *pa = + reinterpret_cast(print_mem_->ptr()); + const int used_sz = std::min(pa->next, shaders::kMetalPrintBufferSize); + using MsgType = shaders::PrintMsg::Type; + char *buf = reinterpret_cast(pa + 1); + const char *buf_end = buf + used_sz; + + while (buf < buf_end) { + int32_t *msg_ptr = reinterpret_cast(buf); + const int num_entries = *msg_ptr; + ++msg_ptr; + shaders::PrintMsg msg(msg_ptr, num_entries); + for (int i = 0; i < num_entries; ++i) { + const auto dt = msg.pm_get_type(i); + const int32_t x = msg.pm_get_data(i); + if (dt == MsgType::I32) { + std::cout << x; + } else if (dt == MsgType::F32) { + std::cout << *reinterpret_cast(&x); + } else if (dt == MsgType::Str) { + std::cout << print_strtable_.get(x); + } else { + TI_ERROR("Unexecpted data type={}", dt); + } + } + buf += shaders::mtl_compute_print_msg_bytes(num_entries); + } + + if (pa->next >= shaders::kMetalPrintBufferSize) { + std::cout << "...(maximum print buffer reached)\n"; + } + + pa->next = 0; + } + static int compute_num_elems_per_chunk(int n) { const int lb = (n + shaders::kTaichiNumChunks - 1) / shaders::kTaichiNumChunks; @@ -662,8 +726,11 @@ class KernelManager::Impl { nsobj_unique_ptr global_tmps_buffer_; std::unique_ptr runtime_mem_; nsobj_unique_ptr runtime_buffer_; + std::unique_ptr print_mem_; + nsobj_unique_ptr print_buffer_; std::unordered_map> compiled_taichi_kernels_; + PrintStringTable print_strtable_; }; #else @@ -690,6 +757,11 @@ class KernelManager::Impl { void synchronize() { TI_ERROR("Metal not supported on the current OS"); } + + PrintStringTable *print_strtable() { + TI_ERROR("Metal not supported on the current OS"); + return nullptr; + } }; #endif // TI_PLATFORM_OSX @@ -719,5 +791,9 @@ void KernelManager::synchronize() { impl_->synchronize(); } +PrintStringTable *KernelManager::print_strtable() { + return impl_->print_strtable(); +} + } // namespace metal TLANG_NAMESPACE_END diff --git a/taichi/backends/metal/kernel_manager.h b/taichi/backends/metal/kernel_manager.h index 46e6ecf2c41e8..dc3502577bdd2 100644 --- a/taichi/backends/metal/kernel_manager.h +++ b/taichi/backends/metal/kernel_manager.h @@ -54,6 +54,8 @@ class KernelManager { // Synchronize the memory content from Metal to host (x86_64). void synchronize(); + PrintStringTable *print_strtable(); + private: // Use Pimpl so that we can expose this interface without conditionally // compiling on TI_PLATFORM_OSX diff --git a/taichi/backends/metal/kernel_util.cpp b/taichi/backends/metal/kernel_util.cpp index 912bfa8da8e11..a4d3cc67190dd 100644 --- a/taichi/backends/metal/kernel_util.cpp +++ b/taichi/backends/metal/kernel_util.cpp @@ -11,15 +11,28 @@ TLANG_NAMESPACE_BEGIN namespace metal { +int PrintStringTable::put(const std::string &str) { + int i = 0; + for (; i < strs_.size(); ++i) { + if (str == strs_[i]) { + return i; + } + } + strs_.push_back(str); + return i; +} + +const std::string &PrintStringTable::get(int i) { + return strs_[i]; +} + // static std::string KernelAttributes::buffers_name(Buffers b) { #define REGISTER_NAME(x) \ { Buffers::x, #x } const static std::unordered_map m = { - REGISTER_NAME(Root), - REGISTER_NAME(GlobalTmps), - REGISTER_NAME(Context), - REGISTER_NAME(Runtime), + REGISTER_NAME(Root), REGISTER_NAME(GlobalTmps), REGISTER_NAME(Context), + REGISTER_NAME(Runtime), REGISTER_NAME(Print), }; #undef REGISTER_NAME return m.find(b)->second; diff --git a/taichi/backends/metal/kernel_util.h b/taichi/backends/metal/kernel_util.h index 0787f2d4c8181..a638702c149e2 100644 --- a/taichi/backends/metal/kernel_util.h +++ b/taichi/backends/metal/kernel_util.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -21,6 +22,16 @@ class SNode; namespace metal { +// TODO(k-ye): Share this between OpenGL and Metal? +class PrintStringTable { + public: + int put(const std::string &str); + const std::string &get(int i); + + private: + std::vector strs_; +}; + // This struct holds the necessary information to launch a Metal kernel. struct KernelAttributes { enum class Buffers { @@ -28,6 +39,7 @@ struct KernelAttributes { GlobalTmps, Context, Runtime, + Print, }; std::string name; int num_threads; @@ -59,6 +71,11 @@ struct KernelAttributes { // clear_list + listgen RuntimeListOpAttributes runtime_list_op_attribs; + // Whether print() is called inside this kernel. + // TODO(k-ye): Encapsulate this inside a UsedFeatures. However, we need a + // TaichiKernelAttributes before we can do this. + bool uses_print = false; + static std::string buffers_name(Buffers b); std::string debug_string() const; }; diff --git a/taichi/backends/metal/shaders/print.metal.h b/taichi/backends/metal/shaders/print.metal.h index b6444daae303a..1c8d2c74fc7d6 100644 --- a/taichi/backends/metal/shaders/print.metal.h +++ b/taichi/backends/metal/shaders/print.metal.h @@ -44,9 +44,7 @@ STR( } [[maybe_unused]] inline int mtl_compute_print_msg_bytes(int num_entries) { - // First int32: stores |num_entries| - // Then follows the number of int32's to store the type masks. - // Finally the print data. 4 byte for each entry. + // See PrintMsg's layout for how this is computed. const int sz = sizeof(int32_t) * (1 + mtl_compute_num_print_msg_typemasks(num_entries) + num_entries); diff --git a/tests/python/test_print.py b/tests/python/test_print.py index 4dd98e373d323..f06195d42aa20 100644 --- a/tests/python/test_print.py +++ b/tests/python/test_print.py @@ -21,7 +21,7 @@ def func(): # TODO: As described by @k-ye above, what we want to ensure # is that, the content shows on console is *correct*. -@ti.archs_excluding(ti.metal) +@ti.all_archs def test_multi_print(): @ti.kernel def func(x: ti.i32, y: ti.f32): @@ -31,7 +31,7 @@ def func(x: ti.i32, y: ti.f32): ti.sync() -@ti.archs_excluding(ti.metal) +@ti.all_archs def test_print_string(): @ti.kernel def func(x: ti.i32, y: ti.f32): @@ -43,13 +43,15 @@ def func(x: ti.i32, y: ti.f32): ti.sync() -@ti.archs_excluding(ti.metal) +@ti.all_archs def test_print_matrix(): x = ti.Matrix(2, 3, dt=ti.f32, shape=()) y = ti.Vector(3, dt=ti.f32, shape=3) @ti.kernel def func(k: ti.f32): + x[None][0, 0] = -1.0 + y[2] += 1.0 print('hello', x[None], 'world!') print(y[2] * k, x[None] / k, y[2]) @@ -57,7 +59,7 @@ def func(k: ti.f32): ti.sync() -@ti.archs_excluding(ti.metal) +@ti.all_archs def test_print_sep(): @ti.kernel def func(): @@ -71,3 +73,19 @@ def func(): func() ti.sync() + + +@ti.all_archs +def test_print_multiple_threads(): + x = ti.var(dt=ti.f32, shape=(128, )) + + @ti.kernel + def func(k: ti.f32): + for i in x: + x[i] = i * k + print('x[', i, ']=', x[i]) + + func(0.1) + ti.sync() + func(10.0) + ti.sync()