Skip to content

Commit

Permalink
[metal] Revise NodeManager's implementation due to weak memory order (#…
Browse files Browse the repository at this point in the history
…2008)

* [metal] Revise NodeManager/ListManager's implementation

* [skip ci] enforce code format

Co-authored-by: Taichi Gardener <[email protected]>
  • Loading branch information
k-ye and taichi-gardener authored Oct 31, 2020
1 parent dea88d0 commit a878232
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 83 deletions.
2 changes: 1 addition & 1 deletion taichi/backends/metal/kernel_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,7 @@ class KernelManager::Impl {
if (compiled_structs_.need_snode_lists_data) {
auto *mem_alloc = reinterpret_cast<MemoryAllocator *>(addr);
// Make sure the retured memory address is always greater than 1.
mem_alloc->next = shaders::kAlignment;
mem_alloc->next = shaders::MemoryAllocator::kInitOffset;
// root list data are static
ListgenElement root_elem;
root_elem.mem_offset = 0;
Expand Down
80 changes: 34 additions & 46 deletions taichi/backends/metal/shaders/runtime_structs.metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,18 @@ STR(
// clang-format on
constant constexpr int kTaichiMaxNumIndices = 8;
constant constexpr int kTaichiNumChunks = 1024;
constant constexpr int kAlignment = 8;
using PtrOffset = int32_t;

struct MemoryAllocator { atomic_int next; };
struct MemoryAllocator {
atomic_int next;

constant constexpr static int kInitOffset = 8;

static inline bool is_valid(PtrOffset v) {
return v >= kInitOffset;
}
};

// ListManagerData manages a list of elements with adjustable size.
struct ListManagerData {
Expand All @@ -44,6 +54,28 @@ STR(
atomic_int next;

atomic_int chunks[kTaichiNumChunks];

struct ReservedElemPtrOffset {
public:
ReservedElemPtrOffset() = default;
explicit ReservedElemPtrOffset(PtrOffset v) : val_(v) {
}

inline bool is_valid() const {
return is_valid(val_);
}

inline static bool is_valid(PtrOffset v) {
return MemoryAllocator::is_valid(v);
}

inline PtrOffset value() const {
return val_;
}

private:
PtrOffset val_{0};
};
};

// NodeManagerData stores the actual data needed to implement NodeManager
Expand All @@ -54,6 +86,7 @@ STR(
// few lists (ListManagerData). In particular, |data_list| stores the actual
// data, while |free_list| and |recycle_list| are only meant for GC.
struct NodeManagerData {
using ElemIndex = ListManagerData::ReservedElemPtrOffset;
// Stores the actual data.
ListManagerData data_list;
// For GC
Expand All @@ -62,51 +95,6 @@ STR(
atomic_int free_list_used;
// Need this field to bookkeep some data during GC
int recycled_list_size_backup;

// Use this type instead of the raw index type (int32_t), because the
// raw value needs to be shifted by |kIndexOffset| in order for the
// spinning memory allocation algorithm to work.
struct ElemIndex {
// The first 8 index values are reserved to encode special status:
// * 0 : nullptr
// * 1 : spinning for allocation
// * 2-7: unused for now
//
/// For each allocated index, it is added by |index_offset| to skip over
/// these reserved values.
constant static constexpr int32_t kIndexOffset = 8;

ElemIndex() = default;

static ElemIndex from_index(int i) {
return ElemIndex(i + kIndexOffset);
}

static ElemIndex from_raw(int r) {
return ElemIndex(r);
}

inline int32_t index() const {
return raw_ - kIndexOffset;
}

inline int32_t raw() const {
return raw_;
}

inline bool is_valid() const {
return raw_ >= kIndexOffset;
}

inline static bool is_valid(int raw) {
return ElemIndex::from_raw(raw).is_valid();
}

private:
explicit ElemIndex(int r) : raw_(r) {
}
int32_t raw_ = 0;
};
};

// This class is very similar to metal::SNodeDescriptor
Expand Down
82 changes: 46 additions & 36 deletions taichi/backends/metal/shaders/runtime_utils.metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ struct Runtime {
// clang-format off
METAL_BEGIN_RUNTIME_UTILS_DEF
STR(
using PtrOffset = int32_t;
constant constexpr int kAlignment = 8;

[[maybe_unused]] PtrOffset mtl_memalloc_alloc(device MemoryAllocator *ma,
int32_t size) {
size = ((size + kAlignment - 1) / kAlignment) * kAlignment;
Expand All @@ -57,6 +54,7 @@ STR(
}

struct ListManager {
using ReservedElemPtrOffset = ListManagerData::ReservedElemPtrOffset;
device ListManagerData *lm_data;
device MemoryAllocator *mem_alloc;

Expand All @@ -74,22 +72,19 @@ STR(
resize(0);
}

struct ReserveElemResult {
int elem_idx;
PtrOffset chunk_ptr_offs;
};

ReserveElemResult reserve_new_elem() {
ReservedElemPtrOffset reserve_new_elem() {
const int elem_idx = atomic_fetch_add_explicit(
&lm_data->next, 1, metal::memory_order_relaxed);
const int chunk_idx = elem_idx >> lm_data->log2_num_elems_per_chunk;
const int chunk_idx = get_chunk_index(elem_idx);
const PtrOffset chunk_ptr_offs = ensure_chunk(chunk_idx);
return {elem_idx, chunk_ptr_offs};
const auto offset =
get_elem_ptr_offs_from_chunk(elem_idx, chunk_ptr_offs);
return ReservedElemPtrOffset{offset};
}

device char *append() {
auto reserved = reserve_new_elem();
return get_elem_from_chunk(reserved.elem_idx, reserved.chunk_ptr_offs);
return get_ptr(reserved);
}

template <typename T>
Expand All @@ -104,8 +99,12 @@ STR(
}
}

device char *get_ptr(ReservedElemPtrOffset offs) {
return mtl_memalloc_to_ptr(mem_alloc, offs.value());
}

device char *get_ptr(int i) {
const int chunk_idx = i >> lm_data->log2_num_elems_per_chunk;
const int chunk_idx = get_chunk_index(i);
const PtrOffset chunk_ptr_offs = atomic_load_explicit(
lm_data->chunks + chunk_idx, metal::memory_order_relaxed);
return get_elem_from_chunk(i, chunk_ptr_offs);
Expand All @@ -117,7 +116,11 @@ STR(
}

private:
PtrOffset ensure_chunk(int i) {
inline int get_chunk_index(int elem_idx) const {
return elem_idx >> lm_data->log2_num_elems_per_chunk;
}

PtrOffset ensure_chunk(int chunk_idx) {
PtrOffset offs = 0;
const int chunk_bytes =
(lm_data->element_stride << lm_data->log2_num_elems_per_chunk);
Expand All @@ -128,11 +131,11 @@ STR(
// from requesting memory again. Once allocated, set chunks[i] to the
// actual address offset, which is guaranteed to be greater than 1.
const bool is_me = atomic_compare_exchange_weak_explicit(
lm_data->chunks + i, &stored, 1, metal::memory_order_relaxed,
metal::memory_order_relaxed);
lm_data->chunks + chunk_idx, &stored, 1,
metal::memory_order_relaxed, metal::memory_order_relaxed);
if (is_me) {
offs = mtl_memalloc_alloc(mem_alloc, chunk_bytes);
atomic_store_explicit(lm_data->chunks + i, offs,
atomic_store_explicit(lm_data->chunks + chunk_idx, offs,
metal::memory_order_relaxed);
break;
} else if (stored > 1) {
Expand All @@ -144,11 +147,16 @@ STR(
return offs;
}

device char *get_elem_from_chunk(int i, PtrOffset chunk_ptr_offs) {
device char *chunk_ptr = reinterpret_cast<device char *>(
mtl_memalloc_to_ptr(mem_alloc, chunk_ptr_offs));
PtrOffset get_elem_ptr_offs_from_chunk(int elem_idx,
PtrOffset chunk_ptr_offs) {
const uint32_t mask = ((1 << lm_data->log2_num_elems_per_chunk) - 1);
return chunk_ptr + ((i & mask) * lm_data->element_stride);
return chunk_ptr_offs + ((elem_idx & mask) * lm_data->element_stride);
}

device char *get_elem_from_chunk(int elem_idx, PtrOffset chunk_ptr_offs) {
const auto offs =
get_elem_ptr_offs_from_chunk(elem_idx, chunk_ptr_offs);
return mtl_memalloc_to_ptr(mem_alloc, offs);
}
};

Expand All @@ -172,15 +180,15 @@ STR(
return free_list.get<ElemIndex>(cur_used);
}

return ElemIndex::from_index(data_list.reserve_new_elem().elem_idx);
return data_list.reserve_new_elem();
}

device byte *get(ElemIndex i) {
ListManager data_list;
data_list.lm_data = &(nm_data->data_list);
data_list.mem_alloc = mem_alloc;

return data_list.get_ptr(i.index());
return data_list.get_ptr(i);
}

void recycle(ElemIndex i) {
Expand Down Expand Up @@ -328,33 +336,35 @@ STR(

void activate(int i) {
device auto *nm_idx_ptr = to_nodemgr_idx_ptr(addr_, i);
auto nm_idx_raw =
auto nm_idx_val =
atomic_load_explicit(nm_idx_ptr, metal::memory_order_relaxed);
while (!ElemIndex::is_valid(nm_idx_raw)) {
nm_idx_raw = 0;
while (!ElemIndex::is_valid(nm_idx_val)) {
nm_idx_val = 0;
// See ListManager::ensure_chunk() for the allocation algorithm.
// See also https://github.com/taichi-dev/taichi/issues/1174.
const bool is_me = atomic_compare_exchange_weak_explicit(
nm_idx_ptr, &nm_idx_raw, 1, metal::memory_order_relaxed,
nm_idx_ptr, &nm_idx_val, 1, metal::memory_order_relaxed,
metal::memory_order_relaxed);
if (is_me) {
nm_idx_raw = nm_.allocate().raw();
atomic_store_explicit(nm_idx_ptr, nm_idx_raw,
nm_idx_val = nm_.allocate().value();
atomic_store_explicit(nm_idx_ptr, nm_idx_val,
metal::memory_order_relaxed);
break;
} else if (ElemIndex::is_valid(nm_idx_raw)) {
} else if (ElemIndex::is_valid(nm_idx_val)) {
break;
}
// |nm_idx_raw| == 1, just spin
// |nm_idx_val| == 1, just spin
}
}

void deactivate(int i) {
device auto *nm_idx_ptr = to_nodemgr_idx_ptr(addr_, i);
const auto old_nm_idx_raw = atomic_exchange_explicit(
const auto old_nm_idx_val = atomic_exchange_explicit(
nm_idx_ptr, 0, metal::memory_order_relaxed);
const auto old_nm_idx = ElemIndex::from_raw(old_nm_idx_raw);
if (!old_nm_idx.is_valid()) return;
const auto old_nm_idx = ElemIndex(old_nm_idx_val);
if (!old_nm_idx.is_valid()) {
return;
}
nm_.recycle(old_nm_idx);
}

Expand All @@ -366,8 +376,8 @@ STR(

static inline ElemIndex to_nodemgr_idx(device byte * addr, int ch_i) {
device auto *ptr = to_nodemgr_idx_ptr(addr, ch_i);
const auto r = atomic_load_explicit(ptr, metal::memory_order_relaxed);
return ElemIndex::from_raw(r);
const auto v = atomic_load_explicit(ptr, metal::memory_order_relaxed);
return ElemIndex(v);
}

static bool is_active(device byte * addr, int ch_i) {
Expand Down

0 comments on commit a878232

Please sign in to comment.