diff --git a/README.md b/README.md index 16e0cab..034af29 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@ -# KFAC JAX - Second Order Optimization with Approximate Curvature in JAX +# KFAC-JAX - Second Order Optimization with Approximate Curvature in JAX [**Installation**](#installation) | [**Quickstart**](#quickstart) | [**Documentation**](https://kfac-jax.readthedocs.io/) | [**Examples**](https://github.com/deepmind/kfac_jax/tree/main/examples/) -| [**Citing KFAC JAX**](#citing-kfac-jax) +| [**Citing KFAC-JAX**](#citing-kfac-jax) ![CI status](https://github.com/deepmind/kfac_jax/workflows/ci/badge.svg) ![docs](https://readthedocs.org/projects/kfac_jax/badge/?version=latest) @@ -196,14 +196,14 @@ parameters of the model to be part of dense layers. For a high level overview of the optimizer, the different curvature approximations, and the supported layers, please see the [documentation]. -## Citing KFAC JAX +## Citing KFAC-JAX To cite this repository: ``` @software{kfac_jax2022github, author = {Aleksandar Botev and James Martens}, - title = {{KFAC JAX}}, + title = {{KFAC-JAX}}, url = {http://github.com/deepmind/kfac_jax}, version = {0.0.1}, year = {2022}, diff --git a/docs/conf.py b/docs/conf.py index 5e3ef0d..ce49cb5 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -110,7 +110,7 @@ def _recursive_add_annotations_import(): autodoc_default_options = { 'member-order': 'bysource', 'special-members': True, - 'exclude-members': '__repr__, __str__, __weakref__', + 'exclude-members': '__repr__, __str__, __weakref__, __eq__, __hash__' } # -- Options for HTML output ------------------------------------------------- diff --git a/docs/overview.rst b/docs/overview.rst index af57233..8670619 100644 --- a/docs/overview.rst +++ b/docs/overview.rst @@ -73,11 +73,12 @@ Supported layers Currently, the library only includes support for the three most common types of layers used in practice: -1. Dense layers, corresponding to ``y = Wx + b``. -2. 2D convolution layers, corresponding to ``y = W * x + b``. -3. Scale and shift layers, corresponding to ``y = w . x + b``. +1. Dense layers, corresponding to :math:`y = Wx + b`. +2. 2D convolution layers, corresponding to :math:`y = W \star x + b`. +3. Scale and shift layers, corresponding to :math:`y = w \odot x + b`. -Here ``*`` corresponds to convolution and ``.`` to elementwise product. +Here :math:`\star` corresponds to convolution and :math:`\odot` to elementwise +product. Parameter reuse, such as in recurrent networks and attention layers, is not currently supported. @@ -118,68 +119,111 @@ checkout the relevant section in :doc:`advanced` on how to do this. Optimizer ========= -The optimization algorithm implement in :class:`kfac_jax.Optimizer` follows the -`K-FAC paper `_. -Throughout optimization the optimizer instance keeps the following state:: - - C - the curvature estimator state. - velocity - velocity vectors of the parameters. - damping - weight of the additional damping added for inverting C. - counter - a step counter. +The optimization algorithm implemented in :class:`kfac_jax.Optimizer` follows +the `K-FAC paper `_. +Throughout optimization the Optimizer instance keeps the following persistent +state: + +.. math:: + \begin{aligned} + & \bm{v}_t - \text{velocity vector, representing the last parameter update. + } \\ + & \bm{C}_t - \text{The state of the curvature estimator on step } t .\\ + & \lambda_t - \text{ weight of the additional damping added for + inverting } \bm{C}. \\ + & t - \text{the step counter.} + \end{aligned} + + +If we denote the current minibatch of data by :math:`\bm{x}_t`, the current +parameters by :math:`\bm{\theta}_t`, the L2 regularizer by :math:`\gamma` and the +loss function (which includes the L2 regularizer) by :math:`\mathcal{L}`, a +high level pseudocode for a single step of the optimizer is: + +.. math:: + \begin{aligned} + &(1) \quad l_t, \bm{g}_t = \mathcal{L}(\bm{\theta}_t, \bm{x}_t), + \nabla_\theta \mathcal{L}(\bm{\theta}_t, \bm{x}_t) + \\ + &(2) \quad \bm{C}_{t+1} = \text{update curvature}(\bm{C}_t, + \bm{\theta}_t, \bm{x}_t) \\ + &(3) \quad \hat{\bm{g}}_t = (\bm{C}_{t+1} + (\lambda_t + \gamma) \bm{I} + )^{-1} \bm{g}_t \\ + &(4) \quad \alpha_t, \beta_t = \text{update coefficients}( + \hat{\bm{g}}_t, \bm{x}_t, \bm{\theta}_t, \bm{v}_t) \\ + &(5) \quad \bm{v}_{t+1} = \alpha_t \hat{\bm{g}}_t + \beta_t \bm{v}_t \\ + &(6) \quad \bm{\theta}_{t+1} = \bm{\theta}_t + \bm{v}_{t+1} \\ + &(7) \quad \lambda_{t+1} = \text{update damping}(l_t, \bm{\theta}_{t+1}, + \bm{C}_{t+1}) + \end{aligned} + +Steps 1, 2, 3, 5 and 6 are standard for any second order optimization algorithm. +Step 4 and 7 are described in more details below. + + +Computing the update coefficients (4) +------------------------------------- + +The update coefficients :math:`\alpha_t` and :math:`\beta_t` in step 4 can +either be provided manually by the user at each step, or can be computed +automatically from the local quadratic model. +This is controlled by the optimizer arguments ``use_adaptive_learning_rate`` +and ``use_adaptive_momentum``. +Note that these features don't currently work very well unless you use a very +large batch size, and/or increase the batch size dynamically during training +(as was done in the original K-FAC paper). -If we denote the current minibatch of data by ``x``, the current parameters by -``theta`` and the function that computes the value and gradient of the loss -by ``f``, a high level pseudocode for a single step of the optimizer -is:: +Automatic selection of update coefficients +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - 1 loss, gradient = f(theta, x) - 2 C = update_curvature_estimate(C, theta, x) - 3 preconditioned_gradient = compute_inverse(C) @ gradient - 4 c1, c2 = compute_update_coefficients(theta, x, preconditioned_gradient, velocity) - 5 velocity_new = c1 * preconditioned_gradient + c2 * velocity - 6 theta_new = theta + velocity_new - 7 damping = update_damping(loss, theta_new, C) +The procedure to automatically select the update coefficients uses the local +quadratic model defined as: +.. math:: + q(\bm{\delta}) = l_t + \bm{g}_t^T \bm{\delta} + \frac{1}{2} \bm{\delta}^T + (\bm{C} + (\lambda_t + \gamma) \bm{I}) \bm{\delta}, -Amortizing expensive computations ---------------------------------- +where :math:`\bm{C}` is usually the exact curvature matrix. +To compute :math:`\alpha_t` and :math:`\beta_t`, we minimize +:math:`q(\alpha_t \hat{\bm{g}}_t + \beta_t \bm{v}_t)` with respect to the two +scalars, treating :math:`\hat{\bm{g}}_t` and :math:`\bm{v}_t` as fixed vectors. +Since this is a simple two dimensional quadratic problem, and it requires only +matrix-vector products with :math:`\bm{C}`, it can be solved efficiently. +For further details see Section 7 of the original +`K-FAC paper `_. -When running the optimizer, several of the steps involved can have -a somewhat significant computational overhead. -For this reason, the optimizer class allows these to be performed every `K` -steps, and to cache these values across iterations. -This has been found to work well in practice without significant drawbacks in -training performance. -Specifically, this is applied to computing the inverse of the estimated -approximate curvature (step 3), and to the updates to the damping (step 7). -Computing the update coefficients ---------------------------------- +Updating the damping (7) +------------------------ -The update coefficients ``c1`` and ``c2`` in step 4 can either be provided -manually by the user at each step, or can be computed automatically using the -procedure described in Section 7 of the original -`K-FAC paper `_. -This is controlled by the optimizer arguments ``use_adaptive_learning_rate`` -and ``use_adaptive_momentum``. -Note that these features don't currently work very well unless you use a very -large batch size, and/or increase the batch size dynamically during training -(as was done in the original K-FAC paper). +The damping update is done via the Levenberg-Marquardt heuristic. +This is done by computing the reduction ratio: -Updating the damping --------------------- +.. math:: + \rho = \frac{\mathcal{L}(\bm{\theta}_{t+1}) - \mathcal{L}(\bm{\theta}_{t})} + {q(\bm{v}_{t+1}) - q(\bm{0})} -The damping update is done via the Levenberg-Marquardt heuristic. -This is done by computing the reduction ratio -``(f(theta_new) - f(theta)) / (q(theta_new) - q_theta)``, where ``q`` is the -quadratic model value induced by either the exact or approximate curvature -matrix. +where :math:`q` is the quadratic model value induced by either the exact or +approximate curvature matrix. If the optimizer uses either learning rate or momentum adaptation, or ``always_use_exact_qmodel_for_damping_adjustment`` is set to ``True``, the optimizer will use the exact curvature matrix; otherwise it will use the approximate curvature. -If this value deviates too much from ``1`` we either increase or decrease the -damping as described in Section 6.5 from the original +If the value of :math:`\rho` deviates too much from 1 we either increase or +decrease the damping :math:`\lambda` as described in Section 6.5 of the original `K-FAC paper `_. Whether the damping is adapted, or provided by the user at each single step, is controlled by the optimizer argument ``use_adaptive_damping``. + + +Amortizing expensive computations +--------------------------------- + +When running the optimizer, several of the steps involved can have +a noticeable computational overhead. +For this reason, the optimizer class allows these to be performed every `K` +steps, and to cache the values across iterations. +This has been found to work well in practice without significant drawbacks in +training performance. +This is applied to computing the inverse of the estimated approximate curvature +(step 3), and to the updates of the damping (step 7). diff --git a/kfac_jax/_src/curvature_blocks.py b/kfac_jax/_src/curvature_blocks.py index 836b39c..0d845c7 100644 --- a/kfac_jax/_src/curvature_blocks.py +++ b/kfac_jax/_src/curvature_blocks.py @@ -126,6 +126,18 @@ class CurvatureBlock(utils.Finalizable): @utils.pytree_dataclass class State: + """Persistent state of the block. + + Any subclasses of :class:`~CurvatureBlock` should also internally extend + this class, with any attributes needed for the curvature estimation. + + Attributes: + cache: A dictionary, containing any state data that is updated on + irregular intervals, such as inverses, eigenvalues, etc. Elements of + this are updated via calls to :func:`~CurvatureBlock.update_cache`, and + do not necessarily correspond to the the most up to date curvature + estimate. + """ cache: Optional[Dict[str, Union[chex.Array, Dict[str, chex.Array]]]] def __init__(self, layer_tag_eq: tags.LayerTagEqn, name: str): @@ -566,6 +578,13 @@ class Diagonal(CurvatureBlock, abc.ABC): @utils.pytree_dataclass class State(CurvatureBlock.State): + """Persistent state of the block. + + Attributes: + diagonal_factors: A tuple of the moving averages of the estimated + diagonals of the curvature for each parameter that is part of the + associated layer. + """ diagonal_factors: Tuple[utils.WeightedMovingAverage] def _init( @@ -659,6 +678,12 @@ class Full(CurvatureBlock, abc.ABC): @utils.pytree_dataclass class State(CurvatureBlock.State): + """Persistent state of the block. + + Attributes: + matrix: A moving average of the estimated curvature matrix for all + parameters that are part of the associated layer. + """ matrix: utils.WeightedMovingAverage def __init__( @@ -836,6 +861,14 @@ class TwoKroneckerFactored(CurvatureBlock, abc.ABC): @utils.pytree_dataclass class State(CurvatureBlock.State): + """Persistent state of the block. + + Attributes: + inputs_factor: A moving average of the estimated second moment matrix of + the inputs to the associated layer. + outputs_factor: A moving average of the estimated second moment matrix of + the gradients of w.r.t. the outputs of the associated layer. + """ inputs_factor: utils.WeightedMovingAverage outputs_factor: utils.WeightedMovingAverage diff --git a/kfac_jax/_src/curvature_estimator.py b/kfac_jax/_src/curvature_estimator.py index fab4bd6..3a0b4a9 100644 --- a/kfac_jax/_src/curvature_estimator.py +++ b/kfac_jax/_src/curvature_estimator.py @@ -727,6 +727,12 @@ class BlockDiagonalCurvature(CurvatureEstimator): @utils.pytree_dataclass class State: + """Persistent state of the estimator. + + Attributes: + blocks_states: A tuple of the state of the estimator corresponding to each + block. + """ blocks_states: Tuple[curvature_blocks.CurvatureBlock.State, ...] def __init__( diff --git a/kfac_jax/_src/optimizer.py b/kfac_jax/_src/optimizer.py index 28ed748..ee01b3a 100644 --- a/kfac_jax/_src/optimizer.py +++ b/kfac_jax/_src/optimizer.py @@ -65,6 +65,17 @@ class Optimizer(utils.WithStagedMethods): @utils.pytree_dataclass class State: + r"""Persistent state of the optimizer. + + Attributes: + velocities: The update to the parameters from the previous step - + :math:`\theta_t - \theta_{t-1}`. + estimator_state: The persistent state for the curvature estimator. + damping: When using damping adaptation, this will contain the current + value. + data_seen: The number of training cases that the optimizer has processed. + step_counter: An integer giving the current step number :math:`t`. + """ velocities: utils.Params estimator_state: curvature_estimator.BlockDiagonalCurvature.State damping: Optional[chex.Array] diff --git a/tests/models.py b/tests/models.py index da3fba1..be27780 100644 --- a/tests/models.py +++ b/tests/models.py @@ -475,31 +475,31 @@ def conv_classifier_loss( functools.partial( autoencoder_with_two_losses, layer_widths=[20, 10, 20]), - dict(images=(32,)), + dict(images=(8,)), 1231987, ), - # ( - # conv_classifier( - # num_classes=10, - # layer_channels=[8, 16, 32] - # ).init, - # functools.partial( - # conv_classifier_loss, - # num_classes=10, - # layer_channels=[8, 16, 32]), - # dict(images=(32, 32, 3), labels=(10,)), - # 354649831, - # ), + ( + conv_classifier( + num_classes=10, + layer_channels=[8, 16] + ).init, + functools.partial( + conv_classifier_loss, + num_classes=10, + layer_channels=[8, 16]), + dict(images=(8, 8, 3), labels=(10,)), + 354649831, + ), ] LINEAR_MODELS = [ ( - autoencoder([100, 50, 100]).init, + autoencoder([20, 10, 20]).init, functools.partial( linear_squared_error_autoencoder_loss, - layer_widths=[100, 50, 100]), - dict(images=(64,)), + layer_widths=[20, 10, 20]), + dict(images=(8,)), 1240982837, ), ] @@ -507,13 +507,13 @@ def conv_classifier_loss( PIECEWISE_LINEAR_MODELS = [ ( - autoencoder([100, 50, 100]).init, + autoencoder([20, 10, 20]).init, functools.partial( autoencoder_with_two_losses, - layer_widths=[100, 50, 100], + layer_widths=[20, 10, 20], activation=_special_relu, ), - dict(images=(64,)), + dict(images=(8,)), 1231987, ), ]