Skip to content

Commit

Permalink
[metal] Pass kernel name and is_evalutator to the runtime (#1430)
Browse files Browse the repository at this point in the history
  • Loading branch information
k-ye authored Jul 7, 2020
1 parent 4bfe530 commit c75e5ed
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 15 deletions.
6 changes: 4 additions & 2 deletions taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,20 +85,22 @@ class KernelCodegen : public IRVisitor {

public:
// TODO(k-ye): Create a Params to hold these ctor params.
KernelCodegen(const std::string &mtl_kernel_prefix,
KernelCodegen(const std::string &taichi_kernel_name,
const std::string &root_snode_type_name,
Kernel *kernel,
const CompiledStructs *compiled_structs,
PrintStringTable *print_strtab,
const CodeGen::Config &config)
: mtl_kernel_prefix_(mtl_kernel_prefix),
: mtl_kernel_prefix_(taichi_kernel_name),
root_snode_type_name_(root_snode_type_name),
kernel_(kernel),
compiled_structs_(compiled_structs),
needs_root_buffer_(compiled_structs_->root_size > 0),
ctx_attribs_(*kernel_),
print_strtab_(print_strtab),
cgen_config_(config) {
ti_kernel_attribus_.name = taichi_kernel_name;
ti_kernel_attribus_.is_jit_evaluator = kernel->is_evaluator;
// allow_undefined_visitor = true;
for (const auto s : kAllSections) {
section_appenders_[s] = LineAppender();
Expand Down
30 changes: 18 additions & 12 deletions taichi/backends/metal/kernel_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,14 @@ class CompiledMtlKernelBase {
public:
struct Params {
const KernelAttributes *kernel_attribs;
bool is_jit_evaluator;
MTLDevice *device;
MTLFunction *mtl_func;
};

explicit CompiledMtlKernelBase(Params &params)
: kernel_attribs_(*params.kernel_attribs),
: is_jit_evalutor_(params.is_jit_evaluator),
kernel_attribs_(*params.kernel_attribs),
pipeline_state_(
new_compute_pipeline_state_with_function(params.device,
params.mtl_func)) {
Expand Down Expand Up @@ -132,6 +134,7 @@ class CompiledMtlKernelBase {
end_encoding(encoder.get());
}

const bool is_jit_evalutor_;
KernelAttributes kernel_attribs_;
nsobj_unique_ptr<MTLComputePipelineState> pipeline_state_;
};
Expand Down Expand Up @@ -215,7 +218,6 @@ class RuntimeListOpsMtlKernel : public CompiledMtlKernelBase {
class CompiledTaichiKernel {
public:
struct Params {
std::string_view taichi_kernel_name;
std::string mtl_source_code;
const TaichiKernelAttributes *ti_kernel_attribs;
const KernelContextAttributes *ctx_attribs;
Expand All @@ -226,16 +228,17 @@ class CompiledTaichiKernel {
};

CompiledTaichiKernel(Params params)
: ctx_attribs(*params.ctx_attribs),
used_features(params.ti_kernel_attribs->used_features) {
: ti_kernel_attribs(*params.ti_kernel_attribs),
ctx_attribs(*params.ctx_attribs) {
auto *const device = params.device;
auto kernel_lib = new_library_with_source(device, params.mtl_source_code,
infer_msl_version(used_features));
auto kernel_lib = new_library_with_source(
device, params.mtl_source_code,
infer_msl_version(ti_kernel_attribs.used_features));
if (kernel_lib == nullptr) {
TI_ERROR("Failed to compile Metal kernel! Generated code:\n\n{}",
params.mtl_source_code);
}
for (const auto &ka : params.ti_kernel_attribs->mtl_kernels_attribs) {
for (const auto &ka : ti_kernel_attribs.mtl_kernels_attribs) {
auto mtl_func = new_function_with_name(kernel_lib.get(), ka.name);
TI_ASSERT(mtl_func != nullptr);
// Note that CompiledMtlKernel doesn't own |kernel_func|.
Expand All @@ -245,6 +248,7 @@ class CompiledTaichiKernel {
ktype == KernelTaskType::listgen) {
RuntimeListOpsMtlKernel::Params kparams;
kparams.kernel_attribs = &ka;
kparams.is_jit_evaluator = ti_kernel_attribs.is_jit_evaluator;
kparams.device = device;
kparams.mtl_func = mtl_func.get();
kparams.mem_pool = params.mem_pool;
Expand All @@ -253,6 +257,7 @@ class CompiledTaichiKernel {
} else {
UserMtlKernel::Params kparams;
kparams.kernel_attribs = &ka;
kparams.is_jit_evaluator = ti_kernel_attribs.is_jit_evaluator;
kparams.device = device;
kparams.mtl_func = mtl_func.get();
kernel = std::make_unique<UserMtlKernel>(kparams);
Expand All @@ -261,7 +266,7 @@ class CompiledTaichiKernel {
TI_ASSERT(kernel != nullptr);
compiled_mtl_kernels.push_back(std::move(kernel));
TI_DEBUG("Added {} for Taichi kernel {}", ka.debug_string(),
params.taichi_kernel_name);
ti_kernel_attribs.name);
}
if (!ctx_attribs.empty()) {
ctx_mem = std::make_unique<BufferMemoryView>(ctx_attribs.total_bytes(),
Expand All @@ -274,16 +279,17 @@ class CompiledTaichiKernel {
// Have to be exposed as public for Impl to use. We cannot friend the Impl
// class because it is private.
std::vector<std::unique_ptr<CompiledMtlKernelBase>> compiled_mtl_kernels;
TaichiKernelAttributes ti_kernel_attribs;
KernelContextAttributes ctx_attribs;
std::unique_ptr<BufferMemoryView> ctx_mem;
nsobj_unique_ptr<MTLBuffer> ctx_buffer;
TaichiKernelAttributes::UsedFeatures used_features;
};

class HostMetalCtxBlitter {
public:
HostMetalCtxBlitter(const CompiledTaichiKernel &kernel, Context *host_ctx)
: ctx_attribs_(&kernel.ctx_attribs),
: ti_kernel_attribs_(&kernel.ti_kernel_attribs),
ctx_attribs_(&kernel.ctx_attribs),
host_ctx_(host_ctx),
kernel_ctx_mem_(kernel.ctx_mem.get()),
kernel_ctx_buffer_(kernel.ctx_buffer.get()) {
Expand Down Expand Up @@ -396,6 +402,7 @@ class HostMetalCtxBlitter {
}

private:
const TaichiKernelAttributes *const ti_kernel_attribs_;
const KernelContextAttributes *const ctx_attribs_;
Context *const host_ctx_;
BufferMemoryView *const kernel_ctx_mem_;
Expand Down Expand Up @@ -476,7 +483,6 @@ class KernelManager::Impl {
mtl_kernel_source_code);
}
CompiledTaichiKernel::Params params;
params.taichi_kernel_name = taichi_kernel_name;
params.mtl_source_code = mtl_kernel_source_code;
params.ti_kernel_attribs = &ti_kernel_attribs;
params.ctx_attribs = &ctx_attribs;
Expand Down Expand Up @@ -511,7 +517,7 @@ class KernelManager::Impl {
for (const auto &mk : ctk.compiled_mtl_kernels) {
mk->launch(input_buffers, cur_command_buffer_.get());
}
const bool used_print = ctk.used_features.print;
const bool used_print = ctk.ti_kernel_attribs.used_features.print;
if (ctx_blitter || used_print) {
// TODO(k-ye): One optimization is to synchronize only when we absolutely
// need to transfer the data back to host. This includes the cases where
Expand Down
4 changes: 4 additions & 0 deletions taichi/backends/metal/kernel_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ class KernelManager {
// * |mtl_kernel_source_code| is the complete source code compiled from a
// Taichi kernel. It may include one or more Metal compute kernels. Each
// Metal kernel is identified by one item in |kernels_attribs|.
//
// TODO(k-ye): Remove |taichi_kernel_name| now that it's part of
// |ti_kernel_attribs|. Return a handle that will be passed to
// launch_taichi_kernel(), instead of using kernel name as the identifier.
void register_taichi_kernel(const std::string &taichi_kernel_name,
const std::string &mtl_kernel_source_code,
const TaichiKernelAttributes &ti_kernel_attribs,
Expand Down
4 changes: 3 additions & 1 deletion taichi/backends/metal/kernel_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ struct TaichiKernelAttributes {
// since MSL 2.1
bool simdgroup = false;
};

std::string name;
// Is this kernel for evaluating the constant fold result?
bool is_jit_evaluator = false;
// Attributes of all the Metal kernels produced from this Taichi kernel.
std::vector<KernelAttributes> mtl_kernels_attribs;
UsedFeatures used_features;
Expand Down

0 comments on commit c75e5ed

Please sign in to comment.