Skip to content

Commit

Permalink
Create a custom dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
alimoezzi committed Mar 15, 2021
1 parent add8d0f commit 307053b
Showing 1 changed file with 41 additions and 1 deletion.
42 changes: 41 additions & 1 deletion RNN_building_blocks.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 3,
"outputs": [],
"source": [
"import torch\n",
"from torch.utils.data import Dataset\n",
"from torchvision import datasets, transforms\n",
"import torch.nn.functional as F\n",
"from torch import nn\n",
Expand All @@ -34,6 +35,45 @@
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [],
"source": [
"from torchvision import datasets, transforms\n",
"import torch.nn.functional as F\n",
"from torch import nn\n",
"\n",
"class mySeries(Dataset):\n",
" \"\"\"Series for RNN dataset.\"\"\"\n",
"\n",
" def __init__(self, length=1000, transform=None):\n",
" self.transform = transform\n",
" self.length = length\n",
"\n",
" def __len__(self):\n",
" return self.length\n",
"\n",
" def __getitem__(self, idx):\n",
" if torch.is_tensor(idx):\n",
" idx = idx.tolist()\n",
"\n",
" X = idx+2\n",
" y = X+2\n",
"\n",
" return X,y\n",
"\n",
"# mySeries dataset\n",
"trainset = mySeries()\n",
"trainloader = torch.utils.data.DataLoader(trainset, batch_size=64)\n"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
}
],
"metadata": {
Expand Down

0 comments on commit 307053b

Please sign in to comment.