Skip to content

Commit

Permalink
Cover caption
Browse files Browse the repository at this point in the history
  • Loading branch information
sooahleex committed Jun 22, 2024
1 parent b0bc72c commit 8d89e59
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 5 deletions.
13 changes: 10 additions & 3 deletions src/datumaro/plugins/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1861,8 +1861,6 @@ class Clean(ItemTransform):
- **Numeric Media**: For numeric data, the class identifies and handles outliers and missing values.
Outliers are either removed or replaced based on a defined strategy,
and missing values are filled using appropriate methods such as mean, median, or a predefined value.|n
- **Categorical Media**: For categorical data, the class addresses missing values.
Missing values in categorical columns are filled with the mode or a specified placeholder.|n
"""

def __init__(
Expand All @@ -1873,6 +1871,7 @@ def __init__(

self._outlier_value = {}
self._missing_value = {}
self._sep_token = ":"

@staticmethod
def remove_unnecessary_char(text):
Expand Down Expand Up @@ -1982,14 +1981,22 @@ def transform_item(self, item):
"Item %s: TableRow info is required for this " "transform" % (item.id,)
)

sep_token = self._sep_token
refined_media = self.refine_tabular_media(item) if item.media.has_data else None
refined_annotations = []
for ann in item.annotations:
if isinstance(ann, Tabular):
annotation_values = {
key: refined_media.data[key] for key in item.annotations[0].values.keys()
} # only for tabular
ann.wrap(values=annotation_values)
ann = ann.wrap(values=annotation_values)
elif isinstance(ann, Caption):
value = [
f"{key}{sep_token}{refined_media.data[key]}"
for key in refined_media.data.keys()
if ann.caption.startswith(key)
]
ann = ann.wrap(caption=value[0])
refined_annotations.append(ann)

return self.wrap_item(item, media=refined_media, annotations=refined_annotations)
98 changes: 96 additions & 2 deletions tests/unit/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1474,6 +1474,46 @@ def setUp(self):
},
media_type=TableRow,
)
self.orig_astyped_dataset = Dataset.from_iterable(
[
DatasetItem(
id=i,
subset="train",
media=TableRow(table=table, index=i),
annotations=[
Tabular(
values={
"Rating": table.data["Rating"][i],
"Age": table.data["Age"][i],
"Title": table.data["Title"][i],
"Review Text": table.data["Review Text"][i],
"Division Name": table.data["Division Name"][i],
}
)
],
)
for i in range(1, 6)
],
categories={
AnnotationType.tabular: TabularCategories.from_iterable(
[
(
"Age",
float,
{10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0},
),
("Review Text", str),
("Rating", CategoricalDtype(), {1.0, 2.0, 3.0, 4.0, 5.0}),
(
"Division Name",
CategoricalDtype(),
{"General", "General Petite", "Initmates"},
),
]
)
},
media_type=TableRow,
)

self.tabular_refined_path = osp.join(
get_test_asset_path("tabular_dataset"), "women-clothing", "women_clothing_refined.csv"
Expand Down Expand Up @@ -1519,16 +1559,56 @@ def setUp(self):
},
media_type=TableRow,
)
self.refined_astyped_dataset = Dataset.from_iterable(
[
DatasetItem(
id=i,
subset="train",
media=TableRow(table=table, index=i),
annotations=[
Tabular(
values={
"Rating": table.data["Rating"][i],
"Age": table.data["Age"][i],
"Title": table.data["Title"][i],
"Review Text": table.data["Review Text"][i],
"Division Name": table.data["Division Name"][i],
}
)
],
)
for i in range(1, 6)
],
categories={
AnnotationType.tabular: TabularCategories.from_iterable(
[
(
"Age",
float,
{10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0},
),
("Review Text", str),
("Rating", CategoricalDtype(), {1.0, 2.0, 3.0, 4.0, 5.0}),
(
"Division Name",
CategoricalDtype(),
{"General", "General Petite", "Initmates"},
),
]
)
},
media_type=TableRow,
)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_remove_unneccessary_char(self):
example_text = "This is a test 😊! Check out https://example.com for more <b>details</b> about this text. Enjoy!!!"
cleaned_text = "test check details text enjoy"

with self.subTest("with None"):
self.assertIsNone(transforms.Clean.remove_unneccessary_char(None))
self.assertIsNone(transforms.Clean.remove_unnecessary_char(None))
with self.subTest("with normal text"):
self.assertEqual(transforms.Clean.remove_unneccessary_char(example_text), cleaned_text)
self.assertEqual(transforms.Clean.remove_unnecessary_char(example_text), cleaned_text)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_find_closest_value(self):
Expand Down Expand Up @@ -1579,3 +1659,17 @@ def test_transform_clean_with_target(self):
result_item = result.__getitem__(i)
self.assertEqual(expected_item.annotations, result_item.annotations)
self.assertEqual(expected_item.media, result_item.media)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_transform_clean_after_astype_ann(self):
dataset = self.orig_astyped_dataset
dataset = dataset.transform("astype_annotations")
result = dataset.transform("clean")

expected = self.refined_astyped_dataset
expected = expected.transform("astype_annotations")

for i, expected_item in enumerate(expected):
result_item = result.__getitem__(i)
self.assertEqual(expected_item.annotations, result_item.annotations)
self.assertEqual(expected_item.media, result_item.media)

0 comments on commit 8d89e59

Please sign in to comment.