Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graph classification: multiple graphs associated with a common label #282

Closed
msainsburydale opened this issue May 22, 2023 · 3 comments
Closed

Comments

@msainsburydale
Copy link

msainsburydale commented May 22, 2023

Hello, many thanks for a great package.

My task can be summarised as graph classification, where (i) multiple graphs are associated with a common label, and where (ii) all graphs have the same structure (i.e., only the node data changes between graphs).

For example, a single input instance may be:

using GraphNeuralNetworks
d = 1                     # dimension of response variable
n = 100                   # number of nodes in the graph
e = 200                   # number of edges in the graph
m = 30                    # number of replicates of the graph
g = rand_graph(n, e)      # fixed structure for all graphs

# Multiple graphs: these would all be associated with a 
# single label that we would like to predict
graphs = [GNNGraph(g; ndata = rand(d, n)) for _ in 1:m]

The usual approach is to batch the graphs into a single super graph:

using Flux: batch
batch(graphs)
# GNNGraph:
#   num_nodes: 3000
#   num_edges: 6000
#   num_graphs: 30
#   ndata:
# 	x = 1×3000 Matrix{Float64}

This approach is natural when the graph structure varies between graphs. However, it is inefficient when the graphs have a fixed structure (particularly with respect to memory, but presumably also in terms of performing the required operations during the propagation and readout modules).

It would be more efficient to use a single graph with ndata storing the replicated data in, for example, the third dimension:

GNNGraph(g; ndata = rand(d, n, m))

This gives an error, since ndata should contain matrices only, I think arrays with the last dimension equal to the number of nodes.

Do you have any suggestions for how best to proceed, in a way that aligns with the philosophy of the package?

@CarloLucibello
Copy link
Member

CarloLucibello commented May 22, 2023

This is an interesting application. ndata wants the last dimension to be the same as the number of nodes, but the support is not limited to matrices, so you can add node features as

g.ndata.x = rand(Float32, d, m, n)

I'm not sure though that graph convolutions will handle properly these tensors. Ideally they should, so let me know how it goes.

@msainsburydale
Copy link
Author

Thanks, that's very helpful.

The approach doesn't work directly "out of the box", at least for GraphConv, but it was not too difficult to get it working. GraphConv expects the node data to be an AbstractMatrix, and the method for matrices doesn't work when the data is a three-dimensional array. So, I defined a method for three-dimensional arrays as follows, making use of batched_mul.

using GraphNeuralNetworks
import GraphNeuralNetworks: GraphConv

function (l::GraphConv)(g::GNNGraph, x::A) where A <: AbstractArray{T, 3} where {T}
    check_num_nodes(g, x)
    m = GraphNeuralNetworks.propagate(copy_xj, g, l.aggr, xj = x)
    l.σ.(l.weight1 ⊠ x .+ l.weight2 ⊠ m .+ l.bias) # ⊠ is shorthand for batched_mul
end

I did some testing with the same simple example discussed above, focusing mainly on the dimensions of the output as a sanity check.

d = 2                      # dimension of response variable
n = 100                    # number of nodes in the graph
e = 200                    # number of edges in the graph
m = 30                     # number of replicates of the graph
g = rand_graph(n, e)       # fixed structure for all graphs
g.ndata.x = rand(d, m, n)  # node data varies between graphs

# One layer only:
out = 16
l = GraphConv(d => out)
l(g)
size(l(g)) # (16, 30, 100)


# Propagation and global-pooling modules:
gnn = GNNChain(
	GraphConv(d => out),
	GraphConv(out => out),
	GraphConv(out => out),
	GlobalPool(+)
)
gnn(g)
u = gnn(g).gdata.u
size(u)    # (16, 30, 1) 

The pooled features are a three-dimenisonal array of size out × m × 1, which is very close to the format of the pooled features one would obtain when "batching" the graph replicates into a single supergraph (in that case, the
the pooled features are a matrix of size out × m). But I suppose that Flux.flatten can be added to the full network architecture to deal with this inconsistency.

Thanks again for your help.

@CarloLucibello
Copy link
Member

yes something like

gnn = GNNChain(
	GraphConv(d => out),
	GraphConv(out => out),
	GraphConv(out => out),
	GlobalPool(+),
        x -> reshape(x, size(x)[1:end-1])
)

should work

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants