Skip to content

Commit

Permalink
Implement encoder class
Browse files Browse the repository at this point in the history
  • Loading branch information
alimoezzi committed Aug 1, 2021
1 parent 60fa9fa commit e433431
Showing 1 changed file with 70 additions and 2 deletions.
72 changes: 70 additions & 2 deletions mnist_autoencoder.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,77 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"outputs": [],
"source": [],
"source": [
"import torch; torch.manual_seed(0)\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.utils\n",
"import torch.distributions\n",
"import torchvision\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 2,
"outputs": [],
"source": [
"device = 'cuda' if torch.cuda.is_available() else 'cpu'"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"## Model implementation"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"### Encoder class"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [],
"source": [
"class Encoder(nn.Module):\n",
" def __init__(self, latent_dims):\n",
" super(Encoder, self).__init__()\n",
" self.linear1 = nn.Linear(784, 512)\n",
" self.linear2 = nn.Linear(512, latent_dims)\n",
"\n",
" def forward(self, x):\n",
" x = torch.flatten(x, start_dim=1)\n",
" x = nn.ReLU(self.linear1)\n",
" x = self.linear2(x)\n",
" return x"
],
"metadata": {
"collapsed": false,
"pycharm": {
Expand Down

0 comments on commit e433431

Please sign in to comment.