From 2b55a36b0f434c08c70787354209e6bae4a69cd9 Mon Sep 17 00:00:00 2001 From: ALi Date: Mon, 19 Jul 2021 14:02:11 +0300 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Implement=20unet=20model=20using=20?= =?UTF-8?q?3=20upsampling=20and=20downsampling=20features?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- unet_semantic_segmentation.ipynb | 107 ++++++++++++++++++++++++++++++- 1 file changed, 105 insertions(+), 2 deletions(-) diff --git a/unet_semantic_segmentation.ipynb b/unet_semantic_segmentation.ipynb index 4da900c..d25e381 100644 --- a/unet_semantic_segmentation.ipynb +++ b/unet_semantic_segmentation.ipynb @@ -331,9 +331,112 @@ }, { "cell_type": "markdown", - "source": [], + "source": [ + "## UNET Model" + ], "metadata": { - "collapsed": false + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 69, + "outputs": [], + "source": [ + "class UNET(nn.Module):\n", + " def __init__(self, in_channels, out_channels):\n", + " super().__init__()\n", + " self.conv1 = self.contract_block(in_channels, 32, 7, 3)\n", + " self.conv2 = self.contract_block(32, 64, 3, 1)\n", + " self.conv3 = self.contract_block(64, 128, 3, 1)\n", + "\n", + " self.upconv3 = self.expand_block(128, 64, 3, 1)\n", + " self.upconv2 = self.expand_block(64 * 2, 32, 3, 1)\n", + " self.upconv1 = self.expand_block(32 * 2, out_channels, 3, 1)\n", + "\n", + " def contract_block(self, in_channels, out_channels, kernel_size, padding):\n", + " contract = nn.Sequential(\n", + " nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),\n", + " nn.BatchNorm2d(out_channels),\n", + " nn.ReLU(),\n", + " nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),\n", + " nn.BatchNorm2d(out_channels),\n", + " nn.ReLU(),\n", + " nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n", + " )\n", + " return contract\n", + "\n", + " def expand_block(self, in_channels, out_channels, kernel_size, padding):\n", + " expand = nn.Sequential(\n", + " nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=padding),\n", + " nn.BatchNorm2d(out_channels),\n", + " nn.ReLU(),\n", + " nn.Conv2d(out_channels, out_channels, kernel_size, stride=1, padding=padding),\n", + " nn.BatchNorm2d(out_channels),\n", + " nn.ReLU(),\n", + " nn.ConvTranspose2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1)\n", + " )\n", + " return expand\n", + "\n", + " def forward(self, x):\n", + " # down-sampling\n", + " conv1 = self.conv1(x)\n", + " conv2 = self.conv2(conv1)\n", + " conv3 = self.conv3(conv2)\n", + "\n", + " # upsampling\n", + " upconv3 = self.upconv3(conv3)\n", + " upconv2 = self.upconv2(torch.cat([upconv3, conv2], 1))\n", + " upconv1 = self.upconv1(torch.cat([upconv2, conv1], 1))\n", + "\n", + " return upconv1" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 70, + "outputs": [], + "source": [ + "unet = UNET(3, 2)" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 72, + "outputs": [ + { + "data": { + "text/plain": "(torch.Size([1, 3, 128, 128]), torch.Size([1, 2, 128, 128]))" + }, + "execution_count": 72, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pred = unet(xb)\n", + "xb.shape, pred.shape" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } } } ],