Skip to content

Commit

Permalink
Adding Llama FastTokenizer support.
Browse files Browse the repository at this point in the history
- Requires huggingface/tokenizers#1183 version
- Only support byte_fallback for llama, raise otherwise (safety net).
- Lots of questions are special tokens

How to test:

```python

from transformers.convert_slow_tokenizer import convert_slow_tokenizer
from transformers import AutoTokenizer
from tokenizers import Tokenizer

tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")

if False:
    new_tokenizer = Tokenizer.from_file("tok.json")
else:
    new_tokenizer = convert_slow_tokenizer(tokenizer)
    new_tokenizer.save("tok.json")

strings = [
    "This is a test",
    "生活的真谛是",
    "生活的真谛是[MASK]。",
    # XXX: This one is problematic because of special tokens
    # "<s> Something something",
]

for string in strings:
    encoded = tokenizer(string)["input_ids"]
    encoded2 = new_tokenizer.encode(string).ids

    assert encoded == encoded2, f"{encoded} != {encoded2}"

    decoded = tokenizer.decode(encoded)
    decoded2 = new_tokenizer.decode(encoded2)

    assert decoded.strip() == decoded2, f"{repr(decoded)} != {repr(decoded2)}"
```

The converter + some test script.

The test script.

Tmp save.

Adding Fast tokenizer + tests.

Adding the tokenization tests.

Correct combination.

Small fix.

Fixing tests.

Fixing with latest update.

Rebased.

fix copies + normalized added tokens  + copies.

Adding doc.

TMP.

Doc + split files.

Doc.

Versions + try import.

Fix Camembert + warnings -> Error.

Fix by ArthurZucker.

Not a decorator.
  • Loading branch information
Narsil committed Apr 5, 2023
1 parent 126eafe commit 2c052e2
Show file tree
Hide file tree
Showing 11 changed files with 204 additions and 25 deletions.
2 changes: 1 addition & 1 deletion docs/source/en/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ Flax), PyTorch, and/or TensorFlow.
| LED | | | | | |
| LeViT | | | | | |
| LiLT | | | | | |
| LLaMA | | | | | |
| LLaMA | | | | | |
| Longformer | | | | | |
| LongT5 | | | | | |
| LUKE | | | | | |
Expand Down
8 changes: 8 additions & 0 deletions docs/source/en/model_doc/llama.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ This model was contributed by [zphang](https://huggingface.co/zphang) with contr
- create_token_type_ids_from_sequences
- save_vocabulary

## LlamaTokenizerFast

[[autodoc]] LlamaTokenizerFast
- build_inputs_with_special_tokens
- get_special_tokens_mask
- create_token_type_ids_from_sequences
- save_vocabulary

## LlamaModel

[[autodoc]] LlamaModel
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
import shutil
from pathlib import Path

from setuptools import setup, Command
from setuptools import Command, setup


# Remove stale transformers.egg-info directory to avoid https://github.com/pypa/pip/issues/5466
Expand Down Expand Up @@ -251,6 +251,7 @@ def run(self):
with open(target, "w", encoding="utf-8", newline="\n") as f:
f.write("\n".join(content))


extras = {}

extras["ja"] = deps_list("fugashi", "ipadic", "unidic_lite", "unidic", "sudachipy", "sudachidict_core", "rhoknp")
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,7 @@
_import_structure["models.layoutlmv3"].append("LayoutLMv3TokenizerFast")
_import_structure["models.layoutxlm"].append("LayoutXLMTokenizerFast")
_import_structure["models.led"].append("LEDTokenizerFast")
_import_structure["models.llama"].append("LlamaTokenizerFast")
_import_structure["models.longformer"].append("LongformerTokenizerFast")
_import_structure["models.lxmert"].append("LxmertTokenizerFast")
_import_structure["models.markuplm"].append("MarkupLMTokenizerFast")
Expand Down Expand Up @@ -4388,6 +4389,7 @@
from .models.layoutlmv3 import LayoutLMv3TokenizerFast
from .models.layoutxlm import LayoutXLMTokenizerFast
from .models.led import LEDTokenizerFast
from .models.llama import LlamaTokenizerFast
from .models.longformer import LongformerTokenizerFast
from .models.lxmert import LxmertTokenizerFast
from .models.markuplm import MarkupLMTokenizerFast
Expand Down
89 changes: 81 additions & 8 deletions src/transformers/convert_slow_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@
allow to make our dependency on SentencePiece optional.
"""

import warnings
from typing import Dict, List, Tuple

from tokenizers import Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
from tokenizers.models import BPE, Unigram, WordPiece

from .utils import requires_backends
Expand Down Expand Up @@ -450,12 +449,13 @@ def __init__(self, *args):
self.proto = m

if self.proto.trainer_spec.byte_fallback:
warnings.warn(
"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
" which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
" tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
"unknown tokens into a sequence of byte tokens matching the original piece of text."
)
if not getattr(self, "handle_byte_fallback", None):
raise RuntimeError(
"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
" which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
" tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
"unknown tokens into a sequence of byte tokens matching the original piece of text."
)

def vocab(self, proto):
return [(piece.piece, piece.score) for piece in proto.pieces]
Expand Down Expand Up @@ -1094,6 +1094,78 @@ def post_processor(self):
)


class LlamaConverter(SpmConverter):
handle_byte_fallback = True

def vocab(self, proto):
vocab = [
("<unk>", 0.0),
("<s>", 0.0),
("</s>", 0.0),
]
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
return vocab

def unk_id(self, proto):
unk_id = 0
return unk_id

def decoder(self, replacement, add_prefix_space):
return decoders.Sequence(
[
decoders.Replace("▁", " "),
decoders.ByteFallback(),
decoders.Fuse(),
decoders.Strip(content=" ", left=1),
]
)

def tokenizer(self, proto):
model_type = proto.trainer_spec.model_type
vocab_scores = self.vocab(proto)
if model_type == 1:
raise RuntimeError("Llama is supposed to be a BPE model!")
elif model_type == 2:
_, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
tokenizer = Tokenizer(
BPE(bpe_vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True, byte_fallback=True)
)
tokenizer.add_special_tokens(
[
AddedToken("<unk>", normalized=True),
AddedToken("<s>", normalized=True),
AddedToken("</s>", normalized=True),
]
)
else:
raise Exception(
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
)

return tokenizer

def normalizer(self, proto):
return normalizers.Sequence(
[
normalizers.Prepend(prepend="▁"),
normalizers.Replace(pattern=" ", content="▁"),
]
)

def pre_tokenizer(self, replacement, add_prefix_space):
return None

def post_processor(self):
return processors.TemplateProcessing(
single="<s> $A",
pair="<s> $A $B",
special_tokens=[
("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
],
)


class MarkupLMConverter(Converter):
def converted(self) -> Tokenizer:
ot = self.original_tokenizer
Expand Down Expand Up @@ -1183,6 +1255,7 @@ def converted(self) -> Tokenizer:
"XLNetTokenizer": XLNetConverter,
"SplinterTokenizer": SplinterConverter,
"XGLMTokenizer": XGLMConverter,
"LlamaTokenizer": LlamaConverter,
}


Expand Down
8 changes: 7 additions & 1 deletion src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,13 @@
("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)),
("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
("lilt", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)),
("llama", ("LlamaTokenizer" if is_sentencepiece_available() else None, None)),
(
"llama",
(
"LlamaTokenizer" if is_sentencepiece_available() else None,
"LlamaTokenizerFast" if is_tokenizers_available() else None,
),
),
("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)),
(
"longt5",
Expand Down
17 changes: 17 additions & 0 deletions src/transformers/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
OptionalDependencyNotAvailable,
_LazyModule,
is_sentencepiece_available,
is_tokenizers_available,
is_torch_available,
)

Expand All @@ -33,6 +34,14 @@
else:
_import_structure["tokenization_llama"] = ["LlamaTokenizer"]

try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_llama_fast"] = ["LlamaTokenizerFast"]

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
Expand All @@ -58,6 +67,14 @@
else:
from .tokenization_llama import LlamaTokenizer

try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_llama_fast import LlamaTokenizerFast

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
Expand Down
19 changes: 19 additions & 0 deletions src/transformers/models/llama/tokenization_llama_fast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils.versions import require_version


require_version("tokenizers>=0.13.3")


class LlamaTokenizerFast(PreTrainedTokenizerFast):
"""
Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding.
"""

def __init__(
self,
*args,
clean_up_tokenization_spaces=False,
**kwargs,
):
super().__init__(*args, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs)
7 changes: 7 additions & 0 deletions src/transformers/utils/dummy_tokenizers_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["tokenizers"])


class LlamaTokenizerFast(metaclass=DummyObject):
_backends = ["tokenizers"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["tokenizers"])


class LongformerTokenizerFast(metaclass=DummyObject):
_backends = ["tokenizers"]

Expand Down
69 changes: 58 additions & 11 deletions tests/models/llama/test_tokenization_llama.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -23,8 +24,10 @@
SPIECE_UNDERLINE,
AddedToken,
LlamaTokenizer,
LlamaTokenizerFast,
is_torch_available,
)
from transformers.convert_slow_tokenizer import convert_slow_tokenizer
from transformers.testing_utils import (
get_tests_dir,
nested_simplify,
Expand Down Expand Up @@ -287,13 +290,11 @@ def test_tokenizer_integration(self):
@require_sentencepiece
@require_tokenizers
class LlamaIntegrationTest(unittest.TestCase):
checkpoint_name = "hf-internal-testing/llama-tokenizer"

@classmethod
def setUpClass(cls):
cls.tokenizer: LlamaTokenizer = LlamaTokenizer.from_pretrained(cls.checkpoint_name)
cls.rust_tokenizer = cls.tokenizer # TODO @narsil replace with the rust one
cls.pad_token_id = 1
checkpoint_name = "hf-internal-testing/llama-tokenizer"
cls.tokenizer: LlamaTokenizer = LlamaTokenizer.from_pretrained(checkpoint_name)
cls.rust_tokenizer = LlamaTokenizerFast.from_pretrained(checkpoint_name)
return cls

@require_torch
Expand All @@ -314,6 +315,27 @@ def integration_tests(self):
},
)

@slow
def test_conversion(self):
# This is excruciatingly slow since it has to recreate the entire merge
# list from the original vocabulary in spm
self.rust_tokenizer.save_pretrained("./out")
with tempfile.TemporaryDirectory() as dirname:
self.rust_tokenizer.save_pretrained(dirname)

with open(os.path.join(dirname, "tokenizer.json"), "r") as f:
old_serialized = f.read()

new_tokenizer = convert_slow_tokenizer(self.tokenizer)
with tempfile.NamedTemporaryFile() as f:
new_tokenizer.save(f.name)
# Re-opening since `f` is in bytes.
new_serialized = open(f.name, "r").read()
with open("out_tokenizer.json", "w") as g:
g.write(new_serialized)

self.assertEqual(old_serialized, new_serialized)

def test_simple_encode_decode(self):
pyth_tokenizer = self.tokenizer
rust_tokenizer = self.rust_tokenizer
Expand Down Expand Up @@ -362,18 +384,43 @@ def test_simple_encode_decode(self):
self.assertEqual(pyth_tokenizer.encode(" Hello"), [1, 29871, 15043])
self.assertEqual(rust_tokenizer.encode(" Hello"), [1, 29871, 15043])

def test_no_differences_showcase(self):
pyth_tokenizer = self.tokenizer
rust_tokenizer = self.rust_tokenizer
self.assertEqual(pyth_tokenizer.encode(""), [1])
self.assertEqual(rust_tokenizer.encode(""), [1])

self.assertEqual(pyth_tokenizer.encode(" "), [1, 259])
self.assertEqual(rust_tokenizer.encode(" "), [1, 259])

self.assertEqual(pyth_tokenizer.encode(" "), [1, 1678])
self.assertEqual(rust_tokenizer.encode(" "), [1, 1678])

self.assertEqual(pyth_tokenizer.encode(" Hello"), [1, 29871, 15043])
self.assertEqual(rust_tokenizer.encode(" Hello"), [1, 29871, 15043])

self.assertEqual(pyth_tokenizer.encode("<s>"), [1, 1])
self.assertEqual(rust_tokenizer.encode("<s>"), [1, 1])

self.assertEqual(pyth_tokenizer.encode(""), [1])
self.assertEqual(rust_tokenizer.encode(""), [1])
def test_no_differences_decode(self):
pyth_tokenizer = self.tokenizer
rust_tokenizer = self.rust_tokenizer

self.assertEqual(pyth_tokenizer.decode([869]), ".")
self.assertEqual(rust_tokenizer.decode([869]), ".")

self.assertEqual(pyth_tokenizer.decode([30112, 869]), "ا .")
self.assertEqual(rust_tokenizer.decode([30112, 869]), "ا .")

def test_no_differences_special_tokens(self):
pyth_tokenizer = self.tokenizer
rust_tokenizer = self.rust_tokenizer
self.assertEqual(pyth_tokenizer.encode(""), [1])
self.assertEqual(rust_tokenizer.encode(""), [1])

self.assertEqual(pyth_tokenizer.encode("<s>"), [1, 1])
self.assertEqual(rust_tokenizer.encode("<s>"), [1, 1])

@unittest.skipIf(
os.getenv("RUN_TOKENIZER_INTEGRATION", "0") == "0",
"RUN_TOKENIZER_INTEGRATION=1 to run tokenizer integration tests",
Expand All @@ -392,8 +439,8 @@ def test_integration_test_xnli(self):

self.assertEqual(encoded1, encoded2)

decoded1 = pyth_tokenizer.decode(encoded1)
decoded2 = rust_tokenizer.decode(encoded2)
decoded1 = pyth_tokenizer.decode(encoded1, skip_special_tokens=True)
decoded2 = rust_tokenizer.decode(encoded2, skip_special_tokens=True)

self.assertEqual(decoded1, decoded2)

Expand All @@ -406,7 +453,7 @@ def test_integration_test_xnli(self):

self.assertEqual(encoded1, encoded2)

decoded1 = pyth_tokenizer.decode(encoded1)
decoded2 = rust_tokenizer.decode(encoded2)
decoded1 = pyth_tokenizer.decode(encoded1, skip_special_tokens=True)
decoded2 = rust_tokenizer.decode(encoded2, skip_special_tokens=True)

self.assertEqual(decoded1, decoded2)
Loading

0 comments on commit 2c052e2

Please sign in to comment.