Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] Remove the singleton Program::context #1799

Merged
merged 2 commits into from
Aug 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions taichi/program/async_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ ExecutionQueue::ExecutionQueue()
: compilation_workers(4), launch_worker(1) { // TODO: remove 4
}

void AsyncEngine::launch(Kernel *kernel) {
void AsyncEngine::launch(Kernel *kernel, Context &context) {
if (!kernel->lowered)
kernel->lower(/*to_executable=*/false);

Expand Down Expand Up @@ -249,8 +249,8 @@ void AsyncEngine::launch(Kernel *kernel) {
TI_ASSERT(kmeta.offloaded_cached.size() == i);
kmeta.offloaded_cached.emplace_back(std::move(cloned_offs), h);
}
KernelLaunchRecord rec(kernel->program.get_context(), kernel, offl_template,
h, kmeta.dummy_root.get());
KernelLaunchRecord rec(context, kernel, offl_template, h,
kmeta.dummy_root.get());
enqueue(std::move(rec));
}
}
Expand Down
2 changes: 1 addition & 1 deletion taichi/program/async_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ class AsyncEngine {
queue.clear_cache();
}

void launch(Kernel *kernel);
void launch(Kernel *kernel, Context &context);

void enqueue(KernelLaunchRecord &&t);

Expand Down
19 changes: 14 additions & 5 deletions taichi/program/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ void Kernel::lower(bool to_executable) { // TODO: is a "Lowerer" class
lowered = true;
}

