Skip to content

Commit

Permalink
Fixing minor type error.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 484237101
  • Loading branch information
james-martens authored and KfacJaxDev committed Oct 27, 2022
1 parent c8015a0 commit 8a79a60
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion kfac_jax/_src/curvature_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,10 +1015,13 @@ def params_vector_to_blocks_vectors(
parameter_structured_vector: utils.Params,
) -> Tuple[Tuple[chex.Array, ...]]:
"""Splits the parameters to values for each corresponding block."""

params_values_flat = jax.tree_util.tree_leaves(parameter_structured_vector)
blocks_vectors: List[Tuple[chex.Array, ...]] = []

for indices in self.jaxpr.layer_indices:
blocks_vectors.append(tuple(params_values_flat[i] for i in indices))

return tuple(blocks_vectors)

def blocks_vectors_to_params_vector(
Expand Down Expand Up @@ -1498,7 +1501,8 @@ def params_vector_to_blocks_vectors(
self,
parameter_structured_vector: utils.Params,
) -> Tuple[Tuple[chex.Array, ...]]:
return jax.tree_util.tree_leaves(parameter_structured_vector),

return (tuple(jax.tree_util.tree_leaves(parameter_structured_vector)),)

def blocks_vectors_to_params_vector(
self,
Expand Down

0 comments on commit 8a79a60

Please sign in to comment.