diff --git a/src/samplers/mcmc/mcmc_state.jl b/src/samplers/mcmc/mcmc_state.jl index 386717fbd..41bdb4e2b 100644 --- a/src/samplers/mcmc/mcmc_state.jl +++ b/src/samplers/mcmc/mcmc_state.jl @@ -278,28 +278,30 @@ function mcmc_update_z_position!!(mcmc_state::MCMCState) return mcmc_state_new end - function mcmc_update_z_position!!(mc_state::MCMCChainState) - f_transform = mc_state.f_transform - proposed_sample_x = proposed_sample(mc_state) + sample_z = mc_state.sample_z + current_sample_x = current_sample(mc_state) + proposed_sample_x = proposed_sample(mc_state) - x_proposed, logd_x_proposed = proposed_sample_x.v, proposed_sample_x.logd x_current, logd_x_current = current_sample_x.v, current_sample_x.logd + x_proposed, logd_x_proposed = proposed_sample_x.v, proposed_sample_x.logd - z_proposed_new, ladj_proposed = with_logabsdet_jacobian(inverse(f_transform), vec(x_proposed)) - z_current_new, ladj_current = with_logabsdet_jacobian(inverse(f_transform), vec(x_current)) + z_current_new, ladj_current = with_logabsdet_jacobian(inverse(f_transform), x_current[:]) + z_proposed_new, ladj_proposed = with_logabsdet_jacobian(inverse(f_transform), x_proposed[:]) - logd_z_proposed_new = logd_x_proposed - ladj_proposed logd_z_current_new = logd_x_current - ladj_current + logd_z_proposed_new = logd_x_proposed - ladj_proposed - mc_state_tmp_1 = @set mc_state.sample_z.v[2] = vec(z_proposed_new) - mc_state_tmp_2 = @set mc_state_tmp_1.sample_z.logd[2] = logd_z_proposed_new - - mc_state_tmp_3 = @set mc_state_tmp_2.sample_z.v[1] = vec(z_current_new) - mc_state_new = @set mc_state_tmp_3.sample_z.logd[1] = logd_z_current_new + mc_state_new = deepcopy(mc_state) + mc_state_new.sample_z.v[1] = vec(z_current_new) + mc_state_new.sample_z.v[2] = vec(z_proposed_new) + + mc_state_new.sample_z.logd[1] = logd_z_current_new + mc_state_new.sample_z.logd[2] = logd_z_proposed_new + return mc_state_new end diff --git a/test/optimization/test_mode_estimators.jl b/test/optimization/test_mode_estimators.jl index dc40caeab..c2f8df778 100644 --- a/test/optimization/test_mode_estimators.jl +++ b/test/optimization/test_mode_estimators.jl @@ -1,6 +1,7 @@ using BAT using Test +using AutoDiffOperators using LinearAlgebra, Distributions, StatsBase, ValueShapes, Random123, DensityInterface using UnPack, InverseFunctions import ForwardDiff @@ -104,7 +105,7 @@ using Optim, OptimizationOptimJL context = BATContext(rng = Philox4x((0, 0)), ad = ADSelector(ForwardDiff)) # result is not type-stable: - test_findmode(posterior, OptimizationAlg(optalg = Optimization.LBFGS(), trafo = DoNotTransform()), 0.01, context, inferred = false) + test_findmode(posterior, OptimizationAlg(optalg = Optimization.LBFGS(), pretransform = DoNotTransform()), 0.01, context, inferred = false) end @testset "Optimization.jl with custom options" begin # checks that options are correctly passed to Optimization.jl