Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Develop #204

Merged
merged 7 commits into from
Nov 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ JAXNS is:

What can you do with JAXNS?

1) Compute the Bayesian evidence of a model or hypothesis (the ultimate scientific method);
1) Compute the Bayesian evidence of a model or hypothesis (the ultimate scientific method);
2) Produce high-quality samples from the posterior distribution;
3) Easily handle degenerate difficult multi-modal posteriors;
4) Model both discrete and continuous priors and likelihoods;
Expand Down Expand Up @@ -359,6 +359,9 @@ before importing JAXNS.

# Change Log

9 Nov, 2024 -- JAXNS 2.6.5 released. Added gradient guided nested sampling. Removed `num_parallel_workers` in favour
`devices`.

4 Nov, 2024 -- JAXNS 2.6.4 released. Resolved bias when using phantom points.

1 Oct, 2024 -- JAXNS 2.6.3 released. Enable pytrees in context.
Expand Down
171 changes: 171 additions & 0 deletions benchmarks/gh136/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import time

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pkg_resources
import tensorflow_probability.substrates.jax as tfp
from jax._src.scipy.linalg import solve_triangular

from jaxns import Model, Prior

try:
from jaxns import NestedSampler
except ImportError:
from jaxns import DefaultNestedSampler as NestedSampler

tfpd = tfp.distributions


def build_run_model(num_slices, gradient_guided, ndims):
def run_model(key, prior_mu, prior_cov, data_mu, data_cov):
def log_normal(x, mean, cov):
L = jnp.linalg.cholesky(cov)
dx = x - mean
dx = solve_triangular(L, dx, lower=True)
return -0.5 * x.size * jnp.log(2. * jnp.pi) - jnp.sum(jnp.log(jnp.diag(L))) \
- 0.5 * dx @ dx

true_logZ = log_normal(data_mu, prior_mu, prior_cov + data_cov)

J = jnp.linalg.solve(data_cov + prior_cov, prior_cov)
post_mu = prior_mu + J.T @ (data_mu - prior_mu)
post_cov = prior_cov - J.T @ (prior_cov + data_cov) @ J

# print("True logZ={}".format(true_logZ))
# print("True post_mu={}".format(post_mu))
# print("True post_cov={}".format(post_cov))

# KL posterior || prior
dist_posterior = tfpd.MultivariateNormalFullCovariance(loc=post_mu, covariance_matrix=post_cov)
dist_prior = tfpd.MultivariateNormalFullCovariance(loc=prior_mu, covariance_matrix=prior_cov)
H_true = -tfp.distributions.kl_divergence(dist_posterior, dist_prior)

# print("True H={}".format(H_true))

def prior_model():
x = yield Prior(
tfpd.MultivariateNormalTriL(loc=prior_mu, scale_tril=jnp.linalg.cholesky(prior_cov)),
name='x')
return x

def log_likelihood(x):
return tfpd.MultivariateNormalTriL(loc=data_mu, scale_tril=jnp.linalg.cholesky(data_cov)).log_prob(x)

model = Model(prior_model=prior_model, log_likelihood=log_likelihood)

ns = NestedSampler(model=model, verbose=False, k=0, num_slices=num_slices, gradient_guided=gradient_guided)

termination_reason, state = ns(key)
results = ns.to_results(termination_reason=termination_reason, state=state, trim=False)

error = results.H_mean - H_true
log_Z_error = results.log_Z_mean - true_logZ
return results.H_mean, H_true, error, log_Z_error

return run_model


def get_data(ndims):
prior_mu = 15 * jnp.ones(ndims)
prior_cov = jnp.diag(jnp.ones(ndims)) ** 2

data_mu = jnp.zeros(ndims)
data_cov = jnp.diag(jnp.ones(ndims)) ** 2
data_cov = jnp.where(data_cov == 0., 0.99, data_cov)
return prior_mu, prior_cov, data_mu, data_cov


def main():
jaxns_version = pkg_resources.get_distribution("jaxns").version
m = 3
d = 32

data = get_data(d)

# Row 1: Plot logZ error for gradient guided vs baseline for different s, with errorbars
# Row 2: Plot H error for gradient guided vs baseline for different s, with errorbars
# Row 3: Plot time taken for gradient guided vs baseline for different s, with errorbars

s_array = [10, 20, 30, 40, 80, 120]

