diff --git a/docs/source/acquisition.rst b/docs/source/acquisition.rst index 59ed2a5..47770fd 100644 --- a/docs/source/acquisition.rst +++ b/docs/source/acquisition.rst @@ -9,4 +9,4 @@ Acquisition functions .. autofunction:: gpax.acquisition.UE -.. autofunction:: gpax.acquisition.bUCB \ No newline at end of file +.. autofunction:: gpax.acquisition.qUCB \ No newline at end of file diff --git a/docs/source/kernels.rst b/docs/source/kernels.rst index ee8e0a5..7f35afa 100644 --- a/docs/source/kernels.rst +++ b/docs/source/kernels.rst @@ -5,4 +5,6 @@ Kernels .. autofunction:: gpax.kernels.MaternKernel -.. autofunction:: gpax.kernels.PeriodicKernel \ No newline at end of file +.. autofunction:: gpax.kernels.PeriodicKernel + +.. autofunction:: gpax.kernels.NNGPKernel \ No newline at end of file diff --git a/docs/source/models.rst b/docs/source/models.rst index 7931d10..c89d8d6 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -1,34 +1,68 @@ GPax models =========== -Gaussian Processes (Fully Bayesian Implementation) +Gaussian Processes - Fully Bayesian Implementation -------------------------------------------------- -.. autoclass:: gpax.gp.ExactGP +.. autoclass:: gpax.models.gp.ExactGP :members: :inherited-members: :undoc-members: :member-order: bysource :show-inheritance: -.. autoclass:: gpax.vgp.vExactGP +.. autoclass:: gpax.models.vgp.vExactGP :members: :inherited-members: :undoc-members: :member-order: bysource :show-inheritance: -Deep Kernel Learning (Fully Bayesian Implementation) +Gaussian Processes - Approximate Bayesian +------------------------------------------ +.. autoclass:: gpax.models.gp.ExactGP + :members: + :inherited-members: + :undoc-members: + :member-order: bysource + :show-inheritance: + +Deep Kernel Learning - Fully Bayesian Implementation ---------------------------------------------------- -.. autoclass:: gpax.dkl.DKL +.. autoclass:: gpax.models.dkl.DKL :members: :inherited-members: :undoc-members: :member-order: bysource :show-inheritance: -Deep Kernel Learning (Approximate Bayesian) +Deep Kernel Learning - Approximate Bayesian ------------------------------------------- -.. autoclass:: gpax.vidkl.viDKL +.. autoclass:: gpax.models.vidkl.viDKL + :members: + :inherited-members: + :undoc-members: + :member-order: bysource + :show-inheritance: + +Infinite-width Bayesian Neural Networks +---------------------------------------- +.. autoclass:: gpax.models.ibnn.iBNN + :members: + :inherited-members: + :undoc-members: + :member-order: bysource + :show-inheritance: + +Multi-Task Learning +-------------------- +.. autoclass:: gpax.models.mtgp.MultiTaskGP + :members: + :inherited-members: + :undoc-members: + :member-order: bysource + :show-inheritance: + +.. autoclass:: gpax.models.vi_mtdkl.viMTDKL :members: :inherited-members: :undoc-members: diff --git a/gpax/__init__.py b/gpax/__init__.py index d68f831..9111430 100644 --- a/gpax/__init__.py +++ b/gpax/__init__.py @@ -1,12 +1,11 @@ -from . import utils, kernels, acquisition -from .gp import ExactGP -from .vgp import vExactGP -from .bnn import DKL, viDKL, iBNN, vi_iBNN -from .vigp import viGP -from .spm import sPM -from .hypo import sample_next - from .__version__ import version as __version__ +from . import utils +from . import kernels +from . import acquisition +from .hypo import sample_next +from .models import (DKL, CoregGP, ExactGP, MultiTaskGP, iBNN, vExactGP, + vi_iBNN, viDKL, viGP, viMTDKL) -__all__ = ["utils", "kernels", "acquisition", "ExactGP", "vExactGP", "DKL", - "viDKL", "iBNN", "vi_iBNN", "viGP", "sPM", "sample_next", "__version__"] +__all__ = ["utils", "kernels", "mtkernels", "acquisition", "ExactGP", "vExactGP", "DKL", + "viDKL", "iBNN", "vi_iBNN", "MultiTaskGP", "viMTDKL", "viGP", "sPM", + "CoregGP", "sample_next", "__version__"] diff --git a/gpax/acquisition/__init__.py b/gpax/acquisition/__init__.py new file mode 100644 index 0000000..a12e435 --- /dev/null +++ b/gpax/acquisition/__init__.py @@ -0,0 +1 @@ +from .acquisition import * \ No newline at end of file diff --git a/gpax/acquisition.py b/gpax/acquisition/acquisition.py similarity index 99% rename from gpax/acquisition.py rename to gpax/acquisition/acquisition.py index 301f63d..b670b5d 100644 --- a/gpax/acquisition.py +++ b/gpax/acquisition/acquisition.py @@ -15,7 +15,7 @@ import numpy as onp import numpyro.distributions as dist -from .gp import ExactGP +from ..models.gp import ExactGP def EI(rng_key: jnp.ndarray, model: Type[ExactGP], diff --git a/gpax/bnn/__init__.py b/gpax/bnn/__init__.py deleted file mode 100644 index 30f2569..0000000 --- a/gpax/bnn/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .dkl import DKL -from .vidkl import viDKL -from .ibnn import iBNN -from .vi_ibnn import vi_iBNN - -__all__ = ["DKL", "viDKL", "iBNN", "vi_iBNN"] diff --git a/gpax/hypo.py b/gpax/hypo.py index 1a77250..4d54aa1 100644 --- a/gpax/hypo.py +++ b/gpax/hypo.py @@ -13,8 +13,8 @@ import numpy as np import numpyro -from .gp import ExactGP -from .spm import sPM +from .models.gp import ExactGP +from .models.spm import sPM from .utils import get_keys diff --git a/gpax/kernels/__init__.py b/gpax/kernels/__init__.py new file mode 100644 index 0000000..cd18852 --- /dev/null +++ b/gpax/kernels/__init__.py @@ -0,0 +1,16 @@ +from .kernels import (MaternKernel, NNGPKernel, PeriodicKernel, RBFKernel, + get_kernel, nngp_erf, nngp_relu) +from .mtkernels import (LCMKernel, MultitaskKernel, MultivariateKernel, + index_kernel) + +__all__ = [ + "RBFKernel", + "MaternKernel", + "PeriodicKernel", + "NNGPKernel", + "get_kernel", + "index_kernel", + "MultitaskKernel", + "MultivariateKernel", + "LCMKernel" +] diff --git a/gpax/kernels.py b/gpax/kernels/kernels.py similarity index 98% rename from gpax/kernels.py rename to gpax/kernels/kernels.py index fb91199..fb8d4fb 100644 --- a/gpax/kernels.py +++ b/gpax/kernels/kernels.py @@ -28,7 +28,7 @@ def add_jitter(x, jitter=1e-6): def square_scaled_distance(X: jnp.ndarray, Z: jnp.ndarray, lengthscale: Union[jnp.ndarray, float] = 1. ) -> jnp.ndarray: - """ + r""" Computes a square of scaled distance, :math:`\|\frac{X-Z}{l}\|^2`, between X and Z are vectors with :math:`n x num_features` dimensions """ @@ -115,7 +115,7 @@ def PeriodicKernel(X: jnp.ndarray, Z: jnp.ndarray, def nngp_erf(x1: jnp.ndarray, x2: jnp.ndarray, - var_b: jnp.array, var_w: jnp.array, + var_b: jnp.array, var_w: jnp.array, depth: int = 3) -> jnp.array: """ Computes the Neural Network Gaussian Process (NNGP) kernel value for @@ -187,7 +187,7 @@ def NNGPKernel(activation: str = 'erf', depth: int = 3 Args: activation: activation function ('erf' or 'relu') - depth: The number of layers in the corresponding infinite-width neural network. + depth: The number of layers in the corresponding infinite-width neural network. Controls the level of recursion in the computation. Returns: @@ -196,7 +196,7 @@ def NNGPKernel(activation: str = 'erf', depth: int = 3 nngp_single_pair_ = nngp_relu if activation == 'relu' else nngp_erf def NNGPKernel_func(X: jnp.ndarray, Z: jnp.ndarray, - params: Dict[str, jnp.ndarray], + params: Dict[str, jnp.ndarray], noise: jnp.ndarray = 0, **kwargs ) -> jnp.ndarray: """ diff --git a/gpax/kernels/mtkernels.py b/gpax/kernels/mtkernels.py new file mode 100644 index 0000000..b3101dc --- /dev/null +++ b/gpax/kernels/mtkernels.py @@ -0,0 +1,229 @@ +""" +mtkernels.py +========== + +Multi-task kernel functions + +Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com) +""" + + +from typing import Dict, Callable + +import jax.numpy as jnp +from jax import vmap + +from .kernels import add_jitter, get_kernel + +kernel_fn_type = Callable[[jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray], jnp.ndarray], jnp.ndarray] + + +def index_kernel(indices1, indices2, params): + r""" + Computes the task kernel matrix for given task indices. + The task covariance between two discrete indices i and j + is calculated as: + + .. math:: + task\_kernel_values[i, j] = WW^T[i, j] + v[i] \delta_{ij} + + where :math:`WW^T` is the matrix product of :math:`B` with its transpose, :math:`v[i]` + is the variance of task :math:`i`, and :math:`\delta_{ij}` is the Kronecker delta + which is 1 if :math:`i == j` and 0 otherwise. + + Args: + indices1: + An array of task indices for the first set of data points. + Each entry is an integer that indicates the task associated + with a data point. + indices2: + An array of task indices for the second set of data points. + Each entry is an integer that indicates the task associated + with a data point. + params: + Dictionary of parameters for the task kernel. It includes: + 'W': The coregionalization matrix of shape (num_tasks, num_tasks). + This is a symmetric positive semi-definite matrix that determines + the correlation structure between the tasks. + 'v': + The vector of task variances with the (n_tasks,) shape. + This is a diagonal matrix that determines the variance of each task. + + Returns: + Computed kernel matrix of the shape (len(indices1), len(indices2)). + Each entry task_kernel_values[i, j] is the covariance between the tasks + associated with data point i in indices1 and data point j in indices2. + """ + W = params["W"] + v = params["v"] + B = jnp.dot(W, W.T) + jnp.diag(v) + return B[jnp.ix_(indices1, indices2)] + + +def MultitaskKernel(base_kernel, **kwargs1): + r""" + Constructs a multi-task kernel given a base data kernel. + The multi-task kernel is defined as + + .. math:: + K(x_i, y_j) = k_{data}(x, y) * k_{task}(i, j) + + where *x* and *y* are data points and *i* and *j* are the tasks + associated with these points. The task indices are passed as the + last column in the input data vectors. + + Args: + base_kernel: + The name of the base data kernel or a function that computes + the base data kernel. This kernel is used to compute the + similarities in the input space. The built-in kernels are 'RBF', + 'Matern', 'Periodic', and 'NNGP'. + + **kwargs1: + Additional keyword arguments to pass to the `get_kernel` + function when constructing the base data kernel. + + Returns: + The constructed multi-task kernel function. + """ + + data_kernel = get_kernel(base_kernel, **kwargs1) + + def multi_task_kernel(X, Z, params, noise=0, **kwargs2): + """ + Computes multi-task kernel matrix, given two input arrays and + a dictionary wuth kernel parameters. The input arrays must have the + shape (N, D+1) where N is the number of data points and D is the feature + dimension. The last column contains task indices. + """ + + # Extract input data and task indices from X and Z + X_data, indices_X = X[:, :-1], X[:, -1].astype(int) + Z_data, indices_Z = Z[:, :-1], Z[:, -1].astype(int) + + # Compute data and task kernels + k_data = data_kernel(X_data, Z_data, params, 0, **kwargs2) # noise will be added later + k_task = index_kernel(indices_X, indices_Z, params) + + # Compute the multi-task kernel + K = k_data * k_task + + # Add noise associated with each task + if X.shape == Z.shape: + # Get the noise corresponding to each sample's task + if isinstance(noise, (int, float)): + noise = jnp.ones(1) * noise + sample_noise = noise[indices_X] + # Add small jitter for numerical stability + sample_noise = add_jitter(sample_noise, **kwargs2) + # Add the noise to the diagonal of the kernel matrix + K = K.at[jnp.diag_indices(K.shape[0])].add(sample_noise) + + return K + + return multi_task_kernel + + +def MultivariateKernel(base_kernel, num_tasks, **kwargs1): + r""" + Construct a multivariate kernel given a base data kernel asssuming + that all tasks share the same input space. For situations where not all + tasks share the same input parameters, see MultitaskKernel. + The multivariate kernel is defined as a Kronecker product between + data and task kernels + + .. math:: + K(x_i, y_j) = k_{data}(x, y) * k_{task}(i, j) + + where *x* and *y* are data points and *i* and *j* are the tasks + associated with these points. + + Args: + base_kernel: + The name of the base data kernel or a function that computes + the base data kernel. This kernel is used to compute the + similarities in the input space. THe built-in kernels are 'RBF', + 'Matern', 'Periodic', and 'NNGP'. + num_tasks: + number of tasks + + **kwargs1 : dict + Additional keyword arguments to pass to the `get_kernel` + function when constructing the base data kernel. + + Returns: + The constructed multi-task kernel function. + """ + + data_kernel = get_kernel(base_kernel, **kwargs1) + + def multivariate_kernel(X, Z, params, noise=0, **kwargs2): + """ + Computes multivariate kernel matrix, given two input arrays and + a dictionary wuth kernel parameters. The input arrays must have the + shape (N, D) where N is the number of data points and D is the feature + dimension. + """ + + # Compute data and task kernels + task_labels = jnp.arange(num_tasks) + k_data = data_kernel(X, Z, params, 0, **kwargs2) # noise will be added later + k_task = index_kernel(task_labels, task_labels, params) + + # Compute the multi-task kernel + K = jnp.kron(k_data, k_task) + + # Add noise associated with each task + if X.shape == Z.shape: + # Make sure noise is a jax ndarray with a proper shape + if isinstance(noise, (float, int)): + noise = jnp.ones(num_tasks) * noise + # Add small jitter for numerical stability + noise = add_jitter(noise, **kwargs2) + # Create a block-diagonal noise matrix with the noise terms + # on the diagonal of each block + noise_matrix = jnp.kron(jnp.eye(k_data.shape[0]), jnp.diag(noise)) + # Add the noise to the diagonal of the kernel matrix + K += noise_matrix + + return K + + return multivariate_kernel + + +def LCMKernel(base_kernel, shared_input_space=True, num_tasks=None, **kwargs1): + """ + Construct kernel for a Linear Model of Coregionalization (LMC) + + Args: + base_kernel: + The name of the data kernel or a function that computes + the data kernel. This kernel is used to compute the + similarities in the input space. The built-in kernels are 'RBF', + 'Matern', 'Periodic', and 'NNGP'. + shared_input_space: + If True (default), assumes that all tasks share the same input space and + uses a multivariate kernel (Kronecker product). + If False, assumes that different tasks have different number of observations + and uses a multitask kernel (elementwise multiplication). In that case, the task + indices must be appended as the last column of the input vector. + num_tasks: int, optional + Number of tasks. This is only used if `shared_input_space` is True. + **kwargs1: + Additional keyword arguments to pass to the `get_kernel` + function when constructing the base data kernel. + + Returns: + The constructed LMC kernel function. + """ + + if shared_input_space: + multi_kernel = MultivariateKernel(base_kernel, num_tasks, **kwargs1) + else: + multi_kernel = MultitaskKernel(base_kernel, **kwargs1) + + def lcm_kernel(X, Z, params, noise=0, **kwargs2): + k = vmap(lambda p: multi_kernel(X, Z, p, noise, **kwargs2))(params) + return k.sum(0) + + return lcm_kernel \ No newline at end of file diff --git a/gpax/models/__init__.py b/gpax/models/__init__.py new file mode 100644 index 0000000..84f2f9a --- /dev/null +++ b/gpax/models/__init__.py @@ -0,0 +1,25 @@ +from .gp import ExactGP +from .vgp import vExactGP +from .vigp import viGP +from .spm import sPM +from .ibnn import iBNN +from .vi_ibnn import vi_iBNN +from .dkl import DKL +from .vidkl import viDKL +from .vi_mtdkl import viMTDKL +from .mtgp import MultiTaskGP +from .corgp import CoregGP + +__all__ = [ + "ExactGP", + "vExactGP", + "viGP", + "sPM", + "iBNN", + "vi_iBNN", + "DKL", + "viDKL", + "viMTDKL", + "MultiTaskGP", + "CoregGP" +] diff --git a/gpax/models/corgp.py b/gpax/models/corgp.py new file mode 100644 index 0000000..0e352a8 --- /dev/null +++ b/gpax/models/corgp.py @@ -0,0 +1,113 @@ +from typing import Callable, Dict, Optional + +import jax.numpy as jnp +import numpy as onp +import numpyro +import numpyro.distributions as dist + +from .gp import ExactGP +from ..kernels import MultitaskKernel + + +class CoregGP(ExactGP): + + """ + Coregionalized Gaussian Process model + + Args: + input_dim: + Number of input dimensions + data_kernel: + Kernel function operating on data inputs ('RBF', 'Matern', 'Periodic', or a custom function) + mean_fn: + Optional deterministic mean function (use 'mean_fn_priors' to make it probabilistic) + data_kernel_prior: + Optional custom priors over the data kernel hyperparameters; uses LogNormal(0,1) by default + mean_fn_prior: + Optional priors over mean function parameters + noise_prior: + Optional custom prior for observation noise; uses LogNormal(0,1) by default. + task_kernel_prior: + Optional custom priors over task kernel parameters; + Defaults to Normal(0, 10) for weights W and LogNormal(0, 1) for variances v. + rank: int + Rank of the weight matrix in the task kernel. Cannot be larger than the number of tasks. + Higher rank implies higher correlation. Defaults to 1. + + """ + def __init__(self, input_dim: int, data_kernel: str, + mean_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None, + data_kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, + mean_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, + noise_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, + task_kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, + rank: int = 1, **kwargs) -> None: + args = (input_dim, None, mean_fn, None, mean_fn_prior, noise_prior) + super(CoregGP, self).__init__(*args) + self.num_tasks = None + self.rank = rank + self.kernel = MultitaskKernel(data_kernel, **kwargs) + self.data_kernel_prior = data_kernel_prior + self.task_kernel_prior = task_kernel_prior + self.kernel_name = data_kernel + + def model(self, + X: jnp.ndarray, + y: jnp.ndarray = None, + **kwargs: float + ) -> None: + """Multitask GP probabilistic model with inputs X and targets y""" + self.num_tasks = len(onp.unique(X[:, -1])) + # Initialize mean function at zeros + f_loc = jnp.zeros(X.shape[0]) + + # Sample data kernel parameters + if self.data_kernel_prior: + data_kernel_params = self.data_kernel_prior() + else: + data_kernel_params = self._sample_kernel_params(output_scale=False) + + # Sample task kernel parameters + if self.task_kernel_prior: + task_kernel_params = self.task_kernel_prior() + else: + task_kernel_params = self._sample_task_kernel_params(self.num_tasks, self.rank) + + # Combine two dictionaries with parameters + kernel_params = {**data_kernel_params, **task_kernel_params} + + # Sample noise + if self.noise_prior: + noise = self.noise_prior() + else: # consider using numpyro.plate here + noise = numpyro.sample( + "noise", dist.LogNormal( + jnp.zeros(self.num_tasks), jnp.ones(self.num_tasks)) + ) + + # Compute multitask_kernel + k = self.kernel(X, X, kernel_params, noise) + + # Add mean function (if any) + if self.mean_fn is not None: + args = [X] + if self.mean_fn_prior is not None: + args += [self.mean_fn_prior()] + f_loc += self.mean_fn(*args).squeeze() + + # sample y according to the standard Gaussian process formula + numpyro.sample( + "y", + dist.MultivariateNormal(loc=f_loc, covariance_matrix=k), + obs=y, + ) + + def _sample_task_kernel_params(self, n_tasks, rank): + """ + Sample task kernel parameters with default weakly-informative priors + """ + W = numpyro.sample("W", numpyro.distributions.Normal( + jnp.zeros(shape=(n_tasks, rank)), 10*jnp.ones(shape=(n_tasks, rank)))) + v = numpyro.sample("v", numpyro.distributions.LogNormal( + jnp.zeros(shape=(n_tasks,)), jnp.ones(shape=(n_tasks,)))) + return {"W": W, "v": v} diff --git a/gpax/bnn/dkl.py b/gpax/models/dkl.py similarity index 95% rename from gpax/bnn/dkl.py rename to gpax/models/dkl.py index 30fde16..e43204e 100644 --- a/gpax/bnn/dkl.py +++ b/gpax/models/dkl.py @@ -16,7 +16,7 @@ import numpyro.distributions as dist from jax import jit -from ..vgp import vExactGP +from .vgp import vExactGP class DKL(vExactGP): @@ -40,6 +40,11 @@ class DKL(vExactGP): latent_prior: Optional prior over the latent space (BNN embedding); uses none by default + **kwargs: + Optional custom prior distributions over observational noise (noise_dist_prior) + and kernel lengthscale (lengthscale_prior_dist) + + Examples: DKL with image patches as inputs and a 1-d vector as targets @@ -61,9 +66,10 @@ def __init__(self, input_dim: int, z_dim: int = 2, kernel: str = 'RBF', kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, nn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None, nn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, - latent_prior: Optional[Callable[[jnp.ndarray], Dict[str, jnp.ndarray]]] = None + latent_prior: Optional[Callable[[jnp.ndarray], Dict[str, jnp.ndarray]]] = None, + **kwargs ) -> None: - super(DKL, self).__init__(input_dim, kernel, None, kernel_prior) + super(DKL, self).__init__(input_dim, kernel, None, kernel_prior, **kwargs) self.nn = nn if nn else mlp self.nn_prior = nn_prior if nn_prior else mlp_prior(input_dim, z_dim) self.kernel_dim = z_dim @@ -88,8 +94,7 @@ def model(self, else: kernel_params = self._sample_kernel_params(task_dim) # Sample noise - with numpyro.plate('obs_noise', task_dim): - noise = numpyro.sample("noise", dist.LogNormal(0.0, 1.0)) + noise = self._sample_noise(task_dim) # GP's mean function f_loc = jnp.zeros(z.shape[:2]) # compute kernel(s) @@ -109,7 +114,7 @@ def _get_mvn_posterior(self, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], noiseless: bool = False, **kwargs: float ) -> Tuple[jnp.ndarray, jnp.ndarray]: - noise = params["noise"] + noise = params.pop("noise") noise_p = noise * (1 - jnp.array(noiseless, int)) # embed data into the latent space z_train = self.nn(X_train, params) diff --git a/gpax/gp.py b/gpax/models/gp.py similarity index 85% rename from gpax/gp.py rename to gpax/models/gp.py index 3d92b47..1486f1a 100644 --- a/gpax/gp.py +++ b/gpax/models/gp.py @@ -7,6 +7,7 @@ Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com) """ +import warnings from functools import partial from typing import Callable, Dict, Optional, Tuple, Type, Union @@ -19,9 +20,10 @@ from jax import jit from numpyro.infer import MCMC, NUTS, init_to_median, Predictive -from .kernels import get_kernel -from .utils import split_in_batches +from ..kernels import get_kernel +from ..utils import split_in_batches +kernel_fn_type = Callable[[jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray], jnp.ndarray], jnp.ndarray] clear_cache = jax._src.dispatch.xla_primitive_callable.cache_clear @@ -38,11 +40,13 @@ class ExactGP: mean_fn: Optional deterministic mean function (use 'mean_fn_priors' to make it probabilistic) kernel_prior: - Optional custom priors over kernel hyperparameters; uses LogNormal(0,1) by default + Optional custom priors over kernel hyperparameters. Use it when passing your custom kernel. mean_fn_prior: Optional priors over mean function parameters - noise_prior: - Optional custom prior for observation noise; uses LogNormal(0,1) by default. + noise_prior_dist: + Optional custom prior distribution over observational noise. Defaults to LogNormal(0,1). + lengthscale_prior_dist: + Optional custom prior distribution over kernel lengthscale. Defaults to LogNormal(0, 1). Examples: @@ -57,23 +61,11 @@ class ExactGP: >>> # Make a noiseless prediction on new inputs >>> y_pred, y_samples = gp_model.predict(rng_key_predict, X_new, noiseless=True) - GP for noiseless observations - - >>> # Initialize model - >>> gp_model = gpax.ExactGP( - >>> input_dim=1, kernel='RBF', - >>> noise_prior = lambda: numpyro.deterministic("noise", 0) # zero observational noise - >>> ) - >>> # Run HMC to obtain posterior samples for the GP model parameters - >>> gp_model.fit(rng_key, X, y) # X and y are arrays with dimensions (n, 1) and (n,) - >>> # Make prediction on new inputs - >>> y_pred, y_samples = gp_model.predict(rng_key_predict, X_new) - GP with custom noise prior >>> gp_model = gpax.ExactGP( >>> input_dim=1, kernel='RBF', - >>> noise_prior = lambda: numpyro.sample("noise", numpyro.distributions.HalfNormal(.1)) + >>> noise_prior_dist = numpyro.distributions.HalfNormal(.1) >>> ) >>> # Run HMC to obtain posterior samples for the GP model parameters >>> gp_model.fit(rng_key, X, y) # X and y are arrays with dimensions (n, 1) and (n,) @@ -101,13 +93,25 @@ class ExactGP: >>> y_pred, y_samples = gp_model.predict(rng_key_predict, X_new, noiseless=True) """ - def __init__(self, input_dim: int, kernel: str, + def __init__(self, input_dim: int, kernel: Union[str, kernel_fn_type], mean_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None, kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, mean_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, - noise_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None + noise_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, + noise_prior_dist: Optional[dist.Distribution] = None, + lengthscale_prior_dist: Optional[dist.Distribution] = None ) -> None: clear_cache() + if noise_prior is not None: + warnings.warn("`noise_prior` is deprecated and will be removed in a future version. " + "Please use `noise_prior_dist` instead, which accepts an instance of a " + "numpyro.distributions Distribution object, e.g., `dist.HalfNormal(scale=0.1)`, " + "rather than a function that calls `numpyro.sample`.", FutureWarning) + if kernel_prior is not None: + warnings.warn("`kernel_prior` will remain available for complex priors. However, for " + "modifying only the lengthscales, it is recommended to use `lengthscale_prior_dist` instead. " + "`lengthscale_prior_dist` accepts an instance of a numpyro.distributions Distribution object, " + "e.g., `dist.Gamma(2, 5)`, rather than a function that calls `numpyro.sample`.", UserWarning) self.kernel_dim = input_dim self.kernel = get_kernel(kernel) self.kernel_name = kernel if isinstance(kernel, str) else None @@ -115,6 +119,8 @@ def __init__(self, input_dim: int, kernel: str, self.kernel_prior = kernel_prior self.mean_fn_prior = mean_fn_prior self.noise_prior = noise_prior + self.noise_prior_dist = noise_prior_dist + self.lengthscale_prior_dist = lengthscale_prior_dist self.X_train = None self.y_train = None self.mcmc = None @@ -133,10 +139,10 @@ def model(self, else: kernel_params = self._sample_kernel_params() # Sample noise - if self.noise_prior: + if self.noise_prior: # this will be removed in the future releases noise = self.noise_prior() else: - noise = numpyro.sample("noise", dist.LogNormal(0.0, 1.0)) + noise = self._sample_noise() # Add mean function (if any) if self.mean_fn is not None: args = [X] @@ -169,8 +175,8 @@ def fit(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray, Args: rng_key: random number generator key - X: 2D feature vector with *(number of points, number of features)* dimensions - y: 1D target vector with *(n,)* dimensions + X: 2D feature vector + y: 1D target vector num_warmup: number of HMC warmup states num_samples: number of HMC samples num_chains: number of HMC chains @@ -206,6 +212,35 @@ def fit(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray, if print_summary: self._print_summary() + def _sample_noise(self) -> jnp.ndarray: + if self.noise_prior_dist is not None: + noise_dist = self.noise_prior_dist + else: + noise_dist = dist.LogNormal(0, 1) + return numpyro.sample("noise", noise_dist) + + def _sample_kernel_params(self, output_scale=True) -> Dict[str, jnp.ndarray]: + """ + Sample kernel parameters with default + weakly-informative log-normal priors + """ + if self.lengthscale_prior_dist is not None: + length_dist = self.lengthscale_prior_dist + else: + length_dist = dist.LogNormal(0.0, 1.0) + with numpyro.plate('ard', self.kernel_dim): # allows using ARD kernel for kernel_dim > 1 + length = numpyro.sample("k_length", length_dist) + if output_scale: + scale = numpyro.sample("k_scale", dist.LogNormal(0.0, 1.0)) + else: + scale = numpyro.deterministic("k_scale", jnp.array(1.0)) + if self.kernel_name == 'Periodic': + period = numpyro.sample("period", dist.LogNormal(0.0, 1.0)) + kernel_params = { + "k_length": length, "k_scale": scale, + "period": period if self.kernel_name == "Periodic" else None} + return kernel_params + def get_samples(self, chain_dim: bool = False) -> Dict[str, jnp.ndarray]: """Get posterior samples (after running the MCMC chains)""" return self.mcmc.get_samples(group_by_chain=chain_dim) @@ -219,7 +254,7 @@ def get_mvn_posterior(self, Returns parameters (mean and cov) of multivariate normal posterior for a single sample of GP parameters """ - noise = params["noise"] + noise = params.pop("noise") noise_p = noise * (1 - jnp.array(noiseless, int)) y_residual = self.y_train.copy() if self.mean_fn is not None: @@ -248,24 +283,6 @@ def _predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray, y_sampled = dist.MultivariateNormal(y_mean, K).sample(rng_key, sample_shape=(n,)) return y_mean, y_sampled - def _sample_kernel_params(self, output_scale=True) -> Dict[str, jnp.ndarray]: - """ - Sample kernel parameters with default - weakly-informative log-normal priors - """ - with numpyro.plate('k_param', self.kernel_dim): # allows using ARD kernel for kernel_dim > 1 - length = numpyro.sample("k_length", dist.LogNormal(0.0, 1.0)) - if output_scale: - scale = numpyro.sample("k_scale", dist.LogNormal(0.0, 1.0)) - else: - scale = numpyro.deterministic("k_scale", jnp.ndarray(1.0)) - if self.kernel_name == 'Periodic': - period = numpyro.sample("period", dist.LogNormal(0.0, 1.0)) - kernel_params = { - "k_length": length, "k_scale": scale, - "period": period if self.kernel_name == "Periodic" else None} - return kernel_params - def _predict_in_batches(self, rng_key: jnp.ndarray, X_new: jnp.ndarray, batch_size: int = 100, batch_dim: int = 0, diff --git a/gpax/bnn/ibnn.py b/gpax/models/ibnn.py similarity index 98% rename from gpax/bnn/ibnn.py rename to gpax/models/ibnn.py index abf34b4..6c59f4e 100644 --- a/gpax/bnn/ibnn.py +++ b/gpax/models/ibnn.py @@ -13,7 +13,7 @@ import numpyro import numpyro.distributions as dist -from .. import ExactGP +from .gp import ExactGP from ..kernels import get_kernel diff --git a/gpax/models/mtgp.py b/gpax/models/mtgp.py new file mode 100644 index 0000000..bcf5fb7 --- /dev/null +++ b/gpax/models/mtgp.py @@ -0,0 +1,206 @@ +from typing import Callable, Dict, Optional + +import jax.numpy as jnp +import numpy as onp +import numpyro +import numpyro.distributions as dist + +from .gp import ExactGP +from ..kernels import LCMKernel + + +class MultiTaskGP(ExactGP): + """ + Gaussian process for multi-task/fidelity learning + + Args: + input_dim: + Number of input dimensions + data_kernel: + Kernel function operating on data inputs ('RBF', 'Matern', 'Periodic', or a custom function) + num_latents: + Number of latent functions. Typically equal to or less than the number of tasks + shared_input_space: + If True (default), assumes that all tasks share the same input space and + uses a multivariate kernel (Kronecker product). If False, assumes that different tasks + have different number of observations and uses a multitask kernel (elementwise multiplication). + In that case, the task indices must be appended as the last column of the input vector. + num_tasks: + Number of tasks. This is only needed if `shared_input_space` is True. + rank: + Rank of the weight matrix in the task kernel. Cannot be larger than the number of tasks. + Higher rank implies higher correlation. Uses *(num_tasks - 1)* when not specified. + mean_fn: + Optional deterministic mean function (use 'mean_fn_priors' to make it probabilistic) + data_kernel_prior: + Optional custom priors over the data kernel hyperparameters + mean_fn_prior: + Optional priors over mean function parameters + noise_prior_dist: + Optional custom prior distribution over observational noise. Defaults to LogNormal(0,1). + lengthscale_prior_dist: + Optional custom prior distribution over kernel lengthscale. Defaults to LogNormal(0, 1) + W_prior_dist: + Optional custom prior distribution over W in the task kernel, :math:`WW^T + diag(v)`. + Defaults to Normal(0, 10). + v_prior_dist: + Optional custom prior distribution over v in the task kernel, :math:`WW^T + diag(v)`. + Must be non-negative. Defaults to LogNormal(0, 1) + task_kernel_prior: + Optional custom priors over task kernel parameters; + Defaults to Normal(0, 10) for weights W and LogNormal(0, 1) for variances v. + output_scale: + Option to sample data kernel's output scale. + Defaults to False to avoid over-parameterization (the scale is already absorbed into task kernel) + """ + def __init__(self, input_dim: int, data_kernel: str, + num_latents: int = None, shared_input_space: bool = True, + num_tasks: int = None, rank: Optional[int] = None, + mean_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None, + data_kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, + mean_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, + noise_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, + noise_prior_dist: Optional[dist.Distribution] = None, + lengthscale_prior_dist: Optional[dist.Distribution] = None, + W_prior_dist: Optional[dist.Distribution] = None, + v_prior_dist: Optional[dist.Distribution] = None, + output_scale: bool = False, **kwargs) -> None: + args = (input_dim, None, mean_fn, None, mean_fn_prior, noise_prior) + super(MultiTaskGP, self).__init__(*args) + if shared_input_space: + if num_tasks is None: + raise ValueError("Please specify num_tasks") + else: + if num_latents is None: + raise ValueError("Please specify num_latents") + self.num_tasks = num_tasks + self.num_latents = num_tasks if num_latents is None else num_latents + self.rank = rank + self.kernel = LCMKernel( + data_kernel, shared_input_space, num_tasks, **kwargs) + self.data_kernel_name = data_kernel if isinstance(data_kernel, str) else None + self.data_kernel_prior = data_kernel_prior + self.noise_prior = noise_prior # will be removed + self.noise_prior_dist = noise_prior_dist + self.lengthscale_prior_dist = lengthscale_prior_dist + self.W_prior_dist = W_prior_dist + self.v_prior_dist = v_prior_dist + self.shared_input = shared_input_space + self.output_scale = output_scale + + def model(self, + X: jnp.ndarray, + y: jnp.ndarray = None, + **kwargs: float + ) -> None: + """Multitask GP probabilistic model with inputs X and targets y""" + + # Initialize mean function at zeros + if self.shared_input: + f_loc = jnp.zeros(self.num_tasks * X.shape[0]) + else: + f_loc = jnp.zeros(X.shape[0]) + + # Check that we have necessary info for sampling kernel params + if not self.shared_input and self.num_tasks is None: + self.num_tasks = len(onp.unique(self.X_train[:, -1])) + + if self.rank is None: + self.rank = self.num_tasks - 1 + + # Sample data kernel parameters + if self.data_kernel_prior: + data_kernel_params = self.data_kernel_prior() + else: + data_kernel_params = self._sample_kernel_params() + + # Sample task kernel parameters + task_kernel_params = self._sample_task_kernel_params() + + # Combine two dictionaries with parameters + kernel_params = {**data_kernel_params, **task_kernel_params} + + # Sample noise + if self.noise_prior: # this will be removed in the future releases + noise = self.noise_prior() + else: + noise = self._sample_noise() + + # Compute multitask_kernel + k = self.kernel(X, X, kernel_params, noise, **kwargs) + + # Add mean function (if any) + if self.mean_fn is not None: + args = [X] + if self.mean_fn_prior is not None: + args += [self.mean_fn_prior()] + f_loc += self.mean_fn(*args).squeeze() + + # Sample y according to the standard Gaussian process formula + numpyro.sample( + "y", + dist.MultivariateNormal(loc=f_loc, covariance_matrix=k), + obs=y, + ) + + def _sample_noise(self): + """Sample observational noise""" + if self.noise_prior_dist is not None: + noise_dist = self.noise_prior_dist + else: + noise_dist = dist.LogNormal( + jnp.zeros(self.num_tasks), + jnp.ones(self.num_tasks)) + + noise = numpyro.sample("noise", noise_dist.to_event(1)) + return noise + + def _sample_task_kernel_params(self): + """ + Sample task kernel parameters with default weakly-informative priors + or custom priors for all the latent functions + """ + if self.W_prior_dist is not None: + W_dist = self.W_prior_dist + else: + W_dist = dist.Normal( + jnp.zeros(shape=(self.num_latents, self.num_tasks, self.rank)), # loc + 10*jnp.ones(shape=(self.num_latents, self.num_tasks, self.rank)) # var + ) + if self.v_prior_dist is not None: + v_dist = self.v_prior_dist + else: + v_dist = dist.LogNormal( + jnp.zeros(shape=(self.num_latents, self.num_tasks)), # loc + jnp.ones(shape=(self.num_latents, self.num_tasks)) # var + ) + with numpyro.plate("latent_plate_task", self.num_latents): + W = numpyro.sample("W", W_dist.to_event(2)) + v = numpyro.sample("v", v_dist.to_event(1)) + return {"W": W, "v": v} + + def _sample_kernel_params(self): + """ + Sample data ("base") kernel parameters with default weakly-informative + priors for all the latent functions. Optionally allows to specify a custom + prior over the kernel lengthscale. + """ + squeezer = lambda x: x.squeeze() if self.num_latents > 1 else x + if self.lengthscale_prior_dist is not None: + length_dist = self.lengthscale_prior_dist + else: + length_dist = dist.LogNormal(0.0, 1.0) + with numpyro.plate("latent_plate_data", self.num_latents, dim=-2): + with numpyro.plate("ard", self.kernel_dim, dim=-1): + length = numpyro.sample("k_length", length_dist) + if self.output_scale: + scale = numpyro.sample("k_scale", dist.LogNormal(0.0, 1.0)) + else: + scale = numpyro.deterministic("k_scale", jnp.ones(self.num_latents)) + if self.data_kernel_name == 'Periodic': + period = numpyro.sample("period", dist.LogNormal(0.0, 1.0)) + kernel_params = { + "k_length": squeezer(length), "k_scale": squeezer(scale), + "period": squeezer(period) if self.data_kernel_name == "Periodic" else None + } + return kernel_params diff --git a/gpax/spm.py b/gpax/models/spm.py similarity index 100% rename from gpax/spm.py rename to gpax/models/spm.py diff --git a/gpax/vgp.py b/gpax/models/vgp.py similarity index 86% rename from gpax/vgp.py rename to gpax/models/vgp.py index 64cabb0..17eb6a8 100644 --- a/gpax/vgp.py +++ b/gpax/models/vgp.py @@ -31,20 +31,29 @@ class vExactGP(ExactGP): kernel_prior: optional custom priors over kernel hyperparameters (uses LogNormal(0,1) by default) mean_fn_prior: optional priors over mean function parameters noise_prior: optional custom prior for observation noise + noise_prior_dist: + Optional custom prior distribution over observational noise. Defaults to LogNormal(0,1). + lengthscale_prior_dist: + Optional custom prior distribution over kernel lengthscale. Defaults to LogNormal(0, 1). + """ def __init__(self, input_dim: int, kernel: str, mean_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None, kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, mean_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, - noise_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None + noise_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, + noise_prior_dist: Optional[dist.Distribution] = None, + lengthscale_prior_dist: Optional[dist.Distribution] = None ) -> None: args = (input_dim, kernel, mean_fn, kernel_prior, mean_fn_prior, noise_prior) super(vExactGP, self).__init__(*args) + self.noise_prior_dist = noise_prior_dist + self.lengthscale_prior_dist = lengthscale_prior_dist - def model(self, + def model(self, X: jnp.ndarray, - y: jnp.ndarray = None, + y: jnp.ndarray = None, **kwargs: float ) -> None: """GP probabilistic model with inputs X and vector-valued targets y""" @@ -58,11 +67,10 @@ def model(self, else: kernel_params = self._sample_kernel_params(task_dim=task_dim) # Sample noise for each task - with numpyro.plate("noise_plate", task_dim): - if self.noise_prior: - noise = self.noise_prior() - else: - noise = numpyro.sample("noise", dist.LogNormal(0.0, 1.0)) + if self.noise_prior: # this will be removed in the future releases + noise = self.noise_prior() + else: + noise = self._sample_noise(task_dim) # Add mean function (if any) if self.mean_fn is not None: args = [X] @@ -80,6 +88,36 @@ def model(self, obs=y, ) + def _sample_noise(self, task_dim) -> jnp.ndarray: + if self.noise_prior_dist is not None: + noise_dist = self.noise_prior_dist + else: + noise_dist = dist.LogNormal(0, 1) + with numpyro.plate("noise_plate", task_dim): + noise = numpyro.sample("noise", noise_dist) + return noise + + def _sample_kernel_params(self, task_dim: int = None) -> Dict[str, jnp.ndarray]: + """ + Sample kernel parameters with default + weakly-informative log-normal priors + """ + if self.lengthscale_prior_dist is not None: + length_dist = self.lengthscale_prior_dist + else: + length_dist = dist.LogNormal(0.0, 1.0) + with numpyro.plate("plate_1", task_dim, dim=-2): # task dimension + with numpyro.plate('lengthscale', self.kernel_dim, dim=-1): # allows using ARD kernel for kernel_dim > 1 + length = numpyro.sample("k_length", dist.LogNormal(0.0, 1.0)) + with numpyro.plate("plate_2", task_dim): # task dimension' + scale = numpyro.sample("k_scale", length_dist) + if self.kernel_name == 'Periodic': + period = numpyro.sample("period", dist.LogNormal(0.0, 1.0)) + kernel_params = { + "k_length": length, "k_scale": scale, + "period": period if self.kernel_name == "Periodic" else None} + return kernel_params + @partial(jit, static_argnames='self') def _get_mvn_posterior(self, X_train: jnp.ndarray, y_train: jnp.ndarray, @@ -133,23 +171,6 @@ def get_mvn_posterior(self, self._get_mvn_posterior)(*vmap_args, noiseless=noiseless, jitter=jitter) return mean, cov - def _sample_kernel_params(self, task_dim: int = None) -> Dict[str, jnp.ndarray]: - """ - Sample kernel parameters with default - weakly-informative log-normal priors - """ - with numpyro.plate("plate_1", task_dim, dim=-2): # task dimension - with numpyro.plate('lengthscale', self.kernel_dim, dim=-1): # allows using ARD kernel for kernel_dim > 1 - length = numpyro.sample("k_length", dist.LogNormal(0.0, 1.0)) - with numpyro.plate("plate_2", task_dim): # task dimension' - scale = numpyro.sample("k_scale", dist.LogNormal(0.0, 1.0)) - if self.kernel_name == 'Periodic': - period = numpyro.sample("period", dist.LogNormal(0.0, 1.0)) - kernel_params = { - "k_length": length, "k_scale": scale, - "period": period if self.kernel_name == "Periodic" else None} - return kernel_params - def predict_in_batches(self, rng_key: jnp.ndarray, X_new: jnp.ndarray, batch_size: int = 100, samples: Optional[Dict[str, jnp.ndarray]] = None, diff --git a/gpax/bnn/vi_ibnn.py b/gpax/models/vi_ibnn.py similarity index 98% rename from gpax/bnn/vi_ibnn.py rename to gpax/models/vi_ibnn.py index 08dddf6..b580366 100644 --- a/gpax/bnn/vi_ibnn.py +++ b/gpax/models/vi_ibnn.py @@ -13,7 +13,7 @@ import numpyro import numpyro.distributions as dist -from ..vigp import viGP +from .vigp import viGP from ..kernels import get_kernel diff --git a/gpax/models/vi_mtdkl.py b/gpax/models/vi_mtdkl.py new file mode 100644 index 0000000..6aad009 --- /dev/null +++ b/gpax/models/vi_mtdkl.py @@ -0,0 +1,230 @@ +from functools import partial +from typing import Callable, Dict, Optional, Tuple + +import jax +from jax import jit +import jax.numpy as jnp +import numpy as onp +import numpyro +import numpyro.distributions as dist +from numpyro.contrib.module import random_haiku_module + +from ..kernels import LCMKernel +from .vidkl import viDKL + + +class viMTDKL(viDKL): + """ + Implementation of the variational infernece-based deep kernel learning + + Args: + input_dim: + Number of input dimensions, not counting the column with task indices (if any) + z_dim: + Latent space dimensionality (defaults to 2) + data_kernel: + Kernel function operating on data inputs ('RBF', 'Matern', 'Periodic', or a custom function) + num_latents: + Number of latent functions. Typically equal to or less than the number of tasks + shared_input_space: + If True (default), assumes that all tasks share the same input space and + uses a multivariate kernel (Kronecker product). If False, assumes that different tasks + have different number of observations and uses a multitask kernel (elementwise multiplication). + In that case, the task indices must be appended as the last column of the input vector. + num_tasks: + Number of tasks. This is only needed if `shared_input_space` is True. + rank: + Rank of the weight matrix in the task kernel. Cannot be larger than the number of tasks. + Higher rank implies higher correlation. Uses *(num_tasks - 1)* when not specified. + data_kernel_prior: + Optional priors over kernel hyperparameters; uses LogNormal(0,1) by default + nn: + Custom neural network ('feature extractor'); uses a 3-layer MLP + with ReLU activations by default + latent_prior: + Optional prior over the latent space (NN embedding); uses none by default + guide: + Auto-guide option, use 'delta' (default) or 'normal' + W_prior_dist: + Optional custom prior distribution over W in the task kernel, :math:`WW^T + diag(v)`. + Defaults to Normal(0, 10). + v_prior_dist: + Optional custom prior distribution over v in the task kernel, :math:`WW^T + diag(v)`. + Must be non-negative. Defaults to LogNormal(0, 1) + task_kernel_prior: + Optional custom priors over task kernel parameters; + Defaults to Normal(0, 10) for weights W and LogNormal(0, 1) for variances v. + + **kwargs: + Optional custom prior distributions over observational noise (noise_dist_prior) + and kernel lengthscale (lengthscale_prior_dist) + """ + + def __init__(self, input_dim: int, z_dim: int = 2, data_kernel: str = 'RBF', + num_latents: int = None, shared_input_space: bool = True, + num_tasks: int = None, rank: Optional[int] = None, + data_kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, + nn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, + guide: str = 'delta', + W_prior_dist: Optional[dist.Distribution] = None, + v_prior_dist: Optional[dist.Distribution] = None, + task_kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, + **kwargs) -> None: + args = (input_dim, z_dim, None, None, nn, None, guide) + super(viMTDKL, self).__init__(*args, **kwargs) + if shared_input_space: + if num_tasks is None: + raise ValueError("Please specify num_tasks") + else: + if num_latents is None: + raise ValueError("Please specify num_latents") + self.num_tasks = num_tasks + self.num_latents = num_tasks if num_latents is None else num_latents + self.rank = rank + self.kernel = LCMKernel( + data_kernel, shared_input_space, num_tasks, **kwargs) + self.data_kernel_prior = data_kernel_prior + self.task_kernel_prior = task_kernel_prior + self.shared_input = shared_input_space + self.W_prior_dist = W_prior_dist + self.v_prior_dist = v_prior_dist + + def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs) -> None: + """Multitask DKL probabilistic model""" + + # Check that we have necessary info for sampling kernel params + if not self.shared_input and self.num_tasks is None: + self.num_tasks = len(onp.unique(self.X_train[:, -1])) + + if self.rank is None: + self.rank = self.num_tasks - 1 + + # NN part + feature_extractor = random_haiku_module( + "feature_extractor", self.nn_module, input_shape=(1, *self.data_dim), + prior=(lambda name, shape: dist.Cauchy() if name.startswith("b") else dist.Normal())) + z = feature_extractor(X if self.shared_input else X[:, :-1]) + if not self.shared_input: + z = jnp.column_stack((z, X[:, -1])) + + # Initialize GP kernel mean function at zeros + if self.shared_input: + f_loc = jnp.zeros(self.num_tasks * X.shape[0]) + else: + f_loc = jnp.zeros(X.shape[0]) + + # Sample data kernel parameters + if self.data_kernel_prior: + data_kernel_params = self.data_kernel_prior() + else: + data_kernel_params = self._sample_kernel_params() + + # Sample task kernel parameters + if self.task_kernel_prior: + task_kernel_params = self.task_kernel_prior() + else: + task_kernel_params = self._sample_task_kernel_params() + + # Combine two dictionaries with parameters + kernel_params = {**data_kernel_params, **task_kernel_params} + + # Sample noise + if self.noise_prior: # this will be removed in the future releases + noise = self.noise_prior() + else: + noise = self._sample_noise() + + # Compute multitask_kernel + k = self.kernel(z, z, kernel_params, noise, **kwargs) + + # Sample y according to the standard Gaussian process formula + numpyro.sample( + "y", + dist.MultivariateNormal(loc=f_loc, covariance_matrix=k), + obs=y, + ) + + def _sample_noise(self): + """Sample observational noise""" + if self.noise_prior_dist is not None: + noise_dist = self.noise_prior_dist + else: + noise_dist = dist.LogNormal( + jnp.zeros(self.num_tasks), + jnp.ones(self.num_tasks)) + + noise = numpyro.sample("noise", noise_dist.to_event(1)) + return noise + + def _sample_task_kernel_params(self): + """ + Sample task kernel parameters with default weakly-informative priors + or custom priors for all the latent functions + """ + if self.W_prior_dist is not None: + W_dist = self.W_prior_dist + else: + W_dist = dist.Normal( + jnp.zeros(shape=(self.num_latents, self.num_tasks, self.rank)), # loc + 10*jnp.ones(shape=(self.num_latents, self.num_tasks, self.rank)) # var + ) + if self.v_prior_dist is not None: + v_dist = self.v_prior_dist + else: + v_dist = dist.LogNormal( + jnp.zeros(shape=(self.num_latents, self.num_tasks)), # loc + jnp.ones(shape=(self.num_latents, self.num_tasks)) # var + ) + with numpyro.plate("latent_plate_task", self.num_latents): + W = numpyro.sample("W", W_dist.to_event(2)) + v = numpyro.sample("v", v_dist.to_event(1)) + return {"W": W, "v": v} + + def _sample_kernel_params(self): + """ + Sample data ("base") kernel parameters with default weakly-informative + priors for all the latent functions + """ + squeezer = lambda x: x.squeeze() if self.num_latents > 1 else x + with numpyro.plate("latent_plate_data", self.num_latents, dim=-2): + with numpyro.plate("ard", self.kernel_dim, dim=-1): + length = numpyro.sample("k_length", dist.LogNormal(0.0, 1.0)) + scale = numpyro.sample("k_scale", dist.Normal(1.0, 1e-4)) + return {"k_length": squeezer(length), "k_scale": squeezer(scale)} + + @partial(jit, static_argnames='self') + def get_mvn_posterior(self, + X_train: jnp.ndarray, + y_train: jnp.ndarray, + X_new: jnp.ndarray, + nn_params: Dict[str, jnp.ndarray], + k_params: Dict[str, jnp.ndarray], + noiseless: bool = False, + **kwargs + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + Returns predictive mean and covariance at new points + (mean and cov, where cov.diagonal() is 'uncertainty') + given a single set of DKL parameters + """ + noise = k_params.pop("noise") + noise_p = noise * (1 - jnp.array(noiseless, int)) + # embed data into the latent space + z_train = self.nn_module.apply( + nn_params, jax.random.PRNGKey(0), + X_train if self.shared_input else X_train[:, :-1]) + z_test = self.nn_module.apply( + nn_params, jax.random.PRNGKey(0), + X_new if self.shared_input else X_new[:, :-1]) + if self.shared_input: + z_train = jnp.column_stack((z_train, X_train[:, -1])) + z_test = jnp.column_stack((z_test, X_new[:, -1])) + # compute kernel matrices for train and test data + k_pp = self.kernel(z_test, z_test, k_params, noise_p, **kwargs) + k_pX = self.kernel(z_test, z_train, k_params, jitter=0.0) + k_XX = self.kernel(z_train, z_train, k_params, noise, **kwargs) + # compute the predictive covariance and mean + K_xx_inv = jnp.linalg.inv(k_XX) + cov = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX))) + mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_train)) + return mean, cov diff --git a/gpax/bnn/vidkl.py b/gpax/models/vidkl.py similarity index 97% rename from gpax/bnn/vidkl.py rename to gpax/models/vidkl.py index 225ce52..eb38bd9 100644 --- a/gpax/bnn/vidkl.py +++ b/gpax/models/vidkl.py @@ -20,8 +20,7 @@ from jax import jit import haiku as hk -from ..gp import ExactGP -from ..kernels import get_kernel +from .gp import ExactGP from ..utils import get_haiku_dict @@ -45,6 +44,11 @@ class viDKL(ExactGP): Optional prior over the latent space (NN embedding); uses none by default guide: Auto-guide option, use 'delta' (default) or 'normal' + + **kwargs: + Optional custom prior distributions over observational noise (noise_dist_prior) + and kernel lengthscale (lengthscale_prior_dist) + Examples: @@ -66,9 +70,9 @@ def __init__(self, input_dim: int, z_dim: int = 2, kernel: str = 'RBF', kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, nn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, latent_prior: Optional[Callable[[jnp.ndarray], Dict[str, jnp.ndarray]]] = None, - guide: str = 'delta' + guide: str = 'delta', **kwargs ) -> None: - super(viDKL, self).__init__(input_dim, kernel, None, kernel_prior) + super(viDKL, self).__init__(input_dim, kernel, None, kernel_prior, **kwargs) if guide not in ['delta', 'normal']: raise NotImplementedError("Select guide between 'delta' and 'normal'") nn_module = nn if nn else MLP @@ -95,7 +99,7 @@ def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs) -> None: else: kernel_params = self._sample_kernel_params() # Sample noise - noise = numpyro.sample("noise", dist.LogNormal(0.0, 1.0)) + noise = self._sample_noise() # GP's mean function f_loc = jnp.zeros(z.shape[0]) # compute kernel @@ -192,7 +196,7 @@ def get_mvn_posterior(self, (mean and cov, where cov.diagonal() is 'uncertainty') given a single set of DKL parameters """ - noise = k_params["noise"] + noise = k_params.pop("noise") noise_p = noise * (1 - jnp.array(noiseless, int)) # embed data into the latent space z_train = self.nn_module.apply( diff --git a/gpax/vigp.py b/gpax/models/vigp.py similarity index 95% rename from gpax/vigp.py rename to gpax/models/vigp.py index aa41d10..f68781c 100644 --- a/gpax/vigp.py +++ b/gpax/models/vigp.py @@ -13,6 +13,7 @@ import jaxlib import jax.numpy as jnp import numpyro +import numpyro.distributions as dist from numpyro.infer import SVI, Trace_ELBO from numpyro.infer.autoguide import AutoDelta, AutoNormal @@ -58,8 +59,11 @@ def __init__(self, input_dim: int, kernel: str, kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, mean_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, noise_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, + noise_prior_dist: Optional[dist.Distribution] = None, + lengthscale_prior_dist: Optional[dist.Distribution] = None, guide: str = 'delta') -> None: - args = input_dim, kernel, mean_fn, kernel_prior, mean_fn_prior, noise_prior + args = (input_dim, kernel, mean_fn, kernel_prior, mean_fn_prior, noise_prior, + noise_prior_dist, lengthscale_prior_dist) super(viGP, self).__init__(*args) self.X_train = None self.y_train = None diff --git a/tests/test_acq.py b/tests/test_acq.py index 21ecbaf..a4fd3bd 100644 --- a/tests/test_acq.py +++ b/tests/test_acq.py @@ -6,9 +6,8 @@ sys.path.insert(0, "../gpax/") -from gpax.gp import ExactGP -from gpax.bnn.dkl import DKL -from gpax.bnn.vidkl import viDKL +from gpax.models.gp import ExactGP +from gpax.models.vidkl import viDKL from gpax.utils import get_keys from gpax.acquisition import EI, UCB, UE, Thompson diff --git a/tests/test_corgp.py b/tests/test_corgp.py new file mode 100644 index 0000000..ca3b6e6 --- /dev/null +++ b/tests/test_corgp.py @@ -0,0 +1,52 @@ +import sys +import pytest +import numpy as onp +import jax.numpy as jnp +import numpyro +from numpy.testing import assert_ + +sys.path.insert(0, "../gpax/") + +from gpax.models.corgp import CoregGP +from gpax.utils import get_keys + + +def get_dummy_data(): + X = onp.linspace(1, 2, 20) + 0.1 * onp.random.randn(20,) + y = (10 * X**2) + return jnp.array(X), jnp.array(y) + + +def attach_indices(X, num_tasks): + indices = onp.random.randint(0, num_tasks, size=len(X)) + return onp.column_stack([X, indices]) + + +def dummy_mean_fn(x, params): + return params["a"] * x[:, :-1]**params["b"] + + +def dummy_mean_fn_priors(): + a = numpyro.sample("a", numpyro.distributions.LogNormal(0, 1)) + b = numpyro.sample("b", numpyro.distributions.Normal(3, 1)) + return {"a": a, "b": b} + + +@pytest.mark.parametrize("num_tasks", [2, 3]) +@pytest.mark.parametrize("data_kernel", ['RBF', 'Matern', 'Periodic']) +def test_fit_corgp(data_kernel, num_tasks): + rng_key = get_keys()[0] + X, y = get_dummy_data() + X = attach_indices(X, num_tasks) + m = CoregGP(1, data_kernel) + m.fit(rng_key, X, y, num_warmup=50, num_samples=50) + assert_(isinstance(m.get_samples(), dict)) + + +def test_fit_corgp_meanfn(): + rng_key = get_keys()[0] + X, y = get_dummy_data() + X = attach_indices(X, 2) + m = CoregGP(1, 'Matern', mean_fn=dummy_mean_fn, mean_fn_prior=dummy_mean_fn_priors) + m.fit(rng_key, X, y, num_warmup=50, num_samples=50) + assert_(isinstance(m.get_samples(), dict)) diff --git a/tests/test_dkl.py b/tests/test_dkl.py index c01a0cc..52e75e6 100644 --- a/tests/test_dkl.py +++ b/tests/test_dkl.py @@ -7,7 +7,7 @@ sys.path.insert(0, "../gpax/") -from gpax.bnn.dkl import DKL +from gpax.models.dkl import DKL from gpax.utils import get_keys diff --git a/tests/test_gp.py b/tests/test_gp.py index 22c09bf..25edadf 100644 --- a/tests/test_gp.py +++ b/tests/test_gp.py @@ -8,7 +8,7 @@ sys.path.insert(0, "../gpax/") -from gpax.gp import ExactGP +from gpax.models.gp import ExactGP from gpax.utils import get_keys @@ -98,6 +98,35 @@ def test_sample_periodic_kernel(): assert isinstance(v, jnp.ndarray) +def test_sample_noise(): + m = ExactGP(1, 'RBF') + with numpyro.handlers.seed(rng_seed=1): + noise = m._sample_noise() + assert isinstance(noise, jnp.ndarray) + + +def test_sample_noise_custom_prior(): + noise_prior_dist = numpyro.distributions.HalfNormal(.1) + m1 = ExactGP(1, 'RBF') + with numpyro.handlers.seed(rng_seed=1): + noise1 = m1._sample_noise() + m2 = ExactGP(1, 'RBF', noise_prior_dist=noise_prior_dist) + with numpyro.handlers.seed(rng_seed=1): + noise2 = m2._sample_noise() + assert_(not onp.array_equal(noise1, noise2)) + + +def test_sample_kernel_custom_lscale_prior(): + lscale_prior_dist = numpyro.distributions.Normal(20, .1) + m1 = ExactGP(1, 'RBF') + with numpyro.handlers.seed(rng_seed=1): + lscale1 = m1._sample_kernel_params()["k_length"] + m2 = ExactGP(1, 'RBF', lengthscale_prior_dist=lscale_prior_dist) + with numpyro.handlers.seed(rng_seed=1): + lscale2 = m2._sample_kernel_params()["k_length"] + assert_(not onp.array_equal(lscale1, lscale2)) + + @pytest.mark.parametrize("kernel", ['RBF', 'Matern']) def test_fit_with_custom_kernel_priors(kernel): rng_key = get_keys()[0] diff --git a/tests/test_ibnn.py b/tests/test_ibnn.py index 62088f7..c69dcc5 100644 --- a/tests/test_ibnn.py +++ b/tests/test_ibnn.py @@ -2,14 +2,12 @@ import pytest import numpy as onp import jax.numpy as jnp -import jax -import numpyro -from numpy.testing import assert_equal, assert_array_equal, assert_ +from numpy.testing import assert_equal sys.path.insert(0, "../gpax/") -from gpax.bnn import iBNN, vi_iBNN +from gpax.models import iBNN, vi_iBNN from gpax.utils import get_keys diff --git a/tests/test_kernels.py b/tests/test_kernels.py new file mode 100644 index 0000000..ff78c36 --- /dev/null +++ b/tests/test_kernels.py @@ -0,0 +1,243 @@ +import sys +import pytest +import numpy as onp +import jax.numpy as jnp +from numpy.testing import assert_equal, assert_ + +sys.path.insert(0, "../gpax/") + +from gpax.kernels import (RBFKernel, MaternKernel, PeriodicKernel, + index_kernel, nngp_erf, nngp_relu, NNGPKernel, + MultitaskKernel, MultivariateKernel, LCMKernel) + + +@pytest.mark.parametrize("kernel", [RBFKernel, MaternKernel]) +@pytest.mark.parametrize("dim", [1, 2]) +def test_data_kernel_shapes(kernel, dim): + x1 = onp.random.randn(5, dim) + x2 = onp.random.randn(5, dim) + params = {"k_length": jnp.array(1.0), "k_scale": jnp.array(1.0)} + k = kernel(x1, x2, params) + assert_equal(k.shape, (5, 5)) + + +@pytest.mark.parametrize("dim", [1, 2]) +def test_periodkernel_shapes(dim): + x1 = onp.random.randn(5, dim) + x2 = onp.random.randn(5, dim) + params = {"k_length": jnp.array(1.0), "k_scale": jnp.array(1.0), "period": jnp.array(1.0)} + k = PeriodicKernel(x1, x2, params) + assert_equal(k.shape, (5, 5)) + + +@pytest.mark.parametrize("kernel", [RBFKernel, MaternKernel]) +@pytest.mark.parametrize("dim", [1, 2]) +def test_data_kernel_ard_shapes(kernel, dim): + x1 = onp.random.randn(5, dim) + x2 = onp.random.randn(5, dim) + params = {"k_length": jnp.ones(dim), "k_scale": jnp.array(1.0)} + k = kernel(x1, x2, params) + assert_equal(k.shape, (5, 5)) + + +def test_index_kernel_shapes(): + indices1 = jnp.array([0, 1, 2]) + indices2 = jnp.array([2, 1, 0]) + params = {"W": jnp.array([[1, 2], [3, 4]]), "v": jnp.array([1, 2])} + result = index_kernel(indices1, indices2, params) + assert_(result.shape == (len(indices1), len(indices2)), "Incorrect shape of result") + + +def test_index_kernel_shapes_uneven_obs(): + indices1 = jnp.array([1, 2]) + indices2 = jnp.array([2, 1, 0]) + params = {"W": jnp.array([[1, 2], [3, 4]]), "v": jnp.array([1, 2])} + result = index_kernel(indices1, indices2, params) + assert_(result.shape == (len(indices1), len(indices2)), "Incorrect shape of result") + + +def test_index_kernel_computations(): + indices1 = jnp.array([0, 1]) + indices2 = jnp.array([1, 0]) + params = {"W": jnp.array([[1, 0], [0, 1]]), "v": jnp.array([1, 0])} + result = index_kernel(indices1, indices2, params) + expected_result = jnp.array([[0, 2], [1, 0]]) + assert_(jnp.allclose(result, expected_result), "Incorrect computations") + + +@pytest.mark.parametrize("depth", [1, 2, 3]) +@pytest.mark.parametrize("kernel", [nngp_erf, nngp_relu]) +@pytest.mark.parametrize("dim", [1, 2]) +def test_nngp_shapes(kernel, dim, depth): + x1 = onp.random.randn(1, dim) + x2 = onp.random.randn(1, dim) + var_b = jnp.array(1.0) + var_w = jnp.array(1.0) + k = kernel(x1, x2, var_b, var_w, depth) + assert_equal(k.shape, (1,)) + + +@pytest.mark.parametrize("depth", [1, 2, 3]) +@pytest.mark.parametrize("activation", ["erf", "relu"]) +@pytest.mark.parametrize("dim", [1, 2]) +def test_NNGPKernel(activation, dim, depth): + x1 = onp.random.randn(5, dim) + x2 = onp.random.randn(5, dim) + params = {"var_b": jnp.array(1.0), "var_w": jnp.array(1.0)} + kernel = NNGPKernel(activation, depth) + k = kernel(x1, x2, params) + assert_equal(k.shape, (5, 5)) + + +def test_NNGPKernel_activations(): + x1 = onp.random.randn(5, 1) + x2 = onp.random.randn(5, 1) + params = {"var_b": jnp.array(1.0), "var_w": jnp.array(1.0)} + kernel1 = NNGPKernel(activation='erf') + k1 = kernel1(x1, x2, params) + kernel2 = NNGPKernel(activation='relu') + k2 = kernel2(x1, x2, params) + assert_(not jnp.allclose(k1, k2, rtol=1e-3)) + + +def test_MultiTaskKernel(): + base_kernel = 'RBF' + mtkernel = MultitaskKernel(base_kernel) + assert_(callable(mtkernel), "The result of MultitaskKernel should be a function.") + + +@pytest.mark.parametrize("data_kernel", [RBFKernel, MaternKernel]) +@pytest.mark.parametrize("dim", [1, 2]) +def test_multitask_kernel_shapes_test_noiseless(data_kernel, dim): + x1 = onp.random.randn(5, dim) + x2 = onp.random.randn(3, dim) + x1 = onp.column_stack([x1, onp.zeros_like(x1)]) + x2 = onp.column_stack([x2, onp.ones_like(x2)]) + x12 = onp.vstack([x1, x2]) + params = {"k_length": jnp.array(1.0), "k_scale": jnp.array(1.0), + "W": jnp.array([[1, 0], [0, 1]]), "v": jnp.array([1, 0])} + mtkernel = MultitaskKernel(data_kernel) + k = mtkernel(x12, x12, params) + assert_equal(k.shape, (len(x12), len(x12))) + + +@pytest.mark.parametrize("data_kernel", [RBFKernel, MaternKernel]) +@pytest.mark.parametrize("dim", [1, 2]) +def test_multitask_kernel_shapes_test_noisy(data_kernel, dim): + x1 = onp.random.randn(5, dim) + x2 = onp.random.randn(3, dim) + x1 = onp.column_stack([x1, onp.zeros_like(x1)]) + x2 = onp.column_stack([x2, onp.ones_like(x2)]) + x12 = onp.vstack([x1, x2]) + params = {"k_length": jnp.array(1.0), "k_scale": jnp.array(1.0), + "W": jnp.array([[1, 0], [0, 1]]), "v": jnp.array([1, 0])} + noise = jnp.array([1.0, 1.0]) + mtkernel = MultitaskKernel(data_kernel) + k = mtkernel(x12, x12, params, noise) + assert_equal(k.shape, (len(x12), len(x12))) + + +@pytest.mark.parametrize("data_kernel", [RBFKernel, MaternKernel]) +@pytest.mark.parametrize("dim", [1, 2]) +def test_multitask_kernel_shapes_train(data_kernel, dim): + x1 = onp.random.randn(5, dim) + x2 = onp.random.randn(5, dim) + x1 = onp.column_stack([x1, onp.zeros_like(x1)]) + x2 = onp.column_stack([x2, onp.ones_like(x2)]) + x12 = onp.vstack([x1, x2]) + params = {"k_length": jnp.array(1.0), "k_scale": jnp.array(1.0), + "W": jnp.array([[1, 0], [0, 1]]), "v": jnp.array([1, 0])} + noise = jnp.array([1.0, 1.0]) + mtkernel = MultitaskKernel(data_kernel) + k = mtkernel(x12, x12, params, noise) + assert_equal(k.shape, (len(x12), len(x12))) + + +def test_MultiVariateKernel(): + base_kernel = 'RBF' + num_tasks = 2 + mtkernel = MultivariateKernel(base_kernel, num_tasks) + assert_(callable(mtkernel), "The result of MultiVariateKernel should be a function.") + + +@pytest.mark.parametrize("data_kernel", [RBFKernel, MaternKernel]) +@pytest.mark.parametrize("num_tasks", [2, 3]) +@pytest.mark.parametrize("rank", [1, 2]) +@pytest.mark.parametrize("dim", [1, 2]) +def test_multivariate_kernel_shapes_test_noisy(data_kernel, dim, num_tasks, rank): + x1 = onp.random.randn(5, dim) + x2 = onp.random.randn(3, dim) + params = {"k_length": jnp.array(1.0), "k_scale": jnp.array(1.0), + "W": jnp.ones((num_tasks, rank)), "v": jnp.ones(num_tasks)} + noise = jnp.ones(num_tasks) + mtkernel = MultivariateKernel(data_kernel, num_tasks) + k = mtkernel(x1, x2, params, noise) + assert_equal(k.shape, (num_tasks*len(x1), num_tasks*len(x2))) + + +@pytest.mark.parametrize("data_kernel", [RBFKernel, MaternKernel]) +@pytest.mark.parametrize("num_tasks", [2, 3]) +@pytest.mark.parametrize("rank", [1, 2]) +@pytest.mark.parametrize("dim", [1, 2]) +def test_multivariate_kernel_shapes_test_noiseless(data_kernel, dim, num_tasks, rank): + x1 = onp.random.randn(5, dim) + x2 = onp.random.randn(3, dim) + params = {"k_length": jnp.array(1.0), "k_scale": jnp.array(1.0), + "W": jnp.ones((num_tasks, rank)), "v": jnp.ones(num_tasks)} + mtkernel = MultivariateKernel(data_kernel, num_tasks) + k = mtkernel(x1, x2, params) + assert_equal(k.shape, (num_tasks*len(x1), num_tasks*len(x2))) + + +@pytest.mark.parametrize("data_kernel", [RBFKernel, MaternKernel]) +@pytest.mark.parametrize("num_tasks", [2, 3]) +@pytest.mark.parametrize("rank", [1, 2]) +@pytest.mark.parametrize("dim", [1, 2]) +def test_multivariate_kernel_shapes_train(data_kernel, dim, num_tasks, rank): + x1 = onp.random.randn(5, dim) + x2 = onp.random.randn(5, dim) + params = {"k_length": jnp.array(1.0), "k_scale": jnp.array(1.0), + "W": jnp.ones((num_tasks, rank)), "v": jnp.ones(num_tasks)} + noise = jnp.ones(num_tasks) + mtkernel = MultivariateKernel(data_kernel, num_tasks) + k = mtkernel(x1, x2, params, noise) + assert_equal(k.shape, (num_tasks*len(x1), num_tasks*len(x2))) + + +def test_LCMKernel(): + base_kernel = 'RBF' + lcm_kernel = LCMKernel(base_kernel) + assert_(callable(lcm_kernel), "The result of MultitaskKernel should be a function.") + + +@pytest.mark.parametrize("num_latent", [1, 2]) +@pytest.mark.parametrize("data_kernel", [RBFKernel, MaternKernel]) +@pytest.mark.parametrize("dim", [1, 2]) +def test_LCMKernel_shapes_multitask(data_kernel, dim, num_latent): + x1 = onp.random.randn(5, dim) + x2 = onp.random.randn(3, dim) + x1 = onp.column_stack([x1, onp.zeros_like(x1)]) + x2 = onp.column_stack([x2, onp.ones_like(x2)]) + x12 = onp.vstack([x1, x2]) + params = {"k_length": jnp.ones(num_latent), "k_scale": jnp.ones(num_latent), + "W": jnp.ones((num_latent, 2, 2)), "v": jnp.ones((num_latent, 2))} + noise = jnp.array([1.0, 1.0]) + mtkernel = LCMKernel(data_kernel, shared_input_space=False) + k = mtkernel(x12, x12, params, noise) + assert_equal(k.shape, (len(x12), len(x12))) + + +@pytest.mark.parametrize("num_latent", [1, 2]) +@pytest.mark.parametrize("data_kernel", [RBFKernel, MaternKernel]) +@pytest.mark.parametrize("num_tasks", [2, 3]) +@pytest.mark.parametrize("rank", [1, 2]) +@pytest.mark.parametrize("dim", [1, 2]) +def test_LCMKernel_shapes_multivariate(data_kernel, dim, num_latent, rank, num_tasks): + x1 = onp.random.randn(5, dim) + x2 = onp.random.randn(3, dim) + params = {"k_length": jnp.ones(num_latent), "k_scale": jnp.ones(num_latent), + "W": jnp.ones((num_latent, num_tasks, rank)), "v": jnp.ones((num_latent, num_tasks))} + noise = jnp.ones(num_tasks) + mtkernel = LCMKernel(data_kernel, shared_input_space=True, num_tasks=num_tasks) + k = mtkernel(x1, x2, params, noise) + assert_equal(k.shape, (num_tasks*len(x1), num_tasks*len(x2))) diff --git a/tests/test_mtgp.py b/tests/test_mtgp.py new file mode 100644 index 0000000..b081249 --- /dev/null +++ b/tests/test_mtgp.py @@ -0,0 +1,104 @@ +import sys +import pytest +import numpy as onp +import jax.numpy as jnp +import numpyro +from numpy.testing import assert_ + +sys.path.insert(0, "../gpax/") + +from gpax.models.mtgp import MultiTaskGP +from gpax.utils import get_keys + + +def get_dummy_data(): + X = onp.linspace(1, 2, 20) + 0.1 * onp.random.randn(20,) + y = (10 * X**2) + return jnp.array(X), jnp.array(y) + + +def attach_indices(X, num_tasks): + indices = onp.random.randint(0, num_tasks, size=len(X)) + return onp.column_stack([X, indices]) + + +def dummy_mean_fn(x, params): + return params["a"] * x[:, :-1]**params["b"] + + +def dummy_mean_fn_priors(): + a = numpyro.sample("a", numpyro.distributions.LogNormal(0, 1)) + b = numpyro.sample("b", numpyro.distributions.Normal(3, 1)) + return {"a": a, "b": b} + + +@pytest.mark.parametrize("num_latents", [1, 2]) +@pytest.mark.parametrize("num_tasks", [2, 3]) +@pytest.mark.parametrize("data_kernel", ['RBF', 'Matern', 'Periodic']) +def test_fit_multitask(data_kernel, num_tasks, num_latents): + rng_key = get_keys()[0] + X, y = get_dummy_data() + X = attach_indices(X, num_tasks) + m = MultiTaskGP(1, data_kernel, num_latents=num_latents, shared_input_space=False) + m.fit(rng_key, X, y, num_warmup=50, num_samples=50) + assert_(isinstance(m.get_samples(), dict)) + + +@pytest.mark.parametrize("num_latents", [1, 2]) +@pytest.mark.parametrize("num_tasks", [2, 3]) +@pytest.mark.parametrize("data_kernel", ['RBF', 'Matern', 'Periodic']) +def test_fit_multivariate(data_kernel, num_tasks, num_latents): + rng_key = get_keys()[0] + X, y = get_dummy_data() + y = jnp.repeat(y[:, None], num_tasks, axis=1).reshape(-1) + m = MultiTaskGP( + 1, data_kernel, num_latents=num_latents, + num_tasks=num_tasks, shared_input_space=True) + m.fit(rng_key, X, y, num_warmup=50, num_samples=50) + assert_(isinstance(m.get_samples(), dict)) + + +def test_fit_multitask_meanfn(): + rng_key = get_keys()[0] + X, y = get_dummy_data() + X = attach_indices(X, 2) + m = MultiTaskGP(1, 'Matern', num_latents=2, shared_input_space=False, + mean_fn=dummy_mean_fn, mean_fn_prior=dummy_mean_fn_priors) + m.fit(rng_key, X, y, num_warmup=50, num_samples=50) + assert_(isinstance(m.get_samples(), dict)) + + +def test_sample_kernel_custom_lscale_prior(): + lscale_prior_dist = numpyro.distributions.Normal(20, .1) + m1 = MultiTaskGP(1, 'RBF', num_latents=2, num_tasks=2, rank=2) + with numpyro.handlers.seed(rng_seed=1): + lscale1 = m1._sample_kernel_params()["k_length"] + m2 = MultiTaskGP(1, 'RBF', num_latents=2, num_tasks=2, rank=2, + lengthscale_prior_dist=lscale_prior_dist) + with numpyro.handlers.seed(rng_seed=1): + lscale2 = m2._sample_kernel_params()["k_length"] + assert_(not onp.array_equal(lscale1, lscale2)) + + +def test_sample_task_kernel_custom_W_prior(): + W_prior_dist = numpyro.distributions.Normal(20*jnp.ones((2, 2, 2)), 0.1*jnp.ones((2, 2, 2))) + m1 = MultiTaskGP(1, 'RBF', num_latents=2, num_tasks=2, rank=2) + with numpyro.handlers.seed(rng_seed=1): + W1 = m1._sample_task_kernel_params()["W"] + m2 = MultiTaskGP(1, 'RBF', num_latents=2, num_tasks=2, rank=2, + W_prior_dist=W_prior_dist) + with numpyro.handlers.seed(rng_seed=1): + W2 = m2._sample_task_kernel_params()["W"] + assert_(not onp.array_equal(W1, W2)) + + +def test_sample_task_kernel_custom_v_prior(): + v_prior_dist = numpyro.distributions.Normal(20*jnp.ones((2, 2)), 0.1*jnp.ones((2, 2))) + m1 = MultiTaskGP(1, 'RBF', num_latents=2, num_tasks=2, rank=2) + with numpyro.handlers.seed(rng_seed=1): + v1 = m1._sample_task_kernel_params()["v"] + m2 = MultiTaskGP(1, 'RBF', num_latents=2, num_tasks=2, rank=2, + v_prior_dist=v_prior_dist) + with numpyro.handlers.seed(rng_seed=1): + v2 = m2._sample_task_kernel_params()["v"] + assert_(not onp.array_equal(v1, v2)) diff --git a/tests/test_spm.py b/tests/test_spm.py index e31b96e..b023189 100644 --- a/tests/test_spm.py +++ b/tests/test_spm.py @@ -8,7 +8,7 @@ sys.path.insert(0, "../gpax/") -from gpax.spm import sPM +from gpax.models.spm import sPM from gpax.utils import get_keys diff --git a/tests/test_vgp.py b/tests/test_vgp.py index bbaad76..78b1570 100644 --- a/tests/test_vgp.py +++ b/tests/test_vgp.py @@ -8,7 +8,7 @@ sys.path.insert(0, "../gpax/") -from gpax.vgp import vExactGP +from gpax.models.vgp import vExactGP from gpax.utils import get_keys @@ -83,6 +83,24 @@ def test_sample_periodic_kernel(): assert isinstance(v, jnp.ndarray) +def test_sample_noise(): + m = vExactGP(1, 'RBF') + with numpyro.handlers.seed(rng_seed=1): + noise = m._sample_noise(3) + assert_equal(noise.shape[0], 3) + + +def test_sample_noise_custom_prior(): + noise_prior_dist = numpyro.distributions.HalfNormal(.1) + m1 = vExactGP(1, 'RBF') + with numpyro.handlers.seed(rng_seed=1): + noise1 = m1._sample_noise(3) + m2 = vExactGP(1, 'RBF', noise_prior_dist=noise_prior_dist) + with numpyro.handlers.seed(rng_seed=1): + noise2 = m2._sample_noise(3) + assert_(not onp.array_equal(noise1, noise2)) + + def test_get_mvn_posterior(): X, y = get_dummy_data(unsqueeze=True) X_test, _ = get_dummy_data(unsqueeze=True) diff --git a/tests/test_vidkl.py b/tests/test_vidkl.py index 187570e..9a1c620 100644 --- a/tests/test_vidkl.py +++ b/tests/test_vidkl.py @@ -9,7 +9,7 @@ sys.path.insert(0, "../gpax/") -from gpax.bnn.vidkl import viDKL, MLP +from gpax.models.vidkl import viDKL, MLP from gpax.utils import get_keys diff --git a/tests/test_vigp.py b/tests/test_vigp.py index 9801aa5..777a165 100644 --- a/tests/test_vigp.py +++ b/tests/test_vigp.py @@ -8,7 +8,7 @@ sys.path.insert(0, "../gpax/") -from gpax.vigp import viGP +from gpax.models.vigp import viGP from gpax.utils import get_keys, enable_x64 enable_x64() diff --git a/tests/test_vimtdkl.py b/tests/test_vimtdkl.py new file mode 100644 index 0000000..be2b733 --- /dev/null +++ b/tests/test_vimtdkl.py @@ -0,0 +1,66 @@ +import sys +import pytest +import numpy as onp +import jax.numpy as jnp +from numpy.testing import assert_, assert_equal + +sys.path.insert(0, "../gpax/") + +from gpax.models.vi_mtdkl import viMTDKL +from gpax.utils import get_keys + + +def get_dummy_data(): + X = onp.random.randn(21, 36) + y = onp.random.randn(21,) + return jnp.array(X), jnp.array(y) + + +def attach_indices(X, num_tasks): + indices = onp.random.randint(0, num_tasks, size=len(X)) + return jnp.column_stack([X, indices]) + + +@pytest.mark.parametrize("num_latents", [1, 2]) +@pytest.mark.parametrize("num_tasks", [2, 3]) +@pytest.mark.parametrize("data_kernel", ['RBF', 'Matern']) +def test_fit_multitask(data_kernel, num_tasks, num_latents): + rng_key = get_keys()[0] + X, y = get_dummy_data() + X = attach_indices(X, num_tasks) + m = viMTDKL(X.shape[-1] - 1, 2, data_kernel, num_latents=num_latents, shared_input_space=False) + m.fit(rng_key, X, y, num_steps=10) + assert_(isinstance(m.kernel_params, dict)) + assert_(isinstance(m.nn_params, dict)) + + +@pytest.mark.parametrize("num_latents", [1, 2]) +@pytest.mark.parametrize("num_tasks", [2, 3]) +@pytest.mark.parametrize("data_kernel", ['RBF', 'Matern']) +def test_fit_multitask_shared_input(data_kernel, num_tasks, num_latents): + rng_key = get_keys()[0] + X, y = get_dummy_data() + y = jnp.repeat(y[:, None], num_tasks, axis=1).reshape(-1) + m = viMTDKL(X.shape[-1], 2, data_kernel, num_latents=num_latents, + shared_input_space=True, num_tasks=num_tasks) + m.fit(rng_key, X, y, num_steps=10) + assert_(isinstance(m.kernel_params, dict)) + assert_(isinstance(m.nn_params, dict)) + + +@pytest.mark.parametrize("num_latents", [1, 2]) +@pytest.mark.parametrize("num_tasks", [2, 3]) +@pytest.mark.parametrize("data_kernel", ['RBF', 'Matern']) +def test_fit_predict_multitask(data_kernel, num_tasks, num_latents): + rng_key = get_keys()[0] + X, y = get_dummy_data() + X = attach_indices(X, num_tasks) + m = viMTDKL(X.shape[-1] - 1, 2, data_kernel, num_latents=num_latents, shared_input_space=False) + m.fit(rng_key, X, y, num_steps=10) + X_test, _ = get_dummy_data() + X_test = jnp.column_stack([X_test, jnp.ones(len(X_test))]) + mean, var = m.predict(rng_key, X_test) + assert_(isinstance(mean, jnp.ndarray)) + assert_(isinstance(var, jnp.ndarray)) + assert_equal(len(mean), len(X_test)) + assert_equal(len(var), len(X_test))