From 9a1be665624b058a211a9ad11330b4ee7dfdd89f Mon Sep 17 00:00:00 2001 From: James Martens Date: Wed, 27 Apr 2022 13:13:41 -0700 Subject: [PATCH] - Adding "modifiable_attribute_exceptions" argument to optimizer - Renaming "preprocess_rng" to "seed_rng" in examples (since it's used for more than just preprocessing) PiperOrigin-RevId: 444946712 --- examples/training.py | 11 ++++++----- kfac_jax/_src/optimizer.py | 6 ++++++ 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/examples/training.py b/examples/training.py index da6f9cd..5f30bbc 100644 --- a/examples/training.py +++ b/examples/training.py @@ -271,9 +271,10 @@ def create_optimizer(self) -> Union[ def initialize_state(self): """Initializes all of the experiment's state variables.""" - init_rng, preprocess_rng = jax.random.split(self.init_rng) + init_rng, seed_rng = jax.random.split(self.init_rng) init_rng = kfac_jax.utils.replicate_all_local_devices(init_rng) - preprocess_rng = jax.random.fold_in(preprocess_rng, jax.process_index()) + seed_rng = jax.random.fold_in(seed_rng, jax.process_index()) + seed = int(seed_rng[0]) # Initialize and load dataset if self.mode == "train": @@ -281,7 +282,7 @@ def initialize_state(self): datasets.dataset_as_generator( self._build_train_input, split="train", - seed=int(preprocess_rng[0]), + seed=seed, device_batch_size=self.train_per_device_batch_size, ) ) @@ -293,12 +294,12 @@ def initialize_state(self): self._eval_input = dict( train=self._build_eval_input( split="train", - seed=int(preprocess_rng[0]), + seed=seed, device_batch_size=self.eval_per_device_batch_size ), test=self._build_eval_input( split="test", - seed=int(preprocess_rng[0]), + seed=seed, device_batch_size=self.eval_per_device_batch_size ), ) diff --git a/kfac_jax/_src/optimizer.py b/kfac_jax/_src/optimizer.py index e028e11..5de28e0 100644 --- a/kfac_jax/_src/optimizer.py +++ b/kfac_jax/_src/optimizer.py @@ -121,6 +121,7 @@ def __init__( default_batch_size_extractor, pmap_axis_name: str = "kfac_axis", forbid_setting_attributes_after_finalize: bool = True, + modifiable_attribute_exceptions: Sequence[str] = (), include_norms_in_stats: bool = False, ): """Initializes the K-FAC optimizer with the provided settings. @@ -266,6 +267,10 @@ def __init__( they have been compiled. However, if you are extending this class, and clearly understand the risks of modifying attributes, setting this to ``False`` will remove the restriction. (Default: ``True``) + modifiable_attribute_exceptions: Sequence of strings. Gives a list + of names for attributes that can be modified after finalization even + when ``forbid_setting_attributes_after_finalize`` is ``True``. + (Default: ``()``) include_norms_in_stats: Boolean. It True, the vector norms of the gradient, preconditioned gradient, and parameter update are included in the statistics returned by the step function. (Default: ``False``) @@ -276,6 +281,7 @@ def __init__( debug=debug, forbid_setting_attributes_after_finalize= forbid_setting_attributes_after_finalize, + excluded_attribute_names=modifiable_attribute_exceptions, ) if use_adaptive_damping and initial_damping is None: raise ValueError("When use_adaptive_damping is True you must provide a "