run_model_baseline_aot_array = [
jax.jit(build_run_model(num_slices=s, gradient_guided=False, ndims=d)).lower(jax.random.PRNGKey(0), *data).compile() for
s in
s_array]
run_model_gg_aot_array = [
jax.jit(build_run_model(num_slices=s, gradient_guided=True, ndims=d)).lower(jax.random.PRNGKey(0), *data).compile() for s
in
s_array]

H_errors = np.zeros((len(s_array), m, 2))
log_z_errors = np.zeros((len(s_array), m, 2))
dt = np.zeros((len(s_array), m, 2))

for s_idx in range(len(s_array)):
s = s_array[s_idx]
for i in range(m):
key = jax.random.PRNGKey(i * 42)
baseline_model = run_model_baseline_aot_array[s_idx]
gg_model = run_model_gg_aot_array[s_idx]
t0 = time.time()
H, H_true, H_error, log_Z_error = jax.block_until_ready(baseline_model(key, *data))
t1 = time.time()
dt[s_idx, i, 0] = t1 - t0
H_errors[s_idx, i, 0] = H_error
log_z_errors[s_idx, i, 0] = log_Z_error
print(f"Baseline: i={i} k=0 s={s} H={H} H_true={H_true} H_error={H_error} log_Z_error={log_Z_error}")
t0 = time.time()
H, H_true, H_error, log_Z_error = jax.block_until_ready(gg_model(key, *data))
t1 = time.time()
dt[s_idx, i, 1] = t1 - t0
H_errors[s_idx, i, 1] = H_error
log_z_errors[s_idx, i, 1] = log_Z_error
print(f"GG: i={i} k=0 s={s} H={H} H_true={H_true} H_error={H_error} log_Z_error={log_Z_error}")

fig, axs = plt.subplots(3, 1, figsize=(10, 15), sharex=True)
# Row 1
H_error_mean = np.mean(H_errors, axis=1) # [s, 2]
H_error_std = np.std(H_errors, axis=1) # [s, 2]
axs[0].plot(s_array, H_error_mean[:, 0], label="Baseline", c='b')
axs[0].plot(s_array, H_error_mean[:, 1], label="Gradient Guided", c='r')
axs[0].fill_between(s_array, H_error_mean[:, 0] - H_error_std[:, 0], H_error_mean[:, 0] + H_error_std[:, 0],
color='b', alpha=0.2)
axs[0].fill_between(s_array, H_error_mean[:, 1] - H_error_std[:, 1], H_error_mean[:, 1] + H_error_std[:, 1],
color='r', alpha=0.2)
axs[0].set_ylabel("H error")
axs[0].legend()

# Row 2
logZ_error_mean = np.mean(log_z_errors, axis=1) # [s, 2]
logZ_error_std = np.std(log_z_errors, axis=1) # [s, 2]
axs[1].plot(s_array, logZ_error_mean[:, 0], label="Baseline", c='b')
axs[1].plot(s_array, logZ_error_mean[:, 1], label="Gradient Guided", c='r')
axs[1].fill_between(s_array, logZ_error_mean[:, 0] - logZ_error_std[:, 0],
logZ_error_mean[:, 0] + logZ_error_std[:, 0], color='b', alpha=0.2)
axs[1].fill_between(s_array, logZ_error_mean[:, 1] - logZ_error_std[:, 1],
logZ_error_mean[:, 1] + logZ_error_std[:, 1], color='r', alpha=0.2)
axs[1].set_ylabel("logZ error")
axs[1].legend()

# Row 3
dt_mean = np.mean(dt, axis=1) # [s, 2]
dt_std = np.std(dt, axis=1) # [s, 2]
axs[2].plot(s_array, dt_mean[:, 0], label="Baseline", c='b')
axs[2].plot(s_array, dt_mean[:, 1], label="Gradient Guided", c='r')
axs[2].fill_between(s_array, dt_mean[:, 0] - dt_std[:, 0], dt_mean[:, 0] + dt_std[:, 0], color='b', alpha=0.2)
axs[2].fill_between(s_array, dt_mean[:, 1] - dt_std[:, 1], dt_mean[:, 1] + dt_std[:, 1], color='r', alpha=0.2)
axs[2].set_ylabel("Time taken")
axs[2].legend()
axs[2].set_xlabel(r"number of slices")

axs[0].set_title(f"Gradient guided vs baseline, D={d}, v{jaxns_version}")

plt.savefig(f"Gradient_guided_vs_baseline_D{d}_v{jaxns_version}.png")

plt.show()


