Skip to content

Commit

Permalink
Improved quadrature for likelihoods (GPflow#736)
Browse files Browse the repository at this point in the history
* introduce ndiagquad to unify quadrature in likelihoods

* ndiagquad that can cope with several dimensions over which to integrate

* add pragma: no cover to error checking

* small change of signature

* fix for multi-parameter likelihoods

* preliminary quadrature test

* improved quadrature test

* fix weight normalisation in ndiagquad

* improve test

* add some type annotations

* more quadrature tests
  • Loading branch information
st-- authored and Mark van der Wilk committed May 4, 2018
1 parent 38b2e90 commit b9e5817
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 40 deletions.
51 changes: 14 additions & 37 deletions gpflow/likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .params import Parameter
from .params import Parameterized
from .params import ParamList
from .quadrature import ndiagquad
from .quadrature import hermgauss


Expand Down Expand Up @@ -60,21 +61,12 @@ def predict_mean_and_var(self, Fmu, Fvar):
Here, we implement a default Gauss-Hermite quadrature routine, but some
likelihoods (e.g. Gaussian) will implement specific cases.
"""
gh_x, gh_w = hermgauss(self.num_gauss_hermite_points)
gh_w /= np.sqrt(np.pi)
gh_w = gh_w.reshape(-1, 1)
shape = tf.shape(Fmu)
Fmu, Fvar = [tf.reshape(e, (-1, 1)) for e in (Fmu, Fvar)]
X = gh_x[None, :] * tf.sqrt(2.0 * Fvar) + Fmu

# here's the quadrature for the mean
E_y = tf.reshape(tf.matmul(self.conditional_mean(X), gh_w), shape)

# here's the quadrature for the variance
integrand = self.conditional_variance(X) \
+ tf.square(self.conditional_mean(X))
V_y = tf.reshape(tf.matmul(integrand, gh_w), shape) - tf.square(E_y)

integrand2 = lambda *X: self.conditional_variance(*X) \
+ tf.square(self.conditional_mean(*X))
E_y, E_y2 = ndiagquad([self.conditional_mean, integrand2],
self.num_gauss_hermite_points,
Fmu, Fvar)
V_y = E_y2 - tf.square(E_y)
return E_y, V_y

def predict_density(self, Fmu, Fvar, Y):
Expand All @@ -96,17 +88,10 @@ def predict_density(self, Fmu, Fvar, Y):
Here, we implement a default Gauss-Hermite quadrature routine, but some
likelihoods (Gaussian, Poisson) will implement specific cases.
"""
gh_x, gh_w = hermgauss(self.num_gauss_hermite_points)

gh_w = gh_w.reshape(-1, 1) / np.sqrt(np.pi)
shape = tf.shape(Fmu)
Fmu, Fvar, Y = [tf.reshape(e, (-1, 1)) for e in (Fmu, Fvar, Y)]
X = gh_x[None, :] * tf.sqrt(2.0 * Fvar) + Fmu

Y = tf.tile(Y, [1, self.num_gauss_hermite_points]) # broadcast Y to match X

logp = self.logp(X, Y)
return tf.reshape(tf.log(tf.matmul(tf.exp(logp), gh_w)), shape)
exp_p = ndiagquad(lambda X, Y: tf.exp(self.logp(X, Y)),
self.num_gauss_hermite_points,
Fmu, Fvar, Y=Y)
return tf.log(exp_p)

def variational_expectations(self, Fmu, Fvar, Y):
"""
Expand All @@ -128,17 +113,9 @@ def variational_expectations(self, Fmu, Fvar, Y):
Here, we implement a default Gauss-Hermite quadrature routine, but some
likelihoods (Gaussian, Poisson) will implement specific cases.
"""

gh_x, gh_w = hermgauss(self.num_gauss_hermite_points)
gh_x = gh_x.reshape(1, -1)
gh_w = gh_w.reshape(-1, 1) / np.sqrt(np.pi)
shape = tf.shape(Fmu)
Fmu, Fvar, Y = [tf.reshape(e, (-1, 1)) for e in (Fmu, Fvar, Y)]
X = gh_x * tf.sqrt(2.0 * Fvar) + Fmu
Y = tf.tile(Y, [1, self.num_gauss_hermite_points]) # broadcast Y to match X

logp = self.logp(X, Y)
return tf.reshape(tf.matmul(logp, gh_w), shape)
return ndiagquad(self.logp,
self.num_gauss_hermite_points,
Fmu, Fvar, Y=Y)


class Gaussian(Likelihood):
Expand Down
63 changes: 60 additions & 3 deletions gpflow/quadrature.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import print_function, absolute_import

import itertools
from collections import Iterable

import numpy as np
import tensorflow as tf
Expand All @@ -9,13 +10,13 @@
from .core.errors import GPflowError


def hermgauss(n):
def hermgauss(n: int):
x, w = np.polynomial.hermite.hermgauss(n)
x, w = x.astype(settings.float_type), w.astype(settings.float_type)
return x, w


def mvhermgauss(H, D):
def mvhermgauss(H: int, D: int):
"""
Return the evaluation locations 'xn', and weights 'wn' for a multivariate
Gauss-Hermite quadrature.
Expand All @@ -33,7 +34,7 @@ def mvhermgauss(H, D):
return x, w


def mvnquad(func, means, covs, H, Din=None, Dout=None):
def mvnquad(func, means, covs, H: int, Din: int=None, Dout=None):
"""
Computes N Gaussian expectation integrals of a single function 'f'
using Gauss-Hermite quadrature.
Expand Down Expand Up @@ -77,3 +78,59 @@ def mvnquad(func, means, covs, H, Din=None, Dout=None):
wr = np.reshape(wn * np.pi ** (-Din * 0.5),
(-1,) + (1,) * (1 + len(Dout)))
return tf.reduce_sum(fX * wr, 0)


def ndiagquad(funcs, H: int, Fmu, Fvar, **Ys):
"""
Computes N Gaussian expectation integrals of one or more functions
using Gauss-Hermite quadrature. The Gaussians must be independent.
:param funcs: Callable or Iterable of Callables that operates elementwise
:param H: number of Gauss-Hermite quadrature points
:param Fmu: array/tensor or `Din`-tuple/list thereof
:param Fvar: array/tensor or `Din`-tuple/list thereof
:param **Ys: arrays/tensors; deterministic arguments to be passed by name
Fmu, Fvar, Ys should all have same shape, with overall size `N`
:return: shape is the same as that of the first Fmu
"""
def unify(f_list):
"""
Stack a list of means/vars into a full block
"""
return tf.reshape(
tf.concat([tf.reshape(f, (-1, 1)) for f in f_list], axis=1),
(-1, 1, Din))

if isinstance(Fmu, (tuple, list)):
Din = len(Fmu)
shape = tf.shape(Fmu[0])
Fmu, Fvar = map(unify, [Fmu, Fvar]) # both N x 1 x Din
else:
Din = 1
shape = tf.shape(Fmu)
Fmu, Fvar = [tf.reshape(f, (-1, 1, 1)) for f in [Fmu, Fvar]]

xn, wn = mvhermgauss(H, Din)
# xn: H**Din x Din, wn: H**Din

gh_x = xn.reshape(1, -1, Din) # 1 x H**Din x Din
Xall = gh_x * tf.sqrt(2.0 * Fvar) + Fmu # N x H**Din x Din
Xs = [Xall[:, :, i] for i in range(Din)] # N x H**Din each

gh_w = wn.reshape(-1, 1) * np.pi ** (-0.5 * Din) # H**Din x 1

for name, Y in Ys.items():
Y = tf.reshape(Y, (-1, 1))
Y = tf.tile(Y, [1, H**Din]) # broadcast Y to match X
# without the tiling, some calls such as tf.where() (in bernoulli) fail
Ys[name] = Y # now N x H**Din

def eval_func(f):
feval = f(*Xs, **Ys) # f should be elementwise: return shape N x H**Din
return tf.reshape(tf.matmul(feval, gh_w), shape)

if isinstance(funcs, Iterable):
return [eval_func(f) for f in funcs]
else:
return eval_func(funcs)
55 changes: 55 additions & 0 deletions tests/test_quadrature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import numpy as np
import pytest
import tensorflow as tf
from numpy.testing import assert_allclose

import gpflow
from gpflow.test_util import session_context

@pytest.fixture
def mu1(): return np.array([1.0, 1.3])

@pytest.fixture
def mu2(): return np.array([-2.0, 0.3])

@pytest.fixture
def var1(): return np.array([3.0, 3.5])

@pytest.fixture
def var2(): return np.array([4.0, 4.2])

def cast(x):
return tf.cast(np.asarray(x), dtype=gpflow.settings.float_type)

def test_diagquad_1d(mu1, var1):
with session_context() as session:
quad = gpflow.quadrature.ndiagquad(
lambda *X: tf.exp(X[0]), 25,
[cast(mu1)], [cast(var1)])
res = session.run(quad)
expected = np.exp(mu1 + var1/2)
assert_allclose(res, expected, atol=1e-10)


def test_diagquad_2d(mu1, var1, mu2, var2):
with session_context() as session:
alpha = 2.5
quad = gpflow.quadrature.ndiagquad(
lambda *X: tf.exp(X[0] + alpha * X[1]), 35,
[cast(mu1), cast(mu2)], [cast(var1), cast(var2)])
res = session.run(quad)
expected = np.exp(mu1 + var1/2 + alpha * mu2 + alpha**2 * var2/2)
assert_allclose(res, expected, atol=1e-10)


def test_diagquad_with_kwarg(mu2, var2):
with session_context() as session:
alpha = np.array([2.5, -1.3])
quad = gpflow.quadrature.ndiagquad(
lambda X, Y: tf.exp(X * Y), 25,
cast(mu2), cast(var2), Y=alpha)
res = session.run(quad)
expected = np.exp(alpha * mu2 + alpha**2 * var2/2)
assert_allclose(res, expected, atol=1e-10)


0 comments on commit b9e5817

Please sign in to comment.