-
-
Notifications
You must be signed in to change notification settings - Fork 610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hessian vector products #129
Comments
Update: using Flux.Tracker
inp = randn(3) # Input
v = randn(3) # Vector
H = randn(3,3); H = H+H' # Hessian
f(inp) = 0.5*inp'*(H*inp) # i'H*i function to take hessian of
hvp = H*v # True Hessian vector product
gg = H*inp # True gradient
ggvp = gg'v # True gradient vector product
x = param(inp)
fp(x) = Tracker.gradient(f,x)[1]
@assert fp(x).data == gg # Works
gvp(x) = fp(x)'v # Gradient vector product
@assert gvp(x).data == ggvp # Works
Hvp(x) = Tracker.gradient(gvp,x)
Hvp(x) # Fails
Error:
Nested AD not defined for broadcast
back_(::Flux.Tracker.Grads, ::Flux.Tracker.Call{Flux.Tracker.##379#380{Symbol},Tuple{Void,Void}}, ::TrackedArray{…,Array{Float64,1}}) at back.jl:102
back(::Flux.Tracker.Grads, ::Flux.Tracker.Tracked{Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}) at back.jl:117
(::Flux.Tracker.##4#5{Flux.Tracker.Grads})(::Flux.Tracker.Tracked{Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}) at back.jl:105
foreach(::Function, ::Tuple{Void,Flux.Tracker.Tracked{Array{Float64,1}}}, ::Tuple{TrackedArray{…,Array{Float64,1}},TrackedArray{…,Array{Float64,1}}}, ::Vararg{Tuple{TrackedArray{…,Array{Float64,1}},TrackedArray{…,Array{Float64,1}}},N} where N) at abstractarray.jl:1734
back_(::Flux.Tracker.Grads, ::Flux.Tracker.Call{Flux.Tracker.##363#368{2,Tuple{Tuple{Int64},Tuple{Int64}},Array{ForwardDiff.Dual{Void,Float64,2},1}},Tuple{Void,Flux.Tracker.Tracked{Array{Float64,1}}}}, ::TrackedArray{…,Array{Float64,1}}) at back.jl:105
back(::Flux.Tracker.Grads, ::Flux.Tracker.Tracked{Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}) at back.jl:117
(::Flux.Tracker.##4#5{Flux.Tracker.Grads})(::Flux.Tracker.Tracked{Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}) at back.jl:105
foreach(::Function, ::Tuple{Flux.Tracker.Tracked{Array{Float64,1}},Flux.Tracker.Tracked{Array{Float64,1}}}, ::Tuple{TrackedArray{…,Array{Float64,1}},TrackedArray{…,Array{Float64,1}}}, ::Vararg{Tuple{TrackedArray{…,Array{Float64,1}},TrackedArray{…,Array{Float64,1}}},N} where N) at abstractarray.jl:1734
back_(::Flux.Tracker.Grads, ::Flux.Tracker.Call{Flux.Tracker.##363#368{2,Tuple{Tuple{Int64},Tuple{Int64}},Array{ForwardDiff.Dual{Void,Float64,2},1}},Tuple{Flux.Tracker.Tracked{Array{Float64,1}},Flux.Tracker.Tracked{Array{Float64,1}}}}, ::Array{Float64,1}) at back.jl:105
back(::Flux.Tracker.Grads, ::Flux.Tracker.Tracked{Array{Float64,1}}, ::Array{Float64,1}) at back.jl:117
(::Flux.Tracker.##4#5{Flux.Tracker.Grads})(::Flux.Tracker.Tracked{Array{Float64,1}}, ::Array{Float64,1}) at back.jl:105
foreach(::Function, ::Tuple{Flux.Tracker.Tracked{Array{Float64,1}},Void}, ::Tuple{Array{Float64,1},TrackedArray{…,Array{Float64,1}}}, ::Vararg{Tuple{Array{Float64,1},TrackedArray{…,Array{Float64,1}}},N} where N) at abstractarray.jl:1734
back_(::Flux.Tracker.Grads, ::Flux.Tracker.Call{Flux.Tracker.##331#332{TrackedArray{…,Array{Float64,1}},Array{Float64,1}},Tuple{Flux.Tracker.Tracked{Array{Float64,1}},Void}}, ::Int64) at back.jl:105
back(::Flux.Tracker.Grads, ::Flux.Tracker.Tracked{Float64}, ::Int64) at back.jl:117
(::Flux.Tracker.##6#7{Flux.Tracker.Params,Flux.Tracker.TrackedReal{Float64}})(::Int64) at back.jl:130
(::Flux.Tracker.##9#11{Flux.Tracker.##6#7{Flux.Tracker.Params,Flux.Tracker.TrackedReal{Float64}}})(::Int64) at back.jl:139
gradient(::Function, ::TrackedArray{…,Array{Float64,1}}, ::Vararg{TrackedArray{…,Array{Float64,1}},N} where N) at back.jl:151
gvp(::TrackedArray{…,Array{Float64,1}}) at hvp.jl:14
... |
Update: I have gotten this particular example to work by removing the call to I'm not entierly sure about the last factor of 2 in using Flux
using Flux.Tracker
inp = randn(3) # Input
v = randn(3) # Vector
H = randn(3,3); H = H+H' # Hessian
f(inp) = 0.5*inp'*(H*inp) # i'H*i function to take hessian of
hvp = H*v # True Hessian vector product
gg = H*inp # True gradient
ggvp = gg'v # True gradient vector product
x = param(inp)
fp(x) = Tracker.gradient(f,x)[1]
@assert fp(x).data == gg # Works
gvp(x) = fp(x)'v # Gradient vector product
@assert gvp(x).data == ggvp # Works
Hvp(x) = Tracker.gradient(gvp,x)[1]
@assert 2Hvp(x) == hvp # Works |
It's broken in general because you're dropping the gradient of delta on this line. If you're lucky, removing that line might just work also :) |
Ah I see. Avoiding extracting dualify(xs::AbstractArray{<:Dual}, ps) but I have issues figureing out how to create those second-order duals, even after having examined the code that does this in ForwardDiff.jl |
This may now work on this branch, can you give it a shot? |
I run into some issues julia> fp(x)
ERROR: MethodError: *(::Transpose{Float64,Array{Float64,2}}, ::TrackedArray{…,Array{Float64,1}}) is ambiguous. Candidates:
*(x::AbstractArray{T,2} where T, y::TrackedArray{T,1,A} where A where T) in Flux.Tracker at /local/home/fredrikb/.julia/packages/Flux/Fux9f/src/tracker/array.jl:279
*(transA::Transpose{#s565,#s564} where #s564<:AbstractArray{T,2} where #s565, x::AbstractArray{S,1}) where {T, S} in LinearAlgebra at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v0.7/LinearAlgebra/src/matmul.jl:83
Possible fix, define
*(::Transpose{#s565,#s564} where #s564<:AbstractArray{T,2} where #s565, ::TrackedArray{S,1,A} where A)
Stacktrace:
[1] (::getfield(Flux.Tracker, Symbol("##326#327")){Array{Float64,2},TrackedArray{…,Array{Float64,1}}})(::TrackedArray{…,Array{Float64,1}}) at /local/home/fredrikb/.julia/packages/Flux/Fux9f/src/tracker/array.jl:287
[2] back_(::Flux.Tracker.Grads, ::Flux.Tracker.Call{getfield(Flux.Tracker, Symbol("##326#327")){Array{Float64,2},TrackedArray{…,Array{Float64,1}}},Tuple{Nothing,Flux.Tracker.Tracked{Array{Float64,1}}}}, ::TrackedArray{…,Array{Float64,1}}) at /local/home/fredrikb/.julia/packages/Flux/Fux9f/src/tracker/back.jl:103
[3] back(::Flux.Tracker.Grads, ::Flux.Tracker.Tracked{Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}) at /local/home/fredrikb/.julia/packages/Flux/Fux9f/src/tracker/back.jl:118
[4] (::getfield(Flux.Tracker, Symbol("##4#5")){Flux.Tracker.Grads})(::Flux.Tracker.Tracked{Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}) at /local/home/fredrikb/.julia/packages/Flux/Fux9f/src/tracker/back.jl:106
[5] foreach(::Function, ::Tuple{Flux.Tracker.Tracked{Adjoint{Float64,Array{Float64,1}}},Flux.Tracker.Tracked{Array{Float64,1}}}, ::Tuple{TrackedArray{…,Transpose{Float64,Array{Float64,1}}},TrackedArray{…,Array{Float64,1}}}) at ./abstractarray.jl:1844
[6] back_(::Flux.Tracker.Grads, ::Flux.Tracker.Call{getfield(Flux.Tracker, Symbol("##326#327")){TrackedArray{…,Adjoint{Float64,Array{Float64,1}}},TrackedArray{…,Array{Float64,1}}},Tuple{Flux.Tracker.Tracked{Adjoint{Float64,Array{Float64,1}}},Flux.Tracker.Tracked{Array{Float64,1}}}}, ::Int64) at /local/home/fredrikb/.julia/packages/Flux/Fux9f/src/tracker/back.jl:106
[7] back(::Flux.Tracker.Grads, ::Flux.Tracker.Tracked{Float64}, ::Int64) at /local/home/fredrikb/.julia/packages/Flux/Fux9f/src/tracker/back.jl:118
[8] #6 at /local/home/fredrikb/.julia/packages/Flux/Fux9f/src/tracker/back.jl:131 [inlined]
[9] (::getfield(Flux.Tracker, Symbol("##9#11")){getfield(Flux.Tracker, Symbol("##6#7")){Params,Flux.Tracker.TrackedReal{Float64}}})(::Int64) at /local/home/fredrikb/.julia/packages/Flux/Fux9f/src/tracker/back.jl:140
[10] gradient(::Function, ::TrackedArray{…,Array{Float64,1}}) at /local/home/fredrikb/.julia/packages/Flux/Fux9f/src/tracker/back.jl:152
[11] fp(::TrackedArray{…,Array{Float64,1}}) at ./REPL[49]:1
[12] top-level scope at none:0 If I define
|
As a side, this error |
@baggepinnen are you still trying to make this work? Does it look like a more fundamental limitation given that you tried a few AD frameworks? @MikeInnes would Zygote be helpful here? |
@jklaise I ended up using ReverseDiff.jl for the outer differentiation and ForwardDiff.jl for the inner. I can still not get it to work using either Flux or Zygote, see FluxML/Zygote.jl#115 for ref. My latest try using Flux is using Flux, LinearAlgebra
using Flux.Tracker
inp = randn(3) # Input
v = randn(3) # Vector
H = randn(3,3); H = H+H' # Hessian
f(inp,H) = 0.5*sum(inp .* (H*inp) ) # i'H*i function to take hessian of
hvp = H*v # True Hessian vector product
gg = H*inp # True gradient
ggvp = gg'v # True gradient vector product
x = param(inp)
fp(x) = Tracker.gradient(x->f(x,H),x,nest=true)[1] # This line errors if nest == true
@assert fp(x).data == gg # Works
gvp(x) = fp(x)'v # Gradient vector product
@assert gvp(x).data == ggvp # Works
Hvp(x) = Tracker.gradient(gvp,x, nest=true)[1]
@assert Hvp(x) == hvp # This line errors if nest == false in the definition of fp(x)
julia> @assert Hvp(x) == hvp
ERROR: MethodError: *(::Transpose{Float64,Array{Float64,2}}, ::TrackedArray{…,Array{Float64,1}}) is ambiguous. Candidates:
*(x::AbstractArray{T,2} where T, y::TrackedArray{T,1,A} where A where T) in Flux.Tracker at /local/home/fredrikb/.julia/packages/Flux/8XpDt/src/tracker/lib/array.jl:354
*(transA::Transpose{#s623,#s622} where #s622<:AbstractArray{T,2} where #s623, x::AbstractArray{S,1}) where {T, S} in LinearAlgebra at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.1/LinearAlgebra/src/matmul.jl:84
Possible fix, define
*(::Transpose{#s623,#s622} where #s622<:AbstractArray{T,2} where #s623, ::TrackedArray{S,1,A} where A)
Stacktrace: |
@baggepinnen: could you please post your working example using ReverseDiff.jl and ForwardDiff.jl? Many thanks! |
I think I eventually hand coded the Jacobian and used Reversediff.jl for the outer diff. Some old code using forward a d reverse is here though |
This has been discussed in Zygote and I think it's better tracked there :) |
I'm trying to calculate hessian-vector products without forming the full Hessian. The trick is described on slide 9 here
http://rll.berkeley.edu/deeprlcourse/docs/lec5.pdf
In tensorflow, used in the example, one builds up the computational graph using calls to
grad
, which I'm struggling to replicate using Flux.This is my attempt so far, the issue is that I can not figure out how to backpropagate all the way to the input
Full output of code below
Any assistance accomplishing this would be greatly appreciated!
The text was updated successfully, but these errors were encountered: