From 8d89e59a6798d8b350ed6d956a288df1bbaad141 Mon Sep 17 00:00:00 2001 From: sooahleex Date: Sat, 15 Jun 2024 15:56:48 +0900 Subject: [PATCH] Cover caption --- src/datumaro/plugins/transforms.py | 13 +++- tests/unit/test_transforms.py | 98 +++++++++++++++++++++++++++++- 2 files changed, 106 insertions(+), 5 deletions(-) diff --git a/src/datumaro/plugins/transforms.py b/src/datumaro/plugins/transforms.py index 5a5780009f..ef5339ffe3 100644 --- a/src/datumaro/plugins/transforms.py +++ b/src/datumaro/plugins/transforms.py @@ -1861,8 +1861,6 @@ class Clean(ItemTransform): - **Numeric Media**: For numeric data, the class identifies and handles outliers and missing values. Outliers are either removed or replaced based on a defined strategy, and missing values are filled using appropriate methods such as mean, median, or a predefined value.|n - - **Categorical Media**: For categorical data, the class addresses missing values. - Missing values in categorical columns are filled with the mode or a specified placeholder.|n """ def __init__( @@ -1873,6 +1871,7 @@ def __init__( self._outlier_value = {} self._missing_value = {} + self._sep_token = ":" @staticmethod def remove_unnecessary_char(text): @@ -1982,6 +1981,7 @@ def transform_item(self, item): "Item %s: TableRow info is required for this " "transform" % (item.id,) ) + sep_token = self._sep_token refined_media = self.refine_tabular_media(item) if item.media.has_data else None refined_annotations = [] for ann in item.annotations: @@ -1989,7 +1989,14 @@ def transform_item(self, item): annotation_values = { key: refined_media.data[key] for key in item.annotations[0].values.keys() } # only for tabular - ann.wrap(values=annotation_values) + ann = ann.wrap(values=annotation_values) + elif isinstance(ann, Caption): + value = [ + f"{key}{sep_token}{refined_media.data[key]}" + for key in refined_media.data.keys() + if ann.caption.startswith(key) + ] + ann = ann.wrap(caption=value[0]) refined_annotations.append(ann) return self.wrap_item(item, media=refined_media, annotations=refined_annotations) diff --git a/tests/unit/test_transforms.py b/tests/unit/test_transforms.py index a99ddfdacc..24db8a76e6 100644 --- a/tests/unit/test_transforms.py +++ b/tests/unit/test_transforms.py @@ -1474,6 +1474,46 @@ def setUp(self): }, media_type=TableRow, ) + self.orig_astyped_dataset = Dataset.from_iterable( + [ + DatasetItem( + id=i, + subset="train", + media=TableRow(table=table, index=i), + annotations=[ + Tabular( + values={ + "Rating": table.data["Rating"][i], + "Age": table.data["Age"][i], + "Title": table.data["Title"][i], + "Review Text": table.data["Review Text"][i], + "Division Name": table.data["Division Name"][i], + } + ) + ], + ) + for i in range(1, 6) + ], + categories={ + AnnotationType.tabular: TabularCategories.from_iterable( + [ + ( + "Age", + float, + {10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0}, + ), + ("Review Text", str), + ("Rating", CategoricalDtype(), {1.0, 2.0, 3.0, 4.0, 5.0}), + ( + "Division Name", + CategoricalDtype(), + {"General", "General Petite", "Initmates"}, + ), + ] + ) + }, + media_type=TableRow, + ) self.tabular_refined_path = osp.join( get_test_asset_path("tabular_dataset"), "women-clothing", "women_clothing_refined.csv" @@ -1519,6 +1559,46 @@ def setUp(self): }, media_type=TableRow, ) + self.refined_astyped_dataset = Dataset.from_iterable( + [ + DatasetItem( + id=i, + subset="train", + media=TableRow(table=table, index=i), + annotations=[ + Tabular( + values={ + "Rating": table.data["Rating"][i], + "Age": table.data["Age"][i], + "Title": table.data["Title"][i], + "Review Text": table.data["Review Text"][i], + "Division Name": table.data["Division Name"][i], + } + ) + ], + ) + for i in range(1, 6) + ], + categories={ + AnnotationType.tabular: TabularCategories.from_iterable( + [ + ( + "Age", + float, + {10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0}, + ), + ("Review Text", str), + ("Rating", CategoricalDtype(), {1.0, 2.0, 3.0, 4.0, 5.0}), + ( + "Division Name", + CategoricalDtype(), + {"General", "General Petite", "Initmates"}, + ), + ] + ) + }, + media_type=TableRow, + ) @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_remove_unneccessary_char(self): @@ -1526,9 +1606,9 @@ def test_remove_unneccessary_char(self): cleaned_text = "test check details text enjoy" with self.subTest("with None"): - self.assertIsNone(transforms.Clean.remove_unneccessary_char(None)) + self.assertIsNone(transforms.Clean.remove_unnecessary_char(None)) with self.subTest("with normal text"): - self.assertEqual(transforms.Clean.remove_unneccessary_char(example_text), cleaned_text) + self.assertEqual(transforms.Clean.remove_unnecessary_char(example_text), cleaned_text) @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_find_closest_value(self): @@ -1579,3 +1659,17 @@ def test_transform_clean_with_target(self): result_item = result.__getitem__(i) self.assertEqual(expected_item.annotations, result_item.annotations) self.assertEqual(expected_item.media, result_item.media) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_transform_clean_after_astype_ann(self): + dataset = self.orig_astyped_dataset + dataset = dataset.transform("astype_annotations") + result = dataset.transform("clean") + + expected = self.refined_astyped_dataset + expected = expected.transform("astype_annotations") + + for i, expected_item in enumerate(expected): + result_item = result.__getitem__(i) + self.assertEqual(expected_item.annotations, result_item.annotations) + self.assertEqual(expected_item.media, result_item.media)