Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

- Minor fixes/improvements to optimizer docstrings #108

Merged
merged 1 commit into from
Mar 8, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions kfac_jax/_src/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""K-FAC optimizer."""

import functools
from typing import Any, Callable, Iterator, Mapping, Optional, Sequence, Tuple, Union

Expand Down Expand Up @@ -122,7 +124,7 @@ def __init__(
A note on damping:

One of the main complications of using second-order optimizers like K-FAC is
the "damping" parameter. This parameter is multiplied by th identity matrix
the "damping" parameter. This parameter is multiplied by the identity matrix
and (approximately) added to the curvature matrix (i.e. the Fisher or GGN)
before it is inverted and multiplied by the gradient when computing the
update (before any learning rate scaling). The damping should follow the
Expand Down Expand Up @@ -708,6 +710,7 @@ def _coefficients_and_quad_change(
func_args: Optional[FuncArgsVariants] = None,
) -> Tuple[Tuple[chex.Array, ...], Optional[chex.Array]]:
"""The correct update coefficients and corresponding quadratic change."""

# Compute the coefficients of the update vectors
# The learning rate is defined as the negative of the coefficient by which
# we multiply the gradients, while the momentum is the coefficient by
Expand Down Expand Up @@ -748,6 +751,7 @@ def _update_damping(
new_func_args: FuncArgsVariants,
) -> Tuple[chex.Array, chex.Array, chex.Array]:
"""Updates the damping parameter."""

new_loss = self.compute_loss_value(new_func_args)

# Sync
Expand Down Expand Up @@ -965,8 +969,11 @@ def _step(

# Compute per-device and total batch size
batch_size = self._batch_size_extractor(func_args[-1])

if self.multi_device:
total_batch_size = batch_size * jax.device_count()
else:
total_batch_size = batch_size

# Update data seen and step counter
state.data_seen = state.data_seen + total_batch_size
Expand Down Expand Up @@ -1031,11 +1038,13 @@ def step(
"""Performs a single update step using the optimizer.

Args:
params: The parameters of the model.
state: The state of the optimizer.
rng: A Jax PRNG key.
data_iterator: A data iterator.
batch: A single batch.
params: The current parameters of the model.
state: The current state of the optimizer.
rng: A Jax PRNG key. Should be different for each iteration and
each Jax process/host.
data_iterator: A data iterator to use (if not passing ``batch``).
batch: A single batch used to compute the update. Should only pass one
of ``data_iterator`` or ``batch``.
func_state: Any function state that gets passed in and returned.
learning_rate: Learning rate to use if the optimizer was created with
``use_adaptive_learning_rate=True``, ``None`` otherwise.
Expand All @@ -1047,17 +1056,20 @@ def step(
damping.
global_step_int: The global step as a python int. Note that this must
match the step internal to the optimizer that is part of its state.

Returns:
(params, state, stats) or (params, state, func_state, stats), where
(params, state, stats) if ``value_func_has_state=False`` and
(params, state, func_state, stats) otherwise, where

* params is the updated model parameters.

* state is the updated optimizer state.

* func_state is the updated function state.

* stats is a dictionary of key statistics provided to be logged.
* stats is a dictionary of useful statistics including the loss.
"""

if (data_iterator is None) == (batch is None):
raise ValueError("Exactly one of the arguments ``data_iterator`` and "
"``batch`` must be provided.")
Expand Down