diff --git a/Machine_translation_Pytorch.ipynb b/Machine_translation_Pytorch.ipynb index 33bb875..0baf508 100644 --- a/Machine_translation_Pytorch.ipynb +++ b/Machine_translation_Pytorch.ipynb @@ -97,6 +97,18 @@ "collapsed": false } }, + { + "cell_type": "code", + "execution_count": 1, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, { "cell_type": "code", "execution_count": 2, @@ -524,11 +536,13 @@ " return [token.text for token in spacy_english.tokenizer(text)]\n", "\n", "train_iter, valid_iter, test_iter, = IWSLT2017(root='.pytorch/.data/', language_pair=('en','de'))\n", - "\n", + "train_data = list(train_iter)\n", + "valid_data = list(valid_iter)\n", + "test_data = list(test_iter)\n", "en_counter = Counter()\n", "de_counter = Counter()\n", "\n", - "for (de, en) in train_iter:\n", + "for (en, de) in train_data:\n", " en_counter.update(tokenize_english(en))\n", " de_counter.update(tokenize_german(de))\n", "\n", @@ -550,8 +564,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "Unique tokens in source (de) vocabs 14101\n", - "Unique tokens in source (en) vocabs 17069\n" + "Unique tokens in source (de) vocabs 16712\n", + "Unique tokens in source (en) vocabs 14018\n" ] } ], @@ -574,7 +588,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "output of the text_transform: [1, 0, 7366, 39, 0, 2]\n" + "output of the text_transform: [1, 92, 15, 54, 241, 2]\n" ] } ], @@ -589,6 +603,63 @@ "name": "#%%\n" } } + }, + { + "cell_type": "markdown", + "source": [ + "## Generating Batch iterator" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 6, + "outputs": [ + { + "data": { + "text/plain": "(tensor([[ 1, 1, 1, ..., 1, 1, 1],\n [ 343, 61, 59, ..., 141, 113, 20],\n [2537, 219, 139, ..., 4919, 44, 45],\n ...,\n [ 3, 3, 3, ..., 5, 3, 3],\n [ 3, 3, 3, ..., 2, 3, 3],\n [ 3, 3, 3, ..., 3, 3, 3]]),\n tensor([[ 1, 1, 1, ..., 1, 1, 1],\n [ 0, 0, 0, ..., 141, 0, 0],\n [ 0, 4791, 0, ..., 0, 0, 0],\n ...,\n [ 3, 3, 3, ..., 141, 3, 3],\n [ 3, 3, 3, ..., 5, 3, 3],\n [ 3, 3, 3, ..., 2, 3, 3]]))" + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from torch.utils.data import DataLoader\n", + "from torch.nn.utils.rnn import pad_sequence\n", + "BATCH_SIZE=32\n", + "def collate_batch(batch):\n", + " de_list, en_list = [], []\n", + " for (_en, _de) in batch:\n", + " en_list.append(torch.tensor(english_transform(_en)))\n", + " de_list.append(torch.tensor(english_transform(_de)))\n", + " return pad_sequence(en_list, padding_value=3.0), pad_sequence(de_list, padding_value=3.0)\n", + "\n", + "train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)\n", + "test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)\n", + "valid_dataloader = DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)\n", + "next(iter(train_dataloader))" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } } ], "metadata": {