Skip to content

Commit

Permalink
Merge pull request #214 from Joshuaalbert/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
Joshuaalbert authored Dec 7, 2024
2 parents 3550b0f + ad83526 commit e3c2a8f
Show file tree
Hide file tree
Showing 9 changed files with 34 additions and 23 deletions.
7 changes: 2 additions & 5 deletions .github/workflows/unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
pip install -r requirements.txt
pip install -r requirements-tests.txt
pip install .
pip install .[tests]
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
Expand All @@ -39,4 +36,4 @@ jobs:
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest
pytest -s
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,8 @@ before importing JAXNS.

# Change Log

7 Dec, 2024 -- JAXNS 2.6.7 released. Fix pip dependencies install.

13 Nov, 2024 -- JAXNS 2.6.6 released. Minor improvements to plotting.

9 Nov, 2024 -- JAXNS 2.6.5 released. Added gradient guided nested sampling. Removed `num_parallel_workers` in favour
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
project = "jaxns"
copyright = "2024, Joshua G. Albert"
author = "Joshua G. Albert"
release = "2.6.6"
release = "2.6.7"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
7 changes: 2 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "jaxns"
version = "2.6.6"
version = "2.6.7"
description = "Nested Sampling in JAX"
readme = "README.md"
requires-python = ">=3.9"
Expand All @@ -19,10 +19,7 @@ classifiers = [
"Operating System :: OS Independent"
]
urls = { "Homepage" = "https://github.com/joshuaalbert/jaxns" }

[project.optional-dependencies]
# Define the extras here; they will be loaded dynamically from setup.py
notebooks = [] # Placeholders; extras will load from setup.py
dynamic = ["dependencies", "optional-dependencies"]

[tool.setuptools]
include-package-data = true
Expand Down
3 changes: 2 additions & 1 deletion requirements-tests.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
scikit-learn
networkx
psutil
pytest
pytest
flake8
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ def load_requirements(file_name):
install_requires=load_requirements("requirements.txt"),
extras_require={
"examples": load_requirements("requirements-examples.txt"),
},
tests_require=load_requirements("requirements-tests.txt"),
"tests": load_requirements("requirements-tests.txt"),
}
)
18 changes: 16 additions & 2 deletions src/jaxns/framework/special_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
"Poisson",
"UnnormalisedDirichlet",
"Empirical",
"TruncationWrapper"
"TruncationWrapper",
"ExplicitDensityPrior",
]


Expand Down Expand Up @@ -72,6 +73,7 @@ def _quantile(self, U):
sample = jnp.less(U, probs)
return sample.astype(self.dtype)


class Beta(SpecialPrior):
def __init__(self, *, concentration0=None, concentration1=None, name: Optional[str] = None):
super(Beta, self).__init__(name=name)
Expand Down Expand Up @@ -443,7 +445,8 @@ class Empirical(SpecialPrior):
Represents the empirical distribution of a set of 1D samples, with arbitrary batch dimension.
"""

def __init__(self, *, samples: jax.Array, resolution: int = 100, name: Optional[str] = None):
def __init__(self, *, samples: jax.Array, support_min: Optional[FloatArray] = None,
support_max: Optional[FloatArray] = None, resolution: int = 100, name: Optional[str] = None):
super(Empirical, self).__init__(name=name)
if len(np.shape(samples)) < 1:
raise ValueError("Samples must have at least one dimension")
Expand All @@ -452,6 +455,17 @@ def __init__(self, *, samples: jax.Array, resolution: int = 100, name: Optional[
if resolution < 1:
raise ValueError("Resolution must be at least 1")
samples = jnp.asarray(samples)
# Add 1 point for each support endpoint
endpoints = []
if support_min is not None:
endpoints.append(support_min)
if support_max is not None:
endpoints.append(support_max)
if len(endpoints) > 0:
samples = jnp.concatenate([samples, jnp.asarray(endpoints)])

resolution = min(resolution, len(samples) - 1)

self._q = jnp.linspace(0., 100., resolution + 1)
self._percentiles = jnp.reshape(jnp.percentile(samples, self._q, axis=-1), (resolution + 1, -1))

Expand Down
8 changes: 4 additions & 4 deletions src/jaxns/framework/tests/test_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,15 +321,15 @@ def test_forced_identifiability():


def test_empirical():
samples = jax.random.normal(jax.random.PRNGKey(42), shape=(5, 2000), dtype=mp_policy.measure_dtype)
samples = jax.random.normal(jax.random.PRNGKey(42), shape=(2000,), dtype=mp_policy.measure_dtype)
prior = Empirical(samples=samples, resolution=100, name='x')
assert prior._percentiles.shape == (101, 5)
assert prior._percentiles.shape == (101, 1)

x = prior.forward(jnp.ones(prior.base_shape, mp_policy.measure_dtype))
assert x.shape == (5,)
assert x.shape == ()
assert jnp.all(jnp.bitwise_not(jnp.isnan(x)))
x = prior.forward(jnp.zeros(prior.base_shape, mp_policy.measure_dtype))
assert x.shape == (5,)
assert x.shape == ()
assert jnp.all(jnp.bitwise_not(jnp.isnan(x)))

x = prior.forward(0.5 * jnp.ones(prior.base_shape, mp_policy.measure_dtype))
Expand Down
6 changes: 3 additions & 3 deletions src/jaxns/public.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class NestedSampler:
max_samples: Optional[Union[int, float]] = None
num_live_points: Optional[int] = None
num_slices: Optional[int] = None
s: Optional[int] = None
s: Optional[Union[int, float]] = None
k: Optional[int] = None
c: Optional[int] = None
devices: Optional[List[xla_client.Device]] = None
Expand All @@ -70,9 +70,9 @@ def __post_init__(self):
# Determine number of slices per acceptance
if self.num_slices is None:
if self.difficult_model:
self.s = 10 if self.s is None else int(self.s)
self.s = 10 if self.s is None else float(self.s)
else:
self.s = 5 if self.s is None else int(self.s)
self.s = 5 if self.s is None else float(self.s)
if self.s <= 0:
raise ValueError(f"Expected s > 0, got s={self.s}")
self.num_slices = self.model.U_ndims * self.s
Expand Down

0 comments on commit e3c2a8f

Please sign in to comment.