Skip to content

Commit

Permalink
* Fix issue of no more seeds left (several plateaus)
Browse files Browse the repository at this point in the history
  • Loading branch information
Joshuaalbert committed Sep 25, 2024
1 parent 8a304d1 commit 13dc750
Show file tree
Hide file tree
Showing 35 changed files with 819 additions and 787 deletions.
17 changes: 8 additions & 9 deletions benchmarks/difficult_problems/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,21 +152,20 @@ def main():
for model_name, model in all_models().items():
print(f"Testing model {model_name}")
model.sanity_check(jax.random.PRNGKey(0), 1000)
ns = NestedSampler(model=model,
max_samples=1000000,
verbose=True,
difficult_model=True,
parameter_estimation=True
)
ns = NestedSampler(
model=model,
difficult_model=True,
parameter_estimation=True
)
ns_jit = jax.jit(lambda key: ns(key))
ns_compiled = ns_jit.lower(jax.random.PRNGKey(42)).compile()
with Timer():
termination_reason, state = ns_compiled(jax.random.PRNGKey(42))
termination_reason.block_until_ready()
results = ns.to_results(termination_reason=termination_reason, state=state)
ns.plot_diagnostics(results)
ns.summary(results)
ns.plot_cornerplot(results)
ns.plot_diagnostics(results, save_name=f"{model_name}_diagnostics.png")
ns.plot_cornerplot(results, save_name=f"{model_name}_cornerplot.png")
ns.summary(results, f_obj=f"{model_name}_summary.txt")


if __name__ == '__main__':
Expand Down
4 changes: 4 additions & 0 deletions benchmarks/gh117/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ def main():
# Avg. time taken: 0.00562 seconds.
# The best 3 of 10 runs took 0.00478 seconds.

# 2.6.2
# Avg. time taken: 0.00012 seconds.
# The best 3 of 10 runs took 0.00007 seconds.


if __name__ == '__main__':
main()
1 change: 1 addition & 0 deletions benchmarks/gh168/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def main():
f"The best 3 of {m} runs took {best_3:.5f} seconds.")

with open('results', 'a') as fp:
# jaxns_version,mean_error,mean_uncert,avg_time,best_3
fp.write(f"{jaxns_version},{np.mean(errors)},{np.mean(uncerts)},{total_time / m},{best_3}\n")

# Before fix
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/gh168/run_benchmark.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash

# Array of jaxns versions to be installed
declare -a jaxns_versions=("2.4.6" "2.4.7" "2.4.8" "2.4.10" "2.4.11" "2.4.12" "2.4.13" "2.5.0")
declare -a jaxns_versions=("2.4.6" "2.4.7" "2.4.8" "2.4.10" "2.4.11" "2.4.12" "2.4.13" "2.5.0" "2.6.0" "2.6.1" "2.6.2")

# Path to your benchmark script
benchmark_script="main.py"
Expand Down
7 changes: 4 additions & 3 deletions benchmarks/parallel_problems/main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
import time

# Force 2 jax hosts
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={os.cpu_count()}"
# Force jax hosts
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={os.cpu_count()//2}"

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -61,7 +61,7 @@ def log_likelihood(x):
def main():
num_devices = len(jax.devices())
jaxns_version = pkg_resources.get_distribution("jaxns").version
m = 3
m = 10
run_model_aot = jax.jit(run_model).lower(jax.random.PRNGKey(0)).compile()
dt = []

Expand All @@ -87,6 +87,7 @@ def main():
f"The best 3 of {m} runs took {best_3:.5f} seconds.")

with open('results', 'a') as fp:
#verion,num_devices,mean_error,mean_uncert,avg_time,best_3
fp.write(f"{jaxns_version},{num_devices},{np.mean(errors)},{np.mean(uncerts)},{total_time / m},{best_3}\n")


Expand Down
7 changes: 5 additions & 2 deletions benchmarks/parallel_problems/results
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#verion,num_devices,mean_error,mean_uncert,avg_time,best_3
2.5.1,12,-0.8741127252578735,0.0003452669479884207,95.40295028686523,77.15287351608276
2.5.1,12,-1.30291748046875,0.32637646794319153,102.9092493057251,34.30308310190836
2.3.4,12,0.09217196161439745,0.3709631743726239,2.116373300552368,0.7054577668507894
2.6.0,6,-0.06817162536483465,0.3668122147028353,1.1894174337387085,1.0906640688578289
2.6.0,12,-0.06817162536483465,0.3668122147028353,1.9961725234985352,1.7415425777435303
2.6.1,6,-0.06817162536483465,0.3668122147028353,1.1894174337387085,1.0906640688578289
2.6.1,12,-0.06817162536483465,0.3668122147028353,1.9961725234985352,1.7415425777435303
2.6.2,6,-0.011831551800402452,0.3695399992566994,0.8977173328399658,0.8630998134613037
2.6.2,12,-0.011831551800402452,0.3695399992566994,1.2420992612838746,1.1599870522816975
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information

