Skip to content

Commit

Permalink
Adding Llama FastTokenizer support. (huggingface#22264)
Browse files Browse the repository at this point in the history
* Adding Llama FastTokenizer support.

- Requires huggingface/tokenizers#1183 version
- Only support byte_fallback for llama, raise otherwise (safety net).
- Lots of questions are special tokens

How to test:

```python

from transformers.convert_slow_tokenizer import convert_slow_tokenizer
from transformers import AutoTokenizer
from tokenizers import Tokenizer

tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")

if False:
    new_tokenizer = Tokenizer.from_file("tok.json")
else:
    new_tokenizer = convert_slow_tokenizer(tokenizer)
    new_tokenizer.save("tok.json")

strings = [
    "This is a test",
    "生活的真谛是",
    "生活的真谛是[MASK]。",
    # XXX: This one is problematic because of special tokens
    # "<s> Something something",
]

for string in strings:
    encoded = tokenizer(string)["input_ids"]
    encoded2 = new_tokenizer.encode(string).ids

    assert encoded == encoded2, f"{encoded} != {encoded2}"

    decoded = tokenizer.decode(encoded)
    decoded2 = new_tokenizer.decode(encoded2)

    assert decoded.strip() == decoded2, f"{repr(decoded)} != {repr(decoded2)}"
```

The converter + some test script.

The test script.

Tmp save.

Adding Fast tokenizer + tests.

Adding the tokenization tests.

Correct combination.

Small fix.

Fixing tests.

Fixing with latest update.

Rebased.

fix copies + normalized added tokens  + copies.

Adding doc.

TMP.

Doc + split files.

Doc.

Versions + try import.

Fix Camembert + warnings -> Error.

Fix by ArthurZucker.

Not a decorator.

* Fixing comments.

* Adding more to docstring.

* Doc rewriting.
  • Loading branch information
Narsil authored and xloem committed Apr 9, 2023
1 parent cc78fe4 commit c8791b5
Show file tree
Hide file tree
Showing 11 changed files with 266 additions and 26 deletions.
2 changes: 1 addition & 1 deletion docs/source/en/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ Flax), PyTorch, and/or TensorFlow.
| LayoutLMv3 | | | | | |
| LED | | | | | |
| LeViT | | | | | |
| LLaMA | | | | | |
| LLaMA | | | | | |
| Longformer | | | | | |
| LongT5 | | | | | |
| LUKE | | | | | |
Expand Down
8 changes: 8 additions & 0 deletions docs/source/en/model_doc/llama.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ This model was contributed by [zphang](https://huggingface.co/zphang) with contr
- create_token_type_ids_from_sequences
- save_vocabulary

## LlamaTokenizerFast

[[autodoc]] LlamaTokenizerFast
- build_inputs_with_special_tokens
- get_special_tokens_mask
- create_token_type_ids_from_sequences
- save_vocabulary

## LlamaModel

[[autodoc]] LlamaModel
Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,9 @@
import os
import re
import shutil
from distutils.core import Command
from pathlib import Path

from setuptools import find_packages, setup
from setuptools import Command, find_packages, setup


# Remove stale transformers.egg-info directory to avoid https://github.com/pypa/pip/issues/5466
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,7 @@
_import_structure["models.layoutlmv3"].append("LayoutLMv3TokenizerFast")
_import_structure["models.layoutxlm"].append("LayoutXLMTokenizerFast")
_import_structure["models.led"].append("LEDTokenizerFast")
_import_structure["models.llama"].append("LlamaTokenizerFast")
_import_structure["models.longformer"].append("LongformerTokenizerFast")
_import_structure["models.lxmert"].append("LxmertTokenizerFast")
_import_structure["models.markuplm"].append("MarkupLMTokenizerFast")
Expand Down Expand Up @@ -3564,6 +3565,7 @@
from .models.layoutlmv3 import LayoutLMv3TokenizerFast
from .models.layoutxlm import LayoutXLMTokenizerFast
from .models.led import LEDTokenizerFast
from .models.llama import LlamaTokenizerFast
from .models.longformer import LongformerTokenizerFast
from .models.lxmert import LxmertTokenizerFast
from .models.markuplm import MarkupLMTokenizerFast
Expand Down
89 changes: 81 additions & 8 deletions src/transformers/convert_slow_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@
allow to make our dependency on SentencePiece optional.
"""

import warnings
from typing import Dict, List, Tuple

from tokenizers import Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
from tokenizers.models import BPE, Unigram, WordPiece

from .utils import requires_backends
Expand Down Expand Up @@ -443,12 +442,13 @@ def __init__(self, *args):
self.proto = m

if self.proto.trainer_spec.byte_fallback:
warnings.warn(
"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
" which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
" tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
"unknown tokens into a sequence of byte tokens matching the original piece of text."
)
if not getattr(self, "handle_byte_fallback", None):
raise RuntimeError(
"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
" which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
" tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
"unknown tokens into a sequence of byte tokens matching the original piece of text."
)

def vocab(self, proto):
return [(piece.piece, piece.score) for piece in proto.pieces]
Expand Down Expand Up @@ -1043,6 +1043,78 @@ def post_processor(self):
)


class LlamaConverter(SpmConverter):
handle_byte_fallback = True

def vocab(self, proto):
vocab = [
("<unk>", 0.0),
("<s>", 0.0),
("</s>", 0.0),
]
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
return vocab

def unk_id(self, proto):
unk_id = 0
return unk_id

def decoder(self, replacement, add_prefix_space):
return decoders.Sequence(
[
decoders.Replace("▁", " "),
decoders.ByteFallback(),
decoders.Fuse(),
decoders.Strip(content=" ", left=1),
]
)

def tokenizer(self, proto):
model_type = proto.trainer_spec.model_type
vocab_scores = self.vocab(proto)
if model_type == 1:
raise RuntimeError("Llama is supposed to be a BPE model!")
elif model_type == 2:
_, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
tokenizer = Tokenizer(
BPE(bpe_vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True, byte_fallback=True)
)
tokenizer.add_special_tokens(
[
AddedToken("<unk>", normalized=True),
AddedToken("<s>", normalized=True),
AddedToken("</s>", normalized=True),
]
)
else:
raise Exception(
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
)

return tokenizer

def normalizer(self, proto):
return normalizers.Sequence(
[
normalizers.Prepend(prepend="▁"),
normalizers.Replace(pattern=" ", content="▁"),
]
)

def pre_tokenizer(self, replacement, add_prefix_space):
return None

def post_processor(self):
return processors.TemplateProcessing(
single="<s> $A",
pair="<s> $A $B",
special_tokens=[
("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
],
)


class MarkupLMConverter(Converter):
def converted(self) -> Tokenizer:
ot = self.original_tokenizer
Expand Down Expand Up @@ -1131,6 +1203,7 @@ def converted(self) -> Tokenizer:
"XLNetTokenizer": XLNetConverter,
"SplinterTokenizer": SplinterConverter,
"XGLMTokenizer": XGLMConverter,
"LlamaTokenizer": LlamaConverter,
}


Expand Down
8 changes: 7 additions & 1 deletion src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,13 @@
("layoutlmv3", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)),
("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)),
("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
("llama", ("LlamaTokenizer" if is_sentencepiece_available() else None, None)),
(
"llama",
(
"LlamaTokenizer" if is_sentencepiece_available() else None,
"LlamaTokenizerFast" if is_tokenizers_available() else None,
),
),
("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)),
(
"longt5",
Expand Down
17 changes: 17 additions & 0 deletions src/transformers/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
OptionalDependencyNotAvailable,
_LazyModule,
is_sentencepiece_available,
is_tokenizers_available,
is_torch_available,
)

Expand All @@ -33,6 +34,14 @@
else:
_import_structure["tokenization_llama"] = ["LlamaTokenizer"]

try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_llama_fast"] = ["LlamaTokenizerFast"]

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
Expand All @@ -58,6 +67,14 @@
else:
from .tokenization_llama import LlamaTokenizer

try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_llama_fast import LlamaTokenizerFast

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
Expand Down
82 changes: 82 additions & 0 deletions src/transformers/models/llama/tokenization_llama_fast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils.versions import require_version


require_version("tokenizers>=0.13.3")


class LlamaTokenizerFast(PreTrainedTokenizerFast):
"""
Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding.
This uses notably ByteFallback and no normalization.
```
from transformers import LlamaTokenizerFast
tokenizer = LlaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
tokenizer.encode("Hello this is a test")
>>> [1, 15043, 445, 338, 263, 1243]
```
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
refer to this superclass for more information regarding those methods.
Args:
vocab_file (`str`):
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
contains the vocabulary necessary to instantiate a tokenizer.
tokenizer_file (`str`):
[tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
contains everything needed to load the tokenizer.
clean_up_tokenization_spaces (`str`, *optional*, defaults to `False`):
Wether to cleanup spaces after decoding, cleanup consists in removing potential artifacts like extra
spaces.
bos_token (`str`, *optional*, defaults to `"<s>"`):
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
eos_token (`str`, *optional*, defaults to `"</s>"`):
The end of sequence token.
unk_token (`str`, *optional*, defaults to `"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
"""

padding_side = "left"

def __init__(
self,
vocab_file=None,
tokenizer_file=None,
clean_up_tokenization_spaces=False,
unk_token="<unk>",
bos_token="<s>",
eos_token="</s>",
**kwargs,
):
super().__init__(
vocab_file=vocab_file,
tokenizer_file=tokenizer_file,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
**kwargs,
)
7 changes: 7 additions & 0 deletions src/transformers/utils/dummy_tokenizers_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["tokenizers"])


class LlamaTokenizerFast(metaclass=DummyObject):
_backends = ["tokenizers"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["tokenizers"])


class LongformerTokenizerFast(metaclass=DummyObject):
_backends = ["tokenizers"]

Expand Down
Loading

0 comments on commit c8791b5

Please sign in to comment.