diff --git a/src/NeuralGraphPDE.jl b/src/NeuralGraphPDE.jl index c3105c8..87a8210 100644 --- a/src/NeuralGraphPDE.jl +++ b/src/NeuralGraphPDE.jl @@ -11,7 +11,6 @@ import GraphNeuralNetworks: propagate, apply_edges import Lux: initialparameters, parameterlength, statelength, Chain, applychain, initialstates include("utils.jl") -#include("msgpass.jl") seems we don't need it! include("layers.jl") export AbstractGNNLayer, AbstractGNNContainerLayer diff --git a/src/layers.jl b/src/layers.jl index cdce95b..aacdcf0 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -35,29 +35,29 @@ Edge convolutional layer from [Learning continuous-time PDEs from sparse data wi \mathbf{u}_i' = \square_{j \in N(i)}\, \phi([\mathbf{u}_i, \mathbf{u}_j; \mathbf{x}_j - \mathbf{x}_i]) ``` -## Arguments +# Arguments - `ϕ`: A neural network. - `initialgraph`: `GNNGraph` or a function that returns a `GNNGraph` - `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`). -## Inputs +# Inputs - `u`: Trainable node embeddings, `NamedTuple` or `Array`. -## Returns +# Returns - `NamedTuple` or `Array` that is consistent with `x` with different a size of channels. -## Parameters +# Parameters - Parameters of `ϕ`. -## States +# States - `graph`: `GNNGraph` where `graph.ndata.x` represents the spatial coordinates of nodes. You can also put other nontrainable node features in `graph.ndata` with arbitrary keys. They will be concatenated like `u`. -## Examples +# Examples ```julia s = [1, 1, 2, 3]