diff --git a/kfac_jax/_src/curvature_estimator.py b/kfac_jax/_src/curvature_estimator.py index c9e5b26..bb243b0 100644 --- a/kfac_jax/_src/curvature_estimator.py +++ b/kfac_jax/_src/curvature_estimator.py @@ -48,6 +48,7 @@ and GGN matrices and how to compute matrix-vector products. """ import abc +import functools from typing import Any, Callable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union import chex @@ -859,7 +860,7 @@ def __init__( self._blocks: Tuple[curvature_blocks.CurvatureBlock] = None def _create_blocks(self): - """Creates all of the curvature blocks instances in ``self._blocks``.""" + """Creates all the curvature blocks instances in ``self._blocks``.""" blocks_list = [] counters = dict() for tag_eqn, idx in zip(self._jaxpr.layer_tags, self._jaxpr.layer_indices): @@ -878,7 +879,10 @@ def _create_blocks(self): assert block_name not in counters counters[block_name] = 1 else: - block_name = cls.__name__ + if isinstance(cls, functools.partial): + block_name = cls.func.__name__ + else: + block_name = cls.__name__ c = counters.get(block_name, 0) counters[block_name] = c + 1 block_name += "__" + str(c)