Skip to content

Commit

Permalink
EvolveGCNO
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Dec 16, 2024
1 parent 13142f2 commit b229ab2
Showing 1 changed file with 106 additions and 9 deletions.
115 changes: 106 additions & 9 deletions GraphNeuralNetworks/src/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,42 @@ end
"""
GNNRecurrence(cell)
Construct a recurrent layer that applies the `cell`
to process an entire temporal sequence of node features at once.
Construct a recurrent layer that applies the graph recurrent `cell` forward
multiple times to process an entire temporal sequence of node features at once.
The `cell` has to satisfy the following interface for the forward pass:
`yt, state = cell(g, xt, state)`, where `xt` is the input node features,
`yt` is the updated node features, `state` is the cell state to be updated.
# Forward
layer(g::GNNGraph, x, [state])
layer(g, x, [state])
- `g`: The input graph.
- `x`: The time-varying node features. An array of size `in x timesteps x num_nodes`.
- `state`: The current state of the cell.
Applies the recurrent cell to each timestep of the input sequence.
## Arguments
- `g`: The input `GNNGraph` or `TemporalSnapshotsGNNGraph`.
- If `GNNGraph`, the same graph is used for all timesteps.
- If `TemporalSnapshotsGNNGraph`, a different graph is used for each timestep. Not all cells support this.
- `x`: The time-varying node features.
- If `g` is `GNNGraph`, it is an array of size `in x timesteps x num_nodes`.
- If `g` is `TemporalSnapshotsGNNGraph`, it is an vector of length `timesteps`,
with element `t` of size `in x num_nodes_t`.
- `state`: The initial state for the cell.
If not provided, it is generated by calling `Flux.initialstates(cell)`.
Applies the recurrent cell to each timestep of the input sequence and returns the output as
an array of size `out_features x timesteps x num_nodes`.
## Return
Returns the updated node features:
- If `g` is `GNNGraph`, returns an array of size `out_features x timesteps x num_nodes`.
- If `g` is `TemporalSnapshotsGNNGraph`, returns a vector of length `timesteps`,
with element `t` of size `out_features x num_nodes_t`.
# Examples
The following example considers a static graph and a time-varying node features.
```jldoctest
julia> num_nodes, num_edges = 5, 10;
Expand All @@ -47,6 +66,9 @@ julia> d_in, d_out = 2, 3;
julia> timesteps = 5;
julia> g = rand_graph(num_nodes, num_edges);
GNNGraph:
num_nodes: 5
num_edges: 10
julia> x = rand(Float32, d_in, timesteps, num_nodes);
Expand All @@ -63,6 +85,38 @@ julia> y = layer(g, x);
julia> size(y) # (d_out, timesteps, num_nodes)
(3, 5, 5)
```
Now consider a time-varying graph and time-varying node features.
```jldoctest
julia> d_in, d_out = 2, 3;
julia> timesteps = 5;
julia> num_nodes = [10, 10, 10, 10, 10];
julia> num_edges = [10, 12, 14, 16, 18];
julia> snapshots = [rand_graph(n, m) for (n, m) in zip(num_nodes, num_edges)];
julia> tg = TemporalSnapshotsGNNGraph(snapshots)
julia> x = [rand(Float32, d_in, n) for n in num_nodes];
julia> cell = EvolveGCNOCell(d_in => d_out)
EvolveGCNOCell(2 => 3) # 321 parameters
julia> layer = GNNRecurrence(cell)
GNNRecurrence(
EvolveGCNOCell(2 => 3), # 321 parameters
) # Total: 5 arrays, 321 parameters, 1.535 KiB.
julia> y = layer(tg, x);
julia> length(y) # timesteps
5
julia> size(y[end]) # (d_out, num_nodes[end])
(3, 10)
```
"""
struct GNNRecurrence{G} <: GNNLayer
cell::G
Expand Down Expand Up @@ -437,7 +491,7 @@ in combination with a Gated Recurrent Unit (GRU) cell to model temporal dependen
- `out`: Number of output node features.
- `k`: Diffusion step for the `DConv`.
- `bias`: Add learnable bias. Default `true`.
- `init`: Weights' initializer. Default `glorot_uniform`.
- `init`: Convolution weights' initializer. Default `glorot_uniform`.
# Forward
Expand Down Expand Up @@ -651,3 +705,46 @@ end
function Base.show(io::IO, egcno::EvolveGCNOCell)
print(io, "EvolveGCNOCell($(egcno.in) => $(egcno.out))")
end


"""
EvolveGCNO(args...; kws...)
Construct a recurrent layer corresponding to the [`EvolveGCNOCell`](@ref) cell.
It can be used to process an entire temporal sequence of graphs and node features at once.
The arguments are passed to the [`EvolveGCNOCell`](@ref) constructor.
See [`GNNRecurrence`](@ref) for more details.
# Examples
```jldoctest
julia> d_in, d_out = 2, 3;
julia> timesteps = 5;
julia> num_nodes = [10, 10, 10, 10, 10];
julia> num_edges = [10, 12, 14, 16, 18];
julia> snapshots = [rand_graph(n, m) for (n, m) in zip(num_nodes, num_edges)];
julia> tg = TemporalSnapshotsGNNGraph(snapshots)
julia> x = [rand(Float32, d_in, n) for n in num_nodes];
julia> cell = EvolveGCNO(d_in => d_out)
GNNRecurrence(
EvolveGCNOCell(2 => 3), # 321 parameters
) # Total: 5 arrays, 321 parameters, 1.535 KiB.
julia> y = layer(tg, x);
julia> length(y) # timesteps
5
julia> size(y[end]) # (d_out, num_nodes[end])
(3, 10)
```
"""
EvolveGCNO(args...; kws...) = GNNRecurrence(EvolveGCNOCell(args...; kws...))

0 comments on commit b229ab2

Please sign in to comment.