From 009f549f882b91dbedf97ae90878b9bf8a86b3cc Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Thu, 16 Jan 2025 16:39:54 +0000 Subject: [PATCH 1/5] add qwen2_vl image processor fast --- docs/source/en/model_doc/qwen2_vl.md | 5 + src/transformers/__init__.py | 3 + .../models/auto/image_processing_auto.py | 2 +- src/transformers/models/qwen2_vl/__init__.py | 1 + .../image_processing_qwen2_vl_fast.py | 422 ++++++++++++++++++ .../utils/dummy_torchvision_objects.py | 14 + .../test_image_processing_qwen2_vl.py | 286 ++++++------ tests/test_image_processing_common.py | 5 +- 8 files changed, 599 insertions(+), 139 deletions(-) create mode 100644 src/transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py diff --git a/docs/source/en/model_doc/qwen2_vl.md b/docs/source/en/model_doc/qwen2_vl.md index 7c864b860bd8..c39728ef71ec 100644 --- a/docs/source/en/model_doc/qwen2_vl.md +++ b/docs/source/en/model_doc/qwen2_vl.md @@ -315,6 +315,11 @@ model = Qwen2VLForConditionalGeneration.from_pretrained( [[autodoc]] Qwen2VLImageProcessor - preprocess +## Qwen2VLImageProcessorFast + +[[autodoc]] Qwen2VLImageProcessorFast + - preprocess + ## Qwen2VLProcessor [[autodoc]] Qwen2VLProcessor diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 8b089276666a..0c7ae539657e 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1299,6 +1299,7 @@ _import_structure["models.deformable_detr"].append("DeformableDetrImageProcessorFast") _import_structure["models.detr"].append("DetrImageProcessorFast") _import_structure["models.pixtral"].append("PixtralImageProcessorFast") + _import_structure["models.qwen2_vl"].append("Qwen2VLImageProcessorFast") _import_structure["models.rt_detr"].append("RTDetrImageProcessorFast") _import_structure["models.vit"].append("ViTImageProcessorFast") @@ -6397,7 +6398,9 @@ from .models.deformable_detr import DeformableDetrImageProcessorFast from .models.detr import DetrImageProcessorFast from .models.pixtral import PixtralImageProcessorFast + from .models.qwen2_vl import Qwen2VLImageProcessorFast from .models.rt_detr import RTDetrImageProcessorFast + from .models.timm_wrapper import TimmWrapperImageProcessor from .models.vit import ViTImageProcessorFast try: diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 39df65d80457..1cb067f386f1 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -125,7 +125,7 @@ ("poolformer", ("PoolFormerImageProcessor",)), ("pvt", ("PvtImageProcessor",)), ("pvt_v2", ("PvtImageProcessor",)), - ("qwen2_vl", ("Qwen2VLImageProcessor",)), + ("qwen2_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")), ("regnet", ("ConvNextImageProcessor",)), ("resnet", ("ConvNextImageProcessor",)), ("rt_detr", ("RTDetrImageProcessor", "RTDetrImageProcessorFast")), diff --git a/src/transformers/models/qwen2_vl/__init__.py b/src/transformers/models/qwen2_vl/__init__.py index 6d859059f35b..70a719cc3a2f 100644 --- a/src/transformers/models/qwen2_vl/__init__.py +++ b/src/transformers/models/qwen2_vl/__init__.py @@ -20,6 +20,7 @@ if TYPE_CHECKING: from .configuration_qwen2_vl import * from .image_processing_qwen2_vl import * + from .image_processing_qwen2_vl_fast import * from .modeling_qwen2_vl import * from .processing_qwen2_vl import * else: diff --git a/src/transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py b/src/transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py new file mode 100644 index 000000000000..a08b838facc6 --- /dev/null +++ b/src/transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py @@ -0,0 +1,422 @@ +# coding=utf-8 +# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for Qwen2-VL.""" + +from typing import Dict, List, Optional, Union + +from ...image_processing_utils import BatchFeature +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, +) +from ...image_transforms import ( + convert_to_rgb, +) +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + ImageType, + PILImageResampling, + VideoInput, + get_image_size, + get_image_type, + infer_channel_dimension_format, + make_list_of_images, + valid_images, + validate_preprocess_arguments, +) +from ...utils import ( + TensorType, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, + is_vision_available, + logging, +) +from .image_processing_qwen2_vl import make_batched_images, make_batched_videos, smart_resize + + +if is_torch_available(): + import torch + +if is_vision_available(): + from ...image_utils import pil_torch_interpolation_mapping + + +if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F +elif is_torchvision_available(): + from torchvision.transforms import functional as F + +logger = logging.get_logger(__name__) + + +class Qwen2VLImageProcessorFast(BaseImageProcessorFast): + r""" + Constructs a fast Qwen2-VL image processor that dynamically resizes images based on the original images. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use when resizing the image. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): + Mean to use if normalizing the image. This is a float or list of floats for each channel in the image. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): + Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + min_pixels (`int`, *optional*, defaults to `56 * 56`): + The min pixels of the image to resize the image. + max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`): + The max pixels of the image to resize the image. + patch_size (`int`, *optional*, defaults to 14): + The spacial patch size of the vision encoder. + temporal_patch_size (`int`, *optional*, defaults to 2): + The temporal patch size of the vision encoder. + merge_size (`int`, *optional*, defaults to 2): + The merge size of the vision encoder to llm encoder. + """ + + model_input_names = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"] + + def __init__( + self, + do_resize: bool = True, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = True, + min_pixels: int = 56 * 56, + max_pixels: int = 28 * 28 * 1280, + patch_size: int = 14, + temporal_patch_size: int = 2, + merge_size: int = 2, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.do_resize = do_resize + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.min_pixels = min_pixels + self.max_pixels = max_pixels + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.merge_size = merge_size + self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels} + self.do_convert_rgb = do_convert_rgb + + def _preprocess( + self, + images: Union[ImageInput, VideoInput], + do_resize: bool = None, + resample: PILImageResampling = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + device: Optional[Union[str, torch.device]] = None, + ): + """ + Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`. + + Args: + images (`ImageInput`): + Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`. + vision_info (`List[Dict]`, *optional*): + Optional list of dictionaries containing additional information about vision inputs. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Scale factor to use if rescaling the image. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + images = make_list_of_images(images) + + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + image_type = get_image_type(images[0]) + if image_type == ImageType.PIL: + images = [F.pil_to_tensor(image) for image in images] + elif image_type == ImageType.NUMPY: + # not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays + images = [torch.from_numpy(image).contiguous() for image in images] + + if device is not None: + images = [image.to(device) for image in images] + + # We assume that all images have the same channel dimension format. + if input_data_format is None: + input_data_format = infer_channel_dimension_format(images[0]) + if input_data_format == ChannelDimension.LAST: + images = [image.permute(2, 0, 1).contiguous() for image in images] + input_data_format = ChannelDimension.FIRST + + if do_rescale and do_normalize: + # fused rescale and normalize + image_mean = torch.tensor(image_mean, device=images[0].device) * (1.0 / rescale_factor) + image_std = torch.tensor(image_std, device=images[0].device) * (1.0 / rescale_factor) + + height, width = get_image_size(images[0], channel_dim=input_data_format) + interpolation = ( + pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample + ) + resized_height, resized_width = height, width + processed_images = [] + for image in images: + if do_resize: + resized_height, resized_width = smart_resize( + height, + width, + factor=self.patch_size * self.merge_size, + min_pixels=self.min_pixels, + max_pixels=self.max_pixels, + ) + image = F.resize(image, size=(resized_height, resized_width), interpolation=interpolation) + + if do_rescale and do_normalize: + # fused rescale and normalize + image = F.normalize(image.to(dtype=torch.float32), image_mean, image_std) + elif do_rescale: + image = image * rescale_factor + elif do_normalize: + image = F.normalize(image, image_mean, image_std) + + processed_images.append(image) + + patches = torch.stack(processed_images) + if patches.shape[0] % self.temporal_patch_size != 0: + repeats = patches[-1].unsqueeze(0).repeat(self.temporal_patch_size - 1, 1, 1, 1) + patches = torch.cat([patches, repeats], dim=0) + + channel = patches.shape[1] + grid_t = patches.shape[0] // self.temporal_patch_size + grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size + + patches = patches.view( + grid_t, + self.temporal_patch_size, + channel, + grid_h // self.merge_size, + self.merge_size, + self.patch_size, + grid_w // self.merge_size, + self.merge_size, + self.patch_size, + ) + patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8) + flatten_patches = patches.reshape( + grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size + ) + + return flatten_patches, (grid_t, grid_h, grid_w) + + def preprocess( + self, + images: ImageInput, + videos: VideoInput = None, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ): + """ + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + videos (`VideoInput`): + Video to preprocess. Expects a single or batch of videos with pixel values ranging from 0 to 255. If + passing in videos with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + device = kwargs.pop("device", None) + + # Make hashable for cache + image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean + image_std = tuple(image_std) if isinstance(image_std, list) else image_std + + if images is not None: + images = make_batched_images(images) + if videos is not None: + videos = make_batched_videos(videos) + + if images is not None and not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + + if images is not None: + pixel_values, vision_grid_thws = [], [] + for image in images: + patches, image_grid_thw = self._preprocess( + image, + do_resize=do_resize, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + data_format=data_format, + do_convert_rgb=do_convert_rgb, + input_data_format=input_data_format, + device=device, + ) + pixel_values.extend(patches) + vision_grid_thws.append(image_grid_thw) + pixel_values = torch.stack(pixel_values) + vision_grid_thws = torch.tensor(vision_grid_thws) + data = {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws} + + if videos is not None: + pixel_values, vision_grid_thws = [], [] + for images in videos: + patches, video_grid_thw = self._preprocess( + images, + do_resize=do_resize, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + data_format=data_format, + do_convert_rgb=do_convert_rgb, + input_data_format=input_data_format, + device=device, + ) + pixel_values.extend(patches) + vision_grid_thws.append(video_grid_thw) + pixel_values = torch.stack(pixel_values) + vision_grid_thws = torch.tensor(vision_grid_thws) + data = {"pixel_values_videos": pixel_values, "video_grid_thw": vision_grid_thws} + + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["Qwen2VLImageProcessorFast"] diff --git a/src/transformers/utils/dummy_torchvision_objects.py b/src/transformers/utils/dummy_torchvision_objects.py index 747f75386490..cf398bf3da01 100644 --- a/src/transformers/utils/dummy_torchvision_objects.py +++ b/src/transformers/utils/dummy_torchvision_objects.py @@ -30,6 +30,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torchvision"]) +class Qwen2VLImageProcessorFast(metaclass=DummyObject): + _backends = ["torchvision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torchvision"]) + + class RTDetrImageProcessorFast(metaclass=DummyObject): _backends = ["torchvision"] @@ -37,6 +44,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torchvision"]) +class TimmWrapperImageProcessor(metaclass=DummyObject): + _backends = ["torchvision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torchvision"]) + + class ViTImageProcessorFast(metaclass=DummyObject): _backends = ["torchvision"] diff --git a/tests/models/qwen2_vl/test_image_processing_qwen2_vl.py b/tests/models/qwen2_vl/test_image_processing_qwen2_vl.py index 76220dc66e96..317e0e28ad14 100644 --- a/tests/models/qwen2_vl/test_image_processing_qwen2_vl.py +++ b/tests/models/qwen2_vl/test_image_processing_qwen2_vl.py @@ -20,7 +20,7 @@ from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_torch_available, is_vision_available +from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs, prepare_video_inputs @@ -33,6 +33,9 @@ from transformers import Qwen2VLImageProcessor + if is_torchvision_available(): + from transformers import Qwen2VLImageProcessorFast + class Qwen2VLImageProcessingTester: def __init__( @@ -114,6 +117,7 @@ def prepare_video_inputs(self, equal_resolution=False, numpify=False, torchify=F @require_vision class Qwen2VLImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = Qwen2VLImageProcessor if is_vision_available() else None + fast_image_processing_class = Qwen2VLImageProcessorFast if is_torchvision_available() else None def setUp(self): super().setUp() @@ -124,28 +128,30 @@ def image_processor_dict(self): return self.image_processor_tester.prepare_image_processor_dict() def test_image_processor_properties(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - self.assertTrue(hasattr(image_processing, "do_normalize")) - self.assertTrue(hasattr(image_processing, "image_mean")) - self.assertTrue(hasattr(image_processing, "image_std")) - self.assertTrue(hasattr(image_processing, "do_resize")) - self.assertTrue(hasattr(image_processing, "min_pixels")) - self.assertTrue(hasattr(image_processing, "max_pixels")) - self.assertTrue(hasattr(image_processing, "do_convert_rgb")) - self.assertTrue(hasattr(image_processing, "patch_size")) - self.assertTrue(hasattr(image_processing, "temporal_patch_size")) - self.assertTrue(hasattr(image_processing, "merge_size")) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "min_pixels")) + self.assertTrue(hasattr(image_processing, "max_pixels")) + self.assertTrue(hasattr(image_processing, "do_convert_rgb")) + self.assertTrue(hasattr(image_processing, "patch_size")) + self.assertTrue(hasattr(image_processing, "temporal_patch_size")) + self.assertTrue(hasattr(image_processing, "merge_size")) def test_image_processor_from_dict_with_kwargs(self): - image_processor = self.image_processing_class.from_dict(self.image_processor_dict) - self.assertEqual(image_processor.min_pixels, 56 * 56) - self.assertEqual(image_processor.max_pixels, 28 * 28 * 1280) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.min_pixels, 56 * 56) + self.assertEqual(image_processor.max_pixels, 28 * 28 * 1280) - image_processor = self.image_processing_class.from_dict( - self.image_processor_dict, min_pixels=256 * 256, max_pixels=640 * 640 - ) - self.assertEqual(image_processor.min_pixels, 256 * 256) - self.assertEqual(image_processor.max_pixels, 640 * 640) + image_processor = image_processing_class.from_dict( + self.image_processor_dict, min_pixels=256 * 256, max_pixels=640 * 640 + ) + self.assertEqual(image_processor.min_pixels, 256 * 256) + self.assertEqual(image_processor.max_pixels, 640 * 640) def test_select_best_resolution(self): # Test with a final resize resolution @@ -153,134 +159,140 @@ def test_select_best_resolution(self): self.assertEqual(best_resolution, (560, 280)) def test_call_pil(self): - # Initialize image_processing - image_processing = self.image_processing_class(**self.image_processor_dict) - # create random PIL images - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True) - for image in image_inputs: - self.assertIsInstance(image[0], Image.Image) - - # Test not batched input - prcocess_out = image_processing(image_inputs[0], return_tensors="pt") - encoded_images = prcocess_out.pixel_values - image_grid_thws = prcocess_out.image_grid_thw - expected_output_image_shape = (4900, 1176) - expected_image_grid_thws = torch.Tensor([[1, 70, 70]]) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) - self.assertTrue((image_grid_thws == expected_image_grid_thws).all()) - - # Test batched - prcocess_out = image_processing(image_inputs, return_tensors="pt") - encoded_images = prcocess_out.pixel_values - image_grid_thws = prcocess_out.image_grid_thw - expected_output_image_shape = (34300, 1176) - expected_image_grid_thws = torch.Tensor([[1, 70, 70]] * 7) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) - self.assertTrue((image_grid_thws == expected_image_grid_thws).all()) + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random PIL images + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True) + for image in image_inputs: + self.assertIsInstance(image[0], Image.Image) + + # Test not batched input + prcocess_out = image_processing(image_inputs[0], return_tensors="pt") + encoded_images = prcocess_out.pixel_values + image_grid_thws = prcocess_out.image_grid_thw + expected_output_image_shape = (4900, 1176) + expected_image_grid_thws = torch.Tensor([[1, 70, 70]]) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + self.assertTrue((image_grid_thws == expected_image_grid_thws).all()) + + # Test batched + prcocess_out = image_processing(image_inputs, return_tensors="pt") + encoded_images = prcocess_out.pixel_values + image_grid_thws = prcocess_out.image_grid_thw + expected_output_image_shape = (34300, 1176) + expected_image_grid_thws = torch.Tensor([[1, 70, 70]] * 7) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + self.assertTrue((image_grid_thws == expected_image_grid_thws).all()) def test_call_numpy(self): - # Initialize image_processing - image_processing = self.image_processing_class(**self.image_processor_dict) - # create random numpy tensors - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True) - for image in image_inputs: - self.assertIsInstance(image[0], np.ndarray) - - # Test not batched input - prcocess_out = image_processing(image_inputs[0], return_tensors="pt") - encoded_images = prcocess_out.pixel_values - image_grid_thws = prcocess_out.image_grid_thw - expected_output_image_shape = (4900, 1176) - expected_image_grid_thws = torch.Tensor([[1, 70, 70]]) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) - self.assertTrue((image_grid_thws == expected_image_grid_thws).all()) - - # Test batched - prcocess_out = image_processing(image_inputs, return_tensors="pt") - encoded_images = prcocess_out.pixel_values - image_grid_thws = prcocess_out.image_grid_thw - expected_output_image_shape = (34300, 1176) - expected_image_grid_thws = torch.Tensor([[1, 70, 70]] * 7) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) - self.assertTrue((image_grid_thws == expected_image_grid_thws).all()) + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random numpy tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True) + for image in image_inputs: + self.assertIsInstance(image[0], np.ndarray) + + # Test not batched input + prcocess_out = image_processing(image_inputs[0], return_tensors="pt") + encoded_images = prcocess_out.pixel_values + image_grid_thws = prcocess_out.image_grid_thw + expected_output_image_shape = (4900, 1176) + expected_image_grid_thws = torch.Tensor([[1, 70, 70]]) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + self.assertTrue((image_grid_thws == expected_image_grid_thws).all()) + + # Test batched + prcocess_out = image_processing(image_inputs, return_tensors="pt") + encoded_images = prcocess_out.pixel_values + image_grid_thws = prcocess_out.image_grid_thw + expected_output_image_shape = (34300, 1176) + expected_image_grid_thws = torch.Tensor([[1, 70, 70]] * 7) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + self.assertTrue((image_grid_thws == expected_image_grid_thws).all()) def test_call_pytorch(self): - # Initialize image_processing - image_processing = self.image_processing_class(**self.image_processor_dict) - # create random PyTorch tensors - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True) - - for image in image_inputs: - self.assertIsInstance(image[0], torch.Tensor) - - # Test not batched input - prcocess_out = image_processing(image_inputs[0], return_tensors="pt") - encoded_images = prcocess_out.pixel_values - image_grid_thws = prcocess_out.image_grid_thw - expected_output_image_shape = (4900, 1176) - expected_image_grid_thws = torch.Tensor([[1, 70, 70]]) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) - self.assertTrue((image_grid_thws == expected_image_grid_thws).all()) - - # Test batched - prcocess_out = image_processing(image_inputs, return_tensors="pt") - encoded_images = prcocess_out.pixel_values - image_grid_thws = prcocess_out.image_grid_thw - expected_output_image_shape = (34300, 1176) - expected_image_grid_thws = torch.Tensor([[1, 70, 70]] * 7) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) - self.assertTrue((image_grid_thws == expected_image_grid_thws).all()) + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random PyTorch tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True) + + for image in image_inputs: + self.assertIsInstance(image[0], torch.Tensor) + + # Test not batched input + prcocess_out = image_processing(image_inputs[0], return_tensors="pt") + encoded_images = prcocess_out.pixel_values + image_grid_thws = prcocess_out.image_grid_thw + expected_output_image_shape = (4900, 1176) + expected_image_grid_thws = torch.Tensor([[1, 70, 70]]) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + self.assertTrue((image_grid_thws == expected_image_grid_thws).all()) + + # Test batched + prcocess_out = image_processing(image_inputs, return_tensors="pt") + encoded_images = prcocess_out.pixel_values + image_grid_thws = prcocess_out.image_grid_thw + expected_output_image_shape = (34300, 1176) + expected_image_grid_thws = torch.Tensor([[1, 70, 70]] * 7) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + self.assertTrue((image_grid_thws == expected_image_grid_thws).all()) @unittest.skip(reason="Qwen2VLImageProcessor doesn't treat 4 channel PIL and numpy consistently yet") def test_call_numpy_4_channels(self): pass def test_nested_input(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True) - - # Test batched as a list of images - prcocess_out = image_processing(image_inputs, return_tensors="pt") - encoded_images = prcocess_out.pixel_values - image_grid_thws = prcocess_out.image_grid_thw - expected_output_image_shape = (34300, 1176) - expected_image_grid_thws = torch.Tensor([[1, 70, 70]] * 7) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) - self.assertTrue((image_grid_thws == expected_image_grid_thws).all()) - - # Test batched as a nested list of images, where each sublist is one batch - image_inputs_nested = image_inputs[:3] + image_inputs[3:] - prcocess_out = image_processing(image_inputs_nested, return_tensors="pt") - encoded_images_nested = prcocess_out.pixel_values - image_grid_thws_nested = prcocess_out.image_grid_thw - expected_output_image_shape = (34300, 1176) - expected_image_grid_thws = torch.Tensor([[1, 70, 70]] * 7) - self.assertEqual(tuple(encoded_images_nested.shape), expected_output_image_shape) - self.assertTrue((image_grid_thws == expected_image_grid_thws).all()) - - # Image processor should return same pixel values, independently of ipnut format - self.assertTrue((encoded_images_nested == encoded_images).all()) - self.assertTrue((image_grid_thws_nested == expected_image_grid_thws).all()) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True) + + # Test batched as a list of images + prcocess_out = image_processing(image_inputs, return_tensors="pt") + encoded_images = prcocess_out.pixel_values + image_grid_thws = prcocess_out.image_grid_thw + expected_output_image_shape = (34300, 1176) + expected_image_grid_thws = torch.Tensor([[1, 70, 70]] * 7) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + self.assertTrue((image_grid_thws == expected_image_grid_thws).all()) + + # Test batched as a nested list of images, where each sublist is one batch + image_inputs_nested = image_inputs[:3] + image_inputs[3:] + prcocess_out = image_processing(image_inputs_nested, return_tensors="pt") + encoded_images_nested = prcocess_out.pixel_values + image_grid_thws_nested = prcocess_out.image_grid_thw + expected_output_image_shape = (34300, 1176) + expected_image_grid_thws = torch.Tensor([[1, 70, 70]] * 7) + self.assertEqual(tuple(encoded_images_nested.shape), expected_output_image_shape) + self.assertTrue((image_grid_thws == expected_image_grid_thws).all()) + + # Image processor should return same pixel values, independently of ipnut format + self.assertTrue((encoded_images_nested == encoded_images).all()) + self.assertTrue((image_grid_thws_nested == expected_image_grid_thws).all()) def test_video_inputs(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - expected_dims_by_frames = {1: 34300, 2: 34300, 3: 68600, 4: 68600, 5: 102900, 6: 102900} - - for num_frames, expected_dims in expected_dims_by_frames.items(): - image_processor_tester = Qwen2VLImageProcessingTester(self, num_frames=num_frames) - video_inputs = image_processor_tester.prepare_video_inputs(equal_resolution=True) - prcocess_out = image_processing(None, videos=video_inputs, return_tensors="pt") - encoded_video = prcocess_out.pixel_values_videos - expected_output_video_shape = (expected_dims, 1176) - self.assertEqual(tuple(encoded_video.shape), expected_output_video_shape) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + expected_dims_by_frames = {1: 34300, 2: 34300, 3: 68600, 4: 68600, 5: 102900, 6: 102900} + + for num_frames, expected_dims in expected_dims_by_frames.items(): + image_processor_tester = Qwen2VLImageProcessingTester(self, num_frames=num_frames) + video_inputs = image_processor_tester.prepare_video_inputs(equal_resolution=True) + prcocess_out = image_processing(None, videos=video_inputs, return_tensors="pt") + encoded_video = prcocess_out.pixel_values_videos + expected_output_video_shape = (expected_dims, 1176) + self.assertEqual(tuple(encoded_video.shape), expected_output_video_shape) def test_custom_patch_size(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - - for patch_size in (1, 3, 5, 7): - image_processor_tester = Qwen2VLImageProcessingTester(self, patch_size=patch_size) - video_inputs = image_processor_tester.prepare_video_inputs(equal_resolution=True) - prcocess_out = image_processing(None, videos=video_inputs, return_tensors="pt") - encoded_video = prcocess_out.pixel_values_videos - expected_output_video_shape = (171500, 1176) - self.assertEqual(tuple(encoded_video.shape), expected_output_video_shape) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + + for patch_size in (1, 3, 5, 7): + image_processor_tester = Qwen2VLImageProcessingTester(self, patch_size=patch_size) + video_inputs = image_processor_tester.prepare_video_inputs(equal_resolution=True) + prcocess_out = image_processing(None, videos=video_inputs, return_tensors="pt") + encoded_video = prcocess_out.pixel_values_videos + expected_output_video_shape = (171500, 1176) + self.assertEqual(tuple(encoded_video.shape), expected_output_video_shape) diff --git a/tests/test_image_processing_common.py b/tests/test_image_processing_common.py index 971462f9e352..25930ea2790c 100644 --- a/tests/test_image_processing_common.py +++ b/tests/test_image_processing_common.py @@ -181,7 +181,10 @@ def test_slow_fast_equivalence(self): encoding_slow = image_processor_slow(dummy_image, return_tensors="pt") encoding_fast = image_processor_fast(dummy_image, return_tensors="pt") - self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-2)) + self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1)) + self.assertLessEqual( + torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3 + ) @require_vision @require_torch From df6fb86dee7b3493f5106e2bf1623cd7d5f593cd Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Thu, 16 Jan 2025 16:49:59 +0000 Subject: [PATCH 2/5] add device to ImagesKwargs --- src/transformers/processing_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 611e2aa3f20c..b94230c7d4a1 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -171,6 +171,8 @@ class methods and docstrings. The channel dimension format for the output image. input_data_format (`ChannelDimension` or `str`, *optional*): The channel dimension format for the input image. + device (`str`, *optional*): + The device to use for processing (e.g. "cpu", "cuda"), only relevant for fast image processing. """ do_resize: Optional[bool] @@ -188,6 +190,7 @@ class methods and docstrings. do_center_crop: Optional[bool] data_format: Optional[ChannelDimension] input_data_format: Optional[Union[str, ChannelDimension]] + device: Optional[str] class VideosKwargs(TypedDict, total=False): From 720bb66125f82377b050274ff142299c1baceeef Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Thu, 16 Jan 2025 16:58:51 +0000 Subject: [PATCH 3/5] remove automatic fix copies --- src/transformers/utils/dummy_torchvision_objects.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/transformers/utils/dummy_torchvision_objects.py b/src/transformers/utils/dummy_torchvision_objects.py index cf398bf3da01..86c997ea7a13 100644 --- a/src/transformers/utils/dummy_torchvision_objects.py +++ b/src/transformers/utils/dummy_torchvision_objects.py @@ -44,13 +44,6 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torchvision"]) -class TimmWrapperImageProcessor(metaclass=DummyObject): - _backends = ["torchvision"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torchvision"]) - - class ViTImageProcessorFast(metaclass=DummyObject): _backends = ["torchvision"] From d1229c9fe006c00927ebbb8c1e103cfa50dc4dca Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Thu, 16 Jan 2025 17:09:14 +0000 Subject: [PATCH 4/5] fix fast_is_faster_than_slow --- tests/test_image_processing_common.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_image_processing_common.py b/tests/test_image_processing_common.py index 25930ea2790c..1f2d1d0fe7e1 100644 --- a/tests/test_image_processing_common.py +++ b/tests/test_image_processing_common.py @@ -196,6 +196,8 @@ def test_fast_is_faster_than_slow(self): self.skipTest(reason="Skipping speed test as one of the image processors is not defined") def measure_time(image_processor, image): + # Warmup + _ = image_processor(image, return_tensors="pt") start = time.time() _ = image_processor(image, return_tensors="pt") return time.time() - start From 90d0945a1d793c64416d5552f00738d8cc056de9 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Tue, 21 Jan 2025 16:35:22 +0000 Subject: [PATCH 5/5] remove unnecessary import --- src/transformers/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 0c7ae539657e..c8c709238427 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -6400,7 +6400,6 @@ from .models.pixtral import PixtralImageProcessorFast from .models.qwen2_vl import Qwen2VLImageProcessorFast from .models.rt_detr import RTDetrImageProcessorFast - from .models.timm_wrapper import TimmWrapperImageProcessor from .models.vit import ViTImageProcessorFast try: