Skip to content

Commit

Permalink
Merge pull request #66 from ziatdinovmax/varinput
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax authored Jan 5, 2024
2 parents 1055f1b + 6c35247 commit d19471b
Show file tree
Hide file tree
Showing 4 changed files with 265 additions and 2 deletions.
2 changes: 1 addition & 1 deletion gpax/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = '0.1.2'
version = '0.1.3'
4 changes: 3 additions & 1 deletion gpax/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .vi_mtdkl import viMTDKL
from .mtgp import MultiTaskGP
from .corgp import CoregGP
from .uigp import UIGP

__all__ = [
"ExactGP",
Expand All @@ -23,5 +24,6 @@
"viDKL",
"viMTDKL",
"MultiTaskGP",
"CoregGP"
"CoregGP",
"UIGP"
]
183 changes: 183 additions & 0 deletions gpax/models/uigp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
"""
uigp.py
=======
Fully Bayesian implementation of Gaussian process regression with uncertain (stochastic) inputs
Created by Maxim Ziatdinov (email: [email protected])
"""

import warnings
from typing import Callable, Dict, Optional, Tuple, Union

import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist

from . import ExactGP

kernel_fn_type = Callable[[jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray], jnp.ndarray], jnp.ndarray]


class UIGP(ExactGP):
"""
Gaussian process with uncertain inputs
This class extends the standard Gaussian Process model to handle uncertain inputs.
It allows for incorporating the uncertainty in input data into the GP model, providing
a more robust prediction.
Args:
input_dim:
Number of input dimensions
kernel:
Kernel function ('RBF', 'Matern', 'Periodic', or custom function)
mean_fn:
Optional deterministic mean function (use 'mean_fn_priors' to make it probabilistic)
kernel_prior:
Optional custom priors over kernel hyperparameters. Use it when passing your custom kernel.
mean_fn_prior:
Optional priors over mean function parameters
noise_prior_dist:
Optional custom prior distribution over observational noise. Defaults to LogNormal(0,1).
lengthscale_prior_dist:
Optional custom prior distribution over kernel lengthscale. Defaults to LogNormal(0, 1).
sigma_x_prior_dist:
Optional custom prior for the input uncertainty (sigma_x). Defaults to HalfNormal(0.1)
under the assumption that data is normalized to (0, 1).
Examples:
UIGP with custom prior over sigma_x
>>> # Get random number generator keys for training and prediction
>>> rng_key, rng_key_predict = gpax.utils.get_keys()
>>> # Initialize model
>>> gp_model = gpax.UIGP(input_dim=1, kernel='Matern', sigma_x_prior_dist=gpax.utils.halfnormal_dist(0.5))
>>> # Run HMC to obtain posterior samples for the model parameters
>>> gp_model.fit(rng_key, X, y) # X and y are arrays with dimensions (n, m) and (n,)
>>> # Make a prediction on new inputs (n>>1 for meaningful MCMC averaging over sampled X_new)
>>> y_pred, y_samples = gp_model.predict(rng_key_predict, X_new, n=200)
"""
def __init__(self,
input_dim: int,
kernel: Union[str, kernel_fn_type],
mean_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None,
kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
mean_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
noise_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
noise_prior_dist: Optional[dist.Distribution] = None,
lengthscale_prior_dist: Optional[dist.Distribution] = None,
sigma_x_prior_dist: Optional[dist.Distribution] = None
) -> None:
args = (input_dim, kernel, mean_fn, kernel_prior, mean_fn_prior, noise_prior, noise_prior_dist, lengthscale_prior_dist)
super(UIGP, self).__init__(*args)
self.sigma_x_prior_dist = sigma_x_prior_dist

def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs: float) -> None:
"""
Gaussian process model for uncertain (stochastic) inputs
"""
# Initialize mean function at zeros
f_loc = jnp.zeros(X.shape[0])

# Sample input X
X_prime = self._sample_x(X)

# Sample kernel parameters
if self.kernel_prior:
kernel_params = self.kernel_prior()
else:
kernel_params = self._sample_kernel_params()
# Sample noise
if self.noise_prior: # this will be removed in the future releases
noise = self.noise_prior()
else:
noise = self._sample_noise()
# Add mean function (if any)
if self.mean_fn is not None:
args = [X_prime]
if self.mean_fn_prior is not None:
args += [self.mean_fn_prior()]
f_loc += self.mean_fn(*args).squeeze()
# compute kernel
k = self.kernel(X_prime, X_prime, kernel_params, noise, **kwargs)
# sample y according to the standard Gaussian process formula
numpyro.sample(
"y",
dist.MultivariateNormal(loc=f_loc, covariance_matrix=k),
obs=y,
)

def _sample_x(self, X):
if self.sigma_x_prior_dist is not None:
sigma_x_dist = self.sigma_x_prior_dist
else:
sigma_x_dist = dist.HalfNormal(.1)
sigma_x = numpyro.sample("sigma_x", sigma_x_dist)
return numpyro.sample("X_prime", dist.Normal(X, sigma_x))

