From 1851e0d03aeb50009765fbc73ddb281dc12f70c9 Mon Sep 17 00:00:00 2001 From: st-- Date: Wed, 27 Jun 2018 12:19:26 +0100 Subject: [PATCH] Monte-Carlo likelihoods (#799) * add @markvdw's stochastic likelihood, including the softmax * initial MC likelihood * remove MonteCarloLikelihood base class from tests * fix test * var of predict_mean_and_var and predict_density for MC likelihood * factor out MC sampling * add comment for variance bias * add tests * fixes * use same integration as for GH quadrature in MonteCarloLikelihood.predict_mean_and_var() * . * increase rtol * move to proper use of super() * move MC integration to quadrature module, similar to ndiagquad * seed to make test deterministic * add Assert for shape of Y * tidy up studentT likelihood * fix for heteroskedastic likelihoods -- requires logp to always call the Y argument Y * fix doc * add assert and equivalence tests for SoftMax * remove erroneously added file * rename "probit" to inv_probit (which is what it actually is) * add assert for num_classes to SoftMax * fix whitespace * Update RELEASE.md * Update RELEASE.md --- RELEASE.md | 4 + gpflow/likelihoods.py | 218 ++++++++++++++++++++++++++++---------- gpflow/logdensities.py | 15 ++- gpflow/quadrature.py | 46 ++++++++ tests/test_dataholders.py | 2 +- tests/test_likelihoods.py | 126 +++++++++++++++++++++- tests/test_variational.py | 2 +- 7 files changed, 344 insertions(+), 69 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index c00a43592..18029f65b 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,7 @@ +# Master branch + - Added likelihoods where expectations are evaluated with Monte Carlo (`MonteCarloLikelihood`) (#799). + - Added `SoftMax` likelihood (#799). + # Release 1.1 - Added inter-domain inducing features. Inducing points are used by default and are now set with `model.feature.Z`. diff --git a/gpflow/likelihoods.py b/gpflow/likelihoods.py index 0a381415c..4ddc40374 100644 --- a/gpflow/likelihoods.py +++ b/gpflow/likelihoods.py @@ -27,17 +27,17 @@ from .params import Parameter from .params import Parameterized from .params import ParamList -from .quadrature import ndiagquad +from .quadrature import ndiagquad, ndiag_mc from .quadrature import hermgauss class Likelihood(Parameterized): - def __init__(self, name=None): - super(Likelihood, self).__init__(name) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) self.num_gauss_hermite_points = 20 def predict_mean_and_var(self, Fmu, Fvar): - """ + r""" Given a Normal distribution for the latent function, return the mean of Y @@ -59,18 +59,17 @@ def predict_mean_and_var(self, Fmu, Fvar): Here, we implement a default Gauss-Hermite quadrature routine, but some likelihoods (e.g. Gaussian) will implement specific cases. """ - integrand2 = lambda *X: self.conditional_variance(*X) \ - + tf.square(self.conditional_mean(*X)) + integrand2 = lambda *X: self.conditional_variance(*X) + tf.square(self.conditional_mean(*X)) E_y, E_y2 = ndiagquad([self.conditional_mean, integrand2], - self.num_gauss_hermite_points, - Fmu, Fvar) + self.num_gauss_hermite_points, + Fmu, Fvar) V_y = E_y2 - tf.square(E_y) return E_y, V_y def predict_density(self, Fmu, Fvar, Y): - """ + r""" Given a Normal distribution for the latent function, and a datum Y, - compute the (log) predictive density of Y. + compute the log predictive density of Y. i.e. if q(f) = N(Fmu, Fvar) @@ -81,17 +80,17 @@ def predict_density(self, Fmu, Fvar, Y): then this method computes the predictive density - \int p(y=Y|f)q(f) df + \log \int p(y=Y|f)q(f) df Here, we implement a default Gauss-Hermite quadrature routine, but some likelihoods (Gaussian, Poisson) will implement specific cases. """ - return ndiagquad(lambda X, Y: self.logp(X, Y), - self.num_gauss_hermite_points, - Fmu, Fvar, logspace=True, Y=Y) + return ndiagquad(self.logp, + self.num_gauss_hermite_points, + Fmu, Fvar, logspace=True, Y=Y) def variational_expectations(self, Fmu, Fvar, Y): - """ + r""" Compute the expected log density of the data, given a Gaussian distribution for the function values. @@ -110,16 +109,16 @@ def variational_expectations(self, Fmu, Fvar, Y): Here, we implement a default Gauss-Hermite quadrature routine, but some likelihoods (Gaussian, Poisson) will implement specific cases. """ - return ndiagquad(lambda X, Y: self.logp(X, Y), - self.num_gauss_hermite_points, - Fmu, Fvar, Y=Y) + return ndiagquad(self.logp, + self.num_gauss_hermite_points, + Fmu, Fvar, Y=Y) class Gaussian(Likelihood): - def __init__(self, var=1.0, name=None): - super().__init__(name=name) + def __init__(self, variance=1.0, **kwargs): + super().__init__(**kwargs) self.variance = Parameter( - var, transform=transforms.positive, dtype=settings.float_type) + variance, transform=transforms.positive, dtype=settings.float_type) @params_as_tensors def logp(self, F, Y): @@ -162,8 +161,8 @@ class Poisson(Likelihood): of size 'binsize') and using this Poisson likelihood. """ - def __init__(self, invlink=tf.exp, binsize=1.): - Likelihood.__init__(self) + def __init__(self, invlink=tf.exp, binsize=1., **kwargs): + super().__init__(**kwargs) self.invlink = invlink self.binsize = np.double(binsize) @@ -183,8 +182,8 @@ def variational_expectations(self, Fmu, Fvar, Y): return super(Poisson, self).variational_expectations(Fmu, Fvar, Y) class Exponential(Likelihood): - def __init__(self, invlink=tf.exp): - super().__init__() + def __init__(self, invlink=tf.exp, **kwargs): + super().__init__(**kwargs) self.invlink = invlink def logp(self, F, Y): @@ -203,15 +202,19 @@ def variational_expectations(self, Fmu, Fvar, Y): class StudentT(Likelihood): - def __init__(self, deg_free=3.0): - Likelihood.__init__(self) - self.deg_free = deg_free - self.scale = Parameter(1.0, transform=transforms.positive, + def __init__(self, scale=1.0, df=3.0, **kwargs): + """ + :param scale float: scale parameter + :param df float: degrees of freedom + """ + super().__init__(**kwargs) + self.df = df + self.scale = Parameter(scale, transform=transforms.positive, dtype=settings.float_type) @params_as_tensors def logp(self, F, Y): - return logdensities.student_t(Y, F, self.scale, self.deg_free) + return logdensities.student_t(Y, F, self.scale, self.df) @params_as_tensors def conditional_mean(self, F): @@ -219,25 +222,26 @@ def conditional_mean(self, F): @params_as_tensors def conditional_variance(self, F): - var = self.scale**2 * (self.deg_free / (self.deg_free - 2.0)) + var = self.scale**2 * (self.df / (self.df - 2.0)) return tf.fill(tf.shape(F), tf.squeeze(var)) -def probit(x): - return 0.5 * (1.0 + tf.erf(x / np.sqrt(2.0))) * (1 - 2e-3) + 1e-3 +def inv_probit(x): + jitter = 1e-3 # ensures output is strictly between 0 and 1 + return 0.5 * (1.0 + tf.erf(x / np.sqrt(2.0))) * (1 - 2*jitter) + jitter class Bernoulli(Likelihood): - def __init__(self, invlink=probit): - Likelihood.__init__(self) + def __init__(self, invlink=inv_probit, **kwargs): + super().__init__(**kwargs) self.invlink = invlink def logp(self, F, Y): return logdensities.bernoulli(Y, self.invlink(F)) def predict_mean_and_var(self, Fmu, Fvar): - if self.invlink is probit: - p = probit(Fmu / tf.sqrt(1 + Fvar)) + if self.invlink is inv_probit: + p = inv_probit(Fmu / tf.sqrt(1 + Fvar)) return p, p - tf.square(p) else: # for other invlink, use quadrature @@ -260,8 +264,8 @@ class Gamma(Likelihood): Use the transformed GP to give the *scale* (inverse rate) of the Gamma """ - def __init__(self, invlink=tf.exp): - Likelihood.__init__(self) + def __init__(self, invlink=tf.exp, **kwargs): + super().__init__(**kwargs) self.invlink = invlink self.shape = Parameter(1.0, transform=transforms.positive) @@ -304,8 +308,8 @@ class Beta(Likelihood): beta = scale * (1-m) """ - def __init__(self, invlink=probit, scale=1.0): - Likelihood.__init__(self) + def __init__(self, invlink=inv_probit, scale=1.0, **kwargs): + super().__init__(**kwargs) self.scale = Parameter(scale, transform=transforms.positive) self.invlink = invlink @@ -339,8 +343,8 @@ class RobustMax(Parameterized): eps/(k-1) otherwise. """ - def __init__(self, num_classes, epsilon=1e-3, name=None): - super().__init__(name) + def __init__(self, num_classes, epsilon=1e-3, **kwargs): + super().__init__(**kwargs) self.epsilon = Parameter(epsilon, transforms.Logistic(), trainable=False, dtype=settings.float_type, prior=priors.Beta(0.2, 5.)) self.num_classes = num_classes @@ -383,13 +387,13 @@ def prob_is_largest(self, Y, mu, var, gh_x, gh_w): class MultiClass(Likelihood): - def __init__(self, num_classes, invlink=None): + def __init__(self, num_classes, invlink=None, **kwargs): """ A likelihood that can do multi-way classification. Currently the only valid choice of inverse-link function (invlink) is an instance of RobustMax. """ - Likelihood.__init__(self) + super().__init__(**kwargs) self.num_classes = num_classes if invlink is None: invlink = RobustMax(self.num_classes) @@ -451,12 +455,12 @@ def conditional_variance(self, F): class SwitchedLikelihood(Likelihood): - def __init__(self, likelihood_list): + def __init__(self, likelihood_list, **kwargs): """ In this likelihood, we assume at extra column of Y, which contains integers that specify a likelihood from the list of likelihoods. """ - Likelihood.__init__(self) + super().__init__(**kwargs) for l in likelihood_list: assert isinstance(l, Likelihood) self.likelihood_list = ParamList(likelihood_list) @@ -523,7 +527,7 @@ class Ordinal(Likelihood): ... p(Y=K|F) = 1 - phi((a_{K-1} - F) / sigma) - where phi is the cumulative density function of a Gaussian (the probit + where phi is the cumulative density function of a Gaussian (the inverse probit function) and sigma is a parameter to be learned. A reference is: @article{chu2005gaussian, @@ -536,13 +540,13 @@ class Ordinal(Likelihood): year={2005} } """ - def __init__(self, bin_edges): + def __init__(self, bin_edges, **kwargs): """ bin_edges is a numpy array specifying at which function value the - output label should switch. In the possible Y values are 0...K, then + output label should switch. If the possible Y values are 0...K, then the size of bin_edges should be (K-1). """ - Likelihood.__init__(self) + super().__init__(**kwargs) self.bin_edges = bin_edges self.num_bins = bin_edges.size + 1 self.sigma = Parameter(1.0, transform=transforms.positive) @@ -555,8 +559,8 @@ def logp(self, F, Y): selected_bins_left = tf.gather(scaled_bins_left, Y) selected_bins_right = tf.gather(scaled_bins_right, Y) - return tf.log(probit(selected_bins_left - F / self.sigma) - - probit(selected_bins_right - F / self.sigma) + 1e-6) + return tf.log(inv_probit(selected_bins_left - F / self.sigma) - + inv_probit(selected_bins_right - F / self.sigma) + 1e-6) @params_as_tensors def _make_phi(self, F): @@ -569,8 +573,8 @@ def _make_phi(self, F): """ scaled_bins_left = tf.concat([self.bin_edges / self.sigma, np.array([np.inf])], 0) scaled_bins_right = tf.concat([np.array([-np.inf]), self.bin_edges/self.sigma], 0) - return probit(scaled_bins_left - tf.reshape(F, (-1, 1)) / self.sigma)\ - - probit(scaled_bins_right - tf.reshape(F, (-1, 1)) / self.sigma) + return inv_probit(scaled_bins_left - tf.reshape(F, (-1, 1)) / self.sigma) \ + - inv_probit(scaled_bins_right - tf.reshape(F, (-1, 1)) / self.sigma) def conditional_mean(self, F): phi = self._make_phi(F) @@ -583,3 +587,107 @@ def conditional_variance(self, F): E_y = tf.matmul(phi, Ys) E_y2 = tf.matmul(phi, tf.square(Ys)) return tf.reshape(E_y2 - tf.square(E_y), tf.shape(F)) + + +class MonteCarloLikelihood(Likelihood): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.num_monte_carlo_points = 100 + del self.num_gauss_hermite_points + + def _mc_quadrature(self, funcs, Fmu, Fvar, logspace: bool=False, epsilon=None, **Ys): + return ndiag_mc(funcs, self.num_monte_carlo_points, Fmu, Fvar, logspace, epsilon, **Ys) + + def predict_mean_and_var(self, Fmu, Fvar, epsilon=None): + r""" + Given a Normal distribution for the latent function, + return the mean of Y + + if + q(f) = N(Fmu, Fvar) + + and this object represents + + p(y|f) + + then this method computes the predictive mean + + \int\int y p(y|f)q(f) df dy + + and the predictive variance + + \int\int y^2 p(y|f)q(f) df dy - [ \int\int y^2 p(y|f)q(f) df dy ]^2 + + Here, we implement a default Monte Carlo routine. + """ + integrand2 = lambda *X: self.conditional_variance(*X) + tf.square(self.conditional_mean(*X)) + E_y, E_y2 = self._mc_quadrature([self.conditional_mean, integrand2], + Fmu, Fvar, epsilon=epsilon) + V_y = E_y2 - tf.square(E_y) + return E_y, V_y # N x D + + def predict_density(self, Fmu, Fvar, Y, epsilon=None): + r""" + Given a Normal distribution for the latent function, and a datum Y, + compute the log predictive density of Y. + + i.e. if + q(f) = N(Fmu, Fvar) + + and this object represents + + p(y|f) + + then this method computes the predictive density + + \log \int p(y=Y|f)q(f) df + + Here, we implement a default Monte Carlo routine. + """ + return self._mc_quadrature(self.logp, Fmu, Fvar, Y=Y, logspace=True, epsilon=epsilon) + + def variational_expectations(self, Fmu, Fvar, Y, epsilon=None): + r""" + Compute the expected log density of the data, given a Gaussian + distribution for the function values. + + if + q(f) = N(Fmu, Fvar) - Fmu: N x D Fvar: N x D + + and this object represents + + p(y|f) - Y: N x 1 + + then this method computes + + \int (\log p(y|f)) q(f) df. + + + Here, we implement a default Monte Carlo quadrature routine. + """ + return self._mc_quadrature(self.logp, Fmu, Fvar, Y=Y, epsilon=epsilon) + + +class GaussianMC(MonteCarloLikelihood, Gaussian): + """ + Stochastic version of Gaussian likelihood for comparison. + """ + pass + + +class SoftMax(MonteCarloLikelihood): + """ + The soft-max multi-class likelihood. + """ + + def __init__(self, num_classes, **kwargs): + super().__init__(**kwargs) + self.num_classes = num_classes + + def logp(self, F, Y): + with tf.control_dependencies([tf.assert_equal(tf.shape(Y)[1], 1), + tf.assert_equal(tf.shape(F)[1], self.num_classes)]): + return -tf.nn.sparse_softmax_cross_entropy_with_logits(logits=F, labels=Y[:, 0])[:, None] + + def conditional_mean(self, F): + return tf.nn.softmax(F) diff --git a/gpflow/logdensities.py b/gpflow/logdensities.py index 1305738d7..44f989c05 100644 --- a/gpflow/logdensities.py +++ b/gpflow/logdensities.py @@ -46,18 +46,17 @@ def exponential(x, scale): def gamma(x, shape, scale): - return -shape * tf.log(scale) - tf.lgamma(shape)\ + return -shape * tf.log(scale) - tf.lgamma(shape) \ + (shape - 1.) * tf.log(x) - x / scale -def student_t(x, mean, scale, deg_free): - const = tf.lgamma(tf.cast((deg_free + 1.) * 0.5, settings.float_type))\ - - tf.lgamma(tf.cast(deg_free * 0.5, settings.float_type))\ - - 0.5*(tf.log(tf.square(scale)) + tf.cast(tf.log(deg_free), settings.float_type) - + np.log(np.pi)) +def student_t(x, mean, scale, df): + df = tf.cast(df, settings.float_type) + const = tf.lgamma((df + 1.) * 0.5) - tf.lgamma(df * 0.5) \ + - 0.5 * (tf.log(tf.square(scale)) + tf.log(df) + np.log(np.pi)) const = tf.cast(const, settings.float_type) - return const - 0.5*(deg_free + 1.) * \ - tf.log(1. + (1. / deg_free) * (tf.square((x - mean) / scale))) + return const - 0.5 * (df + 1.) * \ + tf.log(1. + (1. / df) * (tf.square((x - mean) / scale))) def beta(x, alpha, beta): diff --git a/gpflow/quadrature.py b/gpflow/quadrature.py index 7051a7f90..80346fed7 100644 --- a/gpflow/quadrature.py +++ b/gpflow/quadrature.py @@ -154,3 +154,49 @@ def eval_func(f): return [eval_func(f) for f in funcs] else: return eval_func(funcs) + + +def ndiag_mc(funcs, S: int, Fmu, Fvar, logspace: bool=False, epsilon=None, **Ys): + """ + Computes N Gaussian expectation integrals of one or more functions + using Monte Carlo samples. The Gaussians must be independent. + + :param funcs: the integrand(s): + Callable or Iterable of Callables that operates elementwise + :param S: number of Monte Carlo sampling points + :param Fmu: array/tensor + :param Fvar: array/tensor + :param logspace: if True, funcs are the log-integrands and this calculates + the log-expectation of exp(funcs) + :param **Ys: arrays/tensors; deterministic arguments to be passed by name + + Fmu, Fvar, Ys should all have same shape, with overall size `N` + :return: shape is the same as that of the first Fmu + """ + N, D = tf.shape(Fmu)[0], tf.shape(Fvar)[1] + + if epsilon is None: + epsilon = tf.random_normal((S, N, D), dtype=settings.float_type) + + mc_x = Fmu[None, :, :] + tf.sqrt(Fvar[None, :, :]) * epsilon + mc_Xr = tf.reshape(mc_x, (S * N, D)) + + for name, Y in Ys.items(): + Y = tf.reshape(Y, (-1, 1)) + mc_Yr = tf.tile(Y, [S, 1]) # broadcast Y to match X + # without the tiling, some calls such as tf.where() (in bernoulli) fail + Ys[name] = mc_Yr # now S * N x D + + def eval_func(func): + feval = func(mc_Xr, **Ys) + feval = tf.reshape(feval, (S, N, D)) + if logspace: + log_S = tf.log(tf.cast(S, settings.float_type)) + return tf.reduce_logsumexp(feval, axis=0) - log_S # N x D + else: + return tf.reduce_mean(feval, axis=0) + + if isinstance(funcs, Iterable): + return [eval_func(f) for f in funcs] + else: + return eval_func(funcs) diff --git a/tests/test_dataholders.py b/tests/test_dataholders.py index 9e41cb8d6..cf1cdf755 100644 --- a/tests/test_dataholders.py +++ b/tests/test_dataholders.py @@ -201,7 +201,7 @@ def test_change_variable_size(self): m.X = gpflow.Minibatch(arr, shuffle=False) for i in range(length): assert_allclose(m.X.read_value(session=session), [arr[i]]) - + length = 20 arr = np.random.randn(length, 2) m.X = arr diff --git a/tests/test_likelihoods.py b/tests/test_likelihoods.py index dbada9c61..a28909eeb 100644 --- a/tests/test_likelihoods.py +++ b/tests/test_likelihoods.py @@ -36,6 +36,8 @@ def getLikelihoodSetups(includeMultiClass=True, addNonStandardLinks=False): test_setups = [] rng = np.random.RandomState(1) for likelihoodClass in gpflow.likelihoods.Likelihood.__subclasses__(): + if likelihoodClass == gpflow.likelihoods.MonteCarloLikelihood: + continue # abstract base class if likelihoodClass == gpflow.likelihoods.Ordinal: test_setups.append( LikelihoodSetup(likelihoodClass(np.array([-1, 1])), @@ -130,8 +132,8 @@ def test_var_exp(self): class TestQuadrature(GPflowTestCase): """ - Where quadratre methods have been overwritten, make sure the new code - does something close to the quadrature + Where quadrature methods have been overwritten, make sure the new code + does something close to the quadrature """ def setUp(self): self.test_graph = tf.Graph() @@ -176,6 +178,124 @@ def test_pred_density(self): F2 = session.run(F2) assert_allclose(F1, F2, test_setup.tolerance, test_setup.tolerance) + def test_pred_mean_and_var(self): + # get all the likelihoods where predict_density has been overwritten. + for test_setup in self.test_setups: + with self.test_context() as session: + if not test_setup.is_analytic: + continue + l = test_setup.likelihood + l.compile() + # 'build' the functions + F1 = l.predict_mean_and_var(self.Fmu, self.Fvar) + F2 = gpflow.likelihoods.Likelihood.predict_mean_and_var(l, self.Fmu, self.Fvar) + # compile and run the functions: + F1 = session.run(F1) + F2 = session.run(F2) + assert_allclose(F1, F2, test_setup.tolerance, test_setup.tolerance) + + +class TestMonteCarlo(GPflowTestCase): + def setUp(self): + self.test_graph = tf.Graph() + self.rng = np.random.RandomState() + self.rng.seed(1) + self.Fmu, self.Fvar, self.Y = self.rng.randn(3, 10, 1).astype(settings.float_type) + self.Fvar = 0.01 * (self.Fvar ** 2) + + def test_var_exp(self): + with self.test_context() as session: + tf.set_random_seed(1) + l = gpflow.likelihoods.GaussianMC(0.3) + l.num_monte_carlo_points = 1000000 + # 'build' the functions + l.compile() + F1 = l.variational_expectations(self.Fmu, self.Fvar, self.Y) + F2 = gpflow.likelihoods.Gaussian.variational_expectations( + l, self.Fmu, self.Fvar, self.Y) + # compile and run the functions: + F1 = session.run(F1) + F2 = session.run(F2) + assert_allclose(F1, F2, rtol=5e-4, atol=1e-4) + + def test_pred_density(self): + with self.test_context() as session: + tf.set_random_seed(1) + l = gpflow.likelihoods.GaussianMC(0.3) + l.num_monte_carlo_points = 1000000 + l.compile() + # 'build' the functions + F1 = l.predict_density(self.Fmu, self.Fvar, self.Y) + F2 = gpflow.likelihoods.Gaussian.predict_density(l, self.Fmu, self.Fvar, self.Y) + # compile and run the functions: + F1 = session.run(F1) + F2 = session.run(F2) + assert_allclose(F1, F2, rtol=5e-4, atol=1e-4) + + def test_pred_mean_and_var(self): + with self.test_context() as session: + tf.set_random_seed(1) + l = gpflow.likelihoods.GaussianMC(0.3) + l.num_monte_carlo_points = 1000000 + l.compile() + # 'build' the functions + F1 = l.predict_mean_and_var(self.Fmu, self.Fvar) + F2 = gpflow.likelihoods.Gaussian.predict_mean_and_var(l, self.Fmu, self.Fvar) + # compile and run the functions: + F1m, F1v = session.run(F1) + F2m, F2v = session.run(F2) + assert_allclose(F1m, F2m, rtol=5e-4, atol=1e-4) + assert_allclose(F1v, F2v, rtol=5e-4, atol=1e-4) + + +class TestSoftMax(GPflowTestCase): + def setUp(self): + self.test_graph = tf.Graph() + self.rng = np.random.RandomState(1) + + def prepare(self, dimF, dimY, num=10): + feed = {} + + def make_tensor(data, dtype=settings.float_type): + tensor = tf.placeholder(dtype) + feed[tensor] = data.astype(dtype) + return tensor + + F = make_tensor(self.rng.randn(num, dimF)) + Y = make_tensor(self.rng.randn(num, dimY) > 0, settings.int_type) # 0 or 1 + return F, Y, feed + + def test_y_shape_assert(self): + with self.test_context() as sess: + F, Y, feed = self.prepare(dimF=5, dimY=2) + l = gpflow.likelihoods.SoftMax(5) + l.compile() + try: + sess.run(l.logp(F, Y), feed_dict=feed) + except tf.errors.InvalidArgumentError as e: + assert "assertion failed" in e.message + + def test_bernoulli_equivalence(self): + with self.test_context() as sess: + F, Y, feed = self.prepare(dimF=2, dimY=1) + + q = tf.exp(F[:,0] - F[:,1])[:,None] + p = 1. / (1. + q) + + l = gpflow.likelihoods.SoftMax(2) + l.compile() + + logp_softmax = sess.run(l.logp(F, Y), feed_dict=feed) + logp_bernoulli = sess.run(gpflow.logdensities.bernoulli(Y, p), feed_dict=feed) + + assert_allclose(logp_softmax, logp_bernoulli) + + cm_softmax = sess.run(l.conditional_mean(F), feed_dict=feed) + p_np = sess.run(p, feed_dict=feed) + cm_bernoulli = np.c_[1 - p_np, p_np] + + assert_allclose(cm_softmax, cm_bernoulli) + class TestRobustMaxMulticlass(GPflowTestCase): """ @@ -273,8 +393,6 @@ def testEpsK1Changes(self): - - class TestMulticlassIndexFix(GPflowTestCase): """ A regression test for a bug in multiclass likelihood. diff --git a/tests/test_variational.py b/tests/test_variational.py index 1e3f67f84..4e17d7639 100644 --- a/tests/test_variational.py +++ b/tests/test_variational.py @@ -84,7 +84,7 @@ class VariationalUnivariateTest(GPflowTestCase): posteriorStd = np.sqrt(posteriorVariance) def likelihood(self): - return gpflow.likelihoods.Gaussian(var=self.noiseVariance) + return gpflow.likelihoods.Gaussian(variance=self.noiseVariance) def get_model(self, is_diagonal, is_whitened): m = gpflow.models.SVGP(