Skip to content

Commit

Permalink
- minor refactoring
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 611099938
  • Loading branch information
james-martens authored and KfacJaxDev committed Feb 28, 2024
1 parent 5c64ae7 commit d15b304
Showing 1 changed file with 23 additions and 10 deletions.
33 changes: 23 additions & 10 deletions kfac_jax/_src/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,9 +568,18 @@ def should_update_inverse_cache(
self, state: "Optimizer.State"
) -> Union[Array, bool]:
"""Whether at the current step the optimizer should update the inverse curvature approximation."""
if self._inverse_update_period == 1:
return True
return state.step_counter % self._inverse_update_period == 0
return self._use_cached_inverses and (
state.step_counter % self._inverse_update_period == 0)

def should_sync_estimator(
self, state: "Optimizer.State"
) -> Union[Array, bool]:
"""Whether at the current step the optimizer should update the inverse curvature approximation."""

if self._use_cached_inverses:
return self.should_update_inverse_cache(state)

return True

@functools.partial(utils.staged, static_argnums=1)
def _rng_split(
Expand Down Expand Up @@ -970,6 +979,9 @@ def _init(
) -> "Optimizer.State":
"""A staged function to initialize the optimizer state ."""

# Note that we can reuse the ng in the func_args construction below, as
# these are just dummy values used to perform the tracing.

return Optimizer.State(
velocities=jax.tree_util.tree_map(jnp.zeros_like, params),
estimator_state=self.estimator.init(
Expand Down Expand Up @@ -1014,7 +1026,8 @@ def _burnin(
rng: Array,
batch: Batch,
func_state: Optional[FuncState],
accumulator: utils.MultiChunkAccumulator
accumulator: utils.MultiChunkAccumulator,
sync: Union[Array, bool],
) -> Tuple["Optimizer.State", utils.MultiChunkAccumulator]:
"""A single burnin step, updating only the curvature estimate."""

Expand All @@ -1026,7 +1039,7 @@ def _burnin(

# Update curvature estimate
state.estimator_state = self._update_estimator_curvature(
state.estimator_state, func_args, rng, 1.0, 1.0)
state.estimator_state, func_args, rng, 1.0, 1.0, sync=sync)

# Optionally update func_state
if func_state is not None:
Expand All @@ -1050,16 +1063,18 @@ def burnin(
"""Runs all burnin steps required."""

if num_steps > 0:

rng = self._rng_split(rng, num_steps)

accumulator = utils.MultiChunkAccumulator.zeros_like(
func_state, self.multi_device)

for rng_i in rng:
for i, rng_i in enumerate(rng):
batch = next(data_iterator)

state, accumulator = self._burnin(
params, state, rng_i, batch, func_state, accumulator)
params, state, rng_i, batch, func_state, accumulator,
i == num_steps - 1)

func_state = accumulator.value_and_clear()

Expand Down Expand Up @@ -1099,9 +1114,7 @@ def _step(
rng,
self._curvature_ema,
1.0,
sync=self.should_update_inverse_cache(
state
), # sync curvature estimates only before inverses are updated.
sync=self.should_sync_estimator(state),
)

del rng # should not be used after this point!
Expand Down

0 comments on commit d15b304

Please sign in to comment.