void Kernel::operator()(LaunchContextBuilder &launch_ctx) {
void Kernel::operator()(LaunchContextBuilder &ctx_builder) {
if (!program.config.async_mode || this->is_evaluator) {
if (!compiled) {
compile();
Expand All @@ -107,7 +107,7 @@ void Kernel::operator()(LaunchContextBuilder &launch_ctx) {
account_for_offloaded(offloaded->as<OffloadedStmt>());
}

compiled(launch_ctx.get_context());
compiled(ctx_builder.get_context());

program.sync = (program.sync && arch_is_cpu(arch));
// Note that Kernel::arch may be different from program.config.arch
Expand All @@ -117,7 +117,7 @@ void Kernel::operator()(LaunchContextBuilder &launch_ctx) {
}
} else {
program.sync = false;
program.async_engine->launch(this);
program.async_engine->launch(this, ctx_builder.get_context());
// Note that Kernel::arch may be different from program.config.arch
if (program.config.debug && arch_is_cpu(arch) &&
arch_is_cpu(program.config.arch)) {
Expand All @@ -127,7 +127,17 @@ void Kernel::operator()(LaunchContextBuilder &launch_ctx) {
}

Kernel::LaunchContextBuilder Kernel::make_launch_context() {
return LaunchContextBuilder(this, &(program.context));
return LaunchContextBuilder(this);
}

Kernel::LaunchContextBuilder::LaunchContextBuilder(Kernel *kernel, Context *ctx)
: kernel_(kernel), owned_ctx_(nullptr), ctx_(ctx) {
}

Kernel::LaunchContextBuilder::LaunchContextBuilder(Kernel *kernel)
: kernel_(kernel),
owned_ctx_(std::make_unique<Context>()),
ctx_(owned_ctx_.get()) {
}

void Kernel::LaunchContextBuilder::set_arg_float(int i, float64 d) {
Expand Down Expand Up @@ -232,7 +242,6 @@ void Kernel::LaunchContextBuilder::set_arg_raw(int i, uint64 d) {
}

Context &Kernel::LaunchContextBuilder::get_context() {
// See Program::get_context()
ctx_->runtime = static_cast<LLVMRuntime *>(kernel_->program.llvm_runtime);
return *ctx_;
}
Expand Down
29 changes: 19 additions & 10 deletions taichi/program/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
#include "taichi/ir/snode.h"
#include "taichi/ir/ir.h"

#define TI_RUNTIME_HOST
#include "taichi/program/context.h"
#undef TI_RUNTIME_HOST

TLANG_NAMESPACE_BEGIN

class Program;
Expand Down Expand Up @@ -47,9 +51,13 @@ class Kernel {
// TODO: Give "Context" a more specific name.
class LaunchContextBuilder {
public:
LaunchContextBuilder(Kernel *kernel, Context *ctx)
: kernel_(kernel), ctx_(ctx) {
}
LaunchContextBuilder(Kernel *kernel, Context *ctx);
explicit LaunchContextBuilder(Kernel *kernel);

LaunchContextBuilder(LaunchContextBuilder &&) = default;
LaunchContextBuilder &operator=(LaunchContextBuilder &&) = default;
LaunchContextBuilder(const LaunchContextBuilder &) = delete;
LaunchContextBuilder &operator=(const LaunchContextBuilder &) = delete;

void set_arg_float(int i, float64 d);

Expand All @@ -66,12 +74,13 @@ class Kernel {
Context &get_context();

private:
Kernel *const kernel_;
// TODO: Right now |ctx_| is borrowed from other places: either the
// program's context, or the one in the CUDA launch function. In the future,
// this could *own* a Context (possibly through a std::unique_ptr, since we
// don't always need to own the Context.)
Context *const ctx_;
Kernel *kernel_;
std::unique_ptr<Context> owned_ctx_;
// |ctx_| *almost* always points to |owned_ctx_|. However, it is possible
// that the caller passes a Context pointer externally. In that case,
// |owned_ctx_| will be nullptr.
// Invariant: |ctx_| will never be nullptr.
Context *ctx_;
};

Kernel(Program &program,
Expand All @@ -83,7 +92,7 @@ class Kernel {

void lower(bool to_executable = true);

void operator()(LaunchContextBuilder &launch_ctx);
void operator()(LaunchContextBuilder &ctx_builder);

LaunchContextBuilder make_launch_context();

Expand Down
4 changes: 0 additions & 4 deletions taichi/program/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -706,10 +706,6 @@ void Program::finalize() {
TI_TRACE("Program ({}) finalized.", fmt::ptr(this));
}

void Program::launch_async(Kernel *kernel) {
async_engine->launch(kernel);
}

int Program::default_block_dim() const {
if (arch_is_cpu(config.arch)) {
return config.default_cpu_block_dim;
Expand Down
8 changes: 0 additions & 8 deletions taichi/program/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ class Program {
std::unique_ptr<SNode> snode_root; // pointer to the data structure.
void *llvm_runtime;
CompileConfig config;
Context context;
std::unique_ptr<TaichiLLVMContext> llvm_context_host, llvm_context_device;
bool sync; // device/host synchronized?
bool finalized;
Expand Down Expand Up @@ -131,11 +130,6 @@ class Program {
return profiler.get();
}

Context &get_context() {
context.runtime = (LLVMRuntime *)llvm_runtime;
return context;
}

void initialize_device_llvm_context();

void synchronize();
Expand Down Expand Up @@ -240,8 +234,6 @@ class Program {
snode_root->print();
}

void launch_async(Kernel *kernel);

int default_block_dim() const;

void print_list_manager_info(void *list_manager);
Expand Down
22 changes: 2 additions & 20 deletions taichi/transforms/constant_fold.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,20 +87,6 @@ class ConstantFold : public BasicStmtVisitor {
}
}

class ContextArgSaveGuard {
Context &ctx;
uint64 old_args[taichi_max_num_args];

public:
explicit ContextArgSaveGuard(Context &ctx_) : ctx(ctx_) {
std::memcpy(old_args, ctx.args, sizeof(old_args));
}

~ContextArgSaveGuard() {
std::memcpy(ctx.args, old_args, sizeof(old_args));
}
};

static bool jit_evaluate_binary_op(TypedConstant &ret,
BinaryOpStmt *stmt,
const TypedConstant &lhs,
Expand All @@ -114,13 +100,11 @@ class ConstantFold : public BasicStmtVisitor {
rhs.dt,
true};
auto *ker = get_jit_evaluator_kernel(id);
auto &current_program = stmt->get_kernel()->program;
// save input args to prevent the current kernel from being overridden.
ContextArgSaveGuard _(current_program.get_context());
auto launch_ctx = ker->make_launch_context();
launch_ctx.set_arg_raw(0, lhs.val_u64);
launch_ctx.set_arg_raw(1, rhs.val_u64);
(*ker)(launch_ctx);
auto &current_program = stmt->get_kernel()->program;
ret.val_i64 = current_program.fetch_result<int64_t>(0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My future hope is:

Suggested change
ret.val_i64 = current_program.fetch_result<int64_t>(0);
ret.val_i64 = launch_ctx.get_arg<int64_t>(0);

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point. Unfortunately, reading the result is, in most cases, separated from calling the kernel.. I guess a preliminary step toward this goal would be to stop using Program to fetch the result.

return true;
}
Expand All @@ -137,12 +121,10 @@ class ConstantFold : public BasicStmtVisitor {
stmt->cast_type,
false};
auto *ker = get_jit_evaluator_kernel(id);
auto &current_program = stmt->get_kernel()->program;
// save input args to prevent the current kernel from being overridden.
ContextArgSaveGuard _(current_program.get_context());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delete class ContextArgSaveGuard from this file too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

auto launch_ctx = ker->make_launch_context();
launch_ctx.set_arg_raw(0, operand.val_u64);
(*ker)(launch_ctx);
auto &current_program = stmt->get_kernel()->program;
ret.val_i64 = current_program.fetch_result<int64_t>(0);
return true;
}
Expand Down