From 17e4ee83fb75dc3aedbc6532829bfe6a459ac83f Mon Sep 17 00:00:00 2001 From: Kaan <46622558+Frightera@users.noreply.github.com> Date: Tue, 15 Feb 2022 01:31:34 +0300 Subject: [PATCH 1/2] Implement compute_output_shape method --- .../python/layers/dense_variational_v2.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tensorflow_probability/python/layers/dense_variational_v2.py b/tensorflow_probability/python/layers/dense_variational_v2.py index 87ea963cc1..8e9c386a34 100644 --- a/tensorflow_probability/python/layers/dense_variational_v2.py +++ b/tensorflow_probability/python/layers/dense_variational_v2.py @@ -140,6 +140,24 @@ def call(self, inputs): return outputs + def compute_output_shape(self, input_shape): + """ + Computes the output shape of the layer. + Args: + input_shape: `TensorShape` or `list` of `TensorShape` + (only last dim is used) + Returns: + The output shape. + Raises: + ValueError: If the innermost dimension of `input_shape` is not defined. + """ + input_shape = tf.TensorShape(input_shape) + input_shape = input_shape.with_rank_at_least(2) + if input_shape[-1] is None: + raise ValueError( + 'The innermost dimension of input_shape must be defined, but saw: %s' % ( + input_shape,)) + return input_shape[:-1].concatenate(self.units) def _make_kl_divergence_penalty( use_exact_kl=False, From 520d4c286b74e7a0b95b8760419f45904817b33c Mon Sep 17 00:00:00 2001 From: Kaan <46622558+Frightera@users.noreply.github.com> Date: Tue, 15 Feb 2022 12:02:24 +0300 Subject: [PATCH 2/2] Update dense_variational_v2_test.py --- .../layers/dense_variational_v2_test.py | 34 +++++++++++-------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/tensorflow_probability/python/layers/dense_variational_v2_test.py b/tensorflow_probability/python/layers/dense_variational_v2_test.py index 66203db243..c5ddd920b0 100644 --- a/tensorflow_probability/python/layers/dense_variational_v2_test.py +++ b/tensorflow_probability/python/layers/dense_variational_v2_test.py @@ -72,25 +72,29 @@ def prior_trainable(kernel_size, bias_size=0, dtype=None): @test_util.test_all_tf_execution_regimes class DenseVariationalLayerTest(test_util.TestCase): - def test_end_to_end(self): - # Get dataset. - y, x, x_tst = create_dataset() + def test_end_to_end(self): + # Get dataset. + y, x, x_tst = create_dataset() - # Build model. - model = tf.keras.Sequential([ - tfp.layers.DenseVariational(1, posterior_mean_field, prior_trainable), - tfp.layers.DistributionLambda(lambda t: tfd.Normal(loc=t, scale=1)), - ]) + layer = tfp.layers.DenseVariational(1, posterior_mean_field, prior_trainable) - # Do inference. - model.compile(optimizer=tf.optimizers.Adam(learning_rate=0.05), - loss=negloglik) - model.fit(x, y, epochs=2, verbose=False) + model = tf.keras.Sequential([ + layer, + tfp.layers.DistributionLambda(lambda t: tfd.Normal(loc=t, scale=1)) + ]) - # Profit. - yhat = model(x_tst) - assert isinstance(yhat, tfd.Distribution) + # Do inference. + model.compile(optimizer=tf.optimizers.Adam(learning_rate=0.05), + loss=negloglik) + model.fit(x, y, epochs=2, verbose=False) + # Check the output_shape. + expected_output_shape = layer.compute_output_shape((None, x.shape[-1])) + self.assertAllEqual(expected_output_shape, (None, 1)) + + # Profit. + yhat = model(x_tst) + assert isinstance(yhat, tfd.Distribution) if __name__ == '__main__': test_util.main()