Skip to content

Commit

Permalink
* Make LossTag to return only the parameter dependent arrays.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 467958268
  • Loading branch information
botev authored and KfacJaxDev committed Aug 16, 2022
1 parent 6831396 commit 3fa4efb
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 44 deletions.
4 changes: 2 additions & 2 deletions kfac_jax/_src/layers_and_loss_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def get_outputs(
raise ValueError("Inputs to the tag are too many.")
if self.num_inputs < len(args) < self.num_inputs + self.num_targets:
raise ValueError("Inputs to the tag are not quite enough.")
return args
return args[:self.num_inputs]

def impl(self, *operands: chex.Array, **_: Any) -> Tuple[chex.Array, ...]:
return self.get_outputs(*operands)
Expand Down Expand Up @@ -128,7 +128,7 @@ def _jvp(
if len(arg_values) != len(arg_tangents):
raise ValueError("Values and tangents are not the same length.")
primal_output = self.bind(*arg_values, **kwargs)
return primal_output, tuple(arg_tangents)
return primal_output, tuple(self.get_outputs(*arg_tangents))

def _batching(
self,
Expand Down
20 changes: 10 additions & 10 deletions kfac_jax/_src/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1057,7 +1057,7 @@ def register_normal_predictive_distribution(
targets: Optional[chex.Array] = None,
variance: float = 0.5,
weight: float = 1.0,
) -> Tuple[chex.Array, ...]:
) -> chex.Array:
"""Registers a normal predictive distribution.
This corresponds to a squared error loss of the form
Expand Down Expand Up @@ -1088,14 +1088,14 @@ def register_normal_predictive_distribution(
if targets is None:
targets = jnp.zeros_like(mean)
return NormalMeanNegativeLogProbLoss_tag.bind(
mean, targets, variance=variance, weight=weight)
mean, targets, variance=variance, weight=weight)[0]


def register_squared_error_loss(
prediction: chex.Array,
targets: Optional[chex.Array] = None,
weight: float = 1.0,
) -> Tuple[chex.Array, ...]:
) -> chex.Array:
"""Registers a squared error loss function.
This assumes the squared error loss of the form ``||target - prediction||^2``,
Expand All @@ -1122,7 +1122,7 @@ def register_multi_bernoulli_predictive_distribution(
logits: chex.Array,
targets: Optional[chex.Array] = None,
weight: float = 1.0,
) -> Tuple[chex.Array, ...]:
) -> chex.Array:
"""Registers a multi-Bernoulli predictive distribution.
Note that this is distinct from
Expand All @@ -1147,14 +1147,14 @@ def register_multi_bernoulli_predictive_distribution(
if targets is None:
targets = jnp.zeros_like(logits)
return MultiBernoulliNegativeLogProbLoss_tag.bind(
logits, targets, weight=weight)
logits, targets, weight=weight)[0]


def register_sigmoid_cross_entropy_loss(
logits: chex.Array,
targets: Optional[chex.Array] = None,
weight: float = 1.0,
) -> Tuple[chex.Array, ...]:
) -> chex.Array:
"""Registers a sigmoid cross-entropy loss function.
Note that this is distinct from :func:`~register_softmax_cross_entropy_loss`
Expand All @@ -1181,7 +1181,7 @@ def register_categorical_predictive_distribution(
logits: chex.Array,
targets: Optional[chex.Array] = None,
weight: float = 1.0,
) -> Tuple[chex.Array, ...]:
) -> chex.Array:
"""Registers a categorical predictive distribution.
Note that this is distinct from
Expand All @@ -1207,10 +1207,10 @@ def register_categorical_predictive_distribution(
targets = jnp.zeros_like(logits[..., 0])
if targets.ndim == logits.ndim:
return OneHotCategoricalLogitsNegativeLogProbLoss_tag.bind(
logits, targets, weight=weight)
logits, targets, weight=weight)[0]
elif targets.ndim == logits.ndim - 1:
return CategoricalLogitsNegativeLogProbLoss_tag.bind(
logits, targets, weight=weight)
logits, targets, weight=weight)[0]
else:
raise ValueError(f"The logits rank is {logits.ndim} and the targets rank "
f"must be either equal or one less than it, but is "
Expand All @@ -1221,7 +1221,7 @@ def register_softmax_cross_entropy_loss(
logits: chex.Array,
targets: Optional[chex.Array] = None,
weight: float = 1.0,
) -> Tuple[chex.Array, ...]:
) -> chex.Array:
"""Registers a softmax cross-entropy loss function.
Note that this is distinct from :func:`~register_sigmoid_cross_entropy_loss`
Expand Down
56 changes: 28 additions & 28 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,9 @@ def _register_deterministic_bernoulli(
logits: chex.Array,
targets: chex.Array,
weight=1.0
) -> Tuple[chex.Array, chex.Array]:
) -> chex.Array:
return DeterministicBernoulliNegativeLogProbLoss_tag.bind(
logits, targets, weight=weight)
logits, targets, weight=weight)[0]


