Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Add a class Callable to unify Kernel and Function #2338

Merged
merged 12 commits into from
May 18, 2021
10 changes: 5 additions & 5 deletions python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def __init__(self, func, classfunc=False, pyfunc=False):
for i in range(len(self.argument_annotations)):
if isinstance(self.argument_annotations[i], template):
self.template_slot_locations.append(i)
self.mapper = KernelTemplateMapper(self.argument_annotations,
self.template_slot_locations)
self.mapper = TaichiCallableTemplateMapper(
self.argument_annotations, self.template_slot_locations)
self.taichi_functions = {} # The |Function| class in C++

def __call__(self, *args):
Expand Down Expand Up @@ -207,7 +207,7 @@ def extract_arguments(self):
self.argument_names.append(param.name)


class KernelTemplateMapper:
class TaichiCallableTemplateMapper:
def __init__(self, annotations, template_slot_locations):
self.annotations = annotations
# Make sure extractors's size is the same as the number of args
Expand Down Expand Up @@ -286,8 +286,8 @@ def __init__(self, func, is_grad, classkernel=False):
for i in range(len(self.argument_annotations)):
if isinstance(self.argument_annotations[i], template):
self.template_slot_locations.append(i)
self.mapper = KernelTemplateMapper(self.argument_annotations,
self.template_slot_locations)
self.mapper = TaichiCallableTemplateMapper(
self.argument_annotations, self.template_slot_locations)
impl.get_runtime().kernels.append(self)
self.reset()

