Skip to content

Commit

Permalink
Merge pull request #188 from Joshuaalbert/shard-map-parallel
Browse files Browse the repository at this point in the history
Shard map parallel
  • Loading branch information
Joshuaalbert authored Sep 23, 2024
2 parents c3e47ec + 48b58b9 commit 2262320
Show file tree
Hide file tree
Showing 89 changed files with 4,591 additions and 3,547 deletions.
3 changes: 0 additions & 3 deletions .github/workflows/unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,4 @@ jobs:
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
set -o allexport
source deployment/local.env
set +o allexport
pytest
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ Given a probabilistic model, JAXNS can perform nested sampling on it. This allow
posterior samples.

```python
from jaxns import DefaultNestedSampler
from jaxns import NestedSampler

ns = DefaultNestedSampler(model=model, max_samples=1e5)
ns = NestedSampler(model=model, max_samples=1e5)

# Run the sampler
termination_reason, state = ns(jax.random.PRNGKey(42))
Expand Down Expand Up @@ -363,6 +363,8 @@ is the best way to achieve speed up.

# Change Log

24 Sep, 2024 -- JAXNS 2.6.0 released. Sharded parallel JAXNS. Rewrite of internals to support sharded parallelisation.

20 Aug, 2024 -- JAXNS 2.6.0 released. Removed haiku dependency. Implemented our own
context. `jaxns.framework.context.convert_external_params` enables interfacing with any external NN libary.

Expand Down
4 changes: 2 additions & 2 deletions benchmarks/difficult_problems/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import jax.numpy as jnp
import tensorflow_probability.substrates.jax as tfp

from jaxns import Model, Prior, DefaultNestedSampler
from jaxns import Model, Prior, NestedSampler

tfpd = tfp.distributions

Expand Down Expand Up @@ -152,7 +152,7 @@ def main():
for model_name, model in all_models().items():
print(f"Testing model {model_name}")
model.sanity_check(jax.random.PRNGKey(0), 1000)
ns = DefaultNestedSampler(model=model,
ns = NestedSampler(model=model,
max_samples=1000000,
verbose=True,
difficult_model=True,
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/gh117/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import tensorflow_probability.substrates.jax as tfp
from jax import random

from jaxns import Model, Prior, DefaultNestedSampler
from jaxns import Model, Prior, NestedSampler


tfpd = tfp.distributions
Expand All @@ -23,7 +23,7 @@ def prior_model():


# Create the nested sampler class. In this case without any tuning.
exact_ns = DefaultNestedSampler(model=model, max_samples=max_samples)
exact_ns = NestedSampler(model=model, max_samples=max_samples)

termination_reason, state = exact_ns(random.PRNGKey(42))
return termination_reason
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/gh168/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import tensorflow_probability.substrates.jax as tfp
from jax._src.scipy.linalg import solve_triangular

from jaxns import Model, Prior, DefaultNestedSampler
from jaxns import Model, Prior, NestedSampler

tfpd = tfp.distributions

Expand Down Expand Up @@ -47,7 +47,7 @@ def log_likelihood(x):

model = Model(prior_model=prior_model, log_likelihood=log_likelihood)

ns = DefaultNestedSampler(model=model, max_samples=100000, verbose=False)
ns = NestedSampler(model=model, max_samples=100000, verbose=False)

termination_reason, state = ns(key)
results = ns.to_results(termination_reason=termination_reason, state=state, trim=False)
Expand Down
11 changes: 5 additions & 6 deletions benchmarks/parallel_problems/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import tensorflow_probability.substrates.jax as tfp
from jax._src.scipy.linalg import solve_triangular

from jaxns import Model, Prior, DefaultNestedSampler
from jaxns import Model, Prior, NestedSampler, jaxify_likelihood

tfpd = tfp.distributions

Expand Down Expand Up @@ -51,7 +51,7 @@ def log_likelihood(x):

model = Model(prior_model=prior_model, log_likelihood=log_likelihood)

ns = DefaultNestedSampler(model=model, max_samples=100000, verbose=False, num_parallel_workers=len(jax.devices()))
ns = NestedSampler(model=model, max_samples=100000, verbose=False)

termination_reason, state = ns(key)
results = ns.to_results(termination_reason=termination_reason, state=state, trim=False)
Expand All @@ -61,7 +61,7 @@ def log_likelihood(x):
def main():
num_devices = len(jax.devices())
jaxns_version = pkg_resources.get_distribution("jaxns").version
m = 1
m = 3
run_model_aot = jax.jit(run_model).lower(jax.random.PRNGKey(0)).compile()
dt = []

Expand All @@ -70,14 +70,13 @@ def main():

for i in range(m):
t0 = time.time()
log_Z_error, log_Z_uncert = run_model_aot(jax.random.PRNGKey(i))
log_Z_error.block_until_ready()
log_Z_error, log_Z_uncert = jax.block_until_ready(run_model_aot(jax.random.PRNGKey(i)))
t1 = time.time()
dt.append(t1 - t0)
errors.append(log_Z_error)
uncerts.append(log_Z_uncert)
total_time = sum(dt)
best_3 = sum(sorted(dt)[:min(3, m)]) / 3.
best_3 = sum(sorted(dt)[:3]) / 3.
# print(f"Errors: {errors}")
# print(f"Uncerts: {uncerts}")
print(f"JAXNS {jaxns_version}\n"
Expand Down
3 changes: 3 additions & 0 deletions benchmarks/parallel_problems/results
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
2.5.1,12,-0.8741127252578735,0.0003452669479884207,95.40295028686523,77.15287351608276
2.5.1,12,-1.30291748046875,0.32637646794319153,102.9092493057251,34.30308310190836
2.3.4,12,0.09217196161439745,0.3709631743726239,2.116373300552368,0.7054577668507894
2.6.0,6,-0.06817162536483465,0.3668122147028353,1.1894174337387085,1.0906640688578289
2.6.0,12,-0.06817162536483465,0.3668122147028353,1.9961725234985352,1.7415425777435303
18 changes: 0 additions & 18 deletions deployment/Dockerfile

This file was deleted.

12 changes: 0 additions & 12 deletions deployment/docker-compose.yaml

This file was deleted.

10 changes: 0 additions & 10 deletions deployment/launch-tests.sh

This file was deleted.

2 changes: 0 additions & 2 deletions deployment/local.env

This file was deleted.

117 changes: 65 additions & 52 deletions docs/examples/Jones_scalar_modelling.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 2262320

Please sign in to comment.