Skip to content

Commit

Permalink
Set label name with parents to avoid duplicates for AstypeAnnotations (
Browse files Browse the repository at this point in the history
…#1492)

- Except `nan` for label
- Set label name with parents to avoid duplicate name for `AstypeAnnotations`
- Add unit test when tabular dataset included missing value
- Add unit test for AstypeAnnotations when label value is nan
  • Loading branch information
sooahleex authored May 9, 2024
1 parent 62ec011 commit 072c8a8
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 9 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
(<https://github.com/openvinotoolkit/datumaro/pull/1471>)
- Add ExtractedMask and update importers who can use it to use it
(<https://github.com/openvinotoolkit/datumaro/pull/1480>)
- Set label name with parents to avoid duplicates for AstypeAnnotations
(<https://github.com/openvinotoolkit/datumaro/pull/1492>)

### Bug fixes
- Split the video directory into subsets to avoid overwriting
Expand Down
8 changes: 7 additions & 1 deletion src/datumaro/plugins/data_formats/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,13 @@ def _parse(
target_dtype = table.dtype(target_)
if target_dtype in [int, float, pd.api.types.CategoricalDtype()]:
# 'int' can be categorical, but we don't know this unless user gives information.
labels = set(table.features(target_, unique=True))
labels = set(
[
feature
for feature in table.features(target_, unique=True)
if not pd.isna(feature)
]
)
if category is None:
categories.add(target_, target_dtype, labels)
else: # update labels if they are different.
Expand Down
17 changes: 10 additions & 7 deletions src/datumaro/plugins/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1495,6 +1495,8 @@ def __init__(
):
super().__init__(extractor)

self._sep_token = ":"

if extractor.media_type() and not issubclass(extractor.media_type(), TableRow):
raise MediaTypeError(
"Media type is not table. This transform only support tabular media"
Expand Down Expand Up @@ -1523,6 +1525,7 @@ def __init__(
dst_parent = src_cat.name
dst_labels = sorted(src_cat.labels)
for dst_label in dst_labels:
dst_label = dst_parent + self._sep_token + str(dst_label)
dst_index = dst_label_cat.add(dst_label, parent=dst_parent, attributes={})
self._id_mapping[dst_label] = dst_index
dst_label_cat.add_label_group(src_cat.name, src_cat.labels, group_type=0)
Expand All @@ -1533,12 +1536,12 @@ def categories(self):
return self._categories

def transform_item(self, item: DatasetItem):
annotations = []
for name, value in item.annotations[0].values.items():
dtype = self._tabular_cat_types.get(name, None)
if dtype == CategoricalDtype():
annotations.append(Label(label=self._id_mapping[value]))
else:
annotations.append(Caption(value))
annotations = [
Label(label=self._id_mapping[name + self._sep_token + str(value)])
if self._tabular_cat_types.get(name) == CategoricalDtype() and value is not None
else Caption(value)
for name, value in item.annotations[0].values.items()
if value is not None
]

return self.wrap_item(item, annotations=annotations)
25 changes: 25 additions & 0 deletions tests/assets/tabular_dataset/electricity_missing.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
date,day,period,nswprice,nswdemand,vicprice,vicdemand,transfer,class
0.425556,5,0.340426,0.076108,0.392889,0.003467,0.422915,0.414912,UP
0.425512,4,0.617021,0.060376,0.483041,0.003467,0.422915,0.414912,DOWN
0.013982,4,0.042553,0.061967,0.521125,0.003467,0.422915,0.414912,DOWN
0.907349,3,0.06383,0.080581,0.331003,0.00538,0.47566,0.441228,DOWN
0.889341,0,0.361702,0.027141,0.379649,0.001624,0.248317,0.69386,DOWN
0.433565,3,0.787234,0.082803,0.447337,0.003467,0.422915,0.414912,UP
0.894474,4,0.787234,0.088087,0.840672,0.006012,0.752978,0.455702,UP
0.866997,5,0.446809,0.037739,0.506992,0.002495,0.339202,0.664474,UP
0.460909,3,0.319149,0.054672,0.585689,0.003741,0.448731,0.389912,DOWN
0.031857,4,0.255319,0.055242,0.115739,0.003467,0.422915,0.414912,DOWN
0.876023,5,1.0,0.028822,0.369087,0.001477,0.336872,0.769298,UP
0.030707,6,0.042553,0.047526,0.132104,0.003467,0.422915,0.414912,DOWN
0.500111,6,0.914894,0.06617,0.300952,0.00446,0.287416,0.420175,UP
0.890093,3,0.744681,0.338747,0.960875,0.023332,0.857328,0.325,UP
0.898544,6,0.531915,0.090068,0.476941,0.005544,0.35448,0.716228,UP
0.434406,1,0.340426,0.051039,0.518596,0.003467,0.422915,0.414912,UP
0.881023,0,0.787234,0.029302,0.409104,0.001847,0.418436,0.746053,UP
0.872174,4,0.148936,0.019125,0.142368,0.000841,0.250388,0.875,DOWN
0.469094,0,0.680851,0.029152,0.267034,0.001917,0.23796,0.587281,DOWN
0.871776,2,0.595745,0.0269,0.530348,0.001687,0.682548,0.630702,UP
0.456086,5,0.93617,0.053591,0.57468,0.003671,0.387364,0.565789,
0.486837,3,0.446809,0.07959,0.499851,0.005501,0.493009,0.296491,
0.009513,0,0.170213,0.041341,0.191461,0.003467,0.422915,0.414912,
0.429052,1,0.659574,0.100546,0.512794,0.003467,0.422915,0.414912,
26 changes: 26 additions & 0 deletions tests/unit/test_tabular_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ def fxt_electricity(fxt_tabular_root):
yield Dataset.import_from(path, "tabular")


@pytest.fixture()
def fxt_electricity_missing(fxt_tabular_root):
path = osp.join(fxt_tabular_root, "electricity_missing.csv")
yield Dataset.import_from(path, "tabular", target={"input": "nswprice", "output": "class"})


@pytest.fixture()
def fxt_buddy_target():
yield {"input": "length(m)", "output": ["breed_category", "pet_category"]}
Expand Down Expand Up @@ -178,3 +184,23 @@ def test_target_dtype(self, fxt_tabular_root, target, expected_included_labels)
)
def test_string_to_dict(self, input_string, expected_result):
assert string_to_dict(input_string) == expected_result

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_import_tabular_file_with_missing_value(self, fxt_electricity_missing) -> None:
import math

dataset: Type[Dataset] = fxt_electricity_missing
expected_categories_keys = [("class", CategoricalDtype())]
expected_category_labels = {"UP", "DOWN"}

result_categories = dataset.categories()[AnnotationType.tabular].items[0]
assert [(result_categories.name, result_categories.dtype)] == expected_categories_keys
assert len(dataset) == 24
assert result_categories.labels == expected_category_labels

num_nan_annotations = sum(
math.isnan(item.annotations[0].values["class"])
for item in dataset
if isinstance(item.annotations[0].values["class"], float)
)
assert num_nan_annotations == 4
77 changes: 76 additions & 1 deletion tests/unit/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1271,6 +1271,37 @@ def setUp(self):
categories={},
media_type=TableRow,
)
self.table_label_nan = Table.from_list(
[{"class": "DOWN"}, {"class": "UP"}, {"class": None}]
)
self.dataset_label_nan = Dataset.from_iterable(
[
DatasetItem(
id="0",
subset="train",
media=TableRow(table=self.table_label_nan, index=0),
annotations=[Tabular(values={"class": "DOWN"})],
),
DatasetItem(
id="1",
subset="train",
media=TableRow(table=self.table_label_nan, index=1),
annotations=[Tabular(values={"class": "UP"})],
),
DatasetItem(
id="2",
subset="train",
media=TableRow(table=self.table_label_nan, index=2),
annotations=[Tabular(values={"class": None})],
),
],
categories={
AnnotationType.tabular: TabularCategories.from_iterable(
[("class", CategoricalDtype(), {"DOWN", "UP"})]
)
},
media_type=TableRow,
)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_split_arg_valid(self):
Expand Down Expand Up @@ -1325,7 +1356,7 @@ def test_transform_annotation_type_label(self):
],
categories={
AnnotationType.label: LabelCategories.from_iterable(
[("DOWN", "class"), ("UP", "class")]
[("class:DOWN", "class"), ("class:UP", "class")]
)
},
media_type=TableRow,
Expand Down Expand Up @@ -1369,3 +1400,47 @@ def test_transform_annotation_type_caption(self):
result = transforms.AstypeAnnotations(dataset)

compare_datasets(self, expected, result)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_transform_annotation_type_label_with_nan(self):
table = self.table_label_nan
expected = Dataset.from_iterable(
[
DatasetItem(
id="0",
subset="train",
media=TableRow(table=table, index=0),
annotations=[Label(label=0)],
),
DatasetItem(
id="1",
subset="train",
media=TableRow(table=table, index=1),
annotations=[Label(label=1)],
),
DatasetItem(
id="2",
subset="train",
media=TableRow(table=table, index=2),
annotations=[],
),
],
categories={
AnnotationType.label: LabelCategories.from_iterable(
[("class:DOWN", "class"), ("class:UP", "class")]
)
},
media_type=TableRow,
)

dataset = self.dataset_label_nan
result = transforms.AstypeAnnotations(dataset)

categories = result._categories.get(AnnotationType.label, None)
assert categories

# Check label_groups of categories
assert categories.label_groups[0].name == "class"
assert sorted(categories.label_groups[0].labels) == ["DOWN", "UP"]

compare_datasets(self, expected, result)

0 comments on commit 072c8a8

Please sign in to comment.