From 460afc749def66b5fcea98a36174a02329991797 Mon Sep 17 00:00:00 2001 From: st-- Date: Wed, 12 Sep 2018 20:01:15 +0100 Subject: [PATCH] Sample conditional returns mean and var as well as samples, and generates more than one sample (#836) --- gpflow/conditionals.py | 88 +++++++++++++++++++++++------- gpflow/multioutput/conditionals.py | 17 +++++- tests/test_multioutput.py | 50 +++++++++++------ 3 files changed, 115 insertions(+), 40 deletions(-) diff --git a/gpflow/conditionals.py b/gpflow/conditionals.py index fdd8209be..aa81d5196 100644 --- a/gpflow/conditionals.py +++ b/gpflow/conditionals.py @@ -131,7 +131,7 @@ def _conditional(Xnew, X, kern, f, *, full_cov=False, q_sqrt=None, white=False): @sample_conditional.register(object, InducingFeature, Kernel, object) @name_scope("sample_conditional") -def _sample_conditional(Xnew, feat, kern, f, *, full_output_cov=False, q_sqrt=None, white=False): +def _sample_conditional(Xnew, feat, kern, f, *, full_cov=False, full_output_cov=False, q_sqrt=None, white=False, num_samples=None): """ `sample_conditional` will return a sample from the conditional distribution. In most cases this means calculating the conditional mean m and variance v and then @@ -139,21 +139,50 @@ def _sample_conditional(Xnew, feat, kern, f, *, full_output_cov=False, q_sqrt=No However, for some combinations of Mok and Mof more efficient sampling routines exists. The dispatcher will make sure that we use the most efficient one. - :return: N x P (full_output_cov = False) or N x P x P (full_output_cov = True) + :return: samples, mean, cov + samples has shape [num_samples, N, P] or [N, P] if num_samples is None + mean and cov as for conditional() """ + if full_cov and full_output_cov: + raise NotImplementedError("The combination of both full_cov and full_output_cov is not " + "implemented for sample_conditional.") + logger.debug("sample conditional: InducingFeature Kernel") - mean, var = conditional(Xnew, feat, kern, f, full_cov=False, full_output_cov=full_output_cov, - q_sqrt=q_sqrt, white=white) # N x P, N x P (x P) - cov_structure = "full" if full_output_cov else "diag" - return _sample_mvn(mean, var, cov_structure) + mean, cov = conditional(Xnew, feat, kern, f, q_sqrt=q_sqrt, white=white, + full_cov=full_cov, full_output_cov=full_output_cov) + if full_cov: + # mean: N x P + # cov: P x N x N + mean = tf.matrix_transpose(mean) # now P x N + samples = _sample_mvn(mean, cov, 'full', num_samples=num_samples) # (S x) P x N + samples = tf.matrix_transpose(samples) # now (S x) N x P + + else: + cov_structure = "full" if full_output_cov else "diag" + samples = _sample_mvn(mean, cov, cov_structure, num_samples=num_samples) # [(S,), N, P] + + return samples, mean, cov @sample_conditional.register(object, object, Kernel, object) @name_scope("sample_conditional") -def _sample_conditional(Xnew, X, kern, f, *, q_sqrt=None, white=False): +def _sample_conditional(Xnew, X, kern, f, *, q_sqrt=None, white=False, full_cov=False, full_output_cov=False, num_samples=None): + if full_cov and full_output_cov: + raise NotImplementedError("The combination of both full_cov and full_output_cov is not " + "implemented for sample_conditional.") + logger.debug("sample conditional: Kernel") - mean, var = conditional(Xnew, X, kern, f, q_sqrt=q_sqrt, white=white, full_cov=False) # N x P, N x P - return _sample_mvn(mean, var, "diag") # N x P + if full_output_cov: + raise NotImplementedError("full_output_cov is not implemented") + + mean, cov = conditional(Xnew, X, kern, f, q_sqrt=q_sqrt, white=white, full_cov=full_cov) + if full_cov: + mean = tf.matrix_transpose(mean) + cov_structure = "full" if full_cov else "diag" + samples = _sample_mvn(mean, cov, cov_structure, num_samples=num_samples) + if full_cov: + samples = tf.matrix_transpose(samples) + return samples, mean, cov # ---------------------------------------------------------------------------- @@ -207,7 +236,7 @@ def base_conditional(Kmn, Kmm, Knn, f, *, full_cov=False, q_sqrt=None, white=Fal if q_sqrt.get_shape().ndims == 2: LTA = A * tf.expand_dims(tf.transpose(q_sqrt), 2) # R x M x N elif q_sqrt.get_shape().ndims == 3: - L = tf.matrix_band_part(q_sqrt, -1, 0) # R x M x M + L = q_sqrt A_tiled = tf.tile(tf.expand_dims(A, 0), tf.stack([num_func, 1, 1])) LTA = tf.matmul(L, A_tiled, transpose_a=True) # R x M x N else: # pragma: no cover @@ -334,7 +363,7 @@ def uncertain_conditional(Xnew_mu, Xnew_var, feat, kern, q_mu, q_sqrt, *, ########################## HELPERS ############################## # --------------------------------------------------------------- -def _sample_mvn(mean, cov, cov_structure): +def _sample_mvn(mean, cov, cov_structure=None, num_samples=None): """ Returns a sample from a D-dimensional Multivariate Normal distribution :param mean: N x D @@ -344,17 +373,34 @@ def _sample_mvn(mean, cov, cov_structure): - "full": cov holds the full covariance matrix (without jitter) :return: sample from the MVN of shape N x D """ - eps = tf.random_normal(tf.shape(mean), dtype=settings.float_type) # N x P - if cov_structure == "diag": - sample = mean + tf.sqrt(cov) * eps # N x P - elif cov_structure == "full": - cov = cov + (tf.eye(tf.shape(mean)[1], dtype=settings.float_type) * settings.numerics.jitter_level)[None, ...] # N x P x P - chol = tf.cholesky(cov) # N x P x P - return mean + (tf.matmul(chol, eps[..., None])[..., 0]) # N x P - else: - raise NotImplementedError # pragma: no cover + mean_shape = tf.shape(mean) + cov_shape = tf.shape(cov) + N, D = mean_shape[0], mean_shape[1] + S = num_samples if num_samples is not None else 1 + # assert shape(cov) == (N, D) or (N, D, D) + with tf.control_dependencies([ + tf.Assert(tf.equal(cov_shape[0], N) & tf.reduce_all(tf.equal(cov_shape[1:], D)), + data=[mean_shape, cov_shape]) + ]): + + if cov_structure == "diag": + with tf.control_dependencies([tf.assert_equal(tf.rank(mean), tf.rank(cov))]): + eps = tf.random_normal([S, N, D], dtype=settings.float_type) # S x N x D + samples = mean + tf.sqrt(cov) * eps # S x N x D + elif cov_structure == "full": + with tf.control_dependencies([tf.assert_equal(tf.rank(mean) + 1, tf.rank(cov))]): + jittermat = settings.numerics.jitter_level * \ + tf.eye(D, batch_shape=[N], dtype=settings.float_type) # N x D x D + eps = tf.random_normal([N, D, S], dtype=settings.float_type) # N x D x S + chol = tf.cholesky(cov + jittermat) # N x D x D + samples = mean[..., None] + tf.matmul(chol, eps) # N x D x S + samples = tf.transpose(samples, [2, 0, 1]) # S x N x D + else: + raise NotImplementedError # pragma: no cover - return sample # N x P + if num_samples is None: + return samples[0] # N x D + return samples # S x N x D def _expand_independent_outputs(fvar, full_cov, full_output_cov): """ diff --git a/gpflow/multioutput/conditionals.py b/gpflow/multioutput/conditionals.py index a8c9fdd0d..bf321377c 100644 --- a/gpflow/multioutput/conditionals.py +++ b/gpflow/multioutput/conditionals.py @@ -258,7 +258,7 @@ def _conditional(Xnew, feat, kern, f, *, full_cov=False, full_output_cov=False, @sample_conditional.register(object, (MixedKernelSharedMof, MixedKernelSeparateMof), SeparateMixedMok, object) @name_scope("sample_conditional") -def _sample_conditional(Xnew, feat, kern, f, *, full_output_cov=False, q_sqrt=None, white=False): +def _sample_conditional(Xnew, feat, kern, f, *, full_cov=False, full_output_cov=False, q_sqrt=None, white=False, num_samples=None): """ `sample_conditional` will return a sample from the conditinoal distribution. In most cases this means calculating the conditional mean m and variance v and then @@ -269,13 +269,24 @@ def _sample_conditional(Xnew, feat, kern, f, *, full_output_cov=False, q_sqrt=No :return: N x P (full_output_cov = False) or N x P x P (full_output_cov = True) """ logger.debug("sample conditional: (MixedKernelSharedMof, MixedKernelSeparateMof), SeparateMixedMok") + if full_cov: + raise NotImplementedError("full_cov not yet implemented") + if full_output_cov: + raise NotImplementedError("full_output_cov not yet implemented") independent_cond = conditional.dispatch(object, SeparateIndependentMof, SeparateIndependentMok, object) g_mu, g_var = independent_cond(Xnew, feat, kern, f, white=white, q_sqrt=q_sqrt, full_output_cov=False, full_cov=False) # N x L, N x L - g_sample = _sample_mvn(g_mu, g_var, "diag") # N x L + g_sample = _sample_mvn(g_mu, g_var, "diag", num_samples=num_samples) # N x L with params_as_tensors_for(kern): f_sample = tf.einsum("pl,nl->np", kern.W, g_sample) - return f_sample + f_mu = tf.einsum("pl,nl->np", kern.W, g_mu) + # W g_var W.T + # [P, L] @ [L, L] @ [L, P] + # \sum_l,l' W_pl g_var_ll' W_p'l' + # \sum_l W_pl g_var_nl W_p'l + # -> + f_var = tf.einsum("pl,nl,pl->np", kern.W, g_var, kern.W) + return f_sample, f_mu, f_var # ----------------- diff --git a/tests/test_multioutput.py b/tests/test_multioutput.py index 843cbbc0a..259fbac39 100644 --- a/tests/test_multioutput.py +++ b/tests/test_multioutput.py @@ -175,7 +175,8 @@ class DataMixedKernel(Data): @pytest.mark.parametrize("cov_structure", ["full", "diag"]) -def test_sample_mvn(session_tf, cov_structure): +@pytest.mark.parametrize("num_samples", [None, 1, 10]) +def test_sample_mvn(session_tf, cov_structure, num_samples): """ Draws 10,000 samples from a distribution with known mean and covariance. The test checks @@ -190,8 +191,15 @@ def test_sample_mvn(session_tf, cov_structure): elif cov_structure == "diag": covs = tf.ones((N, D), dtype=float_type) - samples = _sample_mvn(means, covs, cov_structure) + samples = _sample_mvn(means, covs, cov_structure, num_samples=num_samples) value = session_tf.run(samples) + + if num_samples is None: + assert value.shape == (N, D) + else: + assert value.shape == (num_samples, N, D) + value = value.reshape(-1, D) + samples_mean = np.mean(value, axis=0) samples_cov = np.cov(value, rowvar=False) np.testing.assert_array_almost_equal(samples_mean, [1., 1.], decimal=1) @@ -206,11 +214,12 @@ def _create_feed_dict(placeholders_dict, value_dict): @pytest.mark.parametrize("whiten", [True, False]) -def test_sample_conditional(session_tf, whiten): +@pytest.mark.parametrize("full_cov,full_output_cov", [(False, False), (False, True), (True, False)]) +def test_sample_conditional(session_tf, whiten, full_cov, full_output_cov): q_mu = np.random.randn(Data.M , Data.P) # M x P q_sqrt = np.array([np.tril(np.random.randn(Data.M, Data.M)) for _ in range(Data.P)]) # P x M x M Z = Data.X[:Data.M, ...] # M x D - Xs = np.ones((int(10e5), Data.D), dtype=float_type) + Xs = np.ones((Data.N, Data.D), dtype=float_type) feature = InducingPoints(Z.copy()) kernel = RBF(Data.D) @@ -220,20 +229,29 @@ def test_sample_conditional(session_tf, whiten): feed_dict = _create_feed_dict(placeholders, values) # Path 1 - sample = sample_conditional(placeholders["Xnew"], placeholders["Z"], kernel, - placeholders["q_mu"], q_sqrt=placeholders["q_sqrt"], white=whiten) - value = session_tf.run(sample, feed_dict=feed_dict) + sample_f = sample_conditional(placeholders["Xnew"], feature, kernel, + placeholders["q_mu"], q_sqrt=placeholders["q_sqrt"], white=whiten, + full_cov=full_cov, full_output_cov=full_output_cov, num_samples=int(1e5)) + value_f, mean_f, var_f = session_tf.run(sample_f, feed_dict=feed_dict) + value_f = value_f.reshape((-1,) + value_f.shape[2:]) # Path 2 - sample2 = sample_conditional(placeholders["Xnew"], feature, kernel, - placeholders["q_mu"], q_sqrt=placeholders["q_sqrt"], white=whiten) - value2 = session_tf.run(sample2, feed_dict=feed_dict) + if full_output_cov: + pytest.skip("sample_conditional with X instead of feature does not support full_output_cov") + + sample_x = sample_conditional(placeholders["Xnew"], placeholders["Z"], kernel, + placeholders["q_mu"], q_sqrt=placeholders["q_sqrt"], white=whiten, + full_cov=full_cov, full_output_cov=full_output_cov, num_samples=int(1e5)) + value_x, mean_x, var_x = session_tf.run(sample_x, feed_dict=feed_dict) + value_x = value_x.reshape((-1,) + value_x.shape[2:]) # check if mean and covariance of samples are similar - np.testing.assert_array_almost_equal(np.mean(value, axis=0), - np.mean(value2, axis=0), decimal=1) - np.testing.assert_array_almost_equal(np.cov(value, rowvar=False), - np.cov(value2, rowvar=False), decimal=1) + np.testing.assert_array_almost_equal(np.mean(value_x, axis=0), + np.mean(value_f, axis=0), decimal=1) + np.testing.assert_array_almost_equal(np.cov(value_x, rowvar=False), + np.cov(value_f, rowvar=False), decimal=1) + np.testing.assert_allclose(mean_x, mean_f) + np.testing.assert_allclose(var_x, var_f) def test_sample_conditional_mixedkernel(session_tf): @@ -255,7 +273,7 @@ def test_sample_conditional_mixedkernel(session_tf): sample = sample_conditional(placeholders["Xnew"], mixed_feature, mixed_kernel, placeholders["q_mu"], q_sqrt=placeholders["q_sqrt"], white=True) - value = session_tf.run(sample, feed_dict=feed_dict) + value, mean, var = session_tf.run(sample, feed_dict=feed_dict) # Path 2: independent kernels, mixed later @@ -263,7 +281,7 @@ def test_sample_conditional_mixedkernel(session_tf): shared_feature = mf.SharedIndependentMof(InducingPoints(Z.copy())) sample2 = sample_conditional(placeholders["Xnew"], shared_feature, separate_kernel, placeholders["q_mu"], q_sqrt=placeholders["q_sqrt"], white=True) - value2 = session_tf.run(sample2, feed_dict=feed_dict) + value2, mean2, var2 = session_tf.run(sample2, feed_dict=feed_dict) value2 = np.matmul(value2, W.T) # check if mean and covariance of samples are similar np.testing.assert_array_almost_equal(np.mean(value, axis=0),