From 701cf86480c33705dffe1c30dc5a957738b6af1a Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Tue, 19 Dec 2023 17:43:59 +0200 Subject: [PATCH] * Add utility to trim results, e.g. if to_results(trim=False) is used in JIT context, and later want to trim. It's a static method. --- jaxns/public.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/jaxns/public.py b/jaxns/public.py index 5c3f7817..c3494ad1 100644 --- a/jaxns/public.py +++ b/jaxns/public.py @@ -3,6 +3,7 @@ import jax.numpy as jnp import tensorflow_probability.substrates.jax as tfp +from jax import tree_map, core from jaxns.framework.bases import BaseAbstractModel from jaxns.internals.types import PRNGKey, IntArray, StaticStandardNestedSamplerState, TerminationCondition, \ @@ -41,7 +42,7 @@ def __init__(self, model: BaseAbstractModel, max_samples: Union[int, float], num Args: model: a model to perform nested sampling on max_samples: maximum number of samples to take - num_live_points: number of live points to use. Defaults is c * (k + 1). + num_live_points: approximate number of live points to use. Defaults is c * (k + 1). s: number of slices to use per dimension. Defaults to 4. k: number of phantom samples to use. Defaults to D/2. c: number of parallel Markov-chains to use. Defaults to 20 * D. @@ -185,6 +186,29 @@ def to_results(self, termination_reason: IntArray, state: StaticStandardNestedSa trim=trim ) + @staticmethod + def trim_results(results: NestedSamplerResults) -> NestedSamplerResults: + """ + Trims the results to the number of samples taken. Requires static context. + + Args: + results: results to trim + + Returns: + trimmed results + """ + + if isinstance(results.total_num_samples, core.Tracer): + raise RuntimeError("Tracer detected, but expected imperative context.") + + def trim(x): + if x.size > 1: + return x[:results.total_num_samples] + return x + + results = tree_map(trim, results) + return results + class ApproximateNestedSampler(DefaultNestedSampler): def __init__(self, *args, **kwargs):