Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Failure to combine SparseDiffTools.autoback_hesvec and GCNConv #125

Closed
newalexander opened this issue Feb 10, 2022 · 3 comments
Closed

Failure to combine SparseDiffTools.autoback_hesvec and GCNConv #125

newalexander opened this issue Feb 10, 2022 · 3 comments

Comments

@newalexander
Copy link

Hello! Nice work on the library; it is very usable. I'm trying to calculate the hessian-vector product of a loss function involving GNNGraph datapoints and a GNNChain model. I've been using the SparseDiffTools.jl function autoback_hesvec for this, which implements ForwardDiff.jl over Zygote.jl for the hessian-vector calculation. However, this function is failing in the GraphNeuralNetworks.jl setting. The other hessian-vector functions in SparseDiffTools.jl do work, and an analogously-constructed calculation using only Flux works.

using GraphNeuralNetworks, Flux, Graphs, ForwardDiff, Random, SparseDiffTools


function gnn_test()
    Random.seed!(1234)

    g = GNNGraph(erdos_renyi(10,  30), ndata=rand(Float32, 3, 10), gdata=rand(Float32, 2))

    m = GNNChain(GCNConv(3 => 2, tanh), GlobalPool(+))
    ps, re = Flux.destructure(m)  # primal vector and restructure function
    ts = rand(Float32, size(ps))  # tangent vector

    loss(_ps) = Flux.Losses.mse(re(_ps)(g, g.ndata.x), g.gdata.u)

    numback_hesvec(loss, ps, ts) |> println  # works
    numback_hesvec(loss, ps, ts)  |> println  # works
    numauto_hesvec(loss, ps, ts)  |> println  # works
    autoback_hesvec(loss, ps, ts) |> println  # fails
end

function flux_test()
    Random.seed!(1234)

    x = rand(Float32, 10, 3)
    y = rand(Float32, 2, 3)

    m = Chain(Dense(10, 4, tanh), Dense(4, 2))
    ps, re = Flux.destructure(m)  # primal vector and restructure function
    ts = rand(Float32, size(ps))  # tangent vector

    loss(_ps) = Flux.Losses.mse(re(_ps)(x), y)

    numback_hesvec(loss, ps, ts) |> println  # works
    numback_hesvec(loss, ps, ts)  |> println  # works
    numauto_hesvec(loss, ps, ts)  |> println  # works
    autoback_hesvec(loss, ps, ts) |> println  # works
end

