Skip to content

Commit

Permalink
* Made tests smaller with the hope they don't break CI anymore.
Browse files Browse the repository at this point in the history
* Updated documentation for state classes.
* Updated RTD to use latex for some parts.
* Updated with more detailed description of the optimizer.

PiperOrigin-RevId: 438791229
  • Loading branch information
botev authored and KfacJaxDev committed Apr 1, 2022
1 parent 7c3fc63 commit 1daa555
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 77 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# KFAC JAX - Second Order Optimization with Approximate Curvature in JAX
# KFAC-JAX - Second Order Optimization with Approximate Curvature in JAX

[**Installation**](#installation)
| [**Quickstart**](#quickstart)
| [**Documentation**](https://kfac-jax.readthedocs.io/)
| [**Examples**](https://github.com/deepmind/kfac_jax/tree/main/examples/)
| [**Citing KFAC JAX**](#citing-kfac-jax)
| [**Citing KFAC-JAX**](#citing-kfac-jax)

![CI status](https://github.com/deepmind/kfac_jax/workflows/ci/badge.svg)
![docs](https://readthedocs.org/projects/kfac_jax/badge/?version=latest)
Expand Down Expand Up @@ -196,14 +196,14 @@ parameters of the model to be part of dense layers.
For a high level overview of the optimizer, the different curvature
approximations, and the supported layers, please see the [documentation].

## Citing KFAC JAX<a id="citing-kfac-jax"></a>
## Citing KFAC-JAX<a id="citing-kfac-jax"></a>

To cite this repository:

```
@software{kfac_jax2022github,
author = {Aleksandar Botev and James Martens},
title = {{KFAC JAX}},
title = {{KFAC-JAX}},
url = {http://github.com/deepmind/kfac_jax},
version = {0.0.1},
year = {2022},
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _recursive_add_annotations_import():
autodoc_default_options = {
'member-order': 'bysource',
'special-members': True,
'exclude-members': '__repr__, __str__, __weakref__',
'exclude-members': '__repr__, __str__, __weakref__, __eq__, __hash__'
}

# -- Options for HTML output -------------------------------------------------
Expand Down
150 changes: 97 additions & 53 deletions docs/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,12 @@ Supported layers
Currently, the library only includes support for the three most common types of
layers used in practice:

1. Dense layers, corresponding to ``y = Wx + b``.
2. 2D convolution layers, corresponding to ``y = W * x + b``.
3. Scale and shift layers, corresponding to ``y = w . x + b``.
1. Dense layers, corresponding to :math:`y = Wx + b`.
2. 2D convolution layers, corresponding to :math:`y = W \star x + b`.
3. Scale and shift layers, corresponding to :math:`y = w \odot x + b`.

Here ``*`` corresponds to convolution and ``.`` to elementwise product.
Here :math:`\star` corresponds to convolution and :math:`\odot` to elementwise
product.
Parameter reuse, such as in recurrent networks and attention layers, is
not currently supported.

Expand Down Expand Up @@ -118,68 +119,111 @@ checkout the relevant section in :doc:`advanced<advanced>` on how to do this.
Optimizer
=========

The optimization algorithm implement in :class:`kfac_jax.Optimizer` follows the
`K-FAC paper <https://arxiv.org/abs/1503.05671>`_.
Throughout optimization the optimizer instance keeps the following state::

C - the curvature estimator state.
velocity - velocity vectors of the parameters.
damping - weight of the additional damping added for inverting C.
counter - a step counter.
The optimization algorithm implemented in :class:`kfac_jax.Optimizer` follows
the `K-FAC paper <https://arxiv.org/abs/1503.05671>`_.
Throughout optimization the Optimizer instance keeps the following persistent
state:

.. math::
\begin{aligned}
& \bm{v}_t - \text{velocity vector, representing the last parameter update.
} \\
& \bm{C}_t - \text{The state of the curvature estimator on step } t .\\
& \lambda_t - \text{ weight of the additional damping added for
inverting } \bm{C}. \\
& t - \text{the step counter.}
\end{aligned}
If we denote the current minibatch of data by :math:`\bm{x}_t`, the current
parameters by :math:`\bm{\theta}_t`, the L2 regularizer by :math:`\gamma` and the
loss function (which includes the L2 regularizer) by :math:`\mathcal{L}`, a
high level pseudocode for a single step of the optimizer is:

.. math::
\begin{aligned}
&(1) \quad l_t, \bm{g}_t = \mathcal{L}(\bm{\theta}_t, \bm{x}_t),
\nabla_\theta \mathcal{L}(\bm{\theta}_t, \bm{x}_t)
\\
&(2) \quad \bm{C}_{t+1} = \text{update curvature}(\bm{C}_t,
\bm{\theta}_t, \bm{x}_t) \\
&(3) \quad \hat{\bm{g}}_t = (\bm{C}_{t+1} + (\lambda_t + \gamma) \bm{I}
)^{-1} \bm{g}_t \\
&(4) \quad \alpha_t, \beta_t = \text{update coefficients}(
\hat{\bm{g}}_t, \bm{x}_t, \bm{\theta}_t, \bm{v}_t) \\
&(5) \quad \bm{v}_{t+1} = \alpha_t \hat{\bm{g}}_t + \beta_t \bm{v}_t \\
&(6) \quad \bm{\theta}_{t+1} = \bm{\theta}_t + \bm{v}_{t+1} \\
&(7) \quad \lambda_{t+1} = \text{update damping}(l_t, \bm{\theta}_{t+1},
\bm{C}_{t+1})
\end{aligned}
Steps 1, 2, 3, 5 and 6 are standard for any second order optimization algorithm.
Step 4 and 7 are described in more details below.


Computing the update coefficients (4)
-------------------------------------

The update coefficients :math:`\alpha_t` and :math:`\beta_t` in step 4 can
either be provided manually by the user at each step, or can be computed
automatically from the local quadratic model.
This is controlled by the optimizer arguments ``use_adaptive_learning_rate``
and ``use_adaptive_momentum``.
Note that these features don't currently work very well unless you use a very
large batch size, and/or increase the batch size dynamically during training
(as was done in the original K-FAC paper).

If we denote the current minibatch of data by ``x``, the current parameters by
``theta`` and the function that computes the value and gradient of the loss
by ``f``, a high level pseudocode for a single step of the optimizer
is::
Automatic selection of update coefficients
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

1 loss, gradient = f(theta, x)
2 C = update_curvature_estimate(C, theta, x)
3 preconditioned_gradient = compute_inverse(C) @ gradient
4 c1, c2 = compute_update_coefficients(theta, x, preconditioned_gradient, velocity)
5 velocity_new = c1 * preconditioned_gradient + c2 * velocity
6 theta_new = theta + velocity_new
7 damping = update_damping(loss, theta_new, C)
The procedure to automatically select the update coefficients uses the local
quadratic model defined as:

.. math::
q(\bm{\delta}) = l_t + \bm{g}_t^T \bm{\delta} + \frac{1}{2} \bm{\delta}^T
(\bm{C} + (\lambda_t + \gamma) \bm{I}) \bm{\delta},
Amortizing expensive computations
---------------------------------
where :math:`\bm{C}` is usually the exact curvature matrix.
To compute :math:`\alpha_t` and :math:`\beta_t`, we minimize
:math:`q(\alpha_t \hat{\bm{g}}_t + \beta_t \bm{v}_t)` with respect to the two
scalars, treating :math:`\hat{\bm{g}}_t` and :math:`\bm{v}_t` as fixed vectors.
Since this is a simple two dimensional quadratic problem, and it requires only
matrix-vector products with :math:`\bm{C}`, it can be solved efficiently.
For further details see Section 7 of the original
`K-FAC paper <https://arxiv.org/abs/1503.05671>`_.

When running the optimizer, several of the steps involved can have
a somewhat significant computational overhead.
For this reason, the optimizer class allows these to be performed every `K`
steps, and to cache these values across iterations.
This has been found to work well in practice without significant drawbacks in
training performance.
Specifically, this is applied to computing the inverse of the estimated
approximate curvature (step 3), and to the updates to the damping (step 7).

Computing the update coefficients
---------------------------------
Updating the damping (7)
------------------------

The update coefficients ``c1`` and ``c2`` in step 4 can either be provided
manually by the user at each step, or can be computed automatically using the
procedure described in Section 7 of the original
`K-FAC paper <https://arxiv.org/abs/1503.05671>`_.
This is controlled by the optimizer arguments ``use_adaptive_learning_rate``
and ``use_adaptive_momentum``.
Note that these features don't currently work very well unless you use a very
large batch size, and/or increase the batch size dynamically during training
(as was done in the original K-FAC paper).
The damping update is done via the Levenberg-Marquardt heuristic.
This is done by computing the reduction ratio:

Updating the damping
--------------------
.. math::
\rho = \frac{\mathcal{L}(\bm{\theta}_{t+1}) - \mathcal{L}(\bm{\theta}_{t})}
{q(\bm{v}_{t+1}) - q(\bm{0})}
The damping update is done via the Levenberg-Marquardt heuristic.
This is done by computing the reduction ratio
``(f(theta_new) - f(theta)) / (q(theta_new) - q_theta)``, where ``q`` is the
quadratic model value induced by either the exact or approximate curvature
matrix.
where :math:`q` is the quadratic model value induced by either the exact or
approximate curvature matrix.
If the optimizer uses either learning rate or momentum adaptation, or
``always_use_exact_qmodel_for_damping_adjustment`` is set to ``True``, the
optimizer will use the exact curvature matrix; otherwise it will use the
approximate curvature.
If this value deviates too much from ``1`` we either increase or decrease the
damping as described in Section 6.5 from the original
If the value of :math:`\rho` deviates too much from 1 we either increase or
decrease the damping :math:`\lambda` as described in Section 6.5 of the original
`K-FAC paper <https://arxiv.org/abs/1503.05671>`_.
Whether the damping is adapted, or provided by the user at each single step, is
controlled by the optimizer argument ``use_adaptive_damping``.


Amortizing expensive computations
---------------------------------

When running the optimizer, several of the steps involved can have
a noticeable computational overhead.
For this reason, the optimizer class allows these to be performed every `K`
steps, and to cache the values across iterations.
This has been found to work well in practice without significant drawbacks in
training performance.
This is applied to computing the inverse of the estimated approximate curvature
(step 3), and to the updates of the damping (step 7).
33 changes: 33 additions & 0 deletions kfac_jax/_src/curvature_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,18 @@ class CurvatureBlock(utils.Finalizable):

@utils.pytree_dataclass
class State:
"""Persistent state of the block.
Any subclasses of :class:`~CurvatureBlock` should also internally extend
this class, with any attributes needed for the curvature estimation.
Attributes:
cache: A dictionary, containing any state data that is updated on
irregular intervals, such as inverses, eigenvalues, etc. Elements of
this are updated via calls to :func:`~CurvatureBlock.update_cache`, and
do not necessarily correspond to the the most up to date curvature
estimate.
"""
cache: Optional[Dict[str, Union[chex.Array, Dict[str, chex.Array]]]]

def __init__(self, layer_tag_eq: tags.LayerTagEqn, name: str):
Expand Down Expand Up @@ -566,6 +578,13 @@ class Diagonal(CurvatureBlock, abc.ABC):

@utils.pytree_dataclass
class State(CurvatureBlock.State):
"""Persistent state of the block.
Attributes:
diagonal_factors: A tuple of the moving averages of the estimated
diagonals of the curvature for each parameter that is part of the
associated layer.
"""
diagonal_factors: Tuple[utils.WeightedMovingAverage]

def _init(
Expand Down Expand Up @@ -659,6 +678,12 @@ class Full(CurvatureBlock, abc.ABC):

@utils.pytree_dataclass
class State(CurvatureBlock.State):
"""Persistent state of the block.
Attributes:
matrix: A moving average of the estimated curvature matrix for all
parameters that are part of the associated layer.
"""
matrix: utils.WeightedMovingAverage

def __init__(
Expand Down Expand Up @@ -836,6 +861,14 @@ class TwoKroneckerFactored(CurvatureBlock, abc.ABC):

@utils.pytree_dataclass
class State(CurvatureBlock.State):
"""Persistent state of the block.
Attributes:
inputs_factor: A moving average of the estimated second moment matrix of
the inputs to the associated layer.
outputs_factor: A moving average of the estimated second moment matrix of
the gradients of w.r.t. the outputs of the associated layer.
"""
inputs_factor: utils.WeightedMovingAverage
outputs_factor: utils.WeightedMovingAverage

Expand Down
6 changes: 6 additions & 0 deletions kfac_jax/_src/curvature_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,12 @@ class BlockDiagonalCurvature(CurvatureEstimator):

@utils.pytree_dataclass
class State:
"""Persistent state of the estimator.
Attributes:
blocks_states: A tuple of the state of the estimator corresponding to each
block.
"""
blocks_states: Tuple[curvature_blocks.CurvatureBlock.State, ...]

def __init__(
Expand Down
11 changes: 11 additions & 0 deletions kfac_jax/_src/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,17 @@ class Optimizer(utils.WithStagedMethods):

@utils.pytree_dataclass
class State:
r"""Persistent state of the optimizer.
Attributes:
velocities: The update to the parameters from the previous step -
:math:`\theta_t - \theta_{t-1}`.
estimator_state: The persistent state for the curvature estimator.
damping: When using damping adaptation, this will contain the current
value.
data_seen: The number of training cases that the optimizer has processed.
step_counter: An integer giving the current step number :math:`t`.
"""
velocities: utils.Params
estimator_state: curvature_estimator.BlockDiagonalCurvature.State
damping: Optional[chex.Array]
Expand Down
38 changes: 19 additions & 19 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,45 +475,45 @@ def conv_classifier_loss(
functools.partial(
autoencoder_with_two_losses,
layer_widths=[20, 10, 20]),
dict(images=(32,)),
dict(images=(8,)),
1231987,
),
# (
# conv_classifier(
# num_classes=10,
# layer_channels=[8, 16, 32]
# ).init,
# functools.partial(
# conv_classifier_loss,
# num_classes=10,
# layer_channels=[8, 16, 32]),
# dict(images=(32, 32, 3), labels=(10,)),
# 354649831,
# ),
(
conv_classifier(
num_classes=10,
layer_channels=[8, 16]
).init,
functools.partial(
conv_classifier_loss,
num_classes=10,
layer_channels=[8, 16]),
dict(images=(8, 8, 3), labels=(10,)),
354649831,
),
]


LINEAR_MODELS = [
(
autoencoder([100, 50, 100]).init,
autoencoder([20, 10, 20]).init,
functools.partial(
linear_squared_error_autoencoder_loss,
layer_widths=[100, 50, 100]),
dict(images=(64,)),
layer_widths=[20, 10, 20]),
dict(images=(8,)),
1240982837,
),
]


PIECEWISE_LINEAR_MODELS = [
(
autoencoder([100, 50, 100]).init,
autoencoder([20, 10, 20]).init,
functools.partial(
autoencoder_with_two_losses,
layer_widths=[100, 50, 100],
layer_widths=[20, 10, 20],
activation=_special_relu,
),
dict(images=(64,)),
dict(images=(8,)),
1231987,
),
]

0 comments on commit 1daa555

Please sign in to comment.