Skip to content

Commit

Permalink
* Benchmarking for #172
Browse files Browse the repository at this point in the history
  • Loading branch information
Joshuaalbert committed May 29, 2024
1 parent e29da3c commit 4af4dd0
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 47 deletions.
4 changes: 2 additions & 2 deletions benchmarks/parallel_problems/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def log_likelihood(x):
def main():
num_devices = len(jax.devices())
jaxns_version = pkg_resources.get_distribution("jaxns").version
m = 10
m = 1
run_model_aot = jax.jit(run_model).lower(jax.random.PRNGKey(0)).compile()
dt = []

Expand All @@ -77,7 +77,7 @@ def main():
errors.append(log_Z_error)
uncerts.append(log_Z_uncert)
total_time = sum(dt)
best_3 = sum(sorted(dt)[:3]) / 3.
best_3 = sum(sorted(dt)[:min(3, m)]) / 3.
# print(f"Errors: {errors}")
# print(f"Uncerts: {uncerts}")
print(f"JAXNS {jaxns_version}\n"
Expand Down
1 change: 1 addition & 0 deletions benchmarks/parallel_problems/results
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
2.5.1,12,-0.8741127252578735,0.0003452669479884207,95.40295028686523,77.15287351608276
2.5.1,12,-1.30291748046875,0.32637646794319153,102.9092493057251,34.30308310190836
76 changes: 40 additions & 36 deletions jaxns/nested_sampler/standard_static.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import dataclasses
import warnings
from typing import Tuple, NamedTuple, Any, Union
from typing import Tuple, NamedTuple, Any, Union, List, Optional

import jax
from jax import random, pmap, numpy as jnp, lax, core, vmap
from jax import random, numpy as jnp, lax, core, vmap, pmap
from jax._src.lax import parallel
from jaxlib import xla_client

from jaxns.framework.bases import BaseAbstractModel
from jaxns.internals.cumulative_ops import cumulative_op_static
Expand Down Expand Up @@ -581,49 +583,51 @@ def _repeat(a):
)


@dataclasses.dataclass(eq=False)
class StandardStaticNestedSampler(BaseAbstractNestedSampler):
"""
A static nested sampler that uses a fixed number of live points. This uses a uniform sampler to generate the
initial set of samples down to an efficiency threshold, then uses a provided sampler to generate the rest of the
samples until the termination condition is met.
"""
def __init__(self, init_efficiency_threshold: float, sampler: BaseAbstractSampler, num_live_points: int,
model: BaseAbstractModel, max_samples: int, num_parallel_workers: int = 1, verbose: bool = False):
"""
Initialise the static nested sampler.
Args:
init_efficiency_threshold: the efficiency threshold to use for the initial uniform sampling. If 0 then
turns it off.
sampler: the sampler to use after the initial uniform sampling.
num_live_points: the number of live points to use.
model: the model to use.
max_samples: the maximum number of samples to take.
num_parallel_workers: number of parallel workers to use. Defaults to 1. Experimental feature.
verbose: whether to log as we go.
"""
self.init_efficiency_threshold = init_efficiency_threshold
self.sampler = sampler
self.num_live_points = int(num_live_points)
self.num_parallel_workers = int(num_parallel_workers)
self.verbose = bool(verbose)
remainder = max_samples % self.num_live_points
extra = (max_samples - remainder) % self.num_live_points
Args:
init_efficiency_threshold: the efficiency threshold to use for the initial uniform sampling. If 0 then
turns it off.
sampler: the sampler to use after the initial uniform sampling.
num_live_points: the number of live points to use.
model: the model to use.
max_samples: the maximum number of samples to take.
devices: the devices to use, default is 1.
verbose: whether to log as we go.
"""
init_efficiency_threshold: float
sampler: BaseAbstractSampler
num_live_points: int
model: BaseAbstractModel
max_samples: int
devices: Optional[List[xla_client.Device]] = None
verbose: bool = False

def __post_init__(self):
if self.devices is None:
self.devices = jax.devices()[:1]

remainder = self.max_samples % self.num_live_points
extra = (self.max_samples - remainder) % self.num_live_points
if extra > 0:
warnings.warn(
f"Increasing max_samples ({max_samples}) by {extra} to closest multiple of "
f"Increasing max_samples ({self.max_samples}) by {extra} to closest multiple of "
f"num_live_points {self.num_live_points}."
)
max_samples = int(max_samples + extra)
if self.num_parallel_workers > 1:
logger.info(f"Using {self.num_parallel_workers} parallel workers, each running identical samplers.")
super().__init__(model=model, max_samples=max_samples)
self.max_samples = int(self.max_samples + extra)
if len(self.devices) > 1:
logger.info(f"Using {len(self.devices)} parallel workers, each running identical samplers.")
BaseAbstractNestedSampler.__init__(self, model=self.model, max_samples=self.max_samples)

def __repr__(self):
return f"StandardStaticNestedSampler(init_efficiency_threshold={self.init_efficiency_threshold}, " \
f"sampler={self.sampler}, num_live_points={self.num_live_points}, model={self.model}, " \
f"max_samples={self.max_samples}, num_parallel_workers={self.num_parallel_workers})"
f"max_samples={self.max_samples}, devices={self.devices})"

def _to_results(self, termination_reason: IntArray, state: StaticStandardNestedSamplerState,
trim: bool) -> NestedSamplerResults:
Expand Down Expand Up @@ -772,10 +776,10 @@ def replica(key: PRNGKey) -> Tuple[StaticStandardNestedSamplerState, IntArray]:
num_samples_per_sync=self.num_live_points,
verbose=self.verbose
)
if self.num_parallel_workers > 1:
if len(self.devices) > 1:
# We need to do a final sampling run to make all the chains consistent,
# to a likelihood contour (i.e. standardise on L(X)). Would mean that some workers are idle.
target_log_L_contour = parallel.pmax(termination_register.log_L_contour, 'i')
target_log_L_contour = parallel.pmax(termination_register.log_L_contour, 'markov_chain')

termination_cond = TerminationCondition(
dlogZ=jnp.asarray(0., float_type),
Expand All @@ -793,10 +797,10 @@ def replica(key: PRNGKey) -> Tuple[StaticStandardNestedSamplerState, IntArray]:

return state, termination_reason

if self.num_parallel_workers > 1:
parallel_ns = pmap(replica, axis_name='i')
if len(self.devices) > 1:
parallel_ns = pmap(replica, axis_name='markov_chain')

keys = random.split(key, self.num_parallel_workers)
keys = random.split(key, len(self.devices))
batched_state, termination_reason = parallel_ns(keys)
state = unbatch_state(batched_state=batched_state)
else:
Expand Down
14 changes: 11 additions & 3 deletions jaxns/public.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import jax.numpy as jnp
import tensorflow_probability.substrates.jax as tfp
from jax import core
from jaxlib import xla_client

from jaxns.framework.bases import BaseAbstractModel
from jaxns.internals.logging import logger
Expand Down Expand Up @@ -38,7 +39,8 @@ def __init__(self, model: BaseAbstractModel,
s: Optional[int] = None,
k: Optional[int] = None,
c: Optional[int] = None,
num_parallel_workers: int = 1,
num_parallel_workers: int | None = None,
devices: Optional[xla_client.Device] = None,
difficult_model: bool = False,
parameter_estimation: bool = False,
init_efficiency_threshold: float = 0.1,
Expand Down Expand Up @@ -87,6 +89,12 @@ def __init__(self, model: BaseAbstractModel,
# Sanity check for max_samples (should be able to at least do one shrinkage)
if max_samples < self._c * (self._k + 1):
warnings.warn(f"max_samples={max_samples} is likely too small!")
if num_parallel_workers is not None:
warnings.warn("`num_parallel_workers` is depreciated. Use `devices` instead.")
if devices is None:
devices = jax.devices()[:num_parallel_workers]
else:
devices = devices[:num_parallel_workers]
self._nested_sampler = StandardStaticNestedSampler(
model=model,
num_live_points=self._c,
Expand All @@ -99,7 +107,7 @@ def __init__(self, model: BaseAbstractModel,
perfect=True
),
init_efficiency_threshold=init_efficiency_threshold,
num_parallel_workers=num_parallel_workers,
devices=devices,
verbose=verbose
)

Expand All @@ -111,7 +119,7 @@ def __init__(self, model: BaseAbstractModel,
self.load_results = load_results

def __repr__(self):
return f"DefaultNestedSampler(s={self._s}, c={self._c}, k={self._k})"
return f"DefaultNestedSampler(s={self._s}, c={self._c}, k={self._k}, devices={self._nested_sampler.devices})"

@property
def num_live_points(self) -> int:
Expand Down
7 changes: 1 addition & 6 deletions jaxns/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
import os
from time import monotonic_ns

import jax

# Force 2 jax hosts
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2"

import pytest
from jax import numpy as jnp, random
from jax._src.scipy.linalg import solve_triangular
Expand All @@ -14,10 +9,10 @@
from jaxns.framework.bases import PriorModelGen
from jaxns.framework.model import Model
from jaxns.framework.prior import Prior
from jaxns.internals.types import TerminationCondition
from jaxns.nested_sampler.standard_static import StandardStaticNestedSampler
from jaxns.public import DefaultNestedSampler
from jaxns.samplers.multi_ellipsoidal_samplers import MultiEllipsoidalSampler
from jaxns.internals.types import TerminationCondition
from jaxns.utils import bruteforce_evidence, summary

# from jaxns.nested_sampler import ApproximateNestedSampler, ExactNestedSampler
Expand Down

0 comments on commit 4af4dd0

Please sign in to comment.