Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-fidelity/task DKL and GP + code refactoring #29

Merged
merged 45 commits into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
afd8708
Add task kernels
ziatdinovmax Jul 10, 2023
7b3a926
Add multi-task GP
ziatdinovmax Jul 10, 2023
e76272f
do not use output scale in multitask gp
ziatdinovmax Jul 11, 2023
f0e6be3
Compute bare coreg kernel
ziatdinovmax Jul 11, 2023
ade70dd
jnp.ndarray -> jnp.array
ziatdinovmax Jul 11, 2023
adddb2c
roll back changes to the task kernel
ziatdinovmax Jul 12, 2023
6d16f40
Add multivariate and lcm kernels
ziatdinovmax Jul 14, 2023
fe32b18
Update docstrings
ziatdinovmax Jul 14, 2023
07354e6
Fix noise addition in LCMKernel
ziatdinovmax Jul 14, 2023
e1dbadb
change plate name from k_param to ard
ziatdinovmax Jul 14, 2023
dff0d7f
Pass number of letent functions to LCMKernel
ziatdinovmax Jul 14, 2023
5c80892
add simple multi-task gp (coregionalized kernel)
ziatdinovmax Jul 16, 2023
bfeab1a
Add LCM GP
ziatdinovmax Jul 16, 2023
ed52910
Update imports
ziatdinovmax Jul 16, 2023
a9cab07
Infer num_latents in LCMKernel automatically
ziatdinovmax Jul 17, 2023
2fa2414
Fix bug for zero noise with multivariate kernel
ziatdinovmax Jul 18, 2023
fda9a62
Add optional prior task weights
ziatdinovmax Jul 18, 2023
f4de4a1
remove redundant comments
ziatdinovmax Jul 18, 2023
be85171
Revert to the state before task_weights
ziatdinovmax Jul 18, 2023
a156797
Remove redundant noise variables
ziatdinovmax Jul 18, 2023
4005832
Rename low-rank matrix B into W
ziatdinovmax Jul 19, 2023
5d90a70
Add some kernel tests
ziatdinovmax Jul 19, 2023
4d6e53a
Add more kernel tests
ziatdinovmax Jul 19, 2023
aa96e24
Add more tests
ziatdinovmax Jul 19, 2023
bdb922e
Convert noise int to jnp.array
ziatdinovmax Jul 19, 2023
8f77586
Update docstrings
ziatdinovmax Jul 19, 2023
8c49f64
Add more tests
ziatdinovmax Jul 19, 2023
9284eef
Add more tests
ziatdinovmax Jul 19, 2023
381dc2e
Update docstrings and the order of args
ziatdinovmax Jul 19, 2023
1d3507d
Add more tests
ziatdinovmax Jul 20, 2023
9e4a897
Add more tests
ziatdinovmax Jul 20, 2023
9520ef8
Minor changes
ziatdinovmax Jul 20, 2023
d56e9f2
Use Value Error instead of AssertionError
ziatdinovmax Jul 22, 2023
8a4b40e
Allow passing custom priors w/out numpyro.sample
ziatdinovmax Jul 23, 2023
828ae45
Fix minor bugs
ziatdinovmax Jul 23, 2023
3bea414
Update tests
ziatdinovmax Jul 23, 2023
e208c29
Fix typo in the argument name
ziatdinovmax Jul 24, 2023
a12926d
Fix typo in the argument's name
ziatdinovmax Jul 24, 2023
167ec61
Add multi-task/fidelity dkl
ziatdinovmax Jul 24, 2023
889a5f4
Update prediction for multitask dkl
ziatdinovmax Jul 25, 2023
9cb25e3
Update import statements
ziatdinovmax Jul 25, 2023
3cd6ea6
Update docstrings
ziatdinovmax Jul 25, 2023
14eae4d
Refactor
ziatdinovmax Jul 27, 2023
65c42a4
Update docs
ziatdinovmax Jul 27, 2023
259faa2
Update docs
ziatdinovmax Jul 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/acquisition.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ Acquisition functions

.. autofunction:: gpax.acquisition.UE

.. autofunction:: gpax.acquisition.bUCB
.. autofunction:: gpax.acquisition.qUCB
4 changes: 3 additions & 1 deletion docs/source/kernels.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@ Kernels

.. autofunction:: gpax.kernels.MaternKernel

.. autofunction:: gpax.kernels.PeriodicKernel
.. autofunction:: gpax.kernels.PeriodicKernel

.. autofunction:: gpax.kernels.NNGPKernel
48 changes: 41 additions & 7 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
@@ -1,34 +1,68 @@
GPax models
===========

Gaussian Processes (Fully Bayesian Implementation)
Gaussian Processes - Fully Bayesian Implementation
--------------------------------------------------
.. autoclass:: gpax.gp.ExactGP
.. autoclass:: gpax.models.gp.ExactGP
:members:
:inherited-members:
:undoc-members:
:member-order: bysource
:show-inheritance:

.. autoclass:: gpax.vgp.vExactGP
.. autoclass:: gpax.models.vgp.vExactGP
:members:
:inherited-members:
:undoc-members:
:member-order: bysource
:show-inheritance:

Deep Kernel Learning (Fully Bayesian Implementation)
Gaussian Processes - Approximate Bayesian
------------------------------------------
.. autoclass:: gpax.models.gp.ExactGP
:members:
:inherited-members:
:undoc-members:
:member-order: bysource
:show-inheritance:

