Skip to content

Commit

Permalink
fix NNConv docs (#488)
Browse files Browse the repository at this point in the history
* fix nnconv docstring

* cleanup
  • Loading branch information
CarloLucibello authored Aug 30, 2024
1 parent cb82352 commit 82a7450
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -659,8 +659,8 @@ For convenience, also functions returning a single `(out*in, num_edges)` matrix
# Arguments
- `in`: The dimension of input features.
- `out`: The dimension of output features.
- `in`: The dimension of input node features.
- `out`: The dimension of output node features.
- `f`: A (possibly learnable) function acting on edge features.
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
- `σ`: Activation function.
Expand All @@ -670,22 +670,26 @@ For convenience, also functions returning a single `(out*in, num_edges)` matrix
# Examples:
```julia
n_in = 3
n_in_edge = 10
n_out = 5
# create data
s = [1,1,2,3]
t = [2,3,1,1]
in_channel = 3
out_channel = 5
edim = 10
g = GNNGraph(s, t)
# create dense layer
nn = Dense(edim => out_channel * in_channel)
nn = Dense(n_in_edge => n_out * n_in)
# create layer
l = NNConv(in_channel => out_channel, nn, tanh, bias = true, aggr = +)
l = NNConv(n_in => n_out, nn, tanh, bias = true, aggr = +)
x = randn(Float32, n_in, g.num_nodes)
e = randn(Float32, n_in_edge, g.num_edges)
# forward pass
y = l(g, x)
y = l(g, x, e)
```
"""
struct NNConv{W, B, NN, F, A} <: GNNLayer
Expand Down

0 comments on commit 82a7450

Please sign in to comment.