-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Immutability for data collators #30603
Immutability for data collators #30603
Conversation
Oh yea, one last thing. I've created separate classes for the immutability tests. Thought it got too convoluted otherwise. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me - I'm happy with the change to the data collator code itself! The one issue I'd raise is that it seems like there's a lot of code duplication in the tests, with just some changes like torch.tensor
-> np.array
or different values for return_tensors
. This is probably fine, though - that kind of explicit duplication makes it easier to locate test errors.
cc @amyeroberts for core maintainer review (and also let me know if you think I'm wrong about the code duplication)
@Rocketknight1 With code duplications, do you mean across the different classes between pt/tf/np? I agree with that, it's more so a dependency check to see that it suddenly doesn't branch out to something unwanted and that different input datatypes are handled correctly. For example, when the default collator is used, it branches into separate np, tf, and pt calls (i.e. |
Yes, that's what I was referring to - and I think it's fine to keep it as-is! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for tackling this and adding tests ❤️
Just a few formatting style comments (to be propogated to TF and Flax testers too). Otherwise looks great
# this might occur when we pass {..., "labels": None} | ||
if labels is not None and all(label is None for label in labels): | ||
labels = None | ||
no_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit - it'd call this non_label_features
instead of no-labels features. I'd parse the latter as features which have no corresponding label, which isn't necessarily the case
no_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features] | |
non_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features] |
tests/trainer/test_data_collator.py
Outdated
def _compare_assert_with_collator(self, collator, original, batch): | ||
# we only care about side effects, the results are tested elsewhere | ||
collator(batch) | ||
|
||
# we go through every item and convert to `primitive` datatypes if necessary | ||
# then compares for equivalence for the original data and the data that has been passed through the collator | ||
for i in range(len(original)): | ||
for key in original[i].keys(): | ||
if isinstance(original[i][key], np.ndarray): | ||
self.assertEqual(original[i][key].tolist(), batch[i][key].tolist()) | ||
elif isinstance(original[i][key], torch.Tensor): | ||
self.assertEqual(original[i][key].tolist(), batch[i][key].tolist()) | ||
else: | ||
self.assertEqual(original[i][key], batch[i][key]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- let's iterate over the objects directly - it's more pythonic - range(len(x)) is a bit of a code smell
- not sure on the data type of batch here - it is a dict? If so, it would be better to just do
for original_val, batch_val in zip(original.values(), batch.values())
- Two of the branches are the same
def _compare_assert_with_collator(self, collator, original, batch): | |
# we only care about side effects, the results are tested elsewhere | |
collator(batch) | |
# we go through every item and convert to `primitive` datatypes if necessary | |
# then compares for equivalence for the original data and the data that has been passed through the collator | |
for i in range(len(original)): | |
for key in original[i].keys(): | |
if isinstance(original[i][key], np.ndarray): | |
self.assertEqual(original[i][key].tolist(), batch[i][key].tolist()) | |
elif isinstance(original[i][key], torch.Tensor): | |
self.assertEqual(original[i][key].tolist(), batch[i][key].tolist()) | |
else: | |
self.assertEqual(original[i][key], batch[i][key]) | |
def _compare_assert_with_collator(self, collator, original_data, batched_data): | |
# we only care about side effects, the results are tested elsewhere | |
collator(batched_data) | |
# we go through every item and convert to `primitive` datatypes if necessary | |
# then compares for equivalence for the original data and the data that has been passed through the collator | |
for original, batch in zip(original_data, batched_data): | |
for key in original.keys(): | |
original_value = original[key] | |
batch_value = batch[key] | |
if isinstance(original_value, (np.ndarray, torch.Tensor)): | |
self.assertEqual(original_value.tolist(), batch_value.tolist()) | |
else: | |
self.assertEqual(original_value, batch_value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, the datatype of original_data
and batch_data
is List[Dict[str, Any]]
where any usually is a list or tensor etc.
Have changed the code to iterate via zip.
tests/trainer/test_data_collator.py
Outdated
"""used to convert `item` to `None` type""" | ||
return None | ||
|
||
def _compare_assert_with_collator(self, collator, original, batch): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we rename this? We're not really comparing an assert with a collator - we're checking/comparing/validating the original and collated data are the same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed to _validate_original_data_against_collated_data
but open to other suggestions.
tests/trainer/test_data_collator.py
Outdated
else: | ||
self.assertEqual(original[i][key], batch[i][key]) | ||
|
||
def _compare_assert_with_collator_on_datatypes( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same for the name here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed to _validate_original_data_against_collated_data_on_specified_keys_and_datatypes
. Might be a bit long tho.
tests/trainer/test_data_collator.py
Outdated
(list, self._turn_to_none), | ||
]: | ||
self._compare_assert_with_collator_on_datatypes( | ||
default_data_collator, features_base_single_label, "inputs", datatype_input, "label", datatype_label |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make the test clearer - let's call the test with kwargs rather than args
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Propagated to other similar calls.
|
||
# some collators do not use labels, or sometimes we want to check if the collator with labels can handle such cases | ||
if ignore_label: | ||
for original, batch in zip(features_original, features_batch): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed the iteration style here as well from range(len(...))
to zip based style.
Since both features are created from the same base_data
, it ensures that we encounter all samples in both original and batch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the work adding this and iterating!
* immutability fix for seq2seq as well as immutability tests for the collators * ensure we don't act on none labels and formatting * remove tf/pt in respective tests as they are not required * more type error fixes tf/np * remove todo * apply suggestions from code review * formatting / style
* immutability fix for seq2seq as well as immutability tests for the collators * ensure we don't act on none labels and formatting * remove tf/pt in respective tests as they are not required * more type error fixes tf/np * remove todo * apply suggestions from code review * formatting / style
What does this PR do?
Introduces new tests that check if a data collator might introduce side effects, i.e. the given input changes after the call to the collator. Motivated by #30556
Furthermore, fixes the seq2seq collator to not introduce side effects on the given input's labels. This is done by:
As a side note, added some more checks on
None
labels, especially when given"labels": None
in the dictionary.Last remarks:
None
labels. For now, I return them byNone
again.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@amyeroberts @Rocketknight1