Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add expected improvement utility function #460

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -16,13 +16,13 @@ help: ## Display this help

##@ Formatting
black: ## Format code in-place using black.
black ${PKGROOT}/ tests/ -l 79 .
black ${PKGROOT}/ tests/ -l 88 .

isort: ## Format imports in-place using isort.
isort ${PKGROOT}/ tests/

format: ## Code styling - black, isort
black ${PKGROOT}/ tests/ -l 100 .
black ${PKGROOT}/ tests/ -l 88 .
@printf "\033[1;34mBlack passes!\033[0m\n\n"
isort ${PKGROOT}/ tests/
@printf "\033[1;34misort passes!\033[0m\n\n"
2 changes: 1 addition & 1 deletion docs/examples/bayesian_optimisation.py
Original file line number Diff line number Diff line change
@@ -728,7 +728,7 @@ def obtain_log_regret_statistics(
#
# - **Expected Improvement (EI)** ([Močkus, 1974](https://link.springer.com/chapter/10.1007/3-540-07165-2_55)) - EI goes beyond PI by not only considering the
# probability of improving on the current best observed point, but also taking into
# account the \textit{magnitude} of improvement. Mathematically, this is defined as
# account the *magnitude* of improvement. Mathematically, this is defined as
# follows:
# $$
# \begin{aligned}
4 changes: 2 additions & 2 deletions docs/examples/decision_making.py
Original file line number Diff line number Diff line change
@@ -240,8 +240,8 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]:
# %% [markdown]

# It is worth noting that `ThompsonSampling` is not the only utility function we could use,
# since our module also provides e.g. `ProbabilityOfImprovement`,
# which was briefly discussed in [our previous introduction to Bayesian optimisation](https://docs.jaxgaussianprocesses.com/examples/bayesian_optimisation/).
# since our module also provides e.g. `ProbabilityOfImprovement`, `ExpectedImprovment`,
# which were briefly discussed in [our previous introduction to Bayesian optimisation](https://docs.jaxgaussianprocesses.com/examples/bayesian_optimisation/).


# %% [markdown]
26 changes: 19 additions & 7 deletions gpjax/decision_making/test_functions/continuous_functions.py
Original file line number Diff line number Diff line change
@@ -12,24 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from abc import (
ABC,
abstractmethod,
)
from abc import abstractmethod
from dataclasses import dataclass

import jax.numpy as jnp
from jaxtyping import (
Array,
Float,
Num,
)
import tensorflow_probability.substrates.jax as tfp

from gpjax.dataset import Dataset
from gpjax.decision_making.search_space import ContinuousSearchSpace
from gpjax.gps import AbstractMeanFunction
from gpjax.typing import KeyArray


class AbstractContinuousTestFunction(ABC):
class AbstractContinuousTestFunction(AbstractMeanFunction):
"""
Abstract base class for continuous test functions.

@@ -43,19 +43,28 @@ class AbstractContinuousTestFunction(ABC):
minimizer: Float[Array, "1 D"]
minimum: Float[Array, "1 1"]

def generate_dataset(self, num_points: int, key: KeyArray) -> Dataset:
def generate_dataset(
self, num_points: int, key: KeyArray, obs_stddev: float = 0.0
) -> Dataset:
"""
Generate a toy dataset from the test function.

Args:
num_points (int): Number of points to sample.
key (KeyArray): JAX PRNG key.
obs_stddev (float): (Optional) standard deviation of Gaussian distributed
noise added to observations.

Returns:
Dataset: Dataset of points sampled from the test function.
"""
X = self.search_space.sample(num_points=num_points, key=key)
y = self.evaluate(X)
gaussian_noise = tfp.distributions.Normal(
jnp.zeros(num_points), obs_stddev * jnp.ones(num_points)
)
y = self.evaluate(X) + jnp.transpose(
gaussian_noise.sample(sample_shape=[1], seed=key)
)
return Dataset(X=X, y=y)

def generate_test_points(
@@ -73,6 +82,9 @@ def generate_test_points(
"""
return self.search_space.sample(num_points=num_points, key=key)

def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N 1"]:
return self.evaluate(x)

