diff --git a/taichi/platform/metal/shaders/runtime_kernels.metal.h b/taichi/platform/metal/shaders/runtime_kernels.metal.h new file mode 100644 index 0000000000000..31a1bb7554ea2 --- /dev/null +++ b/taichi/platform/metal/shaders/runtime_kernels.metal.h @@ -0,0 +1,95 @@ +#ifndef TI_METAL_NESTED_INCLUDE + +#define TI_METAL_NESTED_INCLUDE +#include "taichi/platform/metal/shaders/runtime_utils.metal.h" +#undef TI_METAL_NESTED_INCLUDE + +#else +#include "taichi/platform/metal/shaders/runtime_utils.metal.h" +#endif // TI_METAL_NESTED_INCLUDE + +#include "taichi/platform/metal/shaders/prolog.h" + +#ifdef TI_INSIDE_METAL_CODEGEN + +#ifndef TI_METAL_NESTED_INCLUDE +#define METAL_BEGIN_RUNTIME_KERNELS_DEF \ + constexpr auto kMetalRuntimeKernelsSourceCode = +#define METAL_END_RUNTIME_KERNELS_DEF ; +#else +#define METAL_BEGIN_RUNTIME_KERNELS_DEF +#define METAL_END_RUNTIME_KERNELS_DEF +#endif // TI_METAL_NESTED_INCLUDE + +#else + +static_assert(false, "Do not include"); + +// Just a mock to illustrate what the Runtime looks like, do not use. +// The actual Runtime struct has to be emitted by codegen, because it depends +// on the number of SNodes. +struct Runtime { + SNodeMeta *snode_metas; + SNodeExtractors *snode_extractors; + ListManager *snode_lists; +}; + +#define METAL_BEGIN_RUNTIME_KERNELS_DEF +#define METAL_END_RUNTIME_KERNELS_DEF + +#endif // TI_INSIDE_METAL_CODEGEN + +METAL_BEGIN_RUNTIME_KERNELS_DEF +STR( + kernel void clear_list(device byte *runtime_addr [[buffer(0)]], + device int *args [[buffer(1)]], + const uint utid_ [[thread_position_in_grid]]) { + if (utid_ > 0) return; + int child_snode_id = args[1]; + device ListManager *child_list = + &(reinterpret_cast(runtime_addr) + ->snode_lists[child_snode_id]); + clear(child_list); + } + + kernel void element_listgen(device byte *runtime_addr [[buffer(0)]], + device byte *root_addr [[buffer(1)]], + device int *args [[buffer(2)]], + const uint utid_ [[thread_position_in_grid]]) { + device Runtime *runtime = + reinterpret_cast(runtime_addr); + device byte *list_data_addr = + reinterpret_cast(runtime + 1); + + int parent_snode_id = args[0]; + int child_snode_id = args[1]; + device ListManager *parent_list = + &(runtime->snode_lists[parent_snode_id]); + device ListManager *child_list = &(runtime->snode_lists[child_snode_id]); + const SNodeMeta child_meta = runtime->snode_metas[child_snode_id]; + const int child_stride = child_meta.element_stride; + const int num_slots = child_meta.num_slots; + if ((int)utid_ >= num_active(parent_list)) { + return; + } + const auto parent_elem = + get(parent_list, utid_, list_data_addr); + for (int i = 0; i < num_slots; ++i) { + ListgenElement child_elem; + child_elem.root_mem_offset = parent_elem.root_mem_offset + + i * child_stride + + child_meta.mem_offset_in_parent; + if (is_active(root_addr + child_elem.root_mem_offset, child_meta, i)) { + refine_coordinates(parent_elem, + runtime->snode_extractors[child_snode_id], i, + &child_elem); + append(child_list, child_elem, list_data_addr); + } + } + }) +METAL_END_RUNTIME_KERNELS_DEF + +#undef METAL_BEGIN_RUNTIME_KERNELS_DEF +#undef METAL_END_RUNTIME_KERNELS_DEF + +#include "taichi/platform/metal/shaders/epilog.h" diff --git a/taichi/platform/metal/shaders/runtime_structs.metal.h b/taichi/platform/metal/shaders/runtime_structs.metal.h new file mode 100644 index 0000000000000..0a0d2c7093b46 --- /dev/null +++ b/taichi/platform/metal/shaders/runtime_structs.metal.h @@ -0,0 +1,67 @@ +#include "taichi/platform/metal/shaders/prolog.h" + +#ifdef TI_INSIDE_METAL_CODEGEN + +#ifndef TI_METAL_NESTED_INCLUDE +#define METAL_BEGIN_RUNTIME_STRUCTS_DEF \ + constexpr auto kMetalRuntimeStructsSourceCode = +#define METAL_END_RUNTIME_STRUCTS_DEF ; +#else +#define METAL_BEGIN_RUNTIME_STRUCTS_DEF +#define METAL_END_RUNTIME_STRUCTS_DEF +#endif // TI_METAL_NESTED_INCLUDE + +#else + +#include + +#include "taichi/inc/constants.h" + +static_assert(taichi_max_num_indices == 8, + "Please update kTaichiMaxNumIndices"); + +#define METAL_BEGIN_RUNTIME_STRUCTS_DEF +#define METAL_END_RUNTIME_STRUCTS_DEF + +#endif // TI_INSIDE_METAL_CODEGEN + +METAL_BEGIN_RUNTIME_STRUCTS_DEF +STR( + constant constexpr int kTaichiMaxNumIndices = 8; + + struct ListgenElement { + int32_t coords[kTaichiMaxNumIndices]; + int32_t root_mem_offset = 0; + }; + + struct ListManager { + int32_t element_stride = 0; + int32_t max_num_elems = 0; + int32_t next = 0; + int32_t mem_begin = 0; + }; + + struct SNodeMeta { + enum Type { Root = 0, Dense = 1, DenseBitmask = 2 }; + int32_t element_stride = 0; + int32_t num_slots = 0; + int32_t mem_offset_in_parent = 0; + int32_t type = 0; + }; + + struct SNodeExtractors { + struct Extractor { + int32_t start = 0; + int32_t num_bits = 0; + int32_t acc_offset = 0; + int32_t num_elements = 0; + }; + + Extractor extractors[kTaichiMaxNumIndices]; + };) +METAL_END_RUNTIME_STRUCTS_DEF + +#undef METAL_BEGIN_RUNTIME_STRUCTS_DEF +#undef METAL_END_RUNTIME_STRUCTS_DEF + +#include "taichi/platform/metal/shaders/epilog.h" diff --git a/taichi/platform/metal/shaders/runtime_utils.metal.h b/taichi/platform/metal/shaders/runtime_utils.metal.h new file mode 100644 index 0000000000000..481be923b856d --- /dev/null +++ b/taichi/platform/metal/shaders/runtime_utils.metal.h @@ -0,0 +1,115 @@ +#ifndef TI_METAL_NESTED_INCLUDE + +#define TI_METAL_NESTED_INCLUDE +#include "taichi/platform/metal/shaders/runtime_structs.metal.h" +#undef TI_METAL_NESTED_INCLUDE + +#else +#include "taichi/platform/metal/shaders/runtime_structs.metal.h" +#endif // TI_METAL_NESTED_INCLUDE + +#include "taichi/platform/metal/shaders/prolog.h" + +#ifdef TI_INSIDE_METAL_CODEGEN + +#ifndef TI_METAL_NESTED_INCLUDE +#define METAL_BEGIN_RUNTIME_UTILS_DEF \ + constexpr auto kMetalRuntimeUtilsSourceCode = +#define METAL_END_RUNTIME_UTILS_DEF ; +#else +#define METAL_BEGIN_RUNTIME_UTILS_DEF +#define METAL_END_RUNTIME_UTILS_DEF +#endif // TI_METAL_NESTED_INCLUDE + +#else + +#define METAL_BEGIN_RUNTIME_UTILS_DEF +#define METAL_END_RUNTIME_UTILS_DEF + +#endif // TI_INSIDE_METAL_CODEGEN + +METAL_BEGIN_RUNTIME_UTILS_DEF +STR( + int num_active(device const ListManager *list) { return list->next; } + + template + int append(device ListManager *list, thread const T &elem, + device byte *data_addr) { + thread char *elem_ptr = (thread char *)(&elem); + int me = atomic_fetch_add_explicit( + reinterpret_cast(&(list->next)), 1, + metal::memory_order_relaxed); + device byte *ptr = + data_addr + list->mem_begin + (me * list->element_stride); + for (int i = 0; i < list->element_stride; ++i) { + *ptr = *elem_ptr; + ++ptr; + ++elem_ptr; + } + return me; + } + + template + T get(const device ListManager *list, int i, device const byte *data_addr) { + return *reinterpret_cast(data_addr + list->mem_begin + + (i * list->element_stride)); + } + + void clear(device ListManager *list) { + atomic_store_explicit( + reinterpret_cast(&(list->next)), 0, + metal::memory_order_relaxed); + } + + int is_active(device byte *addr, SNodeMeta meta, int i) { + if (meta.type == SNodeMeta::Root || meta.type == SNodeMeta::Dense) { + return true; + } + device auto *ptr = + reinterpret_cast( + addr + ((meta.num_slots - i) * meta.element_stride)) + + (i / (sizeof(uint32_t) * 8)); + uint32_t bits = atomic_load_explicit(ptr, metal::memory_order_relaxed); + return ((bits >> (i % (sizeof(uint32_t) * 8))) & 1); + } + + void activate(device byte *addr, SNodeMeta meta, int i) { + if (meta.type == SNodeMeta::Root || meta.type == SNodeMeta::Dense) { + return; + } + device auto *ptr = + reinterpret_cast( + addr + ((meta.num_slots - i) * meta.element_stride)) + + (i / (sizeof(uint32_t) * 8)); + const uint32_t mask = (1 << (i % (sizeof(uint32_t) * 8))); + atomic_fetch_or_explicit(ptr, mask, metal::memory_order_relaxed); + } + + void deactivate(device byte *addr, SNodeMeta meta, int i) { + if (meta.type == SNodeMeta::Root || meta.type == SNodeMeta::Dense) { + return; + } + device auto *ptr = + reinterpret_cast( + addr + ((meta.num_slots - i) * meta.element_stride)) + + (i / (sizeof(uint32_t) * 8)); + const uint32_t mask = ~(1 << (i % (sizeof(uint32_t) * 8))); + atomic_fetch_and_explicit(ptr, mask, metal::memory_order_relaxed); + } + + void refine_coordinates(thread const ListgenElement &parent_elem, + device const SNodeExtractors &child_extrators, + int l, thread ListgenElement *child_elem) { + for (int i = 0; i < kTaichiMaxNumIndices; ++i) { + device const auto &ex = child_extrators.extractors[i]; + const int mask = ((1 << ex.num_bits) - 1); + const int addition = (((l >> ex.acc_offset) & mask) << ex.start); + child_elem->coords[i] = (parent_elem.coords[i] | addition); + } + }) +METAL_END_RUNTIME_UTILS_DEF + +#undef METAL_BEGIN_RUNTIME_UTILS_DEF +#undef METAL_END_RUNTIME_UTILS_DEF + +#include "taichi/platform/metal/shaders/epilog.h" \ No newline at end of file