From beec3447bec78793d30777cc2698bf49590d24f1 Mon Sep 17 00:00:00 2001 From: ALi Date: Wed, 7 Dec 2022 12:31:22 +0100 Subject: [PATCH] :sparkles: Transformer block from scratch for Machine translation --- seq2seq_translation_transformer.ipynb | 305 ++++++++++++++++++++++++++ 1 file changed, 305 insertions(+) create mode 100644 seq2seq_translation_transformer.ipynb diff --git a/seq2seq_translation_transformer.ipynb b/seq2seq_translation_transformer.ipynb new file mode 100644 index 0000000..fb4d486 --- /dev/null +++ b/seq2seq_translation_transformer.ipynb @@ -0,0 +1,305 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + }, + "colab": { + "name": "seq2seq_translation_attention.ipynb", + "provenance": [] + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "collapsed": true, + "id": "T0CUqI56C9V9" + }, + "source": [ + "# Transformer block from scratch for Machine translation" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Ql1UG4OeC9WC", + "outputId": "904f003e-b1c5-42af-e61e-72e533dc4758" + }, + "source": [ + "from __future__ import unicode_literals, print_function, division\n", + "from io import open\n", + "import unicodedata\n", + "import string\n", + "import re\n", + "import random\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch import optim\n", + "import torch.nn.functional as F\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "! curl https://download.pytorch.org/tutorial/data.zip --create-dirs -o .pytorch/engfr/data.zip\n", + "! python -m zipfile -e .pytorch/engfr/data.zip .pytorch/engfr/\n", + "! mv .pytorch/engfr/data/* .pytorch/engfr/" + ], + "execution_count": 18, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " % Total % Received % Xferd Average Speed Time Time Time Current\n", + " Dload Upload Total Spent Left Speed\n", + "\n", + " 0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0\n", + " 0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0\n", + " 0 0 0 0 0 0 0 0 --:--:-- 0:00:01 --:--:-- 0\n", + " 6 2814k 6 173k 0 0 70764 0 0:00:40 0:00:02 0:00:38 70905\n", + " 32 2814k 32 913k 0 0 256k 0 0:00:10 0:00:03 0:00:07 256k\n", + " 79 2814k 79 2233k 0 0 496k 0 0:00:05 0:00:04 0:00:01 496k\n", + "100 2814k 100 2814k 0 0 558k 0 0:00:05 0:00:05 --:--:-- 633k\n", + "'mv' is not recognized as an internal or external command,\n", + "operable program or batch file.\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false, + "id": "uSepgT4-C9WF" + }, + "source": [ + "## Data preprocessing" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "kHQ1J4C2C9WG" + }, + "source": [ + "SOS_token = 0\n", + "EOS_token = 1\n", + "\n", + "\n", + "class Lang:\n", + " def __init__(self, name):\n", + " self.name = name\n", + " self.word2index = {}\n", + " self.word2count = {}\n", + " self.index2word = {0: \"SOS\", 1: \"EOS\"}\n", + " self.n_words = 2 # Count SOS and EOS\n", + "\n", + " def addSentence(self, sentence):\n", + " for word in sentence.split(' '):\n", + " self.addWord(word)\n", + "\n", + " def addWord(self, word):\n", + " if word not in self.word2index:\n", + " self.word2index[word] = self.n_words\n", + " self.word2count[word] = 1\n", + " self.index2word[self.n_words] = word\n", + " self.n_words += 1\n", + " else:\n", + " self.word2count[word] += 1" + ], + "execution_count": 19, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "3fkgqt4HC9WH" + }, + "source": [ + "# Turn a Unicode string to plain ASCII, thanks to\n", + "# https://stackoverflow.com/a/518232/2809427\n", + "def unicodeToAscii(s):\n", + " return ''.join(\n", + " c for c in unicodedata.normalize('NFD', s)\n", + " if unicodedata.category(c) != 'Mn'\n", + " )\n", + "\n", + "# Lowercase, trim, and remove non-letter characters\n", + "\n", + "\n", + "def normalizeString(s):\n", + " s = unicodeToAscii(s.lower().strip())\n", + " s = re.sub(r\"([.!?])\", r\" \\1\", s)\n", + " s = re.sub(r\"[^a-zA-Z.!?]+\", r\" \", s)\n", + " return s" + ], + "execution_count": 20, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false, + "id": "AIvhIi-XC9WJ" + }, + "source": [ + "## read data" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "1kG0nTPVC9WJ" + }, + "source": [ + "def readLangs(lang1, lang2, reverse=False):\n", + " print(\"Reading lines...\")\n", + "\n", + " # Read the file and split into lines\n", + " lines = open('.pytorch/engfr/%s-%s.txt' % (lang1, lang2), encoding='utf-8').read().strip().split('\\n')\n", + "\n", + " # Split every line into pairs and normalize\n", + " pairs = [[normalizeString(s) for s in l.split('\\t')] for l in lines]\n", + "\n", + " # Reverse pairs, make Lang instances\n", + " if reverse:\n", + " pairs = [list(reversed(p)) for p in pairs]\n", + " input_lang = Lang(lang2)\n", + " output_lang = Lang(lang1)\n", + " else:\n", + " input_lang = Lang(lang1)\n", + " output_lang = Lang(lang2)\n", + "\n", + " return input_lang, output_lang, pairs" + ], + "execution_count": 21, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "aH40yeGHC9WL" + }, + "source": [ + "MAX_LENGTH = 20\n", + "\n", + "eng_prefixes = (\n", + " \"i am \", \"i m \",\n", + " \"he is\", \"he s \",\n", + " \"she is\", \"she s \",\n", + " \"you are\", \"you re \",\n", + " \"we are\", \"we re \",\n", + " \"they are\", \"they re \"\n", + ")\n", + "\n", + "\n", + "def filterPair(p):\n", + " return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH and p[1].startswith(eng_prefixes)\n", + "\n", + "def filterPairs(pairs):\n", + " return [pair for pair in pairs if filterPair(pair)]" + ], + "execution_count": 22, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "kD2AZfrCC9WM", + "outputId": "bf035a7b-29e6-4514-c15e-68516a30e6bf" + }, + "source": [ + "def prepareData(lang1, lang2, reverse=False):\n", + " input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)\n", + " print(\"Read %s sentence pairs\" % len(pairs))\n", + " pairs = filterPairs(pairs)\n", + " print(\"Trimmed to %s sentence pairs\" % len(pairs))\n", + " print(\"Counting words...\")\n", + " for pair in pairs:\n", + " input_lang.addSentence(pair[0])\n", + " output_lang.addSentence(pair[1])\n", + " print(\"Counted words:\")\n", + " print(input_lang.name, input_lang.n_words)\n", + " print(output_lang.name, output_lang.n_words)\n", + " return input_lang, output_lang, pairs\n", + "\n", + "\n", + "input_lang, output_lang, pairs = prepareData('eng', 'fra', True)\n", + "print(random.choice(pairs))" + ], + "execution_count": 23, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Reading lines...\n", + "Read 135842 sentence pairs\n", + "Trimmed to 13033 sentence pairs\n", + "Counting words...\n", + "Counted words:\n", + "fra 5143\n", + "eng 3371\n", + "['nous sommes tous ici .', 'we re all here .']\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "KBvyEySxC9WP" + }, + "source": [ + "def indexesFromSentence(lang, sentence):\n", + " return [lang.word2index[word] for word in sentence.split(' ')]\n", + "\n", + "\n", + "def tensorFromSentence(lang, sentence):\n", + " indexes = indexesFromSentence(lang, sentence)\n", + " indexes.append(EOS_token)\n", + " return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)\n", + "\n", + "\n", + "def tensorsFromPair(pair):\n", + " input_tensor = tensorFromSentence(input_lang, pair[0])\n", + " target_tensor = tensorFromSentence(output_lang, pair[1])\n", + " return (input_tensor, target_tensor)" + ], + "execution_count": 24, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false, + "id": "IjA1nldMC9WR" + }, + "source": [ + "## Model implementation" + ] + } + ] +}