diff --git a/examples/training.py b/examples/training.py index da6f9cd..5f30bbc 100644 --- a/examples/training.py +++ b/examples/training.py @@ -271,9 +271,10 @@ def create_optimizer(self) -> Union[ def initialize_state(self): """Initializes all of the experiment's state variables.""" - init_rng, preprocess_rng = jax.random.split(self.init_rng) + init_rng, seed_rng = jax.random.split(self.init_rng) init_rng = kfac_jax.utils.replicate_all_local_devices(init_rng) - preprocess_rng = jax.random.fold_in(preprocess_rng, jax.process_index()) + seed_rng = jax.random.fold_in(seed_rng, jax.process_index()) + seed = int(seed_rng[0]) # Initialize and load dataset if self.mode == "train": @@ -281,7 +282,7 @@ def initialize_state(self): datasets.dataset_as_generator( self._build_train_input, split="train", - seed=int(preprocess_rng[0]), + seed=seed, device_batch_size=self.train_per_device_batch_size, ) ) @@ -293,12 +294,12 @@ def initialize_state(self): self._eval_input = dict( train=self._build_eval_input( split="train", - seed=int(preprocess_rng[0]), + seed=seed, device_batch_size=self.eval_per_device_batch_size ), test=self._build_eval_input( split="test", - seed=int(preprocess_rng[0]), + seed=seed, device_batch_size=self.eval_per_device_batch_size ), ) diff --git a/kfac_jax/_src/optimizer.py b/kfac_jax/_src/optimizer.py index e028e11..5de28e0 100644 --- a/kfac_jax/_src/optimizer.py +++ b/kfac_jax/_src/optimizer.py @@ -121,6 +121,7 @@ def __init__( default_batch_size_extractor, pmap_axis_name: str = "kfac_axis", forbid_setting_attributes_after_finalize: bool = True, + modifiable_attribute_exceptions: Sequence[str] = (), include_norms_in_stats: bool = False, ): """Initializes the K-FAC optimizer with the provided settings. @@ -266,6 +267,10 @@ def __init__( they have been compiled. However, if you are extending this class, and clearly understand the risks of modifying attributes, setting this to ``False`` will remove the restriction. (Default: ``True``) + modifiable_attribute_exceptions: Sequence of strings. Gives a list + of names for attributes that can be modified after finalization even + when ``forbid_setting_attributes_after_finalize`` is ``True``. + (Default: ``()``) include_norms_in_stats: Boolean. It True, the vector norms of the gradient, preconditioned gradient, and parameter update are included in the statistics returned by the step function. (Default: ``False``) @@ -276,6 +281,7 @@ def __init__( debug=debug, forbid_setting_attributes_after_finalize= forbid_setting_attributes_after_finalize, + excluded_attribute_names=modifiable_attribute_exceptions, ) if use_adaptive_damping and initial_damping is None: raise ValueError("When use_adaptive_damping is True you must provide a "