Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add initial design for uniform processors + align model #31197

Merged
merged 49 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
b85036f
add initial design for uniform processors + align model
molbap Jun 3, 2024
bb8ac70
fix mutable default :eyes:
molbap Jun 3, 2024
cd8c601
add configuration test
molbap Jun 3, 2024
f00c852
handle structured kwargs w defaults + add test
molbap Jun 3, 2024
693036f
protect torch-specific test
molbap Jun 3, 2024
766da3a
fix style
molbap Jun 3, 2024
844394d
fix
molbap Jun 3, 2024
c19bbc6
fix assertEqual
molbap Jun 4, 2024
3c38119
move kwargs merging to processing common
molbap Jun 4, 2024
81ae819
rework kwargs for type hinting
molbap Jun 5, 2024
ce4abcd
just get Unpack from extensions
molbap Jun 7, 2024
3acdf28
run-slow[align]
molbap Jun 7, 2024
404239f
handle kwargs passed as nested dict
molbap Jun 7, 2024
603be40
add from_pretrained test for nested kwargs handling
molbap Jun 7, 2024
71c9d6c
[run-slow]align
molbap Jun 7, 2024
26383c5
update documentation + imports
molbap Jun 7, 2024
4521f4f
update audio inputs
molbap Jun 7, 2024
b96eb64
protect audio types, silly
molbap Jun 7, 2024
9c5c01c
try removing imports
molbap Jun 7, 2024
3ccb505
make things simpler
molbap Jun 7, 2024
142acf3
simplerer
molbap Jun 7, 2024
60a5730
move out kwargs test to common mixin
molbap Jun 10, 2024
be6c141
[run-slow]align
molbap Jun 10, 2024
84135d7
skip tests for old processors
molbap Jun 10, 2024
ce967ac
[run-slow]align, clip
molbap Jun 10, 2024
f78ec52
!$#@!! protect imports, darn it
molbap Jun 10, 2024
52fd5ad
[run-slow]align, clip
molbap Jun 10, 2024
8f21abe
Merge branch 'main' into uniform_processors_1
molbap Jun 10, 2024
d510030
[run-slow]align, clip
molbap Jun 10, 2024
fd43bcd
update doc
molbap Jun 11, 2024
b2cd7c9
improve documentation for default values
molbap Jun 11, 2024
bcbd646
add model_max_length testing
molbap Jun 11, 2024
39c1587
Raise if kwargs are specified in two places
molbap Jun 11, 2024
1f73bdf
fix
molbap Jun 11, 2024
b3f98ba
Merge branch 'main' into uniform_processors_1
molbap Jun 11, 2024
e4d6d12
expand VideoInput
molbap Jun 12, 2024
1e09e4a
fix
molbap Jun 12, 2024
d4232f0
fix style
molbap Jun 12, 2024
162b1a7
remove defaults values
molbap Jun 12, 2024
0da1dc3
add comment to indicate documentation on adding kwargs
molbap Jun 12, 2024
f955510
Merge branch 'main' into uniform_processors_1
molbap Jun 12, 2024
f6f1dac
protect imports
molbap Jun 12, 2024
c4b7e84
[run-slow]align
molbap Jun 12, 2024
3ce3608
fix
molbap Jun 12, 2024
6b83e39
remove set() that breaks ordering
molbap Jun 13, 2024
3818b86
test more
molbap Jun 13, 2024
31b7a60
removed unused func
molbap Jun 13, 2024
4072336
[run-slow]align
molbap Jun 13, 2024
bcce007
Merge branch 'main' into uniform_processors_1
molbap Jun 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 15 additions & 33 deletions src/transformers/models/align/processing_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
Image/Text processor class for ALIGN
"""

from typing import List, Union
from typing import List, Union, Unpack

from ...image_utils import ImageInput
from ...processing_utils import (
Expand All @@ -37,39 +37,26 @@
import torch # noqa: F401


class AlignProcessorKwargs(ProcessingKwargs, total=False):
class AlignProcessorKwargs(ProcessingKwargs, CommonKwargs, TextKwargs, ImagesKwargs, total=False):
molbap marked this conversation as resolved.
Show resolved Hide resolved
"""
Inherits from `ProcessingKwargs` to provide:
1) Additional keys that this model requires to process inputs.
2) Default values for extra keys.
New keys have to be defined as follows to ensure type hinting is done correctly.

