Skip to content

Commit

Permalink
[async] Support constant folding in async mode
Browse files Browse the repository at this point in the history
  • Loading branch information
k-ye committed Aug 25, 2020
1 parent cca8545 commit a4db06d
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 20 deletions.
12 changes: 7 additions & 5 deletions python/taichi/lang/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,16 +461,18 @@ def call_back():

ret = None
ret_dt = self.return_type
if ret_dt is not None:
has_ret = ret_dt is not None

if has_external_arrays or has_ret:
import taichi as ti
ti.sync()

if has_ret:
if id(ret_dt) in integer_type_ids:
ret = t_kernel.get_ret_int(0)
else:
ret = t_kernel.get_ret_float(0)

if has_external_arrays:
import taichi as ti
ti.sync()

if callbacks:
for c in callbacks:
c()
Expand Down
2 changes: 1 addition & 1 deletion taichi/program/async_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ void ExecutionQueue::enqueue(KernelLaunchRecord &&ker) {
auto config = kernel->program.config;
auto ir = stmt;
offload_to_executable(
ir, config, /*verbose=*/false,
ir, config, /*verbose=*/config.print_ir,
/*lower_global_access=*/true,
/*make_thread_local=*/true,
/*make_block_local=*/
Expand Down
2 changes: 1 addition & 1 deletion taichi/program/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ void Kernel::lower(bool to_executable) { // TODO: is a "Lowerer" class
}

void Kernel::operator()(LaunchContextBuilder &launch_ctx) {
if (!program.config.async_mode) {
if (!program.config.async_mode || this->is_evaluator) {
if (!compiled) {
compile();
}
Expand Down
24 changes: 12 additions & 12 deletions taichi/program/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,16 +468,20 @@ void Program::synchronize() {
if (config.async_mode) {
async_engine->synchronize();
}
if (config.arch == Arch::cuda) {
device_synchronize();
sync = true;
}
}

void Program::device_synchronize() {
if (config.arch == Arch::cuda) {
#if defined(TI_WITH_CUDA)
CUDADriver::get_instance().stream_synchronize(nullptr);
CUDADriver::get_instance().stream_synchronize(nullptr);
#else
TI_ERROR("No CUDA support");
TI_ERROR("No CUDA support");
#endif
} else if (config.arch == Arch::metal) {
metal_kernel_mgr_->synchronize();
}
sync = true;
} else if (config.arch == Arch::metal) {
metal_kernel_mgr_->synchronize();
}
}

Expand Down Expand Up @@ -626,13 +630,9 @@ Kernel &Program::get_snode_writer(SNode *snode) {
}

uint64 Program::fetch_result_uint64(int i) {
// Precondition: caller must have already done a program synchronization.
uint64 ret;
auto arch = config.arch;
sync = false;
// Runtime calls that set result buffer don't execute sync=false, so we have
// to set it here otherwise synchronize() does nothing.
// TODO: systematically fix this.
synchronize();
if (arch == Arch::cuda) {
#if defined(TI_WITH_CUDA)
if (config.use_unified_memory) {
Expand Down
1 change: 1 addition & 0 deletions taichi/program/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ class Program {
void initialize_device_llvm_context();

void synchronize();
void device_synchronize();

void layout(std::function<void()> func) {
func();
Expand Down
8 changes: 7 additions & 1 deletion taichi/transforms/constant_fold.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ class ConstantFold : public BasicStmtVisitor {
launch_ctx.set_arg_raw(0, lhs.val_u64);
launch_ctx.set_arg_raw(1, rhs.val_u64);
(*ker)(launch_ctx);
// Constant folding kernel is always run in sync mode, therefore we call
// device_synchronize().
current_program.device_synchronize();
ret.val_i64 = current_program.fetch_result<int64_t>(0);
return true;
}
Expand All @@ -143,6 +146,9 @@ class ConstantFold : public BasicStmtVisitor {
auto launch_ctx = ker->make_launch_context();
launch_ctx.set_arg_raw(0, operand.val_u64);
(*ker)(launch_ctx);
// Constant folding kernel is always run in sync mode, therefore we call
// device_synchronize().
current_program.device_synchronize();
ret.val_i64 = current_program.fetch_result<int64_t>(0);
return true;
}
Expand Down Expand Up @@ -240,7 +246,7 @@ bool constant_fold(IRNode *root) {
TI_TRACE("config.debug enabled, ignoring constant fold");
return false;
}
if (!cfg.advanced_optimization || cfg.async_mode)
if (!cfg.advanced_optimization)
return false;
return ConstantFold::run(root);
}
Expand Down
18 changes: 18 additions & 0 deletions tests/python/test_constant_fold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import taichi as ti


@ti.test(arch=ti.cpu, async_mode=True)
def test_constant_fold():
n = 100

@ti.kernel
def series() -> int:
s = 0
for i in ti.static(range(n)):
a = i + 1
s += a * a
return s

# \sum_{i=1}^n (i^2) = n * (n + 1) * (2n + 1) / 6
expected = n * (n + 1) * (2 * n + 1) // 6
assert series() == expected

0 comments on commit a4db06d

Please sign in to comment.