Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[metal] Add 3-stage GC Metal kernels #2268

Merged
merged 2 commits into from
Jun 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions taichi/backends/metal/shaders/runtime_kernels.metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,54 @@ STR(
child_list.append(child_elem);
}
}
}

kernel void gc_compact_free_list(
device byte *runtime_addr [[buffer(0)]], device int *args [[buffer(1)]],
const uint utid_ [[thread_position_in_grid]],
const uint grid_size [[threads_per_grid]]) {
device Runtime *runtime =
reinterpret_cast<device Runtime *>(runtime_addr);
device MemoryAllocator *mem_alloc =
reinterpret_cast<device MemoryAllocator *>(runtime + 1);
const int snode_id = args[0];
run_gc_compact_free_list((runtime->snode_allocators + snode_id),
mem_alloc, utid_, grid_size);
}

kernel void gc_reset_free_list(device byte *runtime_addr [[buffer(0)]],
device int *args [[buffer(1)]],
const uint utid_
[[thread_position_in_grid]]) {
if (utid_ > 0) return;

device Runtime *runtime =
reinterpret_cast<device Runtime *>(runtime_addr);
device MemoryAllocator *mem_alloc =
reinterpret_cast<device MemoryAllocator *>(runtime + 1);
const int snode_id = args[0];
run_gc_reset_free_list((runtime->snode_allocators + snode_id), mem_alloc);
}

kernel void gc_move_recycled_to_free(
device byte *runtime_addr [[buffer(0)]], device int *args [[buffer(1)]],
const uint utid_in_tg_ [[thread_position_in_threadgroup]],
const uint utgid_ [[threadgroup_position_in_grid]],
const uint tg_per_grid [[threadgroups_per_grid]],
const uint threads_per_tg [[threads_per_threadgroup]]) {
device Runtime *runtime =
reinterpret_cast<device Runtime *>(runtime_addr);
device MemoryAllocator *mem_alloc =
reinterpret_cast<device MemoryAllocator *>(runtime + 1);
const int snode_id = args[0];

GCMoveRecycledToFreeThreadParams thparams;
thparams.thread_position_in_threadgroup = utid_in_tg_;
thparams.threadgroup_position_in_grid = utgid_;
thparams.threadgroups_per_grid = tg_per_grid;
thparams.threads_per_threadgroup = threads_per_tg;
run_gc_move_recycled_to_free((runtime->snode_allocators + snode_id),
mem_alloc, thparams);
})
METAL_END_RUNTIME_KERNELS_DEF
// clang-format on
Expand Down
139 changes: 130 additions & 9 deletions taichi/backends/metal/shaders/runtime_utils.metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ struct Runtime {

#endif // TI_INSIDE_METAL_CODEGEN

// clang-format off
METAL_BEGIN_RUNTIME_UTILS_DEF
STR(
[[maybe_unused]] PtrOffset mtl_memalloc_alloc(device MemoryAllocator *ma,
Expand All @@ -50,10 +49,10 @@ STR(
metal::memory_order_relaxed);
}

[[maybe_unused]] device char *mtl_memalloc_to_ptr(
device MemoryAllocator *ma, PtrOffset offs) {
return reinterpret_cast<device char *>(ma + 1) + offs;
}
[[maybe_unused]] device char
*mtl_memalloc_to_ptr(device MemoryAllocator *ma, PtrOffset offs) {
return reinterpret_cast<device char *>(ma + 1) + offs;
}

struct ListManager {
using ReservedElemPtrOffset = ListManagerData::ReservedElemPtrOffset;
Expand Down Expand Up @@ -206,7 +205,9 @@ STR(
// * init(), instead of doing initiliaztion in the constructor.
class SNodeRep_dense {
public:
void init(device byte * addr) { addr_ = addr; }
void init(device byte * addr) {
addr_ = addr;
}

inline device byte *addr() {
return addr_;
Expand Down Expand Up @@ -418,7 +419,8 @@ STR(

[[maybe_unused]] void refine_coordinates(
thread const ElementCoords &parent,
device const SNodeExtractors &child_extrators, int l,
device const SNodeExtractors &child_extrators,
int l,
thread ElementCoords *child) {
for (int i = 0; i < kTaichiMaxNumIndices; ++i) {
device const auto &ex = child_extrators.extractors[i];
Expand All @@ -430,8 +432,10 @@ STR(

// Gets the address of an SNode cell identified by |lgen|.
[[maybe_unused]] device byte *mtl_lgen_snode_addr(
thread const ListgenElement &lgen, device byte *root_addr,
device Runtime *rtm, device MemoryAllocator *mem_alloc) {
thread const ListgenElement &lgen,
device byte *root_addr,
device Runtime *rtm,
device MemoryAllocator *mem_alloc) {
if (lgen.in_root_buffer()) {
return root_addr + lgen.mem_offset;
}
Expand All @@ -440,6 +444,123 @@ STR(
nm.mem_alloc = mem_alloc;
device byte *addr = nm.get(lgen.belonged_nodemgr.elem_idx);
return addr + lgen.mem_offset;
}

// GC utils
[[maybe_unused]] void run_gc_compact_free_list(
device NodeManagerData *nm_data,
device MemoryAllocator *mem_alloc,
const int tid,
const int grid_size) {
NodeManager nm;
nm.nm_data = nm_data;
nm.mem_alloc = mem_alloc;

ListManager free_list;
free_list.lm_data = &(nm.nm_data->free_list);
free_list.mem_alloc = nm.mem_alloc;

const int free_size = free_list.num_active();
const int free_used = atomic_load_explicit(&(nm.nm_data->free_list_used),
metal::memory_order_relaxed);

int num_to_copy = 0;
if (free_used * 2 > free_size) {
num_to_copy = free_size - free_used;
} else {
num_to_copy = free_used;
}
const int offs = free_size - num_to_copy;

using ElemIndex = NodeManager::ElemIndex;
for (int ii = tid; ii < num_to_copy; ii += grid_size) {
device auto *dest =
reinterpret_cast<device ElemIndex *>(free_list.get_ptr(ii));
*dest = free_list.get<ElemIndex>(ii + offs);
}
}

[[maybe_unused]] void run_gc_reset_free_list(
device NodeManagerData *nm_data,
device MemoryAllocator *mem_alloc) {
NodeManager nm;
nm.nm_data = nm_data;
nm.mem_alloc = mem_alloc;

ListManager free_list;
free_list.lm_data = &(nm.nm_data->free_list);
free_list.mem_alloc = nm.mem_alloc;
const int free_size = free_list.num_active();
const int free_used = atomic_exchange_explicit(
&(nm.nm_data->free_list_used), 0, metal::memory_order_relaxed);

int free_remaining = free_size - free_used;
free_remaining = free_remaining > 0 ? free_remaining : 0;
free_list.resize(free_remaining);

nm.nm_data->recycled_list_size_backup = atomic_exchange_explicit(
&(nm.nm_data->recycled_list.next), 0, metal::memory_order_relaxed);
}

struct GCMoveRecycledToFreeThreadParams {
int thread_position_in_threadgroup;
int threadgroup_position_in_grid;
int threadgroups_per_grid;
int threads_per_threadgroup;
};

[[maybe_unused]] void run_gc_move_recycled_to_free(
device NodeManagerData *nm_data,
device MemoryAllocator *mem_alloc,
thread const GCMoveRecycledToFreeThreadParams &thparams) {
NodeManager nm;
nm.nm_data = nm_data;
nm.mem_alloc = mem_alloc;

ListManager free_list;
free_list.lm_data = &(nm.nm_data->free_list);
free_list.mem_alloc = nm.mem_alloc;

ListManager recycled_list;
recycled_list.lm_data = &(nm.nm_data->recycled_list);
recycled_list.mem_alloc = nm.mem_alloc;

ListManager data_list;
data_list.lm_data = &(nm.nm_data->data_list);
data_list.mem_alloc = nm.mem_alloc;

const int kInt32Stride = sizeof(int32_t);

const int recycled_size = nm.nm_data->recycled_list_size_backup;
using ElemIndex = NodeManager::ElemIndex;
for (int ii = thparams.threadgroup_position_in_grid; ii < recycled_size;
ii += thparams.threadgroups_per_grid) {
const auto elem_idx = recycled_list.get<ElemIndex>(ii);
device char *ptr = nm.get(elem_idx);
device const char *ptr_end = ptr + data_list.lm_data->element_stride;
const int ptr_mod = ((int64_t)(ptr) % kInt32Stride);
if (ptr_mod) {
device char *new_ptr = ptr + kInt32Stride - ptr_mod;
if (thparams.thread_position_in_threadgroup == 0) {
for (device char *p = ptr; p < new_ptr; ++p) {
*p = 0;
}
}
ptr = new_ptr;
}
ptr += (thparams.thread_position_in_threadgroup * kInt32Stride);
while ((ptr + kInt32Stride) <= ptr_end) {
*reinterpret_cast<device int32_t *>(ptr) = 0;
ptr += (kInt32Stride * thparams.threads_per_threadgroup);
}
while (ptr < ptr_end) {
*ptr = 0;
++ptr;
}
if (thparams.thread_position_in_threadgroup == 0) {
free_list.append(elem_idx);
}
}
})
METAL_END_RUNTIME_UTILS_DEF
// clang-format on
Expand Down