Skip to content

Commit

Permalink
Multi-output conditionals (GPflow#724)
Browse files Browse the repository at this point in the history
* Introduction of MultiOutputFeatures (Mof) and MultiOutputKernels (Mok).
These are used to specify a particular setup of multi-output correlation.

* Multiple-dispatch for conditional. This allows GPflow to select the most efficient conditional code depending on your choice of Mof and Mok.

* Multiple-dispatch for Kuu and Kuf. Previously Kuu(.) and Kuf(.) were member functions of the feature class. This became cumbersome as the calculation of Kuu and Kuf also depends on the kernel used. In line with conditional we now also use multiple-dispatch to calculate Kuu and Kuf for a particular combination of Mok and Mof.

* The actual maths to efficiently calculate the output-correlated conditional (credits to @markvdw )

* sample_conditional function that makes sure that the most efficient code is used to get a sample from the conditional distribution.

* Minor: we updated a couple of models to use the new multi-output conditional.
  • Loading branch information
vdutor authored and awav committed Jun 18, 2018
1 parent 6baeb43 commit bb08f22
Show file tree
Hide file tree
Showing 34 changed files with 3,021 additions and 243 deletions.
1 change: 1 addition & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[report]
omit = *tests*, setup.py
exclude_lines =
pragma: no cover
def __repr__
Expand Down
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ install:
- python setup.py install

script:
- pytest -W ignore::UserWarning --durations=5 --cov=./ -d --tx 3*popen//python=python3.6 --pyargs tests
- pytest -W ignore::UserWarning --durations=5 --cov=./gpflow -d --tx 3*popen//python=python3.6 --pyargs ./tests
- codecov --token=2ae2a756-f39c-467c-bd9c-4bdb3dc439c8

cache:
Expand Down
63 changes: 47 additions & 16 deletions doc/source/notebooks/multiclass.ipynb

Large diffs are not rendered by default.

1,133 changes: 1,133 additions & 0 deletions doc/source/notebooks/multioutput.ipynb

Large diffs are not rendered by default.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions gpflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,8 @@
from .params import DataHolder
from .params import Minibatch
from .params import Parameterized

from .saver import Saver
from .saver import SaverContext

from . import multioutput
274 changes: 217 additions & 57 deletions gpflow/conditionals.py

Large diffs are not rendered by default.

10 changes: 10 additions & 0 deletions gpflow/dispatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from multipledispatch import dispatch, Dispatcher
from functools import partial

# By default multipledispatch uses a global namespace in multipledispatch.core.global_namespace
# We define our own GPflow namespace to avoid any conflict which may arise
gpflow_md_namespace = dict()
dispatch = partial(dispatch, namespace=gpflow_md_namespace)

conditional = Dispatcher('conditional')
sample_conditional = Dispatcher('sample_conditional')
17 changes: 7 additions & 10 deletions gpflow/expectations.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,10 @@
from .quadrature import mvnquad
from .probability_distributions import Gaussian, DiagonalGaussian, MarkovGaussian

from multipledispatch import dispatch
from functools import partial
from .dispatch import dispatch

# By default multipledispatch uses a global namespace in multipledispatch.core.global_namespace
# We define our own GPflow namespace to avoid any conflict which may arise
gpflow_md_namespace = dict()
dispatch = partial(dispatch, namespace=gpflow_md_namespace)

logger = settings.logger()


# Sections:
Expand Down Expand Up @@ -113,8 +110,8 @@ def _quadrature_expectation(p, obj1, feature1, obj2, feature2, num_gauss_hermite
"""
num_gauss_hermite_points = 100 if num_gauss_hermite_points is None else num_gauss_hermite_points

warnings.warn("Quadrature is used to calculate the expectation. This means that "
"an analytical implementations is not available for the given combination.")
logger.warn("Quadrature is used to calculate the expectation. This means that "
"an analytical implementations is not available for the given combination.")

if obj2 is None:
eval_func = lambda x: get_eval_func(obj1, feature1)(x)
Expand Down Expand Up @@ -155,8 +152,8 @@ def _quadrature_expectation(p, obj1, feature1, obj2, feature2, num_gauss_hermite
"""
num_gauss_hermite_points = 40 if num_gauss_hermite_points is None else num_gauss_hermite_points

warnings.warn("Quadrature is used to calculate the expectation. This means that "
"an analytical implementations is not available for the given combination.")
logger.warn("Quadrature is used to calculate the expectation. This means that "
"an analytical implementations is not available for the given combination.")

if obj2 is None:
eval_func = lambda x: get_eval_func(obj1, feature1)(x)
Expand Down
136 changes: 63 additions & 73 deletions gpflow/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,18 @@
# limitations under the License.

from abc import abstractmethod
from functools import singledispatch
import warnings

import numpy as np
import tensorflow as tf

from . import conditionals, transforms, kernels, decors, settings
from . import transforms, kernels, settings
from .decors import params_as_tensors, params_as_tensors_for
from .params import Parameter, Parameterized
from .dispatch import dispatch


logger = settings.logger()


class InducingFeature(Parameterized):
Expand All @@ -35,23 +40,32 @@ def __len__(self) -> int:
"""
raise NotImplementedError()

@abstractmethod
def Kuu(self, kern, jitter=0.0):
"""
Calculates the covariance matrix between features for kernel `kern`.
Return shape M x M
M = len(feat)
"""
raise NotImplementedError()
warnings.warn('Please replace feature.Kuu(kernel) with Kuu(feature, kernel)',
DeprecationWarning)
return Kuu(self, kern, jitter=jitter)

@abstractmethod
def Kuf(self, kern, Xnew):
"""
Calculates the covariance matrix with function values at new points
`Xnew` for kernel `kern`.
Return shape M x N
M = len(feat)
N = len(Xnew)
"""
raise NotImplementedError()
warnings.warn('Please replace feature.Kuf(kernel, Xnew) with Kuf(feature, kernel, Xnew)',
DeprecationWarning)
return Kuf(self, kern, Xnew)


class InducingPoints(InducingFeature):
class InducingPointsBase(InducingFeature):
"""
Real-space inducing points
"""
Expand All @@ -66,19 +80,25 @@ def __init__(self, Z):
def __len__(self):
return self.Z.shape[0]

@decors.params_as_tensors
def Kuu(self, kern, jitter=0.0):
Kzz = kern.K(self.Z)
Kzz += jitter * tf.eye(len(self), dtype=settings.dtypes.float_type)
return Kzz

@decors.params_as_tensors
def Kuf(self, kern, Xnew):
Kzx = kern.K(self.Z, Xnew)
return Kzx
class InducingPoints(InducingPointsBase):
pass

@dispatch(InducingPoints, kernels.Kernel)
def Kuu(feat, kern, *, jitter=0.0):
with params_as_tensors_for(feat):
Kzz = kern.K(feat.Z)
Kzz += jitter * tf.eye(len(feat), dtype=settings.dtypes.float_type)
return Kzz

class Multiscale(InducingPoints):
@dispatch(InducingPoints, kernels.Kernel, object)
def Kuf(feat, kern, Xnew):
with params_as_tensors_for(feat):
Kzx = kern.K(feat.Z, Xnew)
return Kzx


class Multiscale(InducingPointsBase):
"""
Multi-scale inducing features
Originally proposed in
Expand All @@ -101,69 +121,39 @@ def __init__(self, Z, scales):
if self.Z.shape != scales.shape:
raise ValueError("Input locations `Z` and `scales` must have the same shape.") # pragma: no cover

def _cust_square_dist(self, A, B, sc):
@staticmethod
def _cust_square_dist(A, B, sc):
"""
Custom version of _square_dist that allows sc to provide per-datapoint length
scales. sc: N x M x D.
"""
return tf.reduce_sum(tf.square((tf.expand_dims(A, 1) - tf.expand_dims(B, 0)) / sc), 2)

@decors.params_as_tensors
def Kuf(self, kern, Xnew):
if isinstance(kern, kernels.RBF):
with decors.params_as_tensors_for(kern):
Xnew, _ = kern._slice(Xnew, None)
Zmu, Zlen = kern._slice(self.Z, self.scales)
idlengthscales = kern.lengthscales + Zlen
d = self._cust_square_dist(Xnew, Zmu, idlengthscales)
Kuf = tf.transpose(kern.variance * tf.exp(-d / 2) *
tf.reshape(tf.reduce_prod(kern.lengthscales / idlengthscales, 1),
(1, -1)))
return Kuf
else:
raise NotImplementedError(
"Multiscale features not implemented for `%s`." % str(type(kern)))

@decors.params_as_tensors
def Kuu(self, kern, jitter=0.0):
if isinstance(kern, kernels.RBF):
with decors.params_as_tensors_for(kern):
Zmu, Zlen = kern._slice(self.Z, self.scales)
idlengthscales2 = tf.square(kern.lengthscales + Zlen)
sc = tf.sqrt(
tf.expand_dims(idlengthscales2, 0) + tf.expand_dims(idlengthscales2, 1) - tf.square(
kern.lengthscales))
d = self._cust_square_dist(Zmu, Zmu, sc)
Kzz = kern.variance * tf.exp(-d / 2) * tf.reduce_prod(kern.lengthscales / sc, 2)
Kzz += jitter * tf.eye(len(self), dtype=settings.float_type)
return Kzz
else:
raise NotImplementedError(
"Multiscale features not implemented for `%s`." % str(type(kern)))


@singledispatch
def conditional(feat, kern, Xnew, f, *, full_cov=False, q_sqrt=None, white=False):
"""
Note the changed function signature compared to conditionals.conditional()
to allow for single dispatch on the first argument.
"""
raise NotImplementedError("No implementation for {} found".format(type(feat).__name__))


@conditional.register(InducingPoints)
@conditional.register(Multiscale)
def default_feature_conditional(feat, kern, Xnew, f, *, full_cov=False, q_sqrt=None, white=False):
"""
Uses the same code path as conditionals.conditional(), except Kuu/Kuf
matrices are constructed using the feature.
To use this with features defined in external modules, register your
feature class using
>>> gpflow.features.conditional.register(YourFeatureClass,
... gpflow.features.default_feature_conditional)
"""
return conditionals.feature_conditional(Xnew, feat, kern, f, full_cov=full_cov, q_sqrt=q_sqrt,
white=white)
@dispatch(Multiscale, kernels.RBF, object)
def Kuf(feat, kern, Xnew):
with params_as_tensors_for(feat, kern):
Xnew, _ = kern._slice(Xnew, None)
Zmu, Zlen = kern._slice(feat.Z, feat.scales)
idlengthscales = kern.lengthscales + Zlen
d = feat._cust_square_dist(Xnew, Zmu, idlengthscales)
Kuf = tf.transpose(kern.variance * tf.exp(-d / 2) *
tf.reshape(tf.reduce_prod(kern.lengthscales / idlengthscales, 1),
(1, -1)))
return Kuf

@dispatch(Multiscale, kernels.RBF)
def Kuu(feat, kern, *, jitter=0.0):
with params_as_tensors_for(feat, kern):
Zmu, Zlen = kern._slice(feat.Z, feat.scales)
idlengthscales2 = tf.square(kern.lengthscales + Zlen)
sc = tf.sqrt(
tf.expand_dims(idlengthscales2, 0) + tf.expand_dims(idlengthscales2, 1) - tf.square(
kern.lengthscales))
d = feat._cust_square_dist(Zmu, Zmu, sc)
Kzz = kern.variance * tf.exp(-d / 2) * tf.reduce_prod(kern.lengthscales / sc, 2)
Kzz += jitter * tf.eye(len(feat), dtype=settings.float_type)
return Kzz


def inducingpoint_wrapper(feat, Z):
Expand Down
2 changes: 0 additions & 2 deletions gpflow/gpflowrc
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ level = WARNING

[verbosity]
tf_compile_verb = False
hmc_verb = True
optimisation_verb = False

[dtypes]
float_type = float64
Expand Down
6 changes: 3 additions & 3 deletions gpflow/kullback_leiblers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def gauss_kl(q_mu, q_sqrt, K=None):
q_mu is a matrix (M x L), each column contains a mean.
q_sqrt can be a 3D tensor (L xM x M), each matrix within is a lower
q_sqrt can be a 3D tensor (L x M x M), each matrix within is a lower
triangular square-root matrix of the covariance of q.
q_sqrt can be a matrix (M x L), each column represents the diagonal of a
square-root matrix of the covariance of q.
Expand Down Expand Up @@ -70,7 +70,7 @@ def gauss_kl(q_mu, q_sqrt, K=None):
mahalanobis = tf.reduce_sum(tf.square(alpha))

# Constant term: - B * M
constant = tf.cast(-tf.size(q_mu, out_type=tf.int64), dtype=settings.float_type)
constant = - tf.cast(tf.size(q_mu, out_type=tf.int64), dtype=settings.float_type)

# Log-determinant of the covariance of q(x):
logdet_qcov = tf.reduce_sum(tf.log(tf.square(Lq_diag)))
Expand Down Expand Up @@ -101,4 +101,4 @@ def gauss_kl(q_mu, q_sqrt, K=None):
scale = 1.0 if batch else tf.cast(B, settings.float_type)
twoKL += scale * sum_log_sqdiag_Lp

return 0.5 * twoKL
return 0.5 * twoKL
9 changes: 6 additions & 3 deletions gpflow/logdensities.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@
from . import settings


logger = settings.logger()


def gaussian(x, mu, var):
return -0.5 * (np.log(2 * np.pi) + tf.log(var) + tf.square(mu-x)/var)
return -0.5 * (np.log(2 * np.pi) + tf.log(var) + tf.square(mu-x) / var)


def lognormal(x, mu, var):
Expand Down Expand Up @@ -86,11 +89,11 @@ def multivariate_normal(x, mu, L):
x[n] ~ N(mu, LL^T) or x ~ N(mu[n], LL^T) or x[n] ~ N(mu[n], LL^T)
"""
if x.shape.ndims is None:
warnings.warn('Shape of x must be 2D at computation.')
logger.warn('Shape of x must be 2D at computation.')
elif x.shape.ndims != 2:
raise ValueError('Shape of x must be 2D.')
if mu.shape.ndims is None:
warnings.warn('Shape of mu may be unknown or not 2D.')
logger.warn('Shape of mu may be unknown or not 2D.')
elif mu.shape.ndims != 2:
raise ValueError('Shape of mu must be 2D.')

Expand Down
22 changes: 8 additions & 14 deletions gpflow/models/gpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .. import likelihoods
from .. import settings

from ..conditionals import base_conditional
from ..params import DataHolder
from ..decors import params_as_tensors
from ..decors import name_scope
Expand Down Expand Up @@ -78,17 +79,10 @@ def _build_predict(self, Xnew, full_cov=False):
where F* are points on the GP at Xnew, Y are noisy observations at X.
"""
Kx = self.kern.K(self.X, Xnew)
K = self.kern.K(self.X) + tf.eye(tf.shape(self.X)[0], dtype=settings.float_type) * self.likelihood.variance
L = tf.cholesky(K)
A = tf.matrix_triangular_solve(L, Kx, lower=True)
V = tf.matrix_triangular_solve(L, self.Y - self.mean_function(self.X))
fmean = tf.matmul(A, V, transpose_a=True) + self.mean_function(Xnew)
if full_cov:
fvar = self.kern.K(Xnew) - tf.matmul(A, A, transpose_a=True)
shape = tf.stack([1, 1, tf.shape(self.Y)[1]])
fvar = tf.tile(tf.expand_dims(fvar, 2), shape)
else:
fvar = self.kern.Kdiag(Xnew) - tf.reduce_sum(tf.square(A), 0)
fvar = tf.tile(tf.reshape(fvar, (-1, 1)), [1, tf.shape(self.Y)[1]])
return fmean, fvar
y = self.Y - self.mean_function(self.X)
Kmn = self.kern.K(self.X, Xnew)
Kmm_sigma = self.kern.K(self.X) + tf.eye(tf.shape(self.X)[0], dtype=settings.float_type) * self.likelihood.variance
Knn = self.kern.K(Xnew) if full_cov else self.kern.Kdiag(Xnew)
f_mean, f_var = base_conditional(Kmn, Kmm_sigma, Knn, y, full_cov=full_cov, white=False) # N x P, N x P or P x N x N
return f_mean + self.mean_function(Xnew), f_var

4 changes: 2 additions & 2 deletions gpflow/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,11 @@ def predict_f_samples(self, Xnew, num_samples):
Produce samples from the posterior latent function(s) at the points
Xnew.
"""
mu, var = self._build_predict(Xnew, full_cov=True)
mu, var = self._build_predict(Xnew, full_cov=True) # N x P, # P x N x N
jitter = tf.eye(tf.shape(mu)[0], dtype=settings.float_type) * settings.numerics.jitter_level
samples = []
for i in range(self.num_latent):
L = tf.cholesky(var[:, :, i] + jitter)
L = tf.cholesky(var[i, :, :] + jitter)
shape = tf.stack([tf.shape(L)[0], num_samples])
V = tf.random_normal(shape, dtype=settings.float_type)
samples.append(mu[:, i:i + 1] + tf.matmul(L, V))
Expand Down
Loading

0 comments on commit bb08f22

Please sign in to comment.