def get_mvn_posterior(
self, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], noiseless: bool = False, **kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Returns parameters (mean and cov) of multivariate normal posterior
for a single sample of UIGP parameters
"""
X_train_prime = params["X_prime"]
noise = params["noise"]
noise_p = noise * (1 - jnp.array(noiseless, int))
y_residual = self.y_train.copy()
if self.mean_fn is not None:
args = [X_train_prime, params] if self.mean_fn_prior else [X_train_prime]
y_residual -= self.mean_fn(*args).squeeze()
# compute kernel matrices for train and test data
k_pp = self.kernel(X_new, X_new, params, noise_p, **kwargs)
k_pX = self.kernel(X_new, X_train_prime, params, jitter=0.0)
k_XX = self.kernel(X_train_prime, X_train_prime, params, noise, **kwargs)
# compute the predictive covariance and mean
K_xx_inv = jnp.linalg.inv(k_XX)
cov = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))
mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_residual))
if self.mean_fn is not None:
args = [X_new, params] if self.mean_fn_prior else [X_new]
mean += self.mean_fn(*args).squeeze()
return mean, cov

def _predict(
self,
rng_key: jnp.ndarray,
X_new: jnp.ndarray,
params: Dict[str, jnp.ndarray],
n: int,
noiseless: bool = False,
**kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Prediction with a single sample of UIGP parameters"""
# Sample X_new using the learned standard deviation
X_new_prime = dist.Normal(X_new, params["sigma_x"]).sample(rng_key, sample_shape=(n,))
X_new_prime = X_new_prime.mean(0)
# Get the predictive mean and covariance
y_mean, K = self.get_mvn_posterior(X_new_prime, params, noiseless, **kwargs)
# draw samples from the posterior predictive for a given set of parameters
y_sampled = dist.MultivariateNormal(y_mean, K).sample(rng_key, sample_shape=(n,))
return y_mean, y_sampled

def _set_data(self, X: jnp.ndarray, y: Optional[jnp.ndarray] = None) -> Union[Tuple[jnp.ndarray], jnp.ndarray]:
X = X if X.ndim > 1 else X[:, None]
if y is not None:
if not (X.max() == 1 and X.min() == 0):
warnings.warn(
"The default `sigma_x` prior for uncertain (stochastic) inputs assumes data is "
"normalized to (0, 1), which is not the case for your data. Therefore, the default prior "
"may not be optimal for your case. Consider passing custom prior for sigma_x, for example, "
"`sigma_x_prior_dist=numpyro.distributions.HalfNormal(scale)` if using NumPyro directly "
"or `sigma_x_prior_dist=gpax.utils.halfnormal_dist(scale)` if using a GPax wrapper",
UserWarning,
)
return X, y.squeeze()
return X

def _print_summary(self):
samples = self.get_samples(1)
numpyro.diagnostics.print_summary({k: v for (k, v) in samples.items() if 'X_prime' not in k})
78 changes: 78 additions & 0 deletions tests/test_uigp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import sys
import pytest
import numpy as onp
import jax.numpy as jnp
import jax
import numpyro
import numpyro.distributions as dist
from numpy.testing import assert_equal, assert_array_equal, assert_

sys.path.insert(0, "../gpax/")

from gpax.models.uigp import UIGP
from gpax.utils import get_keys


def get_dummy_data():
X = onp.linspace(1, 2, 8) + 0.1 * onp.random.randn(8,)
X_prime = onp.random.normal(X, 0.1)
y = (10 * X_prime**2)
return jnp.array(X_prime), jnp.array(y)


def test_fit():
rng_key = get_keys()[0]
X, y = get_dummy_data()
m = UIGP(1, 'RBF')
m.fit(rng_key, X, y, num_warmup=10, num_samples=10)
assert m.mcmc is not None


def test_fit_with_custom_sigma_x_prior():
rng_key = get_keys()[0]
X, y = get_dummy_data()
m = UIGP(1, 'RBF', sigma_x_prior_dist=dist.HalfNormal(0.55))
m.fit(rng_key, X, y, num_warmup=10, num_samples=10)
assert m.mcmc is not None


def test_get_mvn_posterior():
X, y = get_dummy_data()
X_test, _ = get_dummy_data()
X = X[:, None]
X_test = X_test[:, None]
params = {"k_length": jnp.array([1.0]),
"k_scale": jnp.array(1.0),
"noise": jnp.array(0.1),
"k_noise_length": jnp.array(0.5),
"sigma_x": jnp.array(0.3),
"X_prime": jnp.array(X + 0.1)
}
m = UIGP(1, 'RBF')
m.X_train = X
m.y_train = y
mean, cov = m.get_mvn_posterior(X_test, params)
assert isinstance(mean, jnp.ndarray)
assert isinstance(cov, jnp.ndarray)
assert_equal(mean.shape, (X_test.shape[0],))
assert_equal(cov.shape, (X_test.shape[0], X_test.shape[0]))


@pytest.mark.parametrize("noiseless", [True, False])
def test_predict(noiseless):
key = get_keys()[0]
X, y = get_dummy_data()
X_test, _ = get_dummy_data()
X = X[:, None]
X_test = X_test[:, None]
params = {"k_length": jnp.array([1.0]),
"k_scale": jnp.array(1.0),
"noise": jnp.array(0.1),
"k_noise_length": jnp.array(0.5),
"sigma_x": jnp.array(0.3),
"X_prime": jnp.array(X + 0.1)
}
m = UIGP(1, 'RBF')
m.X_train = X
m.y_train = y
m._predict(key, X_test, params, 5, noiseless)

0 comments on commit d19471b

Please sign in to comment.