From c9d371281e7d6f7e9fdde7cf0248a64e10dc74c0 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 19 Apr 2021 12:20:26 -0700 Subject: [PATCH] use youtokentome instead of huggingface for default tokenizer --- README.md | 18 ++++++++++++++-- dalle_pytorch/tokenizer.py | 42 +++++++++++++++++++++++++++++++++++++- generate.py | 7 +++++-- setup.py | 5 +++-- train_dalle.py | 7 +++++-- 5 files changed, 70 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 5ebb43d0..4613e5cf 100644 --- a/README.md +++ b/README.md @@ -434,16 +434,30 @@ accordingly. #### Custom Tokenizer -This repository supports Huggingface Tokenizers if you wish to use it instead of the default simple tokenizer. Simply pass in an extra `--bpe_path` when invoking `train_dalle.py` and `generate.py`, with the path to your BPE json file. +This repository supports custom tokenization with YouTokenToMe, if you wish to use it instead of the default simple tokenizer. Simply pass in an extra `--bpe_path` when invoking `train_dalle.py` and `generate.py`, with the path to your BPE model file. The only requirement is that you use `0` as the padding during tokenization ex. ```sh -$ python train_dalle.py --image_text_folder ./path/to/data --bpe_path ./path/to/bpe.json +$ python train_dalle.py --image_text_folder ./path/to/data --bpe_path ./path/to/bpe.model ``` +To create a BPE model file from scratch, firstly + +```bash +$ pip install youtokentome +``` + +Then you need to prepare a big text file that is a representative sample of the type of text you want to encode. You can then invoke the `youtokentome` command-line tools. You'll also need to specify the vocab size you wish to use, in addition to the corpus of text. + +```bash +$ yttm bpe --vocab_size 8000 --data ./path/to/big/text/file.txt --model ./path/to/bpe.model +``` + +That's it! The BPE model file is now saved to `./path/to/bpe.model` and you can begin training! + #### Chinese You can train with a pretrained chinese tokenizer offered by Huggingface 🤗 by simply passing in an extra flag `--chinese` diff --git a/dalle_pytorch/tokenizer.py b/dalle_pytorch/tokenizer.py index ef6a5744..b2f0cb07 100644 --- a/dalle_pytorch/tokenizer.py +++ b/dalle_pytorch/tokenizer.py @@ -3,7 +3,9 @@ import torch +import youtokentome as yttm from tokenizers import Tokenizer +from tokenizers.processors import ByteLevel from transformers import BertTokenizer import html @@ -157,8 +159,8 @@ class HugTokenizer: def __init__(self, bpe_path = None): bpe_path = Path(bpe_path) assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist' - tokenizer = Tokenizer.from_file(str(bpe_path)) + tokenizer.post_processor = ByteLevel(trim_offsets = True) self.tokenizer = tokenizer self.vocab_size = tokenizer.get_vocab_size() @@ -223,3 +225,41 @@ def tokenize(self, texts, context_length = 256, truncate_text = False): result[i, :len(tokens)] = torch.tensor(tokens) return result + +# yttm tokenizer + +class YttmTokenizer: + def __init__(self, bpe_path = None): + bpe_path = Path(bpe_path) + assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist' + + tokenizer = yttm.BPE(model = str(bpe_path)) + self.tokenizer = tokenizer + self.vocab_size = tokenizer.vocab_size() + + def decode(self, tokens): + if torch.is_tensor(tokens): + tokens = tokens.tolist() + + return self.tokenizer.decode(tokens, ignore_ids = [0]) + + def encode(self, texts): + encoded = self.tokenizer.encode(texts, output_type = yttm.OutputType.ID) + return list(map(torch.tensor, encoded)) + + def tokenize(self, texts, context_length = 256, truncate_text = False): + if isinstance(texts, str): + texts = [texts] + + all_tokens = self.encode(texts) + + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate_text: + tokens = tokens[:context_length] + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/generate.py b/generate.py index 89912c96..e85a6d23 100644 --- a/generate.py +++ b/generate.py @@ -16,7 +16,7 @@ # dalle related classes and utils from dalle_pytorch import DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE1024, DALLE -from dalle_pytorch.tokenizer import tokenizer, HugTokenizer +from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, YttmTokenizer, ChineseTokenizer # argument parsing @@ -43,6 +43,8 @@ parser.add_argument('--bpe_path', type = str, help='path to your huggingface BPE json file') +parser.add_argument('--hug', dest='hug', action = 'store_true') + parser.add_argument('--chinese', dest='chinese', action = 'store_true') parser.add_argument('--taming', dest='taming', action='store_true') @@ -57,7 +59,8 @@ def exists(val): # tokenizer if exists(args.bpe_path): - tokenizer = HugTokenizer(args.bpe_path) + klass = HugTokenizer if args.hug else YttmTokenizer + tokenizer = klass(args.bpe_path) elif args.chinese: tokenizer = ChineseTokenizer() diff --git a/setup.py b/setup.py index 98a06416..ee56881a 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ name = 'dalle-pytorch', packages = find_packages(), include_package_data = True, - version = '0.10.2', + version = '0.10.3', license='MIT', description = 'DALL-E - Pytorch', author = 'Phil Wang', @@ -28,7 +28,8 @@ 'torch>=1.6', 'torchvision', 'transformers', - 'tqdm' + 'tqdm', + 'youtokentome' ], classifiers=[ 'Development Status :: 4 - Beta', diff --git a/train_dalle.py b/train_dalle.py index f2ab95fe..3ea6f14e 100644 --- a/train_dalle.py +++ b/train_dalle.py @@ -21,7 +21,7 @@ from dalle_pytorch import distributed_utils from dalle_pytorch import OpenAIDiscreteVAE, VQGanVAE1024, DiscreteVAE, DALLE -from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, ChineseTokenizer +from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, ChineseTokenizer, YttmTokenizer # argument parsing @@ -48,6 +48,8 @@ parser.add_argument('--taming', dest='taming', action='store_true') +parser.add_argument('--hug', dest='hug', action = 'store_true') + parser.add_argument('--bpe_path', type = str, help='path to your huggingface BPE json file') @@ -93,7 +95,8 @@ def exists(val): # tokenizer if exists(args.bpe_path): - tokenizer = HugTokenizer(args.bpe_path) + klass = HugTokenizer if args.hug else YttmTokenizer + tokenizer = klass(args.bpe_path) elif args.chinese: tokenizer = ChineseTokenizer()