Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom Function GPU Compatibitlity Issue: Indexing #91

Closed
umbriquse opened this issue Dec 30, 2021 · 2 comments
Closed

Custom Function GPU Compatibitlity Issue: Indexing #91

umbriquse opened this issue Dec 30, 2021 · 2 comments

Comments

@umbriquse
Copy link
Contributor

Hello, I am curious if you know of a way to access the nodal information that is returned from a GCNChain without using indicies or at least is GPU/CUDA friendly.

Below is the function in question and the issue is from creating the vectors v and p.

function Network.forward(nn::SimpleGNN, state)
  c = nn.common.(state)
  applyV(graph) = nn.vhead(graph, graph.ndata.x)
  resultv = applyV.(c)
  v = [resultv[ind][indDepth] for indDepth in 1:1, ind in 1:length(state)]
  applyP(graph) = nn.phead(graph)
  resultp = applyP.(c)
  p = [resultp[ind].ndata.x[indDepth] for indDepth in 1:state[1].num_nodes, ind in 1:length(state) ]
  return (p, v)
end
modelP = GNNChain(Dense(innerSize, 1),softmax)
modelV = GNNChain( GlobalPool(mean),  # aggregate node-wise features into graph-wise features
                              Dense(innerSize, 1),
                              softmax);

In this case modelP is the nn.phead function call and modelV is the nn.vhead function call.

@CarloLucibello
Copy link
Member

CarloLucibello commented Dec 30, 2021

What is state, a vector of GNNGraphs?
If you are working on multiple graphs, instead of broadcasting the models' forwards you should batch together the graphs and do a single forward. So maybe you want something like

function Network.forward(nn::SimpleGNN, state)
  state = Flux.batch(state)
  c = nn.common(state) # no broadcasting
  v = nn.vhead(c, c.ndata.x)   # size: num_features x num_graphs
  p = nn.phead(c, c.ndata.x)   # size: num_features x num_tot_nodes 
  return p, v
end

If you need p to be a vector of vectors instead you can do

p = nn.phead(c)   # a batched graphs 
p = Flux.unbatch(p)
p = [g.ndata.x for g in p]  

Btw: It is odd to take a softmax on a single feature. Maybe you want a sigmoid there?

@umbriquse
Copy link
Contributor Author

The code is very dynamically changing so there is some legacy code that might be irrational. Thank you for your help, that was the needed solution!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants