Skip to content

Commit

Permalink
Fix Optimization ext
Browse files Browse the repository at this point in the history
  • Loading branch information
oschulz committed Dec 11, 2024
1 parent 159f24f commit fb502b2
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
8 changes: 5 additions & 3 deletions ext/BATOptimizationExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ function test_bat_optimization_ext()
end

AbstractModeEstimator(optalg::Any) = OptimizationAlg(optalg)
convert(::Type{AbstractModeEstimator}, alg::OptimizationAlg) = alg.optalg
Base.convert(::Type{AbstractModeEstimator}, alg::OptimizationAlg) = alg.optalg

BAT.ext_default(::BAT.PackageExtension{:Optimization}, ::Val{:DEFAULT_OPTALG}) = nothing #Optim.NelderMead()


function build_optimizationfunction(f, adsel::AutoDiffOperators.ADSelector)
adm = convert_ad(ADTypes.AbstractADType, adsel)
adm = convert(ADTypes.AbstractADType, reverse_ad_selector(adsel))
optimization_function = Optimization.OptimizationFunction(f, adm)
return optimization_function
end
Expand Down Expand Up @@ -59,7 +59,9 @@ function BAT.bat_findmode_impl(target::MeasureLike, algorithm::OptimizationAlg,
optimization_problem = Optimization.OptimizationProblem(optimization_function, x_init)

algopts = (maxiters = algorithm.maxiters, maxtime = algorithm.maxtime, abstol = algorithm.abstol, reltol = algorithm.reltol)
optimization_result = Optimization.solve(optimization_problem, algorithm.optalg; algopts..., algorithm.kwargs...)
# Not all algorithms support abstol, just filter all NaN-valued opts out:
filtered_algopts = NamedTuple(filter(p -> !isnan(p[2]), pairs(algopts)))
optimization_result = Optimization.solve(optimization_problem, algorithm.optalg; filtered_algopts..., algorithm.kwargs...)

transformed_mode = optimization_result.u
result_mode = inv_trafo(transformed_mode)
Expand Down
4 changes: 4 additions & 0 deletions test/optimization/test_mode_estimators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ using Optim, OptimizationOptimJL
context = BATContext(rng = Philox4x((0, 0)))
# result is not type-stable:
test_findmode(posterior, OptimizationAlg(optalg = OptimizationOptimJL.NelderMead(), pretransform = DoNotTransform()), 0.01, context, inferred = false)

context = BATContext(rng = Philox4x((0, 0)), ad = ADSelector(ForwardDiff))
# result is not type-stable:
test_findmode(posterior, OptimizationAlg(optalg = Optimization.LBFGS(), trafo = DoNotTransform()), 0.01, context, inferred = false)
end

@testset "Optimization.jl with custom options" begin # checks that options are correctly passed to Optimization.jl
Expand Down

0 comments on commit fb502b2

Please sign in to comment.