Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates the code to always create variables and computations of the same dtype as the its inputs. Previously, if float64 was enabled, some of the results would be (potentially incorrectly) promoted to higher precision. #82

Merged
merged 1 commit into from
Feb 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 29 additions & 19 deletions kfac_jax/_src/curvature_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,16 @@ def inputs_shapes(self) -> Tuple[Shape, ...]:
@property
def parameters_shapes(self) -> Tuple[Shape, ...]:
"""The shapes of the parameter variables of the block's tag equation."""

return tuple(jax.tree_util.tree_map(
lambda x: tuple(x.aval.shape), self.parameter_variables))

@property
def dtype(self) -> chex.ArrayDType:
dtypes = set(p.aval.dtype for p in self.parameter_variables) # pytype: disable=attribute-error
if len(dtypes) > 1:
raise ValueError("Not all parameters are the same dtype.")
return dtypes.pop()

@property
def parameters_canonical_order(self) -> Tuple[int, ...]:
"""The canonical order of the parameter variables."""
Expand Down Expand Up @@ -666,8 +672,8 @@ def _init(

return Diagonal.State(
cache=None,
diagonal_factors=tuple(utils.WeightedMovingAverage.zero(s)
for s in self.parameters_shapes),
diagonal_factors=tuple(utils.WeightedMovingAverage.zero(
shape, self.dtype) for shape in self.parameters_shapes),
)

def _multiply_matpower_unscaled(
Expand Down Expand Up @@ -854,18 +860,20 @@ def _init(
cache = {}

if len(exact_powers_to_cache) > self._eigen_decomposition_threshold:
cache["eigenvalues"] = jnp.zeros([self.dim])
cache["eigen_vectors"] = jnp.zeros([self.dim, self.dim])
cache["eigenvalues"] = jnp.zeros([self.dim], self.dtype)
cache["eigen_vectors"] = jnp.zeros([self.dim, self.dim], self.dtype)

elif cache_eigenvalues:
cache["eigenvalues"] = jnp.zeros([self.dim])
cache["eigenvalues"] = jnp.zeros([self.dim], self.dtype)

if len(exact_powers_to_cache) <= self._eigen_decomposition_threshold:
for power in exact_powers_to_cache:
cache[str(power)] = jnp.zeros([self.dim, self.dim])
cache[str(power)] = jnp.zeros([self.dim, self.dim], self.dtype)

return Full.State(
cache=cache,
matrix=utils.WeightedMovingAverage.zero((self.dim, self.dim)),
matrix=utils.WeightedMovingAverage.zero(
[self.dim, self.dim], self.dtype),
)

def _multiply_matpower_unscaled(
Expand Down Expand Up @@ -978,8 +986,8 @@ def _update_cache(
else:

if eigenvalues:
state.cache["eigenvalues"] = (
scale * utils.safe_psd_eigh(state.matrix.value)[0])
state.cache["eigenvalues"] = scale * utils.safe_psd_eigh(
state.matrix.value)[0]

for power in exact_powers:

Expand Down Expand Up @@ -1078,26 +1086,29 @@ def _init(
cache = {}

if cache_eigenvalues or exact_powers_to_cache:
cache["inputs_factor_eigenvalues"] = jnp.zeros([d_in])
cache["outputs_factor_eigenvalues"] = jnp.zeros([d_out])
cache["inputs_factor_eigenvalues"] = jnp.zeros([d_in], self.dtype)
cache["outputs_factor_eigenvalues"] = jnp.zeros([d_out], self.dtype)

if exact_powers_to_cache:
cache["inputs_factor_eigen_vectors"] = jnp.zeros([d_in, d_in])
cache["outputs_factor_eigen_vectors"] = jnp.zeros([d_out, d_out])
cache["inputs_factor_eigen_vectors"] = jnp.zeros([d_in, d_in], self.dtype)
cache["outputs_factor_eigen_vectors"] = jnp.zeros(
[d_out, d_out], self.dtype)

for power in approx_powers_to_cache:
if power != -1:
raise NotImplementedError(f"Approximations for power {power} is not "
f"yet implemented.")
cache[str(power)] = dict(
inputs_factor=jnp.zeros([d_in, d_in]),
outputs_factor=jnp.zeros([d_out, d_out]),
inputs_factor=jnp.zeros([d_in, d_in], self.dtype),
outputs_factor=jnp.zeros([d_out, d_out], self.dtype),
)

return TwoKroneckerFactored.State(
cache=cache,
inputs_factor=utils.WeightedMovingAverage.zero((d_in, d_in)),
outputs_factor=utils.WeightedMovingAverage.zero((d_out, d_out)),
inputs_factor=utils.WeightedMovingAverage.zero(
[d_in, d_in], self.dtype),
outputs_factor=utils.WeightedMovingAverage.zero(
[d_out, d_out], self.dtype),
)

def _multiply_matpower_unscaled(
Expand Down Expand Up @@ -1226,7 +1237,6 @@ def _update_cache(
factor_scale = jnp.power(scale, 0.5)

if eigenvalues or exact_powers:

s_i, q_i = utils.safe_psd_eigh(state.inputs_factor.value)
s_o, q_o = utils.safe_psd_eigh(state.outputs_factor.value)

Expand Down
8 changes: 6 additions & 2 deletions kfac_jax/_src/utils/accumulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,14 @@ def sync(self, pmap_axis_name: Optional[str]) -> None:
self.raw_value = parallel.pmean_if_pmap(self.raw_value, pmap_axis_name)

@classmethod
def zero(cls, shape: chex.Shape) -> "WeightedMovingAverage":
def zero(
cls,
shape: chex.Shape,
dtype: Optional[chex.ArrayDType] = None,
) -> "WeightedMovingAverage":
"""Initializes a `WeightedMovingAverage` with a single array of zeros."""
return WeightedMovingAverage(
weight=jnp.zeros([]), raw_value=jnp.zeros(shape))
weight=jnp.zeros([], dtype), raw_value=jnp.zeros(shape, dtype))

@classmethod
def zeros_like(cls, value: PyTree) -> "WeightedMovingAverage":
Expand Down
8 changes: 5 additions & 3 deletions kfac_jax/_src/utils/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def psd_inv_cholesky(matrix: chex.Array, damping: chex.Array) -> chex.Array:
if matrix.shape[:1] != matrix.shape[1:]:
raise ValueError(f"Expected square matrix, but got shape {matrix.shape}.")

identity = jnp.eye(matrix.shape[0])
identity = jnp.eye(matrix.shape[0], dtype=matrix.dtype)

return linalg.solve(matrix + damping * identity, identity, assume_a="pos")

Expand Down Expand Up @@ -377,6 +377,7 @@ def pi_adjusted_kronecker_inverse(

# kron(arrays) = c * kron(us)
c = jnp.exp(jnp.sum(jnp.log(jnp.stack(norms)) - jnp.log(jnp.stack(dims))))
damping = damping.astype(c.dtype)

def regular_inverse() -> Tuple[chex.Array, ...]:

Expand Down Expand Up @@ -417,7 +418,7 @@ def zero_inverse() -> Tuple[chex.Array, ...]:
for a in us:

if a.ndim == 2:
inv = jnp.eye(a.shape[0])
inv = jnp.eye(a.shape[0], dtype=a.dtype)

else:
inv = jnp.ones_like(a)
Expand Down Expand Up @@ -643,7 +644,8 @@ def safe_psd_eigh(
# of cuda and cudablas they can cause a runtime error.
s, q = lax.cond(
jnp.any(jnp.isnan(x)),
lambda _: (jnp.full([d], jnp.nan), jnp.full([d, d], jnp.nan)),
lambda _: (jnp.full([d], jnp.nan, dtype=x.dtype), # pylint: disable=g-long-lambda
jnp.full([d, d], jnp.nan, dtype=x.dtype)),
functools.partial(_eigh, force_on_host=force_on_host),
x,
)
Expand Down