Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid importing all models when instantiating a pipeline #24960

Merged
merged 2 commits into from
Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/transformers/models/auto/auto_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,7 @@ def __init__(self, config_mapping, model_mapping):
self._config_mapping = config_mapping
self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
self._model_mapping = model_mapping
self._model_mapping._model_mapping = self
self._extra_content = {}
self._modules = {}

Expand Down
12 changes: 0 additions & 12 deletions src/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,6 @@
import tensorflow as tf

from ..models.auto.modeling_tf_auto import (
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
TF_MODEL_WITH_LM_HEAD_MAPPING,
TFAutoModel,
TFAutoModelForCausalLM,
TFAutoModelForImageClassification,
Expand All @@ -110,13 +105,6 @@
import torch

from ..models.auto.modeling_auto import (
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
AutoModel,
AutoModelForAudioClassification,
AutoModelForCausalLM,
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/pipelines/audio_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
from ..models.auto.modeling_auto import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -98,7 +98,7 @@ def __init__(self, *args, **kwargs):
if self.framework != "pt":
raise ValueError(f"The {self.__class__} is only available in PyTorch.")

self.check_model_type(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING)
self.check_model_type(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES)

def __call__(
self,
Expand Down
8 changes: 5 additions & 3 deletions src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
logger = logging.get_logger(__name__)

if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES


def rescale_stride(stride, ratio):
Expand Down Expand Up @@ -205,7 +205,7 @@ def __init__(

if self.model.config.model_type == "whisper":
self.type = "seq2seq_whisper"
elif self.model.__class__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values():
elif self.model.__class__.__name__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.values():
Copy link
Contributor

@fxmarty fxmarty Aug 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sgugger @ArthurZucker Am I correct that this breaks if an user has registered a model with AutoModelForSpeechSeq2Seq.register(AutoConfig, MyCustomModel) and want to use it with the pipeline? Given that the MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES seem to be static and not influenced by register.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the description of the PR I would think that no (_extra_content is what should have been kept).
The self._model_mapping._model_mapping = self adds a field to MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES but it's not taken into account when computing the values() if I am not mistaken

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No new key is added MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, but you can check for its _model_mapping if it exists, as is done in the diff below in pipelines/base.py

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see thank you!

self.type = "seq2seq"
elif (
feature_extractor._processor_class
Expand All @@ -220,7 +220,9 @@ def __init__(
if self.framework == "tf":
raise ValueError("The AutomaticSpeechRecognitionPipeline is only available in PyTorch.")

self.check_model_type(dict(MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.items() + MODEL_FOR_CTC_MAPPING.items()))
mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.copy()
mapping.update(MODEL_FOR_CTC_MAPPING_NAMES)
self.check_model_type(mapping)

def __call__(
self,
Expand Down
14 changes: 10 additions & 4 deletions src/transformers/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,12 +952,18 @@ def check_model_type(self, supported_models: Union[List[str], dict]):
"""
if not isinstance(supported_models, list): # Create from a model mapping
supported_models_names = []
for config, model in supported_models.items():
for _, model_name in supported_models.items():
# Mapping can now contain tuples of models for the same configuration.
if isinstance(model, tuple):
supported_models_names.extend([_model.__name__ for _model in model])
if isinstance(model_name, tuple):
supported_models_names.extend(list(model_name))
else:
supported_models_names.append(model.__name__)
supported_models_names.append(model_name)
if hasattr(supported_models, "_model_mapping"):
for _, model in supported_models._model_mapping._extra_content.items():
if isinstance(model_name, tuple):
supported_models_names.extend([m.__name__ for m in model])
else:
supported_models_names.append(model.__name__)
supported_models = supported_models_names
if self.model.__class__.__name__ not in supported_models:
logger.error(
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/pipelines/depth_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
if is_torch_available():
import torch

from ..models.auto.modeling_auto import MODEL_FOR_DEPTH_ESTIMATION_MAPPING
from ..models.auto.modeling_auto import MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -48,7 +48,7 @@ class DepthEstimationPipeline(Pipeline):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
requires_backends(self, "vision")
self.check_model_type(MODEL_FOR_DEPTH_ESTIMATION_MAPPING)
self.check_model_type(MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES)

def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs):
"""
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/pipelines/document_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
if is_torch_available():
import torch

from ..models.auto.modeling_auto import MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING
from ..models.auto.modeling_auto import MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES

TESSERACT_LOADED = False
if is_pytesseract_available():
Expand Down Expand Up @@ -142,7 +142,7 @@ def __init__(self, *args, **kwargs):
if self.model.config.encoder.model_type != "donut-swin":
raise ValueError("Currently, the only supported VisionEncoderDecoder model is Donut")
else:
self.check_model_type(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING)
self.check_model_type(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES)
if self.model.config.__class__.__name__ == "LayoutLMConfig":
self.model_type = ModelType.LayoutLM
else:
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/pipelines/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
if is_tf_available():
import tensorflow as tf

from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
from ..tf_utils import stable_softmax

if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -57,9 +57,9 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
requires_backends(self, "vision")
self.check_model_type(
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
if self.framework == "tf"
else MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
else MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)

def _sanitize_parameters(self, top_k=None):
Expand Down
21 changes: 9 additions & 12 deletions src/transformers/pipelines/image_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@

if is_torch_available():
from ..models.auto.modeling_auto import (
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING,
MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING,
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES,
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES,
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES,
)


Expand Down Expand Up @@ -71,14 +71,11 @@ def __init__(self, *args, **kwargs):
raise ValueError(f"The {self.__class__} is only available in PyTorch.")

requires_backends(self, "vision")
self.check_model_type(
dict(
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING.items()
+ MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING.items()
+ MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING.items()
+ MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING.items()
)
)
mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES.copy()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sum is not working anymore on the ordered dicts, so had to write it this way.

mapping.update(MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES)
mapping.update(MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES)
mapping.update(MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES)
self.check_model_type(mapping)

def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/pipelines/image_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
from ..image_utils import load_image

if is_tf_available():
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_VISION_2_SEQ_MAPPING
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES

if is_torch_available():
import torch

from ..models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING
from ..models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -55,7 +55,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
requires_backends(self, "vision")
self.check_model_type(
TF_MODEL_FOR_VISION_2_SEQ_MAPPING if self.framework == "tf" else MODEL_FOR_VISION_2_SEQ_MAPPING
TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES if self.framework == "tf" else MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
)

def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None, prompt=None):
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/pipelines/mask_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
if is_torch_available():
import torch

from ..models.auto.modeling_auto import MODEL_FOR_MASK_GENERATION_MAPPING
from ..models.auto.modeling_auto import MODEL_FOR_MASK_GENERATION_MAPPING_NAMES

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -96,7 +96,7 @@ def __init__(self, **kwargs):
if self.framework != "pt":
raise ValueError(f"The {self.__class__} is only available in PyTorch.")

self.check_model_type(MODEL_FOR_MASK_GENERATION_MAPPING)
self.check_model_type(MODEL_FOR_MASK_GENERATION_MAPPING_NAMES)

def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
Expand Down
11 changes: 7 additions & 4 deletions src/transformers/pipelines/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
if is_torch_available():
import torch

from ..models.auto.modeling_auto import MODEL_FOR_OBJECT_DETECTION_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
from ..models.auto.modeling_auto import (
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
)

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -53,9 +56,9 @@ def __init__(self, *args, **kwargs):
raise ValueError(f"The {self.__class__} is only available in PyTorch.")

requires_backends(self, "vision")
self.check_model_type(
dict(MODEL_FOR_OBJECT_DETECTION_MAPPING.items() + MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items())
)
mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES.copy()
mapping.update(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't change the type but token classification here seems like a mistake?

self.check_model_type(mapping)

def _sanitize_parameters(self, **kwargs):
postprocess_kwargs = {}
Expand Down
8 changes: 5 additions & 3 deletions src/transformers/pipelines/question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@
if is_tf_available():
import tensorflow as tf

from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES

Dataset = None

if is_torch_available():
import torch
from torch.utils.data import Dataset

from ..models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
from ..models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES


def decode_spans(
Expand Down Expand Up @@ -270,7 +270,9 @@ def __init__(

self._args_parser = QuestionAnsweringArgumentHandler()
self.check_model_type(
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING if self.framework == "tf" else MODEL_FOR_QUESTION_ANSWERING_MAPPING
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
if self.framework == "tf"
else MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
)

@staticmethod
Expand Down
25 changes: 11 additions & 14 deletions src/transformers/pipelines/table_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@
import torch

from ..models.auto.modeling_auto import (
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
)

if is_tf_available() and is_tensorflow_probability_available():
import tensorflow as tf
import tensorflow_probability as tfp

from ..models.auto.modeling_tf_auto import (
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
)


Expand Down Expand Up @@ -122,16 +122,13 @@ def __init__(self, args_parser=TableQuestionAnsweringArgumentHandler(), *args, *
super().__init__(*args, **kwargs)
self._args_parser = args_parser

self.check_model_type(
dict(
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING.items()
+ TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items()
)
if self.framework == "tf"
else dict(
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING.items() + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items()
)
)
if self.framework == "tf":
mapping = TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES.copy()
mapping.update(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES)
else:
mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES.copy()
mapping.update(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES)
self.check_model_type(mapping)

self.aggregate = bool(getattr(self.model.config, "aggregation_labels", None)) and bool(
getattr(self.model.config, "num_aggregation_labels", None)
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/pipelines/text2text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
if is_tf_available():
import tensorflow as tf

from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES

if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
from ..models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -65,9 +65,9 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.check_model_type(
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
if self.framework == "tf"
else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
)

def _sanitize_parameters(
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/pipelines/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@


if is_tf_available():
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES

if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
from ..models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES


def sigmoid(_outputs):
Expand Down Expand Up @@ -84,9 +84,9 @@ def __init__(self, **kwargs):
super().__init__(**kwargs)

self.check_model_type(
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
if self.framework == "tf"
else MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
else MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
)

def _sanitize_parameters(self, return_all_scores=None, function_to_apply=None, top_k="", **tokenizer_kwargs):
Expand Down
Loading