diff --git a/README.md b/README.md
index 16e0cab..034af29 100644
--- a/README.md
+++ b/README.md
@@ -1,10 +1,10 @@
-# KFAC JAX - Second Order Optimization with Approximate Curvature in JAX
+# KFAC-JAX - Second Order Optimization with Approximate Curvature in JAX
[**Installation**](#installation)
| [**Quickstart**](#quickstart)
| [**Documentation**](https://kfac-jax.readthedocs.io/)
| [**Examples**](https://github.com/deepmind/kfac_jax/tree/main/examples/)
-| [**Citing KFAC JAX**](#citing-kfac-jax)
+| [**Citing KFAC-JAX**](#citing-kfac-jax)
![CI status](https://github.com/deepmind/kfac_jax/workflows/ci/badge.svg)
![docs](https://readthedocs.org/projects/kfac_jax/badge/?version=latest)
@@ -196,14 +196,14 @@ parameters of the model to be part of dense layers.
For a high level overview of the optimizer, the different curvature
approximations, and the supported layers, please see the [documentation].
-## Citing KFAC JAX
+## Citing KFAC-JAX
To cite this repository:
```
@software{kfac_jax2022github,
author = {Aleksandar Botev and James Martens},
- title = {{KFAC JAX}},
+ title = {{KFAC-JAX}},
url = {http://github.com/deepmind/kfac_jax},
version = {0.0.1},
year = {2022},
diff --git a/docs/conf.py b/docs/conf.py
index 5e3ef0d..ce49cb5 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -110,7 +110,7 @@ def _recursive_add_annotations_import():
autodoc_default_options = {
'member-order': 'bysource',
'special-members': True,
- 'exclude-members': '__repr__, __str__, __weakref__',
+ 'exclude-members': '__repr__, __str__, __weakref__, __eq__, __hash__'
}
# -- Options for HTML output -------------------------------------------------
diff --git a/docs/overview.rst b/docs/overview.rst
index af57233..8670619 100644
--- a/docs/overview.rst
+++ b/docs/overview.rst
@@ -73,11 +73,12 @@ Supported layers
Currently, the library only includes support for the three most common types of
layers used in practice:
-1. Dense layers, corresponding to ``y = Wx + b``.
-2. 2D convolution layers, corresponding to ``y = W * x + b``.
-3. Scale and shift layers, corresponding to ``y = w . x + b``.
+1. Dense layers, corresponding to :math:`y = Wx + b`.
+2. 2D convolution layers, corresponding to :math:`y = W \star x + b`.
+3. Scale and shift layers, corresponding to :math:`y = w \odot x + b`.
-Here ``*`` corresponds to convolution and ``.`` to elementwise product.
+Here :math:`\star` corresponds to convolution and :math:`\odot` to elementwise
+product.
Parameter reuse, such as in recurrent networks and attention layers, is
not currently supported.
@@ -118,68 +119,111 @@ checkout the relevant section in :doc:`advanced` on how to do this.
Optimizer
=========
-The optimization algorithm implement in :class:`kfac_jax.Optimizer` follows the
-`K-FAC paper `_.
-Throughout optimization the optimizer instance keeps the following state::
-
- C - the curvature estimator state.
- velocity - velocity vectors of the parameters.
- damping - weight of the additional damping added for inverting C.
- counter - a step counter.
+The optimization algorithm implemented in :class:`kfac_jax.Optimizer` follows
+the `K-FAC paper `_.
+Throughout optimization the Optimizer instance keeps the following persistent
+state:
+
+.. math::
+ \begin{aligned}
+ & \bm{v}_t - \text{velocity vector, representing the last parameter update.
+ } \\
+ & \bm{C}_t - \text{The state of the curvature estimator on step } t .\\
+ & \lambda_t - \text{ weight of the additional damping added for
+ inverting } \bm{C}. \\
+ & t - \text{the step counter.}
+ \end{aligned}
+
+
+If we denote the current minibatch of data by :math:`\bm{x}_t`, the current
+parameters by :math:`\bm{\theta}_t`, the L2 regularizer by :math:`\gamma` and the
+loss function (which includes the L2 regularizer) by :math:`\mathcal{L}`, a
+high level pseudocode for a single step of the optimizer is:
+
+.. math::
+ \begin{aligned}
+ &(1) \quad l_t, \bm{g}_t = \mathcal{L}(\bm{\theta}_t, \bm{x}_t),
+ \nabla_\theta \mathcal{L}(\bm{\theta}_t, \bm{x}_t)
+ \\
+ &(2) \quad \bm{C}_{t+1} = \text{update curvature}(\bm{C}_t,
+ \bm{\theta}_t, \bm{x}_t) \\
+ &(3) \quad \hat{\bm{g}}_t = (\bm{C}_{t+1} + (\lambda_t + \gamma) \bm{I}
+ )^{-1} \bm{g}_t \\
+ &(4) \quad \alpha_t, \beta_t = \text{update coefficients}(
+ \hat{\bm{g}}_t, \bm{x}_t, \bm{\theta}_t, \bm{v}_t) \\
+ &(5) \quad \bm{v}_{t+1} = \alpha_t \hat{\bm{g}}_t + \beta_t \bm{v}_t \\
+ &(6) \quad \bm{\theta}_{t+1} = \bm{\theta}_t + \bm{v}_{t+1} \\
+ &(7) \quad \lambda_{t+1} = \text{update damping}(l_t, \bm{\theta}_{t+1},
+ \bm{C}_{t+1})
+ \end{aligned}
+
+Steps 1, 2, 3, 5 and 6 are standard for any second order optimization algorithm.
+Step 4 and 7 are described in more details below.
+
+
+Computing the update coefficients (4)
+-------------------------------------
+
+The update coefficients :math:`\alpha_t` and :math:`\beta_t` in step 4 can
+either be provided manually by the user at each step, or can be computed
+automatically from the local quadratic model.
+This is controlled by the optimizer arguments ``use_adaptive_learning_rate``
+and ``use_adaptive_momentum``.
+Note that these features don't currently work very well unless you use a very
+large batch size, and/or increase the batch size dynamically during training
+(as was done in the original K-FAC paper).
-If we denote the current minibatch of data by ``x``, the current parameters by
-``theta`` and the function that computes the value and gradient of the loss
-by ``f``, a high level pseudocode for a single step of the optimizer
-is::
+Automatic selection of update coefficients
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- 1 loss, gradient = f(theta, x)
- 2 C = update_curvature_estimate(C, theta, x)
- 3 preconditioned_gradient = compute_inverse(C) @ gradient
- 4 c1, c2 = compute_update_coefficients(theta, x, preconditioned_gradient, velocity)
- 5 velocity_new = c1 * preconditioned_gradient + c2 * velocity
- 6 theta_new = theta + velocity_new
- 7 damping = update_damping(loss, theta_new, C)
+The procedure to automatically select the update coefficients uses the local
+quadratic model defined as:
+.. math::
+ q(\bm{\delta}) = l_t + \bm{g}_t^T \bm{\delta} + \frac{1}{2} \bm{\delta}^T
+ (\bm{C} + (\lambda_t + \gamma) \bm{I}) \bm{\delta},
-Amortizing expensive computations
----------------------------------
+where :math:`\bm{C}` is usually the exact curvature matrix.
+To compute :math:`\alpha_t` and :math:`\beta_t`, we minimize
+:math:`q(\alpha_t \hat{\bm{g}}_t + \beta_t \bm{v}_t)` with respect to the two
+scalars, treating :math:`\hat{\bm{g}}_t` and :math:`\bm{v}_t` as fixed vectors.
+Since this is a simple two dimensional quadratic problem, and it requires only
+matrix-vector products with :math:`\bm{C}`, it can be solved efficiently.
+For further details see Section 7 of the original
+`K-FAC paper `_.
-When running the optimizer, several of the steps involved can have
-a somewhat significant computational overhead.
-For this reason, the optimizer class allows these to be performed every `K`
-steps, and to cache these values across iterations.
-This has been found to work well in practice without significant drawbacks in
-training performance.
-Specifically, this is applied to computing the inverse of the estimated
-approximate curvature (step 3), and to the updates to the damping (step 7).
-Computing the update coefficients
----------------------------------
+Updating the damping (7)
+------------------------
-The update coefficients ``c1`` and ``c2`` in step 4 can either be provided
-manually by the user at each step, or can be computed automatically using the
-procedure described in Section 7 of the original
-`K-FAC paper `_.
-This is controlled by the optimizer arguments ``use_adaptive_learning_rate``
-and ``use_adaptive_momentum``.
-Note that these features don't currently work very well unless you use a very
-large batch size, and/or increase the batch size dynamically during training
-(as was done in the original K-FAC paper).
+The damping update is done via the Levenberg-Marquardt heuristic.
+This is done by computing the reduction ratio:
-Updating the damping
---------------------
+.. math::
+ \rho = \frac{\mathcal{L}(\bm{\theta}_{t+1}) - \mathcal{L}(\bm{\theta}_{t})}
+ {q(\bm{v}_{t+1}) - q(\bm{0})}
-The damping update is done via the Levenberg-Marquardt heuristic.
-This is done by computing the reduction ratio
-``(f(theta_new) - f(theta)) / (q(theta_new) - q_theta)``, where ``q`` is the
-quadratic model value induced by either the exact or approximate curvature
-matrix.
+where :math:`q` is the quadratic model value induced by either the exact or
+approximate curvature matrix.
If the optimizer uses either learning rate or momentum adaptation, or
``always_use_exact_qmodel_for_damping_adjustment`` is set to ``True``, the
optimizer will use the exact curvature matrix; otherwise it will use the
approximate curvature.
-If this value deviates too much from ``1`` we either increase or decrease the
-damping as described in Section 6.5 from the original
+If the value of :math:`\rho` deviates too much from 1 we either increase or
+decrease the damping :math:`\lambda` as described in Section 6.5 of the original
`K-FAC paper `_.
Whether the damping is adapted, or provided by the user at each single step, is
controlled by the optimizer argument ``use_adaptive_damping``.
+
+
+Amortizing expensive computations
+---------------------------------
+
+When running the optimizer, several of the steps involved can have
+a noticeable computational overhead.
+For this reason, the optimizer class allows these to be performed every `K`
+steps, and to cache the values across iterations.
+This has been found to work well in practice without significant drawbacks in
+training performance.
+This is applied to computing the inverse of the estimated approximate curvature
+(step 3), and to the updates of the damping (step 7).
diff --git a/kfac_jax/_src/curvature_blocks.py b/kfac_jax/_src/curvature_blocks.py
index 836b39c..0d845c7 100644
--- a/kfac_jax/_src/curvature_blocks.py
+++ b/kfac_jax/_src/curvature_blocks.py
@@ -126,6 +126,18 @@ class CurvatureBlock(utils.Finalizable):
@utils.pytree_dataclass
class State:
+ """Persistent state of the block.
+
+ Any subclasses of :class:`~CurvatureBlock` should also internally extend
+ this class, with any attributes needed for the curvature estimation.
+
+ Attributes:
+ cache: A dictionary, containing any state data that is updated on
+ irregular intervals, such as inverses, eigenvalues, etc. Elements of
+ this are updated via calls to :func:`~CurvatureBlock.update_cache`, and
+ do not necessarily correspond to the the most up to date curvature
+ estimate.
+ """
cache: Optional[Dict[str, Union[chex.Array, Dict[str, chex.Array]]]]
def __init__(self, layer_tag_eq: tags.LayerTagEqn, name: str):
@@ -566,6 +578,13 @@ class Diagonal(CurvatureBlock, abc.ABC):
@utils.pytree_dataclass
class State(CurvatureBlock.State):
+ """Persistent state of the block.
+
+ Attributes:
+ diagonal_factors: A tuple of the moving averages of the estimated
+ diagonals of the curvature for each parameter that is part of the
+ associated layer.
+ """
diagonal_factors: Tuple[utils.WeightedMovingAverage]
def _init(
@@ -659,6 +678,12 @@ class Full(CurvatureBlock, abc.ABC):
@utils.pytree_dataclass
class State(CurvatureBlock.State):
+ """Persistent state of the block.
+
+ Attributes:
+ matrix: A moving average of the estimated curvature matrix for all
+ parameters that are part of the associated layer.
+ """
matrix: utils.WeightedMovingAverage
def __init__(
@@ -836,6 +861,14 @@ class TwoKroneckerFactored(CurvatureBlock, abc.ABC):
@utils.pytree_dataclass
class State(CurvatureBlock.State):
+ """Persistent state of the block.
+
+ Attributes:
+ inputs_factor: A moving average of the estimated second moment matrix of
+ the inputs to the associated layer.
+ outputs_factor: A moving average of the estimated second moment matrix of
+ the gradients of w.r.t. the outputs of the associated layer.
+ """
inputs_factor: utils.WeightedMovingAverage
outputs_factor: utils.WeightedMovingAverage
diff --git a/kfac_jax/_src/curvature_estimator.py b/kfac_jax/_src/curvature_estimator.py
index fab4bd6..3a0b4a9 100644
--- a/kfac_jax/_src/curvature_estimator.py
+++ b/kfac_jax/_src/curvature_estimator.py
@@ -727,6 +727,12 @@ class BlockDiagonalCurvature(CurvatureEstimator):
@utils.pytree_dataclass
class State:
+ """Persistent state of the estimator.
+
+ Attributes:
+ blocks_states: A tuple of the state of the estimator corresponding to each
+ block.
+ """
blocks_states: Tuple[curvature_blocks.CurvatureBlock.State, ...]
def __init__(
diff --git a/kfac_jax/_src/optimizer.py b/kfac_jax/_src/optimizer.py
index 28ed748..ee01b3a 100644
--- a/kfac_jax/_src/optimizer.py
+++ b/kfac_jax/_src/optimizer.py
@@ -65,6 +65,17 @@ class Optimizer(utils.WithStagedMethods):
@utils.pytree_dataclass
class State:
+ r"""Persistent state of the optimizer.
+
+ Attributes:
+ velocities: The update to the parameters from the previous step -
+ :math:`\theta_t - \theta_{t-1}`.
+ estimator_state: The persistent state for the curvature estimator.
+ damping: When using damping adaptation, this will contain the current
+ value.
+ data_seen: The number of training cases that the optimizer has processed.
+ step_counter: An integer giving the current step number :math:`t`.
+ """
velocities: utils.Params
estimator_state: curvature_estimator.BlockDiagonalCurvature.State
damping: Optional[chex.Array]
diff --git a/tests/models.py b/tests/models.py
index da3fba1..be27780 100644
--- a/tests/models.py
+++ b/tests/models.py
@@ -475,31 +475,31 @@ def conv_classifier_loss(
functools.partial(
autoencoder_with_two_losses,
layer_widths=[20, 10, 20]),
- dict(images=(32,)),
+ dict(images=(8,)),
1231987,
),
- # (
- # conv_classifier(
- # num_classes=10,
- # layer_channels=[8, 16, 32]
- # ).init,
- # functools.partial(
- # conv_classifier_loss,
- # num_classes=10,
- # layer_channels=[8, 16, 32]),
- # dict(images=(32, 32, 3), labels=(10,)),
- # 354649831,
- # ),
+ (
+ conv_classifier(
+ num_classes=10,
+ layer_channels=[8, 16]
+ ).init,
+ functools.partial(
+ conv_classifier_loss,
+ num_classes=10,
+ layer_channels=[8, 16]),
+ dict(images=(8, 8, 3), labels=(10,)),
+ 354649831,
+ ),
]
LINEAR_MODELS = [
(
- autoencoder([100, 50, 100]).init,
+ autoencoder([20, 10, 20]).init,
functools.partial(
linear_squared_error_autoencoder_loss,
- layer_widths=[100, 50, 100]),
- dict(images=(64,)),
+ layer_widths=[20, 10, 20]),
+ dict(images=(8,)),
1240982837,
),
]
@@ -507,13 +507,13 @@ def conv_classifier_loss(
PIECEWISE_LINEAR_MODELS = [
(
- autoencoder([100, 50, 100]).init,
+ autoencoder([20, 10, 20]).init,
functools.partial(
autoencoder_with_two_losses,
- layer_widths=[100, 50, 100],
+ layer_widths=[20, 10, 20],
activation=_special_relu,
),
- dict(images=(64,)),
+ dict(images=(8,)),
1231987,
),
]