@abstractmethod
def evaluate(self, x: Float[Array, "N D"]) -> Float[Array, "N 1"]:
"""
17 changes: 8 additions & 9 deletions gpjax/decision_making/test_functions/non_conjugate_functions.py
Original file line number Diff line number Diff line change
@@ -17,15 +17,14 @@

import jax.numpy as jnp
import jax.random as jr
from jaxtyping import (
Array,
Float,
Integer,
)

from gpjax.dataset import Dataset
from gpjax.decision_making.search_space import ContinuousSearchSpace
from gpjax.typing import KeyArray
from gpjax.typing import (
Array,
Float,
KeyArray,
)


@dataclass
@@ -74,7 +73,7 @@ def generate_test_points(
return self.search_space.sample(num_points=num_points, key=key)

@abstractmethod
def evaluate(self, x: Float[Array, "N 1"]) -> Integer[Array, "N 1"]:
def evaluate(self, x: Float[Array, "N 1"]) -> Float[Array, "N 1"]:
"""
Evaluate the test function at a set of points. Function taken from
https://docs.jaxgaussianprocesses.com/examples/poisson/#dataset.
@@ -83,8 +82,8 @@ def evaluate(self, x: Float[Array, "N 1"]) -> Integer[Array, "N 1"]:
x (Float[Array, 'N D']): Points to evaluate the test function at.

