Skip to content

Commit

Permalink
🚸 Generating batches
Browse files Browse the repository at this point in the history
  • Loading branch information
alimoezzi committed Apr 23, 2021
1 parent 452e262 commit 1b2a304
Showing 1 changed file with 76 additions and 5 deletions.
81 changes: 76 additions & 5 deletions Machine_translation_Pytorch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,18 @@
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 1,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 2,
Expand Down Expand Up @@ -524,11 +536,13 @@
" return [token.text for token in spacy_english.tokenizer(text)]\n",
"\n",
"train_iter, valid_iter, test_iter, = IWSLT2017(root='.pytorch/.data/', language_pair=('en','de'))\n",
"\n",
"train_data = list(train_iter)\n",
"valid_data = list(valid_iter)\n",
"test_data = list(test_iter)\n",
"en_counter = Counter()\n",
"de_counter = Counter()\n",
"\n",
"for (de, en) in train_iter:\n",
"for (en, de) in train_data:\n",
" en_counter.update(tokenize_english(en))\n",
" de_counter.update(tokenize_german(de))\n",
"\n",
Expand All @@ -550,8 +564,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Unique tokens in source (de) vocabs 14101\n",
"Unique tokens in source (en) vocabs 17069\n"
"Unique tokens in source (de) vocabs 16712\n",
"Unique tokens in source (en) vocabs 14018\n"
]
}
],
Expand All @@ -574,7 +588,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"output of the text_transform: [1, 0, 7366, 39, 0, 2]\n"
"output of the text_transform: [1, 92, 15, 54, 241, 2]\n"
]
}
],
Expand All @@ -589,6 +603,63 @@
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"## Generating Batch iterator"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 6,
"outputs": [
{
"data": {
"text/plain": "(tensor([[ 1, 1, 1, ..., 1, 1, 1],\n [ 343, 61, 59, ..., 141, 113, 20],\n [2537, 219, 139, ..., 4919, 44, 45],\n ...,\n [ 3, 3, 3, ..., 5, 3, 3],\n [ 3, 3, 3, ..., 2, 3, 3],\n [ 3, 3, 3, ..., 3, 3, 3]]),\n tensor([[ 1, 1, 1, ..., 1, 1, 1],\n [ 0, 0, 0, ..., 141, 0, 0],\n [ 0, 4791, 0, ..., 0, 0, 0],\n ...,\n [ 3, 3, 3, ..., 141, 3, 3],\n [ 3, 3, 3, ..., 5, 3, 3],\n [ 3, 3, 3, ..., 2, 3, 3]]))"
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from torch.utils.data import DataLoader\n",
"from torch.nn.utils.rnn import pad_sequence\n",
"BATCH_SIZE=32\n",
"def collate_batch(batch):\n",
" de_list, en_list = [], []\n",
" for (_en, _de) in batch:\n",
" en_list.append(torch.tensor(english_transform(_en)))\n",
" de_list.append(torch.tensor(english_transform(_de)))\n",
" return pad_sequence(en_list, padding_value=3.0), pad_sequence(de_list, padding_value=3.0)\n",
"\n",
"train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)\n",
"test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)\n",
"valid_dataloader = DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)\n",
"next(iter(train_dataloader))"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
}
],
"metadata": {
Expand Down

0 comments on commit 1b2a304

Please sign in to comment.