project = "jaxns"
copyright = "2022, Joshua G. Albert"
copyright = "2024, Joshua G. Albert"
author = "Joshua G. Albert"
release = "2.6.1"
release = "2.6.2"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
243 changes: 0 additions & 243 deletions docs/examples/debug/plausible_logic.ipynb

This file was deleted.

300 changes: 90 additions & 210 deletions docs/examples/gamma_poission.ipynb

Large diffs are not rendered by default.

499 changes: 308 additions & 191 deletions docs/examples/gaussian_process_marginalisation_spectral.ipynb

Large diffs are not rendered by default.

252 changes: 252 additions & 0 deletions docs/examples/plausible_logic.ipynb

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions jaxns/experimental/evidence_maximisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from jaxns.internals.cumulative_ops import cumulative_op_static
from jaxns.internals.log_semiring import LogSpace
from jaxns.internals.logging import logger
from jaxns.internals.mixed_precision import mp_policy
from jaxns.internals.types import IntArray, PRNGKey
from jaxns.nested_samplers.common.types import TerminationCondition, NestedSamplerResults, \
StaticStandardNestedSamplerState
from jaxns.internals.mixed_precision import float_type

__all__ = [
'EvidenceMaximisation'
Expand Down Expand Up @@ -242,7 +242,7 @@ def op(log_Z, data):
log_dZ = model.forward(data.U_samples) + data.log_weights
return (LogSpace(log_Z) + LogSpace(log_dZ)).log_abs_val

log_Z, _ = cumulative_op_static(op=op, init=jnp.asarray(-jnp.inf, float_type), xs=data)
log_Z, _ = cumulative_op_static(op=op, init=jnp.asarray(-jnp.inf, mp_policy.measure_dtype), xs=data)
return log_Z

def loss(params: MutableParams, data: MStepData):
Expand Down Expand Up @@ -308,8 +308,8 @@ def _pad_to_n(x, fill_value, dtype):

log_weights = ns_results.log_dp_mean - ns_results.log_L_samples + ns_results.log_Z_mean
data = MStepData(
U_samples=_pad_to_n(ns_results.U_samples, 0.5, float_type),
log_weights=_pad_to_n(log_weights, -jnp.inf, float_type)
U_samples=_pad_to_n(ns_results.U_samples, 0.5, mp_policy.measure_dtype),
log_weights=_pad_to_n(log_weights, -jnp.inf, mp_policy.measure_dtype)
)
desc = p_bar.desc
last_params = params
Expand Down
5 changes: 3 additions & 2 deletions jaxns/experimental/global_optimisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def _round(v, uncert_v):
return float(v)

def _print_termination_condition(_termination_reason: int):
termination_bit_mask = _bit_mask(int(_termination_reason), width=8)
termination_bit_mask = _bit_mask(int(_termination_reason), width=11)
# 0-bit -> 1: used maximum allowed number of likelihood evaluations
# 1-bit -> 2: reached goal log-likelihood contour
# 2-bit -> 4: relative spread of log-likelihood values below threshold
Expand All @@ -215,7 +215,8 @@ def _print_termination_condition(_termination_reason: int):
'Sampler efficiency too low',
'All live-points are on a single plateau (potential numerical errors, consider 64-bit)',
'relative spread of live points < rtol',
'absolute spread of live points < atol'
'absolute spread of live points < atol',
'no seed points left'
]):
if bit == 1:
_print(condition)
Expand Down
4 changes: 2 additions & 2 deletions jaxns/framework/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
parse_joint, transform_parametrised
from jaxns.internals.logging import logger
from jaxns.internals.maps import pytree_unravel
from jaxns.internals.mixed_precision import mp_policy
from jaxns.internals.types import PRNGKey, FloatArray, LikelihoodType, UType, XType, LikelihoodInputType, \
WType
from jaxns.internals.mixed_precision import float_type

