Skip to content

Commit

Permalink
[Metal] Add Runtime shaders to support sparse SNode (#614)
Browse files Browse the repository at this point in the history
  • Loading branch information
k-ye authored Mar 18, 2020
1 parent a03f70d commit e93da7d
Show file tree
Hide file tree
Showing 3 changed files with 277 additions and 0 deletions.
95 changes: 95 additions & 0 deletions taichi/platform/metal/shaders/runtime_kernels.metal.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#ifndef TI_METAL_NESTED_INCLUDE

#define TI_METAL_NESTED_INCLUDE
#include "taichi/platform/metal/shaders/runtime_utils.metal.h"
#undef TI_METAL_NESTED_INCLUDE

#else
#include "taichi/platform/metal/shaders/runtime_utils.metal.h"
#endif // TI_METAL_NESTED_INCLUDE

#include "taichi/platform/metal/shaders/prolog.h"

#ifdef TI_INSIDE_METAL_CODEGEN

#ifndef TI_METAL_NESTED_INCLUDE
#define METAL_BEGIN_RUNTIME_KERNELS_DEF \
constexpr auto kMetalRuntimeKernelsSourceCode =
#define METAL_END_RUNTIME_KERNELS_DEF ;
#else
#define METAL_BEGIN_RUNTIME_KERNELS_DEF
#define METAL_END_RUNTIME_KERNELS_DEF
#endif // TI_METAL_NESTED_INCLUDE

#else

static_assert(false, "Do not include");

// Just a mock to illustrate what the Runtime looks like, do not use.
// The actual Runtime struct has to be emitted by codegen, because it depends
// on the number of SNodes.
struct Runtime {
SNodeMeta *snode_metas;
SNodeExtractors *snode_extractors;
ListManager *snode_lists;
};

#define METAL_BEGIN_RUNTIME_KERNELS_DEF
#define METAL_END_RUNTIME_KERNELS_DEF

#endif // TI_INSIDE_METAL_CODEGEN

METAL_BEGIN_RUNTIME_KERNELS_DEF
STR(
kernel void clear_list(device byte *runtime_addr [[buffer(0)]],
device int *args [[buffer(1)]],
const uint utid_ [[thread_position_in_grid]]) {
if (utid_ > 0) return;
int child_snode_id = args[1];
device ListManager *child_list =
&(reinterpret_cast<device Runtime *>(runtime_addr)
->snode_lists[child_snode_id]);
clear(child_list);
}

kernel void element_listgen(device byte *runtime_addr [[buffer(0)]],
device byte *root_addr [[buffer(1)]],
device int *args [[buffer(2)]],
const uint utid_ [[thread_position_in_grid]]) {
device Runtime *runtime =
reinterpret_cast<device Runtime *>(runtime_addr);
device byte *list_data_addr =
reinterpret_cast<device byte *>(runtime + 1);

int parent_snode_id = args[0];
int child_snode_id = args[1];
device ListManager *parent_list =
&(runtime->snode_lists[parent_snode_id]);
device ListManager *child_list = &(runtime->snode_lists[child_snode_id]);
const SNodeMeta child_meta = runtime->snode_metas[child_snode_id];
const int child_stride = child_meta.element_stride;
const int num_slots = child_meta.num_slots;
if ((int)utid_ >= num_active(parent_list)) {
return;
}
const auto parent_elem =
get<ListgenElement>(parent_list, utid_, list_data_addr);
for (int i = 0; i < num_slots; ++i) {
ListgenElement child_elem;
child_elem.root_mem_offset = parent_elem.root_mem_offset +
i * child_stride +
child_meta.mem_offset_in_parent;
if (is_active(root_addr + child_elem.root_mem_offset, child_meta, i)) {
refine_coordinates(parent_elem,
runtime->snode_extractors[child_snode_id], i,
&child_elem);
append(child_list, child_elem, list_data_addr);
}
}
})
METAL_END_RUNTIME_KERNELS_DEF

#undef METAL_BEGIN_RUNTIME_KERNELS_DEF
#undef METAL_END_RUNTIME_KERNELS_DEF

#include "taichi/platform/metal/shaders/epilog.h"
67 changes: 67 additions & 0 deletions taichi/platform/metal/shaders/runtime_structs.metal.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#include "taichi/platform/metal/shaders/prolog.h"

#ifdef TI_INSIDE_METAL_CODEGEN

#ifndef TI_METAL_NESTED_INCLUDE
#define METAL_BEGIN_RUNTIME_STRUCTS_DEF \
constexpr auto kMetalRuntimeStructsSourceCode =
#define METAL_END_RUNTIME_STRUCTS_DEF ;
#else
#define METAL_BEGIN_RUNTIME_STRUCTS_DEF
#define METAL_END_RUNTIME_STRUCTS_DEF
#endif // TI_METAL_NESTED_INCLUDE

#else

#include <cstdint>

#include "taichi/inc/constants.h"

static_assert(taichi_max_num_indices == 8,
"Please update kTaichiMaxNumIndices");

#define METAL_BEGIN_RUNTIME_STRUCTS_DEF
#define METAL_END_RUNTIME_STRUCTS_DEF

#endif // TI_INSIDE_METAL_CODEGEN

METAL_BEGIN_RUNTIME_STRUCTS_DEF
STR(
constant constexpr int kTaichiMaxNumIndices = 8;

struct ListgenElement {
int32_t coords[kTaichiMaxNumIndices];
int32_t root_mem_offset = 0;
};

struct ListManager {
int32_t element_stride = 0;
int32_t max_num_elems = 0;
int32_t next = 0;
int32_t mem_begin = 0;
};

struct SNodeMeta {
enum Type { Root = 0, Dense = 1, DenseBitmask = 2 };
int32_t element_stride = 0;
int32_t num_slots = 0;
int32_t mem_offset_in_parent = 0;
int32_t type = 0;
};

struct SNodeExtractors {
struct Extractor {
int32_t start = 0;
int32_t num_bits = 0;
int32_t acc_offset = 0;
int32_t num_elements = 0;
};

Extractor extractors[kTaichiMaxNumIndices];
};)
METAL_END_RUNTIME_STRUCTS_DEF

#undef METAL_BEGIN_RUNTIME_STRUCTS_DEF
#undef METAL_END_RUNTIME_STRUCTS_DEF

#include "taichi/platform/metal/shaders/epilog.h"
115 changes: 115 additions & 0 deletions taichi/platform/metal/shaders/runtime_utils.metal.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#ifndef TI_METAL_NESTED_INCLUDE

#define TI_METAL_NESTED_INCLUDE
#include "taichi/platform/metal/shaders/runtime_structs.metal.h"
#undef TI_METAL_NESTED_INCLUDE

#else
#include "taichi/platform/metal/shaders/runtime_structs.metal.h"
#endif // TI_METAL_NESTED_INCLUDE

#include "taichi/platform/metal/shaders/prolog.h"

#ifdef TI_INSIDE_METAL_CODEGEN

#ifndef TI_METAL_NESTED_INCLUDE
#define METAL_BEGIN_RUNTIME_UTILS_DEF \
constexpr auto kMetalRuntimeUtilsSourceCode =
#define METAL_END_RUNTIME_UTILS_DEF ;
#else
#define METAL_BEGIN_RUNTIME_UTILS_DEF
#define METAL_END_RUNTIME_UTILS_DEF
#endif // TI_METAL_NESTED_INCLUDE

#else

#define METAL_BEGIN_RUNTIME_UTILS_DEF
#define METAL_END_RUNTIME_UTILS_DEF

#endif // TI_INSIDE_METAL_CODEGEN

METAL_BEGIN_RUNTIME_UTILS_DEF
STR(
int num_active(device const ListManager *list) { return list->next; }

template <typename T>
int append(device ListManager *list, thread const T &elem,
device byte *data_addr) {
thread char *elem_ptr = (thread char *)(&elem);
int me = atomic_fetch_add_explicit(
reinterpret_cast<device atomic_int *>(&(list->next)), 1,
metal::memory_order_relaxed);
device byte *ptr =
data_addr + list->mem_begin + (me * list->element_stride);
for (int i = 0; i < list->element_stride; ++i) {
*ptr = *elem_ptr;
++ptr;
++elem_ptr;
}
return me;
}

template <typename T>
T get(const device ListManager *list, int i, device const byte *data_addr) {
return *reinterpret_cast<device const T *>(data_addr + list->mem_begin +
(i * list->element_stride));
}

void clear(device ListManager *list) {
atomic_store_explicit(
reinterpret_cast<device atomic_int *>(&(list->next)), 0,
metal::memory_order_relaxed);
}

int is_active(device byte *addr, SNodeMeta meta, int i) {
if (meta.type == SNodeMeta::Root || meta.type == SNodeMeta::Dense) {
return true;
}
device auto *ptr =
reinterpret_cast<device atomic_uint *>(
addr + ((meta.num_slots - i) * meta.element_stride)) +
(i / (sizeof(uint32_t) * 8));
uint32_t bits = atomic_load_explicit(ptr, metal::memory_order_relaxed);
return ((bits >> (i % (sizeof(uint32_t) * 8))) & 1);
}

void activate(device byte *addr, SNodeMeta meta, int i) {
if (meta.type == SNodeMeta::Root || meta.type == SNodeMeta::Dense) {
return;
}
device auto *ptr =
reinterpret_cast<device atomic_uint *>(
addr + ((meta.num_slots - i) * meta.element_stride)) +
(i / (sizeof(uint32_t) * 8));
const uint32_t mask = (1 << (i % (sizeof(uint32_t) * 8)));
atomic_fetch_or_explicit(ptr, mask, metal::memory_order_relaxed);
}

void deactivate(device byte *addr, SNodeMeta meta, int i) {
if (meta.type == SNodeMeta::Root || meta.type == SNodeMeta::Dense) {
return;
}
device auto *ptr =
reinterpret_cast<device atomic_uint *>(
addr + ((meta.num_slots - i) * meta.element_stride)) +
(i / (sizeof(uint32_t) * 8));
const uint32_t mask = ~(1 << (i % (sizeof(uint32_t) * 8)));
atomic_fetch_and_explicit(ptr, mask, metal::memory_order_relaxed);
}

void refine_coordinates(thread const ListgenElement &parent_elem,
device const SNodeExtractors &child_extrators,
int l, thread ListgenElement *child_elem) {
for (int i = 0; i < kTaichiMaxNumIndices; ++i) {
device const auto &ex = child_extrators.extractors[i];
const int mask = ((1 << ex.num_bits) - 1);
const int addition = (((l >> ex.acc_offset) & mask) << ex.start);
child_elem->coords[i] = (parent_elem.coords[i] | addition);
}
})
METAL_END_RUNTIME_UTILS_DEF

#undef METAL_BEGIN_RUNTIME_UTILS_DEF
#undef METAL_END_RUNTIME_UTILS_DEF

#include "taichi/platform/metal/shaders/epilog.h"

0 comments on commit e93da7d

Please sign in to comment.