The full error message:

ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float32, 1})
Closest candidates are:
  (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/rounding.jl:200
  (::Type{T})(::T) where T<:Number at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/boot.jl:770
  (::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number} at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/char.jl:50
  ...
Stacktrace:
  [1] convert(#unused#::Type{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float32, 1})
    @ Base ./number.jl:7
  [2] setindex!(A::Vector{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float32, 1}, i1::Int64)
    @ Base ./array.jl:903
  [3] (::ChainRulesCore.ProjectTo{SparseArrays.SparseMatrixCSC, NamedTuple{(:element, :axes, :rowval, :nzranges, :colptr), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, Vector{Int64}, Vector{UnitRange{Int64}}, Vector{Int64}}}})(dx::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float32, 1}})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/uxrij/src/projection.jl:580
  [4] #1335
    @ ~/.julia/packages/ChainRules/3HAQW/src/rulesets/Base/arraymath.jl:37 [inlined]
  [5] unthunk
    @ ~/.julia/packages/ChainRulesCore/uxrij/src/tangent_types/thunks.jl:197 [inlined]
  [6] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/FPUm3/src/compiler/chainrules.jl:104 [inlined]
  [7] map
    @ ./tuple.jl:223 [inlined]
  [8] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/FPUm3/src/compiler/chainrules.jl:105 [inlined]
  [9] ZBack
    @ ~/.julia/packages/Zygote/FPUm3/src/compiler/chainrules.jl:204 [inlined]
 [10] Pullback
    @ ~/.julia/packages/GraphNeuralNetworks/HAl1C/src/msgpass.jl:189 [inlined]
 [11] (::typeof(∂(propagate)))(Δ::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float32, 1}})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [12] Pullback
    @ ~/.julia/packages/GraphNeuralNetworks/HAl1C/src/msgpass.jl:68 [inlined]
 [13] (::typeof(∂(#propagate#84)))(Δ::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float32, 1}})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [14] Pullback
    @ ~/.julia/packages/GraphNeuralNetworks/HAl1C/src/msgpass.jl:68 [inlined]
 [15] (::typeof(∂(propagate##kw)))(Δ::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float32, 1}})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [16] Pullback
    @ ~/.julia/packages/GraphNeuralNetworks/HAl1C/src/layers/conv.jl:103 [inlined]
 [17] (::typeof(∂(λ)))(Δ::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float32, 1}})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [18] Pullback
    @ ~/.julia/packages/GraphNeuralNetworks/HAl1C/src/layers/conv.jl:80 [inlined]
 [19] Pullback
    @ ~/.julia/packages/GraphNeuralNetworks/HAl1C/src/layers/basic.jl:125 [inlined]
 [20] (::typeof(∂(applylayer)))(Δ::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float32, 1}})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [21] Pullback
    @ ~/.julia/packages/GraphNeuralNetworks/HAl1C/src/layers/basic.jl:137 [inlined]
 [22] (::typeof(∂(applychain)))(Δ::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float32, 1}})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [23] Pullback
    @ ~/.julia/packages/GraphNeuralNetworks/HAl1C/src/layers/basic.jl:139 [inlined]
 [24] (::typeof(∂(λ)))(Δ::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float32, 1}})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [25] Pullback
    @ ~/JuliaProjects/GraphNetworkLayers/test/fwd.jl:15 [inlined]
 [26] (::typeof(∂(λ)))(Δ::ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float32, 1})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [27] (::Zygote.var"#57#58"{typeof(∂(λ))})(Δ::ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float32, 1})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface.jl:41
 [28] gradient(f::Function, args::Vector{ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float32, 1}})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface.jl:76
 [29] (::SparseDiffTools.var"#78#79"{var"#loss#5"{Flux.var"#66#68"{GNNChain{Tuple{GCNConv{Matrix{Float32}, Vector{Float32}, typeof(tanh)}, GlobalPool{typeof(+)}}}}, GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}})(x::Vector{ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float32, 1}})
    @ SparseDiffTools ~/.julia/packages/SparseDiffTools/9lSLn/src/differentiation/jaches_products_zygote.jl:39
 [30] autoback_hesvec(f::Function, x::Vector{Float32}, v::Vector{Float32})
    @ SparseDiffTools ~/.julia/packages/SparseDiffTools/9lSLn/src/differentiation/jaches_products_zygote.jl:41
 [31] gnn_test()
    @ Main [script location]
 [32] top-level scope
    @ REPL[8]:1
 [33] top-level scope
    @ ~/.julia/packages/CUDA/bki2w/src/initialization.jl:52
@CarloLucibello
Copy link
Member

Thanks for this very clear report! The issue could be in Chainrule.ProjectTo{SparseMatrixCSC} not being ForwardDiff.Dual friendly.

@CarloLucibello
Copy link
Member

CarloLucibello commented Feb 12, 2022

It works fine for the generic message passing framework (e.g. for th GATConv)

julia> using GraphNeuralNetworks, Flux, Graphs, ForwardDiff, Random, SparseDiffTools

julia> function gnn_test()
           Random.seed!(1234)

           g = GNNGraph(erdos_renyi(10,  30), ndata=rand(Float32, 3, 10))
           m = GATConv(3 => 2, tanh)
           ps, re = Flux.destructure(m)  # primal vector and restructure function
           ts = rand(Float32, size(ps))  # tangent vector

           loss(_ps) = sum(re(_ps)(g, g.ndata.x))

           numback_hesvec(loss, ps, ts) |> println  # works
           numback_hesvec(loss, ps, ts)  |> println  # works
           numauto_hesvec(loss, ps, ts)  |> println  # works
           autoback_hesvec(loss, ps, ts) |> println  # fails for GCNConv, works for GATConv
       end
gnn_test (generic function with 1 method)

julia> gnn_test()
Float32[4.4092455, -0.75034475, 2.9557762, -0.39457783, 2.5744586, 0.34606418, 8.512531, 0.21992864, -0.2590933, -0.5618502, 0.4025826, -1.1103446]
Float32[4.4092455, -0.75034475, 2.9557762, -0.39457783, 2.5744586, 0.34606418, 8.512531, 0.21992864, -0.2590933, -0.5618502, 0.4025826, -1.1103446]
Float32[4.4092455, -0.751315, 2.9551294, -0.39425442, 2.5746205, 0.34541732, 8.512531, 0.21992864, -0.25904277, -0.56184894, 0.40250173, -1.1103705]
Float32[4.409217, -0.7511167, 2.9556952, -0.3946823, 2.5747547, 0.34566167, 8.511975, 0.21927829, -0.25907597, -0.5618338, 0.40243906, -1.1103226]

