-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* PositiveArray implementation * Add empty tests file * Formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Correctness testing * Add performance tests * Check poor performance of naive approach * Document positive more thoroughly * Formatting * Bump patch version * Improve test comments * Include allocation counter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
- Loading branch information
1 parent
3ff6a53
commit d0a1ac6
Showing
6 changed files
with
98 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
struct PositiveArray{T<:Array{<:Real},V,Tε<:Real} <: AbstractParameter | ||
unconstrained_value::T | ||
transform::V | ||
ε::Tε | ||
end | ||
|
||
value(x::PositiveArray) = map(exp, x.unconstrained_value) .+ x.ε | ||
|
||
function flatten(::Type{T}, x::PositiveArray{<:Array{V}}) where {T<:Real,V<:Real} | ||
v, unflatten_to_array = flatten(T, x.unconstrained_value) | ||
transform = x.transform | ||
ε = x.ε | ||
function unflatten_PositiveArray(v::AbstractVector{T}) | ||
return PositiveArray(unflatten_to_array(v), transform, ε) | ||
end | ||
return v, unflatten_PositiveArray | ||
end | ||
|
||
""" | ||
positive(x::Array{<:Real}) | ||
Roughly equivalent to `map(positive, x)`, but implemented such that unflattening can be | ||
efficiently differentiated through using algorithmic differentiation (Zygote in particular). | ||
""" | ||
function positive(val::Array{<:Real}, transform=exp, ε=sqrt(eps(eltype(val)))) | ||
all(val .> 0) || throw(ArgumentError("Not all elements of val are positive.")) | ||
all(val .> ε) || throw(ArgumentError("Not all elements of val greater than ε ($ε).")) | ||
|
||
inverse_transform = inverse(transform) | ||
unconstrained_value = map(x -> inverse_transform(x - ε), val) | ||
return PositiveArray( | ||
unconstrained_value, transform, convert(eltype(unconstrained_value), ε) | ||
) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
@testset "parameters_array" begin | ||
@testset "postive" begin | ||
@testset "$val" for val in [[5.0, 4.0], [0.001f0], fill(1e-7, 1, 2)] | ||
p = positive(val) | ||
test_parameter_interface(p) | ||
@test value(p) ≈ val | ||
@test typeof(value(p)) === typeof(val) | ||
end | ||
|
||
# Test edge cases around the size of the value relative to the error tol. | ||
@test_throws ArgumentError positive([-0.1, 0.1]) | ||
@test_throws ArgumentError positive(fill(1e-12, 1, 2, 3)) | ||
@test value(positive(fill(1e-11, 3, 2, 1), exp, 1e-12)) ≈ fill(1e-11, 3, 2, 1) | ||
|
||
# These tests assume that if the number of allocations is roughly constant in the | ||
# size of `x`, then performance is acceptable. This is demonstrated by requiring | ||
# that the number of allocations (100) is a lot smaller than the total length of | ||
# the array in question (1_000_000). The bound (100) is quite loose because there | ||
# are typically serveral 10s of allocations made by Zygote for book-keeping | ||
# purposes etc. | ||
@testset "zygote performance" begin | ||
x = rand(1000, 1000) .+ 0.1 | ||
flat_x, unflatten = value_flatten(positive(x)) | ||
|
||
# primal evaluation | ||
count_allocs(unflatten, flat_x) | ||
@test count_allocs(unflatten, flat_x) < 100 | ||
|
||
# forward evaluation | ||
count_allocs(Zygote.pullback, unflatten, flat_x) | ||
@test count_allocs(Zygote.pullback, unflatten, flat_x) < 100 | ||
|
||
# pullback | ||
out, pb = Zygote.pullback(unflatten, flat_x) | ||
count_allocs(pb, out) | ||
@test count_allocs(pb, out) < 100 | ||
end | ||
|
||
# Check that this optimisation is actually necessary -- i.e. that the performance | ||
# of the equivalent operation, `map(positive, x)`, is indeed poor, esp. with AD. | ||
# Poor performance is demonstrated by showing that there's at least one allocation | ||
# per element. A smaller array than the previous test set is used because it can | ||
# be _really_ slow for large arrays (several seconds), which is undesirable in | ||
# unit tests. | ||
@testset "zygote performance of scalar equivalent" begin | ||
x = rand(1000) .+ 0.1 | ||
flat_x, unflatten = value_flatten(map(positive, x)) | ||
|
||
# forward evaluation | ||
count_allocs(Zygote.pullback, unflatten, flat_x) | ||
count_allocs(Zygote.pullback, unflatten, flat_x) > 1000 | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
d0a1ac6
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JuliaRegistrator register()
d0a1ac6
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Registration pull request created: JuliaRegistries/General/67037
After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.
This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via: