From fd9e35f7d8d8c500293d48a26a4174a0a3d787e3 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 20 Sep 2024 23:33:03 +0200 Subject: [PATCH] [Bugfix][Core] Fix tekken edge case for mistral tokenizer (#8640) --- .../decoder_only/language/test_mistral.py | 26 ++++++++++++++- vllm/transformers_utils/tokenizers/mistral.py | 32 +++++++++++++++++-- 2 files changed, 54 insertions(+), 4 deletions(-) diff --git a/tests/models/decoder_only/language/test_mistral.py b/tests/models/decoder_only/language/test_mistral.py index 26f90456849f1..174b905d9cbb9 100644 --- a/tests/models/decoder_only/language/test_mistral.py +++ b/tests/models/decoder_only/language/test_mistral.py @@ -4,7 +4,7 @@ """ import pytest -from vllm import SamplingParams +from vllm import LLM, SamplingParams from ...utils import check_logprobs_close @@ -16,6 +16,10 @@ ] SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5) +SYMBOLIC_LANG_PROMPTS = [ + "勇敢な船乗りについての詩を書く", # japanese + "寫一首關於勇敢的水手的詩", # chinese +] # for function calling TOOLS = [{ @@ -131,6 +135,26 @@ def test_mistral_format( ) +@pytest.mark.parametrize("model", MODELS[1:]) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("prompt", SYMBOLIC_LANG_PROMPTS) +def test_mistral_symbolic_languages( + model: str, + dtype: str, + prompt: str, +) -> None: + prompt = "hi" + msg = {"role": "user", "content": prompt} + llm = LLM(model=model, + dtype=dtype, + max_model_len=8192, + tokenizer_mode="mistral", + config_format="mistral", + load_format="mistral") + outputs = llm.chat([msg], sampling_params=SAMPLING_PARAMS) + assert "�" not in outputs[0].outputs[0].text.strip() + + @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("model", MODELS[1:]) # v1 can't do func calling def test_mistral_function_calling( diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 7a228a3efa6e8..788133059f12d 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -175,10 +175,29 @@ def apply_chat_template(self, def convert_tokens_to_string(self, tokens: List[str]) -> str: if isinstance(self.tokenizer, Tekkenizer): - return "".join(t for t in tokens - if t not in self.tokenizer._all_special_tokens) + tokens = [ + t for t in tokens + if t not in self.tokenizer._all_special_tokens + ] + + if any(isinstance(t, bytes) for t in tokens): + # we need to encode and decode all tokens again + shift = self.tokenizer.num_special_tokens + byte_tokens = [ + t.encode("utf-8") if not isinstance(t, bytes) else t + for t in tokens + ] + ids = [ + self.tokenizer._tekken_token2id_nospecial[t] + shift + for t in byte_tokens + ] + decoded = self.tokenizer.decode(ids) + else: + decoded = "".join(tokens) else: - return self.tokenizer.decode(tokens) # type: ignore[arg-type] + decoded = self.tokenizer.decode(tokens) # type: ignore[arg-type] + + return decoded def decode(self, ids: Union[List[int], int]) -> str: if isinstance(ids, int): @@ -200,4 +219,11 @@ def convert_ids_to_tokens( self.tokenizer) tokens = [self.tokenizer.id_to_piece(id) for id in ids] + + if any(t.strip() == "�" for t in tokens): + # if any stripped decoded token is undefined + # because it's invalid unicode then pass bytes + # See: https://github.com/vllm-project/vllm/pull/8640 + tokens = [self.tokenizer.id_to_byte_piece(id) for id in ids] + return tokens