Skip to content

Commit

Permalink
[metal] Misc tweaks to make taichi-dev#1480 easier to review
Browse files Browse the repository at this point in the history
  • Loading branch information
k-ye committed Jul 12, 2020
1 parent abc7e6b commit 2be1af0
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 55 deletions.
34 changes: 21 additions & 13 deletions taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,16 @@ class KernelCodegen : public IRVisitor {
}

void visit(GetChStmt *stmt) override {
// E.g. `parent.get*()`
const auto get_call =
fmt::format("{}.get{}()", stmt->input_ptr->raw_name(), stmt->chid);
if (stmt->output_snode->is_place()) {
emit(R"(device {}* {} = {}.get{}().val;)",
emit(R"(device {}* {} = {}.val;)",
metal_data_type_name(stmt->output_snode->dt), stmt->raw_name(),
stmt->input_ptr->raw_name(), stmt->chid);
get_call);
} else {
emit(R"({} {} = {}.get{}();)", stmt->output_snode->node_type_name,
stmt->raw_name(), stmt->input_ptr->raw_name(), stmt->chid);
emit(R"({} {} = {};)", stmt->output_snode->node_type_name,
stmt->raw_name(), get_call);
}
}

Expand Down Expand Up @@ -689,7 +692,7 @@ class KernelCodegen : public IRVisitor {

void generate_structs() {
SectionGuard sg(this, Section::Structs);
emit("using byte = uchar;");
emit("using byte = char;");
emit("");
current_appender().append_raw(shaders::kMetalHelpersSourceCode);
emit("");
Expand Down Expand Up @@ -859,6 +862,9 @@ class KernelCodegen : public IRVisitor {
emit("const int end_ = {} + {};", total_elems_name, begin_expr);

if (used_tls) {
// Using TLS means we will access some SNodes within this kernel. The
// struct of an SNode needs Runtime and MemoryAllocator to construct.
emit_runtime_and_memalloc_def();
// Using |int32_t| because it aligns to 4bytes.
emit("// TLS prologue");
const std::string tls_bufi32_name = "tls_bufi32_";
Expand Down Expand Up @@ -925,12 +931,7 @@ class KernelCodegen : public IRVisitor {

current_appender().push_indent();
emit("// struct_for");
emit("device Runtime *{} = reinterpret_cast<device Runtime *>({});",
kRuntimeVarName, kRuntimeBufferName);
emit(
"device MemoryAllocator *{} = reinterpret_cast<device MemoryAllocator "
"*>({} + 1);",
kMemAllocVarName, kRuntimeVarName);
emit_runtime_and_memalloc_def();
emit("ListManager parent_list;");
emit("parent_list.lm_data = ({}->snode_lists + {});", kRuntimeVarName,
sn_id);
Expand Down Expand Up @@ -1051,8 +1052,7 @@ class KernelCodegen : public IRVisitor {

{
ScopedIndent s(current_appender());
emit("device Runtime *{} = reinterpret_cast<device Runtime *>({});",
kRuntimeVarName, kRuntimeBufferName);
emit_runtime_and_memalloc_def();
if (!ctx_attribs_.empty()) {
emit("{} {}({});", kernel_args_classname(), kContextVarName,
kContextBufferName);
Expand Down Expand Up @@ -1156,6 +1156,14 @@ class KernelCodegen : public IRVisitor {
return la.lines();
}

void emit_runtime_and_memalloc_def() {
emit("device auto *{} = reinterpret_cast<device Runtime *>({});",
kRuntimeVarName, kRuntimeBufferName);
emit(
"device auto *{} = reinterpret_cast<device MemoryAllocator *>({} + 1);",
kMemAllocVarName, kRuntimeVarName);
}

std::string make_kernel_name() {
return fmt::format("{}_{}", mtl_kernel_prefix_, mtl_kernel_count_++);
}
Expand Down
29 changes: 19 additions & 10 deletions taichi/backends/metal/kernel_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class BufferMemoryView {
size_ = iroundup(size, taichi_page_size);
ptr_ = mem_pool->allocate(size_, /*alignment=*/taichi_page_size);
TI_ASSERT(ptr_ != nullptr);
std::memset(ptr_, 0, size_);
}
// Move only
BufferMemoryView(BufferMemoryView &&) = default;
Expand Down Expand Up @@ -211,6 +212,12 @@ class RuntimeListOpsMtlKernel : public CompiledMtlKernelBase {
mem[1] = child_snode_id_;
const auto &sn_descs = *params.snode_descriptors;
mem[2] = total_num_self_from_root(sn_descs, child_snode_id_);
TI_DEBUG(
"Registered RuntimeListOpsMtlKernel: name={} num_threads={} "
"parent_snode={} "
"child_snode={} max_num_elems={} ",
params.kernel_attribs->name, params.kernel_attribs->num_threads, mem[0],
mem[1], mem[2]);
did_modify_range(args_buffer_.get(), /*location=*/0, args_mem_->size());
}

Expand Down Expand Up @@ -651,8 +658,6 @@ class KernelManager::Impl {
rtm_meta->element_stride = sn_meta.element_stride;
rtm_meta->num_slots = sn_meta.num_slots;
rtm_meta->mem_offset_in_parent = sn_meta.mem_offset_in_parent;
TI_DEBUG("SnodeMeta\n id={}\n element_stride={}\n num_slots={}\n", i,
rtm_meta->element_stride, rtm_meta->num_slots);
switch (sn_meta.snode->type) {
case SNodeType::dense:
rtm_meta->type = SNodeMeta::Dense;
Expand All @@ -671,6 +676,11 @@ class KernelManager::Impl {
snode_type_name(sn_meta.snode->type));
break;
}
TI_DEBUG(
"SnodeMeta\n id={}\n type={}\n element_stride={}\n "
"num_slots={}\n",
i, snode_type_name(sn_meta.snode->type), rtm_meta->element_stride,
rtm_meta->num_slots);
}
size_t addr_offset = sizeof(SNodeMeta) * max_snodes;
addr += addr_offset;
Expand Down Expand Up @@ -701,7 +711,7 @@ class KernelManager::Impl {
TI_DEBUG("Initialized SNodeExtractors, size={} accumulated={}", addr_offset,
(addr - addr_begin));
// init snode_lists
ListManagerData *const rtm_list_head =
ListManagerData *const rtm_list_begin =
reinterpret_cast<ListManagerData *>(addr);
for (int i = 0; i < max_snodes; ++i) {
auto iter = snode_descriptors.find(i);
Expand All @@ -721,7 +731,7 @@ class KernelManager::Impl {
}
addr_offset = sizeof(ListManagerData) * max_snodes;
addr += addr_offset;
TI_DEBUG("Initialized ListManagerData, size={} accumuated={}", addr_offset,
TI_DEBUG("Initialized ListManagerData, size={} accumulated={}", addr_offset,
(addr - addr_begin));
// init rand_seeds
// TODO(k-ye): Provide a way to use a fixed seed in dev mode.
Expand All @@ -740,18 +750,18 @@ class KernelManager::Impl {
kNumRandSeeds * sizeof(uint32_t), (addr - addr_begin));

if (compiled_structs_.need_snode_lists_data) {
auto *alloc = reinterpret_cast<MemoryAllocator *>(addr);
auto *mem_alloc = reinterpret_cast<MemoryAllocator *>(addr);
// Make sure the retured memory address is always greater than 1.
alloc->next = shaders::kAlignment;
mem_alloc->next = shaders::kAlignment;
// root list data are static
ListgenElement root_elem;
root_elem.root_mem_offset = 0;
for (int i = 0; i < taichi_max_num_indices; ++i) {
root_elem.coords[i] = 0;
}
ListManager root_lm;
root_lm.lm_data = rtm_list_head + root_id;
root_lm.mem_alloc = alloc;
root_lm.lm_data = rtm_list_begin + root_id;
root_lm.mem_alloc = mem_alloc;
root_lm.append(root_elem);
}

Expand All @@ -760,8 +770,7 @@ class KernelManager::Impl {
}

void init_print_buffer() {
// This includes setting PrintMsgAllocator::next to zero.
std::memset(print_mem_->ptr(), 0, print_mem_->size());
// TODO(k-ye): Do we need this at all?
did_modify_range(print_buffer_.get(), /*location=*/0, print_mem_->size());
}

Expand Down
2 changes: 1 addition & 1 deletion taichi/backends/metal/shaders/runtime_structs.metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ STR(
int32_t root_mem_offset = 0;
};

// ListManagerData manages the activeness of its associated SNode.
// ListManagerData manages a list of elements with adjustable size.
struct ListManagerData {
int32_t element_stride = 0;

Expand Down
20 changes: 12 additions & 8 deletions taichi/backends/metal/shaders/runtime_utils.metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@
// clang-format off
METAL_BEGIN_RUNTIME_UTILS_DEF
STR(
using PtrOffset = int32_t; constant constexpr int kAlignment = 8;
using PtrOffset = int32_t;
constant constexpr int kAlignment = 8;

[[maybe_unused]] PtrOffset mtl_memalloc_alloc(device MemoryAllocator * ma,
[[maybe_unused]] PtrOffset mtl_memalloc_alloc(device MemoryAllocator *ma,
int32_t size) {
size = ((size + kAlignment - 1) / kAlignment) * kAlignment;
return atomic_fetch_add_explicit(&ma->next, size,
Expand Down Expand Up @@ -93,19 +94,22 @@ STR(
}
}

template <typename T>
T get(int i) {
device char *get_ptr(int i) {
const int chunk_idx = i >> lm_data->log2_num_elems_per_chunk;
const PtrOffset chunk_ptr_offs = atomic_load_explicit(
lm_data->chunks + chunk_idx, metal::memory_order_relaxed);
return *reinterpret_cast<device T *>(
get_elem_from_chunk(i, chunk_ptr_offs));
return get_elem_from_chunk(i, chunk_ptr_offs);
}

template <typename T>
T get(int i) {
return *reinterpret_cast<device T *>(get_ptr(i));
}

private:
PtrOffset ensure_chunk(int i) {
PtrOffset offs = 0;
const int kChunkBytes =
const int chunk_bytes =
(lm_data->element_stride << lm_data->log2_num_elems_per_chunk);

while (true) {
Expand All @@ -117,7 +121,7 @@ STR(
lm_data->chunks + i, &stored, 1, metal::memory_order_relaxed,
metal::memory_order_relaxed);
if (is_me) {
offs = mtl_memalloc_alloc(mem_alloc, kChunkBytes);
offs = mtl_memalloc_alloc(mem_alloc, chunk_bytes);
atomic_store_explicit(lm_data->chunks + i, offs,
metal::memory_order_relaxed);
break;
Expand Down
39 changes: 16 additions & 23 deletions taichi/backends/metal/struct_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ class StructCompiler {
++max_snodes_;
}

for (auto &n : snodes_rev) {
generate_types(*n);
}
CompiledStructs result;
result.root_size = compute_snode_size(&root);
line_appender_.dump(&result.snode_structs_source_code);
emit_runtime_structs(&root);
emit_runtime_structs();
line_appender_.dump(&result.runtime_utils_source_code);
result.runtime_size = compute_runtime_size();
for (auto &n : snodes_rev) {
generate_types(*n);
}
line_appender_.dump(&result.snode_structs_source_code);
result.need_snode_lists_data = has_sparse_snode_;
result.max_snodes = max_snodes_;
result.snode_descriptors = std::move(snode_descriptors_);
Expand All @@ -97,34 +97,27 @@ class StructCompiler {
void generate_types(const SNode &snode) {
const bool is_place = snode.is_place();
if (!is_place) {
// Generate {snode}_ch
const std::string class_name = snode.node_type_name + "_ch";
emit("class {} {{", class_name);
emit(" public:");
emit(" {}(device byte* a) : addr_(a) {{}}", class_name);
emit(" {}(device byte *a) : addr_(a) {{}}", class_name);

std::string stride_str;
std::string stride_str = "0";
for (int i = 0; i < (int)snode.ch.size(); i++) {
const auto &ch_node_name = snode.ch[i]->node_type_name;
emit(" {} get{}() {{", ch_node_name, i);
if (stride_str.empty()) {
emit(" return {{addr_}};");
stride_str = ch_node_name + "::stride";
} else {
emit(" return {{addr_ + ({})}};", stride_str);
stride_str += " + " + ch_node_name + "::stride";
}
emit(" return {{addr_ + ({})}};", stride_str);
stride_str += " + " + ch_node_name + "::stride";
emit(" }}");
emit("");
}
emit(" device byte* addr() {{ return addr_; }}");
emit(" device byte *addr() {{ return addr_; }}");
emit("");
if (stride_str.empty()) {
// Is it possible for this to have no children?
stride_str = "0";
}
// Is it possible for this to have no children?
emit(" constant static constexpr int stride = {};", stride_str);
emit(" private:");
emit(" device byte* addr_;");
emit(" device byte *addr_;");
emit("}};");
}
emit("");
Expand Down Expand Up @@ -212,17 +205,17 @@ class StructCompiler {
return sn_desc.stride;
}

void emit_runtime_structs(const SNode *root) {
void emit_runtime_structs() {
line_appender_.append_raw(shaders::kMetalRuntimeStructsSourceCode);
emit("");
line_appender_.append_raw(shaders::kMetalRuntimeUtilsSourceCode);
emit("");
emit("struct Runtime {{");
emit(" SNodeMeta snode_metas[{}];", max_snodes_);
emit(" SNodeExtractors snode_extractors[{}];", max_snodes_);
emit(" ListManagerData snode_lists[{}];", max_snodes_);
emit(" uint32_t rand_seeds[{}];", kNumRandSeeds);
emit("}};");
line_appender_.append_raw(shaders::kMetalRuntimeUtilsSourceCode);
emit("");
}

size_t compute_runtime_size() {
Expand Down

0 comments on commit 2be1af0

Please sign in to comment.