From d061b6c34f195ba0931dc7d44949ec3986f37f84 Mon Sep 17 00:00:00 2001 From: shubham-sahoo Date: Mon, 23 Aug 2021 15:44:46 +0530 Subject: [PATCH 1/2] [feat] __eq__ method for SampleList Check if two SampleLists are equal or not --- mmf/common/sample.py | 48 +++++++++++++++++++++++++++++++++++++ tests/common/test_sample.py | 16 +++++++++++++ 2 files changed, 64 insertions(+) diff --git a/mmf/common/sample.py b/mmf/common/sample.py index bafffa460..a74b51c4e 100644 --- a/mmf/common/sample.py +++ b/mmf/common/sample.py @@ -393,6 +393,54 @@ def to_dict(self) -> Dict[str, Any]: return sample_dict + def __eq__(self,other): + """Compare a sampleList with the current SampleList. + + Returns: + Bool : True or False + """ + + if not isinstance(other,SampleList): + return False + + fields = self.fields() + fields_other = other.fields() + tensor_field = self._get_tensor_field() + tensor_field_other = other._get_tensor_field() + + # Check for tensor fields comparison + if ( + len(fields) != 0 + and len(fields_other) != 0 + and tensor_field is not None + and other[tensor_field_other].size(0) != self[tensor_field].size(0) + ): + return False + + a = set(fields) + b = set(fields_other) + + # Comparison between keys and early fail + if a==b: + # Compare all the features + for field in fields: + # Compare Tensors + if ( + isinstance(self[field],torch.Tensor) + and isinstance(other.get_field(field),torch.Tensor) + ): + if not torch.equal(self[field],other.get_field(field)): + return False + + # Compare Lists + else: + if not self[field]==other.get_field(field): + return False + + return True + + return False + def convert_batch_to_sample_list( batch: Union[SampleList, Dict[str, Any]] diff --git a/tests/common/test_sample.py b/tests/common/test_sample.py index 1f90f214e..257a1982a 100644 --- a/tests/common/test_sample.py +++ b/tests/common/test_sample.py @@ -75,6 +75,22 @@ def test_to_dict(self): self.assertTrue(all_keys) self.assertTrue(isinstance(sample_dict, dict)) + + def test_equal(self): + sample_list1 = test_utils.build_random_sample_list() + sample_list2 = sample_list1.copy() + sample_list3 = sample_list1.copy() + sample_list3.add_field('new',list([1,2,3,4])) + sample_list4 = sample_list1.copy() + tensor_size = sample_list1.get_batch_size() + sample_list4.add_field('new',torch.zeros(tensor_size)) + sample_list5 = SampleList() + + self.assertTrue(sample_list1 == sample_list2) + self.assertFalse(sample_list1 == sample_list3) + self.assertFalse(sample_list1 == sample_list4) + self.assertFalse(sample_list2 == sample_list4) + self.assertFalse(sample_list1 == sample_list5) class TestFunctions(unittest.TestCase): From d6b458fb0826ad92e3265ab04c6ba465fc6d5cfc Mon Sep 17 00:00:00 2001 From: shubham-sahoo Date: Sat, 28 Aug 2021 14:24:12 +0530 Subject: [PATCH 2/2] [feat] '__eq__' method for SampleList Method and tests modified --- mmf/common/sample.py | 19 +++++++++++++------ tests/common/test_sample.py | 31 +++++++++++++++++++++++++------ 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/mmf/common/sample.py b/mmf/common/sample.py index a74b51c4e..cdf0450fe 100644 --- a/mmf/common/sample.py +++ b/mmf/common/sample.py @@ -413,28 +413,35 @@ def __eq__(self,other): len(fields) != 0 and len(fields_other) != 0 and tensor_field is not None + and tensor_field_other is not None and other[tensor_field_other].size(0) != self[tensor_field].size(0) ): return False - a = set(fields) - b = set(fields_other) + fields_set = set(fields) + fields_set_other = set(fields_other) # Comparison between keys and early fail - if a==b: + if fields_set==fields_set_other: # Compare all the features for field in fields: # Compare Tensors if ( - isinstance(self[field],torch.Tensor) + isinstance(self.get_field(field),torch.Tensor) and isinstance(other.get_field(field),torch.Tensor) ): - if not torch.equal(self[field],other.get_field(field)): + if not torch.equal(self.get_field(field),other.get_field(field)): return False + # Check for same data type + elif ( + type(self.get_field(field)) is not type(other.get_field(field)) + ): + return False + # Compare Lists else: - if not self[field]==other.get_field(field): + if not self.get_field(field)==other.get_field(field): return False return True diff --git a/tests/common/test_sample.py b/tests/common/test_sample.py index 257a1982a..75560c41e 100644 --- a/tests/common/test_sample.py +++ b/tests/common/test_sample.py @@ -80,19 +80,38 @@ def test_equal(self): sample_list1 = test_utils.build_random_sample_list() sample_list2 = sample_list1.copy() sample_list3 = sample_list1.copy() - sample_list3.add_field('new',list([1,2,3,4])) + + sample_list3.add_field('new',list([1,2,3,4,5])) sample_list4 = sample_list1.copy() tensor_size = sample_list1.get_batch_size() sample_list4.add_field('new',torch.zeros(tensor_size)) + sample_list5 = SampleList() + sample_list6 = SampleList() + sample_list6.add_field('new',SampleList()) - self.assertTrue(sample_list1 == sample_list2) - self.assertFalse(sample_list1 == sample_list3) - self.assertFalse(sample_list1 == sample_list4) - self.assertFalse(sample_list2 == sample_list4) - self.assertFalse(sample_list1 == sample_list5) + sample_list7 = SampleList() + dict_example = {'a':1, 'b':2} + sample_list7.add_field('new',dict_example) + + sample_list8 = sample_list1.copy() + sample_list8.add_field('new',torch.ones(tensor_size)) + self.assertTrue(sample_list1 == sample_list2) + self.assertTrue(sample_list1 != sample_list3) + self.assertTrue(sample_list1 != sample_list4) + self.assertTrue(sample_list2 != sample_list4) + self.assertTrue(sample_list1 != sample_list5) + self.assertTrue(sample_list1 != sample_list6) + self.assertTrue(sample_list1 != sample_list7) + self.assertTrue(sample_list5 != sample_list6) + self.assertTrue(sample_list6 != sample_list7) + self.assertTrue(sample_list6 != sample_list5) + self.assertTrue(sample_list6 != sample_list1) + self.assertTrue(sample_list1 != sample_list8) + self.assertFalse(sample_list4 == sample_list8) + class TestFunctions(unittest.TestCase): def test_to_device(self): sample_list = test_utils.build_random_sample_list()