Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added GMMConv #147

Merged
merged 36 commits into from
Mar 23, 2022
Merged

Added GMMConv #147

merged 36 commits into from
Mar 23, 2022

Conversation

itsmohitanand
Copy link
Contributor

Added GMMConv from the paper: Geometric deep learning on graphs and manifolds using mixture model CNNs

Added GMMConv from the paper: Geometric deep learning on graphs and manifolds using mixture model CNNs
@CarloLucibello
Copy link
Member

Tests are missing

@codecov
Copy link

codecov bot commented Mar 21, 2022

Codecov Report

Merging #147 (65488a1) into master (04d026b) will decrease coverage by 0.08%.
The diff coverage is 70.58%.

@@            Coverage Diff             @@
##           master     #147      +/-   ##
==========================================
- Coverage   85.61%   85.52%   -0.09%     
==========================================
  Files          15       15              
  Lines        1251     1285      +34     
==========================================
+ Hits         1071     1099      +28     
- Misses        180      186       +6     
Impacted Files Coverage Δ
src/layers/conv.jl 78.70% <70.58%> (-0.82%) ⬇️
src/msgpass.jl 73.80% <0.00%> (+9.52%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 04d026b...65488a1. Read the comment docs.

itsmohitanand and others added 6 commits March 21, 2022 16:36
Co-authored-by: Carlo Lucibello <[email protected]>
Co-authored-by: Carlo Lucibello <[email protected]>
Co-authored-by: Carlo Lucibello <[email protected]>
Co-authored-by: Carlo Lucibello <[email protected]>
Co-authored-by: Carlo Lucibello <[email protected]>
learnable param in doc, num_edge in the last dim (remove permutedims)
@itsmohitanand
Copy link
Contributor Author

Added the changes.

  • Added the learnable parameter in the doc
  • Changed the name from u to e in view of the package convention
  • As a convention, last dimension as num_edges.

Still need to add the test.

@itsmohitanand
Copy link
Contributor Author

Test added as well

itsmohitanand and others added 3 commits March 22, 2022 08:34
Co-authored-by: Carlo Lucibello <[email protected]>
Co-authored-by: Carlo Lucibello <[email protected]>
Co-authored-by: Carlo Lucibello <[email protected]>
@CarloLucibello
Copy link
Member

A first test failure could be solved by adding

(l::GMMConv)(g::GNNGraph) = GNNGraph(g, ndata=l(g, node_features(g), edge_features(g)))

itsmohitanand and others added 3 commits March 22, 2022 14:52
Co-authored-by: Carlo Lucibello <[email protected]>
Co-authored-by: Carlo Lucibello <[email protected]>
added changes for test
@itsmohitanand
Copy link
Contributor Author

I understand the error 1, this is that if my input is Float32 still my output is Float64, will see why this is happening. Could not understand the second test.

(nin, ein), out = ch
mu = init(ein, K)
sigma_inv = init(ein, K)
b = bias ? Flux.create_bias(ones(out), true) : false
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
b = bias ? Flux.create_bias(ones(out), true) : false
b = bias ? Flux.create_bias(mu, true, out) : false

This should fix the Float64 issue

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is here. Both w and mu are Float32 before this, but it changes to Float64 at this point. How exactly @. functions? w = @. -0.5 * (w - mu)^2

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if this is the most elegant solution, but I had to change the code at two different points.

  • w = @. -0.5 * (w - mu)^2 goes to w = @. - ((w - mu) ^ 2)/ 2
  • m = 1 / d .* m goes to m = m./reshape(d, (1, g.num_nodes))
    The problem is if in degrees say d=[1,2,1] is operated to 1/d then it converts to Float 64, depending on the system. Same with 2 and 0.5 I guess.

@CarloLucibello
Copy link
Member

Could not understand the second test.

Me neither. Seems to be a Zygote error due to a not supported try/catch. But we already have @assert and @warn in other layers and they don't seem to cause problems. Maybe try to simplify the forward a bit and test locally with something like

l = GMMConv((2,2)=>2)
g = rand_graph(5, 10, ndata=rand(Float32,2,5), edata=rand(Float32,2,10))
gradient(() -> sum(l(g, g.ndata.x, g.edata.e)), Flux.params(l))  

Co-authored-by: Carlo Lucibello <[email protected]>
@itsmohitanand
Copy link
Contributor Author

Could not understand the second test.

Me neither. Seems to be a Zygote error due to a not supported try/catch. But we already have @assert and @warn in other layers and they don't seem to cause problems. Maybe try to simplify the forward a bit and test locally with something like

l = GMMConv((2,2)=>2)
g = rand_graph(5, 10, ndata=rand(Float32,2,5), edata=rand(Float32,2,10))
gradient(() -> sum(l(g, g.ndata.x, g.edata.e)), Flux.params(l))  

It is probably related to the string interpolation error in gradient. But changing this
"Pseudo-cordinate dim $(size(e)) does not match (ein=$(ein),num_edge=$(g.num_edges))"
to
"Pseudo-cordinate dim does not match (ein, num_edge))"
runs everything smoothly

CarloLucibello and others added 2 commits March 23, 2022 06:30
changed the aggr method to mean and got rid of dividing by in degrees (creating NAN)
@itsmohitanand
Copy link
Contributor Author

NANs possibly because dividing by indegree, instead changed the aggregator method to mean. This should potentially get. rid of the NAN values

itsmohitanand and others added 3 commits March 23, 2022 07:40
Co-authored-by: Carlo Lucibello <[email protected]>
Co-authored-by: Carlo Lucibello <[email protected]>
σ only printed if different from identity, same with residual
@CarloLucibello
Copy link
Member

Great, impressive works, thanks!

@CarloLucibello CarloLucibello merged commit 8eddb5f into JuliaGraphs:master Mar 23, 2022
@itsmohitanand
Copy link
Contributor Author

Great, impressive works, thanks!

Thanks a lot for all the help! Will try some other conv layer this weekend. :) And then will see about temporal GNN.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants