Skip to content

Commit

Permalink
add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
SangbumChoi committed Jul 15, 2024
1 parent ebc3862 commit c6dc445
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""

import sys
from typing import List, Optional, Tuple, Union
from typing import List, Tuple, Union

from ...image_processing_utils import BatchFeature
from ...image_transforms import center_to_corners_format
Expand Down Expand Up @@ -69,18 +69,18 @@ class GroundingDinoProcessorKwargs(ProcessingKwargs, total=False):
"text_kwargs": {
"add_special_tokens": True,
"padding": False,
"truncation": None,
"max_length": None,
"stride": 0,
"pad_to_multiple_of": None,
"return_attention_mask": None,
"return_overflowing_tokens": False,
"return_special_tokens_mask": False,
"return_offsets_mapping": False,
"return_token_type_ids": True,
"return_token_type_ids": False,
"return_length": False,
"verbose": True,
}
},
"images_kwargs": {
"do_convert_annotations": True,
"do_resize": True,
},
}


Expand Down Expand Up @@ -111,7 +111,8 @@ def __call__(
self,
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
audio=None,
videos=None,
**kwargs: Unpack[GroundingDinoProcessorKwargs],
) -> BatchEncoding:
"""
Expand Down
28 changes: 28 additions & 0 deletions src/transformers/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import inspect
import json
import os
import pathlib
import warnings
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union
Expand All @@ -40,6 +41,7 @@
)
from .utils import (
PROCESSOR_NAME,
ExplicitEnum,
PushToHubMixin,
TensorType,
add_model_info_to_auto_map,
Expand All @@ -56,6 +58,14 @@

logger = logging.get_logger(__name__)

AnnotationType = Dict[str, Union[int, str, List[Dict]]]


class AnnotationFormat(ExplicitEnum):
COCO_DETECTION = "coco_detection"
COCO_PANOPTIC = "coco_panoptic"


# Dynamically import the Transformers module to grab the attribute classes of the processor form their names.
transformers_module = direct_transformers_import(Path(__file__).parent)

Expand Down Expand Up @@ -128,6 +138,12 @@ class ImagesKwargs(TypedDict, total=False):
class methods and docstrings.
Attributes:
annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
List of annotations associated with the image or batch of images.
return_segmentation_masks (`bool`, *optional*):
Whether to return segmentation masks.
masks_path (`str` or `pathlib.Path`, *optional*):
Path to the directory containing the segmentation masks.
do_resize (`bool`, *optional*):
Whether to resize the image.
size (`Dict[str, int]`, *optional*):
Expand All @@ -144,6 +160,8 @@ class methods and docstrings.
Scale factor to use if rescaling the image.
do_normalize (`bool`, *optional*):
Whether to normalize the image.
do_convert_annotations (`bool`, *optional*):
Whether to convert the annotations to the format expected by the model.
image_mean (`float` or `List[float]`, *optional*):
Mean to use if normalizing the image.
image_std (`float` or `List[float]`, *optional*):
Expand All @@ -152,12 +170,19 @@ class methods and docstrings.
Whether to pad the image to the `(max_height, max_width)` of the images in the batch.
do_center_crop (`bool`, *optional*):
Whether to center crop the image.
format (`str` or `AnnotationFormat`, *optional*):
Format of the annotations.
data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the output image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image.
pad_size (`Dict[str, int]`, *optional*):
The size `{"height": int, "width" int}` to pad the images to.
"""

annotations: Optional[Union[AnnotationType, List[AnnotationType]]]
return_segmentation_masks: Optional[bool]
masks_path: Optional[Union[str, pathlib.Path]]
do_resize: Optional[bool]
size: Optional[Dict[str, int]]
size_divisor: Optional[int]
Expand All @@ -166,12 +191,15 @@ class methods and docstrings.
do_rescale: Optional[bool]
rescale_factor: Optional[float]
do_normalize: Optional[bool]
do_convert_annotations: Optional[bool]
image_mean: Optional[Union[float, List[float]]]
image_std: Optional[Union[float, List[float]]]
do_pad: Optional[bool]
do_center_crop: Optional[bool]
format: Optional[Union[str, AnnotationFormat]]
data_format: Optional[ChannelDimension]
input_data_format: Optional[Union[str, ChannelDimension]]
pad_size: Optional[Dict[str, int]]


class VideosKwargs(TypedDict, total=False):
Expand Down
33 changes: 32 additions & 1 deletion tests/models/grounding_dino/test_processor_grounding_dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import IMAGE_PROCESSOR_NAME, is_torch_available, is_vision_available

from ...test_processing_common import ProcessorTesterMixin


if is_torch_available():
import torch
Expand All @@ -40,7 +42,9 @@

@require_torch
@require_vision
class GroundingDinoProcessorTest(unittest.TestCase):
class GroundingDinoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = GroundingDinoProcessor

def setUp(self):
self.tmpdirname = tempfile.mkdtemp()

Expand Down Expand Up @@ -251,3 +255,30 @@ def test_model_input_names(self):
inputs = processor(text=input_str, images=image_input)

self.assertListEqual(list(inputs.keys()), processor.model_input_names)

@require_torch
@require_vision
def test_unstructured_kwargs_batched(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)

input_str = ["lower newer", "upper older longer string"]
image_input = self.prepare_image_inputs() * 2
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
crop_size={"height": 214, "width": 214},
size={"height": 214, "width": 214},
padding="longest",
max_length=76,
)
self.assertEqual(inputs["pixel_values"].shape[2], 214)

self.assertEqual(len(inputs["input_ids"][0]), 11)

0 comments on commit c6dc445

Please sign in to comment.