Skip to content

Commit

Permalink
ready for test
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenharry committed Dec 27, 2024
1 parent 37b5f57 commit 996258c
Showing 1 changed file with 30 additions and 69 deletions.
99 changes: 30 additions & 69 deletions blackjax/adaptation/adjusted_mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,13 @@
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree

import optax
import blackjax
from blackjax.adaptation.mclmc_adaptation import MCLMCAdaptationState, handle_nans
from blackjax.adaptation.mclmc_adaptation import MCLMCAdaptationState
from blackjax.adaptation.step_size import (
DualAveragingAdaptationState,
dual_averaging_adaptation,
)
from blackjax.diagnostics import effective_sample_size
from blackjax.mcmc.adjusted_mclmc import rescale
from blackjax.util import pytree_size, incremental_value_update, run_inference_algorithm

from blackjax.mcmc.integrators import (
generate_euclidean_integrator,
generate_isokinetic_integrator,
mclachlan,
yoshida,
velocity_verlet,
omelyan,
isokinetic_mclachlan,
isokinetic_velocity_verlet,
isokinetic_yoshida,
isokinetic_omelyan,
)
from blackjax.util import incremental_value_update, pytree_size

Lratio_lowerbound = 0.0
Lratio_upperbound = 2.0
Expand Down Expand Up @@ -92,12 +76,7 @@ def adjusted_mclmc_find_L_and_step_size(

for i in range(num_windows):
window_key = jax.random.fold_in(part1_key, i)
(
state,
params,
eigenvector

) = adjusted_mclmc_make_L_step_size_adaptation(
(state, params, eigenvector) = adjusted_mclmc_make_L_step_size_adaptation(
kernel=mclmc_kernel,
dim=dim,
frac_tune1=frac_tune1,
Expand All @@ -106,25 +85,22 @@ def adjusted_mclmc_find_L_and_step_size(
diagonal_preconditioning=diagonal_preconditioning,
max=max,
tuning_factor=tuning_factor,
)(
state, params, num_steps, window_key
)
)(state, params, num_steps, window_key)

if frac_tune3 != 0:
for i in range(2):
part2_key = jax.random.fold_in(part2_key, i)
part2_key1, part2_key2 = jax.random.split(part2_key, 2)

state, params = adjusted_mclmc_make_adaptation_L(
mclmc_kernel, frac=frac_tune3, Lfactor=0.5, max=max, eigenvector=eigenvector,
mclmc_kernel,
frac=frac_tune3,
Lfactor=0.5,
max=max,
eigenvector=eigenvector,
)(state, params, num_steps, part2_key1)

(
state,
params,
_

) = adjusted_mclmc_make_L_step_size_adaptation(
(state, params, _) = adjusted_mclmc_make_L_step_size_adaptation(
kernel=mclmc_kernel,
dim=dim,
frac_tune1=frac_tune1,
Expand All @@ -134,12 +110,7 @@ def adjusted_mclmc_find_L_and_step_size(
diagonal_preconditioning=diagonal_preconditioning,
max=max,
tuning_factor=tuning_factor,
)(
state, params, num_steps, part2_key2
)



)(state, params, num_steps, part2_key2)

return state, params

Expand All @@ -152,7 +123,7 @@ def adjusted_mclmc_make_L_step_size_adaptation(
target,
diagonal_preconditioning,
fix_L_first_da=False,
max='avg',
max="avg",
tuning_factor=1.0,
):
"""Adapts the stepsize and L of the MCLMC kernel. Designed for the unadjusted MCLMC"""
Expand All @@ -161,8 +132,6 @@ def dual_avg_step(fix_L, update_da):
"""does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize"""

def step(iteration_state, weight_and_key):


mask, rng_key = weight_and_key
(
previous_state,
Expand All @@ -180,7 +149,6 @@ def step(iteration_state, weight_and_key):
step_size=params.step_size,
sqrt_diag_cov=params.sqrt_diag_cov,
)


# step updating
success, state, step_size_max, energy_change = handle_nans(
Expand Down Expand Up @@ -231,7 +199,7 @@ def step(iteration_state, weight_and_key):
+ (1 - mask) * params.L,
)

if max!='max_svd':
if max != "max_svd":
state_position = None
else:
state_position = state.position
Expand Down Expand Up @@ -259,9 +227,6 @@ def step_size_adaptation(mask, state, params, keys, fix_L, initial_da, update_da
),
xs=(mask, keys),
)




def L_step_size_adaptation(state, params, num_steps, rng_key):
num_steps1, num_steps2 = int(num_steps * frac_tune1), int(
Expand Down Expand Up @@ -294,26 +259,21 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
update_da=update_da,
)


final_stepsize = final_da(dual_avg_state)
params = params._replace(step_size=final_stepsize)




# determine L
eigenvector = None
eigenvector = None
if num_steps2 != 0.0:
x_average, x_squared_average = average[0], average[1]
variances = x_squared_average - jnp.square(x_average)

if max=='max':
contract = lambda x: jnp.sqrt(jnp.max(x)*dim)*tuning_factor
if max == "max":
contract = lambda x: jnp.sqrt(jnp.max(x) * dim) * tuning_factor

elif max == "avg":
contract = lambda x: jnp.sqrt(jnp.sum(x)) * tuning_factor

elif max=='avg':
contract = lambda x: jnp.sqrt(jnp.sum(x))*tuning_factor

else:
raise ValueError("max should be either 'max' or 'avg'")

Expand Down Expand Up @@ -346,13 +306,14 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):

params = params._replace(step_size=final_da(dual_avg_state))


return state, params, eigenvector

return L_step_size_adaptation


def adjusted_mclmc_make_adaptation_L(kernel, frac, Lfactor, max='avg', eigenvector=None):
def adjusted_mclmc_make_adaptation_L(
kernel, frac, Lfactor, max="avg", eigenvector=None
):
"""determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)"""

def adaptation_L(state, params, num_steps, key):
Expand All @@ -375,29 +336,29 @@ def step(state, key):
xs=adaptation_L_keys,
)

if max=='max':
if max == "max":
contract = jnp.min
else:
contract = jnp.mean

flat_samples = jax.vmap(lambda x: ravel_pytree(x)[0])(samples)


if eigenvector is not None:

flat_samples = jnp.expand_dims(jnp.einsum('ij,j', flat_samples, eigenvector),1)
flat_samples = jnp.expand_dims(
jnp.einsum("ij,j", flat_samples, eigenvector), 1
)

# number of effective samples per 1 actual sample
ess = contract(effective_sample_size(flat_samples[None, ...]))/num_steps
ess = contract(effective_sample_size(flat_samples[None, ...])) / num_steps

return state, params._replace(L=jnp.clip(Lfactor * params.L / jnp.mean(ess), max=params.L*2))
return state, params._replace(
L=jnp.clip(Lfactor * params.L / jnp.mean(ess), max=params.L * 2)
)

return adaptation_L


def handle_nans(
previous_state, next_state, step_size, step_size_max, kinetic_change
):
def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_change):
"""if there are nans, let's reduce the stepsize, and not update the state. The
function returns the old state in this case."""

Expand Down

0 comments on commit 996258c

Please sign in to comment.