Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Immutability for data collators #30603

Merged
merged 7 commits into from
May 8, 2024

Conversation

vasqu
Copy link
Contributor

@vasqu vasqu commented May 1, 2024

What does this PR do?

Introduces new tests that check if a data collator might introduce side effects, i.e. the given input changes after the call to the collator. Motivated by #30556

Furthermore, fixes the seq2seq collator to not introduce side effects on the given input's labels. This is done by:

  • Passing only relevant features to the tokenizer to pad.
  • Manually crafting the labels afterwards.
  • Reintroducing tokenizer behaviour by converting labels to the respective datatype (pt, tf, np).
    As a side note, added some more checks on None labels, especially when given "labels": None in the dictionary.

Last remarks:

  • The test handles most of the cases that are introduced in the base tests but if I missed something give me a heads up :D
  • I'm not sure how to handle it when the user introduces None labels. For now, I return them by None again.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@amyeroberts @Rocketknight1

@vasqu
Copy link
Contributor Author

vasqu commented May 1, 2024

Oh yea, one last thing. I've created separate classes for the immutability tests. Thought it got too convoluted otherwise.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me - I'm happy with the change to the data collator code itself! The one issue I'd raise is that it seems like there's a lot of code duplication in the tests, with just some changes like torch.tensor -> np.array or different values for return_tensors. This is probably fine, though - that kind of explicit duplication makes it easier to locate test errors.

cc @amyeroberts for core maintainer review (and also let me know if you think I'm wrong about the code duplication)

src/transformers/data/data_collator.py Outdated Show resolved Hide resolved
@vasqu
Copy link
Contributor Author

vasqu commented May 3, 2024

@Rocketknight1 With code duplications, do you mean across the different classes between pt/tf/np? I agree with that, it's more so a dependency check to see that it suddenly doesn't branch out to something unwanted and that different input datatypes are handled correctly. For example, when the default collator is used, it branches into separate np, tf, and pt calls (i.e. numpy_default_data_collator, tf_default_data_collator, torch_default_data_collator). I haven't deep-dived where else something like this might happen.

@Rocketknight1
Copy link
Member

Yes, that's what I was referring to - and I think it's fine to keep it as-is!

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for tackling this and adding tests ❤️

Just a few formatting style comments (to be propogated to TF and Flax testers too). Otherwise looks great

# this might occur when we pass {..., "labels": None}
if labels is not None and all(label is None for label in labels):
labels = None
no_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - it'd call this non_label_features instead of no-labels features. I'd parse the latter as features which have no corresponding label, which isn't necessarily the case

Suggested change
no_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features]
non_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features]

src/transformers/data/data_collator.py Show resolved Hide resolved
Comment on lines 459 to 472
def _compare_assert_with_collator(self, collator, original, batch):
# we only care about side effects, the results are tested elsewhere
collator(batch)

# we go through every item and convert to `primitive` datatypes if necessary
# then compares for equivalence for the original data and the data that has been passed through the collator
for i in range(len(original)):
for key in original[i].keys():
if isinstance(original[i][key], np.ndarray):
self.assertEqual(original[i][key].tolist(), batch[i][key].tolist())
elif isinstance(original[i][key], torch.Tensor):
self.assertEqual(original[i][key].tolist(), batch[i][key].tolist())
else:
self.assertEqual(original[i][key], batch[i][key])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • let's iterate over the objects directly - it's more pythonic - range(len(x)) is a bit of a code smell
  • not sure on the data type of batch here - it is a dict? If so, it would be better to just do for original_val, batch_val in zip(original.values(), batch.values())
  • Two of the branches are the same
Suggested change
def _compare_assert_with_collator(self, collator, original, batch):
# we only care about side effects, the results are tested elsewhere
collator(batch)
# we go through every item and convert to `primitive` datatypes if necessary
# then compares for equivalence for the original data and the data that has been passed through the collator
for i in range(len(original)):
for key in original[i].keys():
if isinstance(original[i][key], np.ndarray):
self.assertEqual(original[i][key].tolist(), batch[i][key].tolist())
elif isinstance(original[i][key], torch.Tensor):
self.assertEqual(original[i][key].tolist(), batch[i][key].tolist())
else:
self.assertEqual(original[i][key], batch[i][key])
def _compare_assert_with_collator(self, collator, original_data, batched_data):
# we only care about side effects, the results are tested elsewhere
collator(batched_data)
# we go through every item and convert to `primitive` datatypes if necessary
# then compares for equivalence for the original data and the data that has been passed through the collator
for original, batch in zip(original_data, batched_data):
for key in original.keys():
original_value = original[key]
batch_value = batch[key]
if isinstance(original_value, (np.ndarray, torch.Tensor)):
self.assertEqual(original_value.tolist(), batch_value.tolist())
else:
self.assertEqual(original_value, batch_value)

Copy link
Contributor Author

@vasqu vasqu May 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, the datatype of original_data and batch_data is List[Dict[str, Any]] where any usually is a list or tensor etc.

Have changed the code to iterate via zip.

"""used to convert `item` to `None` type"""
return None

def _compare_assert_with_collator(self, collator, original, batch):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we rename this? We're not really comparing an assert with a collator - we're checking/comparing/validating the original and collated data are the same

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed to _validate_original_data_against_collated_data but open to other suggestions.

else:
self.assertEqual(original[i][key], batch[i][key])

def _compare_assert_with_collator_on_datatypes(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for the name here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed to _validate_original_data_against_collated_data_on_specified_keys_and_datatypes. Might be a bit long tho.

(list, self._turn_to_none),
]:
self._compare_assert_with_collator_on_datatypes(
default_data_collator, features_base_single_label, "inputs", datatype_input, "label", datatype_label
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make the test clearer - let's call the test with kwargs rather than args

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Propagated to other similar calls.


# some collators do not use labels, or sometimes we want to check if the collator with labels can handle such cases
if ignore_label:
for original, batch in zip(features_original, features_batch):
Copy link
Contributor Author

@vasqu vasqu May 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the iteration style here as well from range(len(...)) to zip based style.

Since both features are created from the same base_data, it ensures that we encounter all samples in both original and batch.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the work adding this and iterating!

@amyeroberts amyeroberts merged commit 71c1985 into huggingface:main May 8, 2024
20 checks passed
@vasqu vasqu deleted the immutability-for-collators branch May 8, 2024 17:08
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request May 10, 2024
* immutability fix for seq2seq as well as immutability tests for the collators

* ensure we don't act on none labels and formatting

* remove tf/pt in respective tests as they are not required

* more type error fixes tf/np

* remove todo

* apply suggestions from code review

* formatting / style
itazap pushed a commit that referenced this pull request May 14, 2024
* immutability fix for seq2seq as well as immutability tests for the collators

* ensure we don't act on none labels and formatting

* remove tf/pt in respective tests as they are not required

* more type error fixes tf/np

* remove todo

* apply suggestions from code review

* formatting / style
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants