Skip to content

Commit

Permalink
✨ Implement unet model using 3 upsampling and downsampling features
Browse files Browse the repository at this point in the history
  • Loading branch information
alimoezzi committed Jul 19, 2021
1 parent 555bd11 commit 2b55a36
Showing 1 changed file with 105 additions and 2 deletions.
107 changes: 105 additions & 2 deletions unet_semantic_segmentation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,112 @@
},
{
"cell_type": "markdown",
"source": [],
"source": [
"## UNET Model"
],
"metadata": {
"collapsed": false
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 69,
"outputs": [],
"source": [
"class UNET(nn.Module):\n",
" def __init__(self, in_channels, out_channels):\n",
" super().__init__()\n",
" self.conv1 = self.contract_block(in_channels, 32, 7, 3)\n",
" self.conv2 = self.contract_block(32, 64, 3, 1)\n",
" self.conv3 = self.contract_block(64, 128, 3, 1)\n",
"\n",
" self.upconv3 = self.expand_block(128, 64, 3, 1)\n",
" self.upconv2 = self.expand_block(64 * 2, 32, 3, 1)\n",
" self.upconv1 = self.expand_block(32 * 2, out_channels, 3, 1)\n",
"\n",
" def contract_block(self, in_channels, out_channels, kernel_size, padding):\n",
" contract = nn.Sequential(\n",
" nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),\n",
" nn.BatchNorm2d(out_channels),\n",
" nn.ReLU(),\n",
" nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),\n",
" nn.BatchNorm2d(out_channels),\n",
" nn.ReLU(),\n",
" nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n",
" )\n",
" return contract\n",
"\n",
" def expand_block(self, in_channels, out_channels, kernel_size, padding):\n",
" expand = nn.Sequential(\n",
" nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=padding),\n",
" nn.BatchNorm2d(out_channels),\n",
" nn.ReLU(),\n",
" nn.Conv2d(out_channels, out_channels, kernel_size, stride=1, padding=padding),\n",
" nn.BatchNorm2d(out_channels),\n",
" nn.ReLU(),\n",
" nn.ConvTranspose2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1)\n",
" )\n",
" return expand\n",
"\n",
" def forward(self, x):\n",
" # down-sampling\n",
" conv1 = self.conv1(x)\n",
" conv2 = self.conv2(conv1)\n",
" conv3 = self.conv3(conv2)\n",
"\n",
" # upsampling\n",
" upconv3 = self.upconv3(conv3)\n",
" upconv2 = self.upconv2(torch.cat([upconv3, conv2], 1))\n",
" upconv1 = self.upconv1(torch.cat([upconv2, conv1], 1))\n",
"\n",
" return upconv1"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 70,
"outputs": [],
"source": [
"unet = UNET(3, 2)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 72,
"outputs": [
{
"data": {
"text/plain": "(torch.Size([1, 3, 128, 128]), torch.Size([1, 2, 128, 128]))"
},
"execution_count": 72,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pred = unet(xb)\n",
"xb.shape, pred.shape"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
}
],
Expand Down

0 comments on commit 2b55a36

Please sign in to comment.