diff --git a/kfac_jax/_src/curvature_blocks.py b/kfac_jax/_src/curvature_blocks.py index b9a277a..ff7dd61 100644 --- a/kfac_jax/_src/curvature_blocks.py +++ b/kfac_jax/_src/curvature_blocks.py @@ -1509,6 +1509,25 @@ def update_curvature_matrix_estimate( # +def compatible_shapes(ref_shape, target_shape): + if len(target_shape) > len(ref_shape): + raise ValueError("Target shape should be smaller.") + for ref_d, target_d in zip(reversed(ref_shape), reversed(target_shape)): + if ref_d != target_d and target_d != 1: + raise ValueError(f"{target_shape} is incompatible with {ref_shape}.") + + +def compatible_sum(tensor, target_shape, skip_axes): + compatible_shapes(tensor.shape, target_shape) + n = tensor.ndim - len(target_shape) + axis = [i + n for i, t in enumerate(target_shape) + if t == 1 and i + n not in skip_axes] + tensor = jnp.sum(tensor, axis=axis, keepdims=True) + axis = [i for i in range(tensor.ndim - len(target_shape)) + if i not in skip_axes] + return jnp.sum(tensor, axis=axis) + + class ScaleAndShiftDiagonal(Diagonal): """A diagonal approximation specifically for a scale and shift layers.""" @@ -1539,18 +1558,20 @@ def _update_curvature_matrix_estimate( assert (state.diagonal_factors[0].raw_value.shape == self.parameters_shapes[0]) scale_shape = estimation_data["params"][0].shape - axis = range(x.ndim)[1:(x.ndim - len(scale_shape))] - d_scale = jnp.sum(x * dy, axis=tuple(axis)) - scale_diag_update = jnp.sum(d_scale * d_scale, axis=0) / batch_size + d_scale = compatible_sum(x * dy, scale_shape, skip_axes=[0]) + scale_diag_update = jnp.sum( + d_scale * d_scale, + axis=0, keepdims=d_scale.ndim == len(scale_shape) + ) / batch_size state.diagonal_factors[0].update(scale_diag_update, ema_old, ema_new) if self.has_shift: - assert (state.diagonal_factors[-1].raw_value.shape == - self.parameters_shapes[-1]) shift_shape = estimation_data["params"][-1].shape - axis = range(x.ndim)[1:(x.ndim - len(shift_shape))] - d_shift = jnp.sum(dy, axis=tuple(axis)) - shift_diag_update = jnp.sum(d_shift * d_shift, axis=0) / batch_size + d_shift = compatible_sum(dy, shift_shape, skip_axes=[0]) + shift_diag_update = jnp.sum( + d_shift * d_shift, + axis=0, keepdims=d_shift.ndim == len(shift_shape) + ) / batch_size state.diagonal_factors[-1].update(shift_diag_update, ema_old, ema_new) return state @@ -1587,16 +1608,14 @@ def update_curvature_matrix_estimate( if self._has_scale: # Scale tangent scale_shape = estimation_data["params"][0].shape - axis = range(x.ndim)[1:(x.ndim - len(scale_shape))] - d_scale = jnp.sum(x * dy, axis=tuple(axis)) + d_scale = compatible_sum(x * dy, scale_shape, skip_axes=[0]) d_scale = d_scale.reshape([batch_size, -1]) tangents.append(d_scale) if self._has_shift: # Shift tangent shift_shape = estimation_data["params"][-1].shape - axis = range(x.ndim)[1:(x.ndim - len(shift_shape))] - d_shift = jnp.sum(dy, axis=tuple(axis)) + d_shift = compatible_sum(dy, shift_shape, skip_axes=[0]) d_shift = d_shift.reshape([batch_size, -1]) tangents.append(d_shift)