-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
✨ Transformer block from scratch for Machine translation
- Loading branch information
Showing
1 changed file
with
305 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,305 @@ | ||
{ | ||
"nbformat": 4, | ||
"nbformat_minor": 0, | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 2 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython2", | ||
"version": "2.7.6" | ||
}, | ||
"colab": { | ||
"name": "seq2seq_translation_attention.ipynb", | ||
"provenance": [] | ||
}, | ||
"accelerator": "GPU" | ||
}, | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"collapsed": true, | ||
"id": "T0CUqI56C9V9" | ||
}, | ||
"source": [ | ||
"# Transformer block from scratch for Machine translation" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"id": "Ql1UG4OeC9WC", | ||
"outputId": "904f003e-b1c5-42af-e61e-72e533dc4758" | ||
}, | ||
"source": [ | ||
"from __future__ import unicode_literals, print_function, division\n", | ||
"from io import open\n", | ||
"import unicodedata\n", | ||
"import string\n", | ||
"import re\n", | ||
"import random\n", | ||
"\n", | ||
"import torch\n", | ||
"import torch.nn as nn\n", | ||
"from torch import optim\n", | ||
"import torch.nn.functional as F\n", | ||
"\n", | ||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", | ||
"! curl https://download.pytorch.org/tutorial/data.zip --create-dirs -o .pytorch/engfr/data.zip\n", | ||
"! python -m zipfile -e .pytorch/engfr/data.zip .pytorch/engfr/\n", | ||
"! mv .pytorch/engfr/data/* .pytorch/engfr/" | ||
], | ||
"execution_count": 18, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
" % Total % Received % Xferd Average Speed Time Time Time Current\n", | ||
" Dload Upload Total Spent Left Speed\n", | ||
"\n", | ||
" 0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0\n", | ||
" 0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0\n", | ||
" 0 0 0 0 0 0 0 0 --:--:-- 0:00:01 --:--:-- 0\n", | ||
" 6 2814k 6 173k 0 0 70764 0 0:00:40 0:00:02 0:00:38 70905\n", | ||
" 32 2814k 32 913k 0 0 256k 0 0:00:10 0:00:03 0:00:07 256k\n", | ||
" 79 2814k 79 2233k 0 0 496k 0 0:00:05 0:00:04 0:00:01 496k\n", | ||
"100 2814k 100 2814k 0 0 558k 0 0:00:05 0:00:05 --:--:-- 633k\n", | ||
"'mv' is not recognized as an internal or external command,\n", | ||
"operable program or batch file.\n" | ||
] | ||
} | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"collapsed": false, | ||
"id": "uSepgT4-C9WF" | ||
}, | ||
"source": [ | ||
"## Data preprocessing" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "kHQ1J4C2C9WG" | ||
}, | ||
"source": [ | ||
"SOS_token = 0\n", | ||
"EOS_token = 1\n", | ||
"\n", | ||
"\n", | ||
"class Lang:\n", | ||
" def __init__(self, name):\n", | ||
" self.name = name\n", | ||
" self.word2index = {}\n", | ||
" self.word2count = {}\n", | ||
" self.index2word = {0: \"SOS\", 1: \"EOS\"}\n", | ||
" self.n_words = 2 # Count SOS and EOS\n", | ||
"\n", | ||
" def addSentence(self, sentence):\n", | ||
" for word in sentence.split(' '):\n", | ||
" self.addWord(word)\n", | ||
"\n", | ||
" def addWord(self, word):\n", | ||
" if word not in self.word2index:\n", | ||
" self.word2index[word] = self.n_words\n", | ||
" self.word2count[word] = 1\n", | ||
" self.index2word[self.n_words] = word\n", | ||
" self.n_words += 1\n", | ||
" else:\n", | ||
" self.word2count[word] += 1" | ||
], | ||
"execution_count": 19, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "3fkgqt4HC9WH" | ||
}, | ||
"source": [ | ||
"# Turn a Unicode string to plain ASCII, thanks to\n", | ||
"# https://stackoverflow.com/a/518232/2809427\n", | ||
"def unicodeToAscii(s):\n", | ||
" return ''.join(\n", | ||
" c for c in unicodedata.normalize('NFD', s)\n", | ||
" if unicodedata.category(c) != 'Mn'\n", | ||
" )\n", | ||
"\n", | ||
"# Lowercase, trim, and remove non-letter characters\n", | ||
"\n", | ||
"\n", | ||
"def normalizeString(s):\n", | ||
" s = unicodeToAscii(s.lower().strip())\n", | ||
" s = re.sub(r\"([.!?])\", r\" \\1\", s)\n", | ||
" s = re.sub(r\"[^a-zA-Z.!?]+\", r\" \", s)\n", | ||
" return s" | ||
], | ||
"execution_count": 20, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"collapsed": false, | ||
"id": "AIvhIi-XC9WJ" | ||
}, | ||
"source": [ | ||
"## read data" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "1kG0nTPVC9WJ" | ||
}, | ||
"source": [ | ||
"def readLangs(lang1, lang2, reverse=False):\n", | ||
" print(\"Reading lines...\")\n", | ||
"\n", | ||
" # Read the file and split into lines\n", | ||
" lines = open('.pytorch/engfr/%s-%s.txt' % (lang1, lang2), encoding='utf-8').read().strip().split('\\n')\n", | ||
"\n", | ||
" # Split every line into pairs and normalize\n", | ||
" pairs = [[normalizeString(s) for s in l.split('\\t')] for l in lines]\n", | ||
"\n", | ||
" # Reverse pairs, make Lang instances\n", | ||
" if reverse:\n", | ||
" pairs = [list(reversed(p)) for p in pairs]\n", | ||
" input_lang = Lang(lang2)\n", | ||
" output_lang = Lang(lang1)\n", | ||
" else:\n", | ||
" input_lang = Lang(lang1)\n", | ||
" output_lang = Lang(lang2)\n", | ||
"\n", | ||
" return input_lang, output_lang, pairs" | ||
], | ||
"execution_count": 21, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "aH40yeGHC9WL" | ||
}, | ||
"source": [ | ||
"MAX_LENGTH = 20\n", | ||
"\n", | ||
"eng_prefixes = (\n", | ||
" \"i am \", \"i m \",\n", | ||
" \"he is\", \"he s \",\n", | ||
" \"she is\", \"she s \",\n", | ||
" \"you are\", \"you re \",\n", | ||
" \"we are\", \"we re \",\n", | ||
" \"they are\", \"they re \"\n", | ||
")\n", | ||
"\n", | ||
"\n", | ||
"def filterPair(p):\n", | ||
" return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH and p[1].startswith(eng_prefixes)\n", | ||
"\n", | ||
"def filterPairs(pairs):\n", | ||
" return [pair for pair in pairs if filterPair(pair)]" | ||
], | ||
"execution_count": 22, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"id": "kD2AZfrCC9WM", | ||
"outputId": "bf035a7b-29e6-4514-c15e-68516a30e6bf" | ||
}, | ||
"source": [ | ||
"def prepareData(lang1, lang2, reverse=False):\n", | ||
" input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)\n", | ||
" print(\"Read %s sentence pairs\" % len(pairs))\n", | ||
" pairs = filterPairs(pairs)\n", | ||
" print(\"Trimmed to %s sentence pairs\" % len(pairs))\n", | ||
" print(\"Counting words...\")\n", | ||
" for pair in pairs:\n", | ||
" input_lang.addSentence(pair[0])\n", | ||
" output_lang.addSentence(pair[1])\n", | ||
" print(\"Counted words:\")\n", | ||
" print(input_lang.name, input_lang.n_words)\n", | ||
" print(output_lang.name, output_lang.n_words)\n", | ||
" return input_lang, output_lang, pairs\n", | ||
"\n", | ||
"\n", | ||
"input_lang, output_lang, pairs = prepareData('eng', 'fra', True)\n", | ||
"print(random.choice(pairs))" | ||
], | ||
"execution_count": 23, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Reading lines...\n", | ||
"Read 135842 sentence pairs\n", | ||
"Trimmed to 13033 sentence pairs\n", | ||
"Counting words...\n", | ||
"Counted words:\n", | ||
"fra 5143\n", | ||
"eng 3371\n", | ||
"['nous sommes tous ici .', 'we re all here .']\n" | ||
] | ||
} | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "KBvyEySxC9WP" | ||
}, | ||
"source": [ | ||
"def indexesFromSentence(lang, sentence):\n", | ||
" return [lang.word2index[word] for word in sentence.split(' ')]\n", | ||
"\n", | ||
"\n", | ||
"def tensorFromSentence(lang, sentence):\n", | ||
" indexes = indexesFromSentence(lang, sentence)\n", | ||
" indexes.append(EOS_token)\n", | ||
" return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)\n", | ||
"\n", | ||
"\n", | ||
"def tensorsFromPair(pair):\n", | ||
" input_tensor = tensorFromSentence(input_lang, pair[0])\n", | ||
" target_tensor = tensorFromSentence(output_lang, pair[1])\n", | ||
" return (input_tensor, target_tensor)" | ||
], | ||
"execution_count": 24, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"collapsed": false, | ||
"id": "IjA1nldMC9WR" | ||
}, | ||
"source": [ | ||
"## Model implementation" | ||
] | ||
} | ||
] | ||
} |