__all__ = [
'Model'
Expand Down Expand Up @@ -132,7 +132,7 @@ def sample_U(self, key: PRNGKey) -> UType:
raise RuntimeError("Model has not been initialised")

def _sample_U():
return random.uniform(key=ctx.next_rng_key(), shape=(self.U_ndims,), dtype=float_type)
return random.uniform(key=ctx.next_rng_key(), shape=(self.U_ndims,), dtype=mp_policy.measure_dtype)

return ctx.transform(_sample_U).apply(self._params, key).fn_val

Expand Down
4 changes: 2 additions & 2 deletions jaxns/framework/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
from typing import Tuple, Callable, Generator

import jax
import numpy as np
from jax import numpy as jnp, lax

from jaxns.framework.bases import PriorModelType, BaseAbstractPrior, PriorModelGen
from jaxns.framework.prior import InvalidPriorName, SingularPrior, Prior
from jaxns.internals.maps import pytree_unravel
from jaxns.internals.mixed_precision import mp_policy
from jaxns.internals.types import UType, XType, LikelihoodInputType, FloatArray, LikelihoodType, PRNGKey, \
isinstance_namedtuple, WType, RandomVariableType
from jaxns.internals.mixed_precision import float_type, mp_policy
import numpy as np

__all__ = [
'simulate_prior_model'
Expand Down
10 changes: 5 additions & 5 deletions jaxns/framework/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from jaxns.framework.bases import BaseAbstractPrior, BaseAbstractDistribution
from jaxns.framework.wrapped_tfp_distribution import WrappedTFPDistribution
from jaxns.internals.constraint_bijections import quick_unit
from jaxns.internals.mixed_precision import mp_policy
from jaxns.internals.types import FloatArray, IntArray, BoolArray, XType, UType
from jaxns.internals.mixed_precision import float_type

tfpd = tfp.distributions

Expand Down Expand Up @@ -56,7 +56,7 @@ def _forward(self, U: UType) -> Union[FloatArray, IntArray, BoolArray]:
return self.value

def _inverse(self, X: XType) -> UType:
return jnp.asarray([], float_type)
return jnp.asarray([], mp_policy.measure_dtype)

def _log_prob(self, X: XType) -> FloatArray:
return self.base_prior.log_prob(X)
Expand Down Expand Up @@ -129,15 +129,15 @@ def _forward(self, U: UType) -> Union[FloatArray, IntArray, BoolArray]:

def _inverse(self, X: XType) -> FloatArray:
if self._type == 'value':
return jnp.asarray([], float_type)
return jnp.asarray([], mp_policy.measure_dtype)
elif self._type == 'dist':
return self.dist.inverse(X)
else:
raise NotImplementedError()

def _log_prob(self, X: XType) -> FloatArray:
if self._type == 'value':
return jnp.asarray(0., float_type)
return jnp.asarray(0., mp_policy.measure_dtype)
elif self._type == 'dist':
return self.dist.log_prob(X=X)
else:
Expand Down Expand Up @@ -189,7 +189,7 @@ def prior_to_parametrised_singular(prior: BaseAbstractPrior, random_init: bool =
norm_U_base_param = ctx.get_parameter(
name=name,
shape=prior.base_shape,
dtype=float_type,
dtype=mp_policy.measure_dtype,
init=initaliser
)
# transform [-inf, inf] -> [0,1]
Expand Down
8 changes: 4 additions & 4 deletions jaxns/framework/special_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from jaxns.framework.prior import SingularPrior, prior_to_parametrised_singular
from jaxns.internals.interp_utils import InterpolatedArray
from jaxns.internals.log_semiring import cumulative_logsumexp
from jaxns.internals.mixed_precision import mp_policy
from jaxns.internals.types import FloatArray, IntArray, BoolArray, UType, RandomVariableType, \
MeasureType
from jaxns.internals.mixed_precision import float_type, int_type

tfpd = tfp.distributions

Expand Down Expand Up @@ -216,7 +216,7 @@ def __init__(self, *, n: int, low=None, high=None, fix_left: bool = False, fix_r
self.fix_right = fix_right

def _dtype(self):
return float_type
return mp_policy.measure_dtype

def _base_shape(self) -> Tuple[int, ...]:
num_base = self.n
Expand Down Expand Up @@ -355,7 +355,7 @@ def fixed_point_update(x, args):
@partial(jax.jit, static_argnames=("unroll",))
def _poisson_quantile(U, rate, unroll: bool = False):
x, _ = _poisson_quantile_bisection(U, rate, unroll=unroll)
return x.astype(int_type)
return x.astype(mp_policy.count_dtype)


class Poisson(SpecialPrior):
Expand All @@ -364,7 +364,7 @@ def __init__(self, *, rate=None, log_rate=None, name: Optional[str] = None):
self.dist = tfpd.Poisson(rate=rate, log_rate=log_rate)

def _dtype(self):
return int_type
return mp_policy.count_dtype

def _base_shape(self) -> Tuple[int, ...]:
return self._shape()
Expand Down
Loading

0 comments on commit 13dc750

Please sign in to comment.