Skip to content

Commit

Permalink
[Refactor] Add a class Callable to unify Kernel and Function (#2338)
Browse files Browse the repository at this point in the history
* Add Callable and change Kernel::program to Program *

* [skip ci] enforce code format

* Add Program::current_callable

* cleanup

* add get_name() and fix tests

* Rename: TaichiCallableTemplateMapper

* code format

* fix test

* Update taichi/program/callable.h

Co-authored-by: Ye Kuang <[email protected]>

* Remove the implementation of the pure virtual function

Co-authored-by: Taichi Gardener <[email protected]>
Co-authored-by: Ye Kuang <[email protected]>
  • Loading branch information
3 people authored May 18, 2021
1 parent 22e33a6 commit aabd87c
Show file tree
Hide file tree
Showing 28 changed files with 198 additions and 205 deletions.
10 changes: 5 additions & 5 deletions python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def __init__(self, func, classfunc=False, pyfunc=False):
for i in range(len(self.argument_annotations)):
if isinstance(self.argument_annotations[i], template):
self.template_slot_locations.append(i)
self.mapper = KernelTemplateMapper(self.argument_annotations,
self.template_slot_locations)
self.mapper = TaichiCallableTemplateMapper(
self.argument_annotations, self.template_slot_locations)
self.taichi_functions = {} # The |Function| class in C++

def __call__(self, *args):
Expand Down Expand Up @@ -207,7 +207,7 @@ def extract_arguments(self):
self.argument_names.append(param.name)


class KernelTemplateMapper:
class TaichiCallableTemplateMapper:
def __init__(self, annotations, template_slot_locations):
self.annotations = annotations
# Make sure extractors's size is the same as the number of args
Expand Down Expand Up @@ -286,8 +286,8 @@ def __init__(self, func, is_grad, classkernel=False):
for i in range(len(self.argument_annotations)):
if isinstance(self.argument_annotations[i], template):
self.template_slot_locations.append(i)
self.mapper = KernelTemplateMapper(self.argument_annotations,
self.template_slot_locations)
self.mapper = TaichiCallableTemplateMapper(
self.argument_annotations, self.template_slot_locations)
impl.get_runtime().kernels.append(self)
self.reset()

Expand Down
8 changes: 4 additions & 4 deletions taichi/backends/cc/codegen_cc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class CCTransformer : public IRVisitor {

void lower_ast() {
auto ir = kernel->ir.get();
auto config = kernel->program.config;
auto config = kernel->program->config;
config.demote_dense_struct_fors = true;
irpass::compile_to_executable(ir, config, kernel,
/*vectorize=*/false, kernel->grad,
Expand Down Expand Up @@ -88,7 +88,7 @@ class CCTransformer : public IRVisitor {
}

void visit(GetRootStmt *stmt) override {
auto root = kernel->program.snode_root.get();
auto root = kernel->program->snode_root.get();
emit("{} = ti_ctx->root;",
define_var(get_node_ptr_name(root), stmt->raw_name()));
root_stmt = stmt;
Expand Down Expand Up @@ -598,7 +598,7 @@ class CCTransformer : public IRVisitor {
}; // namespace cccp

std::unique_ptr<CCKernel> CCKernelGen::compile() {
auto program = kernel->program.cc_program.get();
auto program = kernel->program->cc_program.get();
auto layout = program->get_layout();
CCTransformer tran(kernel, layout);

Expand All @@ -613,7 +613,7 @@ FunctionType compile_kernel(Kernel *kernel) {
CCKernelGen codegen(kernel);
auto ker = codegen.compile();
auto ker_ptr = ker.get();
auto program = kernel->program.cc_program.get();
auto program = kernel->program->cc_program.get();
program->add_kernel(std::move(ker));
return [ker_ptr](Context &ctx) { return ker_ptr->launch(&ctx); };
}
Expand Down
4 changes: 2 additions & 2 deletions taichi/backends/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
tlctx->mark_function_as_cuda_kernel(func, task.block_dim);
}

auto jit = kernel->program.llvm_context_device->jit.get();
auto jit = kernel->program->llvm_context_device->jit.get();
auto cuda_module =
jit->add_module(std::move(module), kernel->program.config.gpu_max_reg);
jit->add_module(std::move(module), kernel->program->config.gpu_max_reg);

return [offloaded_local, cuda_module,
kernel = this->kernel](Context &context) {
Expand Down
2 changes: 1 addition & 1 deletion taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1507,7 +1507,7 @@ FunctionType compile_to_metal_executable(
cgen_config.allow_simdgroup = EnvConfig::instance().is_simdgroup_enabled();

KernelCodegen codegen(
taichi_kernel_name, kernel->program.snode_root->node_type_name, kernel,
taichi_kernel_name, kernel->program->snode_root->node_type_name, kernel,
compiled_structs, kernel_mgr->print_strtable(), cgen_config, offloaded);

const auto source_code = codegen.run();
Expand Down
2 changes: 1 addition & 1 deletion taichi/backends/opengl/codegen_opengl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1110,7 +1110,7 @@ FunctionType OpenglCodeGen::gen(void) {

void OpenglCodeGen::lower() {
auto ir = kernel_->ir.get();
auto &config = kernel_->program.config;
auto &config = kernel_->program->config;
config.demote_dense_struct_fors = true;
irpass::compile_to_executable(ir, config, kernel_,
/*vectorize=*/false, kernel_->grad,
Expand Down
2 changes: 1 addition & 1 deletion taichi/codegen/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
TLANG_NAMESPACE_BEGIN

KernelCodeGen::KernelCodeGen(Kernel *kernel, IRNode *ir)
: prog(&kernel->program), kernel(kernel), ir(ir) {
: prog(kernel->program), kernel(kernel), ir(ir) {
if (ir == nullptr)
this->ir = kernel->ir.get();

Expand Down
8 changes: 4 additions & 4 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,12 +270,12 @@ void CodeGenLLVM::emit_struct_meta_base(const std::string &name,

CodeGenLLVM::CodeGenLLVM(Kernel *kernel, IRNode *ir)
// TODO: simplify LLVMModuleBuilder ctor input
: LLVMModuleBuilder(
kernel->program.get_llvm_context(kernel->arch)->clone_struct_module(),
kernel->program.get_llvm_context(kernel->arch)),
: LLVMModuleBuilder(kernel->program->get_llvm_context(kernel->arch)
->clone_struct_module(),
kernel->program->get_llvm_context(kernel->arch)),
kernel(kernel),
ir(ir),
prog(&kernel->program) {
prog(kernel->program) {
if (ir == nullptr)
this->ir = kernel->ir.get();
initialize_context();
Expand Down
7 changes: 2 additions & 5 deletions taichi/ir/expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,8 @@ Expr Expr::operator[](const ExprGroup &indices) const {
}

Expr &Expr::operator=(const Expr &o) {
if ((std::holds_alternative<Kernel *>(
get_current_program().current_kernel_or_function) &&
std::get<Kernel *>(get_current_program().current_kernel_or_function)) ||
(std::holds_alternative<Function *>(
get_current_program().current_kernel_or_function))) {
if (get_current_program().current_callable) {
// Inside a kernel or a function
// Create an assignment in the IR
if (expr == nullptr) {
set(o.eval());
Expand Down
2 changes: 1 addition & 1 deletion taichi/program/async_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ void ExecutionQueue::enqueue(const TaskLaunchRecord &ker) {
// Final lowering
using namespace irpass;

auto config = kernel->program.config;
auto config = kernel->program->config;
auto ir = stmt;
offload_to_executable(
ir, config, kernel, /*verbose=*/false,
Expand Down
29 changes: 29 additions & 0 deletions taichi/program/callable.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#include "taichi/program/callable.h"
#include "taichi/program/program.h"

namespace taichi {
namespace lang {

int Callable::insert_arg(const DataType &dt, bool is_external_array) {
args.emplace_back(dt->get_compute_type(), is_external_array, /*size=*/0);
return (int)args.size() - 1;
}

int Callable::insert_ret(const DataType& dt) {
rets.emplace_back(dt->get_compute_type());
return (int)rets.size() - 1;
}

Callable::CurrentCallableGuard::CurrentCallableGuard(Program *program,
Callable *callable)
: program(program) {
old_callable = program->current_callable;
program->current_callable = callable;
}

Callable::CurrentCallableGuard::~CurrentCallableGuard() {
program->current_callable = old_callable;
}

} // namespace lang
} // namespace taichi
58 changes: 58 additions & 0 deletions taichi/program/callable.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#pragma once

#include "taichi/lang_util.h"

namespace taichi {
namespace lang {

class Program;
class IRNode;

class Callable {
public:
Program *program;
std::unique_ptr<IRNode> ir;

struct Arg {
DataType dt;
bool is_external_array;
std::size_t size;

explicit Arg(const DataType &dt = PrimitiveType::unknown,
bool is_external_array = false,
std::size_t size = 0)
: dt(dt), is_external_array(is_external_array), size(size) {
}
};

struct Ret {
DataType dt;

explicit Ret(const DataType &dt = PrimitiveType::unknown) : dt(dt) {
}
};

std::vector<Arg> args;
std::vector<Ret> rets;

virtual ~Callable() = default;

int insert_arg(const DataType &dt, bool is_external_array);

int insert_ret(const DataType &dt);

[[nodiscard]] virtual std::string get_name() const = 0;

class CurrentCallableGuard {
Callable *old_callable;
Program *program;

public:
CurrentCallableGuard(Program *program, Callable *callable);

~CurrentCallableGuard();
};
};

} // namespace lang
} // namespace taichi
32 changes: 5 additions & 27 deletions taichi/program/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,9 @@
namespace taichi {
namespace lang {

namespace {
class CurrentFunctionGuard {
std::variant<Kernel *, Function *> old_kernel_or_function;
Program *program;

public:
CurrentFunctionGuard(Program *program, Function *func) : program(program) {
old_kernel_or_function = program->current_kernel_or_function;
program->current_kernel_or_function = func;
}

~CurrentFunctionGuard() {
program->current_kernel_or_function = old_kernel_or_function;
}
};
} // namespace

Function::Function(Program *program, const FunctionKey &func_key)
: program(program), func_key(func_key) {
: func_key(func_key) {
this->program = program;
}

void Function::set_function_body(const std::function<void()> &func) {
Expand All @@ -34,7 +18,7 @@ void Function::set_function_body(const std::function<void()> &func) {
ir = taichi::lang::context->get_root();
{
// Note: this is not a mutex
CurrentFunctionGuard _(program, this);
CurrentCallableGuard _(program, this);
func();
}
irpass::compile_inline_function(ir.get(), program->config, this,
Expand All @@ -53,14 +37,8 @@ void Function::set_function_body(std::unique_ptr<IRNode> func_body) {
/*start_from_ast=*/false);
}

int Function::insert_arg(DataType dt, bool is_external_array) {
args.push_back(Arg{dt->get_compute_type(), is_external_array, /*size=*/0});
return args.size() - 1;
}

int Function::insert_ret(DataType dt) {
rets.push_back(Ret{dt->get_compute_type()});
return rets.size() - 1;
std::string Function::get_name() const {
return func_key.get_full_name();
}

} // namespace lang
Expand Down
21 changes: 3 additions & 18 deletions taichi/program/function.h
Original file line number Diff line number Diff line change
@@ -1,27 +1,14 @@
#pragma once

#include "taichi/lang_util.h"
#include "taichi/ir/ir.h"
#include "taichi/program/callable.h"
#include "taichi/program/function_key.h"
#include "taichi/program/kernel.h"

namespace taichi {
namespace lang {

class Program;

// TODO: Let Function and Kernel inherit from some class like "Callable"
// and merge the common part?
class Function {
class Function : public Callable {
public:
Program *program;
FunctionKey func_key;
std::unique_ptr<IRNode> ir;
using Arg = Kernel::Arg;
using Ret = Kernel::Ret;

std::vector<Arg> args;
std::vector<Ret> rets;

Function(Program *program, const FunctionKey &func_key);

Expand All @@ -32,9 +19,7 @@ class Function {
// Set the function body to a CHI IR.
void set_function_body(std::unique_ptr<IRNode> func_body);

int insert_arg(DataType dt, bool is_external_array);

int insert_ret(DataType dt);
[[nodiscard]] std::string get_name() const override;
};

} // namespace lang
Expand Down
4 changes: 2 additions & 2 deletions taichi/program/ir_bank.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ IRHandle IRBank::fuse(IRHandle handle_a, IRHandle handle_b, Kernel *kernel) {
}
}

irpass::full_simplify(task_a, kernel->program.config,
{/*after_lower_access=*/false, &kernel->program});
irpass::full_simplify(task_a, kernel->program->config,
{/*after_lower_access=*/false, kernel->program});
// For now, re_id is necessary for the hash to be correct.
irpass::re_id(task_a);

Expand Down
2 changes: 1 addition & 1 deletion taichi/program/ir_node_extended_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Kernel *IRNode::get_kernel() const {
}

CompileConfig &IRNode::get_config() const {
return get_kernel()->program.config;
return get_kernel()->program->config;
}

} // namespace lang
Expand Down
Loading

0 comments on commit aabd87c

Please sign in to comment.