Skip to content

Commit

Permalink
fix common processor
Browse files Browse the repository at this point in the history
  • Loading branch information
SangbumChoi committed Jul 24, 2024
1 parent 8104521 commit 5d6a088
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 23 deletions.
35 changes: 20 additions & 15 deletions src/transformers/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,12 +736,12 @@ def _merge_kwargs(
The order of operations is as follows:
1) kwargs passed as before have highest priority to preserve BC.
```python
high_priority_kwargs = {"crop_size" = (224, 224), "padding" = "max_length"}
high_priority_kwargs = {"crop_size" = {"height": 222, "width": 222}, "padding" = "max_length"}
processor(..., **high_priority_kwargs)
```
2) kwargs passed as modality-specific kwargs have second priority. This is the recommended API.
```python
processor(..., text_kwargs={"padding": "max_length"}, images_kwargs={"crop_size": (224, 224)}})
processor(..., text_kwargs={"padding": "max_length"}, images_kwargs={"crop_size": {"height": 222, "width": 222}}})
```
3) kwargs passed during instantiation of a modality processor have fourth priority.
```python
Expand Down Expand Up @@ -799,14 +799,20 @@ class MyProcessingKwargs(ProcessingKwargs, CommonKwargs, TextKwargs, ImagesKwarg
output_kwargs.update(default_kwargs)

# gather common kwargs and remove them from individual kwargs if present
common_kwargs = {
key: value
for key, value in kwargs.items()
if key not in ModelProcessorKwargs.__annotations__["text_kwargs"].__annotations__
and key not in ModelProcessorKwargs.__annotations__["images_kwargs"].__annotations__
and key not in ModelProcessorKwargs.__annotations__["audio_kwargs"].__annotations__
and key not in ModelProcessorKwargs.__annotations__["videos_kwargs"].__annotations__
}
common_kwargs = {}
for key, value in kwargs.items():
if key == "common_kwargs":
for common_key, common_value in value.items():
common_kwargs[common_key] = common_value
elif key in ["text_kwargs", "images_kwargs", "audio_kwargs", "videos_kwargs"]:
pass
elif (
key not in ModelProcessorKwargs.__annotations__["text_kwargs"].__annotations__
and key not in ModelProcessorKwargs.__annotations__["images_kwargs"].__annotations__
and key not in ModelProcessorKwargs.__annotations__["audio_kwargs"].__annotations__
and key not in ModelProcessorKwargs.__annotations__["videos_kwargs"].__annotations__
):
common_kwargs[key] = value

# ensure common kwargs are propagated to all relevant modalities
for key, value in common_kwargs.items():
Expand All @@ -820,10 +826,10 @@ class MyProcessingKwargs(ProcessingKwargs, CommonKwargs, TextKwargs, ImagesKwarg
# update modality kwargs with passed kwargs
for modality in output_kwargs:
for modality_key in ModelProcessorKwargs.__annotations__[modality].__annotations__.keys():
if modality in kwargs and modality_key in kwargs[modality]:
output_kwargs[modality][modality_key] = kwargs[modality][modality_key]
elif modality_key in kwargs:
if modality_key in kwargs:
output_kwargs[modality][modality_key] = kwargs[modality_key]
elif modality in kwargs and modality_key in kwargs[modality]:
output_kwargs[modality][modality_key] = kwargs[modality][modality_key]
return output_kwargs

@classmethod
Expand Down Expand Up @@ -988,5 +994,4 @@ def apply_chat_template(
ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub)
if ProcessorMixin.push_to_hub.__doc__ is not None:
ProcessorMixin.push_to_hub.__doc__ = ProcessorMixin.push_to_hub.__doc__.format(
object="processor", object_class="AutoProcessor", object_files="processor files"
)
object="processor", object_class="AutoProcessor", object_
10 changes: 9 additions & 1 deletion tests/models/grounding_dino/test_processor_grounding_dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
@require_torch
@require_vision
class GroundingDinoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
from_pretrained_id = "IDEA-Research/grounding-dino-base"
processor_class = GroundingDinoProcessor

def setUp(self):
Expand All @@ -67,6 +68,13 @@ def setUp(self):
with open(self.image_processor_file, "w", encoding="utf-8") as fp:
json.dump(image_processor_map, fp)

image_processor = GroundingDinoImageProcessor()
tokenizer = BertTokenizer.from_pretrained(self.from_pretrained_id)

processor = GroundingDinoProcessor(image_processor, tokenizer)

processor.save_pretrained(self.tmpdirname)

self.batch_size = 7
self.num_queries = 5
self.embed_dim = 5
Expand Down Expand Up @@ -281,4 +289,4 @@ def test_unstructured_kwargs_batched(self):
)
self.assertEqual(inputs["pixel_values"].shape[2], 214)

self.assertEqual(len(inputs["input_ids"][0]), 11)
self.assertEqual(len(inputs["input_ids"][0]), 6)
14 changes: 7 additions & 7 deletions tests/test_processing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,13 @@ def test_doubly_passed_kwargs(self):

input_str = ["lower newer"]
image_input = self.prepare_image_inputs()
with self.assertRaises(ValueError):
_ = processor(
text=input_str,
images=image_input,
images_kwargs={"crop_size": {"height": 222, "width": 222}},
crop_size={"height": 214, "width": 214},
)
inputs = processor(
text=input_str,
images=image_input,
images_kwargs={"size": {"height": 222, "width": 222}},
size={"height": 35, "width": 35},
)
self.assertEqual(inputs["pixel_values"][0].shape[2], 35)

@require_torch
@require_vision
Expand Down

0 comments on commit 5d6a088

Please sign in to comment.