diff --git a/kfac_jax/_src/layers_and_loss_tags.py b/kfac_jax/_src/layers_and_loss_tags.py index 5bc0145..d63408b 100644 --- a/kfac_jax/_src/layers_and_loss_tags.py +++ b/kfac_jax/_src/layers_and_loss_tags.py @@ -93,7 +93,7 @@ def get_outputs( raise ValueError("Inputs to the tag are too many.") if self.num_inputs < len(args) < self.num_inputs + self.num_targets: raise ValueError("Inputs to the tag are not quite enough.") - return args + return args[:self.num_inputs] def impl(self, *operands: chex.Array, **_: Any) -> Tuple[chex.Array, ...]: return self.get_outputs(*operands) @@ -128,7 +128,7 @@ def _jvp( if len(arg_values) != len(arg_tangents): raise ValueError("Values and tangents are not the same length.") primal_output = self.bind(*arg_values, **kwargs) - return primal_output, tuple(arg_tangents) + return primal_output, tuple(self.get_outputs(*arg_tangents)) def _batching( self, diff --git a/kfac_jax/_src/loss_functions.py b/kfac_jax/_src/loss_functions.py index e004fe2..a25cc31 100644 --- a/kfac_jax/_src/loss_functions.py +++ b/kfac_jax/_src/loss_functions.py @@ -1057,7 +1057,7 @@ def register_normal_predictive_distribution( targets: Optional[chex.Array] = None, variance: float = 0.5, weight: float = 1.0, -) -> Tuple[chex.Array, ...]: +) -> chex.Array: """Registers a normal predictive distribution. This corresponds to a squared error loss of the form @@ -1088,14 +1088,14 @@ def register_normal_predictive_distribution( if targets is None: targets = jnp.zeros_like(mean) return NormalMeanNegativeLogProbLoss_tag.bind( - mean, targets, variance=variance, weight=weight) + mean, targets, variance=variance, weight=weight)[0] def register_squared_error_loss( prediction: chex.Array, targets: Optional[chex.Array] = None, weight: float = 1.0, -) -> Tuple[chex.Array, ...]: +) -> chex.Array: """Registers a squared error loss function. This assumes the squared error loss of the form ``||target - prediction||^2``, @@ -1122,7 +1122,7 @@ def register_multi_bernoulli_predictive_distribution( logits: chex.Array, targets: Optional[chex.Array] = None, weight: float = 1.0, -) -> Tuple[chex.Array, ...]: +) -> chex.Array: """Registers a multi-Bernoulli predictive distribution. Note that this is distinct from @@ -1147,14 +1147,14 @@ def register_multi_bernoulli_predictive_distribution( if targets is None: targets = jnp.zeros_like(logits) return MultiBernoulliNegativeLogProbLoss_tag.bind( - logits, targets, weight=weight) + logits, targets, weight=weight)[0] def register_sigmoid_cross_entropy_loss( logits: chex.Array, targets: Optional[chex.Array] = None, weight: float = 1.0, -) -> Tuple[chex.Array, ...]: +) -> chex.Array: """Registers a sigmoid cross-entropy loss function. Note that this is distinct from :func:`~register_softmax_cross_entropy_loss` @@ -1181,7 +1181,7 @@ def register_categorical_predictive_distribution( logits: chex.Array, targets: Optional[chex.Array] = None, weight: float = 1.0, -) -> Tuple[chex.Array, ...]: +) -> chex.Array: """Registers a categorical predictive distribution. Note that this is distinct from @@ -1207,10 +1207,10 @@ def register_categorical_predictive_distribution( targets = jnp.zeros_like(logits[..., 0]) if targets.ndim == logits.ndim: return OneHotCategoricalLogitsNegativeLogProbLoss_tag.bind( - logits, targets, weight=weight) + logits, targets, weight=weight)[0] elif targets.ndim == logits.ndim - 1: return CategoricalLogitsNegativeLogProbLoss_tag.bind( - logits, targets, weight=weight) + logits, targets, weight=weight)[0] else: raise ValueError(f"The logits rank is {logits.ndim} and the targets rank " f"must be either equal or one less than it, but is " @@ -1221,7 +1221,7 @@ def register_softmax_cross_entropy_loss( logits: chex.Array, targets: Optional[chex.Array] = None, weight: float = 1.0, -) -> Tuple[chex.Array, ...]: +) -> chex.Array: """Registers a softmax cross-entropy loss function. Note that this is distinct from :func:`~register_sigmoid_cross_entropy_loss` diff --git a/tests/models.py b/tests/models.py index 26566b4..16195b9 100644 --- a/tests/models.py +++ b/tests/models.py @@ -273,9 +273,9 @@ def _register_deterministic_bernoulli( logits: chex.Array, targets: chex.Array, weight=1.0 -) -> Tuple[chex.Array, chex.Array]: +) -> chex.Array: return DeterministicBernoulliNegativeLogProbLoss_tag.bind( - logits, targets, weight=weight) + logits, targets, weight=weight)[0] class _DeterministicCategorical(distrax.Categorical): @@ -302,9 +302,9 @@ def _register_deterministic_categorical( logits: chex.Array, targets: chex.Array, weight=1.0 -) -> Tuple[chex.Array, chex.Array]: +) -> chex.Array: return DeterministicBernoulliNegativeLogProbLoss_tag.bind( - logits, targets, weight=weight) + logits, targets, weight=weight)[0] def squared_error_loss( @@ -313,7 +313,7 @@ def squared_error_loss( model_func: Callable[..., hk.Transformed], l2_reg: float = 0.0, explicit_tagging: bool = False, - return_registered_losses_inputs: bool = False, + return_losses_outputs: bool = False, return_layer_values: bool = False, ) -> LossOutputs: """A squared error loss computed for the given model function.""" @@ -327,10 +327,10 @@ def squared_error_loss( y = y.reshape((-1, y.shape[-1])) y_hat = y_hat.reshape((-1, y_hat.shape[-1])) - y_hat, y = loss_functions.register_squared_error_loss(y_hat, y, weight=0.5) + y_hat = loss_functions.register_squared_error_loss(y_hat, y, weight=0.5) - if return_registered_losses_inputs: - return [[y_hat, y]] + if return_losses_outputs: + return [[y_hat]] loss = jnp.mean(jnp.sum((y_hat - y) ** 2, axis=-1)) / 2 loss = loss + l2_reg * utils.norm(params) @@ -373,7 +373,7 @@ def linear_squared_error_autoencoder_loss( layer_widths: Sequence[int], l2_reg: float = 0.0, explicit_tagging: bool = False, - return_registered_losses_inputs: bool = False, + return_losses_outputs: bool = False, return_layer_values: bool = False, ) -> LossOutputs: """A linear autoencoder with squared error.""" @@ -387,7 +387,7 @@ def linear_squared_error_autoencoder_loss( model_func=model_func, l2_reg=l2_reg, explicit_tagging=explicit_tagging, - return_registered_losses_inputs=return_registered_losses_inputs, + return_losses_outputs=return_losses_outputs, return_layer_values=return_layer_values, ) @@ -405,7 +405,7 @@ def autoencoder_deterministic_loss( logits, _ = autoencoder( layer_widths, x.shape[-1], explicit_tagging, activation=activation, ).apply(params, x) - logits, _ = _register_deterministic_bernoulli(logits, x) + logits = _register_deterministic_bernoulli(logits, x) loss = - distrax.Bernoulli(logits=logits).log_prob(x) loss = jnp.mean(jnp.sum(loss, axis=-1)).astype(logits.dtype) return loss + l2_reg * utils.norm(params) @@ -417,7 +417,7 @@ def autoencoder_with_two_losses( layer_widths: Sequence[int], aux: Optional[Tuple[chex.Array, ...]] = None, explicit_tagging: bool = False, - return_registered_losses_inputs: bool = False, + return_losses_outputs: bool = False, return_layer_values: bool = False, activation: Callable[[LayerInputs], LayerInputs] = _special_tanh, ) -> LossOutputs: @@ -429,13 +429,13 @@ def autoencoder_with_two_losses( ).apply(params, x, aux) # Register both losses in KFAC - logits1, img1 = loss_functions.register_multi_bernoulli_predictive_distribution( + logits1 = loss_functions.register_multi_bernoulli_predictive_distribution( logits, x) - logits2, img2 = loss_functions.register_normal_predictive_distribution( + logits2 = loss_functions.register_normal_predictive_distribution( logits, x, weight=0.1) - if return_registered_losses_inputs: - return [[logits1, img1], [logits2, img2]] + if return_losses_outputs: + return [[logits1], [logits2]] loss_1 = - distrax.Bernoulli(logits=logits1).log_prob(x) scale_diag = jnp.ones_like(logits2) * jnp.sqrt(0.5) @@ -456,7 +456,7 @@ def conv_classifier( stride: int = 2, activation: Callable[[LayerInputs], LayerInputs] = _special_tanh, ) -> hk.Transformed: - """Constructs a Haiku transformed object of the autoencoder network.""" + """Constructs a Haiku transformed object of a convolutional classifier.""" def func( batch: Union[chex.Array, Mapping[str, chex.Array]], aux: Optional[Tuple[chex.Array, ...]] = None, @@ -510,11 +510,11 @@ def conv_classifier_deterministic_loss( explicit_tagging: bool = False, activation: Callable[[LayerInputs], LayerInputs] = _special_tanh, ) -> chex.Array: - """Evaluate the autoencoder with a deterministic loss.""" + """Evaluate the convolutional classifier with a deterministic loss.""" logits, _ = conv_classifier( num_classes, layer_channels, explicit_tagging, activation=activation ).apply(params, batch["images"]) - logits, _ = _register_deterministic_categorical(logits, batch["labels"]) + logits = _register_deterministic_categorical(logits, batch["labels"]) loss = - distrax.Categorical(logits=logits).log_prob(batch["labels"]) loss = jnp.mean(jnp.sum(loss, axis=-1)).astype(logits.dtype) return loss + l2_reg * utils.norm(params) @@ -528,19 +528,19 @@ def conv_classifier_loss( aux: Optional[Tuple[chex.Array, ...]] = None, l2_reg: float = 0.0, explicit_tagging: bool = False, - return_registered_losses_inputs: bool = False, + return_losses_outputs: bool = False, return_layer_values: bool = False, activation: Callable[[LayerInputs], LayerInputs] = _special_tanh, ) -> LossOutputs: - """Evaluates the autoencoder with a deterministic loss.""" + """Evaluates the convolutional classifier loss.""" logits, layer_values = conv_classifier( num_classes, layer_channels, explicit_tagging, activation=activation ).apply(params, batch["images"], aux=aux) - logits, labels = loss_functions.register_categorical_predictive_distribution( + logits = loss_functions.register_categorical_predictive_distribution( logits, batch["labels"]) - if return_registered_losses_inputs: - return [[logits, labels]] + if return_losses_outputs: + return [[logits]] loss = - distrax.Categorical(logits=logits).log_prob(batch["labels"]) loss = loss + l2_reg * utils.norm(params) @@ -602,7 +602,7 @@ def layer_stack_mlp_loss( layer_widths: Sequence[int], l2_reg: float = 0.0, explicit_tagging: bool = False, - return_registered_losses_inputs: bool = False, + return_losses_outputs: bool = False, return_layer_values: bool = False, activation: Callable[[LayerInputs], LayerInputs] = _special_tanh, ) -> LossOutputs: @@ -617,7 +617,7 @@ def layer_stack_mlp_loss( ), l2_reg=l2_reg, explicit_tagging=explicit_tagging, - return_registered_losses_inputs=return_registered_losses_inputs, + return_losses_outputs=return_losses_outputs, return_layer_values=return_layer_values, ) @@ -666,7 +666,7 @@ def vanilla_rnn_with_scan_loss( hidden_size: int, l2_reg: float = 0.0, explicit_tagging: bool = False, - return_registered_losses_inputs: bool = False, + return_losses_outputs: bool = False, return_layer_values: bool = False, activation: Callable[[LayerInputs], LayerInputs] = _special_tanh, ) -> LossOutputs: @@ -681,7 +681,7 @@ def vanilla_rnn_with_scan_loss( ), l2_reg=l2_reg, explicit_tagging=explicit_tagging, - return_registered_losses_inputs=return_registered_losses_inputs, + return_losses_outputs=return_losses_outputs, return_layer_values=return_layer_values, ) diff --git a/tests/test_graph_matcher.py b/tests/test_graph_matcher.py index 6ceaa17..a942c0c 100644 --- a/tests/test_graph_matcher.py +++ b/tests/test_graph_matcher.py @@ -79,9 +79,7 @@ def check_jaxpr_equal(self, jaxpr_1, jaxpr_2, map_output_vars: bool): self.assertEqual(len(jaxpr_1.invars), len(jaxpr_2.invars)) self.assertEqual(len(jaxpr_1.constvars), len(jaxpr_2.constvars)) self.assertEqual(len(jaxpr_1.outvars), len(jaxpr_2.outvars)) - # Note that since the auto registered computation finishes at the loss tags - # it will always have the same or less number of equations. - self.assertLessEqual(len(jaxpr_1.eqns), len(jaxpr_2.eqns)) + self.assertEqual(len(jaxpr_1.eqns), len(jaxpr_2.eqns)) # Extract all loss tags from both jax expressions l1_eqns = [] @@ -153,7 +151,7 @@ def test_auto_register_tags_jaxpr( tagged_func = functools.partial( model_func, explicit_tagging=True, - return_registered_losses_inputs=True, + return_losses_outputs=True, ) tagged_jaxpr = jax.make_jaxpr(tagged_func)(params, data).jaxpr self.check_jaxpr_equal(jaxpr, tagged_jaxpr, False)