class _DeterministicCategorical(distrax.Categorical):
Expand All @@ -302,9 +302,9 @@ def _register_deterministic_categorical(
logits: chex.Array,
targets: chex.Array,
weight=1.0
) -> Tuple[chex.Array, chex.Array]:
) -> chex.Array:
return DeterministicBernoulliNegativeLogProbLoss_tag.bind(
logits, targets, weight=weight)
logits, targets, weight=weight)[0]


def squared_error_loss(
Expand All @@ -313,7 +313,7 @@ def squared_error_loss(
model_func: Callable[..., hk.Transformed],
l2_reg: float = 0.0,
explicit_tagging: bool = False,
return_registered_losses_inputs: bool = False,
return_losses_outputs: bool = False,
return_layer_values: bool = False,
) -> LossOutputs:
"""A squared error loss computed for the given model function."""
Expand All @@ -327,10 +327,10 @@ def squared_error_loss(
y = y.reshape((-1, y.shape[-1]))
y_hat = y_hat.reshape((-1, y_hat.shape[-1]))

y_hat, y = loss_functions.register_squared_error_loss(y_hat, y, weight=0.5)
y_hat = loss_functions.register_squared_error_loss(y_hat, y, weight=0.5)

if return_registered_losses_inputs:
return [[y_hat, y]]
if return_losses_outputs:
return [[y_hat]]

loss = jnp.mean(jnp.sum((y_hat - y) ** 2, axis=-1)) / 2
loss = loss + l2_reg * utils.norm(params)
Expand Down Expand Up @@ -373,7 +373,7 @@ def linear_squared_error_autoencoder_loss(
layer_widths: Sequence[int],
l2_reg: float = 0.0,
explicit_tagging: bool = False,
return_registered_losses_inputs: bool = False,
return_losses_outputs: bool = False,
return_layer_values: bool = False,
) -> LossOutputs:
"""A linear autoencoder with squared error."""
Expand All @@ -387,7 +387,7 @@ def linear_squared_error_autoencoder_loss(
model_func=model_func,
l2_reg=l2_reg,
explicit_tagging=explicit_tagging,
return_registered_losses_inputs=return_registered_losses_inputs,
return_losses_outputs=return_losses_outputs,
return_layer_values=return_layer_values,
)

Expand All @@ -405,7 +405,7 @@ def autoencoder_deterministic_loss(
logits, _ = autoencoder(
layer_widths, x.shape[-1], explicit_tagging, activation=activation,
).apply(params, x)
logits, _ = _register_deterministic_bernoulli(logits, x)
logits = _register_deterministic_bernoulli(logits, x)
loss = - distrax.Bernoulli(logits=logits).log_prob(x)
loss = jnp.mean(jnp.sum(loss, axis=-1)).astype(logits.dtype)
return loss + l2_reg * utils.norm(params)
Expand All @@ -417,7 +417,7 @@ def autoencoder_with_two_losses(
layer_widths: Sequence[int],
aux: Optional[Tuple[chex.Array, ...]] = None,
explicit_tagging: bool = False,
return_registered_losses_inputs: bool = False,
return_losses_outputs: bool = False,
return_layer_values: bool = False,
activation: Callable[[LayerInputs], LayerInputs] = _special_tanh,
) -> LossOutputs:
Expand All @@ -429,13 +429,13 @@ def autoencoder_with_two_losses(
).apply(params, x, aux)

# Register both losses in KFAC
logits1, img1 = loss_functions.register_multi_bernoulli_predictive_distribution(
logits1 = loss_functions.register_multi_bernoulli_predictive_distribution(
logits, x)
logits2, img2 = loss_functions.register_normal_predictive_distribution(
logits2 = loss_functions.register_normal_predictive_distribution(
logits, x, weight=0.1)

if return_registered_losses_inputs:
return [[logits1, img1], [logits2, img2]]
if return_losses_outputs:
return [[logits1], [logits2]]

loss_1 = - distrax.Bernoulli(logits=logits1).log_prob(x)
scale_diag = jnp.ones_like(logits2) * jnp.sqrt(0.5)
Expand All @@ -456,7 +456,7 @@ def conv_classifier(
stride: int = 2,
activation: Callable[[LayerInputs], LayerInputs] = _special_tanh,
) -> hk.Transformed:
"""Constructs a Haiku transformed object of the autoencoder network."""
"""Constructs a Haiku transformed object of a convolutional classifier."""
def func(
batch: Union[chex.Array, Mapping[str, chex.Array]],
aux: Optional[Tuple[chex.Array, ...]] = None,
Expand Down Expand Up @@ -510,11 +510,11 @@ def conv_classifier_deterministic_loss(
explicit_tagging: bool = False,
activation: Callable[[LayerInputs], LayerInputs] = _special_tanh,
) -> chex.Array:
"""Evaluate the autoencoder with a deterministic loss."""
"""Evaluate the convolutional classifier with a deterministic loss."""
logits, _ = conv_classifier(
num_classes, layer_channels, explicit_tagging, activation=activation
).apply(params, batch["images"])
logits, _ = _register_deterministic_categorical(logits, batch["labels"])
logits = _register_deterministic_categorical(logits, batch["labels"])
loss = - distrax.Categorical(logits=logits).log_prob(batch["labels"])
loss = jnp.mean(jnp.sum(loss, axis=-1)).astype(logits.dtype)
return loss + l2_reg * utils.norm(params)
Expand All @@ -528,19 +528,19 @@ def conv_classifier_loss(
aux: Optional[Tuple[chex.Array, ...]] = None,
l2_reg: float = 0.0,
explicit_tagging: bool = False,
return_registered_losses_inputs: bool = False,
return_losses_outputs: bool = False,
return_layer_values: bool = False,
activation: Callable[[LayerInputs], LayerInputs] = _special_tanh,
) -> LossOutputs:
"""Evaluates the autoencoder with a deterministic loss."""
"""Evaluates the convolutional classifier loss."""
logits, layer_values = conv_classifier(
num_classes, layer_channels, explicit_tagging, activation=activation
).apply(params, batch["images"], aux=aux)
logits, labels = loss_functions.register_categorical_predictive_distribution(
logits = loss_functions.register_categorical_predictive_distribution(
logits, batch["labels"])

if return_registered_losses_inputs:
return [[logits, labels]]
if return_losses_outputs:
return [[logits]]

loss = - distrax.Categorical(logits=logits).log_prob(batch["labels"])
loss = loss + l2_reg * utils.norm(params)
Expand Down Expand Up @@ -602,7 +602,7 @@ def layer_stack_mlp_loss(
layer_widths: Sequence[int],
l2_reg: float = 0.0,
explicit_tagging: bool = False,
return_registered_losses_inputs: bool = False,
return_losses_outputs: bool = False,
return_layer_values: bool = False,
activation: Callable[[LayerInputs], LayerInputs] = _special_tanh,
) -> LossOutputs:
Expand All @@ -617,7 +617,7 @@ def layer_stack_mlp_loss(
),
l2_reg=l2_reg,
explicit_tagging=explicit_tagging,
return_registered_losses_inputs=return_registered_losses_inputs,
return_losses_outputs=return_losses_outputs,
return_layer_values=return_layer_values,
)

Expand Down Expand Up @@ -666,7 +666,7 @@ def vanilla_rnn_with_scan_loss(
hidden_size: int,
l2_reg: float = 0.0,
explicit_tagging: bool = False,
return_registered_losses_inputs: bool = False,
return_losses_outputs: bool = False,
return_layer_values: bool = False,
activation: Callable[[LayerInputs], LayerInputs] = _special_tanh,
) -> LossOutputs:
Expand All @@ -681,7 +681,7 @@ def vanilla_rnn_with_scan_loss(
),
l2_reg=l2_reg,
explicit_tagging=explicit_tagging,
return_registered_losses_inputs=return_registered_losses_inputs,
return_losses_outputs=return_losses_outputs,
return_layer_values=return_layer_values,
)

Expand Down
6 changes: 2 additions & 4 deletions tests/test_graph_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,7 @@ def check_jaxpr_equal(self, jaxpr_1, jaxpr_2, map_output_vars: bool):
self.assertEqual(len(jaxpr_1.invars), len(jaxpr_2.invars))
self.assertEqual(len(jaxpr_1.constvars), len(jaxpr_2.constvars))
self.assertEqual(len(jaxpr_1.outvars), len(jaxpr_2.outvars))
# Note that since the auto registered computation finishes at the loss tags
# it will always have the same or less number of equations.
self.assertLessEqual(len(jaxpr_1.eqns), len(jaxpr_2.eqns))
self.assertEqual(len(jaxpr_1.eqns), len(jaxpr_2.eqns))

# Extract all loss tags from both jax expressions
l1_eqns = []
Expand Down Expand Up @@ -153,7 +151,7 @@ def test_auto_register_tags_jaxpr(
tagged_func = functools.partial(
model_func,
explicit_tagging=True,
return_registered_losses_inputs=True,
return_losses_outputs=True,
)
tagged_jaxpr = jax.make_jaxpr(tagged_func)(params, data).jaxpr
self.check_jaxpr_equal(jaxpr, tagged_jaxpr, False)
Expand Down

0 comments on commit 3fa4efb

Please sign in to comment.