Skip to content

Commit

Permalink
Internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 467544278
  • Loading branch information
botev authored and KfacJaxDev committed Aug 14, 2022
1 parent 183732b commit 6831396
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions kfac_jax/_src/curvature_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
and GGN matrices and how to compute matrix-vector products.
"""
import abc
import functools
from typing import Any, Callable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union

import chex
Expand Down Expand Up @@ -859,7 +860,7 @@ def __init__(
self._blocks: Tuple[curvature_blocks.CurvatureBlock] = None

def _create_blocks(self):
"""Creates all of the curvature blocks instances in ``self._blocks``."""
"""Creates all the curvature blocks instances in ``self._blocks``."""
blocks_list = []
counters = dict()
for tag_eqn, idx in zip(self._jaxpr.layer_tags, self._jaxpr.layer_indices):
Expand All @@ -878,7 +879,10 @@ def _create_blocks(self):
assert block_name not in counters
counters[block_name] = 1
else:
block_name = cls.__name__
if isinstance(cls, functools.partial):
block_name = cls.func.__name__
else:
block_name = cls.__name__
c = counters.get(block_name, 0)
counters[block_name] = c + 1
block_name += "__" + str(c)
Expand Down

0 comments on commit 6831396

Please sign in to comment.