Skip to content

Commit

Permalink
* Remove haiku
Browse files Browse the repository at this point in the history
  • Loading branch information
Joshuaalbert committed Aug 20, 2024
1 parent c35bbe9 commit 0824ae5
Show file tree
Hide file tree
Showing 12 changed files with 424 additions and 70 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -363,9 +363,13 @@ is the best way to achieve speed up.

# Change Log

20 Aug, 2024 -- JAXNS 2.6.0 released. Removed haiku dependency. Implemented our own context,
see `jaxns.framework.context`. Uses this context for internals of framework, providing alternative to yielding `Prior`.

24 Jul, 2024 -- JAXNS 2.5.3 released. Replacing framework U-space with W-space. Maintained external API in U space.

23 Jul, 2024 -- JAXNS 2.5.2 released. Added explicit density prior. Sped up parametrisation. Scan associative implemented.
23 Jul, 2024 -- JAXNS 2.5.2 released. Added explicit density prior. Sped up parametrisation. Scan associative
implemented.

27 May, 2024 -- JAXS 2.5.1 released. Fixed minor accuracy degradation introduced in 2.4.13.

Expand Down
47 changes: 20 additions & 27 deletions jaxns/experimental/evidence_maximisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,11 @@
from jaxopt import ArmijoSGD, BFGS
from tqdm import tqdm

from jaxns import DefaultNestedSampler, Model
from jaxns.framework.context import MutableParams
from jaxns.internals.cumulative_ops import cumulative_op_static
from jaxns.internals.log_semiring import LogSpace
from jaxns.internals.logging import logger

try:
import haiku as hk
except ImportError:
print("You must `pip install dm-haiku` first.")
raise

try:
import optax
except ImportError:
print("You must `pip install optax` first.")
raise

from jaxns import DefaultNestedSampler, Model
from jaxns.internals.types import TerminationCondition, NestedSamplerResults, StaticStandardNestedSamplerState, \
IntArray, PRNGKey, float_type

Expand Down Expand Up @@ -106,7 +94,7 @@ def _create_e_step(self):
A compiled function that runs nested sampling and returns trimmed results.
"""

def _ns_solve(params: hk.MutableParams, key: random.PRNGKey) -> Tuple[
def _ns_solve(params: MutableParams, key: random.PRNGKey) -> Tuple[
IntArray, StaticStandardNestedSamplerState]:
model = self.model(params=params)
ns = DefaultNestedSampler(model=model, **self.ns_kwargs)
Expand All @@ -120,15 +108,15 @@ def _ns_solve(params: hk.MutableParams, key: random.PRNGKey) -> Tuple[
logger.info(f"E-step compilation time: {time.time() - t0:.2f}s")
ns = DefaultNestedSampler(model=self.model(params=self.model.params), **self.ns_kwargs)

def _e_step(key: PRNGKey, params: hk.MutableParams, p_bar: tqdm) -> NestedSamplerResults:
def _e_step(key: PRNGKey, params: MutableParams, p_bar: tqdm) -> NestedSamplerResults:
p_bar.set_description(f"Running E-step... {p_bar.desc}")
termination_reason, state = ns_solve_compiled(params, key)
# Trim results
return ns.to_results(termination_reason=termination_reason, state=state, trim=True)

return _e_step

def e_step(self, key: PRNGKey, params: hk.MutableParams, p_bar: tqdm) -> NestedSamplerResults:
def e_step(self, key: PRNGKey, params: MutableParams, p_bar: tqdm) -> NestedSamplerResults:
"""
The E-step is just nested sampling.
Expand Down Expand Up @@ -163,7 +151,7 @@ def _m_step_iterator(self, key: PRNGKey, data: MStepData):
yield batch

def _create_m_step_stochastic(self):
def log_evidence(params: hk.MutableParams, data: MStepData):
def log_evidence(params: MutableParams, data: MStepData):
# Compute the log evidence
model = self.model(params=params)
# To make manageable, we could do chunked_pmap
Expand All @@ -174,7 +162,7 @@ def log_evidence(params: hk.MutableParams, data: MStepData):
log_Z = logsumexp(log_dZ)
return log_Z

def loss(params: hk.MutableParams, data: MStepData):
def loss(params: MutableParams, data: MStepData):
log_Z, grad = jax.value_and_grad(log_evidence, argnums=0)(params, data)
obj = -log_Z
grad = jax.tree.map(jnp.negative, grad)
Expand All @@ -191,6 +179,11 @@ def loss(params: hk.MutableParams, data: MStepData):
return (obj, aux), grad

if self.solver == 'adam':
try:
import optax
except ImportError:
raise ImportError("optax must be installed to use the 'adam' solver")

solver = jaxopt.OptaxSolver(
fun=loss,
opt=optax.adam(learning_rate=1e-2),
Expand All @@ -215,7 +208,7 @@ def loss(params: hk.MutableParams, data: MStepData):
else:
raise ValueError(f"Unknown solver {self.solver}")

def _m_step_stochastic(key: PRNGKey, params: hk.MutableParams, data: MStepData) -> Tuple[hk.MutableParams, Any]:
def _m_step_stochastic(key: PRNGKey, params: MutableParams, data: MStepData) -> Tuple[MutableParams, Any]:
"""
The M-step is just evidence maximisation.
Expand All @@ -238,7 +231,7 @@ def _m_step_stochastic(key: PRNGKey, params: hk.MutableParams, data: MStepData)

def _create_m_step(self):

def log_evidence(params: hk.MutableParams, data: MStepData):
def log_evidence(params: MutableParams, data: MStepData):
# Compute the log evidence
model = self.model(params=params)

Expand All @@ -249,7 +242,7 @@ def op(log_Z, data):
log_Z, _ = cumulative_op_static(op=op, init=jnp.asarray(-jnp.inf, float_type), xs=data)
return log_Z

def loss(params: hk.MutableParams, data: MStepData):
def loss(params: MutableParams, data: MStepData):
log_Z, grad = jax.value_and_grad(log_evidence, argnums=0)(params, data)
obj = -log_Z
grad = jax.tree.map(jnp.negative, grad)
Expand All @@ -268,7 +261,7 @@ def loss(params: hk.MutableParams, data: MStepData):
)

@partial(jax.jit)
def _m_step(key: PRNGKey, params: hk.MutableParams, data: MStepData) -> Tuple[hk.MutableParams, Any]:
def _m_step(key: PRNGKey, params: MutableParams, data: MStepData) -> Tuple[MutableParams, Any]:
"""
The M-step is just evidence maximisation.
Expand All @@ -284,8 +277,8 @@ def _m_step(key: PRNGKey, params: hk.MutableParams, data: MStepData) -> Tuple[hk

return _m_step

def m_step(self, key: PRNGKey, params: hk.MutableParams, ns_results: NestedSamplerResults, p_bar: tqdm) -> Tuple[
hk.MutableParams, Any]:
def m_step(self, key: PRNGKey, params: MutableParams, ns_results: NestedSamplerResults, p_bar: tqdm) -> Tuple[
MutableParams, Any]:
"""
The M-step is just evidence maximisation. We pad the data to the next power of 2, to make JIT compilation
happen less frequently.
Expand Down Expand Up @@ -331,9 +324,9 @@ def _pad_to_n(x, fill_value, dtype):

return params, log_Z

def train(self, num_steps: int = 10, params: Optional[hk.MutableParams] = None) -> \
def train(self, num_steps: int = 10, params: Optional[MutableParams] = None) -> \
Tuple[
NestedSamplerResults, hk.MutableParams]:
NestedSamplerResults, MutableParams]:
"""
Train the model using EM for num_steps.
Expand Down
Loading

0 comments on commit 0824ae5

Please sign in to comment.