diff --git a/python/taichi/lang/kernel_impl.py b/python/taichi/lang/kernel_impl.py index 0d3d1717b28cf..311761a3dd904 100644 --- a/python/taichi/lang/kernel_impl.py +++ b/python/taichi/lang/kernel_impl.py @@ -90,8 +90,8 @@ def __init__(self, func, classfunc=False, pyfunc=False): for i in range(len(self.argument_annotations)): if isinstance(self.argument_annotations[i], template): self.template_slot_locations.append(i) - self.mapper = KernelTemplateMapper(self.argument_annotations, - self.template_slot_locations) + self.mapper = TaichiCallableTemplateMapper( + self.argument_annotations, self.template_slot_locations) self.taichi_functions = {} # The |Function| class in C++ def __call__(self, *args): @@ -207,7 +207,7 @@ def extract_arguments(self): self.argument_names.append(param.name) -class KernelTemplateMapper: +class TaichiCallableTemplateMapper: def __init__(self, annotations, template_slot_locations): self.annotations = annotations # Make sure extractors's size is the same as the number of args @@ -286,8 +286,8 @@ def __init__(self, func, is_grad, classkernel=False): for i in range(len(self.argument_annotations)): if isinstance(self.argument_annotations[i], template): self.template_slot_locations.append(i) - self.mapper = KernelTemplateMapper(self.argument_annotations, - self.template_slot_locations) + self.mapper = TaichiCallableTemplateMapper( + self.argument_annotations, self.template_slot_locations) impl.get_runtime().kernels.append(self) self.reset() diff --git a/taichi/backends/cc/codegen_cc.cpp b/taichi/backends/cc/codegen_cc.cpp index 2bea5f776f167..586a97005061c 100644 --- a/taichi/backends/cc/codegen_cc.cpp +++ b/taichi/backends/cc/codegen_cc.cpp @@ -45,7 +45,7 @@ class CCTransformer : public IRVisitor { void lower_ast() { auto ir = kernel->ir.get(); - auto config = kernel->program.config; + auto config = kernel->program->config; config.demote_dense_struct_fors = true; irpass::compile_to_executable(ir, config, kernel, /*vectorize=*/false, kernel->grad, @@ -88,7 +88,7 @@ class CCTransformer : public IRVisitor { } void visit(GetRootStmt *stmt) override { - auto root = kernel->program.snode_root.get(); + auto root = kernel->program->snode_root.get(); emit("{} = ti_ctx->root;", define_var(get_node_ptr_name(root), stmt->raw_name())); root_stmt = stmt; @@ -598,7 +598,7 @@ class CCTransformer : public IRVisitor { }; // namespace cccp std::unique_ptr CCKernelGen::compile() { - auto program = kernel->program.cc_program.get(); + auto program = kernel->program->cc_program.get(); auto layout = program->get_layout(); CCTransformer tran(kernel, layout); @@ -613,7 +613,7 @@ FunctionType compile_kernel(Kernel *kernel) { CCKernelGen codegen(kernel); auto ker = codegen.compile(); auto ker_ptr = ker.get(); - auto program = kernel->program.cc_program.get(); + auto program = kernel->program->cc_program.get(); program->add_kernel(std::move(ker)); return [ker_ptr](Context &ctx) { return ker_ptr->launch(&ctx); }; } diff --git a/taichi/backends/cuda/codegen_cuda.cpp b/taichi/backends/cuda/codegen_cuda.cpp index 3bdbe27fd399e..e28037f4a9803 100644 --- a/taichi/backends/cuda/codegen_cuda.cpp +++ b/taichi/backends/cuda/codegen_cuda.cpp @@ -40,9 +40,9 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { tlctx->mark_function_as_cuda_kernel(func, task.block_dim); } - auto jit = kernel->program.llvm_context_device->jit.get(); + auto jit = kernel->program->llvm_context_device->jit.get(); auto cuda_module = - jit->add_module(std::move(module), kernel->program.config.gpu_max_reg); + jit->add_module(std::move(module), kernel->program->config.gpu_max_reg); return [offloaded_local, cuda_module, kernel = this->kernel](Context &context) { diff --git a/taichi/backends/metal/codegen_metal.cpp b/taichi/backends/metal/codegen_metal.cpp index 49636413c8bc6..00b69ed4247b9 100644 --- a/taichi/backends/metal/codegen_metal.cpp +++ b/taichi/backends/metal/codegen_metal.cpp @@ -1507,7 +1507,7 @@ FunctionType compile_to_metal_executable( cgen_config.allow_simdgroup = EnvConfig::instance().is_simdgroup_enabled(); KernelCodegen codegen( - taichi_kernel_name, kernel->program.snode_root->node_type_name, kernel, + taichi_kernel_name, kernel->program->snode_root->node_type_name, kernel, compiled_structs, kernel_mgr->print_strtable(), cgen_config, offloaded); const auto source_code = codegen.run(); diff --git a/taichi/backends/opengl/codegen_opengl.cpp b/taichi/backends/opengl/codegen_opengl.cpp index c0cf3a7d52a93..3214d70ecbd34 100644 --- a/taichi/backends/opengl/codegen_opengl.cpp +++ b/taichi/backends/opengl/codegen_opengl.cpp @@ -1110,7 +1110,7 @@ FunctionType OpenglCodeGen::gen(void) { void OpenglCodeGen::lower() { auto ir = kernel_->ir.get(); - auto &config = kernel_->program.config; + auto &config = kernel_->program->config; config.demote_dense_struct_fors = true; irpass::compile_to_executable(ir, config, kernel_, /*vectorize=*/false, kernel_->grad, diff --git a/taichi/codegen/codegen.cpp b/taichi/codegen/codegen.cpp index aff59e947167a..eced172d2226f 100644 --- a/taichi/codegen/codegen.cpp +++ b/taichi/codegen/codegen.cpp @@ -13,7 +13,7 @@ TLANG_NAMESPACE_BEGIN KernelCodeGen::KernelCodeGen(Kernel *kernel, IRNode *ir) - : prog(&kernel->program), kernel(kernel), ir(ir) { + : prog(kernel->program), kernel(kernel), ir(ir) { if (ir == nullptr) this->ir = kernel->ir.get(); diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 485b556118bb3..971c6431b4e98 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -270,12 +270,12 @@ void CodeGenLLVM::emit_struct_meta_base(const std::string &name, CodeGenLLVM::CodeGenLLVM(Kernel *kernel, IRNode *ir) // TODO: simplify LLVMModuleBuilder ctor input - : LLVMModuleBuilder( - kernel->program.get_llvm_context(kernel->arch)->clone_struct_module(), - kernel->program.get_llvm_context(kernel->arch)), + : LLVMModuleBuilder(kernel->program->get_llvm_context(kernel->arch) + ->clone_struct_module(), + kernel->program->get_llvm_context(kernel->arch)), kernel(kernel), ir(ir), - prog(&kernel->program) { + prog(kernel->program) { if (ir == nullptr) this->ir = kernel->ir.get(); initialize_context(); diff --git a/taichi/ir/expr.cpp b/taichi/ir/expr.cpp index 388370a9faf54..4ff574a749c4b 100644 --- a/taichi/ir/expr.cpp +++ b/taichi/ir/expr.cpp @@ -55,11 +55,8 @@ Expr Expr::operator[](const ExprGroup &indices) const { } Expr &Expr::operator=(const Expr &o) { - if ((std::holds_alternative( - get_current_program().current_kernel_or_function) && - std::get(get_current_program().current_kernel_or_function)) || - (std::holds_alternative( - get_current_program().current_kernel_or_function))) { + if (get_current_program().current_callable) { + // Inside a kernel or a function // Create an assignment in the IR if (expr == nullptr) { set(o.eval()); diff --git a/taichi/program/async_engine.cpp b/taichi/program/async_engine.cpp index 6f09457f1963b..440f40336b617 100644 --- a/taichi/program/async_engine.cpp +++ b/taichi/program/async_engine.cpp @@ -150,7 +150,7 @@ void ExecutionQueue::enqueue(const TaskLaunchRecord &ker) { // Final lowering using namespace irpass; - auto config = kernel->program.config; + auto config = kernel->program->config; auto ir = stmt; offload_to_executable( ir, config, kernel, /*verbose=*/false, diff --git a/taichi/program/callable.cpp b/taichi/program/callable.cpp new file mode 100644 index 0000000000000..095c04fc8e94a --- /dev/null +++ b/taichi/program/callable.cpp @@ -0,0 +1,29 @@ +#include "taichi/program/callable.h" +#include "taichi/program/program.h" + +namespace taichi { +namespace lang { + +int Callable::insert_arg(const DataType &dt, bool is_external_array) { + args.emplace_back(dt->get_compute_type(), is_external_array, /*size=*/0); + return (int)args.size() - 1; +} + +int Callable::insert_ret(const DataType& dt) { + rets.emplace_back(dt->get_compute_type()); + return (int)rets.size() - 1; +} + +Callable::CurrentCallableGuard::CurrentCallableGuard(Program *program, + Callable *callable) + : program(program) { + old_callable = program->current_callable; + program->current_callable = callable; +} + +Callable::CurrentCallableGuard::~CurrentCallableGuard() { + program->current_callable = old_callable; +} + +} // namespace lang +} // namespace taichi diff --git a/taichi/program/callable.h b/taichi/program/callable.h new file mode 100644 index 0000000000000..4187f700dadc6 --- /dev/null +++ b/taichi/program/callable.h @@ -0,0 +1,58 @@ +#pragma once + +#include "taichi/lang_util.h" + +namespace taichi { +namespace lang { + +class Program; +class IRNode; + +class Callable { + public: + Program *program; + std::unique_ptr ir; + + struct Arg { + DataType dt; + bool is_external_array; + std::size_t size; + + explicit Arg(const DataType &dt = PrimitiveType::unknown, + bool is_external_array = false, + std::size_t size = 0) + : dt(dt), is_external_array(is_external_array), size(size) { + } + }; + + struct Ret { + DataType dt; + + explicit Ret(const DataType &dt = PrimitiveType::unknown) : dt(dt) { + } + }; + + std::vector args; + std::vector rets; + + virtual ~Callable() = default; + + int insert_arg(const DataType &dt, bool is_external_array); + + int insert_ret(const DataType &dt); + + [[nodiscard]] virtual std::string get_name() const = 0; + + class CurrentCallableGuard { + Callable *old_callable; + Program *program; + + public: + CurrentCallableGuard(Program *program, Callable *callable); + + ~CurrentCallableGuard(); + }; +}; + +} // namespace lang +} // namespace taichi diff --git a/taichi/program/function.cpp b/taichi/program/function.cpp index 28e720901c913..e18c25805f0ea 100644 --- a/taichi/program/function.cpp +++ b/taichi/program/function.cpp @@ -5,25 +5,9 @@ namespace taichi { namespace lang { -namespace { -class CurrentFunctionGuard { - std::variant old_kernel_or_function; - Program *program; - - public: - CurrentFunctionGuard(Program *program, Function *func) : program(program) { - old_kernel_or_function = program->current_kernel_or_function; - program->current_kernel_or_function = func; - } - - ~CurrentFunctionGuard() { - program->current_kernel_or_function = old_kernel_or_function; - } -}; -} // namespace - Function::Function(Program *program, const FunctionKey &func_key) - : program(program), func_key(func_key) { + : func_key(func_key) { + this->program = program; } void Function::set_function_body(const std::function &func) { @@ -34,7 +18,7 @@ void Function::set_function_body(const std::function &func) { ir = taichi::lang::context->get_root(); { // Note: this is not a mutex - CurrentFunctionGuard _(program, this); + CurrentCallableGuard _(program, this); func(); } irpass::compile_inline_function(ir.get(), program->config, this, @@ -53,14 +37,8 @@ void Function::set_function_body(std::unique_ptr func_body) { /*start_from_ast=*/false); } -int Function::insert_arg(DataType dt, bool is_external_array) { - args.push_back(Arg{dt->get_compute_type(), is_external_array, /*size=*/0}); - return args.size() - 1; -} - -int Function::insert_ret(DataType dt) { - rets.push_back(Ret{dt->get_compute_type()}); - return rets.size() - 1; +std::string Function::get_name() const { + return func_key.get_full_name(); } } // namespace lang diff --git a/taichi/program/function.h b/taichi/program/function.h index 2600454e712ce..07b97f39bd83d 100644 --- a/taichi/program/function.h +++ b/taichi/program/function.h @@ -1,27 +1,14 @@ #pragma once -#include "taichi/lang_util.h" -#include "taichi/ir/ir.h" +#include "taichi/program/callable.h" #include "taichi/program/function_key.h" -#include "taichi/program/kernel.h" namespace taichi { namespace lang { -class Program; - -// TODO: Let Function and Kernel inherit from some class like "Callable" -// and merge the common part? -class Function { +class Function : public Callable { public: - Program *program; FunctionKey func_key; - std::unique_ptr ir; - using Arg = Kernel::Arg; - using Ret = Kernel::Ret; - - std::vector args; - std::vector rets; Function(Program *program, const FunctionKey &func_key); @@ -32,9 +19,7 @@ class Function { // Set the function body to a CHI IR. void set_function_body(std::unique_ptr func_body); - int insert_arg(DataType dt, bool is_external_array); - - int insert_ret(DataType dt); + [[nodiscard]] std::string get_name() const override; }; } // namespace lang diff --git a/taichi/program/ir_bank.cpp b/taichi/program/ir_bank.cpp index a224c63c30bc2..cfb7402d04fcb 100644 --- a/taichi/program/ir_bank.cpp +++ b/taichi/program/ir_bank.cpp @@ -112,8 +112,8 @@ IRHandle IRBank::fuse(IRHandle handle_a, IRHandle handle_b, Kernel *kernel) { } } - irpass::full_simplify(task_a, kernel->program.config, - {/*after_lower_access=*/false, &kernel->program}); + irpass::full_simplify(task_a, kernel->program->config, + {/*after_lower_access=*/false, kernel->program}); // For now, re_id is necessary for the hash to be correct. irpass::re_id(task_a); diff --git a/taichi/program/ir_node_extended_impl.cpp b/taichi/program/ir_node_extended_impl.cpp index 4d47a8e938574..aa27c0e0e1e71 100644 --- a/taichi/program/ir_node_extended_impl.cpp +++ b/taichi/program/ir_node_extended_impl.cpp @@ -15,7 +15,7 @@ Kernel *IRNode::get_kernel() const { } CompileConfig &IRNode::get_config() const { - return get_kernel()->program.config; + return get_kernel()->program->config; } } // namespace lang diff --git a/taichi/program/kernel.cpp b/taichi/program/kernel.cpp index 4399423d54749..e5346e97e6a27 100644 --- a/taichi/program/kernel.cpp +++ b/taichi/program/kernel.cpp @@ -1,4 +1,4 @@ -#include "kernel.h" +#include "taichi/program/kernel.h" #include "taichi/backends/cuda/cuda_driver.h" #include "taichi/codegen/codegen.h" @@ -15,28 +15,13 @@ TLANG_NAMESPACE_BEGIN class Function; -namespace { -class CurrentKernelGuard { - std::variant old_kernel_or_function; - Program &program; - - public: - CurrentKernelGuard(Program &program_, Kernel *kernel) : program(program_) { - old_kernel_or_function = program.current_kernel_or_function; - program.current_kernel_or_function = kernel; - } - - ~CurrentKernelGuard() { - program.current_kernel_or_function = old_kernel_or_function; - } -}; -} // namespace - Kernel::Kernel(Program &program, const std::function &func, const std::string &primal_name, bool grad) - : program(program), lowered(false), grad(grad) { + : lowered(false), grad(grad) { + this->program = &program; + // Do not corrupt the context calling this kernel here -- maybe unnecessary auto backup_context = std::move(taichi::lang::context); @@ -52,7 +37,7 @@ Kernel::Kernel(Program &program, // Note: this is NOT a mutex. If we want to call Kernel::Kernel() // concurrently, we need to lock this block of code together with // taichi::lang::context with a mutex. - CurrentKernelGuard _(program, this); + CurrentCallableGuard _(this->program, this); program.start_kernel_definition(this); func(); program.end_kernel_definition(); @@ -77,7 +62,9 @@ Kernel::Kernel(Program &program, std::unique_ptr &&ir, const std::string &primal_name, bool grad) - : ir(std::move(ir)), program(program), lowered(false), grad(grad) { + : lowered(false), grad(grad) { + this->ir = std::move(ir); + this->program = &program; is_accessor = false; is_evaluator = false; compiled = nullptr; @@ -97,16 +84,16 @@ Kernel::Kernel(Program &program, } void Kernel::compile() { - CurrentKernelGuard _(program, this); - compiled = program.compile(*this); + CurrentCallableGuard _(program, this); + compiled = program->compile(*this); } void Kernel::lower(bool to_executable) { // TODO: is a "Lowerer" class // necessary for each backend? TI_ASSERT(!lowered); if (arch_is_cpu(arch) || arch == Arch::cuda || arch == Arch::metal) { - CurrentKernelGuard _(program, this); - auto config = program.config; + CurrentCallableGuard _(program, this); + auto config = program->config; bool verbose = config.print_ir; if ((is_accessor && !config.print_accessor_ir) || (is_evaluator && !config.print_evaluator_ir)) @@ -134,7 +121,7 @@ void Kernel::lower(bool to_executable) { // TODO: is a "Lowerer" class } void Kernel::operator()(LaunchContextBuilder &ctx_builder) { - if (!program.config.async_mode || this->is_evaluator) { + if (!program->config.async_mode || this->is_evaluator) { if (!compiled) { compile(); } @@ -145,19 +132,19 @@ void Kernel::operator()(LaunchContextBuilder &ctx_builder) { compiled(ctx_builder.get_context()); - program.sync = (program.sync && arch_is_cpu(arch)); + program->sync = (program->sync && arch_is_cpu(arch)); // Note that Kernel::arch may be different from program.config.arch - if (program.config.debug && (arch_is_cpu(program.config.arch) || - program.config.arch == Arch::cuda)) { - program.check_runtime_error(); + if (program->config.debug && (arch_is_cpu(program->config.arch) || + program->config.arch == Arch::cuda)) { + program->check_runtime_error(); } } else { - program.sync = false; - program.async_engine->launch(this, ctx_builder.get_context()); + program->sync = false; + program->async_engine->launch(this, ctx_builder.get_context()); // Note that Kernel::arch may be different from program.config.arch - if (program.config.debug && arch_is_cpu(arch) && - arch_is_cpu(program.config.arch)) { - program.check_runtime_error(); + if (program->config.debug && arch_is_cpu(arch) && + arch_is_cpu(program->config.arch)) { + program->check_runtime_error(); } } } @@ -285,7 +272,7 @@ void Kernel::LaunchContextBuilder::set_arg_raw(int arg_id, uint64 d) { } Context &Kernel::LaunchContextBuilder::get_context() { - ctx_->runtime = static_cast(kernel_->program.llvm_runtime); + ctx_->runtime = static_cast(kernel_->program->llvm_runtime); return *ctx_; } @@ -348,16 +335,6 @@ void Kernel::set_arch(Arch arch) { this->arch = arch; } -int Kernel::insert_arg(DataType dt, bool is_external_array) { - args.push_back(Arg{dt->get_compute_type(), is_external_array, /*size=*/0}); - return args.size() - 1; -} - -int Kernel::insert_ret(DataType dt) { - rets.push_back(Ret{dt->get_compute_type()}); - return rets.size() - 1; -} - void Kernel::account_for_offloaded(OffloadedStmt *stmt) { if (is_evaluator || is_accessor) return; @@ -382,4 +359,8 @@ void Kernel::account_for_offloaded(OffloadedStmt *stmt) { } } +std::string Kernel::get_name() const { + return name; +} + TLANG_NAMESPACE_END diff --git a/taichi/program/kernel.h b/taichi/program/kernel.h index 331b56a911be4..d1c9ef4155e8d 100644 --- a/taichi/program/kernel.h +++ b/taichi/program/kernel.h @@ -4,6 +4,7 @@ #include "taichi/ir/snode.h" #include "taichi/ir/ir.h" #include "taichi/program/arch.h" +#include "taichi/program/callable.h" #define TI_RUNTIME_HOST #include "taichi/program/context.h" @@ -13,11 +14,9 @@ TLANG_NAMESPACE_BEGIN class Program; -class Kernel { +class Kernel : public Callable { public: - std::unique_ptr ir; bool ir_is_ast; - Program &program; FunctionType compiled; std::string name; std::vector no_activate; @@ -25,27 +24,6 @@ class Kernel { bool lowered; // lower inital AST all the way down to a bunch of // OffloadedStmt for async execution - struct Arg { - DataType dt; - bool is_external_array; - std::size_t size; - - Arg(DataType dt = PrimitiveType::unknown, - bool is_external_array = false, - std::size_t size = 0) - : dt(dt), is_external_array(is_external_array), size(size) { - } - }; - - struct Ret { - DataType dt; - - explicit Ret(DataType dt = PrimitiveType::unknown) : dt(dt) { - } - }; - - std::vector args; - std::vector rets; bool is_accessor; bool is_evaluator; bool grad; @@ -103,10 +81,6 @@ class Kernel { LaunchContextBuilder make_launch_context(); - int insert_arg(DataType dt, bool is_external_array); - - int insert_ret(DataType dt); - float64 get_ret_float(int i); int64 get_ret_int(int i); @@ -114,6 +88,8 @@ class Kernel { void set_arch(Arch arch); void account_for_offloaded(OffloadedStmt *stmt); + + [[nodiscard]] std::string get_name() const override; }; TLANG_NAMESPACE_END diff --git a/taichi/program/program.cpp b/taichi/program/program.cpp index b09516351ffa4..6c5714262f24a 100644 --- a/taichi/program/program.cpp +++ b/taichi/program/program.cpp @@ -160,7 +160,7 @@ Program::Program(Arch desired_arch) : snode_rw_accessors_bank_(this) { #endif result_buffer = nullptr; - current_kernel_or_function = static_cast(nullptr); + current_callable = nullptr; sync = true; llvm_runtime = nullptr; finalized = false; diff --git a/taichi/program/program.h b/taichi/program/program.h index 5e5b847a4ba69..7988e5440e5a3 100644 --- a/taichi/program/program.h +++ b/taichi/program/program.h @@ -15,6 +15,7 @@ #include "taichi/backends/metal/kernel_manager.h" #include "taichi/backends/opengl/opengl_kernel_launcher.h" #include "taichi/backends/cc/cc_program.h" +#include "taichi/program/callable.h" #include "taichi/program/function.h" #include "taichi/program/kernel.h" #include "taichi/program/kernel_profiler.h" @@ -82,7 +83,7 @@ class AsyncEngine; class Program { public: using Kernel = taichi::lang::Kernel; - std::variant current_kernel_or_function; + Callable *current_callable; std::unique_ptr snode_root; // pointer to the data structure. void *llvm_runtime; CompileConfig config; @@ -186,7 +187,7 @@ class Program { } void start_kernel_definition(Kernel *kernel) { - current_kernel_or_function = kernel; + current_callable = kernel; } void end_kernel_definition() { @@ -208,16 +209,15 @@ class Program { void check_runtime_error(); - inline Kernel &get_current_kernel() { - TI_ASSERT(std::holds_alternative(current_kernel_or_function)); - auto *kernel = std::get(current_kernel_or_function); + // TODO(#2193): Remove get_current_kernel() and get_current_function()? + inline Kernel &get_current_kernel() const { + auto *kernel = dynamic_cast(current_callable); TI_ASSERT(kernel); return *kernel; } - inline Function *get_current_function() { - TI_ASSERT(std::holds_alternative(current_kernel_or_function)); - auto *func = std::get(current_kernel_or_function); + inline Function *get_current_function() const { + auto *func = dynamic_cast(current_callable); return func; } diff --git a/taichi/program/state_flow_graph.cpp b/taichi/program/state_flow_graph.cpp index a901a039090dc..7cb879238d963 100644 --- a/taichi/program/state_flow_graph.cpp +++ b/taichi/program/state_flow_graph.cpp @@ -1290,7 +1290,7 @@ bool StateFlowGraph::optimize_dead_store() { // ***************************** // Erase the state s output. if (!store_eliminable_snodes.empty()) { - const bool verbose = task->rec.kernel->program.config.verbose; + const bool verbose = task->rec.kernel->program->config.verbose; const auto dse_result = ir_bank_->optimize_dse( task->rec.ir_handle, store_eliminable_snodes, verbose); diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 318b7a8702c98..27ead5a78053c 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -703,24 +703,12 @@ void export_lang(py::module &m) { std::make_unique(contents)); }); - m.def("decl_arg", [&](DataType dt, bool is_nparray) { - if (std::holds_alternative( - get_current_program().current_kernel_or_function)) { - return get_current_program().get_current_kernel().insert_arg(dt, - is_nparray); - } else { - return get_current_program().get_current_function()->insert_arg( - dt, is_nparray); - } + m.def("decl_arg", [&](const DataType &dt, bool is_nparray) { + return get_current_program().current_callable->insert_arg(dt, is_nparray); }); - m.def("decl_ret", [&](DataType dt) { - if (std::holds_alternative( - get_current_program().current_kernel_or_function)) { - return get_current_program().get_current_kernel().insert_ret(dt); - } else { - return get_current_program().get_current_function()->insert_ret(dt); - } + m.def("decl_ret", [&](const DataType &dt) { + return get_current_program().current_callable->insert_ret(dt); }); m.def("test_throw", [] { @@ -743,6 +731,7 @@ void export_lang(py::module &m) { m.def("insert_snode_access_flag", insert_snode_access_flag); m.def("no_activate", [](SNode *snode) { + // TODO(#2193): Also apply to @ti.func? get_current_program().get_current_kernel().no_activate.push_back(snode); }); m.def("stop_grad", diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index fc1be6e5d5c35..67bea20d2797b 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -39,7 +39,7 @@ void compile_to_offloads(IRNode *ir, bool start_from_ast) { TI_AUTO_PROF; - auto print = make_pass_printer(verbose, kernel->name, ir); + auto print = make_pass_printer(verbose, kernel->get_name(), ir); print("Initial IR"); if (grad) { @@ -87,7 +87,7 @@ void compile_to_offloads(IRNode *ir, irpass::analysis::verify(ir); } - irpass::full_simplify(ir, config, {false, &kernel->program}); + irpass::full_simplify(ir, config, {false, kernel->program}); print("Simplified I"); irpass::analysis::verify(ir); @@ -100,15 +100,15 @@ void compile_to_offloads(IRNode *ir, // Remove local atomics here so that we don't have to handle their gradients irpass::demote_atomics(ir, config); - irpass::full_simplify(ir, config, {false, &kernel->program}); + irpass::full_simplify(ir, config, {false, kernel->program}); irpass::auto_diff(ir, config, ad_use_stack); - irpass::full_simplify(ir, config, {false, &kernel->program}); + irpass::full_simplify(ir, config, {false, kernel->program}); print("Gradient"); irpass::analysis::verify(ir); } if (config.check_out_of_bound) { - irpass::check_out_of_bound(ir, config, {kernel->name}); + irpass::check_out_of_bound(ir, config, {kernel->get_name()}); print("Bound checked"); irpass::analysis::verify(ir); } @@ -117,7 +117,7 @@ void compile_to_offloads(IRNode *ir, print("Access flagged I"); irpass::analysis::verify(ir); - irpass::full_simplify(ir, config, {false, &kernel->program}); + irpass::full_simplify(ir, config, {false, kernel->program}); print("Simplified II"); irpass::analysis::verify(ir); @@ -136,7 +136,7 @@ void compile_to_offloads(IRNode *ir, irpass::flag_access(ir); print("Access flagged II"); - irpass::full_simplify(ir, config, {false, &kernel->program}); + irpass::full_simplify(ir, config, {false, kernel->program}); print("Simplified III"); irpass::analysis::verify(ir); } @@ -150,7 +150,7 @@ void offload_to_executable(IRNode *ir, bool make_block_local) { TI_AUTO_PROF; - auto print = make_pass_printer(verbose, kernel->name, ir); + auto print = make_pass_printer(verbose, kernel->get_name(), ir); // TODO: This is just a proof that we can demote struct-fors after offloading. // Eventually we might want the order to be TLS/BLS -> demote struct-for. @@ -184,7 +184,7 @@ void offload_to_executable(IRNode *ir, } if (make_block_local) { - irpass::make_block_local(ir, config, {kernel->name}); + irpass::make_block_local(ir, config, {kernel->get_name()}); print("Make block local"); } @@ -221,7 +221,7 @@ void offload_to_executable(IRNode *ir, irpass::demote_operations(ir, config); print("Operations demoted"); - irpass::full_simplify(ir, config, {lower_global_access, &kernel->program}); + irpass::full_simplify(ir, config, {lower_global_access, kernel->program}); print("Simplified IV"); if (is_extension_supported(config.arch, Extension::quant)) { @@ -262,7 +262,7 @@ void compile_inline_function(IRNode *ir, bool start_from_ast) { TI_AUTO_PROF; - auto print = make_pass_printer(verbose, func->func_key.get_full_name(), ir); + auto print = make_pass_printer(verbose, func->get_name(), ir); print("Initial IR"); if (grad) { diff --git a/taichi/transforms/inlining.cpp b/taichi/transforms/inlining.cpp index 451a2a5ae132c..0ed35dd9a4f05 100644 --- a/taichi/transforms/inlining.cpp +++ b/taichi/transforms/inlining.cpp @@ -39,7 +39,7 @@ class Inliner : public BasicStmtVisitor { }).size() > 1) { TI_WARN( "Multiple returns in function \"{}\" may not be handled properly.", - func->func_key.get_full_name()); + func->get_name()); } // Use a local variable to store the return value auto *return_address = inlined_ir->as()->insert( diff --git a/taichi/transforms/ir_printer.cpp b/taichi/transforms/ir_printer.cpp index 422d805aa6e6e..d8a0d4652f1a0 100644 --- a/taichi/transforms/ir_printer.cpp +++ b/taichi/transforms/ir_printer.cpp @@ -288,7 +288,7 @@ class IRPrinter : public IRVisitor { args.push_back(arg->name()); } print("{}{} = call \"{}\", args = {{{}}}", stmt->type_hint(), stmt->name(), - stmt->func->func_key.get_full_name(), fmt::join(args, ", ")); + stmt->func->get_name(), fmt::join(args, ", ")); } void visit(FrontendFuncDefStmt *stmt) override { diff --git a/tests/cpp/transforms/alg_simp_test.cpp b/tests/cpp/transforms/alg_simp_test.cpp index 7e41853ce5794..75e0aef5b2ed1 100644 --- a/tests/cpp/transforms/alg_simp_test.cpp +++ b/tests/cpp/transforms/alg_simp_test.cpp @@ -97,7 +97,7 @@ TEST_F(AlgebraicSimplicationTest, SimplifyMultiplyZeroFastMath) { CompileConfig config_without_fast_math; config_without_fast_math.fast_math = false; - kernel->program.config = config_without_fast_math; + kernel->program->config = config_without_fast_math; irpass::type_check(block.get(), config_without_fast_math); EXPECT_EQ(block->size(), 8); @@ -125,9 +125,8 @@ TEST_F(AlgebraicSimplicationTest, SimplifyMultiplyZeroFastMath) { irpass::type_check(block.get(), config_without_fast_math); // insert 2 casts EXPECT_EQ(block->size(), 10); - irpass::constant_fold( - block.get(), config_without_fast_math, - {&kernel->program}); // should change 2 casts into const + irpass::constant_fold(block.get(), config_without_fast_math, + {kernel->program}); // should change 2 casts into const irpass::alg_simp(block.get(), config_without_fast_math); // should not eliminate irpass::die(block.get()); // should eliminate 2 const @@ -135,7 +134,7 @@ TEST_F(AlgebraicSimplicationTest, SimplifyMultiplyZeroFastMath) { CompileConfig config_with_fast_math; config_with_fast_math.fast_math = true; - kernel->program.config = config_with_fast_math; + kernel->program->config = config_with_fast_math; irpass::alg_simp(block.get(), config_with_fast_math); // should eliminate mul, add diff --git a/tests/cpp/transforms/simplify_test.cpp b/tests/cpp/transforms/simplify_test.cpp index 9ba2581ea4c30..f18808131b680 100644 --- a/tests/cpp/transforms/simplify_test.cpp +++ b/tests/cpp/transforms/simplify_test.cpp @@ -33,20 +33,20 @@ TEST(Simplify, SimplifyLinearizedWithTrivialInputs) { [[maybe_unused]] auto lookup2 = block->push_back( root.ch[0].get(), get_child, linearized_zero, true); - irpass::type_check(block.get(), kernel->program.config); + irpass::type_check(block.get(), kernel->program->config); EXPECT_EQ(block->size(), 7); irpass::simplify(block.get(), - kernel->program.config); // should lower linearized + kernel->program->config); // should lower linearized // EXPECT_EQ(block->size(), 11); // not required to check size here - irpass::constant_fold(block.get(), kernel->program.config, - {&kernel->program}); - irpass::alg_simp(block.get(), kernel->program.config); + irpass::constant_fold(block.get(), kernel->program->config, + {kernel->program}); + irpass::alg_simp(block.get(), kernel->program->config); irpass::die(block.get()); // should eliminate consts - irpass::simplify(block.get(), kernel->program.config); + irpass::simplify(block.get(), kernel->program->config); irpass::whole_kernel_cse(block.get()); - if (kernel->program.config.advanced_optimization) { + if (kernel->program->config.advanced_optimization) { // get root, const 0, lookup, get child, lookup EXPECT_EQ(block->size(), 5); } diff --git a/tests/python/test_kernel_template_mapper.py b/tests/python/test_callable_template_mapper.py similarity index 74% rename from tests/python/test_kernel_template_mapper.py rename to tests/python/test_callable_template_mapper.py index 3b68c36e84f00..3183e6c282b84 100644 --- a/tests/python/test_kernel_template_mapper.py +++ b/tests/python/test_callable_template_mapper.py @@ -1,15 +1,15 @@ -from taichi.lang.kernel_impl import KernelTemplateMapper +from taichi.lang.kernel_impl import TaichiCallableTemplateMapper import taichi as ti @ti.all_archs -def test_kernel_template_mapper(): +def test_callable_template_mapper(): x = ti.field(ti.i32) y = ti.field(ti.f32) ti.root.place(x, y) - mapper = KernelTemplateMapper( + mapper = TaichiCallableTemplateMapper( (ti.template(), ti.template(), ti.template()), template_slot_locations=(0, 1, 2)) assert mapper.lookup((0, 0, 0))[0] == 0 @@ -18,21 +18,22 @@ def test_kernel_template_mapper(): assert mapper.lookup((0, 0, 1))[0] == 2 assert mapper.lookup((0, 1, 0))[0] == 1 - mapper = KernelTemplateMapper((ti.i32, ti.i32, ti.i32), ()) + mapper = TaichiCallableTemplateMapper((ti.i32, ti.i32, ti.i32), ()) assert mapper.lookup((0, 0, 0))[0] == 0 assert mapper.lookup((0, 1, 0))[0] == 0 assert mapper.lookup((0, 0, 0))[0] == 0 assert mapper.lookup((0, 0, 1))[0] == 0 assert mapper.lookup((0, 1, 0))[0] == 0 - mapper = KernelTemplateMapper((ti.i32, ti.template(), ti.i32), (1, )) + mapper = TaichiCallableTemplateMapper((ti.i32, ti.template(), ti.i32), + (1,)) assert mapper.lookup((0, x, 0))[0] == 0 assert mapper.lookup((0, y, 0))[0] == 1 assert mapper.lookup((0, x, 1))[0] == 0 @ti.all_archs -def test_kernel_template_mapper_numpy(): +def test_callable_template_mapper_numpy(): x = ti.field(ti.i32) y = ti.field(ti.f32) @@ -42,7 +43,7 @@ def test_kernel_template_mapper_numpy(): import numpy as np - mapper = KernelTemplateMapper(annotations, (0, 1, 2)) + mapper = TaichiCallableTemplateMapper(annotations, (0, 1, 2)) assert mapper.lookup((0, 0, np.ones(shape=(1, 2, 3), dtype=np.float32)))[0] == 0 assert mapper.lookup((0, 0, np.ones(shape=(1, 2, 4), diff --git a/tests/python/test_function.py b/tests/python/test_function.py index 96938360d86dc..84393c3ba3324 100644 --- a/tests/python/test_function.py +++ b/tests/python/test_function.py @@ -6,7 +6,7 @@ def test_function_without_return(): x = ti.field(ti.i32, shape=()) @ti.func - def foo(val: ti.i32) -> ti.i32: + def foo(val: ti.i32): x[None] += val @ti.kernel