Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Oct 8, 2023
1 parent 72e6b94 commit acc9779
Showing 1 changed file with 45 additions and 5 deletions.
50 changes: 45 additions & 5 deletions tests/test_hskgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@
from gpax.utils import get_keys


def get_dummy_data(jax_ndarray=True, unsqueeze=False):
def get_dummy_data(unsqueeze=False):
X = onp.linspace(1, 2, 8) + 0.1 * onp.random.randn(8,)
y = (10 * X**2)
if unsqueeze:
X = X[:, None]
if jax_ndarray:
return jnp.array(X), jnp.array(y)
return X, y
return jnp.array(X), jnp.array(y)


def noise_fn(x, params):
return params["a"] + params["b"]*x
Expand All @@ -41,14 +40,31 @@ def test_fit(noise_kernel):
assert m.mcmc is not None


def test_fit_with_mean_fn():
def test_fit_with_custom_noise_lscale():
rng_key = get_keys()[0]
X, y = get_dummy_data()
m = VarNoiseGP(1, 'RBF', noise_lengthscale_prior_dist=dist.HalfNormal(1))
m.fit(rng_key, X, y, num_warmup=10, num_samples=10)
assert m.mcmc is not None


def test_fit_with_noise_mean_fn():
rng_key = get_keys()[0]
X, y = get_dummy_data()
m = VarNoiseGP(1, 'RBF', noise_mean_fn=noise_fn, noise_mean_fn_prior=noise_fn_prior)
m.fit(rng_key, X, y, num_warmup=10, num_samples=10)
assert m.mcmc is not None


def test_fit_with_noise_and_regular_mean_fn():
rng_key = get_keys()[0]
X, y = get_dummy_data()
m = VarNoiseGP(1, 'RBF', mean_fn = lambda x: 8*x**2,
noise_mean_fn=noise_fn, noise_mean_fn_prior=noise_fn_prior)
m.fit(rng_key, X, y, num_warmup=10, num_samples=10)
assert m.mcmc is not None


def test_get_mvn_posterior():
X, y = get_dummy_data(unsqueeze=True)
X_test, _ = get_dummy_data(unsqueeze=True)
Expand Down Expand Up @@ -90,6 +106,30 @@ def test_get_mvn_posterior_with_mean_fn():
assert_equal(cov.shape, (X_test.shape[0], X_test.shape[0]))


def test_get_mvn_posterior_with_noise_and_regular_mean_fn():
X, y = get_dummy_data(unsqueeze=True)
X_test, _ = get_dummy_data(unsqueeze=True)
params = {"k_length": jnp.array([1.0]),
"k_scale": jnp.array(1.0),
"noise": jnp.array(0.1),
"k_noise_length": jnp.array(0.5),
"k_noise_scale": jnp.array(1.0),
"log_var": jnp.ones(len(X)),
"a": jnp.array(1.0),
"b": jnp.array(1.0)
}
m = VarNoiseGP(1, 'RBF', noise_kernel='RBF',
mean_fn = lambda x: 8*x**2,
noise_mean_fn=noise_fn, noise_mean_fn_prior=noise_fn_prior)
m.X_train = X
m.y_train = y
mean, cov = m.get_mvn_posterior(X_test, params)
assert isinstance(mean, jnp.ndarray)
assert isinstance(cov, jnp.ndarray)
assert_equal(mean.shape, (X_test.shape[0],))
assert_equal(cov.shape, (X_test.shape[0], X_test.shape[0]))


def test_get_noise_samples():
rng_key = get_keys()[0]
X, y = get_dummy_data()
Expand Down

0 comments on commit acc9779

Please sign in to comment.