-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
make to_vec(::Integer)
an empty vector
#189
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -55,7 +55,7 @@ is defined. Each 2-`Tuple` in `xẋs` contains the value `x` and its tangent `x | |
function jvp(fdm, f, (x, ẋ)::Tuple{Any, Any}) | ||
x_vec, vec_to_x = to_vec(x) | ||
_, vec_to_y = to_vec(f(x)) | ||
return vec_to_y(_jvp(fdm, x_vec->to_vec(f(vec_to_x(x_vec)))[1], x_vec, to_vec(ẋ)[1])) | ||
return _int2zero(vec_to_y(_jvp(fdm, x_vec->to_vec(f(vec_to_x(x_vec)))[1], x_vec, to_vec(ẋ)[1]))) | ||
end | ||
function jvp(fdm, f, xẋs::Tuple{Any, Any}...) | ||
x, ẋ = collect(zip(xẋs...)) | ||
|
@@ -70,7 +70,7 @@ Compute an adjoint with any types of arguments `x` for which [`to_vec`](@ref) is | |
function j′vp(fdm, f, ȳ, x) | ||
x_vec, vec_to_x = to_vec(x) | ||
ȳ_vec, _ = to_vec(ȳ) | ||
return (vec_to_x(_j′vp(fdm, first ∘ to_vec ∘ f ∘ vec_to_x, ȳ_vec, x_vec)), ) | ||
return (_int2zero(vec_to_x(_j′vp(fdm, first ∘ to_vec ∘ f ∘ vec_to_x, ȳ_vec, x_vec))), ) | ||
end | ||
|
||
j′vp(fdm, f, ȳ, xs...) = j′vp(fdm, xs->f(xs...), ȳ, xs)[1] | ||
|
@@ -85,4 +85,14 @@ end | |
|
||
Compute the gradient of `f` for any `xs` for which [`to_vec`](@ref) is defined. | ||
""" | ||
grad(fdm, f, xs...) = j′vp(fdm, f, 1, xs...) # `j′vp` with seed of 1 | ||
grad(fdm, f, xs...) = j′vp(fdm, f, 1.0, xs...) # `j′vp` with seed of 1 | ||
|
||
# This deals with the fact that integers are non perturbable | ||
# v, b = to_vec(1); | ||
# v == [] | ||
# b(v) == 1 | ||
# which means that jvp always returns the integer itself, since [] - [] == [] | ||
_int2zero(x) = x | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suspect that this will miss a few cases (e.g. vectors of integers), but probably that's fine for now -- we can revisit later it needs be. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On a more general note: is the need for this function a manifestation of the fact that FiniteDifferences doesn't really know how to handle tangents properly? |
||
_int2zero(x::Tuple) = map(_int2zero, x) | ||
_int2zero(x::NamedTuple) = map(_int2zero, x) | ||
_int2zero(::Integer) = ZeroTangent() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -67,8 +67,11 @@ function test_to_vec(x::T; check_inferred=true) where {T} | |
return nothing | ||
end | ||
|
||
myrandn(T::Type{<:Number}, args...) = randn(T, args...) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i should probably rename this to something sensible.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure this should have |
||
myrandn(T::Type{<:Integer}, args...) = rand(T[-1, 0, 2], args...) | ||
|
||
@testset "to_vec" begin | ||
@testset "$T" for T in (Float32, ComplexF32, Float64, ComplexF64) | ||
@testset "$T" for T in (Int64, Float32, ComplexF32, Float64, ComplexF64) | ||
if T == Float64 | ||
test_to_vec(1.0) | ||
test_to_vec(1) | ||
|
@@ -79,52 +82,52 @@ end | |
test_to_vec(T[]) | ||
test_to_vec(Vector{T}[]) | ||
test_to_vec(Matrix{T}[]) | ||
test_to_vec(randn(T, 3)) | ||
test_to_vec(randn(T, 5, 11)) | ||
test_to_vec(randn(T, 13, 17, 19)) | ||
test_to_vec(randn(T, 13, 0, 19)) | ||
test_to_vec([1.0, randn(T, 2), randn(T, 1), 2.0]; check_inferred=false) | ||
test_to_vec([randn(T, 5, 4, 3), (5, 4, 3), 2.0]; check_inferred=false) | ||
test_to_vec(reshape([1.0, randn(T, 5, 4, 3), randn(T, 4, 3), 2.0], 2, 2); check_inferred=false) | ||
test_to_vec(UpperTriangular(randn(T, 13, 13))) | ||
test_to_vec(Diagonal(randn(T, 7))) | ||
test_to_vec(DummyType(randn(T, 2, 9))) | ||
test_to_vec(myrandn(T, 3)) | ||
test_to_vec(myrandn(T, 5, 11)) | ||
test_to_vec(myrandn(T, 13, 17, 19)) | ||
test_to_vec(myrandn(T, 13, 0, 19)) | ||
test_to_vec([1.0, myrandn(T, 2), myrandn(T, 1), 2.0]; check_inferred=false) | ||
test_to_vec([myrandn(T, 5, 4, 3), (5, 4, 3), 2.0]; check_inferred=false) | ||
test_to_vec(reshape([1.0, myrandn(T, 5, 4, 3), myrandn(T, 4, 3), 2.0], 2, 2); check_inferred=false) | ||
test_to_vec(UpperTriangular(myrandn(T, 13, 13))) | ||
test_to_vec(Diagonal(myrandn(T, 7))) | ||
test_to_vec(DummyType(myrandn(T, 2, 9))) | ||
test_to_vec(SVector{2, T}(1.0, 2.0); check_inferred=false) | ||
test_to_vec(SMatrix{2, 2, T}(1.0, 2.0, 3.0, 4.0); check_inferred=false) | ||
test_to_vec(@view randn(T, 10)[1:4]) # SubArray -- Vector | ||
test_to_vec(@view randn(T, 10, 2)[1:4, :]) # SubArray -- Matrix | ||
test_to_vec(@view myrandn(T, 10)[1:4]) # SubArray -- Vector | ||
test_to_vec(@view myrandn(T, 10, 2)[1:4, :]) # SubArray -- Matrix | ||
test_to_vec(Base.ReshapedArray(rand(T, 3, 3), (9,), ())) | ||
|
||
@testset "$Op" for Op in (Symmetric, Hermitian) | ||
test_to_vec(Op(randn(T, 11, 11))) | ||
test_to_vec(Op(myrandn(T, 11, 11))) | ||
@testset "$uplo" for uplo in (:L, :U) | ||
A = Op(randn(T, 11, 11), uplo) | ||
A = Op(myrandn(T, 11, 11), uplo) | ||
test_to_vec(A) | ||
x_vec, back = to_vec(A) | ||
@test back(x_vec).uplo == A.uplo | ||
end | ||
end | ||
|
||
@testset "$Op" for Op in (Adjoint, Transpose) | ||
test_to_vec(Op(randn(T, 4, 4))) | ||
test_to_vec(Op(randn(T, 6))) | ||
test_to_vec(Op(randn(T, 2, 5))) | ||
test_to_vec(Op(myrandn(T, 4, 4))) | ||
test_to_vec(Op(myrandn(T, 6))) | ||
test_to_vec(Op(myrandn(T, 2, 5))) | ||
|
||
# Ensure that if an `AbstractVector` is `Adjoint`ed, then the reconstructed | ||
# version also contains an `AbstractVector`, rather than an `AbstractMatrix` | ||
# whose 2nd dimension is of size 1. | ||
@testset "Vector" begin | ||
x_vec, back = to_vec(Op(randn(T, 5))) | ||
x_vec, back = to_vec(Op(myrandn(T, 5))) | ||
@test parent(back(x_vec)) isa AbstractVector | ||
end | ||
end | ||
|
||
@testset "PermutedDimsArray" begin | ||
test_to_vec(PermutedDimsArray(randn(T, 3, 1), (2, 1))) | ||
test_to_vec(PermutedDimsArray(randn(T, 4, 2, 3), (3, 1, 2))) | ||
test_to_vec(PermutedDimsArray(myrandn(T, 3, 1), (2, 1))) | ||
test_to_vec(PermutedDimsArray(myrandn(T, 4, 2, 3), (3, 1, 2))) | ||
test_to_vec( | ||
PermutedDimsArray( | ||
[randn(T, 3) for _ in 1:3, _ in 1:2, _ in 1:4], (2, 1, 3), | ||
[myrandn(T, 3) for _ in 1:3, _ in 1:2, _ in 1:4], (2, 1, 3), | ||
), | ||
) | ||
end | ||
|
@@ -133,7 +136,7 @@ end | |
# (100, 100) is needed to test for the NaNs that can appear in the | ||
# qr(M).T matrix | ||
for dims in [(7, 3), (100, 100)] | ||
M = randn(T, dims...) | ||
M = myrandn(T, dims...) | ||
P = M * M' + I # Positive definite matrix | ||
test_to_vec(svd(M)) | ||
test_to_vec(cholesky(P)) | ||
|
@@ -172,25 +175,25 @@ end | |
|
||
@testset "Tuples" begin | ||
test_to_vec((5, 4)) | ||
test_to_vec((5, randn(T, 5)); check_inferred = VERSION ≥ v"1.2") # broken on Julia 1.6.0, fixed on 1.6.1 | ||
test_to_vec((randn(T, 4), randn(T, 4, 3, 2), 1); check_inferred=false) | ||
test_to_vec((5, randn(T, 4, 3, 2), UpperTriangular(randn(T, 4, 4)), 2.5); check_inferred = VERSION ≥ v"1.2") # broken on Julia 1.6.0, fixed on 1.6.1 | ||
test_to_vec(((6, 5), 3, randn(T, 3, 2, 0, 1)); check_inferred=false) | ||
test_to_vec((DummyType(randn(T, 2, 7)), DummyType(randn(T, 3, 9)))) | ||
test_to_vec((DummyType(randn(T, 3, 2)), randn(T, 11, 8))) | ||
test_to_vec((5, myrandn(T, 5)); check_inferred = VERSION ≥ v"1.2") # broken on Julia 1.6.0, fixed on 1.6.1 | ||
test_to_vec((myrandn(T, 4), myrandn(T, 4, 3, 2), 1); check_inferred=false) | ||
test_to_vec((5, myrandn(T, 4, 3, 2), UpperTriangular(myrandn(T, 4, 4)), 2.5); check_inferred = VERSION ≥ v"1.2") # broken on Julia 1.6.0, fixed on 1.6.1 | ||
test_to_vec(((6, 5), 3, myrandn(T, 3, 2, 0, 1)); check_inferred=false) | ||
test_to_vec((DummyType(myrandn(T, 2, 7)), DummyType(myrandn(T, 3, 9)))) | ||
test_to_vec((DummyType(myrandn(T, 3, 2)), myrandn(T, 11, 8))) | ||
end | ||
@testset "NamedTuple" begin | ||
if T == Float64 | ||
test_to_vec((a=5, b=randn(10, 11), c=(5, 4, 3)); check_inferred = VERSION ≥ v"1.2") | ||
else | ||
test_to_vec((a=3 + 2im, b=randn(T, 10, 11), c=(5+im, 2-im, 1+im)); check_inferred = VERSION ≥ v"1.2") | ||
test_to_vec((a=3 + 2im, b=myrandn(T, 10, 11), c=(5+im, 2-im, 1+im)); check_inferred = VERSION ≥ v"1.2") | ||
end | ||
end | ||
@testset "Dictionary" begin | ||
if T == Float64 | ||
test_to_vec(Dict(:a=>5, :b=>randn(10, 11), :c=>(5, 4, 3)); check_inferred=false) | ||
else | ||
test_to_vec(Dict(:a=>3 + 2im, :b=>randn(T, 10, 11), :c=>(5+im, 2-im, 1+im)); check_inferred=false) | ||
test_to_vec(Dict(:a=>3 + 2im, :b=>myrandn(T, 10, 11), :c=>(5+im, 2-im, 1+im)); check_inferred=false) | ||
end | ||
end | ||
end | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be breaking instead? i think of it as a bug fix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My view is treat it as a bug fix for now -- if it turns out to be super breaking and we later decide it's actually a breaking change, we can always bump the minor version number then.