-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Metal] Add Runtime shaders to support sparse SNode (#614)
- Loading branch information
Showing
3 changed files
with
277 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
#ifndef TI_METAL_NESTED_INCLUDE | ||
|
||
#define TI_METAL_NESTED_INCLUDE | ||
#include "taichi/platform/metal/shaders/runtime_utils.metal.h" | ||
#undef TI_METAL_NESTED_INCLUDE | ||
|
||
#else | ||
#include "taichi/platform/metal/shaders/runtime_utils.metal.h" | ||
#endif // TI_METAL_NESTED_INCLUDE | ||
|
||
#include "taichi/platform/metal/shaders/prolog.h" | ||
|
||
#ifdef TI_INSIDE_METAL_CODEGEN | ||
|
||
#ifndef TI_METAL_NESTED_INCLUDE | ||
#define METAL_BEGIN_RUNTIME_KERNELS_DEF \ | ||
constexpr auto kMetalRuntimeKernelsSourceCode = | ||
#define METAL_END_RUNTIME_KERNELS_DEF ; | ||
#else | ||
#define METAL_BEGIN_RUNTIME_KERNELS_DEF | ||
#define METAL_END_RUNTIME_KERNELS_DEF | ||
#endif // TI_METAL_NESTED_INCLUDE | ||
|
||
#else | ||
|
||
static_assert(false, "Do not include"); | ||
|
||
// Just a mock to illustrate what the Runtime looks like, do not use. | ||
// The actual Runtime struct has to be emitted by codegen, because it depends | ||
// on the number of SNodes. | ||
struct Runtime { | ||
SNodeMeta *snode_metas; | ||
SNodeExtractors *snode_extractors; | ||
ListManager *snode_lists; | ||
}; | ||
|
||
#define METAL_BEGIN_RUNTIME_KERNELS_DEF | ||
#define METAL_END_RUNTIME_KERNELS_DEF | ||
|
||
#endif // TI_INSIDE_METAL_CODEGEN | ||
|
||
METAL_BEGIN_RUNTIME_KERNELS_DEF | ||
STR( | ||
kernel void clear_list(device byte *runtime_addr [[buffer(0)]], | ||
device int *args [[buffer(1)]], | ||
const uint utid_ [[thread_position_in_grid]]) { | ||
if (utid_ > 0) return; | ||
int child_snode_id = args[1]; | ||
device ListManager *child_list = | ||
&(reinterpret_cast<device Runtime *>(runtime_addr) | ||
->snode_lists[child_snode_id]); | ||
clear(child_list); | ||
} | ||
|
||
kernel void element_listgen(device byte *runtime_addr [[buffer(0)]], | ||
device byte *root_addr [[buffer(1)]], | ||
device int *args [[buffer(2)]], | ||
const uint utid_ [[thread_position_in_grid]]) { | ||
device Runtime *runtime = | ||
reinterpret_cast<device Runtime *>(runtime_addr); | ||
device byte *list_data_addr = | ||
reinterpret_cast<device byte *>(runtime + 1); | ||
|
||
int parent_snode_id = args[0]; | ||
int child_snode_id = args[1]; | ||
device ListManager *parent_list = | ||
&(runtime->snode_lists[parent_snode_id]); | ||
device ListManager *child_list = &(runtime->snode_lists[child_snode_id]); | ||
const SNodeMeta child_meta = runtime->snode_metas[child_snode_id]; | ||
const int child_stride = child_meta.element_stride; | ||
const int num_slots = child_meta.num_slots; | ||
if ((int)utid_ >= num_active(parent_list)) { | ||
return; | ||
} | ||
const auto parent_elem = | ||
get<ListgenElement>(parent_list, utid_, list_data_addr); | ||
for (int i = 0; i < num_slots; ++i) { | ||
ListgenElement child_elem; | ||
child_elem.root_mem_offset = parent_elem.root_mem_offset + | ||
i * child_stride + | ||
child_meta.mem_offset_in_parent; | ||
if (is_active(root_addr + child_elem.root_mem_offset, child_meta, i)) { | ||
refine_coordinates(parent_elem, | ||
runtime->snode_extractors[child_snode_id], i, | ||
&child_elem); | ||
append(child_list, child_elem, list_data_addr); | ||
} | ||
} | ||
}) | ||
METAL_END_RUNTIME_KERNELS_DEF | ||
|
||
#undef METAL_BEGIN_RUNTIME_KERNELS_DEF | ||
#undef METAL_END_RUNTIME_KERNELS_DEF | ||
|
||
#include "taichi/platform/metal/shaders/epilog.h" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
#include "taichi/platform/metal/shaders/prolog.h" | ||
|
||
#ifdef TI_INSIDE_METAL_CODEGEN | ||
|
||
#ifndef TI_METAL_NESTED_INCLUDE | ||
#define METAL_BEGIN_RUNTIME_STRUCTS_DEF \ | ||
constexpr auto kMetalRuntimeStructsSourceCode = | ||
#define METAL_END_RUNTIME_STRUCTS_DEF ; | ||
#else | ||
#define METAL_BEGIN_RUNTIME_STRUCTS_DEF | ||
#define METAL_END_RUNTIME_STRUCTS_DEF | ||
#endif // TI_METAL_NESTED_INCLUDE | ||
|
||
#else | ||
|
||
#include <cstdint> | ||
|
||
#include "taichi/inc/constants.h" | ||
|
||
static_assert(taichi_max_num_indices == 8, | ||
"Please update kTaichiMaxNumIndices"); | ||
|
||
#define METAL_BEGIN_RUNTIME_STRUCTS_DEF | ||
#define METAL_END_RUNTIME_STRUCTS_DEF | ||
|
||
#endif // TI_INSIDE_METAL_CODEGEN | ||
|
||
METAL_BEGIN_RUNTIME_STRUCTS_DEF | ||
STR( | ||
constant constexpr int kTaichiMaxNumIndices = 8; | ||
|
||
struct ListgenElement { | ||
int32_t coords[kTaichiMaxNumIndices]; | ||
int32_t root_mem_offset = 0; | ||
}; | ||
|
||
struct ListManager { | ||
int32_t element_stride = 0; | ||
int32_t max_num_elems = 0; | ||
int32_t next = 0; | ||
int32_t mem_begin = 0; | ||
}; | ||
|
||
struct SNodeMeta { | ||
enum Type { Root = 0, Dense = 1, DenseBitmask = 2 }; | ||
int32_t element_stride = 0; | ||
int32_t num_slots = 0; | ||
int32_t mem_offset_in_parent = 0; | ||
int32_t type = 0; | ||
}; | ||
|
||
struct SNodeExtractors { | ||
struct Extractor { | ||
int32_t start = 0; | ||
int32_t num_bits = 0; | ||
int32_t acc_offset = 0; | ||
int32_t num_elements = 0; | ||
}; | ||
|
||
Extractor extractors[kTaichiMaxNumIndices]; | ||
};) | ||
METAL_END_RUNTIME_STRUCTS_DEF | ||
|
||
#undef METAL_BEGIN_RUNTIME_STRUCTS_DEF | ||
#undef METAL_END_RUNTIME_STRUCTS_DEF | ||
|
||
#include "taichi/platform/metal/shaders/epilog.h" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
#ifndef TI_METAL_NESTED_INCLUDE | ||
|
||
#define TI_METAL_NESTED_INCLUDE | ||
#include "taichi/platform/metal/shaders/runtime_structs.metal.h" | ||
#undef TI_METAL_NESTED_INCLUDE | ||
|
||
#else | ||
#include "taichi/platform/metal/shaders/runtime_structs.metal.h" | ||
#endif // TI_METAL_NESTED_INCLUDE | ||
|
||
#include "taichi/platform/metal/shaders/prolog.h" | ||
|
||
#ifdef TI_INSIDE_METAL_CODEGEN | ||
|
||
#ifndef TI_METAL_NESTED_INCLUDE | ||
#define METAL_BEGIN_RUNTIME_UTILS_DEF \ | ||
constexpr auto kMetalRuntimeUtilsSourceCode = | ||
#define METAL_END_RUNTIME_UTILS_DEF ; | ||
#else | ||
#define METAL_BEGIN_RUNTIME_UTILS_DEF | ||
#define METAL_END_RUNTIME_UTILS_DEF | ||
#endif // TI_METAL_NESTED_INCLUDE | ||
|
||
#else | ||
|
||
#define METAL_BEGIN_RUNTIME_UTILS_DEF | ||
#define METAL_END_RUNTIME_UTILS_DEF | ||
|
||
#endif // TI_INSIDE_METAL_CODEGEN | ||
|
||
METAL_BEGIN_RUNTIME_UTILS_DEF | ||
STR( | ||
int num_active(device const ListManager *list) { return list->next; } | ||
|
||
template <typename T> | ||
int append(device ListManager *list, thread const T &elem, | ||
device byte *data_addr) { | ||
thread char *elem_ptr = (thread char *)(&elem); | ||
int me = atomic_fetch_add_explicit( | ||
reinterpret_cast<device atomic_int *>(&(list->next)), 1, | ||
metal::memory_order_relaxed); | ||
device byte *ptr = | ||
data_addr + list->mem_begin + (me * list->element_stride); | ||
for (int i = 0; i < list->element_stride; ++i) { | ||
*ptr = *elem_ptr; | ||
++ptr; | ||
++elem_ptr; | ||
} | ||
return me; | ||
} | ||
|
||
template <typename T> | ||
T get(const device ListManager *list, int i, device const byte *data_addr) { | ||
return *reinterpret_cast<device const T *>(data_addr + list->mem_begin + | ||
(i * list->element_stride)); | ||
} | ||
|
||
void clear(device ListManager *list) { | ||
atomic_store_explicit( | ||
reinterpret_cast<device atomic_int *>(&(list->next)), 0, | ||
metal::memory_order_relaxed); | ||
} | ||
|
||
int is_active(device byte *addr, SNodeMeta meta, int i) { | ||
if (meta.type == SNodeMeta::Root || meta.type == SNodeMeta::Dense) { | ||
return true; | ||
} | ||
device auto *ptr = | ||
reinterpret_cast<device atomic_uint *>( | ||
addr + ((meta.num_slots - i) * meta.element_stride)) + | ||
(i / (sizeof(uint32_t) * 8)); | ||
uint32_t bits = atomic_load_explicit(ptr, metal::memory_order_relaxed); | ||
return ((bits >> (i % (sizeof(uint32_t) * 8))) & 1); | ||
} | ||
|
||
void activate(device byte *addr, SNodeMeta meta, int i) { | ||
if (meta.type == SNodeMeta::Root || meta.type == SNodeMeta::Dense) { | ||
return; | ||
} | ||
device auto *ptr = | ||
reinterpret_cast<device atomic_uint *>( | ||
addr + ((meta.num_slots - i) * meta.element_stride)) + | ||
(i / (sizeof(uint32_t) * 8)); | ||
const uint32_t mask = (1 << (i % (sizeof(uint32_t) * 8))); | ||
atomic_fetch_or_explicit(ptr, mask, metal::memory_order_relaxed); | ||
} | ||
|
||
void deactivate(device byte *addr, SNodeMeta meta, int i) { | ||
if (meta.type == SNodeMeta::Root || meta.type == SNodeMeta::Dense) { | ||
return; | ||
} | ||
device auto *ptr = | ||
reinterpret_cast<device atomic_uint *>( | ||
addr + ((meta.num_slots - i) * meta.element_stride)) + | ||
(i / (sizeof(uint32_t) * 8)); | ||
const uint32_t mask = ~(1 << (i % (sizeof(uint32_t) * 8))); | ||
atomic_fetch_and_explicit(ptr, mask, metal::memory_order_relaxed); | ||
} | ||
|
||
void refine_coordinates(thread const ListgenElement &parent_elem, | ||
device const SNodeExtractors &child_extrators, | ||
int l, thread ListgenElement *child_elem) { | ||
for (int i = 0; i < kTaichiMaxNumIndices; ++i) { | ||
device const auto &ex = child_extrators.extractors[i]; | ||
const int mask = ((1 << ex.num_bits) - 1); | ||
const int addition = (((l >> ex.acc_offset) & mask) << ex.start); | ||
child_elem->coords[i] = (parent_elem.coords[i] | addition); | ||
} | ||
}) | ||
METAL_END_RUNTIME_UTILS_DEF | ||
|
||
#undef METAL_BEGIN_RUNTIME_UTILS_DEF | ||
#undef METAL_END_RUNTIME_UTILS_DEF | ||
|
||
#include "taichi/platform/metal/shaders/epilog.h" |