Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient of edge weights is nothing with fused e_mul_xj #113

Closed
learning-chip opened this issue Jan 21, 2022 · 4 comments · Fixed by #123
Closed

Gradient of edge weights is nothing with fused e_mul_xj #113

learning-chip opened this issue Jan 21, 2022 · 4 comments · Fixed by #123

Comments

@learning-chip
Copy link

learning-chip commented Jan 21, 2022

#107 breaks Zygote autodiff. Zygote.gradient() returns nothing for the fused kernel, while returns correct gradient for the unfused one. This bug further breaks GNN training, with hard-to-understand error like MethodError: no method matching vec(::Nothing)

To reproduce

using GraphNeuralNetworks
using SparseArrays
import Random: seed!
using Zygote

n = 32
seed!(0)
A = sprand(n, n, 0.1)
b = rand(1, n)
g = GNNGraph(A)
A_val = reshape(A.nzval, 1, :)

"""SpMV followed by a scalar loss function"""
function forward_fused(g, b, A_val)
    out = propagate(
        e_mul_xj, g, +; xj=b, e=A_val
        )
    return sum(abs2, out)
end

function forward_unfused(g, b, A_val)
    out = propagate(
        (xi, xj, e) -> e .* xj, g, +; xj=b, e=A_val
        )
    return sum(abs2, out)
end

forward_fused(g, b, vec(A_val)) == forward_unfused(g, b, A_val)  # true, forward passes agree

grad_builtin = gradient(A -> sum(abs2, b * A), A)[1];  # turns a sparse CSC matrix containing gradient

grad_gnn1 = gradient(
    A_vals -> forward_unfused(g, b, A_vals), 
    A_val
)[1]

isequal(vec(grad_gnn1), grad_builtin.nzval)  # true, gradient agree with reference

# not flatten edge feature, so the “fused function” not actually invoking the fused kernel
grad_gnn2 = gradient(
    A_vals -> forward_fused(g, b, A_vals), 
    A_val
)[1]

isequal(vec(grad_gnn2), grad_builtin.nzval)   # true, gradient agree with reference

# passing flattened edge feature, activating fusion
grad_gnn3 = gradient(
    A_vals -> forward_fused(g, b, A_vals), 
    vec(A_val)
)[1]  # bug, turns nothing

Pacakge version

  • GraphNeuralNetworks.jl 0.3.10 (from git master)
  • Zygote.jl 0.6.33
@CarloLucibello
Copy link
Member

Unfortunately seems hard to support gradient with respect to edge_weights when doing fused operations.
I don't know how to make the construction of a matrix out of a vector differentiable.

@CarloLucibello CarloLucibello changed the title Fused e_mul_xj kernel breaks zygote autodiff Gradient of edge weights is nothing with fused e_mul_xj Jan 21, 2022
@CarloLucibello
Copy link
Member

We have to solve this issue:

julia> using SparseArrays, Zygote

julia> s, t, w = [1,2], [2,3], [0.5,0.5]
([1, 2], [2, 3], [0.5, 0.5])

julia> gradient(w -> sum(sparse(s,t,w)), w)
ERROR: Need an adjoint for constructor SparseMatrixCSC{Float64, Int64}. Gradient is of type FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.Jnew{SparseMatrixCSC{Float64, Int64}, Nothing, false})(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/lib/lib.jl:324
  [3] (::Zygote.var"#1786#back#229"{Zygote.Jnew{SparseMatrixCSC{Float64, Int64}, Nothing, false}})(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [4] Pullback
    @ /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/stdlib/v1.7/SparseArrays/src/sparsematrix.jl:31 [inlined]
  [5] (::typeof((SparseMatrixCSC{Float64, Int64})))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface2.jl:0
  [6] Pullback
    @ /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/stdlib/v1.7/SparseArrays/src/sparsematrix.jl:44 [inlined]
  [7] (::typeof((SparseMatrixCSC)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface2.jl:0
  [8] Pullback
    @ /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/stdlib/v1.7/SparseArrays/src/sparsematrix.jl:961 [inlined]
  [9] (::typeof((sparse!)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface2.jl:0
 [10] Pullback
    @ /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/stdlib/v1.7/SparseArrays/src/sparsematrix.jl:798 [inlined]
 [11] (::typeof((sparse)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface2.jl:0
 [12] Pullback
    @ /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/stdlib/v1.7/SparseArrays/src/sparsematrix.jl:987 [inlined]
 [13] (::typeof((sparse)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface2.jl:0
 [14] Pullback
    @ /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/stdlib/v1.7/SparseArrays/src/sparsematrix.jl:983 [inlined]
 [15] (::typeof((sparse)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface2.jl:0
 [16] Pullback
    @ ./REPL[7]:1 [inlined]
 [17] (::typeof((#6)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface2.jl:0
 [18] (::Zygote.var"#57#58"{typeof((#6))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface.jl:41
 [19] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface.jl:76
 [20] top-level scope
    @ REPL[7]:1

@learning-chip
Copy link
Author

gradient(w -> sum(sparse(s,t,w)), w)

This is definitely doable with PyTorch sparse tensors. I am not familiar enough with Zygote's autodiff rules (still learning), but let me just post the PyTorch (1.9.1) code for reference...

import torch

# gradient w.r.t. to value of sparse matrix
A_indices = torch.tensor([[0, 1], [1, 2]])
A_vals = torch.tensor([0.5, 0.5], requires_grad=True)
A_coo = torch.sparse_coo_tensor(A_indices, A_vals, (3, 3))  # or torch.sparse_csr_tensor()
B = torch.sparse.mm(A_coo, A_coo)  # some sparse linear algebra
loss = B.coalesce().values().pow(2).sum()
loss.backward()

A_vals.grad  # tensor([0.2500, 0.2500])

# gradient w.r.t. sparse matrix itself
A_new = A_coo.detach().requires_grad_(True)  # new leaf node
loss2 = torch.sparse.mm(A_new, A_new).coalesce().values().pow(2).sum()
loss2.backward()

A_new.grad  # a sparse tensor with same pattern as A and value of tensor([0.2500, 0.2500])

torch.equal(A_vals.grad, A_new.grad.coalesce().values())  # True

@CarloLucibello
Copy link
Member

CarloLucibello commented Jan 22, 2022

Differentiability of sparse is being taken care of in JuliaDiff/ChainRules.jl#579

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants