From f2f16abae62d2c7e68a22c521910e349ffdbb9af Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Mon, 29 Jan 2024 16:07:35 +0000 Subject: [PATCH] [Whisper] Make tokenizer normalization public (#28136) * [Whisper] Make tokenizer normalization public * add to docs --- docs/source/en/model_doc/whisper.md | 4 ++++ .../models/whisper/tokenization_whisper.py | 21 ++++++++++++++++--- .../whisper/tokenization_whisper_fast.py | 21 +++++++++++++++++-- 3 files changed, 41 insertions(+), 5 deletions(-) diff --git a/docs/source/en/model_doc/whisper.md b/docs/source/en/model_doc/whisper.md index 37411209bf9157..e384d2be908c0b 100644 --- a/docs/source/en/model_doc/whisper.md +++ b/docs/source/en/model_doc/whisper.md @@ -102,6 +102,8 @@ python convert_hf_to_openai.py \ - save_vocabulary - batch_decode - decode + - basic_normalize + - normalize ## WhisperTokenizerFast @@ -113,6 +115,8 @@ python convert_hf_to_openai.py \ - save_vocabulary - batch_decode - decode + - basic_normalize + - normalize ## WhisperFeatureExtractor diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 127f5be6193d72..f853c60e260f50 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -15,6 +15,7 @@ """Tokenization classes for Whisper.""" import json import os +import warnings from functools import lru_cache from typing import List, Optional, Tuple, Union @@ -507,6 +508,20 @@ def _convert_id_to_token(self, index): return self.decoder.get(index, "") def _normalize(self, text): + warnings.warn( + "The private method `_normalize` is deprecated and will be removed in v5 of Transformers." + "You can normalize an input string using the Whisper English normalizer using the `normalize` method." + ) + return self.normalize(text) + + def _basic_normalize(self, text, remove_diacritics=False): + warnings.warn( + "The private method `_basic_normalize` is deprecated and will be removed in v5 of Transformers." + "You can normalize an input string using the Whisper basic normalizer using the `basic_normalize` method." + ) + return self.basic_normalize(text, remove_diacritics=remove_diacritics) + + def normalize(self, text): """ Normalize a given string using the `EnglishTextNormalizer` class, which preforms commons transformation on english text. @@ -515,7 +530,7 @@ def _normalize(self, text): return normalizer(text) @staticmethod - def _basic_normalize(text, remove_diacritics=False): + def basic_normalize(text, remove_diacritics=False): """ Normalize a given string using the `BasicTextNormalizer` class, which preforms commons transformation on multilingual text. @@ -745,10 +760,10 @@ def _decode( text = "".join(sub_texts) if normalize: - clean_text = self._normalize(text) + clean_text = self.normalize(text) return clean_text elif basic_normalize: - clean_text = self._basic_normalize(text, remove_diacritics=remove_diacritics) + clean_text = self.basic_normalize(text, remove_diacritics=remove_diacritics) return clean_text else: return text diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py index 509175be994f75..dc5a3e0dc1f784 100644 --- a/src/transformers/models/whisper/tokenization_whisper_fast.py +++ b/src/transformers/models/whisper/tokenization_whisper_fast.py @@ -16,6 +16,7 @@ import json import os import re +import warnings from functools import lru_cache from typing import List, Optional, Tuple @@ -427,6 +428,22 @@ def _decode( # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._normalize def _normalize(self, text): + warnings.warn( + "The private method `_normalize` is deprecated and will be removed in v5 of Transformers." + "You can normalize an input string using the Whisper English normalizer using the `normalize` method." + ) + return self.normalize(text) + + # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._basic_normalize + def _basic_normalize(self, text, remove_diacritics=False): + warnings.warn( + "The private method `_basic_normalize` is deprecated and will be removed in v5 of Transformers." + "You can normalize an input string using the Whisper basic normalizer using the `basic_normalize` method." + ) + return self.basic_normalize(text, remove_diacritics=remove_diacritics) + + # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.normalize + def normalize(self, text): """ Normalize a given string using the `EnglishTextNormalizer` class, which preforms commons transformation on english text. @@ -435,8 +452,8 @@ def _normalize(self, text): return normalizer(text) @staticmethod - # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._basic_normalize - def _basic_normalize(text, remove_diacritics=False): + # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.basic_normalize + def basic_normalize(text, remove_diacritics=False): """ Normalize a given string using the `BasicTextNormalizer` class, which preforms commons transformation on multilingual text.