Skip to content

Commit

Permalink
[metal] Refactor runtime ListManager utils (#1444)
Browse files Browse the repository at this point in the history
* [metal] Refactor runtime ListManager utils

* rm NodeManagerData

* fix runtime size
  • Loading branch information
k-ye authored Jul 11, 2020
1 parent 0d2e34f commit f6cf282
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 91 deletions.
6 changes: 3 additions & 3 deletions taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -946,11 +946,11 @@ class KernelCodegen : public IRVisitor {
{
ScopedIndent s2(current_appender());
emit("const int parent_idx_ = (ii / child_num_slots);");
emit("if (parent_idx_ >= num_active(&parent_list)) return;");
emit("if (parent_idx_ >= parent_list.num_active()) return;");
emit("const int child_idx_ = (ii % child_num_slots);");
emit(
"const auto parent_elem_ = get<ListgenElement>(&parent_list, "
"parent_idx_);");
"const auto parent_elem_ = "
"parent_list.get<ListgenElement>(parent_idx_);");

emit("ListgenElement {};", kListgenElemVarName);
// No need to add mem_offset_in_parent, because place() always starts at 0
Expand Down
2 changes: 1 addition & 1 deletion taichi/backends/metal/kernel_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,7 @@ class KernelManager::Impl {
ListManager root_lm;
root_lm.lm_data = rtm_list_head + root_id;
root_lm.mem_alloc = alloc;
append(&root_lm, root_elem);
root_lm.append(root_elem);
}

did_modify_range(runtime_buffer_.get(), /*location=*/0,
Expand Down
8 changes: 4 additions & 4 deletions taichi/backends/metal/shaders/runtime_kernels.metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ STR(
child_list.lm_data =
(reinterpret_cast<device Runtime *>(runtime_addr)->snode_lists +
child_snode_id);
clear(&child_list);
child_list.clear();
}

kernel void element_listgen(device byte *runtime_addr[[buffer(0)]],
Expand Down Expand Up @@ -83,13 +83,13 @@ STR(
const int max_num_elems = args[2];
for (int ii = utid_; ii < max_num_elems; ii += grid_size) {
const int parent_idx = (ii / num_slots);
if (parent_idx >= num_active(&parent_list)) {
if (parent_idx >= parent_list.num_active()) {
// Since |parent_idx| increases monotonically, we can return directly
// once it goes beyond the number of active parent elements.
return;
}
const int child_idx = (ii % num_slots);
const auto parent_elem = get<ListgenElement>(&parent_list, parent_idx);
const auto parent_elem = parent_list.get<ListgenElement>(parent_idx);
ListgenElement child_elem;
child_elem.root_mem_offset = parent_elem.root_mem_offset +
child_idx * child_stride +
Expand All @@ -99,7 +99,7 @@ STR(
refine_coordinates(parent_elem,
runtime->snode_extractors[parent_snode_id],
child_idx, &child_elem);
append(&child_list, child_elem);
child_list.append(child_elem);
}
}
}
Expand Down
5 changes: 0 additions & 5 deletions taichi/backends/metal/shaders/runtime_structs.metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,6 @@ STR(
atomic_int chunks[kTaichiNumChunks];
};

struct ListManager {
device ListManagerData *lm_data;
device MemoryAllocator *mem_alloc;
};

// This class is very similar to metal::SNodeDescriptor
struct SNodeMeta {
enum Type { Root = 0, Dense = 1, Bitmasked = 2, Dynamic = 3 };
Expand Down
160 changes: 85 additions & 75 deletions taichi/backends/metal/shaders/runtime_utils.metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@
// clang-format off
METAL_BEGIN_RUNTIME_UTILS_DEF
STR(
using PtrOffset = int32_t;
constant constexpr int kAlignment = 8;
using PtrOffset = int32_t; constant constexpr int kAlignment = 8;

[[maybe_unused]] PtrOffset mtl_memalloc_alloc(device MemoryAllocator * ma,
int32_t size) {
Expand All @@ -46,86 +45,98 @@ STR(
return reinterpret_cast<device char *>(ma + 1) + offs;
}

[[maybe_unused]] int num_active(thread ListManager *l) {
return atomic_load_explicit(&(l->lm_data->next),
metal::memory_order_relaxed);
}
struct ListManager {
device ListManagerData *lm_data;
device MemoryAllocator *mem_alloc;

[[maybe_unused]] void clear(thread ListManager *l) {
atomic_store_explicit(&(l->lm_data->next), 0,
metal::memory_order_relaxed);
}
inline int num_active() {
return atomic_load_explicit(&(lm_data->next),
metal::memory_order_relaxed);
}

[[maybe_unused]] PtrOffset mtl_listmgr_ensure_chunk(thread ListManager *l,
int i) {
device ListManagerData *list = l->lm_data;
PtrOffset offs = 0;
const int kChunkBytes =
(list->element_stride << list->log2_num_elems_per_chunk);

while (true) {
int stored = 0;
// If chunks[i] is unallocated, i.e. 0, mark it as 1 to prevent others
// from requesting memory again. Once allocated, set chunks[i] to the
// actual address offset, which is guaranteed to be greater than 1.
const bool is_me = atomic_compare_exchange_weak_explicit(
list->chunks + i, &stored, 1, metal::memory_order_relaxed,
metal::memory_order_relaxed);
if (is_me) {
offs = mtl_memalloc_alloc(l->mem_alloc, kChunkBytes);
atomic_store_explicit(list->chunks + i, offs,
metal::memory_order_relaxed);
break;
} else if (stored > 1) {
offs = stored;
break;
}
// |stored| == 1, just spin
inline void resize(int sz) {
atomic_store_explicit(&(lm_data->next), sz,
metal::memory_order_relaxed);
}
return offs;
}

[[maybe_unused]] device char *mtl_listmgr_get_elem_from_chunk(
thread ListManager *l,
int i,
PtrOffset chunk_ptr_offs) {
device ListManagerData *list = l->lm_data;
device char *chunk_ptr = reinterpret_cast<device char *>(
mtl_memalloc_to_ptr(l->mem_alloc, chunk_ptr_offs));
const uint32_t mask = ((1 << list->log2_num_elems_per_chunk) - 1);
return chunk_ptr + ((i & mask) * list->element_stride);
}
inline void clear() {
resize(0);
}

[[maybe_unused]] device char *append(thread ListManager *l) {
device ListManagerData *list = l->lm_data;
const int elem_idx = atomic_fetch_add_explicit(
&list->next, 1, metal::memory_order_relaxed);
const int chunk_idx = elem_idx >> list->log2_num_elems_per_chunk;
const PtrOffset chunk_ptr_offs = mtl_listmgr_ensure_chunk(l, chunk_idx);
return mtl_listmgr_get_elem_from_chunk(l, elem_idx, chunk_ptr_offs);
}
struct ReserveElemResult {
int elem_idx;
PtrOffset chunk_ptr_offs;
};

ReserveElemResult reserve_new_elem() {
const int elem_idx = atomic_fetch_add_explicit(
&lm_data->next, 1, metal::memory_order_relaxed);
const int chunk_idx = elem_idx >> lm_data->log2_num_elems_per_chunk;
const PtrOffset chunk_ptr_offs = ensure_chunk(chunk_idx);
return {elem_idx, chunk_ptr_offs};
}

device char *append() {
auto reserved = reserve_new_elem();
return get_elem_from_chunk(reserved.elem_idx, reserved.chunk_ptr_offs);
}

template <typename T>
[[maybe_unused]] void append(thread ListManager *l, thread const T &elem) {
device char *ptr = append(l);
thread char *elem_ptr = (thread char *)(&elem);
template <typename T>
void append(thread const T &elem) {
device char *ptr = append();
thread char *elem_ptr = (thread char *)(&elem);

for (int i = 0; i < l->lm_data->element_stride; ++i) {
*ptr = *elem_ptr;
++ptr;
++elem_ptr;
for (int i = 0; i < lm_data->element_stride; ++i) {
*ptr = *elem_ptr;
++ptr;
++elem_ptr;
}
}
}

template <typename T>
[[maybe_unused]] T get(thread ListManager *l, int i) {
device ListManagerData *list = l->lm_data;
const int chunk_idx = i >> list->log2_num_elems_per_chunk;
const PtrOffset chunk_ptr_offs = atomic_load_explicit(
list->chunks + chunk_idx, metal::memory_order_relaxed);
return *reinterpret_cast<device T *>(
mtl_listmgr_get_elem_from_chunk(l, i, chunk_ptr_offs));
}
template <typename T>
T get(int i) {
const int chunk_idx = i >> lm_data->log2_num_elems_per_chunk;
const PtrOffset chunk_ptr_offs = atomic_load_explicit(
lm_data->chunks + chunk_idx, metal::memory_order_relaxed);
return *reinterpret_cast<device T *>(
get_elem_from_chunk(i, chunk_ptr_offs));
}

private:
PtrOffset ensure_chunk(int i) {
PtrOffset offs = 0;
const int kChunkBytes =
(lm_data->element_stride << lm_data->log2_num_elems_per_chunk);

while (true) {
int stored = 0;
// If chunks[i] is unallocated, i.e. 0, mark it as 1 to prevent others
// from requesting memory again. Once allocated, set chunks[i] to the
// actual address offset, which is guaranteed to be greater than 1.
const bool is_me = atomic_compare_exchange_weak_explicit(
lm_data->chunks + i, &stored, 1, metal::memory_order_relaxed,
metal::memory_order_relaxed);
if (is_me) {
offs = mtl_memalloc_alloc(mem_alloc, kChunkBytes);
atomic_store_explicit(lm_data->chunks + i, offs,
metal::memory_order_relaxed);
break;
} else if (stored > 1) {
offs = stored;
break;
}
// |stored| == 1, just spin
}
return offs;
}

device char *get_elem_from_chunk(int i, PtrOffset chunk_ptr_offs) {
device char *chunk_ptr = reinterpret_cast<device char *>(
mtl_memalloc_to_ptr(mem_alloc, chunk_ptr_offs));
const uint32_t mask = ((1 << lm_data->log2_num_elems_per_chunk) - 1);
return chunk_ptr + ((i & mask) * lm_data->element_stride);
}
};

[[maybe_unused]] int is_active(device byte *addr, SNodeMeta meta, int i) {
if (meta.type == SNodeMeta::Root || meta.type == SNodeMeta::Dense) {
Expand Down Expand Up @@ -207,8 +218,7 @@ STR(
device auto *n_ptr = reinterpret_cast<device atomic_int *>(
addr + (meta.num_slots * meta.element_stride));
return atomic_load_explicit(n_ptr, metal::memory_order_relaxed);
}
)
})
METAL_END_RUNTIME_UTILS_DEF
// clang-format on

Expand Down
6 changes: 3 additions & 3 deletions taichi/backends/metal/struct_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace shaders {

} // namespace shaders

constexpr size_t kListManagerSize = sizeof(shaders::ListManager);
constexpr size_t kListManagerDataSize = sizeof(shaders::ListManagerData);
constexpr size_t kSNodeMetaSize = sizeof(shaders::SNodeMeta);
constexpr size_t kSNodeExtractorsSize = sizeof(shaders::SNodeExtractors);

Expand Down Expand Up @@ -226,8 +226,8 @@ class StructCompiler {
}

size_t compute_runtime_size() {
size_t result = (max_snodes_) *
(kSNodeMetaSize + kSNodeExtractorsSize + kListManagerSize);
size_t result = (max_snodes_) * (kSNodeMetaSize + kSNodeExtractorsSize +
kListManagerDataSize);
result += sizeof(uint32_t) * kNumRandSeeds;
return result;
}
Expand Down

0 comments on commit f6cf282

Please sign in to comment.