From 6831396b99faeff9a574afbe4f777528753c5b65 Mon Sep 17 00:00:00 2001 From: Alex Botev Date: Sun, 14 Aug 2022 13:04:52 -0700 Subject: [PATCH] Internal change. PiperOrigin-RevId: 467544278 --- kfac_jax/_src/curvature_estimator.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/kfac_jax/_src/curvature_estimator.py b/kfac_jax/_src/curvature_estimator.py index c9e5b26..bb243b0 100644 --- a/kfac_jax/_src/curvature_estimator.py +++ b/kfac_jax/_src/curvature_estimator.py @@ -48,6 +48,7 @@ and GGN matrices and how to compute matrix-vector products. """ import abc +import functools from typing import Any, Callable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union import chex @@ -859,7 +860,7 @@ def __init__( self._blocks: Tuple[curvature_blocks.CurvatureBlock] = None def _create_blocks(self): - """Creates all of the curvature blocks instances in ``self._blocks``.""" + """Creates all the curvature blocks instances in ``self._blocks``.""" blocks_list = [] counters = dict() for tag_eqn, idx in zip(self._jaxpr.layer_tags, self._jaxpr.layer_indices): @@ -878,7 +879,10 @@ def _create_blocks(self): assert block_name not in counters counters[block_name] = 1 else: - block_name = cls.__name__ + if isinstance(cls, functools.partial): + block_name = cls.func.__name__ + else: + block_name = cls.__name__ c = counters.get(block_name, 0) counters[block_name] = c + 1 block_name += "__" + str(c)