diff --git a/tensorflow_probability/python/layers/dense_variational_v2.py b/tensorflow_probability/python/layers/dense_variational_v2.py index 99e3be7e07..9f8dd3ebcd 100644 --- a/tensorflow_probability/python/layers/dense_variational_v2.py +++ b/tensorflow_probability/python/layers/dense_variational_v2.py @@ -142,6 +142,25 @@ def call(self, inputs): return outputs + def compute_output_shape(self, input_shape): + """Computes the output shape of the layer. + + Args: + input_shape: `TensorShape` or `list` of `TensorShape` + (only last dim is used) + Returns: + The output shape. + Raises: + ValueError: If the innermost dimension of `input_shape` is not defined. + """ + input_shape = tf.TensorShape(input_shape) + input_shape = input_shape.with_rank_at_least(2) + if input_shape[-1] is None: + raise ValueError( + f'The innermost dimension of input_shape must be defined, but saw: {input_shape}' + ) + return input_shape[:-1].concatenate(self.units) + def _make_kl_divergence_penalty( use_exact_kl=False, diff --git a/tensorflow_probability/python/layers/dense_variational_v2_test.py b/tensorflow_probability/python/layers/dense_variational_v2_test.py index 24216252a4..da46d75430 100644 --- a/tensorflow_probability/python/layers/dense_variational_v2_test.py +++ b/tensorflow_probability/python/layers/dense_variational_v2_test.py @@ -76,10 +76,12 @@ def test_end_to_end(self): # Get dataset. y, x, x_tst = create_dataset() - # Build model. + layer = tfp.layers.DenseVariational(1, posterior_mean_field, + prior_trainable) + model = tf.keras.Sequential([ - tfp.layers.DenseVariational(1, posterior_mean_field, prior_trainable), - tfp.layers.DistributionLambda(lambda t: tfd.Normal(loc=t, scale=1)), + layer, + tfp.layers.DistributionLambda(lambda t: tfd.Normal(loc=t, scale=1)) ]) # Do inference. @@ -96,6 +98,11 @@ def test_end_to_end(self): self.assertContainsSubsequence(posterior.name, '/posterior/') self.assertContainsSubsequence(prior.name, '/prior/') + # Check the output_shape. + expected_output_shape = layer.compute_output_shape( + (None, x.shape[-1])).as_list() + self.assertAllEqual(expected_output_shape, (None, 1)) + # Profit. yhat = model(x_tst) self.assertIsInstance(yhat, tfd.Distribution)