Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for using tiktoken tokenizers #610

Merged
merged 66 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
2964228
unit tests pass
dakinggg Sep 17, 2023
7904938
precommit and add vocab test
dakinggg Sep 17, 2023
29d0564
precommit
dakinggg Sep 17, 2023
8ba4283
precommit
dakinggg Sep 17, 2023
194bc59
fix auto stuff
dakinggg Sep 17, 2023
da37e8c
precommit
dakinggg Sep 17, 2023
246606d
add kwargs to dataset conversion
dakinggg Sep 17, 2023
f821306
precommit
dakinggg Sep 17, 2023
e01fc57
fix non unicode stuff
dakinggg Sep 18, 2023
067bbb4
attempt to ignore test file
dakinggg Sep 18, 2023
44dcf3e
get the ignore right
dakinggg Sep 18, 2023
2acc466
precommit
dakinggg Sep 18, 2023
371185d
precommit
dakinggg Sep 18, 2023
27be411
remove formatter ignore
dakinggg Sep 18, 2023
d00c25b
still precommit
dakinggg Sep 18, 2023
d4bf5b7
license
dakinggg Sep 18, 2023
ff30ce8
more precommit
dakinggg Sep 18, 2023
b34eb8f
pyrigth config change
dakinggg Sep 18, 2023
6fd1874
more tests
dakinggg Sep 18, 2023
c84fee4
add another test
dakinggg Sep 18, 2023
01504eb
precommit
dakinggg Sep 18, 2023
fc7a65d
fix test
dakinggg Sep 18, 2023
bcf49c3
merge
dakinggg Sep 19, 2023
2c0d4e6
rm old implementation
dakinggg Sep 19, 2023
caa51d2
merge with other wrapper
dakinggg Sep 20, 2023
4b05efd
Merge branch 'main' into tiktoken
dakinggg Sep 20, 2023
e6f385f
docs
dakinggg Sep 20, 2023
0e5376a
precommit
dakinggg Sep 20, 2023
0fec0ff
precommit
dakinggg Sep 20, 2023
de3bbbc
test the encoding path
dakinggg Sep 20, 2023
13ce011
precommit
dakinggg Sep 20, 2023
4ad6f87
fix error
dakinggg Sep 20, 2023
a67689f
precommit
dakinggg Sep 20, 2023
bf3ac1b
more thorough tests and fix no split tokens setting
dakinggg Sep 20, 2023
535202b
Merge branch 'main' into tiktoken
dakinggg Sep 20, 2023
e802370
fix tests
dakinggg Sep 20, 2023
56fe3c1
precommit
dakinggg Sep 20, 2023
2ee57b6
precommit
dakinggg Sep 20, 2023
f5acf48
pyright
dakinggg Sep 20, 2023
17002a5
precommit
dakinggg Sep 20, 2023
1f48d62
remove unnecessary cast
dakinggg Sep 20, 2023
85e5740
Merge branch 'main' into tiktoken
dakinggg Sep 20, 2023
75c26de
fix type ignore
dakinggg Sep 20, 2023
fdec6ab
precommit
dakinggg Sep 20, 2023
2893eaf
pr nit
dakinggg Sep 20, 2023
3eb328d
add return types
dakinggg Sep 20, 2023
efed581
docstring typos
dakinggg Sep 20, 2023
473dd0e
precommit
dakinggg Sep 20, 2023
132b6e5
pr comments
dakinggg Sep 20, 2023
32c7bd9
pr comments
dakinggg Sep 20, 2023
925355a
precommit
dakinggg Sep 20, 2023
40ec2c4
use helper
dakinggg Sep 20, 2023
21c7953
precommit
dakinggg Sep 20, 2023
53d0761
fix import
dakinggg Sep 20, 2023
295d42b
precommit
dakinggg Sep 20, 2023
a0c10c1
reprompt tests
dakinggg Sep 20, 2023
7e4e878
pr comments
dakinggg Sep 20, 2023
a669130
fix test
dakinggg Sep 20, 2023
58ecca4
precommit
dakinggg Sep 20, 2023
9e8a0d1
Merge branch 'main' into tiktoken
dakinggg Sep 20, 2023
1e01e11
fix typo
dakinggg Sep 20, 2023
b1a70b8
precommit
dakinggg Sep 20, 2023
2a766fc
clean up tests
dakinggg Sep 20, 2023
483b4f8
precommit
dakinggg Sep 20, 2023
2b5fb5c
Merge branch 'main' into tiktoken
dakinggg Sep 21, 2023
1c5a999
Merge branch 'main' into tiktoken
dakinggg Sep 21, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ repos:
hooks:
- id: yapf
name: yapf
exclude: tests/horrible_strings.py
description: A formatter for Python files.
entry: yapf
args: [-i, -vv, -p] # inplace
Expand Down Expand Up @@ -50,6 +51,7 @@ repos:
- id: debug-statements
- id: destroyed-symlinks
- id: double-quote-string-fixer
exclude: tests/horrible_strings.py
- id: end-of-file-fixer
- id: fix-byte-order-marker
- id: mixed-line-ending
Expand Down Expand Up @@ -93,6 +95,7 @@ repos:
hooks:
- id: pyright
name: pyright
exclude: tests/horrible_strings.py
entry: pyright
language: node
types: [python]
Expand All @@ -104,6 +107,7 @@ repos:
hooks:
- id: trufflehog
name: secret scan
exclude: tests/horrible_strings.py
entry: trufflehog filesystem ./
args:
- --only-verified
Expand Down
2 changes: 2 additions & 0 deletions llmfoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from llmfoundry.models.mpt import (ComposerMPTCausalLM, MPTConfig,
MPTForCausalLM, MPTModel,
MPTPreTrainedModel)
from llmfoundry.tokenizers import TiktokenTokenizerWrapper

except ImportError as e:
try:
Expand Down Expand Up @@ -64,6 +65,7 @@
'build_alibi_bias',
'optim',
'utils',
'TiktokenTokenizerWrapper',
]

__version__ = '0.2.0'
3 changes: 3 additions & 0 deletions llmfoundry/callbacks/monolithic_ckpt_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def _save_checkpoint(self, state: State, logger: Logger) -> None:
) if self.upload_to_object_store else contextlib.nullcontext(
enter_result=save_dir)
with dir_context_mgr as temp_save_dir:
# pyright doesn't know about enter_result
assert isinstance(temp_save_dir, str)

save_path = str(Path(temp_save_dir) / Path(filename))
dirname = os.path.dirname(save_path)
if dirname:
Expand Down
3 changes: 1 addition & 2 deletions llmfoundry/models/inference_api_wrapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
from llmfoundry.models.inference_api_wrapper.interface import \
InferenceAPIEvalWrapper
from llmfoundry.models.inference_api_wrapper.openai_causal_lm import (
OpenAICausalLMEvalWrapper, OpenAIChatAPIEvalWrapper, OpenAITokenizerWrapper)
OpenAICausalLMEvalWrapper, OpenAIChatAPIEvalWrapper)

__all__ = [
'OpenAICausalLMEvalWrapper',
'OpenAIChatAPIEvalWrapper',
'OpenAITokenizerWrapper',
'InferenceAPIEvalWrapper',
]
93 changes: 6 additions & 87 deletions llmfoundry/models/inference_api_wrapper/openai_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,102 +11,21 @@
import torch
from composer.core.types import Batch
from composer.utils.import_helpers import MissingConditionalImportError
from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)
from transformers import AutoTokenizer

log = logging.getLogger(__name__)

from llmfoundry.models.inference_api_wrapper.interface import \
InferenceAPIEvalWrapper

__all__ = [
'OpenAICausalLMEvalWrapper', 'OpenAIChatAPIEvalWrapper',
'OpenAITokenizerWrapper'
'OpenAICausalLMEvalWrapper',
'OpenAIChatAPIEvalWrapper',
]

Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]

MAX_RETRIES = 10


class OpenAITokenizerWrapper(AutoTokenizer):
# this API is experimental and for evaluation only. It is subject to change as we add support for training
def __init__(self, name: str) -> None:
try:
import tiktoken
except ImportError as e:
raise MissingConditionalImportError(
extra_deps_group='openai',
conda_package='tiktoken',
conda_channel='conda-forge') from e
self.tokenizer = tiktoken.encoding_for_model(name)

def __call__(self, x: str, add_special_tokens: bool = False):
if add_special_tokens:
raise ValueError(
'OpenAITokenizerWrapper only supports add_special_tokens=False')
return self.encode(x)

def encode(self,
x: Union[str, List[str]],
add_special_tokens: bool = False):
if add_special_tokens:
raise ValueError(
'OpenAITokenizerWrapper only supports add_special_tokens=False')
if isinstance(x, str):
return {
'input_ids':
self.tokenizer.encode(x, allowed_special={'<|endoftext|>'})
}
elif isinstance(x,
list): # pyright: ignore [reportUnnecessaryIsInstance]
return {
'input_ids':
self.tokenizer.encode_batch(
x, allowed_special={'<|endoftext|>'})
}
else:
raise ValueError(
f'`encode` argument must be str or List[str], got: {type(x)}')

def decode(
self,
x: Union[List[int], List[List[int]]],
):
if len(x) > 0 and isinstance(x[0], list):
return self.tokenizer.decode_batch(
x) # pyright: ignore [reportGeneralTypeIssues]
else:
assert isinstance(x, list)
return self.tokenizer.decode(
x) # pyright: ignore [reportGeneralTypeIssues]

@property
def pad_token_id(self):
return self.tokenizer.eot_token

@property
def eos_token_id(self):
return self.tokenizer.eot_token

@property
def vocab_size(self):
return self.tokenizer.n_vocab

def construct_logit_tensor(self, logprobs: Dict[str, float]):
"""Construct tensor of shape (vocab_size,) mapping words to logprobs.

Args:
logprobs (Dict[str, float]): Dictionary mapping tokens to log probabilities assigned to them by the model.
"""
tensor = torch.tensor([min(logprobs.values()) - 1] * (self.vocab_size))
for k in logprobs:
encoding = self.encode(k)['input_ids']
idx = encoding[0]
tensor[idx] = logprobs[k]
return tensor


class OpenAIEvalInterface(InferenceAPIEvalWrapper):

def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None:
Expand Down Expand Up @@ -185,7 +104,7 @@ def retokenize(self, tokens: List[int], cont_idxs: List[int]):
re-tokenize with the space removed.
"""
original_len = len(tokens)
retokenized_continuation = self.tokenizer.encode(
retokenized_continuation = self.tokenizer(
self.tokenizer.decode(tokens[cont_idxs[0]:cont_idxs[-1] +
1]).strip())['input_ids']

Expand Down Expand Up @@ -275,8 +194,8 @@ def process_result(self, completion: Optional[dict]):
assert isinstance(completion, dict)
if len(completion['choices']) > 0:
tensors = []
for t in self.tokenizer.encode(completion['choices'][0]['message']
['content'])['input_ids']:
for t in self.tokenizer(completion['choices'][0]['message']
['content'])['input_ids']:
tensors.append(
self.tokenizer.construct_logit_tensor(
{self.tokenizer.decode([t]): 0.0}))
Expand Down
8 changes: 8 additions & 0 deletions llmfoundry/tokenizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper

__all__ = [
'TiktokenTokenizerWrapper',
]
Loading
Loading