Skip to content

Commit

Permalink
✨ Transformer block from scratch for Machine translation
Browse files Browse the repository at this point in the history
  • Loading branch information
alimoezzi committed Dec 7, 2022
1 parent 649331d commit beec344
Showing 1 changed file with 305 additions and 0 deletions.
305 changes: 305 additions & 0 deletions seq2seq_translation_transformer.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,305 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
},
"colab": {
"name": "seq2seq_translation_attention.ipynb",
"provenance": []
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"collapsed": true,
"id": "T0CUqI56C9V9"
},
"source": [
"# Transformer block from scratch for Machine translation"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Ql1UG4OeC9WC",
"outputId": "904f003e-b1c5-42af-e61e-72e533dc4758"
},
"source": [
"from __future__ import unicode_literals, print_function, division\n",
"from io import open\n",
"import unicodedata\n",
"import string\n",
"import re\n",
"import random\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch import optim\n",
"import torch.nn.functional as F\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"! curl https://download.pytorch.org/tutorial/data.zip --create-dirs -o .pytorch/engfr/data.zip\n",
"! python -m zipfile -e .pytorch/engfr/data.zip .pytorch/engfr/\n",
"! mv .pytorch/engfr/data/* .pytorch/engfr/"
],
"execution_count": 18,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" % Total % Received % Xferd Average Speed Time Time Time Current\n",
" Dload Upload Total Spent Left Speed\n",
"\n",
" 0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0\n",
" 0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0\n",
" 0 0 0 0 0 0 0 0 --:--:-- 0:00:01 --:--:-- 0\n",
" 6 2814k 6 173k 0 0 70764 0 0:00:40 0:00:02 0:00:38 70905\n",
" 32 2814k 32 913k 0 0 256k 0 0:00:10 0:00:03 0:00:07 256k\n",
" 79 2814k 79 2233k 0 0 496k 0 0:00:05 0:00:04 0:00:01 496k\n",
"100 2814k 100 2814k 0 0 558k 0 0:00:05 0:00:05 --:--:-- 633k\n",
"'mv' is not recognized as an internal or external command,\n",
"operable program or batch file.\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false,
"id": "uSepgT4-C9WF"
},
"source": [
"## Data preprocessing"
]
},
{
"cell_type": "code",
"metadata": {
"id": "kHQ1J4C2C9WG"
},
"source": [
"SOS_token = 0\n",
"EOS_token = 1\n",
"\n",
"\n",
"class Lang:\n",
" def __init__(self, name):\n",
" self.name = name\n",
" self.word2index = {}\n",
" self.word2count = {}\n",
" self.index2word = {0: \"SOS\", 1: \"EOS\"}\n",
" self.n_words = 2 # Count SOS and EOS\n",
"\n",
" def addSentence(self, sentence):\n",
" for word in sentence.split(' '):\n",
" self.addWord(word)\n",
"\n",
" def addWord(self, word):\n",
" if word not in self.word2index:\n",
" self.word2index[word] = self.n_words\n",
" self.word2count[word] = 1\n",
" self.index2word[self.n_words] = word\n",
" self.n_words += 1\n",
" else:\n",
" self.word2count[word] += 1"
],
"execution_count": 19,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "3fkgqt4HC9WH"
},
"source": [
"# Turn a Unicode string to plain ASCII, thanks to\n",
"# https://stackoverflow.com/a/518232/2809427\n",
"def unicodeToAscii(s):\n",
" return ''.join(\n",
" c for c in unicodedata.normalize('NFD', s)\n",
" if unicodedata.category(c) != 'Mn'\n",
" )\n",
"\n",
"# Lowercase, trim, and remove non-letter characters\n",
"\n",
"\n",
"def normalizeString(s):\n",
" s = unicodeToAscii(s.lower().strip())\n",
" s = re.sub(r\"([.!?])\", r\" \\1\", s)\n",
" s = re.sub(r\"[^a-zA-Z.!?]+\", r\" \", s)\n",
" return s"
],
"execution_count": 20,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false,
"id": "AIvhIi-XC9WJ"
},
"source": [
"## read data"
]
},
{
"cell_type": "code",
"metadata": {
"id": "1kG0nTPVC9WJ"
},
"source": [
"def readLangs(lang1, lang2, reverse=False):\n",
" print(\"Reading lines...\")\n",
"\n",
" # Read the file and split into lines\n",
" lines = open('.pytorch/engfr/%s-%s.txt' % (lang1, lang2), encoding='utf-8').read().strip().split('\\n')\n",
"\n",
" # Split every line into pairs and normalize\n",
" pairs = [[normalizeString(s) for s in l.split('\\t')] for l in lines]\n",
"\n",
" # Reverse pairs, make Lang instances\n",
" if reverse:\n",
" pairs = [list(reversed(p)) for p in pairs]\n",
" input_lang = Lang(lang2)\n",
" output_lang = Lang(lang1)\n",
" else:\n",
" input_lang = Lang(lang1)\n",
" output_lang = Lang(lang2)\n",
"\n",
" return input_lang, output_lang, pairs"
],
"execution_count": 21,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "aH40yeGHC9WL"
},
"source": [
"MAX_LENGTH = 20\n",
"\n",
"eng_prefixes = (\n",
" \"i am \", \"i m \",\n",
" \"he is\", \"he s \",\n",
" \"she is\", \"she s \",\n",
" \"you are\", \"you re \",\n",
" \"we are\", \"we re \",\n",
" \"they are\", \"they re \"\n",
")\n",
"\n",
"\n",
"def filterPair(p):\n",
" return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH and p[1].startswith(eng_prefixes)\n",
"\n",
"def filterPairs(pairs):\n",
" return [pair for pair in pairs if filterPair(pair)]"
],
"execution_count": 22,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "kD2AZfrCC9WM",
"outputId": "bf035a7b-29e6-4514-c15e-68516a30e6bf"
},
"source": [
"def prepareData(lang1, lang2, reverse=False):\n",
" input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)\n",
" print(\"Read %s sentence pairs\" % len(pairs))\n",
" pairs = filterPairs(pairs)\n",
" print(\"Trimmed to %s sentence pairs\" % len(pairs))\n",
" print(\"Counting words...\")\n",
" for pair in pairs:\n",
" input_lang.addSentence(pair[0])\n",
" output_lang.addSentence(pair[1])\n",
" print(\"Counted words:\")\n",
" print(input_lang.name, input_lang.n_words)\n",
" print(output_lang.name, output_lang.n_words)\n",
" return input_lang, output_lang, pairs\n",
"\n",
"\n",
"input_lang, output_lang, pairs = prepareData('eng', 'fra', True)\n",
"print(random.choice(pairs))"
],
"execution_count": 23,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Reading lines...\n",
"Read 135842 sentence pairs\n",
"Trimmed to 13033 sentence pairs\n",
"Counting words...\n",
"Counted words:\n",
"fra 5143\n",
"eng 3371\n",
"['nous sommes tous ici .', 'we re all here .']\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "KBvyEySxC9WP"
},
"source": [
"def indexesFromSentence(lang, sentence):\n",
" return [lang.word2index[word] for word in sentence.split(' ')]\n",
"\n",
"\n",
"def tensorFromSentence(lang, sentence):\n",
" indexes = indexesFromSentence(lang, sentence)\n",
" indexes.append(EOS_token)\n",
" return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)\n",
"\n",
"\n",
"def tensorsFromPair(pair):\n",
" input_tensor = tensorFromSentence(input_lang, pair[0])\n",
" target_tensor = tensorFromSentence(output_lang, pair[1])\n",
" return (input_tensor, target_tensor)"
],
"execution_count": 24,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false,
"id": "IjA1nldMC9WR"
},
"source": [
"## Model implementation"
]
}
]
}

0 comments on commit beec344

Please sign in to comment.