if __name__ == '__main__':
main()
428 changes: 428 additions & 0 deletions docs/examples/efficient_parameter_estimation.ipynb

Large diffs are not rendered by default.

334 changes: 334 additions & 0 deletions docs/examples/gradient_guided.ipynb

Large diffs are not rendered by default.

26 changes: 26 additions & 0 deletions src/jaxns/internals/pytree_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import jax
import jax.numpy as jnp


def tree_dot(x, y):
dots = jax.tree.leaves(jax.tree.map(jnp.vdot, x, y))
return sum(dots[1:], start=dots[0])


def tree_norm(x):
norm2 = tree_dot(x, x)
if jnp.issubdtype(norm2.dtype, jnp.complexfloating):
return jnp.sqrt(norm2.real)
return jnp.sqrt(norm2)


def tree_mul(x, y):
return jax.tree.map(jax.lax.mul, x, y)


def tree_sub(x, y):
return jax.tree.map(jax.lax.sub, x, y)


def tree_div(x, y):
return jax.tree.map(jax.lax.div, x, y)
10 changes: 10 additions & 0 deletions src/jaxns/internals/random.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional

import jax
from jax import random, numpy as jnp
from jax.scipy import special

Expand Down Expand Up @@ -72,3 +73,12 @@ def resample_indicies(key: PRNGKey, log_weights: Optional[FloatArray] = None, S:
g = -random.gumbel(key, shape=(num_total,))
idx = jnp.argsort(g)[:S]
return idx


def sample_uniformly_masked(key, v, select_mask, num_samples: int, squeeze: bool = False):
# If no satisfied samples, then chooses randomly from them. Should never happen, but good to know.
log_weights = jnp.where(select_mask, 0., -jnp.inf)
sample_idxs = resample_indicies(key, log_weights=log_weights, S=num_samples, replace=True)
if squeeze:
sample_idxs = jnp.squeeze(sample_idxs)
return jax.tree.map(lambda x: x[sample_idxs], v)
3 changes: 2 additions & 1 deletion src/jaxns/nested_samplers/common/initialisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,5 +103,6 @@ def create_init_termination_register() -> TerminationRegister:
plateau=jnp.asarray(False, jnp.bool_),
no_seed_points=jnp.asarray(False, jnp.bool_),
relative_spread=jnp.asarray(jnp.inf, mp_policy.measure_dtype),
absolute_spread=jnp.asarray(jnp.inf, mp_policy.measure_dtype)
absolute_spread=jnp.asarray(jnp.inf, mp_policy.measure_dtype),
peak_log_XL=jnp.asarray(-jnp.inf, mp_policy.measure_dtype)
)
8 changes: 8 additions & 0 deletions src/jaxns/nested_samplers/common/termination.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def determine_termination(
8-bit -> 256: relative spread of live points < rtol
9-bit -> 512: absolute spread of live points < atol
10-bit -> 1024: no seed points left
11-bit -> 2048: XL < max(XL) * peak_XL_frac

Multiple flags are summed together

Expand Down Expand Up @@ -136,4 +137,11 @@ def _set_done_bit(bit_done, bit_reason, done, termination_reason):
done, termination_reason = _set_done_bit(termination_register.no_seed_points, 10,
done=done, termination_reason=termination_reason)

if term_cond.peak_XL_frac is not None:
log_XL = termination_register.evidence_calc.log_X_mean + termination_register.evidence_calc.log_L
peak_log_XL = termination_register.peak_log_XL
XL_reduction_reached = log_XL < peak_log_XL + jnp.log(term_cond.peak_XL_frac)
done, termination_reason = _set_done_bit(XL_reduction_reached, 11,
done=done, termination_reason=termination_reason)

return done, termination_reason
2 changes: 2 additions & 0 deletions src/jaxns/nested_samplers/common/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class TerminationCondition(NamedTuple):
efficiency_threshold: Optional[FloatArray] = None
rtol: Optional[FloatArray] = None
atol: Optional[FloatArray] = None
peak_XL_frac: Optional[FloatArray] = None

def __and__(self, other):
return TerminationConditionConjunction(conds=[self, other])
Expand Down Expand Up @@ -134,6 +135,7 @@ class TerminationRegister(NamedTuple):
no_seed_points: BoolArray
relative_spread: FloatArray
absolute_spread: FloatArray
peak_log_XL: FloatArray


class NestedSamplerState(NamedTuple):
Expand Down
Loading
Loading