but doesn't like operations involving sparse matrices (i.e. what happens for GCNConv`).

@CarloLucibello CarloLucibello changed the title Failure to combine SparseDiffTools.autoback_hesvec and GNNChain Failure to combine SparseDiffTools.autoback_hesvec and GCNConc Feb 12, 2022
@CarloLucibello CarloLucibello changed the title Failure to combine SparseDiffTools.autoback_hesvec and GCNConc Failure to combine SparseDiffTools.autoback_hesvec and GCNConv Feb 12, 2022
@CarloLucibello
Copy link
Member

CarloLucibello commented Feb 12, 2022

The problem seems a very generic one not strictly related to GNN.jl. Here is a MWE:

julia> using SparseArrays, SparseDiffTools

julia> x, t = rand(5), rand(5);

julia> A = sprand(5, 5, 0.5);

julia> loss(x) = sum(tanh.(A * x));

julia> numback_hesvec(loss, x, t) # works
5-element Vector{Float64}:
 -0.349703846209146
 -1.210662833747414
 -1.4030571895355597
 -0.47786341057923254
 -0.9171474544184983

julia> autoback_hesvec(loss, x, t)
ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float64}, Float64, 1})
Closest candidates are:
  (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat at ~/julia/julia-1.7.1/share/julia/base/rounding.jl:200
  (::Type{T})(::T) where T<:Number at ~/julia/julia-1.7.1/share/julia/base/boot.jl:770
  (::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number} at ~/julia/julia-1.7.1/share/julia/base/char.jl:50
  ...
Stacktrace:
  [1] convert(#unused#::Type{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float64}, Float64, 1})
    @ Base ./number.jl:7
  [2] setindex!(A::Vector{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float64}, Float64, 1}, i1::Int64)
    @ Base ./array.jl:903
  [3] (::ChainRulesCore.ProjectTo{SparseMatrixCSC, NamedTuple{(:element, :axes, :rowval, :nzranges, :colptr), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, Vector{Int64}, Vector{UnitRange{Int64}}, Vector{Int64}}}})(dx::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float64}, Float64, 1}})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/uxrij/src/projection.jl:580
  [4] #1334
    @ ~/.julia/packages/ChainRules/GRzER/src/rulesets/Base/arraymath.jl:36 [inlined]
  [5] unthunk
    @ ~/.julia/packages/ChainRulesCore/uxrij/src/tangent_types/thunks.jl:197 [inlined]
  [6] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/FPUm3/src/compiler/chainrules.jl:104 [inlined]
  [7] map
    @ ./tuple.jl:223 [inlined]
  [8] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/FPUm3/src/compiler/chainrules.jl:105 [inlined]
  [9] ZBack
    @ ~/.julia/packages/Zygote/FPUm3/src/compiler/chainrules.jl:204 [inlined]
 [10] Pullback
    @ ./REPL[40]:1 [inlined]
 [11] (::typeof((loss)))(Δ::ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float64}, Float64, 1})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [12] (::Zygote.var"#57#58"{typeof((loss))})(Δ::ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float64}, Float64, 1})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface.jl:41
 [13] gradient(f::Function, args::Vector{ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float64}, Float64, 1}})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface.jl:76
 [14] (::SparseDiffTools.var"#78#79"{typeof(loss)})(x::Vector{ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float64}, Float64, 1}})
    @ SparseDiffTools ~/.julia/packages/SparseDiffTools/b2cgD/src/differentiation/jaches_products_zygote.jl:39
 [15] autoback_hesvec(f::Function, x::Vector{Float64}, v::Vector{Float64})
    @ SparseDiffTools ~/.julia/packages/SparseDiffTools/b2cgD/src/differentiation/jaches_products_zygote.jl:41
 [16] top-level scope
    @ REPL[47]:1
 [17] top-level scope
    @ ~/.julia/packages/CUDA/bki2w/src/initialization.jl:52

Hi think this issue should be reported to SparseDiffTools.jl and probably fixed in ChainRules.jl

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants