diff --git a/multimodal/src/autogluon/multimodal/configs/model/fusion_mlp_image_text_tabular.yaml b/multimodal/src/autogluon/multimodal/configs/model/fusion_mlp_image_text_tabular.yaml index 843cd97327d..65238ef10ac 100644 --- a/multimodal/src/autogluon/multimodal/configs/model/fusion_mlp_image_text_tabular.yaml +++ b/multimodal/src/autogluon/multimodal/configs/model/fusion_mlp_image_text_tabular.yaml @@ -62,6 +62,7 @@ model: data_types: - "text" tokenizer_name: "hf_auto" + use_fast: True # Use a fast Rust-based tokenizer if it is supported for a given model. If a fast tokenizer is not available for a given model, a normal Python-based tokenizer is returned instead. max_text_len: 512 # If None or <=0, then use the max length of pretrained models. insert_sep: True low_cpu_mem_usage: False diff --git a/multimodal/src/autogluon/multimodal/data/process_text.py b/multimodal/src/autogluon/multimodal/data/process_text.py index a98d3a1f37e..1881c271ab6 100644 --- a/multimodal/src/autogluon/multimodal/data/process_text.py +++ b/multimodal/src/autogluon/multimodal/data/process_text.py @@ -95,6 +95,7 @@ def __init__( train_augment_types: Optional[List[str]] = None, template_config: Optional[DictConfig] = None, normalize_text: Optional[bool] = False, + use_fast: Optional[bool] = True, ): """ Parameters @@ -125,6 +126,11 @@ def __init__( Whether to normalize text to resolve encoding problems. Examples of normalized texts can be found at https://github.com/autogluon/autogluon/tree/master/examples/automm/kaggle_feedback_prize#15-a-few-examples-of-normalized-texts + use_fast + Use a fast Rust-based tokenizer if it is supported for a given model. + If a fast tokenizer is not available for a given model, a normal Python-based tokenizer is returned instead. + See: https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoTokenizer.from_pretrained.use_fast + """ self.prefix = model.prefix self.tokenizer_name = tokenizer_name @@ -136,6 +142,7 @@ def __init__( self.tokenizer = self.get_pretrained_tokenizer( tokenizer_name=tokenizer_name, checkpoint_name=model.checkpoint_name, + use_fast=use_fast, ) if hasattr(self.tokenizer, "deprecation_warnings"): # Disable the warning "Token indices sequence length is longer than the specified maximum sequence..." @@ -410,6 +417,7 @@ def get_special_tokens(tokenizer): def get_pretrained_tokenizer( tokenizer_name: str, checkpoint_name: str, + use_fast: Optional[bool] = True, ): """ Load the tokenizer for a pre-trained huggingface checkpoint. @@ -420,6 +428,10 @@ def get_pretrained_tokenizer( The tokenizer type, e.g., "bert", "clip", "electra", and "hf_auto". checkpoint_name Name of a pre-trained checkpoint. + use_fast + Use a fast Rust-based tokenizer if it is supported for a given model. + If a fast tokenizer is not available for a given model, a normal Python-based tokenizer is returned instead. + See: https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoTokenizer.from_pretrained.use_fast Returns ------- @@ -427,7 +439,7 @@ def get_pretrained_tokenizer( """ try: tokenizer_class = ALL_TOKENIZERS[tokenizer_name] - return tokenizer_class.from_pretrained(checkpoint_name) + return tokenizer_class.from_pretrained(checkpoint_name, use_fast=use_fast) except TypeError as e: try: tokenizer_class = ALL_TOKENIZERS["bert"] diff --git a/multimodal/src/autogluon/multimodal/utils/data.py b/multimodal/src/autogluon/multimodal/utils/data.py index 77abf751309..223e62433a0 100644 --- a/multimodal/src/autogluon/multimodal/utils/data.py +++ b/multimodal/src/autogluon/multimodal/utils/data.py @@ -157,6 +157,7 @@ def create_data_processor( train_augment_types=OmegaConf.select(model_config, "text_train_augment_types"), template_config=getattr(config.data, "templates", OmegaConf.create({"turn_on": False})), normalize_text=getattr(config.data.text, "normalize_text", False), + use_fast=OmegaConf.select(model_config, "use_fast", default=True), ) elif data_type == CATEGORICAL: data_processor = CategoricalProcessor( diff --git a/multimodal/tests/hf_model_list.yaml b/multimodal/tests/hf_model_list.yaml index c70b8e2b608..c2034dcc444 100644 --- a/multimodal/tests/hf_model_list.yaml +++ b/multimodal/tests/hf_model_list.yaml @@ -40,6 +40,7 @@ others_2: - t5-small - microsoft/layoutlmv3-base - microsoft/layoutlmv2-base-uncased +- albert-base-v2 predictor: - CLTL/MedRoBERTa.nl - google/electra-small-discriminator diff --git a/multimodal/tests/unittests/others_2/test_data_processors.py b/multimodal/tests/unittests/others_2/test_data_processors.py new file mode 100644 index 00000000000..6ce83bbf4d8 --- /dev/null +++ b/multimodal/tests/unittests/others_2/test_data_processors.py @@ -0,0 +1,66 @@ +import os +import shutil +import tempfile + +import pytest +from transformers import AlbertTokenizer, AlbertTokenizerFast + +from autogluon.multimodal import MultiModalPredictor +from autogluon.multimodal.constants import TEXT + +from ..utils.unittest_datasets import AEDataset, HatefulMeMesDataset, IDChangeDetectionDataset, PetFinderDataset + +ALL_DATASETS = { + "petfinder": PetFinderDataset, + "hateful_memes": HatefulMeMesDataset, + "ae": AEDataset, +} + + +@pytest.mark.parametrize( + "checkpoint_name,use_fast,tokenizer_type", + [ + ( + "albert-base-v2", + None, + AlbertTokenizerFast, + ), + ( + "albert-base-v2", + True, + AlbertTokenizerFast, + ), + ( + "albert-base-v2", + False, + AlbertTokenizer, + ), + ], +) +def test_tokenizer_use_fast(checkpoint_name, use_fast, tokenizer_type): + dataset = ALL_DATASETS["ae"]() + metric_name = dataset.metric + + predictor = MultiModalPredictor( + label=dataset.label_columns[0], + problem_type=dataset.problem_type, + eval_metric=metric_name, + ) + hyperparameters = { + "data.categorical.convert_to_text": True, + "data.numerical.convert_to_text": True, + "model.hf_text.checkpoint_name": checkpoint_name, + } + if use_fast is not None: + hyperparameters["model.hf_text.use_fast"] = use_fast + + with tempfile.TemporaryDirectory() as save_path: + if os.path.isdir(save_path): + shutil.rmtree(save_path) + predictor.fit( + train_data=dataset.train_df, + time_limit=5, + save_path=save_path, + hyperparameters=hyperparameters, + ) + assert isinstance(predictor._data_processors[TEXT][0].tokenizer, tokenizer_type)