Skip to content

Commit

Permalink
Allowing the pi-adjusted psd inverse to accept diagonal factors.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 471014334
  • Loading branch information
botev authored and KfacJaxDev committed Aug 30, 2022
1 parent 4958cef commit d11a5b3
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 12 deletions.
4 changes: 1 addition & 3 deletions kfac_jax/_src/curvature_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,14 +959,12 @@ def _multiply_matpower_unscaled(
if not use_cached:
s_i, q_i = utils.safe_psd_eigh(state.inputs_factor.value)
s_o, q_o = utils.safe_psd_eigh(state.outputs_factor.value)
eigenvalues = jnp.outer(s_i, s_o)
else:
s_i = state.cache["inputs_factor_eigenvalues"]
q_i = state.cache["inputs_factor_eigen_vectors"]
s_o = state.cache["outputs_factor_eigenvalues"]
q_o = state.cache["outputs_factor_eigen_vectors"]
eigenvalues = jnp.outer(s_i, s_o)
eigenvalues = eigenvalues + identity_weight
eigenvalues = jnp.outer(s_i, s_o) + identity_weight
eigenvalues = jnp.power(eigenvalues, power)
result = utils.kronecker_eigen_basis_mul_v(q_o, q_i, eigenvalues, vector)
else:
Expand Down
33 changes: 24 additions & 9 deletions kfac_jax/_src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,12 +641,13 @@ def pi_adjusted_inverse(
inverse of `(a kron b + damping * I)`.
"""
# Compute the nuclear/trace norms of each factor
a_norm = jnp.trace(a)
b_norm = jnp.trace(b)
assert a.ndim <= 2 and b.ndim <= 2
a_norm = jnp.sum(a) if a.ndim == 1 else jnp.trace(a)
b_norm = jnp.sum(b) if b.ndim == 1 else jnp.trace(b)

# We need to sync the norms here, because reduction can be non-deterministic.
# They specifically are on GPUs by default for better performance.
# Hence although factor_0 and factor_1 are synced, the trace operation above
# Hence, although factor_0 and factor_1 are synced, the trace operation above
# can still produce different answers on different devices.
a_norm, b_norm = pmean_if_pmap((a_norm, b_norm), pmap_axis_name)

Expand All @@ -667,7 +668,10 @@ def regular_inverse() -> Tuple[chex.Array, chex.Array]:
# since `scale = a * ||b||`.
b_normalized = b / b_norm
b_damping = damping / scale
b_inv = psd_inv_cholesky(b_normalized, b_damping)
if b.ndim == 1:
b_inv = 1.0 / (b_normalized + b_damping)
else:
b_inv = psd_inv_cholesky(b_normalized, b_damping)
return jnp.full_like(a, 1.0 / scale), b_inv

elif b.size == 1:
Expand All @@ -676,7 +680,10 @@ def regular_inverse() -> Tuple[chex.Array, chex.Array]:
# since `scale = ||a|| * b`.
a_normalized = a / a_norm
a_damping = damping / scale
a_inv = psd_inv_cholesky(a_normalized, a_damping)
if a.ndim == 1:
a_inv = 1.0 / (a_normalized + a_damping)
else:
a_inv = psd_inv_cholesky(a_normalized, a_damping)
return a_inv, jnp.full_like(b, 1.0 / scale)

else:
Expand All @@ -685,17 +692,25 @@ def regular_inverse() -> Tuple[chex.Array, chex.Array]:
# Invert first factor
a_normalized = a / a_norm
a_damping = jnp.sqrt(damping * b.shape[0] / (scale * a.shape[0]))
a_inv = psd_inv_cholesky(a_normalized, a_damping) / jnp.sqrt(scale)
if a.ndim == 1:
a_inv = 1.0 / (a_normalized + a_damping) / jnp.sqrt(scale)
else:
a_inv = psd_inv_cholesky(a_normalized, a_damping) / jnp.sqrt(scale)

# Invert second factor
b_normalized = b / b_norm
b_damping = jnp.sqrt(damping * a.shape[0] / (scale * b.shape[0]))
b_inv = psd_inv_cholesky(b_normalized, b_damping) / jnp.sqrt(scale)
if b.ndim == 1:
b_inv = 1.0 / (b_normalized + b_damping) / jnp.sqrt(scale)
else:
b_inv = psd_inv_cholesky(b_normalized, b_damping) / jnp.sqrt(scale)

return a_inv, b_inv

def zero_inverse() -> Tuple[chex.Array, chex.Array]:
return (jnp.eye(a[0].shape[0]) / jnp.sqrt(damping),
jnp.eye(b[1].shape[0]) / jnp.sqrt(damping))
a_inv = jnp.ones_like(a) if a.ndim == 1 else jnp.eye(a.shape[0])
b_inv = jnp.ones_like(b) if b.ndim == 1 else jnp.eye(b.shape[0])
return a_inv / jnp.sqrt(damping), b_inv / jnp.sqrt(damping)

if get_special_case_zero_inv():
# In the special case where for some reason one of the factors is zero, then
Expand Down

0 comments on commit d11a5b3

Please sign in to comment.