diff --git a/test/nn/simplicial/test_sccnn.py b/test/nn/simplicial/test_sccnn.py index fb78bf42..93275ec7 100644 --- a/test/nn/simplicial/test_sccnn.py +++ b/test/nn/simplicial/test_sccnn.py @@ -43,7 +43,7 @@ def test_forward(self): incidence_2 = simplicial_complex.incidence_matrix(rank=2) laplacian_0 = simplicial_complex.hodge_laplacian_matrix(rank=0, weight=True) laplacian_down_1 = simplicial_complex.down_laplacian_matrix(rank=1, weight=True) - laplacian_up_1 = simplicial_complex.up_laplacian_matrix(rank=1, weight=True) + laplacian_up_1 = simplicial_complex.up_laplacian_matrix(rank=1) laplacian_2 = simplicial_complex.hodge_laplacian_matrix(rank=2, weight=True) incidence_1 = torch.from_numpy(incidence_1.todense()).to_sparse() diff --git a/test/nn/simplicial/test_scnn.py b/test/nn/simplicial/test_scnn.py index ca6baa32..82c62678 100644 --- a/test/nn/simplicial/test_scnn.py +++ b/test/nn/simplicial/test_scnn.py @@ -38,7 +38,7 @@ def test_forward(self): incidence_2 = simplicial_complex.incidence_matrix(rank=2) laplacian_0 = simplicial_complex.hodge_laplacian_matrix(rank=0, weight=True) laplacian_down_1 = simplicial_complex.down_laplacian_matrix(rank=1, weight=True) - laplacian_up_1 = simplicial_complex.up_laplacian_matrix(rank=1, weight=True) + laplacian_up_1 = simplicial_complex.up_laplacian_matrix(rank=1) laplacian_2 = simplicial_complex.hodge_laplacian_matrix(rank=2, weight=True) incidence_1 = torch.from_numpy(incidence_1.todense()).to_sparse() @@ -80,6 +80,7 @@ def get_simplicial_features(dataset, rank): conv_order_down=conv_order_down, conv_order_up=conv_order_up, n_layers=num_layers, + aggr=True, ) with torch.no_grad(): forward_pass = model(x_1, laplacian_down_1, laplacian_up_1) diff --git a/topomodelx/nn/simplicial/scnn.py b/topomodelx/nn/simplicial/scnn.py index fba2183e..cac78447 100644 --- a/topomodelx/nn/simplicial/scnn.py +++ b/topomodelx/nn/simplicial/scnn.py @@ -12,18 +12,21 @@ class SCNN(torch.nn.Module): Parameters ---------- - in_channels: int - Dimension of input features - intermediate_channels: int - Dimension of features of intermediate layers - out_channels: int - Dimension of output features - conv_order_down: int - Order of lower convolution - conv_order_up: int - Order of upper convolution - n_layers: int - Numer of layers + in_channels : int + Dimension of input features. + intermediate_channels : int + Dimension of features of intermediate layers. + out_channels : int + Dimension of output features. + conv_order_down : int + Order of lower convolution. + conv_order_up : int + Order of upper convolution. + aggr : bool + Whether to aggregate features on the nodes into 1 feature for the whole complex. + Default: False. + n_layers : int + Number of layers. """ def __init__( @@ -33,6 +36,7 @@ def __init__( out_channels, conv_order_down, conv_order_up, + aggr=False, aggr_norm=False, update_func=None, n_layers=2, @@ -59,7 +63,7 @@ def __init__( update_func=update_func, ) ) - + self.aggr = aggr self.linear = torch.nn.Linear(out_channels, 1) self.layers = torch.nn.ModuleList(layers) @@ -68,22 +72,25 @@ def forward(self, x, laplacian_down, laplacian_up): Parameters ---------- - x: tensor - shape = [n_simplices, channels] + x : tensor, shape=[n_simplices, channels] node/edge/face features - - laplacian: tensor - shape = [n_simplices,n_simplices] + laplacian_down : tensor, shape=[n_simplices, n_simplices] + Down Laplacian. For node features, laplacian_down = None + laplacian_up: tensor, shape=[n_edges, n_nodes] + Up Laplacian. - incidence_1: tensor - order 1 incidence matrix - shape = [n_edges, n_nodes] + Returns + ------- + one_dimensional_cells_mean : tensor, shape=[n_simplices, 1] + Mean on one-dimensional cells. """ for layer in self.layers: x = layer(x, laplacian_down, laplacian_up) x_1 = self.linear(x) + if not self.aggr: + return x_1 one_dimensional_cells_mean = torch.nanmean(x_1, dim=0) one_dimensional_cells_mean[torch.isnan(one_dimensional_cells_mean)] = 0 diff --git a/tutorials/simplicial/sccnn_train.ipynb b/tutorials/simplicial/sccnn_train.ipynb index ffc23a2b..f3f5e521 100644 --- a/tutorials/simplicial/sccnn_train.ipynb +++ b/tutorials/simplicial/sccnn_train.ipynb @@ -59,7 +59,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -84,16 +84,13 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "downloading shrec 16 small dataset...\n", - "\n", - "done!\n", "Loading shrec 16 small dataset...\n", "\n", "done!\n" @@ -113,7 +110,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -146,7 +143,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -162,10 +159,10 @@ "for simplex in simplexes:\n", " incidence_1 = simplex.incidence_matrix(rank=1)\n", " incidence_2 = simplex.incidence_matrix(rank=2)\n", - " laplacian_0 = simplex.hodge_laplacian_matrix(rank=0, weight=True)\n", - " laplacian_down_1 = simplex.down_laplacian_matrix(rank=1, weight=True)\n", - " laplacian_up_1 = simplex.up_laplacian_matrix(rank=1, weight=True)\n", - " laplacian_2 = simplex.hodge_laplacian_matrix(rank=2, weight=True)\n", + " laplacian_0 = simplex.hodge_laplacian_matrix(rank=0)\n", + " laplacian_down_1 = simplex.down_laplacian_matrix(rank=1)\n", + " laplacian_up_1 = simplex.up_laplacian_matrix(rank=1)\n", + " laplacian_2 = simplex.hodge_laplacian_matrix(rank=2)\n", "\n", " incidence_1 = torch.from_numpy(incidence_1.todense()).to_sparse()\n", " incidence_2 = torch.from_numpy(incidence_2.todense()).to_sparse()\n", @@ -193,7 +190,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -245,7 +242,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -286,7 +283,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -301,16 +298,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 1 loss: 117979.8366\n", - "Test_loss: 406.8377\n", - "Epoch: 2 loss: 977.5059\n", - "Test_loss: 2.4212\n", - "Epoch: 3 loss: 217.5435\n", - "Test_loss: 5.1847\n", - "Epoch: 4 loss: 202.9460\n", - "Test_loss: 2.7320\n", - "Epoch: 5 loss: 212.8642\n", - "Test_loss: 3.1975\n" + "Epoch: 1 loss: 944857.9959\n", + "Test_loss: 214.5353\n", + "Epoch: 2 loss: 2055.6737\n", + "Test_loss: 8.6187\n", + "Epoch: 3 loss: 1060.7969\n", + "Test_loss: 1.4612\n", + "Epoch: 4 loss: 635.5951\n", + "Test_loss: 15.8349\n", + "Epoch: 5 loss: 408.8397\n", + "Test_loss: 45.9482\n" ] } ], @@ -420,7 +417,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -444,7 +441,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -475,7 +472,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -497,7 +494,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -510,22 +507,22 @@ } ], "source": [ - "laplacian_0 = dataset.hodge_laplacian_matrix(rank=0, weight=True)\n", - "laplacian_down_1 = dataset.down_laplacian_matrix(rank=1, weight=True)\n", - "laplacian_up_1 = dataset.up_laplacian_matrix(rank=1, weight=True)\n", - "laplacian_down_2 = dataset.down_laplacian_matrix(rank=2, weight=True)\n", - "laplacian_up_2 = dataset.up_laplacian_matrix(rank=2, weight=True)\n", - "\n", - "laplacian_0 = dataset.adjacency_matrix(rank=0, weight=True)\n", - "laplacian_down_1 = dataset.coadjacency_matrix(rank=1, weight=True)\n", - "laplacian_up_1 = dataset.adjacency_matrix(rank=1, weight=True)\n", - "laplacian_down_2 = dataset.coadjacency_matrix(rank=2, weight=True)\n", - "laplacian_up_2 = dataset.adjacency_matrix(rank=2, weight=True)" + "laplacian_0 = dataset.hodge_laplacian_matrix(rank=0)\n", + "laplacian_down_1 = dataset.down_laplacian_matrix(rank=1)\n", + "laplacian_up_1 = dataset.up_laplacian_matrix(rank=1)\n", + "laplacian_down_2 = dataset.down_laplacian_matrix(rank=2)\n", + "laplacian_up_2 = dataset.up_laplacian_matrix(rank=2)\n", + "\n", + "laplacian_0 = dataset.adjacency_matrix(rank=0)\n", + "laplacian_down_1 = dataset.coadjacency_matrix(rank=1)\n", + "laplacian_up_1 = dataset.adjacency_matrix(rank=1)\n", + "laplacian_down_2 = dataset.coadjacency_matrix(rank=2)\n", + "laplacian_up_2 = dataset.adjacency_matrix(rank=2)" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -550,7 +547,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -580,7 +577,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -601,7 +598,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -701,23 +698,23 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 1 loss: 62.1591 Train_acc: 0.1333\n", - "Test_acc: 0.5000\n", - "Epoch: 2 loss: 67.6678 Train_acc: 0.2333\n", + "Epoch: 1 loss: 31.6361 Train_acc: 0.6000\n", + "Test_acc: 0.7500\n", + "Epoch: 2 loss: 17.7911 Train_acc: 0.5667\n", "Test_acc: 0.7500\n", - "Epoch: 3 loss: 75.0000 Train_acc: 0.2333\n", + "Epoch: 3 loss: 6.6667 Train_acc: 0.6333\n", "Test_acc: 0.7500\n", - "Epoch: 4 loss: 74.3241 Train_acc: 0.2667\n", - "Test_acc: 0.2500\n", - "Epoch: 5 loss: 73.3333 Train_acc: 0.2000\n", - "Test_acc: 0.2500\n" + "Epoch: 4 loss: 6.6667 Train_acc: 0.6333\n", + "Test_acc: 1.0000\n", + "Epoch: 5 loss: 5.0000 Train_acc: 0.7000\n", + "Test_acc: 1.0000\n" ] } ], diff --git a/tutorials/simplicial/scnn_train.ipynb b/tutorials/simplicial/scnn_train.ipynb index 435548cd..d8f87b45 100644 --- a/tutorials/simplicial/scnn_train.ipynb +++ b/tutorials/simplicial/scnn_train.ipynb @@ -53,7 +53,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# 1. Comples Classification" + "# 1. Complex Classification" ] }, { @@ -157,10 +157,10 @@ "for simplex in simplexes:\n", " incidence_1 = simplex.incidence_matrix(rank=1)\n", " incidence_2 = simplex.incidence_matrix(rank=2)\n", - " laplacian_0 = simplex.hodge_laplacian_matrix(rank=0, weight=True)\n", - " laplacian_down_1 = simplex.down_laplacian_matrix(rank=1, weight=True)\n", - " laplacian_up_1 = simplex.up_laplacian_matrix(rank=1, weight=True)\n", - " laplacian_2 = simplex.hodge_laplacian_matrix(rank=2, weight=True)\n", + " laplacian_0 = simplex.hodge_laplacian_matrix(rank=0)\n", + " laplacian_down_1 = simplex.down_laplacian_matrix(rank=1)\n", + " laplacian_up_1 = simplex.up_laplacian_matrix(rank=1)\n", + " laplacian_2 = simplex.hodge_laplacian_matrix(rank=2)\n", "\n", " incidence_1 = torch.from_numpy(incidence_1.todense()).to_sparse()\n", " incidence_2 = torch.from_numpy(incidence_2.todense()).to_sparse()\n", @@ -233,6 +233,7 @@ " conv_order_down=conv_order_down,\n", " conv_order_up=conv_order_up,\n", " n_layers=num_layers,\n", + " aggr=True,\n", ")\n", "\n", "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n", @@ -274,16 +275,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 1 loss: 1094.0324\n", - "Test_loss: 47.0409\n", - "Epoch: 2 loss: 176.8604\n", - "Test_loss: 44.5140\n", - "Epoch: 3 loss: 139.2026\n", - "Test_loss: 47.5221\n", - "Epoch: 4 loss: 127.0717\n", - "Test_loss: 84.9650\n", - "Epoch: 5 loss: 105.8648\n", - "Test_loss: 80.5553\n" + "Epoch: 1 loss: 1094.7131\n", + "Test_loss: 198.5218\n", + "Epoch: 2 loss: 103.6860\n", + "Test_loss: 177.6109\n", + "Epoch: 3 loss: 88.7219\n", + "Test_loss: 121.8411\n", + "Epoch: 4 loss: 85.7747\n", + "Test_loss: 90.2382\n", + "Epoch: 5 loss: 84.1926\n", + "Test_loss: 79.0308\n" ] } ], @@ -424,11 +425,11 @@ "metadata": {}, "outputs": [], "source": [ - "laplacian_0 = dataset.hodge_laplacian_matrix(rank=0, weight=True)\n", - "laplacian_down_1 = dataset.down_laplacian_matrix(rank=1, weight=True)\n", - "laplacian_up_1 = dataset.up_laplacian_matrix(rank=1, weight=True)\n", - "laplacian_down_2 = dataset.down_laplacian_matrix(rank=2, weight=True)\n", - "laplacian_up_2 = dataset.up_laplacian_matrix(rank=2, weight=True)\n", + "laplacian_0 = dataset.hodge_laplacian_matrix(rank=0)\n", + "laplacian_down_1 = dataset.down_laplacian_matrix(rank=1)\n", + "laplacian_up_1 = dataset.up_laplacian_matrix(rank=1)\n", + "laplacian_down_2 = dataset.down_laplacian_matrix(rank=2)\n", + "laplacian_up_2 = dataset.up_laplacian_matrix(rank=2)\n", "\n", "laplacian_0 = torch.from_numpy(laplacian_0.todense()).to_sparse()\n", "laplacian_down_1 = torch.from_numpy(laplacian_down_1.todense()).to_sparse()\n", @@ -586,7 +587,7 @@ "metadata": {}, "source": [ "# Create the SCNN for node classification\n", - "Use the SCNNLayer classm we create a neural network with stacked layers. A final linear layer produces an output with shape $n_{\\rm{nodes}}\\times 2$, so we can compare with the binary labels" + "Use the SCNNLayer classm we create a neural network with stacked layers, without aggregation." ] }, { @@ -595,93 +596,22 @@ "metadata": {}, "outputs": [], "source": [ - "from topomodelx.nn.simplicial.scnn_layer import SCNNLayer\n", - "\n", - "\n", - "class SCNN(torch.nn.Module):\n", - " \"\"\"Simplicial convolutional neural network implementation for binary node classification.\n", - "\n", - " Note: At the last layer, we obtain the output on simplcies, e.g., edges.\n", - " To perform the node classification task for this challenge, we consider a projection step via the incidence matrix B_1 which obtains the node labels from the edge output, which was also done in \"T Mitchell Roddenberry, Nicholas Glaze and Santiago Segarra. Principled simplicial neural networks for trajectory prediction. International Conference on Machine Learning. 2021\"\n", - "\n", - " Parameters\n", - " ----------\n", - " in_channels: int\n", - " Dimension of input features\n", - " intermediate_channels: int\n", - " Dimension of features of intermediate layers\n", - " out_channels: int\n", - " Dimension of output features\n", - " conv_order_down: int\n", - " Order of lower convolution\n", - " conv_order_up: int\n", - " Order of upper convolution\n", - " n_layers: int\n", - " Numer of layers\n", - " \"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " in_channels,\n", - " intermediate_channels,\n", - " out_channels,\n", - " conv_order_down,\n", - " conv_order_up,\n", - " aggr_norm=False,\n", - " update_func=None,\n", - " n_layers=2,\n", - " ):\n", - " super().__init__()\n", - " # First layer -- initial layer has the in_channels as input, and inter_channels as the output\n", - " layers = [\n", - " SCNNLayer(\n", - " in_channels=in_channels,\n", - " out_channels=intermediate_channels,\n", - " conv_order_down=conv_order_down,\n", - " conv_order_up=conv_order_up,\n", - " )\n", - " ]\n", - "\n", - " for _ in range(n_layers - 1):\n", - " layers.append(\n", - " SCNNLayer(\n", - " in_channels=intermediate_channels,\n", - " out_channels=out_channels,\n", - " conv_order_down=conv_order_down,\n", - " conv_order_up=conv_order_up,\n", - " aggr_norm=aggr_norm,\n", - " update_func=update_func,\n", - " )\n", - " )\n", - "\n", - " self.linear = torch.nn.Linear(out_channels, 2)\n", - " self.layers = torch.nn.ModuleList(layers)\n", - "\n", - " def forward(self, x, laplacian_down, laplacian_up, incidence_1):\n", - " \"\"\"Forward computation.\n", - "\n", - " Parameters\n", - " ---------\n", - " x: tensor\n", - " shape = [n_simplices, channels]\n", - " node/edge/face features\n", - "\n", - " laplacian: tensor\n", - " shape = [n_simplices,n_simplices]\n", - " For node features, laplacian_down = None\n", - "\n", - " incidence_1: tensor\n", - " order 1 incidence matrix\n", - " shape = [n_edges, n_nodes]\n", - " \"\"\"\n", - " for layer in self.layers:\n", - " x = layer(x, laplacian_down, laplacian_up)\n", - " \"\"\"\n", - " Project the output from edges to nodes \n", - " incidence_1 @ x\n", - " \"\"\"\n", - " logits = self.linear(incidence_1 @ x)\n", - " return torch.softmax(logits, dim=-1)" + "model = SCNN(\n", + " in_channels=in_channels,\n", + " intermediate_channels=intermediate_channels,\n", + " out_channels=out_channels,\n", + " conv_order_down=conv_order_down,\n", + " conv_order_up=conv_order_up,\n", + " n_layers=num_layers,\n", + " aggr=False,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will add a final computation that produces an output with shape $n_{\\rm{nodes}}\\times 2$, so we can compare with the binary labels." ] }, { @@ -701,7 +631,7 @@ "output_type": "stream", "text": [ "SCNN(\n", - " (linear): Linear(in_features=16, out_features=2, bias=True)\n", + " (linear): Linear(in_features=16, out_features=1, bias=True)\n", " (layers): ModuleList(\n", " (0-1): 2 x SCNNLayer()\n", " )\n", @@ -756,13 +686,23 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 1 loss: 0.7314 Train_acc: 0.4333\n", - "Epoch: 2 loss: 0.7310 Train_acc: 0.4333\n", - "Test_acc: 1.0000\n", - "Epoch: 3 loss: 0.7306 Train_acc: 0.4333\n", - "Epoch: 4 loss: 0.7302 Train_acc: 0.4333\n", - "Test_acc: 1.0000\n", - "Epoch: 5 loss: 0.7297 Train_acc: 0.4333\n" + "torch.Size([34, 1])\n", + "torch.Size([30, 2])\n", + "Epoch: 1 loss: 0.8799 Train_acc: 0.0000\n", + "torch.Size([34, 1])\n", + "torch.Size([30, 2])\n", + "Epoch: 2 loss: 0.8799 Train_acc: 0.0000\n", + "Test_acc: 0.0000\n", + "torch.Size([34, 1])\n", + "torch.Size([30, 2])\n", + "Epoch: 3 loss: 0.8799 Train_acc: 0.0000\n", + "torch.Size([34, 1])\n", + "torch.Size([30, 2])\n", + "Epoch: 4 loss: 0.8799 Train_acc: 0.0000\n", + "Test_acc: 0.0000\n", + "torch.Size([34, 1])\n", + "torch.Size([30, 2])\n", + "Epoch: 5 loss: 0.8799 Train_acc: 0.0000\n" ] } ], @@ -774,9 +714,14 @@ " model.train()\n", " optimizer.zero_grad()\n", "\n", - " y_hat = model(x, laplacian_down, laplacian_up, incidence_1)\n", + " y_hat = model(x, laplacian_down, laplacian_up)\n", + " y_hat = torch.softmax(\n", + " incidence_1 @ y_hat, dim=-1\n", + " ) # Transform features on edges to features on nodes\n", + " print(y_hat.shape)\n", + " print(y_train.shape)\n", " loss = torch.nn.functional.binary_cross_entropy_with_logits(\n", - " y_hat[: len(y_train)].float(), y_train.float()\n", + " y_hat[: len(y_train)].float().squeeze(), torch.argmax(y_train, dim=1).float()\n", " )\n", " epoch_loss.append(loss.item())\n", " loss.backward()\n", @@ -790,7 +735,10 @@ " )\n", " if epoch_i % test_interval == 0:\n", " with torch.no_grad():\n", - " y_hat_test = model(x, laplacian_down, laplacian_up, incidence_1)\n", + " y_hat_test = model(x, laplacian_down, laplacian_up)\n", + " y_hat_test = torch.softmax(\n", + " incidence_1 @ y_hat_test, dim=-1\n", + " ) # Transform features on edges to features on nodes\n", " y_pred_test = torch.where(\n", " y_hat_test > 0.5, torch.tensor(1), torch.tensor(0)\n", " )\n", @@ -803,6 +751,13 @@ " )\n", " print(f\"Test_acc: {test_accuracy:.4f}\", flush=True)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": {