Skip to content

Commit

Permalink
Removed jitter in nat grads (GPflow#768)
Browse files Browse the repository at this point in the history
  • Loading branch information
hughsalimbeni authored and awav committed Jun 19, 2018
1 parent 61088fd commit 3a16991
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 14 deletions.
15 changes: 2 additions & 13 deletions gpflow/training/natgrad_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def natural_to_meanvarsqrt(nat_1, nat_2):
mu = tf.matmul(S, nat_1)
# We need the decomposition of S as L L^T, not as L^T L,
# hence we need another cholesky.
return mu, _cholesky_with_jitter(S)
return mu, tf.cholesky(S)


@swap_dimensions
Expand All @@ -370,25 +370,14 @@ def expectation_to_natural(eta_1, eta_2):
@swap_dimensions
def expectation_to_meanvarsqrt(eta_1, eta_2):
var = eta_2 - tf.matmul(eta_1, eta_1, transpose_b=True)
return eta_1, _cholesky_with_jitter(var)
return eta_1, tf.cholesky(var)


@swap_dimensions
def meanvarsqrt_to_expectation(m, v_sqrt):
v = tf.matmul(v_sqrt, v_sqrt, transpose_b=True)
return m, v + tf.matmul(m, m, transpose_b=True)


def _cholesky_with_jitter(M):
"""
Add jitter and take Cholesky
:param M: Tensor of shape NxNx...N
:return: The Cholesky decomposition of the input `M`. It's a `tf.Tensor` of shape ...xNxN
"""
N = tf.shape(M)[-1]
return tf.cholesky(M + settings.jitter * tf.eye(N, dtype=M.dtype))

def _inverse_lower_triangular(M):
"""
Take inverse of lower triangular (e.g. Cholesky) matrix. This function
Expand Down
27 changes: 26 additions & 1 deletion tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,31 @@ def test_scipy_optimizer_options(session_tf):
assert o1.optimizer.optimizer_kwargs['options'][gtol] == gtol_value
assert gtol not in o2.optimizer.optimizer_kwargs['options']

def test_small_q_sqrt_handeled_correctly(session_tf):
"""
This is an extra test to make sure things still work when q_sqrt is small. This was breaking (#767)
"""
N, D = 3, 2
X = np.random.randn(N, D)
Y = np.random.randn(N, 1)
kern = gpflow.kernels.RBF(D)
lik_var = 0.1
lik = gpflow.likelihoods.Gaussian()
lik.variance = lik_var

m_vgp = gpflow.models.VGP(X, Y, kern, lik)
m_gpr = gpflow.models.GPR(X, Y, kern)
m_gpr.likelihood.variance = lik_var

m_vgp.set_trainable(False)
m_vgp.q_mu.set_trainable(True)
m_vgp.q_sqrt.set_trainable(True)
m_vgp.q_mu = np.random.randn(N, 1)
m_vgp.q_sqrt = np.eye(N)[None, :, :] * 1e-3
NatGradOptimizer(1.).minimize(m_vgp, [(m_vgp.q_mu, m_vgp.q_sqrt)], maxiter=1)

assert_allclose(m_gpr.compute_log_likelihood(),
m_vgp.compute_log_likelihood(), atol=1e-4)

def test_VGP_vs_GPR(session_tf):
"""
Expand All @@ -269,7 +294,7 @@ def test_VGP_vs_GPR(session_tf):
NatGradOptimizer(1.).minimize(m_vgp, [(m_vgp.q_mu, m_vgp.q_sqrt)], maxiter=1)

assert_allclose(m_gpr.compute_log_likelihood(),
m_vgp.compute_log_likelihood(), atol=1e-5)
m_vgp.compute_log_likelihood(), atol=1e-4)


def test_other_XiTransform_VGP_vs_GPR(session_tf, xi_transform=XiSqrtMeanVar()):
Expand Down

0 comments on commit 3a16991

Please sign in to comment.