-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #66 from ziatdinovmax/varinput
- Loading branch information
Showing
4 changed files
with
265 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
version = '0.1.2' | ||
version = '0.1.3' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
""" | ||
uigp.py | ||
======= | ||
Fully Bayesian implementation of Gaussian process regression with uncertain (stochastic) inputs | ||
Created by Maxim Ziatdinov (email: [email protected]) | ||
""" | ||
|
||
import warnings | ||
from typing import Callable, Dict, Optional, Tuple, Union | ||
|
||
import jax.numpy as jnp | ||
import numpyro | ||
import numpyro.distributions as dist | ||
|
||
from . import ExactGP | ||
|
||
kernel_fn_type = Callable[[jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray], jnp.ndarray], jnp.ndarray] | ||
|
||
|
||
class UIGP(ExactGP): | ||
""" | ||
Gaussian process with uncertain inputs | ||
This class extends the standard Gaussian Process model to handle uncertain inputs. | ||
It allows for incorporating the uncertainty in input data into the GP model, providing | ||
a more robust prediction. | ||
Args: | ||
input_dim: | ||
Number of input dimensions | ||
kernel: | ||
Kernel function ('RBF', 'Matern', 'Periodic', or custom function) | ||
mean_fn: | ||
Optional deterministic mean function (use 'mean_fn_priors' to make it probabilistic) | ||
kernel_prior: | ||
Optional custom priors over kernel hyperparameters. Use it when passing your custom kernel. | ||
mean_fn_prior: | ||
Optional priors over mean function parameters | ||
noise_prior_dist: | ||
Optional custom prior distribution over observational noise. Defaults to LogNormal(0,1). | ||
lengthscale_prior_dist: | ||
Optional custom prior distribution over kernel lengthscale. Defaults to LogNormal(0, 1). | ||
sigma_x_prior_dist: | ||
Optional custom prior for the input uncertainty (sigma_x). Defaults to HalfNormal(0.1) | ||
under the assumption that data is normalized to (0, 1). | ||
Examples: | ||
UIGP with custom prior over sigma_x | ||
>>> # Get random number generator keys for training and prediction | ||
>>> rng_key, rng_key_predict = gpax.utils.get_keys() | ||
>>> # Initialize model | ||
>>> gp_model = gpax.UIGP(input_dim=1, kernel='Matern', sigma_x_prior_dist=gpax.utils.halfnormal_dist(0.5)) | ||
>>> # Run HMC to obtain posterior samples for the model parameters | ||
>>> gp_model.fit(rng_key, X, y) # X and y are arrays with dimensions (n, m) and (n,) | ||
>>> # Make a prediction on new inputs (n>>1 for meaningful MCMC averaging over sampled X_new) | ||
>>> y_pred, y_samples = gp_model.predict(rng_key_predict, X_new, n=200) | ||
""" | ||
def __init__(self, | ||
input_dim: int, | ||
kernel: Union[str, kernel_fn_type], | ||
mean_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None, | ||
kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, | ||
mean_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, | ||
noise_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, | ||
noise_prior_dist: Optional[dist.Distribution] = None, | ||
lengthscale_prior_dist: Optional[dist.Distribution] = None, | ||
sigma_x_prior_dist: Optional[dist.Distribution] = None | ||
) -> None: | ||
args = (input_dim, kernel, mean_fn, kernel_prior, mean_fn_prior, noise_prior, noise_prior_dist, lengthscale_prior_dist) | ||
super(UIGP, self).__init__(*args) | ||
self.sigma_x_prior_dist = sigma_x_prior_dist | ||
|
||
def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs: float) -> None: | ||
""" | ||
Gaussian process model for uncertain (stochastic) inputs | ||
""" | ||
# Initialize mean function at zeros | ||
f_loc = jnp.zeros(X.shape[0]) | ||
|
||
# Sample input X | ||
X_prime = self._sample_x(X) | ||
|
||
# Sample kernel parameters | ||
if self.kernel_prior: | ||
kernel_params = self.kernel_prior() | ||
else: | ||
kernel_params = self._sample_kernel_params() | ||
# Sample noise | ||
if self.noise_prior: # this will be removed in the future releases | ||
noise = self.noise_prior() | ||
else: | ||
noise = self._sample_noise() | ||
# Add mean function (if any) | ||
if self.mean_fn is not None: | ||
args = [X_prime] | ||
if self.mean_fn_prior is not None: | ||
args += [self.mean_fn_prior()] | ||
f_loc += self.mean_fn(*args).squeeze() | ||
# compute kernel | ||
k = self.kernel(X_prime, X_prime, kernel_params, noise, **kwargs) | ||
# sample y according to the standard Gaussian process formula | ||
numpyro.sample( | ||
"y", | ||
dist.MultivariateNormal(loc=f_loc, covariance_matrix=k), | ||
obs=y, | ||
) | ||
|
||
def _sample_x(self, X): | ||
if self.sigma_x_prior_dist is not None: | ||
sigma_x_dist = self.sigma_x_prior_dist | ||
else: | ||
sigma_x_dist = dist.HalfNormal(.1) | ||
sigma_x = numpyro.sample("sigma_x", sigma_x_dist) | ||
return numpyro.sample("X_prime", dist.Normal(X, sigma_x)) | ||
|
||
def get_mvn_posterior( | ||
self, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], noiseless: bool = False, **kwargs: float | ||
) -> Tuple[jnp.ndarray, jnp.ndarray]: | ||
""" | ||
Returns parameters (mean and cov) of multivariate normal posterior | ||
for a single sample of UIGP parameters | ||
""" | ||
X_train_prime = params["X_prime"] | ||
noise = params["noise"] | ||
noise_p = noise * (1 - jnp.array(noiseless, int)) | ||
y_residual = self.y_train.copy() | ||
if self.mean_fn is not None: | ||
args = [X_train_prime, params] if self.mean_fn_prior else [X_train_prime] | ||
y_residual -= self.mean_fn(*args).squeeze() | ||
# compute kernel matrices for train and test data | ||
k_pp = self.kernel(X_new, X_new, params, noise_p, **kwargs) | ||
k_pX = self.kernel(X_new, X_train_prime, params, jitter=0.0) | ||
k_XX = self.kernel(X_train_prime, X_train_prime, params, noise, **kwargs) | ||
# compute the predictive covariance and mean | ||
K_xx_inv = jnp.linalg.inv(k_XX) | ||
cov = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX))) | ||
mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_residual)) | ||
if self.mean_fn is not None: | ||
args = [X_new, params] if self.mean_fn_prior else [X_new] | ||
mean += self.mean_fn(*args).squeeze() | ||
return mean, cov | ||
|
||
def _predict( | ||
self, | ||
rng_key: jnp.ndarray, | ||
X_new: jnp.ndarray, | ||
params: Dict[str, jnp.ndarray], | ||
n: int, | ||
noiseless: bool = False, | ||
**kwargs: float | ||
) -> Tuple[jnp.ndarray, jnp.ndarray]: | ||
"""Prediction with a single sample of UIGP parameters""" | ||
# Sample X_new using the learned standard deviation | ||
X_new_prime = dist.Normal(X_new, params["sigma_x"]).sample(rng_key, sample_shape=(n,)) | ||
X_new_prime = X_new_prime.mean(0) | ||
# Get the predictive mean and covariance | ||
y_mean, K = self.get_mvn_posterior(X_new_prime, params, noiseless, **kwargs) | ||
# draw samples from the posterior predictive for a given set of parameters | ||
y_sampled = dist.MultivariateNormal(y_mean, K).sample(rng_key, sample_shape=(n,)) | ||
return y_mean, y_sampled | ||
|
||
def _set_data(self, X: jnp.ndarray, y: Optional[jnp.ndarray] = None) -> Union[Tuple[jnp.ndarray], jnp.ndarray]: | ||
X = X if X.ndim > 1 else X[:, None] | ||
if y is not None: | ||
if not (X.max() == 1 and X.min() == 0): | ||
warnings.warn( | ||
"The default `sigma_x` prior for uncertain (stochastic) inputs assumes data is " | ||
"normalized to (0, 1), which is not the case for your data. Therefore, the default prior " | ||
"may not be optimal for your case. Consider passing custom prior for sigma_x, for example, " | ||
"`sigma_x_prior_dist=numpyro.distributions.HalfNormal(scale)` if using NumPyro directly " | ||
"or `sigma_x_prior_dist=gpax.utils.halfnormal_dist(scale)` if using a GPax wrapper", | ||
UserWarning, | ||
) | ||
return X, y.squeeze() | ||
return X | ||
|
||
def _print_summary(self): | ||
samples = self.get_samples(1) | ||
numpyro.diagnostics.print_summary({k: v for (k, v) in samples.items() if 'X_prime' not in k}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import sys | ||
import pytest | ||
import numpy as onp | ||
import jax.numpy as jnp | ||
import jax | ||
import numpyro | ||
import numpyro.distributions as dist | ||
from numpy.testing import assert_equal, assert_array_equal, assert_ | ||
|
||
sys.path.insert(0, "../gpax/") | ||
|
||
from gpax.models.uigp import UIGP | ||
from gpax.utils import get_keys | ||
|
||
|
||
def get_dummy_data(): | ||
X = onp.linspace(1, 2, 8) + 0.1 * onp.random.randn(8,) | ||
X_prime = onp.random.normal(X, 0.1) | ||
y = (10 * X_prime**2) | ||
return jnp.array(X_prime), jnp.array(y) | ||
|
||
|
||
def test_fit(): | ||
rng_key = get_keys()[0] | ||
X, y = get_dummy_data() | ||
m = UIGP(1, 'RBF') | ||
m.fit(rng_key, X, y, num_warmup=10, num_samples=10) | ||
assert m.mcmc is not None | ||
|
||
|
||
def test_fit_with_custom_sigma_x_prior(): | ||
rng_key = get_keys()[0] | ||
X, y = get_dummy_data() | ||
m = UIGP(1, 'RBF', sigma_x_prior_dist=dist.HalfNormal(0.55)) | ||
m.fit(rng_key, X, y, num_warmup=10, num_samples=10) | ||
assert m.mcmc is not None | ||
|
||
|
||
def test_get_mvn_posterior(): | ||
X, y = get_dummy_data() | ||
X_test, _ = get_dummy_data() | ||
X = X[:, None] | ||
X_test = X_test[:, None] | ||
params = {"k_length": jnp.array([1.0]), | ||
"k_scale": jnp.array(1.0), | ||
"noise": jnp.array(0.1), | ||
"k_noise_length": jnp.array(0.5), | ||
"sigma_x": jnp.array(0.3), | ||
"X_prime": jnp.array(X + 0.1) | ||
} | ||
m = UIGP(1, 'RBF') | ||
m.X_train = X | ||
m.y_train = y | ||
mean, cov = m.get_mvn_posterior(X_test, params) | ||
assert isinstance(mean, jnp.ndarray) | ||
assert isinstance(cov, jnp.ndarray) | ||
assert_equal(mean.shape, (X_test.shape[0],)) | ||
assert_equal(cov.shape, (X_test.shape[0], X_test.shape[0])) | ||
|
||
|
||
@pytest.mark.parametrize("noiseless", [True, False]) | ||
def test_predict(noiseless): | ||
key = get_keys()[0] | ||
X, y = get_dummy_data() | ||
X_test, _ = get_dummy_data() | ||
X = X[:, None] | ||
X_test = X_test[:, None] | ||
params = {"k_length": jnp.array([1.0]), | ||
"k_scale": jnp.array(1.0), | ||
"noise": jnp.array(0.1), | ||
"k_noise_length": jnp.array(0.5), | ||
"sigma_x": jnp.array(0.3), | ||
"X_prime": jnp.array(X + 0.1) | ||
} | ||
m = UIGP(1, 'RBF') | ||
m.X_train = X | ||
m.y_train = y | ||
m._predict(key, X_test, params, 5, noiseless) |