diff --git a/README.rst b/README.rst index 5acc5279..5e28cce7 100644 --- a/README.rst +++ b/README.rst @@ -2,7 +2,7 @@ anesthetic: nested sampling post-processing =========================================== :Authors: Will Handley and Lukas Hergt -:Version: 2.8.15 +:Version: 2.8.16 :Homepage: https://github.com/handley-lab/anesthetic :Documentation: http://anesthetic.readthedocs.io/ diff --git a/anesthetic/_version.py b/anesthetic/_version.py index 7cdefb37..f8f5d150 100644 --- a/anesthetic/_version.py +++ b/anesthetic/_version.py @@ -1 +1 @@ -__version__ = '2.8.15' +__version__ = '2.8.16' diff --git a/anesthetic/samples.py b/anesthetic/samples.py index d0dee2b1..44c81076 100644 --- a/anesthetic/samples.py +++ b/anesthetic/samples.py @@ -1302,15 +1302,17 @@ def recompute(self, logL_birth=None, inplace=False): n_bad = invalid.sum() n_equal = (samples.logL == samples.logL_birth).sum() if n_bad: - warnings.warn("%i out of %i samples have logL <= logL_birth," - "\n%i of which have logL == logL_birth." - "\nThis may just indicate numerical rounding " - "errors at the peak of the likelihood, but " - "further investigation of the chains files is " - "recommended." - "\nDropping the invalid samples." % - (n_bad, len(samples), n_equal), - RuntimeWarning) + n_inf = ((samples.logL == samples.logL_birth) & + (samples.logL == -np.inf)).sum() + if n_bad > n_inf: + warnings.warn( + "%i out of %i samples have logL <= logL_birth,\n" + "%i of which have logL == logL_birth.\n" + "This may just indicate numerical rounding errors at " + "the peak of the likelihood, but further " + "investigation of the chains files is recommended.\n" + "Dropping the invalid samples." + % (n_bad, len(samples), n_equal), RuntimeWarning) samples = samples[~invalid].reset_index(drop=True) samples.sort_values('logL', inplace=True) diff --git a/tests/test_samples.py b/tests/test_samples.py index f49aed3d..0dfdf4c8 100644 --- a/tests/test_samples.py +++ b/tests/test_samples.py @@ -1133,6 +1133,7 @@ def test_beta(): def test_beta_with_logL_infinities(): ns = read_chains("./tests/example_data/pc") ns.loc[:10, ('logL', r'$\ln\mathcal{L}$')] = -np.inf + ns.loc[1000, ('logL', r'$\ln\mathcal{L}$')] = -np.inf with pytest.warns(RuntimeWarning): ns.recompute(inplace=True) assert (ns.logL == -np.inf).sum() == 0