Skip to content

Commit

Permalink
[Perf] [metal] Support TLS and SIMD group reduction for range-for ker…
Browse files Browse the repository at this point in the history
…nels (#1358)

* [Perf] [metal] Support TLS and SIMD group reduction

* newline

* [skip ci] enforce code format

Co-authored-by: Taichi Gardener <[email protected]>
  • Loading branch information
k-ye and taichi-gardener authored Jul 1, 2020
1 parent ee0281f commit 6100ee2
Show file tree
Hide file tree
Showing 10 changed files with 193 additions and 27 deletions.
11 changes: 8 additions & 3 deletions taichi/backends/metal/api.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#include "taichi/backends/metal/api.h"

#include "taichi/backends/metal/constants.h"
#include "taichi/util/environ_config.h"

TLANG_NAMESPACE_BEGIN
Expand Down Expand Up @@ -53,15 +55,18 @@ nsobj_unique_ptr<MTLComputeCommandEncoder> new_compute_command_encoder(
return wrap_as_nsobj_unique_ptr(encoder);
}

nsobj_unique_ptr<MTLLibrary> new_library_with_source(
MTLDevice *device,
const std::string &source) {
nsobj_unique_ptr<MTLLibrary> new_library_with_source(MTLDevice *device,
const std::string &source,
int msl_version) {
auto source_str = mac::wrap_string_as_ns_string(source);

id options = clscall("MTLCompileOptions", "alloc");
options = call(options, "init");
auto options_cleanup = wrap_as_nsobj_unique_ptr(options);
call(options, "setFastMathEnabled:", false);
if (msl_version != kMslVersionNone) {
call(options, "setLanguageVersion:", msl_version);
}

auto *lib = cast_call<MTLLibrary *>(
device, "newLibraryWithSource:options:error:", source_str.get(), options,
Expand Down
9 changes: 6 additions & 3 deletions taichi/backends/metal/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
// Reference implementation:
// https://github.com/halide/Halide/blob/master/src/runtime/metal.cpp

#include <string>

#include "taichi/common/trait.h"
#include "taichi/lang_util.h"
#include "taichi/platform/mac/objc_api.h"

#include <string>

TLANG_NAMESPACE_BEGIN

namespace metal {
Expand Down Expand Up @@ -42,8 +42,11 @@ nsobj_unique_ptr<MTLCommandBuffer> new_command_buffer(MTLCommandQueue *queue);
nsobj_unique_ptr<MTLComputeCommandEncoder> new_compute_command_encoder(
MTLCommandBuffer *buffer);

// msl_version: Metal Shader Language version. 0 means not set.
// See https://developer.apple.com/documentation/metal/mtllanguageversion
nsobj_unique_ptr<MTLLibrary> new_library_with_source(MTLDevice *device,
const std::string &source);
const std::string &source,
int msl_version);

nsobj_unique_ptr<MTLFunction> new_function_with_name(MTLLibrary *library,
const std::string &name);
Expand Down
118 changes: 101 additions & 17 deletions taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "taichi/ir/ir.h"
#include "taichi/ir/transforms.h"
#include "taichi/util/line_appender.h"
#include "taichi/math/arithmetic.h"
#include "taichi/backends/metal/api.h"

TLANG_NAMESPACE_BEGIN
namespace metal {
Expand All @@ -30,6 +32,7 @@ using BuffersEnum = KernelAttributes::Buffers;

constexpr char kKernelThreadIdName[] = "utid_"; // 'u' for unsigned
constexpr char kKernelGridSizeName[] = "ugrid_size_"; // 'u' for unsigned
constexpr char kKernelTidInSimdgroupName[] = "utid_in_simdg_";
constexpr char kRootBufferName[] = "root_addr";
constexpr char kGlobalTmpsBufferName[] = "global_tmps_addr";
constexpr char kContextBufferName[] = "ctx_addr";
Expand All @@ -43,6 +46,7 @@ constexpr char kListgenElemVarName[] = "listgen_elem_";
constexpr char kRandStateVarName[] = "rand_state_";
constexpr char kSNodeMetaVarName[] = "sn_meta_";
constexpr char kMemAllocVarName[] = "mem_alloc_";
constexpr char kTlsBufferName[] = "tls_buffer_";

std::string buffer_to_name(BuffersEnum b) {
switch (b) {
Expand Down Expand Up @@ -85,14 +89,16 @@ class KernelCodegen : public IRVisitor {
const std::string &root_snode_type_name,
Kernel *kernel,
const CompiledStructs *compiled_structs,
PrintStringTable *print_strtab)
PrintStringTable *print_strtab,
const CodeGen::Config &config)
: mtl_kernel_prefix_(mtl_kernel_prefix),
root_snode_type_name_(root_snode_type_name),
kernel_(kernel),
compiled_structs_(compiled_structs),
needs_root_buffer_(compiled_structs_->root_size > 0),
ctx_attribs_(*kernel_),
print_strtab_(print_strtab) {
print_strtab_(print_strtab),
cgen_config_(config) {
// allow_undefined_visitor = true;
for (const auto s : kAllSections) {
section_appenders_[s] = LineAppender();
Expand Down Expand Up @@ -336,6 +342,13 @@ class KernelCodegen : public IRVisitor {
stmt->raw_name(), dt, kGlobalTmpsBufferName, stmt->offset);
}

void visit(ThreadLocalPtrStmt *stmt) override {
TI_ASSERT(stmt->width() == 1);
emit("thread auto* {} = reinterpret_cast<thread {}*>({} + {});",
stmt->raw_name(), metal_data_type_name(stmt->element_type()),
kTlsBufferName, stmt->offset);
}

void visit(LoopIndexStmt *stmt) override {
const auto stmt_name = stmt->raw_name();
if (stmt->loop->is<OffloadedStmt>()) {
Expand Down Expand Up @@ -447,31 +460,48 @@ class KernelCodegen : public IRVisitor {
TI_NOT_IMPLEMENTED;
}

std::string val_var = stmt->val->raw_name();
// TODO(k-ye): This is not a very reliable way to detect if we're in TLS
// xlogues...
const bool is_tls_reduction =
(inside_tls_epilogue_ && (op_type == AtomicOpType::add));
const bool use_simd_in_tls_reduction =
(is_tls_reduction && cgen_config_.allow_simdgroup);
if (use_simd_in_tls_reduction) {
val_var += "_simd_val_";
emit("const auto {} = simd_sum({});", val_var, stmt->val->raw_name());
emit("if ({} == 0) {{", kKernelTidInSimdgroupName);
current_appender().push_indent();
}

if (dt == DataType::i32) {
emit(
"const auto {} = atomic_fetch_{}_explicit((device atomic_int*){}, "
"{}, "
"metal::memory_order_relaxed);",
stmt->raw_name(), op_name, stmt->dest->raw_name(),
stmt->val->raw_name());
stmt->raw_name(), op_name, stmt->dest->raw_name(), val_var);
} else if (dt == DataType::u32) {
emit(
"const auto {} = atomic_fetch_{}_explicit((device atomic_uint*){}, "
"{}, "
"metal::memory_order_relaxed);",
stmt->raw_name(), op_name, stmt->dest->raw_name(),
stmt->val->raw_name());
stmt->raw_name(), op_name, stmt->dest->raw_name(), val_var);
} else if (dt == DataType::f32) {
if (handle_float) {
emit("const float {} = fatomic_fetch_{}({}, {});", stmt->raw_name(),
op_name, stmt->dest->raw_name(), stmt->val->raw_name());
op_name, stmt->dest->raw_name(), val_var);
} else {
TI_ERROR("Metal does not support atomic {} for floating points",
op_name);
}
} else {
TI_ERROR("Metal only supports 32-bit atomic data types");
}

if (use_simd_in_tls_reduction) {
current_appender().pop_indent();
emit("}}"); // closes `if (kKernelTidInSimdgroupName == 0) {`
}
}

void visit(IfStmt *if_stmt) override {
Expand Down Expand Up @@ -651,6 +681,7 @@ class KernelCodegen : public IRVisitor {
void emit_headers() {
SectionGuard sg(this, Section::Headers);
emit("#include <metal_stdlib>");
emit("#include <metal_compute>");
emit("using namespace metal;");
}

Expand Down Expand Up @@ -776,7 +807,13 @@ class KernelCodegen : public IRVisitor {
ka.task_type = stmt->task_type;
ka.buffers = get_common_buffers();

emit_mtl_kernel_sig(mtl_kernel_name, ka.buffers);
const bool used_tls = (stmt->prologue != nullptr);
KernelSigExtensions kernel_exts;
kernel_exts.use_simdgroup = (used_tls && cgen_config_.allow_simdgroup);
used_features()->simdgroup =
used_features()->simdgroup || kernel_exts.use_simdgroup;

emit_mtl_kernel_sig(mtl_kernel_name, ka.buffers, kernel_exts);

ka.range_for_attribs = KernelAttributes::RangeForAttributes();
auto &range_for_attribs = ka.range_for_attribs.value();
Expand Down Expand Up @@ -818,17 +855,45 @@ class KernelCodegen : public IRVisitor {
emit("const int begin_ = {} + {};", kKernelThreadIdName, begin_expr);
// end_ = total_elems + begin_expr
emit("const int end_ = {} + {};", total_elems_name, begin_expr);

if (used_tls) {
// Using |int32_t| because it aligns to 4bytes.
emit("// TLS prologue");
const std::string tls_bufi32_name = "tls_bufi32_";
emit("int32_t {}[{}];", tls_bufi32_name, (stmt->tls_size + 3) / 4);
emit("thread char* {} = reinterpret_cast<thread char*>({});",
kTlsBufferName, tls_bufi32_name);
stmt->prologue->accept(this);
}

emit("for (int ii = begin_; ii < end_; ii += {}) {{", kKernelGridSizeName);
{
ScopedIndent s2(current_appender());

current_kernel_attribs_ = &ka;
const auto mtl_func_name = mtl_kernel_func_name(mtl_kernel_name);
emit_mtl_kernel_func_def(mtl_func_name, ka.buffers, stmt->body.get());
emit_call_mtl_kernel_func(mtl_func_name, ka.buffers,
std::vector<FuncParamLiteral> extra_func_params;
std::vector<std::string> extra_args;
if (used_tls) {
extra_func_params.push_back({"thread char*", kTlsBufferName});
extra_args.push_back(kTlsBufferName);
}
emit_mtl_kernel_func_def(mtl_func_name, ka.buffers, extra_func_params,
stmt->body.get());
emit_call_mtl_kernel_func(mtl_func_name, ka.buffers, extra_args,
/*loop_index_expr=*/"ii");
}
emit("}}"); // closes for loop

if (used_tls) {
TI_ASSERT(stmt->epilogue != nullptr);
inside_tls_epilogue_ = true;
emit("{{ // TLS epilogue");
stmt->epilogue->accept(this);
inside_tls_epilogue_ = false;
emit("}}");
}

current_appender().pop_indent();
// Close kernel
emit("}}\n");
Expand Down Expand Up @@ -1042,15 +1107,28 @@ class KernelCodegen : public IRVisitor {
loop_index_expr);
}

struct KernelSigExtensions {
// https://stackoverflow.com/a/44693603/12003165
KernelSigExtensions() noexcept {
}

bool use_simdgroup = false;
};

void emit_mtl_kernel_sig(
const std::string &kernel_name,
const std::vector<KernelAttributes::Buffers> &buffers) {
const std::vector<KernelAttributes::Buffers> &buffers,
const KernelSigExtensions &exts = {}) {
emit("kernel void {}(", kernel_name);
for (int i = 0; i < buffers.size(); ++i) {
emit(" device byte* {} [[buffer({})]],", buffer_to_name(buffers[i]),
i);
}
emit(" const uint {} [[threads_per_grid]],", kKernelGridSizeName);
if (exts.use_simdgroup) {
emit(" const uint {} [[thread_index_in_simdgroup]],",
kKernelTidInSimdgroupName);
}
emit(" const uint {} [[thread_position_in_grid]]) {{",
kKernelThreadIdName);
}
Expand Down Expand Up @@ -1134,12 +1212,14 @@ class KernelCodegen : public IRVisitor {
const bool needs_root_buffer_;
const KernelContextAttributes ctx_attribs_;
PrintStringTable *const print_strtab_;
const CodeGen::Config &cgen_config_;

bool is_top_level_{true};
int mtl_kernel_count_{0};
TaichiKernelAttributes ti_kernel_attribus_;
GetRootStmt *root_stmt_{nullptr};
KernelAttributes *current_kernel_attribs_{nullptr};
bool inside_tls_epilogue_{false};
Section code_section_{Section::Structs};
std::unordered_map<Section, LineAppender> section_appenders_;
};
Expand All @@ -1148,24 +1228,28 @@ class KernelCodegen : public IRVisitor {

CodeGen::CodeGen(Kernel *kernel,
KernelManager *kernel_mgr,
const CompiledStructs *compiled_structs)
const CompiledStructs *compiled_structs,
const Config &config)
: kernel_(kernel),
kernel_mgr_(kernel_mgr),
compiled_structs_(compiled_structs),
id_(Program::get_kernel_id()),
taichi_kernel_name_(fmt::format("mtl_k{:04d}_{}", id_, kernel_->name)) {
taichi_kernel_name_(fmt::format("mtl_k{:04d}_{}", id_, kernel_->name)),
config_(config) {
}

FunctionType CodeGen::compile() {
auto &config = kernel_->program.config;
config.demote_dense_struct_fors = true;
irpass::compile_to_offloads(kernel_->ir.get(), config,
/*vectorize=*/false, kernel_->grad,
/*ad_use_stack=*/true, config.print_ir);
/*ad_use_stack=*/true, config.print_ir,
/*lower_global_access=*/true,
/*make_thread_local=*/true);

KernelCodegen codegen(taichi_kernel_name_,
kernel_->program.snode_root->node_type_name, kernel_,
compiled_structs_, kernel_mgr_->print_strtable());
KernelCodegen codegen(
taichi_kernel_name_, kernel_->program.snode_root->node_type_name, kernel_,
compiled_structs_, kernel_mgr_->print_strtable(), config_);
const auto source_code = codegen.run();
kernel_mgr_->register_taichi_kernel(taichi_kernel_name_, source_code,
codegen.ti_kernels_attribs(),
Expand Down
8 changes: 7 additions & 1 deletion taichi/backends/metal/codegen_metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,14 @@ namespace metal {

class CodeGen {
public:
struct Config {
bool allow_simdgroup = true;
};

CodeGen(Kernel *kernel,
KernelManager *kernel_mgr,
const CompiledStructs *compiled_structs);
const CompiledStructs *compiled_structs,
const Config &config);

FunctionType compile();

Expand All @@ -33,6 +38,7 @@ class CodeGen {
const CompiledStructs *const compiled_structs_;
const int id_;
const std::string taichi_kernel_name_;
const Config config_;
};

} // namespace metal
Expand Down
3 changes: 2 additions & 1 deletion taichi/backends/metal/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ namespace metal {

inline constexpr int kMaxNumThreadsGridStrideLoop = 64 * 1024;
inline constexpr int kNumRandSeeds = 64 * 1024; // 256 KB is nothing
inline constexpr int kMslVersionNone = 0;

} // namespace metal
TLANG_NAMESPACE_END
TLANG_NAMESPACE_END
21 changes: 21 additions & 0 deletions taichi/backends/metal/env_config.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#include "taichi/backends/metal/env_config.h"

#include "taichi/lang_util.h"
#include "taichi/util/environ_config.h"

TLANG_NAMESPACE_BEGIN
namespace metal {

EnvConfig::EnvConfig() {
simdgroup_enabled_ =
get_environ_config("TI_USE_METAL_SIMDGROUP", /*default_value=*/1);
}

const EnvConfig &EnvConfig::instance() {
static const EnvConfig c;
return c;
}

} // namespace metal

TLANG_NAMESPACE_END
Loading

0 comments on commit 6100ee2

Please sign in to comment.