-
Notifications
You must be signed in to change notification settings - Fork 936
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] '__eq__' method for SampleList #1053
base: main
Are you sure you want to change the base?
Changes from 2 commits
d061b6c
626ab7e
b61213e
6f9bb98
d6b458f
0f4155f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -393,6 +393,54 @@ def to_dict(self) -> Dict[str, Any]: | |
|
||
return sample_dict | ||
|
||
def __eq__(self,other): | ||
"""Compare a sampleList with the current SampleList. | ||
|
||
Returns: | ||
Bool : True or False | ||
""" | ||
|
||
if not isinstance(other,SampleList): | ||
return False | ||
|
||
fields = self.fields() | ||
fields_other = other.fields() | ||
tensor_field = self._get_tensor_field() | ||
tensor_field_other = other._get_tensor_field() | ||
|
||
# Check for tensor fields comparison | ||
if ( | ||
len(fields) != 0 | ||
and len(fields_other) != 0 | ||
and tensor_field is not None | ||
and other[tensor_field_other].size(0) != self[tensor_field].size(0) | ||
): | ||
return False | ||
|
||
a = set(fields) | ||
b = set(fields_other) | ||
|
||
# Comparison between keys and early fail | ||
if a==b: | ||
# Compare all the features | ||
for field in fields: | ||
# Compare Tensors | ||
if ( | ||
isinstance(self[field],torch.Tensor) | ||
and isinstance(other.get_field(field),torch.Tensor) | ||
): | ||
if not torch.equal(self[field],other.get_field(field)): | ||
return False | ||
|
||
# Compare Lists | ||
else: | ||
if not self[field]==other.get_field(field): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe use either key access on both sides or use |
||
return False | ||
|
||
return True | ||
|
||
return False | ||
|
||
|
||
def convert_batch_to_sample_list( | ||
batch: Union[SampleList, Dict[str, Any]] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -75,6 +75,22 @@ def test_to_dict(self): | |
|
||
self.assertTrue(all_keys) | ||
self.assertTrue(isinstance(sample_dict, dict)) | ||
|
||
def test_equal(self): | ||
sample_list1 = test_utils.build_random_sample_list() | ||
sample_list2 = sample_list1.copy() | ||
sample_list3 = sample_list1.copy() | ||
sample_list3.add_field('new',list([1,2,3,4])) | ||
sample_list4 = sample_list1.copy() | ||
tensor_size = sample_list1.get_batch_size() | ||
sample_list4.add_field('new',torch.zeros(tensor_size)) | ||
sample_list5 = SampleList() | ||
|
||
self.assertTrue(sample_list1 == sample_list2) | ||
self.assertFalse(sample_list1 == sample_list3) | ||
self.assertFalse(sample_list1 == sample_list4) | ||
self.assertFalse(sample_list2 == sample_list4) | ||
self.assertFalse(sample_list1 == sample_list5) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Assert more complex cases and cases which are not equal as well. Try testing SampleList inside SampleList or dict inside SampleList. |
||
|
||
|
||
class TestFunctions(unittest.TestCase): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Linter won't be happy with this.