Skip to content

Commit

Permalink
Merge pull request #1515 from Frightera:frighterafix#1505
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 453335007
  • Loading branch information
tensorflower-gardener committed Jun 7, 2022
2 parents 6d04325 + bd53773 commit 0efbee9
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 3 deletions.
19 changes: 19 additions & 0 deletions tensorflow_probability/python/layers/dense_variational_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,25 @@ def call(self, inputs):

return outputs

def compute_output_shape(self, input_shape):
"""Computes the output shape of the layer.
Args:
input_shape: `TensorShape` or `list` of `TensorShape`
(only last dim is used)
Returns:
The output shape.
Raises:
ValueError: If the innermost dimension of `input_shape` is not defined.
"""
input_shape = tf.TensorShape(input_shape)
input_shape = input_shape.with_rank_at_least(2)
if input_shape[-1] is None:
raise ValueError(
f'The innermost dimension of input_shape must be defined, but saw: {input_shape}'
)
return input_shape[:-1].concatenate(self.units)


def _make_kl_divergence_penalty(
use_exact_kl=False,
Expand Down
13 changes: 10 additions & 3 deletions tensorflow_probability/python/layers/dense_variational_v2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,12 @@ def test_end_to_end(self):
# Get dataset.
y, x, x_tst = create_dataset()

# Build model.
layer = tfp.layers.DenseVariational(1, posterior_mean_field,
prior_trainable)

model = tf.keras.Sequential([
tfp.layers.DenseVariational(1, posterior_mean_field, prior_trainable),
tfp.layers.DistributionLambda(lambda t: tfd.Normal(loc=t, scale=1)),
layer,
tfp.layers.DistributionLambda(lambda t: tfd.Normal(loc=t, scale=1))
])

# Do inference.
Expand All @@ -96,6 +98,11 @@ def test_end_to_end(self):
self.assertContainsSubsequence(posterior.name, '/posterior/')
self.assertContainsSubsequence(prior.name, '/prior/')

# Check the output_shape.
expected_output_shape = layer.compute_output_shape(
(None, x.shape[-1])).as_list()
self.assertAllEqual(expected_output_shape, (None, 1))

# Profit.
yhat = model(x_tst)
self.assertIsInstance(yhat, tfd.Distribution)
Expand Down

0 comments on commit 0efbee9

Please sign in to comment.