From dd13a55f68cd666f05315de8aff22f422038ad1e Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 25 Aug 2022 14:37:42 +0100 Subject: [PATCH] Bounded Array + Zygote performance (#55) * Implement bounded for array * Bump patch * Fix positive test * Improve readme * Formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Reduce test-case size * Loosen allocation bound * Fix test Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- Project.toml | 2 +- README.md | 2 +- src/parameters_array.jl | 52 ++++++++++++++++++++++++++++++++++++++++ test/parameters_array.jl | 43 ++++++++++++++++++++++++++++++++- 4 files changed, 96 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index ffabe97..74226e0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ParameterHandling" uuid = "2412ca09-6db7-441c-8e3a-88d5709968c5" authors = ["Invenia Technical Computing Corporation"] -version = "0.4.4" +version = "0.4.5" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/README.md b/README.md index 2431f69..141699e 100644 --- a/README.md +++ b/README.md @@ -280,4 +280,4 @@ gap between the "flat" representation of parameters that `Optim` likes to work w 1. `Integer`s typically don't take part in the kind of optimisation procedures that this package is designed to handle. Consequently, `flatten(::Integer)` produces an empty vector. 2. `deferred` has some type-stability issues when used in conjunction with abstract types. For example, `flatten(deferred(Normal, 5.0, 4.0))` won't infer properly. A simple work around is to write a function `normal(args...) = Normal(args...)` and work with `deferred(normal, 5.0, 4.0)` instead. -3. Let `x` be an `Array{<:Real}`. If you wish to constrain each of its values to be positive, prefer `positive(x)` over `map(positive, x)` or `positive.(x)`. `positive(x)` has been implemented the associated `unflatten` function has good performance, particularly when interacting with `Zygote` (when `map(positive, x)` is extremely slow). +3. Let `x` be an `Array{<:Real}`. If you wish to constrain each of its values to be positive, prefer `positive(x)` over `map(positive, x)` or `positive.(x)`. `positive(x)` has been implemented the associated `unflatten` function has good performance, particularly when interacting with `Zygote` (when `map(positive, x)` is extremely slow). The same thing applies to `bounded` values. Prefer `bounded(x, lb, ub)` to e.g. `bounded.(x, lb, ub)`. diff --git a/src/parameters_array.jl b/src/parameters_array.jl index 883c8aa..25eadb6 100644 --- a/src/parameters_array.jl +++ b/src/parameters_array.jl @@ -32,3 +32,55 @@ function positive(val::Array{<:Real}, transform=exp, ε=sqrt(eps(eltype(val)))) unconstrained_value, transform, convert(eltype(unconstrained_value), ε) ) end + +struct BoundedArray{T<:Real,Ta<:AbstractArray{T},V,Tε<:Real} <: AbstractParameter + unconstrained_value::Ta + lower_bound::T + upper_bound::T + transform::V + ε::Tε +end + +value(x::BoundedArray) = x.transform(x.unconstrained_value) + +function flatten(::Type{T}, x::BoundedArray) where {T<:Real} + v, unflatten_to_Array = flatten(T, x.unconstrained_value) + + function unflatten_Bounded(v_new::Vector{T}) + return BoundedArray( + unflatten_to_Array(v_new), x.lower_bound, x.upper_bound, x.transform, x.ε + ) + end + + return v, unflatten_Bounded +end + +""" + bounded(val::Array{<:Real}, lower_bound::Real, upper_bound::Real) + +Roughly equivalent to `bounded.(val, lower_bound, upper_bound)`, but implemented such that +unflattening can be efficiently differentiated through using algorithmic differentiation +(Zygote in particular). +""" +function bounded(val::Array{<:Real}, lower_bound::Real, upper_bound::Real) + lb = convert(eltype(val), lower_bound) + ub = convert(eltype(val), upper_bound) + + # construct open interval + ε = convert(eltype(val), 1e-12) + lb_plus_ε = lb + ε + ub_minus_ε = ub - ε + + if any(val .> ub_minus_ε) || any(val .< lb_plus_ε) + throw( + ArgumentError("At least one element of `val`, $val, outside of specified bounds + ($lower_bound, $upper_bound).") + ) + end + + length_interval = ub_minus_ε - lb_plus_ε + unconstrained_val = logit.((val .- lb_plus_ε) ./ length_interval) + transform(x) = lb_plus_ε .+ length_interval .* logistic.(x) + + return BoundedArray(unconstrained_val, lb, ub, transform, ε) +end diff --git a/test/parameters_array.jl b/test/parameters_array.jl index 5bec38b..4c2917c 100644 --- a/test/parameters_array.jl +++ b/test/parameters_array.jl @@ -48,7 +48,48 @@ # forward evaluation count_allocs(Zygote.pullback, unflatten, flat_x) - count_allocs(Zygote.pullback, unflatten, flat_x) > 1000 + @test count_allocs(Zygote.pullback, unflatten, flat_x) > 1000 + end + end + + @testset "bounded" begin + @testset "$val" for val in [ + [-0.05, 0.5], [-0.1 + 1e-12, 2.0 - 1e-11], fill(2.0 - 1e-12, 1, 2, 3) + ] + p = bounded(val, -0.1, 2.0) + test_parameter_interface(p) + @test value(p) ≈ val + end + + @test_throws ArgumentError bounded([-0.05], 0.0, 1.0) + + # Same style of performance test as for positive(::Array). See above for info. + @testset "zygote performance" begin + x = rand(1000, 1000) .* 1.98 .- 0.99 + flat_x, unflatten = value_flatten(bounded(x, -1.0, 1.0)) + + # primal evaluation + count_allocs(unflatten, flat_x) + @test count_allocs(unflatten, flat_x) < 300 + + # forward evaluation + count_allocs(Zygote.pullback, unflatten, flat_x) + @test count_allocs(Zygote.pullback, unflatten, flat_x) < 300 + + # pullback + out, pb = Zygote.pullback(unflatten, flat_x) + count_allocs(pb, out) + @test count_allocs(pb, out) < 300 + end + + # Same style of performance test as for `map(positive, x)`. See above for info. + @testset "zygote performance of scalar equivalent" begin + x = rand(1000) .* 1.98 .- 0.99 + flat_x, unflatten = value_flatten(map(x -> bounded(x, -1.0, 1.0), x)) + + # forward evaluation + count_allocs(Zygote.pullback, unflatten, flat_x) + @test count_allocs(Zygote.pullback, unflatten, flat_x) > 1000 end end end