Skip to content

Commit

Permalink
Making ScaleAndShift blocks begin capable of having parameters that a…
Browse files Browse the repository at this point in the history
…re broadcast by construction, e.g. batch norm with scale parameters [1, 1, 1, d].

PiperOrigin-RevId: 456070961
  • Loading branch information
botev authored and KfacJaxDev committed Jul 14, 2022
1 parent 1ace327 commit 433f838
Showing 1 changed file with 31 additions and 12 deletions.
43 changes: 31 additions & 12 deletions kfac_jax/_src/curvature_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1509,6 +1509,25 @@ def update_curvature_matrix_estimate(
#


def compatible_shapes(ref_shape, target_shape):
if len(target_shape) > len(ref_shape):
raise ValueError("Target shape should be smaller.")
for ref_d, target_d in zip(reversed(ref_shape), reversed(target_shape)):
if ref_d != target_d and target_d != 1:
raise ValueError(f"{target_shape} is incompatible with {ref_shape}.")


def compatible_sum(tensor, target_shape, skip_axes):
compatible_shapes(tensor.shape, target_shape)
n = tensor.ndim - len(target_shape)
axis = [i + n for i, t in enumerate(target_shape)
if t == 1 and i + n not in skip_axes]
tensor = jnp.sum(tensor, axis=axis, keepdims=True)
axis = [i for i in range(tensor.ndim - len(target_shape))
if i not in skip_axes]
return jnp.sum(tensor, axis=axis)


class ScaleAndShiftDiagonal(Diagonal):
"""A diagonal approximation specifically for a scale and shift layers."""

Expand Down Expand Up @@ -1539,18 +1558,20 @@ def _update_curvature_matrix_estimate(
assert (state.diagonal_factors[0].raw_value.shape ==
self.parameters_shapes[0])
scale_shape = estimation_data["params"][0].shape
axis = range(x.ndim)[1:(x.ndim - len(scale_shape))]
d_scale = jnp.sum(x * dy, axis=tuple(axis))
scale_diag_update = jnp.sum(d_scale * d_scale, axis=0) / batch_size
d_scale = compatible_sum(x * dy, scale_shape, skip_axes=[0])
scale_diag_update = jnp.sum(
d_scale * d_scale,
axis=0, keepdims=d_scale.ndim == len(scale_shape)
) / batch_size
state.diagonal_factors[0].update(scale_diag_update, ema_old, ema_new)

if self.has_shift:
assert (state.diagonal_factors[-1].raw_value.shape ==
self.parameters_shapes[-1])
shift_shape = estimation_data["params"][-1].shape
axis = range(x.ndim)[1:(x.ndim - len(shift_shape))]
d_shift = jnp.sum(dy, axis=tuple(axis))
shift_diag_update = jnp.sum(d_shift * d_shift, axis=0) / batch_size
d_shift = compatible_sum(dy, shift_shape, skip_axes=[0])
shift_diag_update = jnp.sum(
d_shift * d_shift,
axis=0, keepdims=d_shift.ndim == len(shift_shape)
) / batch_size
state.diagonal_factors[-1].update(shift_diag_update, ema_old, ema_new)

return state
Expand Down Expand Up @@ -1587,16 +1608,14 @@ def update_curvature_matrix_estimate(
if self._has_scale:
# Scale tangent
scale_shape = estimation_data["params"][0].shape
axis = range(x.ndim)[1:(x.ndim - len(scale_shape))]
d_scale = jnp.sum(x * dy, axis=tuple(axis))
d_scale = compatible_sum(x * dy, scale_shape, skip_axes=[0])
d_scale = d_scale.reshape([batch_size, -1])
tangents.append(d_scale)

if self._has_shift:
# Shift tangent
shift_shape = estimation_data["params"][-1].shape
axis = range(x.ndim)[1:(x.ndim - len(shift_shape))]
d_shift = jnp.sum(dy, axis=tuple(axis))
d_shift = compatible_sum(dy, shift_shape, skip_axes=[0])
d_shift = d_shift.reshape([batch_size, -1])
tangents.append(d_shift)

Expand Down

0 comments on commit 433f838

Please sign in to comment.