Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GNNLux] Adding NNConv Layer #478

Closed
wants to merge 47 commits into from
Closed

Conversation

rbSparky
Copy link
Contributor

@rbSparky rbSparky commented Aug 2, 2024

Adding conv layer according to #461

@rbSparky
Copy link
Contributor Author

rbSparky commented Aug 2, 2024

If some of the args are not needed for NNConv please let me know

GNNLux/src/layers/conv.jl Outdated Show resolved Hide resolved
@rbSparky
Copy link
Contributor Author

Writing tests

@rbSparky
Copy link
Contributor Author

dimension mismatch is in GNNLib implementation of nn conv:

function nn_conv(l, g::GNNGraph, x::AbstractMatrix, e)
    check_num_nodes(g, x)
    message = Fix1(nn_conv_message, l)
    m = propagate(message, g, l.aggr, xj = x, e = e)
    return l.σ.(l.weight * x .+ m .+ l.bias)
end

in the return statement

any suggestions or ideas as to why this could be happening?

@rbSparky rbSparky marked this pull request as ready for review August 25, 2024 23:35
@rbSparky rbSparky changed the title [WIP] [GNNLux] Adding NNConv Layer [GNNLux] Adding NNConv Layer Aug 25, 2024
GNNLux/Project.toml Outdated Show resolved Hide resolved
@CarloLucibello
Copy link
Member

This code works fine

using GraphNeuralNetworks, Flux

n_in = 3
n_in_edge = 10
n_out = 5

s = [1,1,2,3]
t = [2,3,1,1]
g = GNNGraph(s, t)

nn = Dense(n_in_edge => n_out * n_in)
l = NNConv(n_in => n_out, nn, tanh, bias = true, aggr = +)
x = randn(Float32, n_in, g.num_nodes)
e = randn(Float32, n_in_edge, g.num_edges)
y = l(g, x, e)  

Try to run the corresponding LuxGNN version and see if you get an error, we'll try to debug from there

@CarloLucibello
Copy link
Member

any progress here?

@rbSparky
Copy link
Contributor Author

rbSparky commented Sep 4, 2024

any progress here?

yes got the same error

Got exception outside of a @test
  DimensionMismatch: A has dimensions (15,10) but B has dimensions (3,3)

@CarloLucibello
Copy link
Member

Can you paste the code that errors?

Comment on lines 102 to 120
@testset "NNConv" begin
n_in = 3
n_in_edge = 10
n_out = 5

s = [1,1,2,3]
t = [2,3,1,1]
g2 = GNNGraph(s, t)

nn = Dense(n_in_edge => n_out * n_in)
l = NNConv(n_in => n_out, nn, tanh, aggr = +)
x = randn(Float32, n_in, g2.num_nodes)
e = randn(Float32, n_in_edge, g2.num_edges)
#y = l(g, x, e) # just to see if it runs without an error
#edim = 10
#nn = Dense(edim, in_dims * out_dims)
#l = NNConv(in_dims => out_dims, nn, tanh, aggr = +)
test_lux_layer(rng, l, g2, x, sizey=(n_out, g2.num_nodes), container=true, edge_weight=e)
end
Copy link
Contributor Author

@rbSparky rbSparky Sep 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you paste the code that errors?

i used the code that you gave for tests (in the test file) just without the bias, but that shouldnt cause an issue since issue is with multiplication
it errors while testing

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as I said, can you translate this

using GraphNeuralNetworks, Flux

n_in = 3
n_in_edge = 10
n_out = 5

s = [1,1,2,3]
t = [2,3,1,1]
g = GNNGraph(s, t)

nn = Dense(n_in_edge => n_out * n_in)
l = NNConv(n_in => n_out, nn, tanh, bias = true, aggr = +)
x = randn(Float32, n_in, g.num_nodes)
e = randn(Float32, n_in_edge, g.num_edges)
y = l(g, x, e)  

to the corresponding GNNLux code and see what happens?

Copy link
Contributor Author

@rbSparky rbSparky Sep 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tried to run locally today but seems like there is an update to some package that leads to this error:
ERROR: LoadError: Failed to precompile MLDataDevices

while importing GNNlib (usual setup for the repo, activate, instantiate etc)

Could you check once as well?

Copy link
Contributor Author

@rbSparky rbSparky Sep 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try to run the corresponding LuxGNN version

        n_in = 3
        n_in_edge = 10
        n_out = 5
        
        s = [1,1,2,3]
        t = [2,3,1,1]
        g2 = GNNGraph(s, t)
        
        nn = Dense(n_in_edge => n_out * n_in)
        l = NNConv(n_in => n_out, nn, tanh, aggr = +)
        x = randn(Float32, n_in, g2.num_nodes)
        e = randn(Float32, n_in_edge, g2.num_edges)

        ps = LuxCore.initialparameters(rng, l)
        st = LuxCore.initialstates(rng, l)
        
        y = l(g2, x, e, ps, st)

added this to the test file, there we can see what errors

Copy link
Contributor Author

@rbSparky rbSparky Sep 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

error:

  DimensionMismatch: A has dimensions (15,10) but B has dimensions (3,3)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any suggestions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getting the same dimension error

@CarloLucibello
Copy link
Member

The problem is that ps.weight corresponded to the internal weights of the nn in Lux < v1.0 due LuxDL/Lux.jl#795.
So you were passing to GNNLib.nn_conv the weights of the wrong shape to use in this line

In order to fix the layer you need to do two things:

  • rebase this branch on top of master, so that now we work with Lux >= v1.0 and the weights in the internal nn are found in ps.nn.weight (this was a breaking change in Lux' container layer)
  • Use this definition for the layer (didn't fully test it though)
@concrete struct NNConv <: GNNContainerLayer{(:nn,)}
    nn <: AbstractLuxLayer    
    aggr
    in_dims::Int
    out_dims::Int
    use_bias::Bool
    init_weight
    init_bias
    σ
end

function NNConv(ch::Pair{Int, Int}, nn, σ = identity; 
                aggr = +, 
                init_bias = zeros32,
                use_bias::Bool = true,
                init_weight = glorot_uniform)
    in_dims, out_dims = ch
    σ = NNlib.fast_act(σ)
    return NNConv(nn, aggr, in_dims, out_dims, use_bias, init_weight, init_bias, σ)
end

function LuxCore.initialparameters(rng::AbstractRNG, l::NNConv)
    weight = l.init_weight(rng, l.out_dims, l.in_dims)
    ps = (; nn = LuxCore.initialparameters(rng, l.nn), weight)
    if l.use_bias
        ps = (; ps..., bias = l.init_bias(rng, l.out_dims))
    end
    return ps
end

function LuxCore.initialstates(rng::AbstractRNG, l::NNConv)
    return (; nn = LuxCore.initialstates(rng, l.nn))
end

function LuxCore.parameterlength(l::NNConv)
    n = parameterlength(l.nn) + l.in_dims * l.out_dims
    if l.use_bias
        n += l.out_dims
    end
    return n
end

LuxCore.statelength(l::NNConv) = statelength(l.nn)

function (l::NNConv)(g, x, e, ps, st)
    nn = StatefulLuxLayer{true}(l.nn, ps.nn, st.nn)
    m = (; nn, l.aggr, ps.weight, bias = _getbias(ps), l.σ)
    y = GNNlib.nn_conv(m, g, x, e)
    stnew = _getstate(nn)
    return y, stnew
end

@CarloLucibello
Copy link
Member

Since in this branch there are so many commits and a merge rebase could be hard. Close this and open a new PR if needed.

@rbSparky
Copy link
Contributor Author

rebasing and making new PR

@rbSparky rbSparky closed this Sep 14, 2024
@rbSparky
Copy link
Contributor Author

#491

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants