Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[metal] Use grid-stride loop to implement listgen kernels #682

Merged
merged 4 commits into from
Mar 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions taichi/backends/metal/kernel_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ class CompiledTaichiKernel {

TI_ASSERT(kernel != nullptr);
compiled_mtl_kernels.push_back(std::move(kernel));
TI_DEBUG("Added {} for Taichi kernel {}", ka.debug_string(),
params.taichi_kernel_name);
}
if (args_attribs.has_args()) {
args_mem = std::make_unique<BufferMemoryView>(args_attribs.total_bytes(),
Expand Down
34 changes: 34 additions & 0 deletions taichi/backends/metal/kernel_util.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "taichi/backends/metal/kernel_util.h"

#include <unordered_map>

#define TI_RUNTIME_HOST
#include "taichi/runtime/llvm/context.h"
#undef TI_RUNTIME_HOST
Expand All @@ -8,6 +10,38 @@ TLANG_NAMESPACE_BEGIN

namespace metal {

// static
std::string KernelAttributes::buffers_name(Buffers b) {
#define REGISTER_NAME(x) \
{ Buffers::x, #x }
const static std::unordered_map<Buffers, std::string> m = {
REGISTER_NAME(Root),
REGISTER_NAME(GlobalTmps),
REGISTER_NAME(Args),
REGISTER_NAME(Runtime),
};
#undef REGISTER_NAME
return m.find(b)->second;
}

std::string KernelAttributes::debug_string() const {
std::string result;
result += fmt::format(
"<KernelAttributes name={} num_threads={} task_type={} buffers=[ ", name,
num_threads, OffloadedStmt::task_type_name(task_type));
for (auto b : buffers) {
result += buffers_name(b) + " ";
}
result += "]"; // closes |buffers|
// TODO(k-ye): show range_for
if (task_type == OffloadedStmt::TaskType::clear_list ||
task_type == OffloadedStmt::TaskType::listgen) {
result += fmt::format(" snode={}", runtime_list_op_attribs.snode->id);
}
result += ">";
return result;
}

KernelArgsAttributes::KernelArgsAttributes(const std::vector<Kernel::Arg> &args)
: args_bytes_(0), extra_args_bytes_(Context::extra_args_size) {
arg_attribs_vec_.reserve(args.size());
Expand Down
3 changes: 3 additions & 0 deletions taichi/backends/metal/kernel_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ struct KernelAttributes {
RangeForAttributes range_for_attribs;
// clear_list + listgen
RuntimeListOpAttributes runtime_list_op_attribs;

static std::string buffers_name(Buffers b);
std::string debug_string() const;
};

// Note that all Metal kernels belonging to the same Taichi kernel will share
Expand Down
44 changes: 27 additions & 17 deletions taichi/backends/metal/shaders/runtime_kernels.metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ struct Runtime {
METAL_BEGIN_RUNTIME_KERNELS_DEF
STR(
// clang-format on
kernel void clear_list(device byte *runtime_addr [[buffer(0)]],
device int *args [[buffer(1)]],
const uint utid_ [[thread_position_in_grid]]) {
kernel void clear_list(device byte *runtime_addr[[buffer(0)]],
device int *args[[buffer(1)]],
const uint utid_[[thread_position_in_grid]]) {
if (utid_ > 0)
return;
int child_snode_id = args[1];
Expand All @@ -55,10 +55,11 @@ STR(
clear(child_list);
}

kernel void element_listgen(device byte *runtime_addr [[buffer(0)]],
device byte *root_addr [[buffer(1)]],
device int *args [[buffer(2)]],
const uint utid_ [[thread_position_in_grid]]) {
kernel void element_listgen(device byte *runtime_addr[[buffer(0)]],
device byte *root_addr[[buffer(1)]],
device int *args[[buffer(2)]],
const uint utid_[[thread_position_in_grid]],
const uint grid_size[[threads_per_grid]]) {
device Runtime *runtime =
reinterpret_cast<device Runtime *>(runtime_addr);
device byte *list_data_addr =
Expand All @@ -72,20 +73,29 @@ STR(
const SNodeMeta child_meta = runtime->snode_metas[child_snode_id];
const int child_stride = child_meta.element_stride;
const int num_slots = child_meta.num_slots;
if ((int)utid_ >= num_active(parent_list)) {
return;
}
const auto parent_elem =
get<ListgenElement>(parent_list, utid_, list_data_addr);
for (int i = 0; i < num_slots; ++i) {
const int range = max(
(int)((child_list->max_num_elems + grid_size - 1) / grid_size), 1);
const int begin = range * (int)utid_;

for (int ii = begin; ii < (begin + range); ++ii) {
const int parent_idx = (ii / num_slots);
if (parent_idx >= num_active(parent_list)) {
// Since |parent_idx| increases monotonically, we can return directly
// once it goes beyond the number of active parent elements.
return;
}
const int child_idx = (ii % num_slots);
const auto parent_elem =
get<ListgenElement>(parent_list, parent_idx, list_data_addr);
ListgenElement child_elem;
child_elem.root_mem_offset = parent_elem.root_mem_offset +
i * child_stride +
child_idx * child_stride +
child_meta.mem_offset_in_parent;
if (is_active(root_addr + child_elem.root_mem_offset, child_meta, i)) {
if (is_active(root_addr + child_elem.root_mem_offset, child_meta,
child_idx)) {
refine_coordinates(parent_elem,
runtime->snode_extractors[child_snode_id], i,
&child_elem);
runtime->snode_extractors[child_snode_id],
child_idx, &child_elem);
append(child_list, child_elem, list_data_addr);
}
}
Expand Down
11 changes: 6 additions & 5 deletions taichi/codegen/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -680,11 +680,12 @@ class KernelCodegen : public IRVisitor {
ka.num_threads = 1;
ka.buffers = {BuffersEnum::Runtime, BuffersEnum::Args};
} else if (type == Type::listgen) {
// This launches |total_num_elems_from_root| number of threads, which
// could be a huge waste of GPU resources.
// TODO(k-ye): use grid-stride loop to reduce #threads.
ka.num_threads = compiled_structs_->snode_descriptors.find(sn->id)
->second.total_num_elems_from_root;
// listgen kernels use grid-stride loops, so that we can cap its maximum
// number of threads at 1M.
ka.num_threads =
std::min(compiled_structs_->snode_descriptors.find(sn->id)
->second.total_num_elems_from_root,
64 * 1024);
ka.buffers = {BuffersEnum::Runtime, BuffersEnum::Root, BuffersEnum::Args};
} else {
TI_ERROR("Unsupported offload task type {}", stmt->task_name());
Expand Down
22 changes: 19 additions & 3 deletions taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
// Intermediate representations

#include "ir.h"
#include <thread>
#include "taichi/ir/ir.h"

#include <numeric>
#include "frontend.h"
#include <thread>
#include <unordered_map>

#include "taichi/ir/frontend.h"

TLANG_NAMESPACE_BEGIN

Expand Down Expand Up @@ -531,4 +534,17 @@ std::string OffloadedStmt::task_name() const {
}
}

// static
std::string OffloadedStmt::task_type_name(TaskType tt) {
#define REGISTER_NAME(x) \
{ TaskType::x, #x }
const static std::unordered_map<TaskType, std::string> m = {
REGISTER_NAME(serial), REGISTER_NAME(range_for),
REGISTER_NAME(struct_for), REGISTER_NAME(clear_list),
REGISTER_NAME(listgen), REGISTER_NAME(gc),
};
#undef REGISTER_NAME
return m.find(tt)->second;
}

TLANG_NAMESPACE_END
5 changes: 4 additions & 1 deletion taichi/ir/statements.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once
#include "ir.h"

#include "taichi/ir/ir.h"

TLANG_NAMESPACE_BEGIN

Expand Down Expand Up @@ -190,6 +191,8 @@ class OffloadedStmt : public Stmt {

std::string task_name() const;

static std::string task_type_name(TaskType tt);

bool has_body() const {
return task_type != clear_list && task_type != listgen && task_type != gc;
}
Expand Down
27 changes: 27 additions & 0 deletions tests/python/test_bitmasked.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,30 @@ def func():

func()
assert s[None] == 4


@archs_support_bitmasked
def test_huge_bitmasked():
# Mainly for testing Metal listgen's grid-stride loop implementation.
x = ti.var(ti.f32)
s = ti.var(ti.i32)

n = 1024

ti.root.bitmasked(ti.i, n).bitmasked(ti.i, 2 * n).place(x)
ti.root.place(s)

@ti.kernel
def func():
for i in range(n * n * 2):
if i % 32 == 0:
x[i] = 1.0

@ti.kernel
def count():
for i in x:
s[None] += 1

func()
count()
assert s[None] == (n * n * 2) // 32