Deep Kernel Learning - Fully Bayesian Implementation
----------------------------------------------------
.. autoclass:: gpax.dkl.DKL
.. autoclass:: gpax.models.dkl.DKL
:members:
:inherited-members:
:undoc-members:
:member-order: bysource
:show-inheritance:

Deep Kernel Learning (Approximate Bayesian)
Deep Kernel Learning - Approximate Bayesian
-------------------------------------------
.. autoclass:: gpax.vidkl.viDKL
.. autoclass:: gpax.models.vidkl.viDKL
:members:
:inherited-members:
:undoc-members:
:member-order: bysource
:show-inheritance:

Infinite-width Bayesian Neural Networks
----------------------------------------
.. autoclass:: gpax.models.ibnn.iBNN
:members:
:inherited-members:
:undoc-members:
:member-order: bysource
:show-inheritance:

Multi-Task Learning
--------------------
.. autoclass:: gpax.models.mtgp.MultiTaskGP
:members:
:inherited-members:
:undoc-members:
:member-order: bysource
:show-inheritance:

.. autoclass:: gpax.models.vi_mtdkl.viMTDKL
:members:
:inherited-members:
:undoc-members:
Expand Down
19 changes: 9 additions & 10 deletions gpax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from . import utils, kernels, acquisition
from .gp import ExactGP
from .vgp import vExactGP
from .bnn import DKL, viDKL, iBNN, vi_iBNN
from .vigp import viGP
from .spm import sPM
from .hypo import sample_next

from .__version__ import version as __version__
from . import utils
from . import kernels
from . import acquisition
from .hypo import sample_next
from .models import (DKL, CoregGP, ExactGP, MultiTaskGP, iBNN, vExactGP,
vi_iBNN, viDKL, viGP, viMTDKL)

__all__ = ["utils", "kernels", "acquisition", "ExactGP", "vExactGP", "DKL",
"viDKL", "iBNN", "vi_iBNN", "viGP", "sPM", "sample_next", "__version__"]
__all__ = ["utils", "kernels", "mtkernels", "acquisition", "ExactGP", "vExactGP", "DKL",
"viDKL", "iBNN", "vi_iBNN", "MultiTaskGP", "viMTDKL", "viGP", "sPM",
"CoregGP", "sample_next", "__version__"]
1 change: 1 addition & 0 deletions gpax/acquisition/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .acquisition import *
2 changes: 1 addition & 1 deletion gpax/acquisition.py → gpax/acquisition/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import numpy as onp
import numpyro.distributions as dist

from .gp import ExactGP
from ..models.gp import ExactGP


def EI(rng_key: jnp.ndarray, model: Type[ExactGP],
Expand Down
6 changes: 0 additions & 6 deletions gpax/bnn/__init__.py

This file was deleted.

4 changes: 2 additions & 2 deletions gpax/hypo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
import numpy as np
import numpyro

from .gp import ExactGP
from .spm import sPM
from .models.gp import ExactGP
from .models.spm import sPM
from .utils import get_keys


Expand Down
16 changes: 16 additions & 0 deletions gpax/kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from .kernels import (MaternKernel, NNGPKernel, PeriodicKernel, RBFKernel,
get_kernel, nngp_erf, nngp_relu)
from .mtkernels import (LCMKernel, MultitaskKernel, MultivariateKernel,
index_kernel)

__all__ = [
"RBFKernel",
"MaternKernel",
"PeriodicKernel",
"NNGPKernel",
"get_kernel",
"index_kernel",
"MultitaskKernel",
"MultivariateKernel",
"LCMKernel"
]
8 changes: 4 additions & 4 deletions gpax/kernels.py → gpax/kernels/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def add_jitter(x, jitter=1e-6):
def square_scaled_distance(X: jnp.ndarray, Z: jnp.ndarray,
lengthscale: Union[jnp.ndarray, float] = 1.
) -> jnp.ndarray:
"""
r"""
Computes a square of scaled distance, :math:`\|\frac{X-Z}{l}\|^2`,
between X and Z are vectors with :math:`n x num_features` dimensions
"""
Expand Down Expand Up @@ -115,7 +115,7 @@ def PeriodicKernel(X: jnp.ndarray, Z: jnp.ndarray,


def nngp_erf(x1: jnp.ndarray, x2: jnp.ndarray,
var_b: jnp.array, var_w: jnp.array,
var_b: jnp.array, var_w: jnp.array,
depth: int = 3) -> jnp.array:
"""
Computes the Neural Network Gaussian Process (NNGP) kernel value for
Expand Down Expand Up @@ -187,7 +187,7 @@ def NNGPKernel(activation: str = 'erf', depth: int = 3

Args:
activation: activation function ('erf' or 'relu')
depth: The number of layers in the corresponding infinite-width neural network.
depth: The number of layers in the corresponding infinite-width neural network.
Controls the level of recursion in the computation.

Returns:
Expand All @@ -196,7 +196,7 @@ def NNGPKernel(activation: str = 'erf', depth: int = 3
nngp_single_pair_ = nngp_relu if activation == 'relu' else nngp_erf

def NNGPKernel_func(X: jnp.ndarray, Z: jnp.ndarray,
params: Dict[str, jnp.ndarray],
params: Dict[str, jnp.ndarray],
noise: jnp.ndarray = 0, **kwargs
) -> jnp.ndarray:
"""
Expand Down
Loading