Skip to content

Commit

Permalink
* Revamping the graph matching code to be able to detect layers and r…
Browse files Browse the repository at this point in the history
…egister tag in arbitrary higher-order Jax primitives.

* A few minor typos.
* Bumping jax and jaxlib version to 0.3.15

PiperOrigin-RevId: 460947146
  • Loading branch information
botev authored and KfacJaxDev committed Aug 11, 2022
1 parent 698fcc6 commit cc2cd86
Show file tree
Hide file tree
Showing 9 changed files with 1,219 additions and 541 deletions.
2 changes: 1 addition & 1 deletion kfac_jax/_src/curvature_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,7 +983,7 @@ def _finalize(self, func_args: utils.FuncArgs):

def init(
self,
rng: chex.Array,
rng: chex.PRNGKey,
func_args: utils.FuncArgs,
exact_powers_to_cache: Optional[curvature_blocks.ScalarOrSequence],
approx_powers_to_cache: Optional[curvature_blocks.ScalarOrSequence],
Expand Down
Loading

0 comments on commit cc2cd86

Please sign in to comment.