Skip to content

Commit

Permalink
remove ensure_compile_time_eval context, fix progress bar in loop t…
Browse files Browse the repository at this point in the history
…ransformation, compatible with jax==0.4.36 (#46)
  • Loading branch information
chaoming0625 authored Dec 9, 2024
1 parent c5be092 commit 176ee7a
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 40 deletions.
12 changes: 5 additions & 7 deletions brainstate/compile/_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,8 @@ def cond(pred, true_fun, false_fun, *operands):
return false_fun(*operands)

# evaluate jaxpr
with jax.ensure_compile_time_eval():
stateful_true = StatefulFunction(true_fun).make_jaxpr(*operands)
stateful_false = StatefulFunction(false_fun).make_jaxpr(*operands)
stateful_true = StatefulFunction(true_fun).make_jaxpr(*operands)
stateful_false = StatefulFunction(false_fun).make_jaxpr(*operands)

# state trace and state values
state_trace = stateful_true.get_state_trace() + stateful_false.get_state_trace()
Expand Down Expand Up @@ -175,10 +174,9 @@ def switch(index, branches, *operands):
return branches[int(index)](*operands)

# evaluate jaxpr
with jax.ensure_compile_time_eval():
wrapped_branches = [StatefulFunction(branch) for branch in branches]
for wrapped_branch in wrapped_branches:
wrapped_branch.make_jaxpr(*operands)
wrapped_branches = [StatefulFunction(branch) for branch in branches]
for wrapped_branch in wrapped_branches:
wrapped_branch.make_jaxpr(*operands)

# wrap the functions
state_trace = wrapped_branches[0].get_state_trace() + wrapped_branches[1].get_state_trace()
Expand Down
6 changes: 3 additions & 3 deletions brainstate/compile/_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ def jitted_fun(*args, **params):
return fun.fun(*args, **params)

# compile the function and get the state trace
with jax.ensure_compile_time_eval():
state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
read_state_vals = state_trace.get_read_state_values(True)
state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
read_state_vals = state_trace.get_read_state_values(True)

# call the jitted function
write_state_vals, outs = jit_fun(state_trace.get_state_values(), *args, **params)
# write the state values back to the states
Expand Down
11 changes: 5 additions & 6 deletions brainstate/compile/_loop_collect_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,11 @@ def scan(f, init, xs, length=None):
# ------------------------------ #
xs_avals = [jax.core.raise_to_shaped(jax.core.get_aval(x)) for x in xs_flat]
x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
with jax.ensure_compile_time_eval():
stateful_fun = StatefulFunction(f).make_jaxpr(init, xs_tree.unflatten(x_avals))
state_trace = stateful_fun.get_state_trace()
all_writen_state_vals = state_trace.get_write_state_values(True)
all_read_state_vals = state_trace.get_read_state_values(True)
wrapped_f = wrap_single_fun(stateful_fun, state_trace.been_writen, all_read_state_vals)
stateful_fun = StatefulFunction(f).make_jaxpr(init, xs_tree.unflatten(x_avals))
state_trace = stateful_fun.get_state_trace()
all_writen_state_vals = state_trace.get_write_state_values(True)
all_read_state_vals = state_trace.get_read_state_values(True)
wrapped_f = wrap_single_fun(stateful_fun, state_trace.been_writen, all_read_state_vals)

# scan
init = (all_writen_state_vals, init)
Expand Down
9 changes: 4 additions & 5 deletions brainstate/compile/_loop_no_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,10 @@ def while_loop(cond_fun, body_fun, init_val):
pass

# evaluate jaxpr
with jax.ensure_compile_time_eval():
stateful_cond = StatefulFunction(cond_fun).make_jaxpr(init_val)
stateful_body = StatefulFunction(body_fun).make_jaxpr(init_val)
if len(stateful_cond.get_write_states()) != 0:
raise ValueError("while_loop: cond_fun should not have any write states.")
stateful_cond = StatefulFunction(cond_fun).make_jaxpr(init_val)
stateful_body = StatefulFunction(body_fun).make_jaxpr(init_val)
if len(stateful_cond.get_write_states()) != 0:
raise ValueError("while_loop: cond_fun should not have any write states.")

# state trace and state values
state_trace = stateful_cond.get_state_trace() + stateful_body.get_state_trace()
Expand Down
39 changes: 20 additions & 19 deletions brainstate/compile/_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,25 +105,26 @@ def _tqdm(self, is_init, is_print, is_final):
self.tqdm_bars[0].close()

def __call__(self, iter_num, *args, **kwargs):
jax.debug.callback(
self._tqdm,
iter_num == 0,
(iter_num + 1) % self.print_freq == 0,
iter_num == self.n - 1
)

# _ = jax.lax.cond(
# jax.debug.callback(
# self._tqdm,
# iter_num == 0,
# lambda: jax.debug.callback(self._define_tqdm, ordered=True),
# lambda: None,
# )
# _ = jax.lax.cond(
# (iter_num + 1) % self.print_freq == 0,
# lambda: jax.debug.callback(self._update_tqdm, ordered=True),
# lambda: None,
# )
# _ = jax.lax.cond(
# iter_num == self.n - 1,
# lambda: jax.debug.callback(self._close_tqdm, ordered=True),
# lambda: None,
# iter_num == self.n - 1
# )

_ = jax.lax.cond(
iter_num == 0,
lambda: jax.debug.callback(self._define_tqdm),
lambda: None,
)
_ = jax.lax.cond(
iter_num % self.print_freq == (self.print_freq - 1),
lambda: jax.debug.callback(self._update_tqdm),
lambda: None,
)
_ = jax.lax.cond(
iter_num == self.n - 1,
lambda: jax.debug.callback(self._close_tqdm),
lambda: None,
)

0 comments on commit 176ee7a

Please sign in to comment.