Skip to content

Commit

Permalink
use youtokentome instead of huggingface for default tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 19, 2021
1 parent 2856439 commit c9d3712
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 9 deletions.
18 changes: 16 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -434,16 +434,30 @@ accordingly.

#### Custom Tokenizer

This repository supports <a href="https://huggingface.co/transformers/main_classes/tokenizer.html">Huggingface Tokenizers</a> if you wish to use it instead of the default simple tokenizer. Simply pass in an extra `--bpe_path` when invoking `train_dalle.py` and `generate.py`, with the path to your BPE json file.
This repository supports custom tokenization with <a href="https://github.com/VKCOM/YouTokenToMe">YouTokenToMe</a>, if you wish to use it instead of the default simple tokenizer. Simply pass in an extra `--bpe_path` when invoking `train_dalle.py` and `generate.py`, with the path to your BPE model file.

The only requirement is that you use `0` as the padding during tokenization

ex.

```sh
$ python train_dalle.py --image_text_folder ./path/to/data --bpe_path ./path/to/bpe.json
$ python train_dalle.py --image_text_folder ./path/to/data --bpe_path ./path/to/bpe.model
```

To create a BPE model file from scratch, firstly

```bash
$ pip install youtokentome
```

Then you need to prepare a big text file that is a representative sample of the type of text you want to encode. You can then invoke the `youtokentome` command-line tools. You'll also need to specify the vocab size you wish to use, in addition to the corpus of text.

```bash
$ yttm bpe --vocab_size 8000 --data ./path/to/big/text/file.txt --model ./path/to/bpe.model
```

That's it! The BPE model file is now saved to `./path/to/bpe.model` and you can begin training!

#### Chinese

You can train with a <a href="https://huggingface.co/bert-base-chinese">pretrained chinese tokenizer</a> offered by Huggingface 🤗 by simply passing in an extra flag `--chinese`
Expand Down
42 changes: 41 additions & 1 deletion dalle_pytorch/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

import torch

import youtokentome as yttm
from tokenizers import Tokenizer
from tokenizers.processors import ByteLevel
from transformers import BertTokenizer

import html
Expand Down Expand Up @@ -157,8 +159,8 @@ class HugTokenizer:
def __init__(self, bpe_path = None):
bpe_path = Path(bpe_path)
assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist'

tokenizer = Tokenizer.from_file(str(bpe_path))
tokenizer.post_processor = ByteLevel(trim_offsets = True)
self.tokenizer = tokenizer
self.vocab_size = tokenizer.get_vocab_size()

Expand Down Expand Up @@ -223,3 +225,41 @@ def tokenize(self, texts, context_length = 256, truncate_text = False):
result[i, :len(tokens)] = torch.tensor(tokens)

return result

# yttm tokenizer

class YttmTokenizer:
def __init__(self, bpe_path = None):
bpe_path = Path(bpe_path)
assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist'

tokenizer = yttm.BPE(model = str(bpe_path))
self.tokenizer = tokenizer
self.vocab_size = tokenizer.vocab_size()

def decode(self, tokens):
if torch.is_tensor(tokens):
tokens = tokens.tolist()

return self.tokenizer.decode(tokens, ignore_ids = [0])

def encode(self, texts):
encoded = self.tokenizer.encode(texts, output_type = yttm.OutputType.ID)
return list(map(torch.tensor, encoded))

def tokenize(self, texts, context_length = 256, truncate_text = False):
if isinstance(texts, str):
texts = [texts]

all_tokens = self.encode(texts)

result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
if truncate_text:
tokens = tokens[:context_length]
else:
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
result[i, :len(tokens)] = torch.tensor(tokens)

return result
7 changes: 5 additions & 2 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# dalle related classes and utils

from dalle_pytorch import DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE1024, DALLE
from dalle_pytorch.tokenizer import tokenizer, HugTokenizer
from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, YttmTokenizer, ChineseTokenizer

# argument parsing

Expand All @@ -43,6 +43,8 @@
parser.add_argument('--bpe_path', type = str,
help='path to your huggingface BPE json file')

parser.add_argument('--hug', dest='hug', action = 'store_true')

parser.add_argument('--chinese', dest='chinese', action = 'store_true')

parser.add_argument('--taming', dest='taming', action='store_true')
Expand All @@ -57,7 +59,8 @@ def exists(val):
# tokenizer

if exists(args.bpe_path):
tokenizer = HugTokenizer(args.bpe_path)
klass = HugTokenizer if args.hug else YttmTokenizer
tokenizer = klass(args.bpe_path)
elif args.chinese:
tokenizer = ChineseTokenizer()

Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'dalle-pytorch',
packages = find_packages(),
include_package_data = True,
version = '0.10.2',
version = '0.10.3',
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',
Expand All @@ -28,7 +28,8 @@
'torch>=1.6',
'torchvision',
'transformers',
'tqdm'
'tqdm',
'youtokentome'
],
classifiers=[
'Development Status :: 4 - Beta',
Expand Down
7 changes: 5 additions & 2 deletions train_dalle.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from dalle_pytorch import distributed_utils
from dalle_pytorch import OpenAIDiscreteVAE, VQGanVAE1024, DiscreteVAE, DALLE
from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, ChineseTokenizer
from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, ChineseTokenizer, YttmTokenizer

# argument parsing

Expand All @@ -48,6 +48,8 @@

parser.add_argument('--taming', dest='taming', action='store_true')

parser.add_argument('--hug', dest='hug', action = 'store_true')

parser.add_argument('--bpe_path', type = str,
help='path to your huggingface BPE json file')

Expand Down Expand Up @@ -93,7 +95,8 @@ def exists(val):
# tokenizer

if exists(args.bpe_path):
tokenizer = HugTokenizer(args.bpe_path)
klass = HugTokenizer if args.hug else YttmTokenizer
tokenizer = klass(args.bpe_path)
elif args.chinese:
tokenizer = ChineseTokenizer()

Expand Down

0 comments on commit c9d3712

Please sign in to comment.