Skip to content

Commit

Permalink
- Adding "modifiable_attribute_exceptions" argument to optimizer
Browse files Browse the repository at this point in the history
- Renaming "preprocess_rng" to "seed_rng" in examples (since it's used for more than just preprocessing)

PiperOrigin-RevId: 444946712
  • Loading branch information
james-martens authored and KfacJaxDev committed May 10, 2022
1 parent f8b6405 commit 9a1be66
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
11 changes: 6 additions & 5 deletions examples/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,17 +271,18 @@ def create_optimizer(self) -> Union[

def initialize_state(self):
"""Initializes all of the experiment's state variables."""
init_rng, preprocess_rng = jax.random.split(self.init_rng)
init_rng, seed_rng = jax.random.split(self.init_rng)
init_rng = kfac_jax.utils.replicate_all_local_devices(init_rng)
preprocess_rng = jax.random.fold_in(preprocess_rng, jax.process_index())
seed_rng = jax.random.fold_in(seed_rng, jax.process_index())
seed = int(seed_rng[0])

# Initialize and load dataset
if self.mode == "train":
self._train_input = pipe_utils.py_prefetch(
datasets.dataset_as_generator(
self._build_train_input,
split="train",
seed=int(preprocess_rng[0]),
seed=seed,
device_batch_size=self.train_per_device_batch_size,
)
)
Expand All @@ -293,12 +294,12 @@ def initialize_state(self):
self._eval_input = dict(
train=self._build_eval_input(
split="train",
seed=int(preprocess_rng[0]),
seed=seed,
device_batch_size=self.eval_per_device_batch_size
),
test=self._build_eval_input(
split="test",
seed=int(preprocess_rng[0]),
seed=seed,
device_batch_size=self.eval_per_device_batch_size
),
)
Expand Down
6 changes: 6 additions & 0 deletions kfac_jax/_src/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __init__(
default_batch_size_extractor,
pmap_axis_name: str = "kfac_axis",
forbid_setting_attributes_after_finalize: bool = True,
modifiable_attribute_exceptions: Sequence[str] = (),
include_norms_in_stats: bool = False,
):
"""Initializes the K-FAC optimizer with the provided settings.
Expand Down Expand Up @@ -266,6 +267,10 @@ def __init__(
they have been compiled. However, if you are extending this class, and
clearly understand the risks of modifying attributes, setting this to
``False`` will remove the restriction. (Default: ``True``)
modifiable_attribute_exceptions: Sequence of strings. Gives a list
of names for attributes that can be modified after finalization even
when ``forbid_setting_attributes_after_finalize`` is ``True``.
(Default: ``()``)
include_norms_in_stats: Boolean. It True, the vector norms of the
gradient, preconditioned gradient, and parameter update are included in
the statistics returned by the step function. (Default: ``False``)
Expand All @@ -276,6 +281,7 @@ def __init__(
debug=debug,
forbid_setting_attributes_after_finalize=
forbid_setting_attributes_after_finalize,
excluded_attribute_names=modifiable_attribute_exceptions,
)
if use_adaptive_damping and initial_damping is None:
raise ValueError("When use_adaptive_damping is True you must provide a "
Expand Down

0 comments on commit 9a1be66

Please sign in to comment.