```python
common_kwargs: CommonKwargs = {
**CommonKwargs.__annotations__,
}
text_kwargs: TextKwargs = {
**TextKwargs.__annotations__,
"a_new_text_boolean_key": Optional[bool],
}
images_kwargs: ImagesKwargs = {
**ImagesKwargs.__annotations__,
"a_new_image_processing_key": Optional[int]
}
```
images_kwargs: ImagesKwargs = {"new_image_kwarg": Optional[bool]}

"""

common_kwargs: CommonKwargs = {
**CommonKwargs.__annotations__,
}
text_kwargs: TextKwargs = {
**TextKwargs.__annotations__,
}
images_kwargs: ImagesKwargs = {
**ImagesKwargs.__annotations__,
_defaults = {
"text_kwargs": {
"padding": "max_length",
"max_length": 64,
},
}

```
"""

_defaults = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_defaults = {
padding: "max_length"
max_lenght: 64

should work no? Or does it not update the default for type-hints?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes it works for sure, this was to have a structured dict for defaults. Can change :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, now I remember, it actually can't work like that since Typed Dicts don't support default values, they are made to hold the layout. They can have any attributes however, but it won't pass a value as default -like a dataclass would, but in this case we'd lose typing-, hence the manual operation

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok got it thanks! Let's maybe comment about this!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a comment for future code inspectors? I'm assuming here isn't the best place (we don't want it for all models) but didn't find a corresponding one elsewhere on a quick skim

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On that: there's doc in processing_utils.ProcessingKwargs, I added a comment nudging users to check there for documentation!

"text_kwargs": {
"padding": "max_length",
Expand Down Expand Up @@ -106,9 +93,10 @@ class AlignProcessor(ProcessorMixin):

processor(images=your_pil_image, text=["What is that?"], **all_kwargs)

# passing directly any number of kwargs is also supported, but not recommended
# passing directly any number of kwargs flattened is also supported

processor(images=your_pil_image, text=["What is that?"], padding="do_not_pad)
all_kwargs = {"return_tensors": "pt", "crop_size": {"height": 214, "width": 214}, "padding": "max_length", "max_length": 76}
processor(images=your_pil_image, text=["What is that?"], **all_kwargs)
```

Args:
Expand All @@ -132,10 +120,7 @@ def __call__(
images: ImageInput = None,
audio=None,
videos=None,
text_kwargs: AlignProcessorKwargs.text_kwargs = None,
images_kwargs: AlignProcessorKwargs.images_kwargs = None,
common_kwargs: AlignProcessorKwargs.common_kwargs = None,
**kwargs: AlignProcessorKwargs,
**kwargs: Unpack[AlignProcessorKwargs],
) -> BatchEncoding:
"""
Main method to prepare text(s) and image(s) to be fed as input to the model. This method forwards the `text`
Expand Down Expand Up @@ -171,9 +156,6 @@ def __call__(
raise ValueError("You must specify either text or images.")
output_kwargs = self._merge_kwargs(
AlignProcessorKwargs,
text_kwargs=text_kwargs,
images_kwargs=images_kwargs,
common_kwargs=common_kwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
Expand Down
75 changes: 34 additions & 41 deletions src/transformers/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,12 +244,22 @@ class CommonKwargs(TypedDict, total=False):
return_tensors: Optional[Union[str, TensorType]]


class ProcessingKwargs(TypedDict, total=False):
common_kwargs: CommonKwargs
text_kwargs: TextKwargs
images_kwargs: ImagesKwargs
audio_kwargs: AudioKwargs
videos_kwargs: VideosKwargs
class ProcessingKwargs(TextKwargs, ImagesKwargs, VideosKwargs, AudioKwargs, CommonKwargs, total=False):
common_kwargs: CommonKwargs = {
**CommonKwargs.__annotations__,
}
text_kwargs: TextKwargs = {
**TextKwargs.__annotations__,
}
images_kwargs: ImagesKwargs = {
**ImagesKwargs.__annotations__,
}
videos_kwargs: VideosKwargs = {
**VideosKwargs.__annotations__,
}
audio_kwargs: AudioKwargs = {
**AudioKwargs.__annotations__,
}


class ProcessorMixin(PushToHubMixin):
Expand Down Expand Up @@ -610,11 +620,6 @@ def from_args_and_dict(cls, args, processor_dict: Dict[str, Any], **kwargs):
def _merge_kwargs(
self,
ModelProcessorKwargs: ProcessingKwargs,
text_kwargs: Optional[TextKwargs] = None,
images_kwargs: Optional[ImagesKwargs] = None,
common_kwargs: Optional[CommonKwargs] = None,
videos_kwargs: Optional[VideosKwargs] = None,
audio_kwargs: Optional[AudioKwargs] = None,
tokenizer_init_kwargs: Optional[Dict] = None,
**kwargs,
) -> Dict[str, Dict]:
Expand Down Expand Up @@ -648,30 +653,21 @@ def _merge_kwargs(
Args:
ModelProcessorKwargs (`ProcessingKwargs`):
Typed dictionary of kwargs specifically required by the model passed.
text_kwargs (`TextKwargs`, *optional*):
Typed dictionary of kwargs inputs applied to the text modality processor, i.e. the tokenizer.
images_kwargs (`ImagesKwargs`, *optional*):
Typed dictionary of kwargs inputs applied to the images modality processor.
videos_kwargs (`VideosKwargs`, *optional*):
Typed dictionary of kwargs inputs applied to the videos modality processor.
audio_kwargs (`AudioKwargs`, *optional*):
Typed dictionary of kwargs inputs applied to the audio modality processor.
tokenizer_init_kwargs (`Dict`, *optional*):
Dictionary of kwargs the tokenizer was instantiated with and need to take precedence over other kwargs.
Dictionary of kwargs the tokenizer was instantiated with and need to take precedence over defaults.

Returns:
output_kwargs (`Dict`):
Dictionary of per-modality kwargs to be passed to each modality-specific processor.

"""

# Initialize dictionaries
output_kwargs = {
"text_kwargs": text_kwargs or {},
"images_kwargs": images_kwargs or {},
"audio_kwargs": audio_kwargs or {},
"videos_kwargs": videos_kwargs or {},
"common_kwargs": common_kwargs or {},
"text_kwargs": {},
"images_kwargs": {},
"audio_kwargs": {},
"videos_kwargs": {},
"common_kwargs": {},
}

default_kwargs = {
Expand All @@ -685,31 +681,28 @@ def _merge_kwargs(
# get defaults from set model processor kwargs if they exist
for modality in default_kwargs:
default_kwargs[modality] = ModelProcessorKwargs._defaults.get(modality, {}).copy()
# then override with tokenizer-level arguments passed
if tokenizer_init_kwargs:
default_kwargs["text_kwargs"].update(
{k: v for k, v in tokenizer_init_kwargs.items() if k in ModelProcessorKwargs.text_kwargs}
)

# then get passed per-modality dictionaries if they exist
# update modality kwargs with passed kwargs
for modality in output_kwargs:
output_kwargs[modality] = {
**default_kwargs[modality],
**output_kwargs[modality],
**kwargs.pop(modality, {}),
}
# then merge kwargs by name
for modality_key in ModelProcessorKwargs[modality].__annotations__.keys():
modality_kwarg_value = kwargs.pop(modality_key, None)
if modality_kwarg_value is not None:
output_kwargs[modality] = modality_kwarg_value
for modality_key in ModelProcessorKwargs.__annotations__[modality].__annotations__.keys():
# init with tokenizer init kwargs if necessary
if modality_key in tokenizer_init_kwargs:
output_kwargs[modality][modality_key] = tokenizer_init_kwargs[modality_key]
# check if we received a structured kwarg dict or not to handle it correctly
if modality in kwargs:
kwarg_value = kwargs[modality].pop(modality_key, "__empty__")
else:
kwarg_value = kwargs.pop(modality_key, "__empty__")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might there be a situation when we have the key in kwargs and modality kwargs?

I am getting a strange if key is provided the same key for both:

result = processor(
    images=image,
    text=["What is that?"],
    crop_size={"height": 256, "width": 256},
    images_kwargs = {"crop_size": {"height": 224, "width": 224}},
)
TypeError: PreTrainedTokenizerFast._batch_encode_plus() got an unexpected keyword argument 'height'

I guess the priority should working here or a proper error raised.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes, good catch! I think raising an error if both keys are defined is easier - responsibility of choosing should be on the user side imo

Copy link
Contributor Author

@molbap molbap Jun 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fixed :) example you included would raise a ValueError

ValueError: Keyword argument crop_size was passed two times: in a dictionary for images_kwargs and as a **kwarg.

if kwarg_value != "__empty__":
output_kwargs[modality][modality_key] = kwarg_value

# if something remains in kwargs, it belongs to common
output_kwargs["common_kwargs"].update(kwargs)
# all modality-specific kwargs are updated with common kwargs
for modality in output_kwargs:
output_kwargs[modality].update(output_kwargs["common_kwargs"])

return output_kwargs

@classmethod
Expand Down
48 changes: 42 additions & 6 deletions tests/models/align/test_processor_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ def test_model_input_names(self):

self.assertListEqual(list(inputs.keys()), processor.model_input_names)

def test_defaults_preserved(self):
# TODO move these tests to a common Mixin
def test_defaults_preserved_kwargs(self):
image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer(max_length=117)

Expand All @@ -218,6 +219,19 @@ def test_defaults_preserved(self):

self.assertEqual(len(inputs["input_ids"]), 117)

@require_torch
def test_defaults_preserved_image_kwargs(self):
image_processor = self.get_image_processor(crop_size=(234, 234))
tokenizer = self.get_tokenizer(max_length=117)

processor = AlignProcessor(tokenizer=tokenizer, image_processor=image_processor)

input_str = "lower newer"
image_input = self.prepare_image_inputs()

inputs = processor(text=input_str, images=image_input)
self.assertEqual(len(inputs["pixel_values"][0][0]), 234)

@require_torch
def test_structured_kwargs(self):
image_processor = self.get_image_processor()
Expand All @@ -229,12 +243,34 @@ def test_structured_kwargs(self):
image_input = self.prepare_image_inputs()

# Define the kwargs for each modality
common_kwargs = {"return_tensors": "pt"}
images_kwargs = {"crop_size": {"height": 214, "width": 214}}
text_kwargs = {"padding": "max_length", "max_length": 76}
all_kwargs = {
"return_tensors": "pt",
"crop_size": {"height": 214, "width": 214},
"padding": "max_length",
"max_length": 76,
}

inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.assertEqual(inputs["pixel_values"].shape[2], 214)

# Combine them into a single dictionary
all_kwargs = {"images_kwargs": images_kwargs, "text_kwargs": text_kwargs, "common_kwargs": common_kwargs}
self.assertEqual(len(inputs["input_ids"][0]), 76)

@require_torch
def test_structured_kwargs_nested(self):
image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer()

processor = AlignProcessor(tokenizer=tokenizer, image_processor=image_processor)

input_str = "lower newer"
image_input = self.prepare_image_inputs()

# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"crop_size": {"height": 214, "width": 214}},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}

inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.assertEqual(inputs["pixel_values"].shape[2], 214)
Expand Down
Loading