-
Notifications
You must be signed in to change notification settings - Fork 88
/
Copy pathtest_hsn_layer.py
40 lines (29 loc) · 1.12 KB
/
test_hsn_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""Test the HSN layer."""
import torch
from topomodelx.nn.simplicial.hsn_layer import HSNLayer
class TestHSNLayer:
"""Test the HSN layer."""
def test_forward(self):
"""Test the forward pass of the HSN layer."""
channels = 5
n_nodes = 10
n_edges = 20
incidence_1 = torch.randint(0, 2, (n_nodes, n_edges)).float()
adjacency_0 = torch.randint(0, 2, (n_nodes, n_nodes)).float()
x_0 = torch.randn(n_nodes, channels)
hsn = HSNLayer(channels)
output = hsn.forward(x_0, incidence_1, adjacency_0)
assert output.shape == (n_nodes, channels)
def test_reset_parameters(self):
"""Test the reset of the parameters."""
channels = 5
hsn = HSNLayer(channels)
hsn.reset_parameters()
for module in hsn.modules():
if isinstance(module, torch.nn.Conv2d):
torch.testing.assert_allclose(
module.weight, torch.zeros_like(module.weight)
)
torch.testing.assert_allclose(
module.bias, torch.zeros_like(module.bias)
)