Skip to content

Commit

Permalink
Adding data seen to the reported statistics on the evaluator in the e…
Browse files Browse the repository at this point in the history
…xamples.

PiperOrigin-RevId: 508676532
  • Loading branch information
botev authored and KfacJaxDev committed Feb 10, 2023
1 parent c4e5765 commit d6f14ad
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion examples/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def _evaluate_single_batch(
) -> Dict[str, chex.Array]:
"""Evaluates a single batch."""

del global_step, opt_state # This might be used in subclasses
del global_step # This might be used in subclasses

func_args = kfac_jax.optimizer.make_func_args(
params=params,
Expand All @@ -447,6 +447,7 @@ def _evaluate_single_batch(
)
loss, stats = self.eval_model_func(*func_args)
stats["loss"] = loss
stats["data_seen"] = opt_state.data_seen

return kfac_jax.utils.pmean_if_pmap(stats, "eval_axis")

Expand Down

0 comments on commit d6f14ad

Please sign in to comment.