Expand Down
8 changes: 4 additions & 4 deletions taichi/backends/cc/codegen_cc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class CCTransformer : public IRVisitor {

void lower_ast() {
auto ir = kernel->ir.get();
auto config = kernel->program.config;
auto config = kernel->program->config;
config.demote_dense_struct_fors = true;
irpass::compile_to_executable(ir, config, kernel,
/*vectorize=*/false, kernel->grad,
Expand Down Expand Up @@ -88,7 +88,7 @@ class CCTransformer : public IRVisitor {
}

void visit(GetRootStmt *stmt) override {
auto root = kernel->program.snode_root.get();
auto root = kernel->program->snode_root.get();
emit("{} = ti_ctx->root;",
define_var(get_node_ptr_name(root), stmt->raw_name()));
root_stmt = stmt;
Expand Down Expand Up @@ -598,7 +598,7 @@ class CCTransformer : public IRVisitor {
}; // namespace cccp

std::unique_ptr<CCKernel> CCKernelGen::compile() {
auto program = kernel->program.cc_program.get();
auto program = kernel->program->cc_program.get();
auto layout = program->get_layout();
CCTransformer tran(kernel, layout);

Expand All @@ -613,7 +613,7 @@ FunctionType compile_kernel(Kernel *kernel) {
CCKernelGen codegen(kernel);
auto ker = codegen.compile();
auto ker_ptr = ker.get();
auto program = kernel->program.cc_program.get();
auto program = kernel->program->cc_program.get();
program->add_kernel(std::move(ker));
return [ker_ptr](Context &ctx) { return ker_ptr->launch(&ctx); };
}
Expand Down
4 changes: 2 additions & 2 deletions taichi/backends/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
tlctx->mark_function_as_cuda_kernel(func, task.block_dim);
}

auto jit = kernel->program.llvm_context_device->jit.get();
auto jit = kernel->program->llvm_context_device->jit.get();
auto cuda_module =
jit->add_module(std::move(module), kernel->program.config.gpu_max_reg);
jit->add_module(std::move(module), kernel->program->config.gpu_max_reg);

return [offloaded_local, cuda_module,
kernel = this->kernel](Context &context) {
Expand Down
2 changes: 1 addition & 1 deletion taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1507,7 +1507,7 @@ FunctionType compile_to_metal_executable(
cgen_config.allow_simdgroup = EnvConfig::instance().is_simdgroup_enabled();

KernelCodegen codegen(
taichi_kernel_name, kernel->program.snode_root->node_type_name, kernel,
taichi_kernel_name, kernel->program->snode_root->node_type_name, kernel,
compiled_structs, kernel_mgr->print_strtable(), cgen_config, offloaded);

const auto source_code = codegen.run();
Expand Down
2 changes: 1 addition & 1 deletion taichi/backends/opengl/codegen_opengl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1110,7 +1110,7 @@ FunctionType OpenglCodeGen::gen(void) {

void OpenglCodeGen::lower() {
auto ir = kernel_->ir.get();
auto &config = kernel_->program.config;
auto &config = kernel_->program->config;
config.demote_dense_struct_fors = true;
irpass::compile_to_executable(ir, config, kernel_,
/*vectorize=*/false, kernel_->grad,
Expand Down
2 changes: 1 addition & 1 deletion taichi/codegen/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
TLANG_NAMESPACE_BEGIN

KernelCodeGen::KernelCodeGen(Kernel *kernel, IRNode *ir)
: prog(&kernel->program), kernel(kernel), ir(ir) {
: prog(kernel->program), kernel(kernel), ir(ir) {
if (ir == nullptr)
this->ir = kernel->ir.get();

Expand Down
8 changes: 4 additions & 4 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,12 +270,12 @@ void CodeGenLLVM::emit_struct_meta_base(const std::string &name,

CodeGenLLVM::CodeGenLLVM(Kernel *kernel, IRNode *ir)
// TODO: simplify LLVMModuleBuilder ctor input
: LLVMModuleBuilder(
kernel->program.get_llvm_context(kernel->arch)->clone_struct_module(),
kernel->program.get_llvm_context(kernel->arch)),
: LLVMModuleBuilder(kernel->program->get_llvm_context(kernel->arch)
->clone_struct_module(),
kernel->program->get_llvm_context(kernel->arch)),
kernel(kernel),
ir(ir),
prog(&kernel->program) {
prog(kernel->program) {
if (ir == nullptr)
this->ir = kernel->ir.get();
initialize_context();
Expand Down
7 changes: 2 additions & 5 deletions taichi/ir/expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,8 @@ Expr Expr::operator[](const ExprGroup &indices) const {
}

Expr &Expr::operator=(const Expr &o) {
if ((std::holds_alternative<Kernel *>(
get_current_program().current_kernel_or_function) &&
std::get<Kernel *>(get_current_program().current_kernel_or_function)) ||
(std::holds_alternative<Function *>(
get_current_program().current_kernel_or_function))) {
if (get_current_program().current_callable) {
// Inside a kernel or a function
// Create an assignment in the IR
if (expr == nullptr) {
set(o.eval());
Expand Down
2 changes: 1 addition & 1 deletion taichi/program/async_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ void ExecutionQueue::enqueue(const TaskLaunchRecord &ker) {
// Final lowering
using namespace irpass;

auto config = kernel->program.config;
auto config = kernel->program->config;
auto ir = stmt;
offload_to_executable(
ir, config, kernel, /*verbose=*/false,
Expand Down
29 changes: 29 additions & 0 deletions taichi/program/callable.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#include "taichi/program/callable.h"
#include "taichi/program/program.h"

namespace taichi {
namespace lang {

int Callable::insert_arg(const DataType &dt, bool is_external_array) {
args.emplace_back(dt->get_compute_type(), is_external_array, /*size=*/0);
return (int)args.size() - 1;
}

int Callable::insert_ret(const DataType& dt) {
rets.emplace_back(dt->get_compute_type());
return (int)rets.size() - 1;
}

Callable::CurrentCallableGuard::CurrentCallableGuard(Program *program,
Callable *callable)
: program(program) {
old_callable = program->current_callable;
program->current_callable = callable;
}

Callable::CurrentCallableGuard::~CurrentCallableGuard() {
program->current_callable = old_callable;
}

} // namespace lang
} // namespace taichi
58 changes: 58 additions & 0 deletions taichi/program/callable.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#pragma once

#include "taichi/lang_util.h"

namespace taichi {
namespace lang {

class Program;
class IRNode;

class Callable {
public:
Program *program;
std::unique_ptr<IRNode> ir;

struct Arg {
DataType dt;
bool is_external_array;
std::size_t size;

explicit Arg(const DataType &dt = PrimitiveType::unknown,
bool is_external_array = false,
std::size_t size = 0)
: dt(dt), is_external_array(is_external_array), size(size) {
}
};

struct Ret {
DataType dt;

explicit Ret(const DataType &dt = PrimitiveType::unknown) : dt(dt) {
}
};

std::vector<Arg> args;
std::vector<Ret> rets;

virtual ~Callable() = default;

int insert_arg(const DataType &dt, bool is_external_array);

int insert_ret(const DataType &dt);

[[nodiscard]] virtual std::string get_name() const = 0;

class CurrentCallableGuard {
Callable *old_callable;
Program *program;

public:
CurrentCallableGuard(Program *program, Callable *callable);

~CurrentCallableGuard();
};
};

} // namespace lang
} // namespace taichi
32 changes: 5 additions & 27 deletions taichi/program/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,9 @@
namespace taichi {
namespace lang {

namespace {
class CurrentFunctionGuard {
std::variant<Kernel *, Function *> old_kernel_or_function;
Program *program;

public:
CurrentFunctionGuard(Program *program, Function *func) : program(program) {
old_kernel_or_function = program->current_kernel_or_function;
program->current_kernel_or_function = func;
}

~CurrentFunctionGuard() {
program->current_kernel_or_function = old_kernel_or_function;
}
};
} // namespace

Function::Function(Program *program, const FunctionKey &func_key)
: program(program), func_key(func_key) {
: func_key(func_key) {
this->program = program;
}

void Function::set_function_body(const std::function<void()> &func) {
Expand All @@ -34,7 +18,7 @@ void Function::set_function_body(const std::function<void()> &func) {
ir = taichi::lang::context->get_root();
{
// Note: this is not a mutex
CurrentFunctionGuard _(program, this);
CurrentCallableGuard _(program, this);
func();
}
irpass::compile_inline_function(ir.get(), program->config, this,
Expand All @@ -53,14 +37,8 @@ void Function::set_function_body(std::unique_ptr<IRNode> func_body) {
/*start_from_ast=*/false);
}

int Function::insert_arg(DataType dt, bool is_external_array) {
args.push_back(Arg{dt->get_compute_type(), is_external_array, /*size=*/0});
return args.size() - 1;
}

int Function::insert_ret(DataType dt) {
rets.push_back(Ret{dt->get_compute_type()});
return rets.size() - 1;
std::string Function::get_name() const {
return func_key.get_full_name();
}

} // namespace lang
Expand Down
21 changes: 3 additions & 18 deletions taichi/program/function.h
Original file line number Diff line number Diff line change
@@ -1,27 +1,14 @@
#pragma once

#include "taichi/lang_util.h"
#include "taichi/ir/ir.h"
#include "taichi/program/callable.h"
#include "taichi/program/function_key.h"
#include "taichi/program/kernel.h"

namespace taichi {
namespace lang {

class Program;

// TODO: Let Function and Kernel inherit from some class like "Callable"
// and merge the common part?
class Function {
class Function : public Callable {
public:
Program *program;
FunctionKey func_key;
std::unique_ptr<IRNode> ir;
using Arg = Kernel::Arg;
using Ret = Kernel::Ret;

std::vector<Arg> args;
std::vector<Ret> rets;

Function(Program *program, const FunctionKey &func_key);

Expand All @@ -32,9 +19,7 @@ class Function {
// Set the function body to a CHI IR.
void set_function_body(std::unique_ptr<IRNode> func_body);

int insert_arg(DataType dt, bool is_external_array);

int insert_ret(DataType dt);
[[nodiscard]] std::string get_name() const override;
};

} // namespace lang
Expand Down
4 changes: 2 additions & 2 deletions taichi/program/ir_bank.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ IRHandle IRBank::fuse(IRHandle handle_a, IRHandle handle_b, Kernel *kernel) {
}
}

irpass::full_simplify(task_a, kernel->program.config,
{/*after_lower_access=*/false, &kernel->program});
irpass::full_simplify(task_a, kernel->program->config,
{/*after_lower_access=*/false, kernel->program});
// For now, re_id is necessary for the hash to be correct.
irpass::re_id(task_a);

Expand Down
2 changes: 1 addition & 1 deletion taichi/program/ir_node_extended_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Kernel *IRNode::get_kernel() const {
}

CompileConfig &IRNode::get_config() const {
return get_kernel()->program.config;
return get_kernel()->program->config;
}

} // namespace lang
Expand Down
Loading