Skip to content

Commit

Permalink
Bounded Array + Zygote performance (#55)
Browse files Browse the repository at this point in the history
* Implement bounded for array

* Bump patch

* Fix positive test

* Improve readme

* Formatting

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Reduce test-case size

* Loosen allocation bound

* Fix test

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
Will Tebbutt and github-actions[bot] authored Aug 25, 2022
1 parent d0a1ac6 commit dd13a55
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ParameterHandling"
uuid = "2412ca09-6db7-441c-8e3a-88d5709968c5"
authors = ["Invenia Technical Computing Corporation"]
version = "0.4.4"
version = "0.4.5"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -280,4 +280,4 @@ gap between the "flat" representation of parameters that `Optim` likes to work w

1. `Integer`s typically don't take part in the kind of optimisation procedures that this package is designed to handle. Consequently, `flatten(::Integer)` produces an empty vector.
2. `deferred` has some type-stability issues when used in conjunction with abstract types. For example, `flatten(deferred(Normal, 5.0, 4.0))` won't infer properly. A simple work around is to write a function `normal(args...) = Normal(args...)` and work with `deferred(normal, 5.0, 4.0)` instead.
3. Let `x` be an `Array{<:Real}`. If you wish to constrain each of its values to be positive, prefer `positive(x)` over `map(positive, x)` or `positive.(x)`. `positive(x)` has been implemented the associated `unflatten` function has good performance, particularly when interacting with `Zygote` (when `map(positive, x)` is extremely slow).
3. Let `x` be an `Array{<:Real}`. If you wish to constrain each of its values to be positive, prefer `positive(x)` over `map(positive, x)` or `positive.(x)`. `positive(x)` has been implemented the associated `unflatten` function has good performance, particularly when interacting with `Zygote` (when `map(positive, x)` is extremely slow). The same thing applies to `bounded` values. Prefer `bounded(x, lb, ub)` to e.g. `bounded.(x, lb, ub)`.
52 changes: 52 additions & 0 deletions src/parameters_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,55 @@ function positive(val::Array{<:Real}, transform=exp, ε=sqrt(eps(eltype(val))))
unconstrained_value, transform, convert(eltype(unconstrained_value), ε)
)
end

struct BoundedArray{T<:Real,Ta<:AbstractArray{T},V,Tε<:Real} <: AbstractParameter
unconstrained_value::Ta
lower_bound::T
upper_bound::T
transform::V
ε::Tε
end

value(x::BoundedArray) = x.transform(x.unconstrained_value)

function flatten(::Type{T}, x::BoundedArray) where {T<:Real}
v, unflatten_to_Array = flatten(T, x.unconstrained_value)

function unflatten_Bounded(v_new::Vector{T})
return BoundedArray(
unflatten_to_Array(v_new), x.lower_bound, x.upper_bound, x.transform, x.ε
)
end

return v, unflatten_Bounded
end

"""
bounded(val::Array{<:Real}, lower_bound::Real, upper_bound::Real)
Roughly equivalent to `bounded.(val, lower_bound, upper_bound)`, but implemented such that
unflattening can be efficiently differentiated through using algorithmic differentiation
(Zygote in particular).
"""
function bounded(val::Array{<:Real}, lower_bound::Real, upper_bound::Real)
lb = convert(eltype(val), lower_bound)
ub = convert(eltype(val), upper_bound)

# construct open interval
ε = convert(eltype(val), 1e-12)
lb_plus_ε = lb + ε
ub_minus_ε = ub - ε

if any(val .> ub_minus_ε) || any(val .< lb_plus_ε)
throw(
ArgumentError("At least one element of `val`, $val, outside of specified bounds
($lower_bound, $upper_bound).")
)
end

length_interval = ub_minus_ε - lb_plus_ε
unconstrained_val = logit.((val .- lb_plus_ε) ./ length_interval)
transform(x) = lb_plus_ε .+ length_interval .* logistic.(x)

return BoundedArray(unconstrained_val, lb, ub, transform, ε)
end
43 changes: 42 additions & 1 deletion test/parameters_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,48 @@

# forward evaluation
count_allocs(Zygote.pullback, unflatten, flat_x)
count_allocs(Zygote.pullback, unflatten, flat_x) > 1000
@test count_allocs(Zygote.pullback, unflatten, flat_x) > 1000
end
end

@testset "bounded" begin
@testset "$val" for val in [
[-0.05, 0.5], [-0.1 + 1e-12, 2.0 - 1e-11], fill(2.0 - 1e-12, 1, 2, 3)
]
p = bounded(val, -0.1, 2.0)
test_parameter_interface(p)
@test value(p) val
end

@test_throws ArgumentError bounded([-0.05], 0.0, 1.0)

# Same style of performance test as for positive(::Array). See above for info.
@testset "zygote performance" begin
x = rand(1000, 1000) .* 1.98 .- 0.99
flat_x, unflatten = value_flatten(bounded(x, -1.0, 1.0))

# primal evaluation
count_allocs(unflatten, flat_x)
@test count_allocs(unflatten, flat_x) < 300

# forward evaluation
count_allocs(Zygote.pullback, unflatten, flat_x)
@test count_allocs(Zygote.pullback, unflatten, flat_x) < 300

# pullback
out, pb = Zygote.pullback(unflatten, flat_x)
count_allocs(pb, out)
@test count_allocs(pb, out) < 300
end

# Same style of performance test as for `map(positive, x)`. See above for info.
@testset "zygote performance of scalar equivalent" begin
x = rand(1000) .* 1.98 .- 0.99
flat_x, unflatten = value_flatten(map(x -> bounded(x, -1.0, 1.0), x))

# forward evaluation
count_allocs(Zygote.pullback, unflatten, flat_x)
@test count_allocs(Zygote.pullback, unflatten, flat_x) > 1000
end
end
end

2 comments on commit dd13a55

@willtebbutt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/67054

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.5 -m "<description of version>" dd13a55f68cd666f05315de8aff22f422038ad1e
git push origin v0.4.5

Please sign in to comment.