You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello, I am curious if you know of a way to access the nodal information that is returned from a GCNChain without using indicies or at least is GPU/CUDA friendly.
Below is the function in question and the issue is from creating the vectors v and p.
function Network.forward(nn::SimpleGNN, state)
c = nn.common.(state)
applyV(graph) = nn.vhead(graph, graph.ndata.x)
resultv = applyV.(c)
v = [resultv[ind][indDepth] for indDepth in 1:1, ind in 1:length(state)]
applyP(graph) = nn.phead(graph)
resultp = applyP.(c)
p = [resultp[ind].ndata.x[indDepth] for indDepth in 1:state[1].num_nodes, ind in 1:length(state) ]
return (p, v)
end
modelP = GNNChain(Dense(innerSize, 1),softmax)
modelV = GNNChain( GlobalPool(mean), # aggregate node-wise features into graph-wise features
Dense(innerSize, 1),
softmax);
In this case modelP is the nn.phead function call and modelV is the nn.vhead function call.
The text was updated successfully, but these errors were encountered:
What is state, a vector of GNNGraphs?
If you are working on multiple graphs, instead of broadcasting the models' forwards you should batch together the graphs and do a single forward. So maybe you want something like
function Network.forward(nn::SimpleGNN, state)
state = Flux.batch(state)
c = nn.common(state) # no broadcasting
v = nn.vhead(c, c.ndata.x) # size: num_features x num_graphs
p = nn.phead(c, c.ndata.x) # size: num_features x num_tot_nodes return p, v
end
If you need p to be a vector of vectors instead you can do
p = nn.phead(c) # a batched graphs
p = Flux.unbatch(p)
p = [g.ndata.x for g in p]
Btw: It is odd to take a softmax on a single feature. Maybe you want a sigmoid there?
Hello, I am curious if you know of a way to access the nodal information that is returned from a GCNChain without using indicies or at least is GPU/CUDA friendly.
Below is the function in question and the issue is from creating the vectors v and p.
In this case modelP is the nn.phead function call and modelV is the nn.vhead function call.
The text was updated successfully, but these errors were encountered: