Skip to content

Commit

Permalink
* Add utility to trim results, e.g. if to_results(trim=False) is used…
Browse files Browse the repository at this point in the history
… in JIT context, and later want to trim. It's a static method.
  • Loading branch information
Joshuaalbert committed Dec 19, 2023
1 parent 34f8e89 commit 701cf86
Showing 1 changed file with 25 additions and 1 deletion.
26 changes: 25 additions & 1 deletion jaxns/public.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import jax.numpy as jnp
import tensorflow_probability.substrates.jax as tfp
from jax import tree_map, core

from jaxns.framework.bases import BaseAbstractModel
from jaxns.internals.types import PRNGKey, IntArray, StaticStandardNestedSamplerState, TerminationCondition, \
Expand Down Expand Up @@ -41,7 +42,7 @@ def __init__(self, model: BaseAbstractModel, max_samples: Union[int, float], num
Args:
model: a model to perform nested sampling on
max_samples: maximum number of samples to take
num_live_points: number of live points to use. Defaults is c * (k + 1).
num_live_points: approximate number of live points to use. Defaults is c * (k + 1).
s: number of slices to use per dimension. Defaults to 4.
k: number of phantom samples to use. Defaults to D/2.
c: number of parallel Markov-chains to use. Defaults to 20 * D.
Expand Down Expand Up @@ -185,6 +186,29 @@ def to_results(self, termination_reason: IntArray, state: StaticStandardNestedSa
trim=trim
)

@staticmethod
def trim_results(results: NestedSamplerResults) -> NestedSamplerResults:
"""
Trims the results to the number of samples taken. Requires static context.
Args:
results: results to trim
Returns:
trimmed results
"""

if isinstance(results.total_num_samples, core.Tracer):
raise RuntimeError("Tracer detected, but expected imperative context.")

def trim(x):
if x.size > 1:
return x[:results.total_num_samples]
return x

results = tree_map(trim, results)
return results


class ApproximateNestedSampler(DefaultNestedSampler):
def __init__(self, *args, **kwargs):
Expand Down

0 comments on commit 701cf86

Please sign in to comment.