Skip to content

Commit

Permalink
Refactor z position update in transformed MCMC (#467)
Browse files Browse the repository at this point in the history
  • Loading branch information
Micki-D authored Dec 17, 2024
1 parent fb502b2 commit 250a969
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 13 deletions.
26 changes: 14 additions & 12 deletions src/samplers/mcmc/mcmc_state.jl
Original file line number Diff line number Diff line change
Expand Up @@ -278,28 +278,30 @@ function mcmc_update_z_position!!(mcmc_state::MCMCState)
return mcmc_state_new
end


function mcmc_update_z_position!!(mc_state::MCMCChainState)

f_transform = mc_state.f_transform
proposed_sample_x = proposed_sample(mc_state)
sample_z = mc_state.sample_z

current_sample_x = current_sample(mc_state)
proposed_sample_x = proposed_sample(mc_state)

x_proposed, logd_x_proposed = proposed_sample_x.v, proposed_sample_x.logd
x_current, logd_x_current = current_sample_x.v, current_sample_x.logd
x_proposed, logd_x_proposed = proposed_sample_x.v, proposed_sample_x.logd

z_proposed_new, ladj_proposed = with_logabsdet_jacobian(inverse(f_transform), vec(x_proposed))
z_current_new, ladj_current = with_logabsdet_jacobian(inverse(f_transform), vec(x_current))
z_current_new, ladj_current = with_logabsdet_jacobian(inverse(f_transform), x_current[:])
z_proposed_new, ladj_proposed = with_logabsdet_jacobian(inverse(f_transform), x_proposed[:])

logd_z_proposed_new = logd_x_proposed - ladj_proposed
logd_z_current_new = logd_x_current - ladj_current
logd_z_proposed_new = logd_x_proposed - ladj_proposed

mc_state_tmp_1 = @set mc_state.sample_z.v[2] = vec(z_proposed_new)
mc_state_tmp_2 = @set mc_state_tmp_1.sample_z.logd[2] = logd_z_proposed_new

mc_state_tmp_3 = @set mc_state_tmp_2.sample_z.v[1] = vec(z_current_new)
mc_state_new = @set mc_state_tmp_3.sample_z.logd[1] = logd_z_current_new
mc_state_new = deepcopy(mc_state)

mc_state_new.sample_z.v[1] = vec(z_current_new)
mc_state_new.sample_z.v[2] = vec(z_proposed_new)

mc_state_new.sample_z.logd[1] = logd_z_current_new
mc_state_new.sample_z.logd[2] = logd_z_proposed_new

return mc_state_new
end

Expand Down
3 changes: 2 additions & 1 deletion test/optimization/test_mode_estimators.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using BAT
using Test

using AutoDiffOperators
using LinearAlgebra, Distributions, StatsBase, ValueShapes, Random123, DensityInterface
using UnPack, InverseFunctions
import ForwardDiff
Expand Down Expand Up @@ -104,7 +105,7 @@ using Optim, OptimizationOptimJL

context = BATContext(rng = Philox4x((0, 0)), ad = ADSelector(ForwardDiff))
# result is not type-stable:
test_findmode(posterior, OptimizationAlg(optalg = Optimization.LBFGS(), trafo = DoNotTransform()), 0.01, context, inferred = false)
test_findmode(posterior, OptimizationAlg(optalg = Optimization.LBFGS(), pretransform = DoNotTransform()), 0.01, context, inferred = false)
end

@testset "Optimization.jl with custom options" begin # checks that options are correctly passed to Optimization.jl
Expand Down

0 comments on commit 250a969

Please sign in to comment.