Returns:
Integer[Array, 'N 1']: Values of the test function at the points.
Float[Array, 'N 1']: Values of the test function at the points.
"""
key = jr.key(42)
f = lambda x: 2.0 * jnp.sin(3 * x) + 0.5 * x
return jr.poisson(key, jnp.exp(f(x)))
return jnp.float64(jr.poisson(key, jnp.exp(f(x))))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Must this be float64? If the user is using float32 arrays (ill-advised, but possible), then this will cause mixed precision errors. Can we perhaps wrap in a jnp.asarray()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have switched back to returning an array of integers, as these are samples from a Poisson distribution.

4 changes: 4 additions & 0 deletions gpjax/decision_making/utility_functions/__init__.py
Original file line number Diff line number Diff line change
@@ -18,6 +18,9 @@
SinglePointUtilityFunction,
UtilityFunction,
)
from gpjax.decision_making.utility_functions.expected_improvement import (
ExpectedImprovement,
)
from gpjax.decision_making.utility_functions.probability_of_improvement import (
ProbabilityOfImprovement,
)
@@ -27,6 +30,7 @@
"UtilityFunction",
"AbstractUtilityFunctionBuilder",
"AbstractSinglePointUtilityFunctionBuilder",
"ExpectedImprovement",
"SinglePointUtilityFunction",
"ThompsonSampling",
"ProbabilityOfImprovement",
112 changes: 112 additions & 0 deletions gpjax/decision_making/utility_functions/expected_improvement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright 2024 The JaxGaussianProcesses Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from dataclasses import dataclass
from functools import partial

from beartype.typing import Mapping
import jax.numpy as jnp
import tensorflow_probability.substrates.jax as tfp

from gpjax.dataset import Dataset
from gpjax.decision_making.utility_functions.base import (
AbstractSinglePointUtilityFunctionBuilder,
SinglePointUtilityFunction,
)
from gpjax.decision_making.utils import (
OBJECTIVE,
get_best_latent_observation_val,
)
from gpjax.gps import ConjugatePosterior
from gpjax.typing import (
Array,
Float,
KeyArray,
)


@dataclass
class ExpectedImprovement(AbstractSinglePointUtilityFunctionBuilder):
"""
Expected Improvement acquisition function as introduced by [Močkus,
1974](https://link.springer.com/chapter/10.1007/3-540-07165-2_55). The "best"
incumbent value is defined as the lowest posterior mean value evaluated at the the
previously observed points. This enables the acquisition function to be utilised with noisy observations.
"""

def build_utility_function(
self,
posteriors: Mapping[str, ConjugatePosterior],
datasets: Mapping[str, Dataset],
key: KeyArray,
) -> SinglePointUtilityFunction:
r"""
Build the Expected Improvement acquisition function. This computes the expected
improvement over the "best" of the previously observed points, utilising the
posterior distribution of the surrogate model. For posterior distribution
$`f(\cdot)`$, and best incumbent value $`\eta`$, this is defined
as:
```math
\alpha_{\text{EI}}(\mathbf{x}) = \mathbb{E}\left[\max(0, \eta - f(\mathbf{x}))\right]
```

Args:
posteriors (Mapping[str, ConjugatePosterior]): Dictionary of posteriors to
used to form the utility function. One posteriors must correspond to the
`OBJECTIVE` key, as we utilise the objective posterior to form the utility
function.
datasets (Mapping[str, Dataset]): Dictionary of datasets used to form the
utility function. Keys in `datasets` should correspond to keys in
`posteriors`. One of the datasets must correspond to the `OBJECTIVE` key.
key (KeyArray): JAX PRNG key used for random number generation.

Returns:
SinglePointUtilityFunction: The Expected Improvement acquisition function to
to be *maximised* in order to decide which point to query next.
"""
self.check_objective_present(posteriors, datasets)
objective_posterior = posteriors[OBJECTIVE]
objective_dataset = datasets[OBJECTIVE]

if not isinstance(objective_posterior, ConjugatePosterior):
raise ValueError(
"Objective posterior must be a ConjugatePosterior to compute the Expected Improvement."
)

if (
objective_dataset.X is None
or objective_dataset.n == 0
or objective_dataset.y is None
):
raise ValueError("Objective dataset must contain at least one item")

Check warning on line 92 in gpjax/decision_making/utility_functions/expected_improvement.py

Codecov / codecov/patch

gpjax/decision_making/utility_functions/expected_improvement.py#L92

Added line #L92 was not covered by tests

eta = get_best_latent_observation_val(objective_posterior, objective_dataset)
return partial(
_expected_improvement, objective_posterior, objective_dataset, eta
)


def _expected_improvement(
objective_posterior: ConjugatePosterior,
objective_dataset: Dataset,
eta: Float[Array, ""],
x: Float[Array, "N D"],
) -> Float[Array, "N 1"]:
latent_dist = objective_posterior(x, objective_dataset)
mean = latent_dist.mean()
var = latent_dist.variance()
normal = tfp.distributions.Normal(mean, jnp.sqrt(var))
return jnp.expand_dims(
((eta - mean) * normal.cdf(eta) + var * normal.prob(eta)), -1
)
Original file line number Diff line number Diff line change
@@ -23,7 +23,10 @@
AbstractSinglePointUtilityFunctionBuilder,
SinglePointUtilityFunction,
)
from gpjax.decision_making.utils import OBJECTIVE
from gpjax.decision_making.utils import (
OBJECTIVE,
get_best_latent_observation_val,
)
from gpjax.gps import ConjugatePosterior
from gpjax.typing import (
Array,
@@ -107,14 +110,9 @@ def build_utility_function(
)

def probability_of_improvement(x_test: Num[Array, "N D"]):
# Computing the posterior mean for the training dataset
# for computing the best_y value (as the minimum
# posterior mean of the objective function)
predictive_dist_for_training = objective_posterior.predict(
objective_dataset.X, objective_dataset
best_y = get_best_latent_observation_val(
objective_posterior, objective_dataset
)
best_y = predictive_dist_for_training.mean().min()

predictive_dist = objective_posterior.predict(x_test, objective_dataset)

normal_dist = tfp.distributions.Normal(
11 changes: 4 additions & 7 deletions gpjax/decision_making/utility_functions/thompson_sampling.py
Original file line number Diff line number Diff line change
@@ -22,10 +22,7 @@
SinglePointUtilityFunction,
)
from gpjax.decision_making.utils import OBJECTIVE
from gpjax.gps import (
ConjugatePosterior,
NonConjugatePosterior,
)
from gpjax.gps import ConjugatePosterior
from gpjax.typing import KeyArray


@@ -59,7 +56,7 @@ def __post_init__(self):

def build_utility_function(
self,
posteriors: Mapping[str, ConjugatePosterior | NonConjugatePosterior],
posteriors: Mapping[str, ConjugatePosterior],
datasets: Mapping[str, Dataset],
key: KeyArray,
) -> SinglePointUtilityFunction:
@@ -69,8 +66,8 @@ def build_utility_function(
are *maximised*.

Args:
posteriors (Mapping[str, AbstractPosterior]): Dictionary of posteriors to be
used to form the utility function. One of the posteriors must correspond
posteriors (Mapping[str, ConjugatePosterior]): Dictionary of posteriors to
be used to form the utility function. One of the posteriors must correspond
to the `OBJECTIVE` key, as we sample from the objective posterior to form
the utility function.
datasets (Mapping[str, Dataset]): Dictionary of datasets which may be used
14 changes: 14 additions & 0 deletions gpjax/decision_making/utils.py
Original file line number Diff line number Diff line change
@@ -17,8 +17,10 @@
Dict,
Final,
)
import jax.numpy as jnp

from gpjax.dataset import Dataset
from gpjax.gps import AbstractPosterior
from gpjax.typing import (
Array,
Float,
@@ -48,3 +50,15 @@ def build_function_evaluator(
dictionary of datasets storing the evaluated points.
"""
return lambda x: {tag: Dataset(x, f(x)) for tag, f in functions.items()}


def get_best_latent_observation_val(
posterior: AbstractPosterior, dataset: Dataset
) -> Float[Array, ""]:
"""
Takes a posterior and dataset and returns the best (latent) function value in the
dataset, corresponding to the minimum of the posterior mean value evaluated at
locations in the dataset. In the noiseless case, this corresponds to the minimum
value in the dataset.
"""
return jnp.min(posterior(dataset.X, dataset).mean())
Loading