Skip to content

Commit

Permalink
positive(::Array) (#54)
Browse files Browse the repository at this point in the history
* PositiveArray implementation

* Add empty tests file

* Formatting

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Correctness testing

* Add performance tests

* Check poor performance of naive approach

* Document positive more thoroughly

* Formatting

* Bump patch version

* Improve test comments

* Include allocation counter

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
Will Tebbutt and github-actions[bot] authored Aug 25, 2022
1 parent 3ff6a53 commit d0a1ac6
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ParameterHandling"
uuid = "2412ca09-6db7-441c-8e3a-88d5709968c5"
authors = ["Invenia Technical Computing Corporation"]
version = "0.4.3"
version = "0.4.4"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,8 @@ In particular, we've seen an example of how ParameterHandling.jl can be used to
gap between the "flat" representation of parameters that `Optim` likes to work with, and the
"structured" representation that it's convenient to write optimisation algorithms with.

# Gotchas
# Gotchas and Performance Tips

1. `Integer`s typically don't take part in the kind of optimisation procedures that this package is designed to handle. Consequently, `flatten(::Integer)` produces an empty vector.
2. `deferred` has some type-stability issues when used in conjunction with abstract types. For example, `flatten(deferred(Normal, 5.0, 4.0))` won't infer properly. A simple work around is to write a function `normal(args...) = Normal(args...)` and work with `deferred(normal, 5.0, 4.0)` instead.
3. Let `x` be an `Array{<:Real}`. If you wish to constrain each of its values to be positive, prefer `positive(x)` over `map(positive, x)` or `positive.(x)`. `positive(x)` has been implemented the associated `unflatten` function has good performance, particularly when interacting with `Zygote` (when `map(positive, x)` is extremely slow).
1 change: 1 addition & 0 deletions src/ParameterHandling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ include("parameters_base.jl")
include("parameters_meta.jl")
include("parameters_scalar.jl")
include("parameters_matrix.jl")
include("parameters_array.jl")

include("test_utils.jl")

Expand Down
34 changes: 34 additions & 0 deletions src/parameters_array.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
struct PositiveArray{T<:Array{<:Real},V,Tε<:Real} <: AbstractParameter
unconstrained_value::T
transform::V
ε::Tε
end

value(x::PositiveArray) = map(exp, x.unconstrained_value) .+ x.ε

function flatten(::Type{T}, x::PositiveArray{<:Array{V}}) where {T<:Real,V<:Real}
v, unflatten_to_array = flatten(T, x.unconstrained_value)
transform = x.transform
ε = x.ε
function unflatten_PositiveArray(v::AbstractVector{T})
return PositiveArray(unflatten_to_array(v), transform, ε)
end
return v, unflatten_PositiveArray
end

"""
positive(x::Array{<:Real})
Roughly equivalent to `map(positive, x)`, but implemented such that unflattening can be
efficiently differentiated through using algorithmic differentiation (Zygote in particular).
"""
function positive(val::Array{<:Real}, transform=exp, ε=sqrt(eps(eltype(val))))
all(val .> 0) || throw(ArgumentError("Not all elements of val are positive."))
all(val .> ε) || throw(ArgumentError("Not all elements of val greater than ε ()."))

inverse_transform = inverse(transform)
unconstrained_value = map(x -> inverse_transform(x - ε), val)
return PositiveArray(
unconstrained_value, transform, convert(eltype(unconstrained_value), ε)
)
end
54 changes: 54 additions & 0 deletions test/parameters_array.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
@testset "parameters_array" begin
@testset "postive" begin
@testset "$val" for val in [[5.0, 4.0], [0.001f0], fill(1e-7, 1, 2)]
p = positive(val)
test_parameter_interface(p)
@test value(p) val
@test typeof(value(p)) === typeof(val)
end

# Test edge cases around the size of the value relative to the error tol.
@test_throws ArgumentError positive([-0.1, 0.1])
@test_throws ArgumentError positive(fill(1e-12, 1, 2, 3))
@test value(positive(fill(1e-11, 3, 2, 1), exp, 1e-12)) fill(1e-11, 3, 2, 1)

# These tests assume that if the number of allocations is roughly constant in the
# size of `x`, then performance is acceptable. This is demonstrated by requiring
# that the number of allocations (100) is a lot smaller than the total length of
# the array in question (1_000_000). The bound (100) is quite loose because there
# are typically serveral 10s of allocations made by Zygote for book-keeping
# purposes etc.
@testset "zygote performance" begin
x = rand(1000, 1000) .+ 0.1
flat_x, unflatten = value_flatten(positive(x))

# primal evaluation
count_allocs(unflatten, flat_x)
@test count_allocs(unflatten, flat_x) < 100

# forward evaluation
count_allocs(Zygote.pullback, unflatten, flat_x)
@test count_allocs(Zygote.pullback, unflatten, flat_x) < 100

# pullback
out, pb = Zygote.pullback(unflatten, flat_x)
count_allocs(pb, out)
@test count_allocs(pb, out) < 100
end

# Check that this optimisation is actually necessary -- i.e. that the performance
# of the equivalent operation, `map(positive, x)`, is indeed poor, esp. with AD.
# Poor performance is demonstrated by showing that there's at least one allocation
# per element. A smaller array than the previous test set is used because it can
# be _really_ slow for large arrays (several seconds), which is undesirable in
# unit tests.
@testset "zygote performance of scalar equivalent" begin
x = rand(1000) .+ 0.1
flat_x, unflatten = value_flatten(map(positive, x))

# forward evaluation
count_allocs(Zygote.pullback, unflatten, flat_x)
count_allocs(Zygote.pullback, unflatten, flat_x) > 1000
end
end
end
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,16 @@ using ParameterHandling.TestUtils: test_flatten_interface, test_parameter_interf

const tuple_infers = VERSION < v"1.5" ? false : true

function count_allocs(f, args...)
stats = @timed f(args...)
return Base.gc_alloc_count(stats.gcstats)
end

@testset "ParameterHandling.jl" begin
include("flatten.jl")
include("parameters.jl")
include("parameters_meta.jl")
include("parameters_scalar.jl")
include("parameters_matrix.jl")
include("parameters_array.jl")
end

2 comments on commit d0a1ac6

@willtebbutt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/67037

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.4 -m "<description of version>" d0a1ac6e8991530190d4c01379d65e86543d9445
git push origin v0.4.4

Please sign in to comment.