Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

outputsize for GNNChain #96

Closed
tclements opened this issue Jan 4, 2022 · 2 comments
Closed

outputsize for GNNChain #96

tclements opened this issue Jan 4, 2022 · 2 comments

Comments

@tclements
Copy link
Contributor

This is a feature request: it'd be nice to have extend the functionality of Flux.outputsize to GNNChains. I imagine this could be applied to either a WithGraph or a GNNGraph and Tuple of inputsize. Here's a sketch of a MWE from the docs:

using Flux, Graphs, GraphNeuralNetworks

din, d, dout = 3, 4, 2 
g = rand_graph(10, 30)
X = randn(Float32, din, 10)
inputsize = size(X) 

model = GNNChain(GCNConv(din => d),
                 BatchNorm(d),
                 x -> relu.(x),
                 GCNConv(d => d, relu),
                 Dropout(0.5),
                 Dense(d, dout))
wg = WithGraph(model, g)

@assert GraphNeuralNetworks.outputsize(model, g, inputsize) == size(model(g,X))
@assert GraphNeuralNetworks.outputsize(wg, inputsize) == wg(X) 
@CarloLucibello
Copy link
Member

Fortunately, Flux.outputsize implementation is so generic that it basically works with anything, you just need to wrap into closures in some cases:

julia> Flux.outputsize(wg, (3, 10))
(2, 10)

julia> Flux.outputsize(x -> model(g, x), (3, 10))
(2, 10)

@tclements
Copy link
Contributor Author

Great, thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants