Skip to content

Commit

Permalink
Merge pull request #65 from ziatdinovmax/optimize_acq
Browse files Browse the repository at this point in the history
Acquisition function optimization
  • Loading branch information
ziatdinovmax authored Jan 3, 2024
2 parents 57882c9 + 5f63586 commit 63ec983
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 3 deletions.
3 changes: 2 additions & 1 deletion gpax/acquisition/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .acquisition import UCB, EI, POI, UE, Thompson, KG
from .batch_acquisition import qEI, qPOI, qUCB, qKG
from .optimize import optimize_acq

__all__ = ["UCB", "EI", "POI", "UE", "KG", "Thompson", "qEI", "qPOI", "qUCB", "qKG"]
__all__ = ["UCB", "EI", "POI", "UE", "KG", "Thompson", "qEI", "qPOI", "qUCB", "qKG", "optimize_acq"]
97 changes: 97 additions & 0 deletions gpax/acquisition/optimize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""
optimize.py
==============
Optimize continuous acquisition functions
Created by Maxim Ziatdinov (email: [email protected])
"""

from typing import Type, Callable, Union, List, Tuple

import jax.numpy as jnp
import jax.random as jra
import numpy as onp

from ..models.gp import ExactGP


def optimize_acq(rng_key: jnp.ndarray,
model: Type[ExactGP],
acq_fn: Callable,
num_initial_guesses: int,
lower_bound: Union[List, Tuple, float, onp.ndarray, jnp.ndarray],
upper_bound: Union[List, Tuple, float, onp.ndarray, jnp.ndarray],
**kwargs) -> jnp.ndarray:
"""
Optimizes an acquisition function for a given Gaussian Process model using the JAXopt library.
This function finds the point that maximizes the acquisition function within the specified bounds.
It uses L-BFGS-B algorithm through ScipyBoundedMinimize from JAXopt.
Args:
rng_key: A JAX random key for stochastic processes.
model: The Gaussian Process model to be used.
acq_fn: The acquisition function to be maximized.
num_initial_guesses: Number of random initial guesses for the optimization.
lower_bound: Lower bounds for the optimization.
upper_bound: Upper bounds for the optimization.
**kwargs: Additional keyword arguments to be passed to the acquisition function.
Returns:
Parameter(s) that maximize the acquisition function within the specified bounds.
Note:
Ensure JAXopt is installed to use this function (`pip install jaxopt`).
The acquisition function is minimized using its negative value to find the maximum.
Examples:
Optimize EI given a trained GP model for 1D problem
>>> acq_fn = gpax.acquisition.EI
>>> num_initial_guesses = 10
>>> lower_bound = -2.0
>>> upper_bound = 2.0
>>> x_next = gpax.acquisition.optimize_acq(
>>> rng_key, gp_model, acq_fn,
>>> num_initial_guesses, lower_bound, upper_bound,
>>> maximize=False, noiseless=True)
"""

try:
import jaxopt # noqa: F401
except ImportError as e:
raise ImportError(
"You need to install `jaxopt` to be able to use this feature. "
"It can be installed with `pip install jaxopt`."
) from e

def acq(x):
x = jnp.array([x])
x = x[None] if x.ndim == 0 else x
obj = -acq_fn(rng_key, model, x, **kwargs)
return jnp.reshape(obj, ())

lower_bound = ensure_array(lower_bound)
upper_bound = ensure_array(upper_bound)

initial_guesses = jra.uniform(
rng_key, shape=(num_initial_guesses, lower_bound.shape[0]),
minval=lower_bound, maxval=upper_bound)
initial_acq_vals = acq_fn(rng_key, model, initial_guesses, **kwargs)
best_initial_guess = initial_guesses[initial_acq_vals.argmax()].squeeze()

minimizer = jaxopt.ScipyBoundedMinimize(fun=acq, method='l-bfgs-b')
result = minimizer.run(best_initial_guess, bounds=(lower_bound, upper_bound))

return result.params


def ensure_array(x):
if not isinstance(x, jnp.ndarray):
if isinstance(x, (list, tuple, float, onp.ndarray)):
x = jnp.array([x]) if isinstance(x, float) else jnp.array(x)
else:
raise TypeError(f"Expected input to be a list, tuple, float, or jnp.ndarray, got {type(x)} instead.")
return x
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ matplotlib>=3.1.1
jax>=0.2.21
numpyro>=0.8.0
dm-haiku>=0.0.5
jaxopt>0.8.0
typing-extensions>=4.4.0
2 changes: 0 additions & 2 deletions tests/test_acq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import numpy as onp
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpy.testing import assert_equal, assert_

sys.path.insert(0, "../gpax/")
Expand Down
36 changes: 36 additions & 0 deletions tests/test_optimize_acq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import sys
import pytest
import numpy as onp
import jax.numpy as jnp
from numpy.testing import assert_

sys.path.insert(0, "../gpax/")

from gpax.models.gp import ExactGP
from gpax.acquisition.optimize import optimize_acq
from gpax.acquisition.acquisition import UCB, EI
from gpax.utils import get_keys


def get_inputs():
X = onp.random.uniform(-2, 2, size=(4,))
y = X**3
return X, y


@pytest.mark.parametrize("acq_fn", [UCB, EI])
def test_optimize_acq(acq_fn):
lower_bound = -2.0
upper_bound = 2.0
num_initial_guesses = 3
key1, key2 = get_keys()
X, y = get_inputs()
model = ExactGP(1, 'RBF')
model.fit(key1, X, y, num_warmup=50, num_samples=50)
x_next = optimize_acq(
key2, model, acq_fn, num_initial_guesses, lower_bound, upper_bound)
assert_(isinstance(x_next, jnp.ndarray))




0 comments on commit 63ec983

Please sign in to comment.