diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index d27aeeaf720099..a5d8f6f872aabd 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -81,7 +81,16 @@ ] # noqa -VideoInput = Union[np.ndarray, "torch.Tensor", List[np.ndarray], List["torch.Tensor"]] # noqa +VideoInput = Union[ + List["PIL.Image.Image"], + "np.ndarray", + "torch.Tensor", + List["np.ndarray"], + List["torch.Tensor"], + List[List["PIL.Image.Image"]], + List[List["np.ndarrray"]], + List[List["torch.Tensor"]], +] # noqa class ChannelDimension(ExplicitEnum): diff --git a/src/transformers/models/align/processing_align.py b/src/transformers/models/align/processing_align.py index 3bc97afd1ca541..5fdaf051404845 100644 --- a/src/transformers/models/align/processing_align.py +++ b/src/transformers/models/align/processing_align.py @@ -16,8 +16,30 @@ Image/Text processor class for ALIGN """ -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding +from typing import List, Union + + +try: + from typing import Unpack +except ImportError: + from typing_extensions import Unpack + +from ...image_utils import ImageInput +from ...processing_utils import ( + ProcessingKwargs, + ProcessorMixin, +) +from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput + + +class AlignProcessorKwargs(ProcessingKwargs, total=False): + # see processing_utils.ProcessingKwargs documentation for usage. + _defaults = { + "text_kwargs": { + "padding": "max_length", + "max_length": 64, + }, + } class AlignProcessor(ProcessorMixin): @@ -26,12 +48,28 @@ class AlignProcessor(ProcessorMixin): [`BertTokenizer`]/[`BertTokenizerFast`] into a single processor that interits both the image processor and tokenizer functionalities. See the [`~AlignProcessor.__call__`] and [`~OwlViTProcessor.decode`] for more information. + The preferred way of passing kwargs is as a dictionary per modality, see usage example below. + ```python + from transformers import AlignProcessor + from PIL import Image + model_id = "kakaobrain/align-base" + processor = AlignProcessor.from_pretrained(model_id) + + processor( + images=your_pil_image, + text=["What is that?"], + images_kwargs = {"crop_size": {"height": 224, "width": 224}}, + text_kwargs = {"padding": "do_not_pad"}, + common_kwargs = {"return_tensors": "pt"}, + ) + ``` Args: image_processor ([`EfficientNetImageProcessor`]): The image processor is a required input. tokenizer ([`BertTokenizer`, `BertTokenizerFast`]): The tokenizer is a required input. + """ attributes = ["image_processor", "tokenizer"] @@ -41,11 +79,18 @@ class AlignProcessor(ProcessorMixin): def __init__(self, image_processor, tokenizer): super().__init__(image_processor, tokenizer) - def __call__(self, text=None, images=None, padding="max_length", max_length=64, return_tensors=None, **kwargs): + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + images: ImageInput = None, + audio=None, + videos=None, + **kwargs: Unpack[AlignProcessorKwargs], + ) -> BatchEncoding: """ Main method to prepare text(s) and image(s) to be fed as input to the model. This method forwards the `text` - and `kwargs` arguments to BertTokenizerFast's [`~BertTokenizerFast.__call__`] if `text` is not `None` to encode - the text. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to + arguments to BertTokenizerFast's [`~BertTokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` arguments to EfficientNetImageProcessor's [`~EfficientNetImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring of the above two methods for more information. @@ -57,20 +102,12 @@ def __call__(self, text=None, images=None, padding="max_length", max_length=64, images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch tensor. Both channels-first and channels-last formats are supported. - padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `max_length`): - Activates and controls padding for tokenization of input text. Choose between [`True` or `'longest'`, - `'max_length'`, `False` or `'do_not_pad'`] - max_length (`int`, *optional*, defaults to `max_length`): - Maximum padding value to use to pad the input text during tokenization. - return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. - + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: @@ -81,15 +118,22 @@ def __call__(self, text=None, images=None, padding="max_length", max_length=64, - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ if text is None and images is None: - raise ValueError("You have to specify either text or images. Both cannot be none.") - + raise ValueError("You must specify either text or images.") + output_kwargs = self._merge_kwargs( + AlignProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + # then, we can pass correct kwargs to each processor if text is not None: - encoding = self.tokenizer( - text, padding=padding, max_length=max_length, return_tensors=return_tensors, **kwargs - ) + encoding = self.tokenizer(text, **output_kwargs["text_kwargs"]) if images is not None: - image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs) + image_features = self.image_processor(images, **output_kwargs["images_kwargs"]) + + # BC for explicit return_tensors + if "return_tensors" in output_kwargs["common_kwargs"]: + return_tensors = output_kwargs["common_kwargs"].pop("return_tensors", None) if text is not None and images is not None: encoding["pixel_values"] = image_features.pixel_values diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index a21d265b9d1bda..8f939cadfa05d3 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -22,13 +22,26 @@ import os import warnings from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union + +import numpy as np from .dynamic_module_utils import custom_object_save -from .tokenization_utils_base import PreTrainedTokenizerBase +from .image_utils import ChannelDimension, is_vision_available + + +if is_vision_available(): + from .image_utils import PILImageResampling + +from .tokenization_utils_base import ( + PaddingStrategy, + PreTrainedTokenizerBase, + TruncationStrategy, +) from .utils import ( PROCESSOR_NAME, PushToHubMixin, + TensorType, add_model_info_to_auto_map, add_model_info_to_custom_pipelines, cached_file, @@ -54,6 +67,248 @@ } +class TextKwargs(TypedDict, total=False): + """ + Keyword arguments for text processing. For extended documentation, check out tokenization_utils_base methods and + docstrings associated. + + Attributes: + add_special_tokens (`bool`, *optional*) + Whether or not to add special tokens when encoding the sequences. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*) + Activates and controls padding. + truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*): + Activates and controls truncation. + max_length (`int`, *optional*): + Controls the maximum length to use by one of the truncation/padding parameters. + stride (`int`, *optional*): + If set, the overflowing tokens will contain some tokens from the end of the truncated sequence. + is_split_into_words (`bool`, *optional*): + Whether or not the input is already pre-tokenized. + pad_to_multiple_of (`int`, *optional*): + If set, will pad the sequence to a multiple of the provided value. + return_token_type_ids (`bool`, *optional*): + Whether to return token type IDs. + return_attention_mask (`bool`, *optional*): + Whether to return the attention mask. + return_overflowing_tokens (`bool`, *optional*): + Whether or not to return overflowing token sequences. + return_special_tokens_mask (`bool`, *optional*): + Whether or not to return special tokens mask information. + return_offsets_mapping (`bool`, *optional*): + Whether or not to return `(char_start, char_end)` for each token. + return_length (`bool`, *optional*): + Whether or not to return the lengths of the encoded inputs. + verbose (`bool`, *optional*): + Whether or not to print more information and warnings. + padding_side (`str`, *optional*): + The side on which padding will be applied. + """ + + add_special_tokens: Optional[bool] + padding: Union[bool, str, PaddingStrategy] + truncation: Union[bool, str, TruncationStrategy] + max_length: Optional[int] + stride: Optional[int] + is_split_into_words: Optional[bool] + pad_to_multiple_of: Optional[int] + return_token_type_ids: Optional[bool] + return_attention_mask: Optional[bool] + return_overflowing_tokens: Optional[bool] + return_special_tokens_mask: Optional[bool] + return_offsets_mapping: Optional[bool] + return_length: Optional[bool] + verbose: Optional[bool] + padding_side: Optional[str] + + +class ImagesKwargs(TypedDict, total=False): + """ + Keyword arguments for image processing. For extended documentation, check the appropriate ImageProcessor + class methods and docstrings. + + Attributes: + do_resize (`bool`, *optional*): + Whether to resize the image. + size (`Dict[str, int]`, *optional*): + Resize the shorter side of the input to `size["shortest_edge"]`. + size_divisor (`int`, *optional*): + The size by which to make sure both the height and width can be divided. + crop_size (`Dict[str, int]`, *optional*): + Desired output size when applying center-cropping. + resample (`PILImageResampling`, *optional*): + Resampling filter to use if resizing the image. + do_rescale (`bool`, *optional*): + Whether to rescale the image by the specified scale `rescale_factor`. + rescale_factor (`int` or `float`, *optional*): + Scale factor to use if rescaling the image. + do_normalize (`bool`, *optional*): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*): + Mean to use if normalizing the image. + image_std (`float` or `List[float]`, *optional*): + Standard deviation to use if normalizing the image. + do_pad (`bool`, *optional*): + Whether to pad the image to the `(max_height, max_width)` of the images in the batch. + do_center_crop (`bool`, *optional*): + Whether to center crop the image. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the output image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. + """ + + do_resize: Optional[bool] + size: Optional[Dict[str, int]] + size_divisor: Optional[int] + crop_size: Optional[Dict[str, int]] + resample: Optional[Union["PILImageResampling", int]] + do_rescale: Optional[bool] + rescale_factor: Optional[float] + do_normalize: Optional[bool] + image_mean: Optional[Union[float, List[float]]] + image_std: Optional[Union[float, List[float]]] + do_pad: Optional[bool] + do_center_crop: Optional[bool] + data_format: Optional[ChannelDimension] + input_data_format: Optional[Union[str, ChannelDimension]] + + +class VideosKwargs(TypedDict, total=False): + """ + Keyword arguments for video processing. + + Attributes: + do_resize (`bool`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*): + Resize the shorter side of the input to `size["shortest_edge"]`. + size_divisor (`int`, *optional*): + The size by which to make sure both the height and width can be divided. + resample (`PILImageResampling`, *optional*): + Resampling filter to use if resizing the image. + do_rescale (`bool`, *optional*): + Whether to rescale the image by the specified scale `rescale_factor`. + rescale_factor (`int` or `float`, *optional*): + Scale factor to use if rescaling the image. + do_normalize (`bool`, *optional*): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*): + Mean to use if normalizing the image. + image_std (`float` or `List[float]`, *optional*): + Standard deviation to use if normalizing the image. + do_pad (`bool`, *optional*): + Whether to pad the image to the `(max_height, max_width)` of the images in the batch. + do_center_crop (`bool`, *optional*): + Whether to center crop the image. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the output image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. + """ + + do_resize: Optional[bool] + size: Optional[Dict[str, int]] + size_divisor: Optional[int] + resample: Optional["PILImageResampling"] + do_rescale: Optional[bool] + rescale_factor: Optional[float] + do_normalize: Optional[bool] + image_mean: Optional[Union[float, List[float]]] + image_std: Optional[Union[float, List[float]]] + do_pad: Optional[bool] + do_center_crop: Optional[bool] + data_format: Optional[ChannelDimension] + input_data_format: Optional[Union[str, ChannelDimension]] + + +class AudioKwargs(TypedDict, total=False): + """ + Keyword arguments for audio processing. + + Attributes: + sampling_rate (`int`, *optional*): + The sampling rate at which the `raw_speech` input was sampled. + raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`): + The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float + values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not + stereo, i.e. single float per timestep. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`, *optional*): + Activates truncation to cut input sequences longer than *max_length* to *max_length*. + pad_to_multiple_of (`int`, *optional*): + If set, will pad the sequence to a multiple of the provided value. + return_attention_mask (`bool`, *optional*): + Whether or not [`~ASTFeatureExtractor.__call__`] should return `attention_mask`. + """ + + sampling_rate: Optional[int] + raw_speech: Optional[Union["np.ndarray", List[float], List["np.ndarray"], List[List[float]]]] + padding: Optional[Union[bool, str, PaddingStrategy]] + max_length: Optional[int] + truncation: Optional[bool] + pad_to_multiple_of: Optional[int] + return_attention_mask: Optional[bool] + + +class CommonKwargs(TypedDict, total=False): + return_tensors: Optional[Union[str, TensorType]] + + +class ProcessingKwargs(TextKwargs, ImagesKwargs, VideosKwargs, AudioKwargs, CommonKwargs, total=False): + """ + Base class for kwargs passing to processors. + A model should have its own `ModelProcessorKwargs` class that inherits from `ProcessingKwargs` to provide: + 1) Additional typed keys and that this model requires to process inputs. + 2) Default values for existing keys under a `_defaults` attribute. + New keys have to be defined as follows to ensure type hinting is done correctly. + + ```python + # adding a new image kwarg for this model + class ModelImagesKwargs(ImagesKwargs, total=False): + new_image_kwarg: Optional[bool] + + class ModelProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: ModelImagesKwargs + _defaults = { + "images_kwargs: { + "new_image_kwarg": False, + } + "text_kwargs": { + "padding": "max_length", + }, + } + + ``` + """ + + common_kwargs: CommonKwargs = { + **CommonKwargs.__annotations__, + } + text_kwargs: TextKwargs = { + **TextKwargs.__annotations__, + } + images_kwargs: ImagesKwargs = { + **ImagesKwargs.__annotations__, + } + videos_kwargs: VideosKwargs = { + **VideosKwargs.__annotations__, + } + audio_kwargs: AudioKwargs = { + **AudioKwargs.__annotations__, + } + + class ProcessorMixin(PushToHubMixin): """ This is a mixin used to provide saving/loading functionality for all processor classes. @@ -414,6 +669,111 @@ def from_args_and_dict(cls, args, processor_dict: Dict[str, Any], **kwargs): else: return processor + def _merge_kwargs( + self, + ModelProcessorKwargs: ProcessingKwargs, + tokenizer_init_kwargs: Optional[Dict] = None, + **kwargs, + ) -> Dict[str, Dict]: + """ + Method to merge dictionaries of kwargs cleanly separated by modality within a Processor instance. + The order of operations is as follows: + 1) kwargs passed as before have highest priority to preserve BC. + ```python + high_priority_kwargs = {"crop_size" = (224, 224), "padding" = "max_length"} + processor(..., **high_priority_kwargs) + ``` + 2) kwargs passed as modality-specific kwargs have second priority. This is the recommended API. + ```python + processor(..., text_kwargs={"padding": "max_length"}, images_kwargs={"crop_size": (224, 224)}}) + ``` + 3) kwargs passed during instantiation of a modality processor have fourth priority. + ```python + tokenizer = tokenizer_class(..., {"padding": "max_length"}) + image_processor = image_processor_class(...) + processor(tokenizer, image_processor) # will pass max_length unless overriden by kwargs at call + ``` + 4) defaults kwargs specified at processor level have lowest priority. + ```python + class MyProcessingKwargs(ProcessingKwargs, CommonKwargs, TextKwargs, ImagesKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": "max_length", + "max_length": 64, + }, + } + ``` + Args: + ModelProcessorKwargs (`ProcessingKwargs`): + Typed dictionary of kwargs specifically required by the model passed. + tokenizer_init_kwargs (`Dict`, *optional*): + Dictionary of kwargs the tokenizer was instantiated with and need to take precedence over defaults. + + Returns: + output_kwargs (`Dict`): + Dictionary of per-modality kwargs to be passed to each modality-specific processor. + + """ + # Initialize dictionaries + output_kwargs = { + "text_kwargs": {}, + "images_kwargs": {}, + "audio_kwargs": {}, + "videos_kwargs": {}, + "common_kwargs": {}, + } + + default_kwargs = { + "text_kwargs": {}, + "images_kwargs": {}, + "audio_kwargs": {}, + "videos_kwargs": {}, + "common_kwargs": {}, + } + + # get defaults from set model processor kwargs if they exist + for modality in default_kwargs: + default_kwargs[modality] = ModelProcessorKwargs._defaults.get(modality, {}).copy() + # update defaults with arguments from tokenizer init + for modality_key in ModelProcessorKwargs.__annotations__[modality].__annotations__.keys(): + # init with tokenizer init kwargs if necessary + if modality_key in tokenizer_init_kwargs: + default_kwargs[modality][modality_key] = tokenizer_init_kwargs[modality_key] + # now defaults kwargs are updated with the tokenizers defaults. + # pass defaults to output dictionary + output_kwargs.update(default_kwargs) + + # update modality kwargs with passed kwargs + non_modality_kwargs = set(kwargs) - set(output_kwargs) + for modality in output_kwargs: + for modality_key in ModelProcessorKwargs.__annotations__[modality].__annotations__.keys(): + # check if we received a structured kwarg dict or not to handle it correctly + if modality in kwargs: + kwarg_value = kwargs[modality].pop(modality_key, "__empty__") + # check if this key was passed as a flat kwarg. + if kwarg_value != "__empty__" and modality_key in non_modality_kwargs: + raise ValueError( + f"Keyword argument {modality_key} was passed two times: in a dictionary for {modality} and as a **kwarg." + ) + elif modality_key in kwargs: + kwarg_value = kwargs.pop(modality_key, "__empty__") + else: + kwarg_value = "__empty__" + if kwarg_value != "__empty__": + output_kwargs[modality][modality_key] = kwarg_value + # if something remains in kwargs, it belongs to common after flattening + if set(kwargs) & set(default_kwargs): + # here kwargs is dictionary-based since it shares keys with default set + [output_kwargs["common_kwargs"].update(subdict) for _, subdict in kwargs.items()] + else: + # here it's a flat dict + output_kwargs["common_kwargs"].update(kwargs) + + # all modality-specific kwargs are updated with common kwargs + for modality in output_kwargs: + output_kwargs[modality].update(output_kwargs["common_kwargs"]) + return output_kwargs + @classmethod def from_pretrained( cls, diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 8f72b303bbbbcc..ab965c5279b44f 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -126,6 +126,8 @@ class EncodingFast: PreTokenizedInputPair = Tuple[List[str], List[str]] EncodedInputPair = Tuple[List[int], List[int]] +# Define type aliases for text-related non-text modalities +AudioInput = Union["np.ndarray", "torch.Tensor", List["np.ndarray"], List["torch.Tensor"]] # Slow tokenizers used to be saved in three separated files SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json" diff --git a/tests/models/align/test_processor_align.py b/tests/models/align/test_processor_align.py index 12fbea5a50cdfa..3c904e59a8831a 100644 --- a/tests/models/align/test_processor_align.py +++ b/tests/models/align/test_processor_align.py @@ -26,6 +26,8 @@ from transformers.testing_utils import require_vision from transformers.utils import IMAGE_PROCESSOR_NAME, is_vision_available +from ...test_processing_common import ProcessorTesterMixin + if is_vision_available(): from PIL import Image @@ -34,7 +36,9 @@ @require_vision -class AlignProcessorTest(unittest.TestCase): +class AlignProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = AlignProcessor + def setUp(self): self.tmpdirname = tempfile.mkdtemp() @@ -159,7 +163,6 @@ def test_tokenizer(self): encoded_processor = processor(text=input_str) encoded_tok = tokenizer(input_str, padding="max_length", max_length=64) - for key in encoded_tok.keys(): self.assertListEqual(encoded_tok[key], encoded_processor[key]) diff --git a/tests/test_processing_common.py b/tests/test_processing_common.py index 402e6a73515122..074aa2f1d62545 100644 --- a/tests/test_processing_common.py +++ b/tests/test_processing_common.py @@ -14,10 +14,19 @@ # limitations under the License. +import inspect import json import tempfile + + +try: + from typing import Unpack +except ImportError: + from typing_extensions import Unpack import unittest +import numpy as np + from transformers import CLIPTokenizerFast, ProcessorMixin from transformers.models.auto.processing_auto import processor_class_from_name from transformers.testing_utils import ( @@ -30,9 +39,13 @@ if is_vision_available(): + from PIL import Image + from transformers import CLIPImageProcessor +@require_torch +@require_vision @require_torch class ProcessorTesterMixin: processor_class = None @@ -64,6 +77,15 @@ def get_processor(self): processor = self.processor_class(**components, **self.prepare_processor_dict()) return processor + @require_vision + def prepare_image_inputs(self): + """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True, + or a list of PyTorch tensors if one specifies torchify=True. + """ + image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] + image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs] + return image_inputs + def test_processor_to_json_string(self): processor = self.get_processor() obj = json.loads(processor.to_json_string()) @@ -82,6 +104,214 @@ def test_processor_from_and_save_pretrained(self): self.assertEqual(processor_second.to_dict(), processor_first.to_dict()) + # These kwargs-related tests ensure that processors are correctly instantiated. + # they need to be applied only if an image_processor exists. + + def skip_processor_without_typed_kwargs(self, processor): + # TODO this signature check is to test only uniformized processors. + # Once all are updated, remove it. + is_kwargs_typed_dict = False + call_signature = inspect.signature(processor.__call__) + for param in call_signature.parameters.values(): + if param.kind == param.VAR_KEYWORD and param.annotation != param.empty: + is_kwargs_typed_dict = ( + hasattr(param.annotation, "__origin__") and param.annotation.__origin__ == Unpack + ) + if not is_kwargs_typed_dict: + self.skipTest(f"{self.processor_class} doesn't have typed kwargs.") + + @require_vision + @require_torch + def test_tokenizer_defaults_preserved_by_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer", max_length=117) + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + inputs = processor(text=input_str, images=image_input, return_tensors="pt") + self.assertEqual(len(inputs["input_ids"][0]), 117) + + @require_torch + @require_vision + def test_image_processor_defaults_preserved_by_image_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor", crop_size=(234, 234)) + tokenizer = self.get_component("tokenizer", max_length=117) + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + inputs = processor(text=input_str, images=image_input) + self.assertEqual(len(inputs["pixel_values"][0][0]), 234) + + @require_vision + @require_torch + def test_kwargs_overrides_default_tokenizer_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer", max_length=117) + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + inputs = processor(text=input_str, images=image_input, return_tensors="pt", max_length=112) + self.assertEqual(len(inputs["input_ids"][0]), 112) + + @require_torch + @require_vision + def test_kwargs_overrides_default_image_processor_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor", crop_size=(234, 234)) + tokenizer = self.get_component("tokenizer", max_length=117) + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + inputs = processor(text=input_str, images=image_input, crop_size=[224, 224]) + self.assertEqual(len(inputs["pixel_values"][0][0]), 224) + + @require_torch + @require_vision + def test_unstructured_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + inputs = processor( + text=input_str, + images=image_input, + return_tensors="pt", + crop_size={"height": 214, "width": 214}, + padding="max_length", + max_length=76, + ) + + self.assertEqual(inputs["pixel_values"].shape[2], 214) + self.assertEqual(len(inputs["input_ids"][0]), 76) + + @require_torch + @require_vision + def test_unstructured_kwargs_batched(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = ["lower newer", "upper older longer string"] + image_input = self.prepare_image_inputs() * 2 + inputs = processor( + text=input_str, + images=image_input, + return_tensors="pt", + crop_size={"height": 214, "width": 214}, + padding="longest", + max_length=76, + ) + + self.assertEqual(inputs["pixel_values"].shape[2], 214) + + self.assertEqual(len(inputs["input_ids"][0]), 6) + + @require_torch + @require_vision + def test_doubly_passed_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = ["lower newer"] + image_input = self.prepare_image_inputs() + with self.assertRaises(ValueError): + _ = processor( + text=input_str, + images=image_input, + images_kwargs={"crop_size": {"height": 222, "width": 222}}, + crop_size={"height": 214, "width": 214}, + ) + + @require_torch + @require_vision + def test_structured_kwargs_nested(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + # Define the kwargs for each modality + all_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "images_kwargs": {"crop_size": {"height": 214, "width": 214}}, + "text_kwargs": {"padding": "max_length", "max_length": 76}, + } + + inputs = processor(text=input_str, images=image_input, **all_kwargs) + self.skip_processor_without_typed_kwargs(processor) + + self.assertEqual(inputs["pixel_values"].shape[2], 214) + + self.assertEqual(len(inputs["input_ids"][0]), 76) + + @require_torch + @require_vision + def test_structured_kwargs_nested_from_dict(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + # Define the kwargs for each modality + all_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "images_kwargs": {"crop_size": {"height": 214, "width": 214}}, + "text_kwargs": {"padding": "max_length", "max_length": 76}, + } + + inputs = processor(text=input_str, images=image_input, **all_kwargs) + self.assertEqual(inputs["pixel_values"].shape[2], 214) + + self.assertEqual(len(inputs["input_ids"][0]), 76) + class MyProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer"]