From 135488e660e5e5db4d538b5704518e50e898052e Mon Sep 17 00:00:00 2001 From: Ailing Zhang Date: Wed, 8 Jun 2022 21:06:06 +0800 Subject: [PATCH] [aot][bug] Use cached compiled kernel pointer when it's added to graph multiple times This bug was triggered when we tried to port stable_fluid demo so this PR also added a cgraph based stable fluid demo. Note it's not ideal to save both `FunctionType compiled_` as well as `aot::Kernel compiled_aot_kernel_` inside C++ `Kernel` class. But we plan to clean that up (likely by getting rid of `FunctionType compiled_`) in #5114. --- .../examples/graph/stable_fluid_graph.py | 342 ++++++++++++++++++ python/taichi/graph/_graph.py | 6 +- python/taichi/lang/kernel_impl.py | 4 + taichi/program/graph_builder.cpp | 6 +- taichi/program/graph_builder.h | 1 - taichi/program/kernel.cpp | 4 +- taichi/program/kernel.h | 11 +- 7 files changed, 364 insertions(+), 10 deletions(-) create mode 100644 python/taichi/examples/graph/stable_fluid_graph.py diff --git a/python/taichi/examples/graph/stable_fluid_graph.py b/python/taichi/examples/graph/stable_fluid_graph.py new file mode 100644 index 00000000000000..23fafe36032490 --- /dev/null +++ b/python/taichi/examples/graph/stable_fluid_graph.py @@ -0,0 +1,342 @@ +# References: +# http://developer.download.nvidia.com/books/HTML/gpugems/gpugems_ch38.html +# https://github.com/PavelDoGreat/WebGL-Fluid-Simulation +# https://www.bilibili.com/video/BV1ZK411H7Hc?p=4 +# https://github.com/ShaneFX/GAMES201/tree/master/HW01 + +import argparse + +import numpy as np +import taichi as ti + +ti.init(arch=ti.vulkan) + +res = 512 +dt = 0.03 +p_jacobi_iters = 500 # 40 for a quicker but less accurate result +f_strength = 10000.0 +curl_strength = 0 +time_c = 2 +maxfps = 60 +dye_decay = 1 - 1 / (maxfps * time_c) +force_radius = res / 2.0 +gravity = True +paused = False + + + +class TexPair: + def __init__(self, cur, nxt): + self.cur = cur + self.nxt = nxt + + def swap(self): + self.cur, self.nxt = self.nxt, self.cur + + + +@ti.func +def sample(qf: ti.template(), u, v): + I = ti.Vector([int(u), int(v)]) + I = max(0, min(res - 1, I)) + return qf[I] + + +@ti.func +def lerp(vl, vr, frac): + # frac: [0.0, 1.0] + return vl + frac * (vr - vl) + + +@ti.func +def bilerp(vf: ti.template(), p): + u, v = p + s, t = u - 0.5, v - 0.5 + # floor + iu, iv = ti.floor(s), ti.floor(t) + # fract + fu, fv = s - iu, t - iv + a = sample(vf, iu, iv) + b = sample(vf, iu + 1, iv) + c = sample(vf, iu, iv + 1) + d = sample(vf, iu + 1, iv + 1) + return lerp(lerp(a, b, fu), lerp(c, d, fu), fv) + + +# 3rd order Runge-Kutta +@ti.func +def backtrace(vf: ti.template(), p, dt: ti.template()): + v1 = bilerp(vf, p) + p1 = p - 0.5 * dt * v1 + v2 = bilerp(vf, p1) + p2 = p - 0.75 * dt * v2 + v3 = bilerp(vf, p2) + p -= dt * ((2 / 9) * v1 + (1 / 3) * v2 + (4 / 9) * v3) + return p + + +@ti.kernel +def advect(vf: ti.types.ndarray(field_dim=2), + qf: ti.types.ndarray(field_dim=2), + new_qf: ti.types.ndarray(field_dim=2)): + for i, j in vf: + p = ti.Vector([i, j]) + 0.5 + p = backtrace(vf, p, dt) + new_qf[i, j] = bilerp(qf, p) * dye_decay + + +@ti.kernel +def apply_impulse(vf: ti.types.ndarray(field_dim=2), + dyef: ti.types.ndarray(field_dim=2), + imp_data: ti.types.ndarray(field_dim=1)): + g_dir = -ti.Vector([0, 9.8]) * 300 + for i, j in vf: + omx, omy = imp_data[2], imp_data[3] + mdir = ti.Vector([imp_data[0], imp_data[1]]) + dx, dy = (i + 0.5 - omx), (j + 0.5 - omy) + d2 = dx * dx + dy * dy + # dv = F * dt + factor = ti.exp(-d2 / force_radius) + + dc = dyef[i, j] + a = dc.norm() + + momentum = (mdir * f_strength * factor + g_dir * a / (1 + a)) * dt + + v = vf[i, j] + vf[i, j] = v + momentum + # add dye + if mdir.norm() > 0.5: + dc += ti.exp(-d2 * (4 / (res / 15)**2)) * ti.Vector( + [imp_data[4], imp_data[5], imp_data[6]]) + + dyef[i, j] = dc + + +@ti.kernel +def divergence(vf: ti.types.ndarray(field_dim=2), + velocity_divs: ti.types.ndarray(field_dim=2)): + for i, j in vf: + vl = sample(vf, i - 1, j) + vr = sample(vf, i + 1, j) + vb = sample(vf, i, j - 1) + vt = sample(vf, i, j + 1) + vc = sample(vf, i, j) + if i == 0: + vl.x = -vc.x + if i == res - 1: + vr.x = -vc.x + if j == 0: + vb.y = -vc.y + if j == res - 1: + vt.y = -vc.y + velocity_divs[i, j] = (vr.x - vl.x + vt.y - vb.y) * 0.5 + + +@ti.kernel +def pressure_jacobi(pf: ti.types.ndarray(field_dim=2), + new_pf: ti.types.ndarray(field_dim=2), + velocity_divs: ti.types.ndarray(field_dim=2)): + for i, j in pf: + pl = sample(pf, i - 1, j) + pr = sample(pf, i + 1, j) + pb = sample(pf, i, j - 1) + pt = sample(pf, i, j + 1) + div = velocity_divs[i, j] + new_pf[i, j] = (pl + pr + pb + pt - div) * 0.25 + + +@ti.kernel +def subtract_gradient(vf: ti.types.ndarray(field_dim=2), + pf: ti.types.ndarray(field_dim=2)): + for i, j in vf: + pl = sample(pf, i - 1, j) + pr = sample(pf, i + 1, j) + pb = sample(pf, i, j - 1) + pt = sample(pf, i, j + 1) + vf[i, j] -= 0.5 * ti.Vector([pr - pl, pt - pb]) + + +def solve_pressure_jacobi(): + for _ in range(p_jacobi_iters): + pressure_jacobi(pressures_pair.cur, pressures_pair.nxt, _velocity_divs) + pressures_pair.swap() + + +def step_orig(mouse_data): + advect(velocities_pair.cur, velocities_pair.cur, velocities_pair.nxt) + advect(velocities_pair.cur, dyes_pair.cur, dyes_pair.nxt) + velocities_pair.swap() + dyes_pair.swap() + + apply_impulse(velocities_pair.cur, dyes_pair.cur, mouse_data) + + divergence(velocities_pair.cur, _velocity_divs) + + solve_pressure_jacobi() + + subtract_gradient(velocities_pair.cur, pressures_pair.cur) + + +mouse_data_ti = ti.ndarray(ti.f32, shape=(8, )) + + +class MouseDataGen(object): + def __init__(self): + self.prev_mouse = None + self.prev_color = None + + def __call__(self, gui): + # [0:2]: normalized delta direction + # [2:4]: current mouse xy + # [4:7]: color + mouse_data = np.zeros(8, dtype=np.float32) + if gui.is_pressed(ti.GUI.LMB): + mxy = np.array(gui.get_cursor_pos(), dtype=np.float32) * res + if self.prev_mouse is None: + self.prev_mouse = mxy + # Set lower bound to 0.3 to prevent too dark colors + self.prev_color = (np.random.rand(3) * 0.7) + 0.3 + else: + mdir = mxy - self.prev_mouse + mdir = mdir / (np.linalg.norm(mdir) + 1e-5) + mouse_data[0], mouse_data[1] = mdir[0], mdir[1] + mouse_data[2], mouse_data[3] = mxy[0], mxy[1] + mouse_data[4:7] = self.prev_color + self.prev_mouse = mxy + else: + self.prev_mouse = None + self.prev_color = None + mouse_data_ti.from_numpy(mouse_data) + return mouse_data_ti + + +def reset(): + velocities_pair.cur.fill(0) + pressures_pair.cur.fill(0) + dyes_pair.cur.fill(0) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + '--baseline', + action='store_true') + args, unknown = parser.parse_known_args() + + gui = ti.GUI('Stable Fluid', (res, res)) + md_gen = MouseDataGen() + + _velocities = ti.Vector.ndarray(2, float, shape=(res, res)) + _new_velocities = ti.Vector.ndarray(2, float, shape=(res, res)) + _velocity_divs = ti.ndarray(float, shape=(res, res)) + velocity_curls = ti.ndarray(float, shape=(res, res)) + _pressures = ti.ndarray(float, shape=(res, res)) + _new_pressures = ti.ndarray(float, shape=(res, res)) + _dye_buffer = ti.Vector.ndarray(3, float, shape=(res, res)) + _new_dye_buffer = ti.Vector.ndarray(3, float, shape=(res, res)) + + if args.baseline: + velocities_pair = TexPair(_velocities, _new_velocities) + pressures_pair = TexPair(_pressures, _new_pressures) + dyes_pair = TexPair(_dye_buffer, _new_dye_buffer) + else: + print('running in graph mode') + velocities_pair_cur = ti.graph.Arg(ti.graph.ArgKind.NDARRAY, + 'velocities_pair_cur', + ti.f32, + element_shape=(2, )) + velocities_pair_nxt = ti.graph.Arg(ti.graph.ArgKind.NDARRAY, + 'velocities_pair_nxt', + ti.f32, + element_shape=(2, )) + dyes_pair_cur = ti.graph.Arg(ti.graph.ArgKind.NDARRAY, + 'dyes_pair_cur', + ti.f32, + element_shape=(3, )) + dyes_pair_nxt = ti.graph.Arg(ti.graph.ArgKind.NDARRAY, + 'dyes_pair_nxt', + ti.f32, + element_shape=(3, )) + pressures_pair_cur = ti.graph.Arg(ti.graph.ArgKind.NDARRAY, + 'pressures_pair_cur', ti.f32) + pressures_pair_nxt = ti.graph.Arg(ti.graph.ArgKind.NDARRAY, + 'pressures_pair_nxt', ti.f32) + velocity_divs = ti.graph.Arg(ti.graph.ArgKind.NDARRAY, 'velocity_divs', + ti.f32) + mouse_data = ti.graph.Arg(ti.graph.ArgKind.NDARRAY, 'mouse_data', ti.f32) + + g1_builder = ti.graph.GraphBuilder() + g1_builder.dispatch(advect, velocities_pair_cur, velocities_pair_cur, + velocities_pair_nxt) + g1_builder.dispatch(advect, velocities_pair_cur, dyes_pair_cur, dyes_pair_nxt) + g1_builder.dispatch(apply_impulse, velocities_pair_nxt, dyes_pair_nxt, mouse_data) + g1_builder.dispatch(divergence, velocities_pair_nxt, velocity_divs) + # swap is unrolled in the loop so we only need p_jacobi_iters // 2 iterations. + for _ in range(p_jacobi_iters // 2): + g1_builder.dispatch(pressure_jacobi, pressures_pair_cur, pressures_pair_nxt, + velocity_divs) + g1_builder.dispatch(pressure_jacobi, pressures_pair_nxt, pressures_pair_cur, + velocity_divs) + g1_builder.dispatch(subtract_gradient, velocities_pair_nxt, pressures_pair_cur) + g1 = g1_builder.compile() + + g2_builder = ti.graph.GraphBuilder() + g2_builder.dispatch(advect, velocities_pair_nxt, velocities_pair_nxt, + velocities_pair_cur) + g2_builder.dispatch(advect, velocities_pair_nxt, dyes_pair_nxt, dyes_pair_cur) + g2_builder.dispatch(apply_impulse, velocities_pair_cur, dyes_pair_cur, mouse_data) + g2_builder.dispatch(divergence, velocities_pair_cur, velocity_divs) + for _ in range(p_jacobi_iters // 2): + g2_builder.dispatch(pressure_jacobi, pressures_pair_cur, pressures_pair_nxt, + velocity_divs) + g2_builder.dispatch(pressure_jacobi, pressures_pair_nxt, pressures_pair_cur, + velocity_divs) + g2_builder.dispatch(subtract_gradient, velocities_pair_cur, pressures_pair_cur) + g2 = g2_builder.compile() + + + swap = True + + while gui.running: + if gui.get_event(ti.GUI.PRESS): + e = gui.event + if e.key == ti.GUI.ESCAPE: + break + elif e.key == 'r': + paused = False + reset() + elif e.key == 's': + if curl_strength: + curl_strength = 0 + else: + curl_strength = 7 + elif e.key == 'g': + gravity = not gravity + elif e.key == 'p': + paused = not paused + + if not paused: + _mouse_data = md_gen(gui) + if args.baseline: + step_orig(_mouse_data) + gui.set_image(dyes_pair.cur.to_numpy()) + else: + invoke_args = { + 'mouse_data': _mouse_data, + 'velocities_pair_cur': _velocities, + 'velocities_pair_nxt': _new_velocities, + 'dyes_pair_cur': _dye_buffer, + 'dyes_pair_nxt': _new_dye_buffer, + 'pressures_pair_cur': _pressures, + 'pressures_pair_nxt': _new_pressures, + 'velocity_divs': _velocity_divs + } + if swap: + g1.run(invoke_args) + gui.set_image(_dye_buffer.to_numpy()) + swap = False + else: + g2.run(invoke_args) + gui.set_image(_new_dye_buffer.to_numpy()) + swap = True + gui.show() diff --git a/python/taichi/graph/_graph.py b/python/taichi/graph/_graph.py index f9f06b32c8aa09..f5f476fe69a252 100644 --- a/python/taichi/graph/_graph.py +++ b/python/taichi/graph/_graph.py @@ -12,8 +12,8 @@ def gen_cpp_kernel(kernel_fn, args): kernel = kernel_fn._primal assert isinstance(kernel, kernel_impl.Kernel) injected_args = produce_injected_args(kernel, symbolic_args=args) - kernel.ensure_compiled(*injected_args) - return kernel.kernel_cpp + key = kernel.ensure_compiled(*injected_args) + return kernel.compiled_kernels[key] class Sequential: @@ -64,7 +64,7 @@ def run(self, args): arg_floats[k] = v else: raise TaichiRuntimeError( - 'Only python int, float and ti.Ndarray are supported as runtime arguments' + f'Only python int, float and ti.Ndarray are supported as runtime arguments but got {type(v)}' ) self._compiled_graph.run(arg_ptrs, arg_ints, arg_floats) diff --git a/python/taichi/lang/kernel_impl.py b/python/taichi/lang/kernel_impl.py index 884298e9f99c7c..d97d104f8c138f 100644 --- a/python/taichi/lang/kernel_impl.py +++ b/python/taichi/lang/kernel_impl.py @@ -419,6 +419,9 @@ def __init__(self, _func, is_grad, _classkernel=False): impl.get_runtime().kernels.append(self) self.reset() self.kernel_cpp = None + # TODO[#5114]: get rid of compiled_functions and use compiled_kernels instead. + # Main motivation is that compiled_kernels can be potentially serialized in the AOT scenario. + self.compiled_kernels = {} def reset(self): self.runtime = impl.get_runtime() @@ -532,6 +535,7 @@ def taichi_ast_generator(kernel_cxx): assert key not in self.compiled_functions self.compiled_functions[key] = self.get_function_body(taichi_kernel) + self.compiled_kernels[key] = taichi_kernel def get_torch_callbacks(self, v, has_torch, is_ndarray=True): callbacks = [] diff --git a/taichi/program/graph_builder.cpp b/taichi/program/graph_builder.cpp index 76c579d70e74b6..8aba27cf3087e2 100644 --- a/taichi/program/graph_builder.cpp +++ b/taichi/program/graph_builder.cpp @@ -6,11 +6,11 @@ namespace taichi { namespace lang { void Dispatch::compile( std::vector &compiled_dispatches) { - if (!compiled_kernel_) { - compiled_kernel_ = kernel_->compile_to_aot_kernel(); + if (kernel_->compiled_aot_kernel() == nullptr) { + kernel_->compile_to_aot_kernel(); } aot::CompiledDispatch dispatch{kernel_->get_name(), symbolic_args_, - compiled_kernel_.get()}; + kernel_->compiled_aot_kernel()}; compiled_dispatches.push_back(std::move(dispatch)); } diff --git a/taichi/program/graph_builder.h b/taichi/program/graph_builder.h index 129e5adc7b94ba..8f9c8a341109be 100644 --- a/taichi/program/graph_builder.h +++ b/taichi/program/graph_builder.h @@ -36,7 +36,6 @@ class Dispatch : public Node { private: mutable bool serialized_{false}; Kernel *kernel_{nullptr}; - std::unique_ptr compiled_kernel_{nullptr}; std::vector symbolic_args_; }; diff --git a/taichi/program/kernel.cpp b/taichi/program/kernel.cpp index 3b12af94ab526a..93f6ec88a12854 100644 --- a/taichi/program/kernel.cpp +++ b/taichi/program/kernel.cpp @@ -64,8 +64,8 @@ void Kernel::compile() { compiled_ = program->compile(*this); } -std::unique_ptr Kernel::compile_to_aot_kernel() { - return program->make_aot_kernel(*this); +void Kernel::compile_to_aot_kernel() { + compiled_aot_kernel_ = program->make_aot_kernel(*this); } void Kernel::lower(bool to_executable) { diff --git a/taichi/program/kernel.h b/taichi/program/kernel.h index 0d98252b8c4994..fc54784c56746d 100644 --- a/taichi/program/kernel.h +++ b/taichi/program/kernel.h @@ -87,7 +87,12 @@ class TI_DLL_EXPORT Kernel : public Callable { void compile(); - std::unique_ptr compile_to_aot_kernel(); + void compile_to_aot_kernel(); + + aot::Kernel *compiled_aot_kernel() { + return compiled_aot_kernel_.get(); + } + /** * Lowers |ir| to CHI IR level * @@ -142,6 +147,10 @@ class TI_DLL_EXPORT Kernel : public Callable { bool ir_is_ast_{false}; // The closure that, if invoked, lauches the backend kernel (shader) FunctionType compiled_{nullptr}; + // TODO[#5114]: It's kinda redundant to keep both compiled_ (used for JIT + // execution) as well as compiled_aot_kernel_. In fact we'd better unify + // everything around compiled_aot_kernel and rename it. + std::unique_ptr compiled_aot_kernel_{nullptr}; // A flag to record whether |ir| has been fully lowered. // lower inital AST all the way down to a bunch of // OffloadedStmt for async execution