Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hessian vector products #129

Closed
baggepinnen opened this issue Dec 16, 2017 · 12 comments
Closed

Hessian vector products #129

baggepinnen opened this issue Dec 16, 2017 · 12 comments

Comments

@baggepinnen
Copy link
Contributor

baggepinnen commented Dec 16, 2017

I'm trying to calculate hessian-vector products without forming the full Hessian. The trick is described on slide 9 here
http://rll.berkeley.edu/deeprlcourse/docs/lec5.pdf
In tensorflow, used in the example, one builds up the computational graph using calls to grad, which I'm struggling to replicate using Flux.

This is my attempt so far, the issue is that I can not figure out how to backpropagate all the way to the input

#Goal: calculate H*v without forming H
inp = randn(3) # Input
v = randn(3) # Vector
H = randn(3,3); H = H+H' # Hessian
f(inp) = 0.5*sum(inp.*(H*inp)) # i'H*i function to take hessian of
hvp = H*v # True Hessian vector product
gg = H*inp # True gradient
ggvp = gg'v # True gradient vector product

x = param(inp)
back!(f(x)[1])
g = param(x.grad) # If this would be a function call that inserts a gradient calculation in a graph, problem would probably be solved
gvp = sum(g.*v) # Correct until here == gg'v
back!(gvp)
g.grad # Not == hvp, does not backpropagate all the way to x

Full output of code below

julia> using Flux

julia> using Flux: back!

julia> inp = randn(3) # Input

julia> v = randn(3) # Vector

julia> H = randn(3,3); H = H+H' # Hessian

julia> f(inp) = 0.5*sum(inp.*(H*inp)) # i'H*i function to take hessian of

julia> hvp = H*v # True Hessian vector product
3-element Array{Float64,1}:
 -0.61915 
 -0.490359
  2.94145 

julia> gg = H*inp # True gradient
3-element Array{Float64,1}:
 -1.49463 
  1.60448 
  0.860397

julia> ggvp = gg'v # True gradient vector product
-3.36572811826236

julia> x = param(inp)
Tracked 3-element Array{Float64,1}:
  0.918377
 -0.7506  
 -1.07606 

julia> back!(f(x)[1])

julia> g = param(x.grad)
Tracked 3-element Array{Float64,1}:
 -1.49463 
  1.60448 
  0.860397

julia> gvp = sum(g.*v) # Correct until here == gg'v
Tracked 0-dimensional Array{Float64,0}:
-3.36573

julia> back!(gvp)

julia> g.grad
3-element Array{Float64,1}:
  0.364484
 -1.3813  
 -0.7028 

Any assistance accomplishing this would be greatly appreciated!

@baggepinnen
Copy link
Contributor Author

Update:

using Flux.Tracker

inp = randn(3) # Input
v = randn(3) # Vector
H = randn(3,3); H = H+H' # Hessian
f(inp) = 0.5*inp'*(H*inp) # i'H*i function to take hessian of
hvp = H*v # True Hessian vector product
gg = H*inp # True gradient
ggvp = gg'v # True gradient vector product

x = param(inp)
fp(x) = Tracker.gradient(f,x)[1]
@assert fp(x).data == gg # Works
gvp(x) = fp(x)'v # Gradient vector product
@assert gvp(x).data == ggvp # Works
Hvp(x) = Tracker.gradient(gvp,x)
Hvp(x) # Fails

Error: 
Nested AD not defined for broadcast
back_(::Flux.Tracker.Grads, ::Flux.Tracker.Call{Flux.Tracker.##379#380{Symbol},Tuple{Void,Void}}, ::TrackedArray{…,Array{Float64,1}}) at back.jl:102
back(::Flux.Tracker.Grads, ::Flux.Tracker.Tracked{Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}) at back.jl:117
(::Flux.Tracker.##4#5{Flux.Tracker.Grads})(::Flux.Tracker.Tracked{Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}) at back.jl:105
foreach(::Function, ::Tuple{Void,Flux.Tracker.Tracked{Array{Float64,1}}}, ::Tuple{TrackedArray{…,Array{Float64,1}},TrackedArray{…,Array{Float64,1}}}, ::Vararg{Tuple{TrackedArray{,Array{Float64,1}},TrackedArray{,Array{Float64,1}}},N} where N) at abstractarray.jl:1734
back_(::Flux.Tracker.Grads, ::Flux.Tracker.Call{Flux.Tracker.##363#368{2,Tuple{Tuple{Int64},Tuple{Int64}},Array{ForwardDiff.Dual{Void,Float64,2},1}},Tuple{Void,Flux.Tracker.Tracked{Array{Float64,1}}}}, ::TrackedArray{…,Array{Float64,1}}) at back.jl:105
back(::Flux.Tracker.Grads, ::Flux.Tracker.Tracked{Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}) at back.jl:117
(::Flux.Tracker.##4#5{Flux.Tracker.Grads})(::Flux.Tracker.Tracked{Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}) at back.jl:105
foreach(::Function, ::Tuple{Flux.Tracker.Tracked{Array{Float64,1}},Flux.Tracker.Tracked{Array{Float64,1}}}, ::Tuple{TrackedArray{…,Array{Float64,1}},TrackedArray{…,Array{Float64,1}}}, ::Vararg{Tuple{TrackedArray{,Array{Float64,1}},TrackedArray{,Array{Float64,1}}},N} where N) at abstractarray.jl:1734
back_(::Flux.Tracker.Grads, ::Flux.Tracker.Call{Flux.Tracker.##363#368{2,Tuple{Tuple{Int64},Tuple{Int64}},Array{ForwardDiff.Dual{Void,Float64,2},1}},Tuple{Flux.Tracker.Tracked{Array{Float64,1}},Flux.Tracker.Tracked{Array{Float64,1}}}}, ::Array{Float64,1}) at back.jl:105
back(::Flux.Tracker.Grads, ::Flux.Tracker.Tracked{Array{Float64,1}}, ::Array{Float64,1}) at back.jl:117
(::Flux.Tracker.##4#5{Flux.Tracker.Grads})(::Flux.Tracker.Tracked{Array{Float64,1}}, ::Array{Float64,1}) at back.jl:105
foreach(::Function, ::Tuple{Flux.Tracker.Tracked{Array{Float64,1}},Void}, ::Tuple{Array{Float64,1},TrackedArray{…,Array{Float64,1}}}, ::Vararg{Tuple{Array{Float64,1},TrackedArray{,Array{Float64,1}}},N} where N) at abstractarray.jl:1734
back_(::Flux.Tracker.Grads, ::Flux.Tracker.Call{Flux.Tracker.##331#332{TrackedArray{…,Array{Float64,1}},Array{Float64,1}},Tuple{Flux.Tracker.Tracked{Array{Float64,1}},Void}}, ::Int64) at back.jl:105
back(::Flux.Tracker.Grads, ::Flux.Tracker.Tracked{Float64}, ::Int64) at back.jl:117
(::Flux.Tracker.##6#7{Flux.Tracker.Params,Flux.Tracker.TrackedReal{Float64}})(::Int64) at back.jl:130
(::Flux.Tracker.##9#11{Flux.Tracker.##6#7{Flux.Tracker.Params,Flux.Tracker.TrackedReal{Float64}}})(::Int64) at back.jl:139
gradient(::Function, ::TrackedArray{…,Array{Float64,1}}, ::Vararg{TrackedArray{…,Array{Float64,1}},N} where N) at back.jl:151
gvp(::TrackedArray{…,Array{Float64,1}}) at hvp.jl:14
...

@baggepinnen baggepinnen mentioned this issue Jul 26, 2018
2 tasks
@baggepinnen
Copy link
Contributor Author

Update: I have gotten this particular example to work by removing the call to nobacksies nobacksies(:broadcast, dxs)

I'm not entierly sure about the last factor of 2 in 2Hvp(x) == hvp though, and whether or not it's a lucky coincidence that nested AD worked in this case but might be erroneous for a general broadcast call.

using Flux
using Flux.Tracker

inp    = randn(3)             # Input
v      = randn(3)             # Vector
H      = randn(3,3); H = H+H' # Hessian
f(inp) = 0.5*inp'*(H*inp)     # i'H*i function to take hessian of
hvp    = H*v                  # True Hessian vector product
gg     = H*inp                # True gradient
ggvp   = gg'v                 # True gradient vector product

x      = param(inp)
fp(x)  = Tracker.gradient(f,x)[1]
@assert fp(x).data == gg # Works
gvp(x) = fp(x)'v # Gradient vector product
@assert gvp(x).data == ggvp # Works
Hvp(x) = Tracker.gradient(gvp,x)[1]
@assert 2Hvp(x) == hvp # Works

@MikeInnes
Copy link
Member

It's broken in general because you're dropping the gradient of delta on this line. If you're lucky, removing that line might just work also :)

@baggepinnen
Copy link
Contributor Author

Ah I see. Avoiding extracting data(Δ_) did mess with the Dual numbers as you foresaw. It seems that Dual numbers do support second-order duals, but getting the tags and partials right is beyond me, I think the following method is needed

dualify(xs::AbstractArray{<:Dual}, ps)

but I have issues figureing out how to create those second-order duals, even after having examined the code that does this in ForwardDiff.jl

@MikeInnes
Copy link
Member

This may now work on this branch, can you give it a shot?

@baggepinnen
Copy link
Contributor Author

I run into some issues

julia> fp(x)
ERROR: MethodError: *(::Transpose{Float64,Array{Float64,2}}, ::TrackedArray{…,Array{Float64,1}}) is ambiguous. Candidates:
  *(x::AbstractArray{T,2} where T, y::TrackedArray{T,1,A} where A where T) in Flux.Tracker at /local/home/fredrikb/.julia/packages/Flux/Fux9f/src/tracker/array.jl:279
  *(transA::Transpose{#s565,#s564} where #s564<:AbstractArray{T,2} where #s565, x::AbstractArray{S,1}) where {T, S} in LinearAlgebra at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v0.7/LinearAlgebra/src/matmul.jl:83
Possible fix, define
  *(::Transpose{#s565,#s564} where #s564<:AbstractArray{T,2} where #s565, ::TrackedArray{S,1,A} where A)
Stacktrace:
 [1] (::getfield(Flux.Tracker, Symbol("##326#327")){Array{Float64,2},TrackedArray{,Array{Float64,1}}})(::TrackedArray{…,Array{Float64,1}}) at /local/home/fredrikb/.julia/packages/Flux/Fux9f/src/tracker/array.jl:287
 [2] back_(::Flux.Tracker.Grads, ::Flux.Tracker.Call{getfield(Flux.Tracker, Symbol("##326#327")){Array{Float64,2},TrackedArray{,Array{Float64,1}}},Tuple{Nothing,Flux.Tracker.Tracked{Array{Float64,1}}}}, ::TrackedArray{…,Array{Float64,1}}) at /local/home/fredrikb/.julia/packages/Flux/Fux9f/src/tracker/back.jl:103
 [3] back(::Flux.Tracker.Grads, ::Flux.Tracker.Tracked{Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}) at /local/home/fredrikb/.julia/packages/Flux/Fux9f/src/tracker/back.jl:118
 [4] (::getfield(Flux.Tracker, Symbol("##4#5")){Flux.Tracker.Grads})(::Flux.Tracker.Tracked{Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}) at /local/home/fredrikb/.julia/packages/Flux/Fux9f/src/tracker/back.jl:106
 [5] foreach(::Function, ::Tuple{Flux.Tracker.Tracked{Adjoint{Float64,Array{Float64,1}}},Flux.Tracker.Tracked{Array{Float64,1}}}, ::Tuple{TrackedArray{,Transpose{Float64,Array{Float64,1}}},TrackedArray{,Array{Float64,1}}}) at ./abstractarray.jl:1844
 [6] back_(::Flux.Tracker.Grads, ::Flux.Tracker.Call{getfield(Flux.Tracker, Symbol("##326#327")){TrackedArray{,Adjoint{Float64,Array{Float64,1}}},TrackedArray{,Array{Float64,1}}},Tuple{Flux.Tracker.Tracked{Adjoint{Float64,Array{Float64,1}}},Flux.Tracker.Tracked{Array{Float64,1}}}}, ::Int64) at /local/home/fredrikb/.julia/packages/Flux/Fux9f/src/tracker/back.jl:106
 [7] back(::Flux.Tracker.Grads, ::Flux.Tracker.Tracked{Float64}, ::Int64) at /local/home/fredrikb/.julia/packages/Flux/Fux9f/src/tracker/back.jl:118
 [8] #6 at /local/home/fredrikb/.julia/packages/Flux/Fux9f/src/tracker/back.jl:131 [inlined]
 [9] (::getfield(Flux.Tracker, Symbol("##9#11")){getfield(Flux.Tracker, Symbol("##6#7")){Params,Flux.Tracker.TrackedReal{Float64}}})(::Int64) at /local/home/fredrikb/.julia/packages/Flux/Fux9f/src/tracker/back.jl:140
 [10] gradient(::Function, ::TrackedArray{…,Array{Float64,1}}) at /local/home/fredrikb/.julia/packages/Flux/Fux9f/src/tracker/back.jl:152
 [11] fp(::TrackedArray{…,Array{Float64,1}}) at ./REPL[49]:1
 [12] top-level scope at none:0

If I define f(inp) = 0.5*copy(inp')*(H*inp) to circumvent the method error I instead get

julia> f(inp) = 0.5*copy(inp')*(H*inp)     # i'H*i function to take hessian of
f (generic function with 1 method)

julia> fp(x)  = Tracker.gradient(f,x)[1]
fp (generic function with 1 method)

julia> fp(x)
ERROR: MethodError: no method matching Float64(::Flux.Tracker.TrackedReal{Float64})
Closest candidates are:
  Float64(::Real, ::RoundingMode) where T<:AbstractFloat at rounding.jl:173
  Float64(::T<:Number) where T<:Number at boot.jl:725
  Float64(::Int8) at float.jl:60
  ...
Stacktrace:
 [1] Float64(::Flux.Tracker.TrackedReal{Float64}) at ./deprecated.jl:468
 [2] convert at ./number.jl:7 [inlined]
 [3] setindex! at ./array.jl:769 [inlined]
 [4] setindex! at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v0.7/LinearAlgebra/src/adjtrans.jl:131 [inlined]
 [5] _setindex! at ./abstractarray.jl:1029 [inlined]
 [6] setindex! at ./abstractarray.jl:1006 [inlined]
 [7] copyto!(::Adjoint{Float64,Array{Float64,1}}, ::TrackedArray{…,Adjoint{Float64,Array{Float64,1}}}) at ./multidimensional.jl:827
 [8] copymutable(::TrackedArray{…,Adjoint{Float64,Array{Float64,1}}}) at ./abstractarray.jl:831
 [9] copy at ./abstractarray.jl:781 [inlined]
 [10] f(::TrackedArray{…,Array{Float64,1}}) at ./REPL[51]:1
 [11] #8 at /local/home/fredrikb/.julia/packages/Flux/Fux9f/src/tracker/back.jl:139 [inlined]
 [12] forward at /local/home/fredrikb/.julia/packages/Flux/Fux9f/src/tracker/back.jl:126 [inlined]
 [13] forward(::Function, ::TrackedArray{…,Array{Float64,1}}) at /local/home/fredrikb/.julia/packages/Flux/Fux9f/src/tracker/back.jl:139
 [14] gradient(::Function, ::TrackedArray{…,Array{Float64,1}}) at /local/home/fredrikb/.julia/packages/Flux/Fux9f/src/tracker/back.jl:150
 [15] fp(::TrackedArray{…,Array{Float64,1}}) at ./REPL[52]:1
 [16] top-level scope at none:0

@baggepinnen
Copy link
Contributor Author

As a side, this error
ERROR: MethodError: no method matching Float64(::ADspecialType)
is the same one I get if I try this experiment using Yota.jl, AutoGrad.jl, some versions of Reverse/ForwardDiff and if I remember correctly, also Nabla.jl.

@jklaise
Copy link

jklaise commented Mar 14, 2019

@baggepinnen are you still trying to make this work? Does it look like a more fundamental limitation given that you tried a few AD frameworks? @MikeInnes would Zygote be helpful here?

@baggepinnen
Copy link
Contributor Author

@jklaise I ended up using ReverseDiff.jl for the outer differentiation and ForwardDiff.jl for the inner. I can still not get it to work using either Flux or Zygote, see FluxML/Zygote.jl#115 for ref. My latest try using Flux is

using Flux, LinearAlgebra
using Flux.Tracker

inp    = randn(3)             # Input
v      = randn(3)             # Vector
H      = randn(3,3); H = H+H' # Hessian
f(inp,H) = 0.5*sum(inp .* (H*inp) )    # i'H*i function to take hessian of
hvp    = H*v                  # True Hessian vector product
gg     = H*inp                # True gradient
ggvp   = gg'v                 # True gradient vector product

x      = param(inp)
fp(x)  = Tracker.gradient(x->f(x,H),x,nest=true)[1] # This line errors if nest == true
@assert fp(x).data == gg # Works
gvp(x) = fp(x)'v # Gradient vector product
@assert gvp(x).data == ggvp # Works
Hvp(x) = Tracker.gradient(gvp,x, nest=true)[1]
@assert Hvp(x) == hvp # This line errors if nest == false in the definition of fp(x)


julia> @assert Hvp(x) == hvp
ERROR: MethodError: *(::Transpose{Float64,Array{Float64,2}}, ::TrackedArray{…,Array{Float64,1}}) is ambiguous. Candidates:
  *(x::AbstractArray{T,2} where T, y::TrackedArray{T,1,A} where A where T) in Flux.Tracker at /local/home/fredrikb/.julia/packages/Flux/8XpDt/src/tracker/lib/array.jl:354
  *(transA::Transpose{#s623,#s622} where #s622<:AbstractArray{T,2} where #s623, x::AbstractArray{S,1}) where {T, S} in LinearAlgebra at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.1/LinearAlgebra/src/matmul.jl:84
Possible fix, define
  *(::Transpose{#s623,#s622} where #s622<:AbstractArray{T,2} where #s623, ::TrackedArray{S,1,A} where A)
Stacktrace:

@andreasko
Copy link

@baggepinnen: could you please post your working example using ReverseDiff.jl and ForwardDiff.jl? Many thanks!

@baggepinnen
Copy link
Contributor Author

I think I eventually hand coded the Jacobian and used Reversediff.jl for the outer diff. Some old code using forward a d reverse is here though
https://github.com/baggepinnen/JacProp.jl/blob/a24ff370c3c5ea654a3ae0af59cf77e2372aa2d1/src/linear_sys_manual.jl#L113

@ToucheSir
Copy link
Member

This has been discussed in Zygote and I think it's better tracked there :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants